A Test-Time Training End-to-End (TTT-E2E) kezeli a hosszú LLM bemeneteket
A nagy nyelvi modellek (LLM-ek) jellemzően pontatlanabbá és lassabbá válnak, amikor hosszabb kontextusokat dolgoznak fel, de a kutatók lehetővé tették egy LLM számára, hogy stabil pontosságot és állandó inference időt tartson fenn, miközben a kontextus mérete nőtt.
Mi az új: Arnuv Tandon, Karan Dalal és kollégáik a nonprofit Astera Institute, Nvidia, Stanford, UC Berkeley és UC San Diego intézményekben bemutatták a Test-Time Training, End-to-End (TTT-E2E) módszert, amely inference során történő tréningezéssel tömöríti a kontextust egy transformer súlyaiba.
Kulcsfontosságú felismerés: A transformer architektúrára épülő LLM-ek a teljes kontextusra (az eddig bemeneti és kimeneti tokenekre) figyelnek, hogy generálják a következő kimeneti tokent. Így minden új kimeneti token generálása több feldolgozást igényel, mint az előző, ami potenciálisan drágává és lassúvá teszi az inference-t. A teljes kontextus figyelése helyett egy transformer korlátozhatja a figyelmét egy kisebb, fix méretű ablakra – ami állandóan tartja az egyes kimeneti tokenek generálásához szükséges időt –, és a súlyainak frissítésével tanulhat a kontextusból.
Hogyan működik: A szerzők egy 3 milliárd paraméteres transformert építettek, amely sliding-window attentiont alkalmazott, ami a figyelmet egy fix, 8000 token méretű ablakra korlátozta. A modellt 8000 tokenből álló szekvenciákon – összesen 164 milliárd tokenen – pretrainelték, amelyeket a webről kapart szövegek szűrt adathalmazából vettek. Annak érdekében, hogy hosszabb kontextusokat is tudjon követni, finomhangolták (fine-tuned) 128 000 token hosszú szekvenciákon, amelyeket a The Pile Books alhalmazából vettek. A szerzők a meta-learning egy formáját alkalmazták, azaz a tanulás tanulását; ebben az esetben a modell azt tanulja meg, hogyan tanuljon az inference időben kapott bemenetből.
Eredmények: A szerzők összehasonlították a TTT-E2E-t egy konvencionális attentionnel rendelkező transformerrel, valamint olyan rendkívül hatékony architektúrákkal, mint a Mamba 2 (egy rekurens neurális hálózat stílusú modell) és a Gated DeltaNet (amely lineáris attention egyedi formáját használja). Pontossága kissé meghaladta a transformerét hosszú kontextusokon – kivéve a Needle-in-a-Haystack feladatot, amely egy rövid célstring visszaállítását jelenti egy hosszú kontextusból –, és ugyanolyan gyorsan generált kimeneti tokeneket, mint a hatékonyabb architektúrák, ahogy a kontextus nőtt. Kivételes inference sebessége lassabb és komplexebb tréning árán jött létre.
- A tréning és a fine-tuning két ciklusban történt, az egyik (amit belső ciklusnak nevezünk) a másik (külső ciklus) által volt magában foglalva. A belső ciklus egy kontextus darab tanulását szimulálta inference során, a külső ciklus pedig értékelte, hogy a modell mennyire teljesítene a tanulás után, és ennek megfelelően módosította a súlyokat.
- A belső ciklus egy tréning szekvenciát vett, és egymást követő 1000 tokenből álló darabokra osztotta. Minden darabhoz a modell sliding-window attentiont használt, hogy (i) sorban előrejelezze az egyes tokeneket, (ii) kiszámítsa a tipikus következő token előrejelzési veszteséget, és (iii) felhasználja a veszteséget annak kiszámítására, hogy a súlyok hogyan változzanak a hálózat utolsó negyedének teljesen összekapcsolt rétegeiben. Az eredmény egy súlyfrissítési sorozat volt, ezer tokenenként egy.
- A külső ciklus ezeket a súlyfrissítéseket használta a szimulált súlyfrissítések utáni átlagos következő token előrejelzési veszteség kiszámítására. Visszaterjesztette a szimulált súlyfrissítések sorozatán keresztül, és frissítette az egész modell súlyait. (Ez a folyamat növelte a tréning időt, mert a gradiensek gradienseinek számítását igényelte.)
- Inference során a modell a belső ciklust követte. Feldarabolta a bemeneti kontextust, kiszámította a következő token előrejelzési veszteséget a darabokon, és csak a hálózat utolsó negyedének teljesen összekapcsolt rétegeit frissítette. Aztán új tokeneket generált. (Mivel az inference csak a belső ciklust használta, nem volt szüksége a külső ciklus tréning folyamatában szükséges megnövelt időre, így a feldolgozási idő állandó maradt a kontextus hosszától függetlenül.)
- A TTT-E2E valamivel magasabb teljesítményt mutatott, mint egy vanilla transformer rövid és hosszú kontextusok esetén is, a következő token előrejelzési veszteség alapján. A vanilla transformer átlagos vesztesége 0,015-tel magasabb volt 8000 és 128 000 token közötti kontextus hosszokon. A Mamba 2 és a Gated DeltaNet veszteségei továbbra is 0,03-mal magasabbak voltak. A TTT-E2E a Needle-in-a-Haystack (NIAH) feladatban megegyezett ezekkel a modellekkel rövidebb kontextusok feldolgozásakor, de teljesítménye drámaian visszaesett 8000 token után. Például 128 000 tokennél a TTT-E2E (6 százalék) a Mamba 2 (7 százalék) és a Gated DeltaNet (7 százalék) alá esett, és messze elmaradt a vanilla transformertől (99 százalék).
- A TTT-E2E gyorsabban dolgozta fel a hosszú kontextusokat, mint a vanilla transformer, nagyjából megegyezve a Mamba 2 és a Gated DeltaNet teljesítményével. Egy H100 GPU-n futtatva a TTT-E2E első tokenjének generálási ideje lineárisan, 25 milliszekundummal nőtt 1000 tokenenként, ahogy a kontextus 8000-ről 128 000 tokenre nőtt. A vanilla transformer első tokenjének ideje 12 milliszekundumról 70 milliszekundumra nőtt 1000 tokenenként, 8000-ről 128 000 tokenre növelve a kontextust.
- A TTT-E2E tréning késleltetése, vagyis az az idő, amely a modellfrissítések feldolgozásához és végrehajtásához szükséges volt 1000 tréning tokenenként, meghaladta a Mamba 2 és a Gated DeltaNet késleltetését. A TTT-E2E tréning késleltetése körülbelül 0,25 másodpercről nőtt 8000 tréning token esetén, körülbelül 0,33 másodpercre 128 000 tréning token esetén. Ezzel szemben a Mamba 2 és a Gated DeltaNet nagyjából állandó maradt, körülbelül 0,06 másodpercnél. 8000 tréning token esetén a vanilla transformer (0,08 másodperc) négyszer gyorsabban trénelt. 128 000 tokennél ez az arány megfordult: a vanilla transformer (0,39 másodperc) körülbelül 1,2-szer lassabban trénelt.
Miért fontos?
Miért fontos ez: Az inference során történő tanulás egy olyan megközelítést kínál a hosszú kontextusok feldolgozására, amely egyszerűbb, mint egyedi attention mechanizmusok vagy rekurens architektúrák tervezése. Ez a munka a problémát a tréning és az inference közötti kompromisszumként fogalmazza meg: Az inference során történő feldolgozás olcsóbb és tokenenként konzisztensebb, de a tréning lassabb.