Do LLMs not consider the probability distribution over all combinations of tokens up to a certain output length with regard to sequence prediction? I assumed they did that already.
If they don’t, I’m amazed they work as well as they do. Consider 2-bit sequence prediction with the following possible outcomes and associated probabilities:
00: p=0.36
01: p=0.04
10: p=0.30
11: p=0.30
So the most likely 2-bit sequence is 00. But on the basis of predicting the next token (bit) alone, we have:
0: p=0.40
1: p=0.60
which suggests that 1 is the next bit and leads to a suboptimal starting point for predicting the bit after that. The error is even more prominent with longer sequences as the joint probability distribution becomes more unfactorizable into marginal distributions (as I would expect any minimal algorithmic description of real-world data to be).
Edit: now that I think about this a bit more, a cool research project that would be really simple to carry out might be to modify the cross-entropy loss function to consider only the nth future token in the text training data, and then plot LLM performance vs n, assuming that for all current LLM models we just have n=1.
My hypothesis is that you can mostly bypass all of the resource blow-up involved in predicting the joint probability distribution over the next 1 through n tokens (which scales as x^n) by just predicting the nth token directly, since doing so would implicitly require a better data model (at least for human-generated text; this wouldn’t be the case for all types of data).
I think you're not looking at this from the right perspective. A LLM is designed to sample text, that follows the training distribution. It is not designed to tell you the "most likely" text that follows, and we don't actually want that. This would mean you have no diversity in your outputs.
In your example, sampling a 0 in 40% of cases and a 1 in 60% of cases does make sense for chat applications.
For applications where we do care about the most likely sentence (e.g. question answering), then beam search helps, as others have mentioned.
Another thing to consider is that the model can "look ahead" and precompute what the future tokens might be. And it can then use this to predict the current token. In fact, some work have been investigating this, such as [1].
And a final note, predicting one token at a time is what we are doing as humans when we speak, so clearly it is not a wrong approach. We are doing this "look ahead" in our mind before speaking.
> It is not designed to tell you the "most likely" text that follows, and we don't actually want that. This would mean you have no diversity in your outputs.
No, we specifically do want "most likely" to follow; the goal is to approximate Solomonoff induction as well as possible. See this recent paper by Hutter's team: https://arxiv.org/pdf/2401.14953
Quote from the paper:
"LLMs pretrained on long-range coherent documents can learn new tasks from a few examples by inferring a shared latent concept. They can do so because in-context learning does implicit Bayesian inference (in line with our CTW experiments) and builds world representations and algorithms (necessary to perform SI [Solomonoff Induction]). In fact, one could argue that the impressive in-context generalization capabilities of LLMs is a sign of a rough approximation of Solomonoff induction."
> In your example, sampling a 0 in 40% of cases and a 1 in 60% of cases does[n't] make sense for chat applications.
I didn't say anything about sampling. A sequence prediction model represents a mapping between an input sequence and a probability distribution over all possible output sequences up to a certain length.
My example uses a binary alphabet, but LLMs use an alphabet of tokens. Any chat application that expresses its output as a string of concatenated symbols from a given alphabet has a probability distribution defined over all possible output sequences. I'm simply comparing the fundamental limitations of any approach to inference that restricts its outcome space to sequences consisting of one symbol (and then layers on a meta-model to generate longer sequences by repeatedly calling the core inference capability) vs an approach that performs inference over an outcome space consisting of sequences longer than one symbol.
> "It is not designed to tell you the "most likely" text that follows,"
It is exactly designed to do that. A temperature of 0 this is what you are approximating. The crucial point though is that it is the most likely next word given the proceeding multi-token context, not just the previous token.
Indeed. An interesting reference to this is the work Millman Parry did to describe the key phrases in the Odyssey and the queues they gave to help someone memorize the poem.
Also, this is maybe a semantic point, but, I am not predicting any words I speak. Not in a statistical sense. I have intent behind my words, which means I have an abstraction of meaning that I want to convey and I assemble the correct words to do that. no part of that is "predictive"
This is how they work and it's a real problem when doing prediction with low temperatures.
IIRC you see weird patterns in LLM outputs since "an" is often less likely than "a" so you end up with fewer nouns beginning with vowels than you would expect.
Language models factor the joint probability p(y, x) as p(y, x) = p(y|x) p(x) which is exact. I.e. if you train a language model on your distribution and sample with temperature 1, you will get the exact same distribution out. If you sample at lower temperature or even greedily, evidently, you will get other distributions.
What you’ve described is basically the problem with greedy sampling in the decoder. Many other local optimization sampling strategies exist (e.g., beam search) and there’s been a lot of work on more global sampling (e.g., speculative decoding).
Training loss considers only the next single token, right? (I’m not up-to-date on the SOTA.)
I thought post-training prediction still only directly predicts the next token and beam search is sort of a meta-model applied over that (i.e., it is a model on top of the output of the model that performs next-token prediction—beam search considers at each iteration a subset of the current next-token predictions ranked by their probability to use as multiple starting points for predicting the next token, while keeping track of the joint probabilities to prune the set of candidate sequences at each step).
Seems like beam search would fail drastically in cases where the true (unknown) probability distribution over all sequences of tokens of length n has very low conditional probabilities for the first few tokens, each given the computed joint probability of the prior predicted tokens. That is, the true values of p(t2|t1), p(t3|t2,t1), p(t4|t3,t2,t1), ... as derived from the unknown p(t1,t2,...,tn) are very small, but very high when computed via a next-token prediction model.
I’m suggesting to modify both. Use cross-entropy of the nth token for training loss. Use cross-entropy of nth token for post-training prediction and then work backward from there to the beginning of your sequence prediction.
The problem is that a position's probability output is conditioned via attention on all previous positions.
If you want to be better you need to switch to DDPMs for example (e.g. an encoder-only transformer to predict diffusion transition probabilities in parallel, then apply steps of denoising).
The problem is just that these don't work so well from auto regressive decoder transformers, and encoder-decoder architectures like e.g. Google's T5 have fallen out of favor since about LLAMA dropped.
> 0: p=0.40 1: p=0.60 which suggests that 1 is the next bit and leads to a suboptimal starting point for predicting the bit after that. The error is even more prominent with longer sequences as the joint probability distribution becomes more unfactorizable into marginal distributions (as I would expect any minimal algorithmic description of real-world data to be).
Can someone explain this part a bit more? I'm not seeing the issue. From what I see, if the first token (t1) output is a zero, then the next token (t2) would have probabilities 0:p=.90 and 1:p=.10. (And t2 0/1:p= .50/.50 if t1=1)
Mathematically, those line up with the initial distribution, so what's the concern? That's how conditional probability works.
If I'm reading you right, you're saying that a simple way to do this would be to calculate logits for not just the next token, but also n+1 -- all at the same time. If one of the n+1 logits is chosen, then do an infill on the skipped token for the next step, then resume.
This could get us around the example that you gave for only a linear increase in the vocabulary size -- so looking an extra token ahead only increases vocab size by a factor of 2, and looking at a third token is a total factor of 3.
....uhh Isnt this what Beam search attempts to fix(approximate) over greedy decoding. You know greedy optimum vs global optimum. And most LM already use beam search. Because finding the true optimum is intractable even over modest lengths.
Sure, beam search makes the result less greedy than it originally was, but it’s still an extremely greedy approach overall.
It would be sort of like trying to make local optimization techniques less local by running the process multiple times from different starting points and choosing the minimal basin the process ended up in from out of the different runs. Quite a bit better, but in many cases not even close to the global optimum*.
For a slightly more apt analogy, it would be like:
1) Choose multiple starting points
2) For each starting point, perform local optimization until you get stuck in a basin (local minimum)
3) Keep the n overall lowest points, create m different perturbations for each point that pushes the point out of its basin, and go to 1), using these m*n points as your new set of starting points for the next round.
Note this process is totally agnostic to whatever the local optimization algorithm is. That’s why I called the beam search part of an LLM’s post-training prediction a “meta-model”, because it doesn’t matter if the core inference is performed by a transformer architecture or something entirely different.
*I say “in many cases”. But I am extremely curious for this particular case of inferring sequences of human-generated text the degree to which we fail to capture the true joint probability distribution via single-token prediction + beam search.
It’s quite possible we are very far off—perhaps this is the missing “ingredient” for generalized reasoning. On the other hand, as I said in my original post, I never would have guessed current SOTA LLMs only use single-token prediction, and I’m astonished they work as well as they do based on that, so maybe we’re not actually too far off. Without further research, it’s just speculation either way though.
Hmm, I may have misundertood this, but isnt this just beam search but multiple times(* possibly). Also usually the search is over the discrete token space directly, I am not sure if there are continous surrogates which translate the discrete inference problem(combinatorial) to a continous one, which fit better with your local minima, and perturbation terminology. Although I am uncertain about the utility of beam search run multiple times, I am keen to research literature casting the inference search problem as a continous one.
*This might be just a detail/semantics, but for example, for a 3 word context, your starting points may look like "[some token][empty][empty]". Here your procedure simply reduces to a single run of beam search, not multiple, since beam search optimises locally for every turn, producing n different "perturbations" every turn. Let me know if I misunderstood you on this.
But inference combinatorial optimisation as a continous surrogate(which sounds like what you are conveying, and will naturally result in starting points as sentences and not just incomplete words/sequences) is something I never considered. There must be some literature around on this....lets see.
It's called the Markov assumption. It was basically the single most important piece of mathematics in the field for decades. It allowed us to solve otherwise intractable problems given the limited compute budgets of the time.
If they don’t, I’m amazed they work as well as they do. Consider 2-bit sequence prediction with the following possible outcomes and associated probabilities:
So the most likely 2-bit sequence is 00. But on the basis of predicting the next token (bit) alone, we have: which suggests that 1 is the next bit and leads to a suboptimal starting point for predicting the bit after that. The error is even more prominent with longer sequences as the joint probability distribution becomes more unfactorizable into marginal distributions (as I would expect any minimal algorithmic description of real-world data to be).Edit: now that I think about this a bit more, a cool research project that would be really simple to carry out might be to modify the cross-entropy loss function to consider only the nth future token in the text training data, and then plot LLM performance vs n, assuming that for all current LLM models we just have n=1.
My hypothesis is that you can mostly bypass all of the resource blow-up involved in predicting the joint probability distribution over the next 1 through n tokens (which scales as x^n) by just predicting the nth token directly, since doing so would implicitly require a better data model (at least for human-generated text; this wouldn’t be the case for all types of data).