On-device control agents, especially on mobile devices, are responsible for operating mobile devices to fulfill users' requests, enabling seamless and intuitive interactions. Integrating Multimodal Large Language Models (MLLMs) into these agents enhances their ability to understand and execute complex commands, thereby improving user experience. However, fine-tuning MLLMs for on-device control presents significant challenges due to limited data availability and inefficient online training processes. This paper introduces DistRL, a novel framework designed to enhance the efficiency of online RL fine-tuning for mobile device control agents. DistRL employs centralized training and decentralized data acquisition to ensure efficient fine-tuning in the context of dynamic online interactions. Additionally, the framework is backed by our tailor-made RL algorithm, which effectively balances exploration with the prioritized utilization of collected data to ensure stable and robust training. Our experiments show that, on average, DistRL delivers a 3X improvement in training efficiency and enables training data collection 2.4X faster than the leading synchronous multi-machine methods. Notably, after training, DistRL achieves a 20% relative improvement in success rate compared to state-of-the-art methods on general Android tasks from an open benchmark, significantly outperforming existing approaches while maintaining the same training time. These results validate DistRL as a scalable and efficient solution, offering substantial improvements in both training efficiency and agent performance for real-world, in-the-wild device control tasks.
Our method employs advantage-based estimations to refine policy gradient updates, as an extension of Generalized Advantage Estimation (GAE) by Schulman et al. (2015), to better suit asynchronous, distributed environments common in device control tasks.
With the trajectory-level rewards and state value estimates obtained, we compute the advantage function A(st, at) using one-step estimation:
A(st, at) = Q(st,at) - V(st) = r(st, at) + γV(st+1) - V(st)
which correctly represents the advantage function as per the policy gradient theorem. Here, r(st, at) includes signals of immediate rewards. The advantage and value functions are further modified by off-policy corrections.
Finally, the policy network optimizes the following loss:
ℒ = −𝔼μ[ρt A(st, at) log π(at|st)] − β 𝔼μ[ℍ(π(at|st))] + λ 𝔼μ[𝒫invalid(at)],
where ρt = π(at|st)/μ(at|st) is the importance sampling ratio between the target policy π and the behavior policy μ, ℍ is the entropy term for prevention of overfitting, 𝒫invalid(at) imposes a penalty on actions deemed invalid based on task-specific criteria, β controls the strength of entropy regularization, and λ modulates the penalty's influence. The penalty is assigned using validation through pre-trained LLMs like Gemini, ensuring that inappropriate actions are penalized. This formulation encourages the agent to explore a diverse set of actions while constraining it to generate valid and meaningful commands, thereby enhancing both exploration and policy robustness when dealing with online non-stationarities.
To enhance the estimation of the state-value function V(st) in off-policy and asynchronous settings, we apply Retrace(λ) corrections directly to V(st). The Retrace algorithm extends the TD(λ) method to off-policy learning by incorporating importance sampling ratios and a trace decay parameter λ. Specifically, we update V(st) using the correction term δt, computed as:
V(st) ← V(st) + δt
δt = Σk=tH γk-t (∏i=t+1k ci) [rk + γV(sk+1) - V(sk)]
where ci = λ min(1, ρi) with λ ∈ [0,1] being the trace decay parameter, and ρi is the importance sampling ratio as mentioned before. This correction term effectively adjusts the value estimates using future rewards and importance sampling, enabling off-policy learning while mitigating variance due to importance weights. By applying Retrace(λ), we improve the estimation of V(st) in off-policy settings, enhancing the stability and convergence of the value network.
To improve sample efficiency, we employ Distributed Prioritized Experience Replay (DPER). For each trajectory τ = {(st, at, rt, st+1)}t=0H, we compute the priority p(τ) as:
p(τ) = w1|δ̄| + w2ρ̄ + w3ℍ̄
where |δ̄| is the average absolute temporal-difference (TD) error over the trajectory, calculated as δt = rt + γV(st+1) - V(st); ρ̄ is the average importance sampling ratio ρt; and ℍ̄ is the average policy entropy, ℍt = -log π(at|st), encouraging exploration by encouraging policy uncertainty, thus avoiding early convergence to suboptimal policies during training in dynamic environments. The weights w1, w2, and w3 balance the contributions of each component, which is selected by grid-search (See Appendix: Hyperparameters). Trajectories with higher priorities are replayed more frequently, focusing learning on the most informative experiences. Priorities are periodically updated based on the latest policy, recalculating them to focus learning on the most informative experiences, ensuring continual adaptation to evolving behavior policies.
Framework Type | Framework Name | General | Web Shopping | ||
---|---|---|---|---|---|
Training | Test | Training | Test | ||
Prompting | AppAgent + GPT-4v | 41.4 | 43.0 | 31.2 | 35.2 |
AppAgent + Gemini | 39.1 | 45.3 | 30.5 | 32.0 | |
Learning | AutoUI | 38.3 | 40.6 | 42.2 | 44.5 |
DigiRL (single, online) | 64.6 ± 1.5 | 59.9 ± 2.1 | 63.3 ± 1.5 | 59.6 ± 3.1 | |
DigiRL (multi) | 67.7 ± 1.3 | 61.2 ± 2.4 | 64.5 ± 1.1 | 59.9 ± 2.8 | |
DistRL (Ours) | 75.5 ± 0.2 | 73.2 ± 1.1 | 69.8 ± 0.5 | 68.5 ± 1.7 |
@article{wang2024distrl,
title={DistRL: An Asynchronous Distributed Reinforcement Learning Framework for On-Device Control Agents},
author={Wang, Taiyi and Wu, Zhihao and Liu, Jianheng and Hao, Jianye and Wang, Jun and Shao, Kun},
journal={arXiv preprint arXiv:2410.14803},
year={2024}
}