# Latent Reasoning Optimization
*What is the difference between Salesforce and SAP? I never read a good ML paper written by SAP lmao*
___
## Intro
The paper 'Language Models are Hidden Reasoners: Unlocking Latent Reasoning Capabilities via Self-Rewarding'^[https://arxiv.org/abs/2411.04282], which introduces **LaT**ent **R**easoning **O**ptimization (**LaTRO**), made quite the splash on X^[(formerly known as Twitter)], at least in my bubble. Also, it was published by SalesforceAI, of all organizations! For some reason, it has eluded me for a few months, but I finally go around to read it and it is just so beautiful. I love it when new developments tie back to methods that we^["You were just playing Dota 2 all day back then, so who's 'we', little bro?"] used in the past, long before the age of LLMs - back when dinosaurs roamed the earth and people would think you are clinically insane if you told them that you trained a $7$ Billion parameter model.
Since there are a lot of people who began their journey into ML more recently, there is a real possibility that a lot of the math that went into LaTRO has gone over their head. Even though the paper is quite readable (and they share their code), I still decided to take it upon myself to talk about these older ideas that are re-discovered in amazing ways in this paper. By the end of this post, you will hopefully understand everything *for real* and agree with me that this is a beautiful paper.
## Some Background + Intuition
Chain-of-Thought works. We have all seen o1, R1 and QwQ. But why does it work?
Jason Wei, the researcher who first-authored the original CoT paper, provides this intuition on his blog^[https://www.jasonwei.net/blog/some-intuitions-about-large-language-models]:
> **Tokens can have very different information density, so give language models time to think**
It's a great blogpost and you should check it out to see how people on the frontier conceptualize LLMs, but we are here to talk about Latent Reasoning and, as luck would have it, they too make a similar claim in the paper:
> Empirically, people observe that there is a higher chance for the LLM $\pi_\theta$ to generate the desired answer $y$ following the above procedure \[referring to adding "Let's think step by step at the end of the prompt\] than directly sampling the response $y \sim \pi_\theta(· | x)$. From a statistical perspective, we hypothesize that good reasoning rationales can significantly improve the probability of generating good answers $y$:
> $\boxed{\exists z: \pi_\theta(y| x \oplus z) >> \pi_\theta(y| x)}$
So our goal is to find the $z$, which, if concatenated to $x$, makes our sample $y \sim \pi_\theta(\cdot|x \oplus z)$ more likely to be correct. Let's clarify why this makes sense with an example:
Try solving the following question, which we shall call $x$:
> You buy an item for $50 and sell it for $60. Then, you buy it back for $70 and sell it again for $80. What is your total profit?
Okay, if you are reading this blog, you are probably smart enough to solve this very easily, but, believe it or not, this trips up many people. try it out yourself!
In contrast, if you provide good reasoning (let's call it $z$),
> You buy an item for $50, spending $50 initially. Then, you sell it for $60, gaining $60. At this point, your net profit is $60 - $50 = $10. Next, you buy the item back for $70, which means you spend $70, reducing your net profit to $10 - $70 = -$60. Finally, you sell the item again for $80, gaining $80. Adding this to your previous net profit, you now have $80 - $70 + $10 = $20.
then most people will find the correct answer to obviously be $y = 20\$.
So ideally, we would go through all possible reasoning traces and change the parameters of the model such that those that lead us to the correct answer become more likely and those that do not become less likely.
Alas, going through all possible reasoning traces is intractable, so we will have to live with just approximating it through Monte Carlo sampling.
## The Math
So let's do this by introducing a "latent" random variable $z$ that we cannot directly observe. $z$ represents the "rationales" (what the authors use to describe the reasoning traces that get us right up to the answer).
Now we have to juggle the input $x$,the output $y$ and the rationale $z$. the formulae will get messy real quick and obfuscate the core idea, so let's instead look at a simpler case that we can use to understand what's going on.
Namely, we can imagine some abstract generative process (with learnable parameters $\theta$) that produces an observation $x$ given some latent variable $z$, so $p_\theta(x, z) = p_\theta(x|z)~p_\theta(z)$. What we are really interested in, however, is sampling from $p_\theta(x)$, which is equivalent to marginalizing over the $z$ component:
$p_\theta (x) = \int p_\theta(x,z)~dz = \int p_\theta(x | z)~p(z)dz$
In practice, integrating over $z$ is intractable. So we need to find a way to optimize the parameters $\theta$ without integrating over $z$.
Let's make some assumptions and embed them in a smart way, such that, even if we are off, we can improve our model parameters $\theta$ by improving the parameters $\phi$ of some proxy model which we can integrate over. Let's call this proxy model $q_\phi$.
Now we can do some mathematical magic by multiplying the above equation by $1$. Everyone knows that $1$ is the neutral element when multiplying real numbers as $x \times 1 = x$, so the equation remains the same.
Okay, but instead of using $1$ to denote this neutral element, let's instead do $\frac{q_\phi}{q_\phi}$, which is obviously^[some handwaving here, but since we can generally choose $q_\phi$, we can assume infinite support for now, which makes the fraction valid] the same!
So let's take a look at $p_\theta(x)$ again:
$p_\theta(x) = \log \int p_\theta(x|z) ~ \frac{q_\phi(z|x)}{q_\phi(z|x)} ~ p(z)~dz$
If you know your basics well, you will realize that the term on the right can be rewritten as an expectation! Let me swap some terms around so it becomes clearer:
$p_\theta(x) = \int q_\phi(z|x) ~ p_\theta(x|z) ~ \frac{p(z)}{q_\phi(z|x)}~dz = \mathbb{E}_{z \sim q_\phi(z|x)}(p_\theta(x|z) ~ \frac{p(z)}{q_\phi(z|x)})$
Okay, cool, now we don't need to integrate over $z \sim p(z)$ anymore and we can instead sample from $q_\phi(z|x)$. Let's let this sink in for a moment.
Sadly, this approach has it's downsides, specifically numerical instability. Well, when we do classical Maximum-Likelihood Estimation, we run into similar problems and perhaps using the $\log$ to turn a product into a sum to avoid numerical instability could be the solution for our troubles.
So, let's do just that:
$\log p_\theta(x) = \log~\mathbb{E}_{z \sim q_\phi(z|x)}(p_\theta(x|z)~\frac{p(z)}{q_\phi(z|x)})$
Not so fast. We need to get the $\log$ inside of the expectation. Can we do that? Well, the $\log$ is a concave function, which is basically the opposite of a convex function (we can turn any concave function into a convex one and vice versa) and there is a famous result about moving convex functions into an expectation called *Jensen's inequality* - what are the chances!
Jensen's inequality states that for a convex function $\phi$, the following bound holds:
$\mathbb{E}(\phi(X)) \leq \phi(\mathbb{E}(X))$
Well, $- \log$ is convex, so we can just use that for our $\phi$ and multiply both sides by $-1$, which flips the signs, but has the exact effect that we want.
Granted, we now have an inequality between our original objective and our current objective, so we are not taking the fastet route for optimization, but hey, at least we can optimize now. The formula looks like this:
$\log p_\theta(x) \geq \mathbb{E}_{z \sim q_\phi(z|x)}(\log(p_\theta(x|z)~\frac{p(z)}{q_\phi(z|x)})) = \mathbb{E}_{z \sim q_\phi(z|x)}(\log(p_\theta(x|z)) - \log(\frac{q_\phi(z|x)}{p(z)}))$
You have studied information theory, anon, haven't you? I sure hope you did, because what you are seeing here is literally the equation for *relative entropy*, also known as the *Kullback-Leibler divergence*. Look it up if you don't believe me! So let's rewrite this as such:
$\log p_\theta(x) \geq \mathbb{E}_{z \sim q_\phi(z|x)}(\log(p_\theta(x|z)))~ - ~ D_{KL}(q_\phi(z|x) || p(z))$
This inequality is called the *Evidence Lower Bound* (ELBO) and it is what we will use to optimize. It also offers a very nice probabilistic interpretation:
- The first term maximizes how well our proxy model explains the observed data $x$
- The second term acts as a regularizing force, ensuring that our proxy model stays close to the original prior.
Okay, but what does all of this stuff have to do with Latent Reasoning Optimization? It's not like we are doing some weird stuff with VAEs (you are too young to know what that acronym stands for, aren't you, anon?).
Well, you are now ready to properly understand and appreciate LaTRO. Remember that we stand on the shoulders of giants and the authors of LaTRO did not pluck everything out of thin air. They even cite the 2014 paper^[https://arxiv.org/abs/1312.6114] by Kingma et al. which, to my knowledge, introduced ELBO - for ML, this is like one of the early quantum physicist in the 20th century citing Aristotle's 'Physica'.
Okay, so let's tie it all together in the next section.
## Latent Reasoning through Variational Methods
Remember what this was all about? In the end, it's about creating more capable LLMs, which means tuning the parameters. A standard objective could look like this:
$\max_\theta \mathbb{E}_{(x, y) \sim \mathcal{D}_{\text{gold}}} \left[\log \pi_\theta(y \mid x)\right]$
In the equation above, we have a dataset with labels — hey, that's SFT! You recognize that, don't you, anon?
Beware the change of variables: we now have a model that we condition on $x$ to output $y$, whereas earlier, when deriving the ELBO, we used $x$ to denote the output and $z$ was a hidden variable. Don't get confused!
At the very start of this post, we had this notion that, if we found a good rationale $z$, then we could augment our original prompt $x$ such that, if we sample from our LLM, conditioned on $x \oplus z$, we get the right solution $y$ to our problem. Well, let's take this seriously by treating $z$ as a hidden variable as before and writing $\pi_\theta(y | x)$ by again marginalizing over $z$ :
$\log \pi_\theta(y \mid x) = \log \int \pi_\theta(y \mid x \oplus z) ~ \pi_0(z \mid x) \, dz$
Wait a minute, why can we do $x \oplus z$? Well, for LLMs, this is literally what conditioning on $x$ and then on $z$ is. Crazy how that works. From here, we can again derive the ELBO by introducing a proxy model, which we have to do because, again, the integral is intractable. So we get
$ \log \pi_\theta(y | x) = \log \int \pi_\theta(y \mid x \oplus z) \frac{q(z \mid x)}{q(z \mid x)} \pi_0(z \mid x) \, dz$
$\geq \max_{q(z \mid x)} \mathbb{E}_{q(z \mid x)} \left[ \log \pi_\theta(y \mid x \oplus z) \right] - D_{\text{KL}} \left[q(z \mid x) \| \pi_0(z \mid x)\right]$
Now, what can we use for our proxy model $q(z|x)$, you ask? Well, we can just use a **reasoner**, i.e. an LLM that, given a task, creates a rationale (or a reasoning trace or however you wanna call it). We could even use the LLM $\pi_\theta$ as a 'naive' reasoner itself! Pretty cool, isn't it?