Нейролента - подборка новостей о нейронных сетях, ChatGPT

γ = 1 − 2^{−5−arange(0,h)} ∈ R^h. head_i...

γ = 1 − 2^{−5−arange(0,h)} ∈ R^h
head_i = Retention(X, γ_i)
Y = GroupNorm_h (Concat(head_1, · · · , head_h))
MSR(X) = (swish(XW_G) ⊙ Y )W_O

где W_G, W_O -- снова обучаемые матрицы.

Также внутри много всяких нормализаций. В дополнение к GroupNorm есть нормализация QK на sqrt(d), нормализация D и QK^⊺⊙D.

Резюмируя, ну это точно трансформер. Выглядит как очередная вариация на тему линейного трансформера, в которых я уже сам запутался. Ну то есть оно конечно отличается от многого в этом зоопарке -- разреженных вниманий нет, аппроксимации софтмакса нет, так что наверное больше вариация рекуррентного трансформера, которых тоже в достатке. Теперь подобрался набор компонентов, которые и быстрое обучение дают, и быстрый инференс.

Если внимательно посмотреть на разницу с другими моделями, то во-первых таки относительно обычного трансформера, как мы упоминали, есть архитектурная разница с софтмаксом, позиционными энкодингами, кучей нормализаций и появилась рекуррентная формулировка.

На практике на языковых задачах RetNet получше дефолтного трансформера везде, и в перплексии (но только начиная с 2B), и в куче задач типа BoolQ, Winograd, StoryCloze и т.д. При этом сравнивать с дефолтным трансформером при наличии такого безумного количества улучшений тоже странно. Ну лучше по перплексии, но не то чтобы намного, а тот же Lex Transformer был заметно лучше обычного по перплексии. А по всяким BoolQ, PIQA и т.п. ну первая Llama сопоставимого размера (7B vs. 6.7B) была лучше (но конечно это нечестно сравнивать, она дольше обучалась). Непонятно, не выглядит суперулучшением качества. Но точно и не ухудшение.

Более важная история про производительность и здесь RetNet однозначно лучше стандартного трансформера, но при этом не сильно лучше чем FlashAttention. А теперь есть FlashAttention-2 (https://arxiv.org/abs/2307.08691), который намного круче первого. Но его элементы можно, наверное, и в RetNet добавить.

По памяти RetNet хорош, KV кешей нет, с ростом длины последовательности память не растёт, вообще дополнительной памяти почти не потребляет (97% памяти занимают просто веса сети). Throughput с ростом длины тоже не падает, latency тоже хорошая и не растёт ни от длины, ни от батча.

Из интересной экзотики, кстати, обучали на 512 AMD MI200 GPUs. Ну наконец то!

Из продвинутых моделей сравнивают с одним из старых линейных трансформеров (https://arxiv.org/abs/2006.16236), RWKV (https://t.me/gonzo_ML/1647), Hungry Hungry Hippos или H3 (https://arxiv.org/abs/2212.14052, это свежая SSM типа S4, https://t.me/gonzo_ML/1424) и Hyena Hierarchy (свежая свёрточная модель, https://arxiv.org/abs/2302.10866). Перплексия получается лучше. Скорость обучения не репортят, хотя вроде как у RWKV сложность ниже. И непонятно почему в таблице со сравнением для RWKV поставили отсутствие параллелизации, это странно.

Резюмируя, выглядит интересно, как альтернатива дефолтным трансформерам пробовать стоит, но в такие моменты я всегда вспоминаю истории оптимизированных трансформеров, из которых не то чтобы какой-то конкретный всех вытеснил.

Очень жду обучения реально большой модели на RetNet. В коде заготовлен retnet_65b, сделать на нём аналог Шиншиллы или Llama 2 было бы интересно.