Est: 9 minute read

Follow The Perturbed Blogger

Multi-Distribution Learning in 2025

A gentle intro into the basics and history of multi-distribution learning. Also, some cool new results by the community.

While PAC learning has historically been the dominant theoretical framework for studying learning, its inconsistency with the real world (e.g., overparameterization [3] and overfitting [4]) has become a bit of a meme. This blog centers on a limitation of PAC learning that has received growing attention from the learning theory community: it isn't expressive enough to capture the multi-domain aspects of learning. In particular, we'll talk a bit about the concepts and history behind multi-distribution learning [1], a framework for provable multi-domain learning that addresses this gap.

Towards a Multi-Domain Perspective

One of the first things we learn in ML 101 is that data collection is as important as model training. Most learning systems are expected to be performant in multiple domains, which requires robustness that you can only really get with intentional data sourcing. Large multimodal models have also demonstrated the importance of aggregating data from heterogeneous domains when it comes to learning generalizable concepts. We also increasingly see ML data being collected by the joint effort of multiple stakeholders (think users in federated learning, or institutions with a data-sharing agreement).

While PAC learning is compatible with notions of variable model class complexity and data complexity, it just isn't expressive enough to describe the scaling of the number of domains that one is expected to gather data from, learn about, and balance between. The many variants of PAC learning that have been proposed in the past few decades (online learning, bandits, active learning, membership queries, etc.) also come up short in this regard.

Multi-distribution learning (MDL) fills this gap. Its formal definition, as introduced by [1], doesn't look all that different from your classic agnostic learning setup:

Let \(\mathcal{H}\) be your hypothesis class, \(\ell\) a loss function, and \(\mathcal{D} = \{D_1, \dots, D_k\}\) a set of distributions on \(\mathcal{X} \times \mathcal{Y}\). Given example oracles for each \(D_i\) (i.e., you can call them to sample from any distributions of your choice), the goal of multi-distribution learning is to learn a (potentially randomized) hypothesis \(h\in\mathcal{H}\) that is good for all \(D_i \in \mathcal{D}\): \[ \max_{D_i \in \mathcal{D}} \mathop{\mathbb{E}}_{(x, y) \sim D_i} [\ell(h, (x, y))] \leq \epsilon + \min_{h^* \in \mathcal{H}} \max_{D_i \in \mathcal{D}} \mathop{\mathbb{E}}_{(x, y) \sim D_i} [\ell(h^*, (x, y))], \] while collecting as few datapoints as possible.

While multi-distribution learning can be defined for other social welfare functions, what I've written above, i.e. min-max utility or Rawlsian fairness, has historically been the default. You can also write multi-distribution learning with multiple losses, i.e. minimizing \(\max_{D_i \in \mathcal{D}} \max_{\ell_j \in \mathcal{L}} \mathop{\mathbb{E}}_{(x, y) \sim D_i} [\ell_j(h, (x, y))]\).

