Нейролента - подборка новостей о нейронных сетях, ChatGPT

[Singapore] TinyLlama: An Open-Source Small Language Model. Peiyuan...

[Singapore] TinyLlama: An Open-Source Small Language Model
Peiyuan Zhang, Guangtao Zeng, Tianduo Wang, Wei Lu
Статья:https://arxiv.org/abs/2401.02385
Код: https://github.com/jzhang38/TinyLlama

В полку SLM (Small Language Models) прибыло! TinyLlama — это моделька размера 1.1B, обученная на 3T токенов! Для сравнения намного большую 70B Шиншиллу (https://t.me/gonzo_ML/1216) обучали на меньшем датасете в 1.4T токенов. По рецептам Шиншиллы оптимальное обучение для 1B модели было бы на 20B токенов (https://t.me/gonzo_ML/1223), а тут 3T, почувствуйте разницу! Кажется, это в первый раз для настолько малой модели.

Из других SLM за последнее время были, например, Phi 1 и 1.5 с 1.3B (https://t.me/gonzo_ML/1871), Phi 2 c 2.7B (https://t.me/gonzo_ML/2173) или Gemini Nano с 1.8B и 3.2B (https://t.me/gonzo_ML/2117).

Это интересное направление, потому что в целом все бегут за большими размерами, и ниша малых моделей недоисследована, а с учётом важности инференса они не менее важны. При этом давно уже есть наблюдения, что можно пообучать модель сильно за пределами compute optimal рецептов Шиншиллы, и это продолжает приносить плоды.

Архитектура классическая, декодер трансформера по рецепту Llama 2 с её же токенизатором. Данные собрали из SlimPajama (почищенный вариант RedPajama) и Starcoderdata, суммарно 950B токенов, так что обучали примерно 3 эпохи. Сэмплили датасеты в пропорции 7:3.

При этом задействовали разные продвинутые штуки и взяли RoPE энкодинги, RMSNorm pre-norm, SwiGLU, grouped-query attention.

Для скейлинга и ускорения задействовали Fully Sharded Data Parallel (FSDP) из Пайторча, свежий Flash Attention 2, заменили fused SwiGLU из xFormers на оригинальный и сэкономили памяти (это, кстати, для меня удивительно, мои первые ожидания, что fused реализация должна быть лучше) -- это позволило уместить модель в 40Gb памяти.

В итоге на A100-40G получили training throughput в 24,000 токенов в секунду. Для обучения на 300B токенов TinyLlama-1.1B требуется 3,456 A100 GPU-часов, в то время как у Pythia эта цифра равна 4,830 и у MPT’s вообще 7,920 часов.

Использовали для обучения Lit-GPT (https://github.com/Lightning-AI/lit-gpt, базируется на nanoGPT). AdamW, cosine learning rate, warmup, gradient clipping.

Обучалось 90 дней на 16 A100-40G GPU. По ценам AWS на p4d (https://aws.amazon.com/ec2/instance-types/p4/) это было бы примерно $140k между прочим.

Результат хорошо бьёт бейзлайны в лице OPT-1.3B, Pythia-1.0B и Pythia-1.4B. На MMLU правда хуже. С увеличением вычислительного бюджета перформанс продолжает расти, не понял только почему он более шумным становится.

Кажется, работа -- верх открытости. Весь код обучения, промежуточные чекпойнты, все детали обучения доступны.

Респект!