Нейролента - подборка новостей о нейронных сетях, ChatGPT

Сделать такое эффективно -- это челлендж, авторы реализовали...

Сделать такое эффективно -- это челлендж, авторы реализовали алгоритм parallel scan с умным использованием иерархии памяти GPU, что-то происходит в быстрой SRAM, что-то в более медленной HBM. В сочетании с kernel fusion и recomputation получается весьма эффективная реализация с требованиями к памяти как у оптимизированной реализации трансформера с FlashAttention (соавтор текущей работы Tri Dao является и соавтором FlashAttention).

Модели selective SSM в работе иногда называют S6 моделями, потому что S4 + selection mechanism + computed with a scan.

Итоговая архитектура представляет собой микс SSM (здесь, H3, https://arxiv.org/abs/2212.14052) и MLP блоков из трансформера в одном новом блоке, который дальше можно гомогенно стыковать. Внутри блока model dimension D сначала увеличивается на фактор E=2 и из-за этого большую часть параметров блока составляют линейные проекции на входе и выходе, а не сама SSM. Полученный блок в чередовании со стандартной нормализацией (кажется, это RMSNorm или LayerNorm) и residual connection даёт архитектуру под названием Mamba. Там же активации SiLU / Swish и опциональный LayerNorm в той же позиции, что и у RetNet (https://t.me/gonzo_ML/1753).

Модель по дефолту использует действительные числа (многие предыдущие SSM использовали комплексные), и это хорошо работает везде кроме одной задачи. Авторы предполагают, что комплексные числа могут быть полезными в непрерывных модальностях типа аудио/видео, но не в дискретных типа текста или ДНК. Инициализация взята из S4D-Lin/S4D-Real.

Проверяли много на чём.

Сначала синтетические задачи. Selective Copying работает отлично, очень близко к 100%. На задачках с Induction Heads тоже всё супер.

Проверили на языковом моделировании с обучением на Pile и по рецептам из статьи про GPT-3. Сравниваются со стандартной архитектурой (здесь GPT-3), а также с продвинутыми трансформерами (обозначены как Transformer++), основанными на архитектурах PaLM и LLaMa. Тестировали на размерах от 125M до 1.3B параметров. В итоге Mamba -- первая модель без внимания, достигшая качества сильных трансформерных рецептов.

На разных downstream zero-shot задачах качество выше, чем у сопоставимых по размеру Pythia, GPT-Neo, OPT, RWKV (https://t.me/gonzo_ML/1647). А иногда выше, чем и у в два раза более тяжёлых.

На задачах моделирования последовательности ДНК кривые скейлинга тоже отличные, качество на downstream задачах зачётное.

На аудио сравнились с SaShiMi (https://arxiv.org/abs/2202.09729), вроде как на авторегрессионном обучении там она была SoTA. Побили. На генерации речи (датасет SC09) бьёт и её же, и WaveNet с WaveGAN.

По производительности SSM scan текущая имплементация очень хороша, лучше лучшей трансформерной имплементации (FlashAttention-2) и в 20-40 раз лучше пайторчового скана. На инференсе throughput выше сопоставимого трансформера в 4-5 раз (потому что за ненадобностью KV кеша можно делать большие батчи). Так у Mamba-6.9B throughput на инференсе выше, чем у Transformer-1.3B.

Много интересных абляций. И блоки S6, и архитектура Mamba рулят. S6 явно лучше S4, а мамба сравнима с H3 и проще её.

Бомбическая архитектура в общем. Ждём натренированное что-то очень большое. Кстати, на днях также появилась нетрансформерная StripedHyena-7B (https://www.together.ai/blog/stripedhyena-7b) тоже из когорты SSM. Про гиену мы пока так и не написали, но может быть доберёмся таки (как и про бегемотов). На бенчмарках выглядит как сравнимая с Mistral 7B, что круто. Мамба наверное ещё круче должна быть, обычную гиену она бьёт (тут, правда, необычная).

Вангую, 2024-й должен быть годом SSM-LLM.