Despite its apparent simplicity, multi-distribution learning introduces three new dimensions to the learning problem:

  1. The learner is no longer a passive recipient of data. Data collection must be intentional and strategic, and naive sampling can be exponentially bad (we'll see this in a second).
  2. The learner needs to learn concepts across multiple domains. This can make things harder (you need to provide more guarantees), but also easier (some domains may be informative about others).
  3. The learner needs to balance the considerations of different domains, especially in agnostic settings where domains may disagree on labels.

History of Multi-Distribution Learning

The earliest example of a MDL-like guarantee can be found in Blum, Haghtalab, Procaccia, and Qiao [5] in 2017, which studied a collaborative learning problem where multiple learners, each with distinct data distributions, aim to collaboratively identify some ideal hypothesis that perfectly and noiselessly labels everyone's data. They demonstrated that to pick an \(\varepsilon\)-optimal hypothesis from \(\lvert \mathcal{H} \rvert\) options and \(k\) data distributions, only \(\varepsilon^{-1} \bigl((\log k)^2 \,\log \lvert \mathcal{H} \rvert + k \,\log k \bigr)\) samples were needed. That is, learning \(k\) distributions requires only \((\log k)^2\) as many samples as learning one (in the worst-case). This was improved to \(\varepsilon^{-1} \bigl(\log k \,\log \lvert \mathcal{H} \rvert + k \,\log k \bigr)\) by [6], [7]. [6] also considered the general (agnostic) setting which lifts the assumption that there is no noise and that all distributions share a single perfect hypothesis, obtaining a sample complexity of \[ \frac{1}{\varepsilon^5}\log\bigl(\tfrac{1}{\varepsilon}\bigr) \log\bigl(\tfrac{n}{\delta}\bigr) \bigl(\log(\lvert \mathcal{H} \rvert) + n\bigr). \]

Independently, parallel threads developed in robustness and federated learning literature. Under the name of group distributionally robust optimization, [8] and [9] developed algorithms aiming to ensure uniform performance across predefined groups—rediscovering a mathematical equivalent of the collaborative learning problem, but from an empirical lens centered on robustness. In parallel, [10] formalized agnostic federated learning by similarly framing the federated learning objective explicitly as a min-max problem across diverse client distributions.

Multi-distribution learning arose from institutional data sharing, domain robustness, and federated learning applications.

We eventually unified these separate threads of work in Haghtalab, Jordan, Zhao [1] under the framework of multi-distribution learning, tightening the sample complexity bound of [6] for general settings to \[ \frac{\log\lvert \mathcal{H} \rvert + n \log \bigl(\tfrac{n}{\delta}\bigr)}{\varepsilon^2}. \] This is still the only bound we're aware of that avoids a multiplicative factor increase in sample complexity (i.e., \(\log(k) \log(\lvert \mathcal{H} \rvert)\)); it remains an open problem whether this is even possible for VC classes. This rate is also optimal for VC classes up to log-factors, except for a difficult high-error-tolerance regime where \(\varepsilon > 1/k\); closing this gap became an open problem at COLT [2] that was subsequently resolved by two teams in parallel [11], [12] (more on these great papers in a bit). For the setting with \(m\) different losses and \(k\) distributions, the same algorithm provides a rate of \(\varepsilon^{-2}(\log\lvert \mathcal{H} \rvert + n \log (\tfrac{n m}{\delta}))\).

Notably, these MDL rates all arise rather elegantly from game dynamics. Specifically, we can view multi-distribution learning as the min-max equilibrium of a bilinear game between a hypothesis picker (the learner) and data distribution picker (the adversary). Designing an algorithm then becomes a matter of designing a no-regret game dynamic that quickly converges to the equilibrium without requiring either player to collect too many samples. For example, we obtained our rate in [1] by having the learner run Hedge and the adversary run a bandit algorithm, while [2] and [12] have the learner run ERM and the adversary run Hedge.

Recent Developments

Optimal MDL Rates for VC Classes with High Error Tolerance \(\varepsilon > 1/k\)

This was an open problem that we posed at COLT [2]. It was subsequently resolved by two concurrent efforts, using very different approaches.

1. Peng 2024 [11]: Peng provided an algorithm obtaining a sample complexity of \[ \widetilde{O}\bigl((d + k)\,\varepsilon^{-2}\bigr)\,\cdot\,\bigl(\tfrac{k}{\varepsilon}\bigr)^{o(1)}, \] where \(d\) is the VC dimension of \(\mathcal{H}\), resolving our open problem up to a subpolynomial factor.

One challenge we ran into with extending our rate in [1] to high error tolerance regimes was that taking a cover was very expensive, requiring \(\tfrac{k\,\log(\lvert \mathcal{H}\rvert)}{\varepsilon}\) samples. Peng implements a creative adaptive covering scheme that bypasses this cost using recursive width reduction. In online learning, width is the gap between the loss of the best and worst option available to a no-regret agent at any given timestep. Using the fact that bounded width translates into improved regret bounds, Peng implements an adaptive covering scheme that iteratively prunes hypotheses whose loss is either too high or too low. For example, if we suppose that \(\mathrm{OPT}\) is known, width reduction can be implemented by removing all hypotheses that do worse than \(\mathrm{OPT} + \varepsilon\) on any distribution and only retaining "robust" hypotheses that, for any subset of distributions with sufficient probability mass (under the adversary's policy), have a loss no greater than \(\mathrm{OPT} + \varepsilon\). The subtlety is in the latter step, as \(\mathrm{OPT}\) does not directly imply a lower bound on the loss of the optimal hypothesis on any particular distribution.

2. Zhang, Zhan, Chen, Du, Lee 2024 [12]: ZZCDL provided an explicit, oracle-efficient algorithm obtaining a rate, optimal up to \(\log\) factors of \[ O\Bigl( \frac{ d\,\log\bigl(\tfrac{k\,d}{\varepsilon}\bigr)\;+\;k\,\log\bigl(\tfrac{k}{\delta\,\varepsilon}\bigr) }{\varepsilon^2} \;\cdot\; \log^8\bigl(\tfrac{k}{\delta\,\varepsilon}\bigr) \Bigr). \] This result provides the current tightest rate for our open problem. Note also that oracle-efficiency was a second open problem in [2], which this result resolves. Their algorithm implements a no-regret best-response game dynamic where the learner runs a form of repeated ERM and the adversary runs Hedge. This has been attempted before, yet a direct implementation (as described in [7]) gives a suboptimal \(\varepsilon\)-dependence on the VC term. The key challenge was proving that the learner can re-use data when estimating their loss vector, which requires proving that the adversary's Hedge algorithm doesn't bounce between too many vertices of the simplex. This was a fairly neat technical lemma, requiring some very impressive and involved case work, which I paraphrase below:

Suppose you have a no-regret agent playing a bilinear game with an agent that approximately best-responds at each timestep. Let \(w_t(i)\) denote the weight of action \(i \in [n]\) at time \(t \in [T]\). If the no-regret agent uses Hedge, their iterates \(w_1, \dots, w_T\) satisfy: \[\sum_{i \in [n]} \max_{t \in [T]} w_{t}(i) \in O(\mathrm{polylog}(n, T)).\] That is, the no-regret agent never puts too much weight on too many actions.

For both [11] and [12], it remains an open problem how to extend their rates to also get guarantees for the multi-distribution learning setting with multiple loss functions.

Randomization and Multi-Distribution Learning

Somewhat interestingly, multi-distribution learning can generally only be performed efficiently if you're happy with obtaining a randomized or improper hypothesis (though it's still an open problem whether one can find a deterministic improper solution to agnostic multi-distribution learning with efficient sample complexity). This isn't entirely surprising from a statistical perspective, as multi-objective learning is known to be easier when randomization is allowed (this is also the case for multicalibration [14], see e.g. [15]).

However Larsen, Montasser, and Zhivotovskiy [13] recently showed that the derandomization problem for multi-distribution learning is also computationally difficult—NP-hard in the oracle-efficient setting— by reducing MDL to discrepancy-like problems. That is, the process of aggregating the support of a randomized MDL solution to a (potentially improper) deterministic solution must necessarily be computationally inefficient.

LMZ note that this computational hardness arises entirely from resolving conflicting label distributions. Specifically, there is a simple structural condition that sidesteps their hardness result: when there is no label shift between different distributions (i.e., \(\Pr_{x \sim D_1}[y \mid x] = \Pr_{x \sim D_2}[y \mid x]\)). Note that this setting both allows for the existence of label noise and a hypothesis class whose optimal hypothesis differs between distributions (e.g., \(h_1\) may be optimal for \(D_1\) while \(h_2\) is optimal for \(D_2\)). The latter may seem like it should introduce trade-offs that need to be negotiated between different distributions, potentially at significant computational expense. However, we know that such trade-offs aren't really problematic if you're allowed to be improper and your distributions are only covariate shifts of one another—this is the same intuition behind calibeating and multi-group learning.

Optimal Multiplicative Factor for VC Classes

This was also part of the open problem that we posed at COLT [2], but remains open. Here's the original statement:

We believe a \(\ln(k)\,d\) factor is missing from the best known sample complexity lower bound of \(\Theta\bigl(\varepsilon^{-2}\,(d + k \ln\bigl(\tfrac{\min\{k,d\}}{\delta}\bigr))\bigr).\) The absence of a \(\ln(k)\,d\) term would be significant as it would imply that, when model class complexity dominates sample complexity, handling more data distributions comes effectively for free. Interestingly, this \(\ln(k)\) factor does not appear in the upper bound when the complexity of \(H\) is characterized by Littlestone dimension, perhaps due to the stronger compression guarantees for online-learnable classes.

One of the reasons we expect this \(\log(k)\) factor may be present for the VC case is because taking a covering over \(k\) dimensions necessarily introduces a \(\log k\) factor to your metric entropy.

Selected Problems

There have been many more great papers over the past two years, but I'm hitting my self-imposed blog length limit so I'll cover more in another post. I'll end with a summary of some of the most important open problems in multi-distribution learning, which I've all mentioned at some form in this blog post.

  1. Optimal Multiplicative Factor for VC Classes: Is the sample complexity of multi-distribution learning for VC classes \(\Omega(\ln(k)\,d / \varepsilon^2)\)?

    An equivalent phrasing of the question: is there a fundamental separation between the multi-distribution learning of VC classes and online-learnable classes, where only the latter can be learned without a multiplicative increase in sample complexity?

  2. Improper Multi-Distribution Learning: Can you sample efficiently learn a deterministic improper solution to multi-distribution learning? We know that this is possible for realizable settings.
  3. Other Social Welfare Functions: Multi-distribution learning is typically studied with a min-max objective, but you can also ask for a Pareto-optimal solution (which can only be found inefficiently) or a solution that minimizes the \(\ell_p\)-norm of the loss vector, where each component is the loss of the hypothesis on a particular distribution. For which social welfare functions, and under what conditions, can you efficiently perform multi-distribution learning?


This is different from the gap used in e.g. UCB bounds, which is between the loss of the best and second-best option.

References

  1. Haghtalab, N., Jordan, M., & Zhao, E. (2022, May). On-demand sampling: Learning optimally from multiple distributions. Proceedings of the 36th Annual Conference on Neural Information Processing Systems. Link
  2. Awasthi, P., Haghtalab, N., & Zhao, E. (2023, July). The sample complexity of multi-distribution learning for VC classes. Proceedings of the 36th Annual Conference on Learning Theory. Link
  3. Nakkiran, P., Venkat, P., Kakade, S. M., & Ma, T. (2021, December). Deep double descent: Where bigger models and more data hurt. Journal of Statistical Mechanics: Theory and Experiment, 2021(12), 124003.
  4. Recht, B. (2017). Thou shalt not overfit. Retrieved from https://www.argmin.net/p/thou-shalt-not-overfit
  5. Blum, A., Haghtalab, N., Procaccia, A. D., & Qiao, M. (2017). Collaborative PAC learning. In I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, & R. Garnett (Eds.), Advances in Neural Information Processing Systems 30 (pp. 2392–2401). Curran Associates, Inc.
  6. Nguyen, H. L., & Zakynthinou, L. (2018). Improved algorithms for collaborative PAC learning. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, & R. Garnett (Eds.), Advances in Neural Information Processing Systems 31 (pp. 7642–7650). Curran Associates, Inc.
  7. Chen, J., Zhang, Q., & Zhou, Y. (2018). Tight bounds for collaborative PAC learning via multiplicative weights. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, & R. Garnett (Eds.), Advances in Neural Information Processing Systems 31 (pp. 3602–3611). Curran Associates, Inc.
  8. Sagawa, S., Koh, P. W., Hashimoto, T. B., & Liang, P. (2020). Distributionally robust neural networks. Proceedings of the International Conference on Learning Representations (ICLR). OpenReview.
  9. Sagawa, S., Raghunathan, A., Koh, P. W., & Liang, P. (2020). An investigation of why overparameterization exacerbates spurious correlations. In H. Daumé III & A. Singh (Eds.), Proceedings of the International Conference on Machine Learning (ICML), 119, (pp. 8346–8356). Proceedings of Machine Learning Research, PMLR.
  10. Mohri, M., Sivek, G., & Suresh, A. T. (2019). Agnostic federated learning. In K. Chaudhuri & R. Salakhutdinov (Eds.), Proceedings of the International Conference on Machine Learning (ICML), 97, (pp. 4615–4625). Proceedings of Machine Learning Research, PMLR.
  11. Peng, B. (2023). The sample complexity of multi-distribution learning.
  12. Zhang, Z., Zhan, W., Chen, Y., Du, S. S., & Lee, J. D. (2023). Optimal multi-distribution learning.
  13. Larsen, K. G., Montasser, O., & Zhivotovskiy, N. (2024). Derandomizing multi-distribution learning.
  14. Hebert-Johnson, U., Kim, M. P., Reingold, O., & Rothblum, G. N. (2018). Multicalibration: Calibration for the (computationally-identifiable) masses. In J. G. Dy & A. Krause (Eds.), Proceedings of the International Conference on Machine Learning (ICML), Proceedings of Machine Learning Research (pp. 1944–1953). PMLR.
  15. Haghtalab, N., Jordan, M., & Zhao, E. (2023, February). A unifying perspective on multi-calibration: Game dynamics for multi-objective learning. Proceedings of the 37th Annual Conference on Neural Information Processing Systems.

Thanks for reading! Anonymous feedback can be left here. Feel free to reach out if you think there's something I should add or clarify.