Post
3223
๐๐๐ฐ ๐๐๐๐จ๐๐ข๐ง๐ ๐ญ๐๐๐ก๐ง๐ข๐ช๐ฎ๐ ๐ข๐ง ๐ญ๐ซ๐๐ง๐ฌ๐๐จ๐ซ๐ฆ๐๐ซ๐ฌ ๐ฌ๐ข๐ ๐ง๐ข๐๐ข๐๐๐ง๐ญ๐ฅ๐ฒ ๐ซ๐๐๐ฎ๐๐๐ฌ ๐ก๐๐ฅ๐ฅ๐ฎ๐๐ข๐ง๐๐ญ๐ข๐จ๐ง๐ฌ ๐
DoLa decoding, which made a conference paper at ICLR '24, has just been merged in Transformers by @joaogante and Yung-Sung Chuang.
This new decoding method is simple yet extremely impressive!
Reminder: Decoder LLMs (the GPT kind of LLM, the most common one) generate their outputs one token at a time: at each step, given a current text, they compute a logit for each token in their vocabulary that should represent the probability of this token coming next.
Then they either pick the highest logit token (greedy decoding) or sample one with a probability defined by the logits (sampling).
The authors of DoLa wanted to improve that simple method.
They knew this established fact that transformer LMs encode low-level info (like base syntax) in early layers and more high-level info like knowledge in the later layers.
๐ก This gave them their key idea: During decoding, rather than picking the token with the highest logit, ๐๐ต๐ ๐ป๐ผ๐ ๐ฝ๐ถ๐ฐ๐ธ ๐๐ต๐ฒ ๐๐ผ๐ธ๐ฒ๐ป ๐๐ถ๐๐ต ๐๐ต๐ฒ ๐บ๐ผ๐๐ ๐ถ๐บ๐ฝ๐ฟ๐ฒ๐๐๐ถ๐๐ฒ ๐ถ๐ป๐ฐ๐ฟ๐ฒ๐ฎ๐๐ฒ ๐ถ๐ป ๐น๐ผ๐ด๐ถ๐ ๐ฎ๐ฐ๐ฟ๐ผ๐๐ ๐น๐ฎ๐๐ฒ๐ฟ๐?
This gives impressive results:
๐ ๐ฑ% - ๐ฎ๐ฌ% ๐ฏ๐ฎ๐๐ฒ ๐ฝ๐ผ๐ถ๐ป๐๐ ๐ถ๐ป๐ฐ๐ฟ๐ฒ๐ฎ๐๐ฒ ๐ฎ๐ฐ๐ฟ๐ผ๐๐ ๐๐ต๐ฒ ๐ฏ๐ฒ๐ป๐ฐ๐ต๐บ๐ฎ๐ฟ๐ธ๐
๐ For instance on TruthfulQA / Open-ended, across all model sizes the increase in truthfulness is 14 base points, which is ๐ฎ๐ฟ๐ผ๐๐ป๐ฑ ๐ฐ๐ฌ% ๐ถ๐บ๐ฝ๐ฟ๐ผ๐๐ฒ๐บ๐ฒ๐ป๐ ๐ฐ๐ผ๐บ๐ฝ๐ฎ๐ฟ๐ฒ๐ฑ ๐๐ผ ๐๐๐ฎ๐ป๐ฑ๐ฎ๐ฟ๐ฑ ๐ฑ๐ฒ๐ฐ๐ผ๐ฑ๐ถ๐ป๐ด!
๐ค Wouldn't decoding take longer because of this added contrasting step? ๐ ๐ง๐ต๐ฒ ๐ฟ๐๐ป๐๐ถ๐บ๐ฒ ๐ถ๐ป๐ฐ๐ฟ๐ฒ๐ฎ๐๐ฒ ๐ถ๐ ๐ป๐ฒ๐ด๐น๐ถ๐ด๐ถ๐ฏ๐น๐ฒ, ๐ญ ๐๐ผ ๐ด% ๐ผ๐ป๐น๐.
Paper added to my collection ๐ m-ric/optimization-mechanics-661d543a5fc6ca1dc84284a0
DoLa decoding, which made a conference paper at ICLR '24, has just been merged in Transformers by @joaogante and Yung-Sung Chuang.
This new decoding method is simple yet extremely impressive!
Reminder: Decoder LLMs (the GPT kind of LLM, the most common one) generate their outputs one token at a time: at each step, given a current text, they compute a logit for each token in their vocabulary that should represent the probability of this token coming next.
Then they either pick the highest logit token (greedy decoding) or sample one with a probability defined by the logits (sampling).
The authors of DoLa wanted to improve that simple method.
They knew this established fact that transformer LMs encode low-level info (like base syntax) in early layers and more high-level info like knowledge in the later layers.
๐ก This gave them their key idea: During decoding, rather than picking the token with the highest logit, ๐๐ต๐ ๐ป๐ผ๐ ๐ฝ๐ถ๐ฐ๐ธ ๐๐ต๐ฒ ๐๐ผ๐ธ๐ฒ๐ป ๐๐ถ๐๐ต ๐๐ต๐ฒ ๐บ๐ผ๐๐ ๐ถ๐บ๐ฝ๐ฟ๐ฒ๐๐๐ถ๐๐ฒ ๐ถ๐ป๐ฐ๐ฟ๐ฒ๐ฎ๐๐ฒ ๐ถ๐ป ๐น๐ผ๐ด๐ถ๐ ๐ฎ๐ฐ๐ฟ๐ผ๐๐ ๐น๐ฎ๐๐ฒ๐ฟ๐?
This gives impressive results:
๐ ๐ฑ% - ๐ฎ๐ฌ% ๐ฏ๐ฎ๐๐ฒ ๐ฝ๐ผ๐ถ๐ป๐๐ ๐ถ๐ป๐ฐ๐ฟ๐ฒ๐ฎ๐๐ฒ ๐ฎ๐ฐ๐ฟ๐ผ๐๐ ๐๐ต๐ฒ ๐ฏ๐ฒ๐ป๐ฐ๐ต๐บ๐ฎ๐ฟ๐ธ๐
๐ For instance on TruthfulQA / Open-ended, across all model sizes the increase in truthfulness is 14 base points, which is ๐ฎ๐ฟ๐ผ๐๐ป๐ฑ ๐ฐ๐ฌ% ๐ถ๐บ๐ฝ๐ฟ๐ผ๐๐ฒ๐บ๐ฒ๐ป๐ ๐ฐ๐ผ๐บ๐ฝ๐ฎ๐ฟ๐ฒ๐ฑ ๐๐ผ ๐๐๐ฎ๐ป๐ฑ๐ฎ๐ฟ๐ฑ ๐ฑ๐ฒ๐ฐ๐ผ๐ฑ๐ถ๐ป๐ด!
๐ค Wouldn't decoding take longer because of this added contrasting step? ๐ ๐ง๐ต๐ฒ ๐ฟ๐๐ป๐๐ถ๐บ๐ฒ ๐ถ๐ป๐ฐ๐ฟ๐ฒ๐ฎ๐๐ฒ ๐ถ๐ ๐ป๐ฒ๐ด๐น๐ถ๐ด๐ถ๐ฏ๐น๐ฒ, ๐ญ ๐๐ผ ๐ด% ๐ผ๐ป๐น๐.
Paper added to my collection ๐ m-ric/optimization-mechanics-661d543a5fc6ca1dc84284a0