Learn from your own latents and not from tokens: A sample-complexity theory
Generative models, from diffusion models to large language models, achieve remarkable performance but at a cost in training data orders of magnitude larger than what biological learners require. An alternative paradigm has emerged in which networks are trained to predict their own latent representations of related views or masked regions, as in data2vec and JEPA – an idea related to predictive-coding accounts of the cortex. Despite strong empirical results, the theoretical understanding of these methods remains limited. Central questions include: by how much does latent prediction actually improve data efficiency? Is there a benefit to stacking such methods into multi-scale hierarchies? We answer both using as data a tractable probabilistic context-free grammar that captures the compositional structure of natural language and images. Such a grammar generates strings of visible tokens by recursively applying production rules along a tree of hidden symbols of depth L. For such data, supervised or token-level SSL require a number of samples exponential in L to recover the latent tree; we prove that latent prediction achieves this with a number of samples constant in L, up to logarithmic factors.
Introduction. Generative models have reached striking empirical success. Diffusion models produce realistic images and video [1–3], while large language models trained by next-token prediction master grammar, world knowledge and reasoning [4–6]. Both rest on a single recipe: predict masked or future fragments of the raw signal at massive scale. Yet this success is bought at a cost biological learners do not pay. Frontier LLMs are trained on 1013–1014 tokens [7, 8], more than five orders of magnitude beyond what a child encounters before reaching adult-level competence [9–11]; state-of-the-art diffusion models likewise rely on billions of images [12]. This gap suggests that token-level pretraining is far from optimal, and several hypotheses have been advanced to explain it. One is that biological learning is multimodal and grounded [13, 14]. A second – the one we pursue – is that learning may not be most efficient at the level of raw tokens, but instead in a more abstract latent space [15–17].
Discussion / Conclusion. We have proved that learning from one’s own latents recovers the full non-root latent tree of the Random Hierarchy Model from a number of samples scaling as m3, exponentially fewer than the mL+1 required by token-level self-supervised objectives. We confirmed this prediction with a hierarchical clustering algorithm, an end-to-end neural network of predictor-clusterer modules, and the first sample-complexity analysis of data2vec. The main lesson is that latents at the same level of the hierarchy are far more correlated with each other than they are with raw tokens, so predicting from one’s own latents amplifies a signal that token-level prediction dilutes. This places on a firm quantitative footing the common intuition that token-level prediction is suboptimal. Practical implications. Recent work supports that empirical neural scaling laws in language are governed by power-law decays of token correlations with context length [33, 34, 38].