Skip to content
Home » Recent Progress » Assisted Learning for Organizations with Limited Imbalanced Data

Assisted Learning for Organizations with Limited Imbalanced Data

Abstract

We develop an assisted learning framework for assisting organization-level learners to improve their learning performance with limited and imbalanced data. In particular, learners at the organization level usually have sufficient computation resource, but are subject to stringent data sharing and collaboration policies. Their limited imbalanced data often cause biased inference and sub-optimal decision-making. In our assisted learning framework, an organizational learner purchases assistance service from a service provider and aims to enhance its model performance within a few assistance rounds. We develop effective stochastic training algorithms for assisted deep learning and assisted reinforcement learning. Different from existing distributed algorithms that need to frequently transmit gradients or models, our framework allows the learner to only occasionally share information with the service provider, and still achieve a near-oracle model as if all the data were centralized.

A Generalization for Horizontal Distributed Data

We identify the need for developing an assisted learning framework for facilitating the deployment
of general machine learning in large organizations. This learning framework addresses the unique
challenges as explained previously.

Comparison of AssistSGD, SGD, Learner-SGD and FedAvg with balanced learner’s data using AlexNet (top row) and ResNet-18 (bottom row).

We first develop an assisted deep learning framework for organizational learners with limited and
imbalanced data, and propose a stochastic training algorithm named AssistSGD. Specifically, every
assistance round of AssistSGD consists of two phases. In the first phase, the learner performs local
SGD training for multiple iterations and sends the generated trajectory of models together with their
corresponding local loss values to the service provider. In the second phase, the provider utilizes
the learner’s information to evaluate the global loss of the received models, and uses the best model
with the smallest global loss as an initialization. Then, the provider performs local SGD training for
multiple iterations and sends the generated trajectory of models together with their corresponding
local loss values to the learner. Finally, the learner utilizes the provider’s information to evaluate the
global loss of the received models, and outputs the best model with the smallest global loss. Under
mild technical assumptions, we formally prove that AssistSGD with full batch gradient updates is
guaranteed to find a critical point of the global loss function in general nonconvex optimization.

Comparison of AssistPG, PG, Learner-PG and FedAvg in the CartPole and LunarLander game.

We further generalize the framework to enable assisted reinforcement learning, and develop a policy
gradient training algorithm named AssistPG, which has the same training logic as that of AssistSGD.
Through extensive experiments with deep learning and reinforcement learning, we demonstrate that
the learner can achieve a near-oracle performance with AssistSGD and AssistPG as if all the data
were centralized. In particular, as the learner data’s level of imbalance increases, AssistSGD can help
the learner achieve a higher performance gain. Moreover, data are never exchanged in the assisted
learning process for both participants.

References

Chen, Cheng, Jiaying Zhou, Jie Ding, and Yi Zhou. “Assisted learning for organizations with limited data.” arXiv preprint arXiv:2109.09307 (2021). [DOC]