Модели отскейлили от 100M до 7B параметров, Griffin...
Модели отскейлили от 100M до 7B параметров, Griffin до 14B. Количество токенов в обучении скейлили по рецептам Шиншиллы (https://t.me/gonzo_ML/1216), для оценки на разных задачах модели обучали на 300B токенов. Все модели демонстрируют красивую степенную зависимость между лоссом и training FLOPs. Лоссы грифона стабильно чуть ниже трансформерного бейзлайна при том же бюджете. У ястреба повыше, но с тенденцией к уменьшению по мере роста бюджета.
Внешними бейзлайнами выступили Mamba-3B и Llama-2 (7B, 13B). Они обучены на больших (600B/2T) и отличающихся датасетах. Hawk и Griffin весьма хороши, бьют Мамбу, хоть и обучались на меньших датасетах.
Для обучения больших моделей на наборе устройств реализовали model parallel training через шардинг слоёв. Отдельный челлендж -- эффективная реализация рекуррентностей на устройствах, так как в отличие от классических архитектур они работают в режиме низкого FLOPs-to-byte ratio, и вычисления оказываются memory bound. Кастомные кернелы написали на Pallas (https://jax.readthedocs.io/en/latest/pallas/index.html), специальном расширении JAX. Как это выглядит, можно посмотреть в репе RecurrentGemma (https://github.com/google-deepmind/recurrentgemma/blob/main/recurrentgemma/jax/pallas.py). Использовали linear scan, получилось в три раза быстрее родной реализации. Через associative scan (использовался в S5, https://arxiv.org/abs/2208.04933) получается медленнее, а через свёртки это не получается, механизм гейтинга RG-LRU не совместим со свёрточным представлением.
С ростом длины последовательности обучение Грифона идет быстрее обучения трансформера. Особенно эта разница заметна, когда длина последовательности заметно больше размерности модели и вычисление внимания занимает значимую долю всего времени.
По latency на инференсе Hawk и Griffin быстрее MQA трансформера (который в свою очередь быстрее классического MHA). Заметная разница проявляется на больших длинах, в основном после 2048 токенов. Throughput у новых моделей тоже лучше (особенно у Hawk), частично от лучшего latency, частично от меньшего размера кешей и возможности запихнуть больший батч на тот же девайс. Griffin поэтому же медленнее Hawk, его кеш локального внимания растёт с ростом батча.
На предсказании следующего токена в длинной последовательности новые модели лучше трансформеров и экстраполируют на сильно более длинные последовательности (по крайней мере 4x), чем были в обучении. Из интересных наблюдений, модели, обученные на меньшей длине (2k против 8k), перформят на малых длинах лучше. Поэтому важно выбирать длину последовательности при обучении под будущие задачи.
Одна свежая работа “Repeat After Me: Transformers are Better than State Space Models at Copying” (https://arxiv.org/abs/2402.01032) показала, что трансформеры лучше работают на задачах типа копирования или retrieval’а, чем SSM. Проверили новые модели на задачах Selective Copying и Induction Heads (как в работе про Мамбу, https://t.me/gonzo_ML/2149). Все три модели могут идеально решить задачу копирования (но Hawk обучается медленнее). На induction jeads все три решают задачу до определённого предела длины, дальше трансформер фейлится, не может экстраполировать. На этих задачах и у Мамбы всё было хорошо (https://t.me/gonzo_ML/2154).
В упомянутой работе про “Repeat After Me” была предложена задача retrieval с синтетической телефонной книгой, где по имени надо выбрать номер телефона. В промпте содержится “книга”, затем два примера и имя для которого надо извлечь телефон. На этой задаче Hawk быстро скатывается в ноль с ростом длины книги, это похоже на поведение Мамбы. Что в общем неудивительно, размер состояния у него маленький. Трансформер держится до длин знакомых по обучению и после скатывается в ноль. Griffin идеально держится до длины контекста локального внимания, затем начинает деградировать, но зато экстраполирует дальше трансформера.
Интересное развитие!
Внешними бейзлайнами выступили Mamba-3B и Llama-2 (7B, 13B). Они обучены на больших (600B/2T) и отличающихся датасетах. Hawk и Griffin весьма хороши, бьют Мамбу, хоть и обучались на меньших датасетах.
Для обучения больших моделей на наборе устройств реализовали model parallel training через шардинг слоёв. Отдельный челлендж -- эффективная реализация рекуррентностей на устройствах, так как в отличие от классических архитектур они работают в режиме низкого FLOPs-to-byte ratio, и вычисления оказываются memory bound. Кастомные кернелы написали на Pallas (https://jax.readthedocs.io/en/latest/pallas/index.html), специальном расширении JAX. Как это выглядит, можно посмотреть в репе RecurrentGemma (https://github.com/google-deepmind/recurrentgemma/blob/main/recurrentgemma/jax/pallas.py). Использовали linear scan, получилось в три раза быстрее родной реализации. Через associative scan (использовался в S5, https://arxiv.org/abs/2208.04933) получается медленнее, а через свёртки это не получается, механизм гейтинга RG-LRU не совместим со свёрточным представлением.
С ростом длины последовательности обучение Грифона идет быстрее обучения трансформера. Особенно эта разница заметна, когда длина последовательности заметно больше размерности модели и вычисление внимания занимает значимую долю всего времени.
По latency на инференсе Hawk и Griffin быстрее MQA трансформера (который в свою очередь быстрее классического MHA). Заметная разница проявляется на больших длинах, в основном после 2048 токенов. Throughput у новых моделей тоже лучше (особенно у Hawk), частично от лучшего latency, частично от меньшего размера кешей и возможности запихнуть больший батч на тот же девайс. Griffin поэтому же медленнее Hawk, его кеш локального внимания растёт с ростом батча.
На предсказании следующего токена в длинной последовательности новые модели лучше трансформеров и экстраполируют на сильно более длинные последовательности (по крайней мере 4x), чем были в обучении. Из интересных наблюдений, модели, обученные на меньшей длине (2k против 8k), перформят на малых длинах лучше. Поэтому важно выбирать длину последовательности при обучении под будущие задачи.
Одна свежая работа “Repeat After Me: Transformers are Better than State Space Models at Copying” (https://arxiv.org/abs/2402.01032) показала, что трансформеры лучше работают на задачах типа копирования или retrieval’а, чем SSM. Проверили новые модели на задачах Selective Copying и Induction Heads (как в работе про Мамбу, https://t.me/gonzo_ML/2149). Все три модели могут идеально решить задачу копирования (но Hawk обучается медленнее). На induction jeads все три решают задачу до определённого предела длины, дальше трансформер фейлится, не может экстраполировать. На этих задачах и у Мамбы всё было хорошо (https://t.me/gonzo_ML/2154).
В упомянутой работе про “Repeat After Me” была предложена задача retrieval с синтетической телефонной книгой, где по имени надо выбрать номер телефона. В промпте содержится “книга”, затем два примера и имя для которого надо извлечь телефон. На этой задаче Hawk быстро скатывается в ноль с ростом длины книги, это похоже на поведение Мамбы. Что в общем неудивительно, размер состояния у него маленький. Трансформер держится до длин знакомых по обучению и после скатывается в ноль. Griffin идеально держится до длины контекста локального внимания, затем начинает деградировать, но зато экстраполирует дальше трансформера.
Интересное развитие!
Источник: gonzo-обзоры ML статей
2024-04-15 08:36:35