Нейролента - подборка новостей о нейронных сетях, ChatGPT

Авторы предлагают mesa-layer, как вариант self-attention, который полностью...

Авторы предлагают mesa-layer, как вариант self-attention, который полностью решает оптимизационную задачу слоя (в смысле явно минимизирует L2 между предсказанием и таргетом, с регуляризацией) вместо лишь выполнения одного градиентного шага. В реализации этого варианта внимания есть дополнительная матрица R, которую если убрать, то получится стандартный линейный SA. Меза-слой вычислительно более тяжёлый, но главная проблема, что он не параллелится как и классические RNN.

Возвращаясь к экспериментам, берут линейную динамическую систему с шумом вида:
s_{t+1} = W∗ s_t + ϵ_t, где W* -- случайная ортогональная матрица. Для генерации каждой последовательности берут новую матрицу. Трансформер обучают на минимизацию авторегрессионного лосса и реверс-инжинирят.

Исследуют per-timestep loss L_t(s_{1:t}, θ) и его эволюцию в зависимости от длины контекста, то есть как улучшается качество предсказания при увеличении контекста. Это соответствует операционному определению in-context learning из классической статьи про скейлинг (https://arxiv.org/abs/2001.08361).

Гипотеза в том, что базовая оптимизация (собственно обучение трансформера) ведёт к появлению меза-оптимизации, и будущие значения последовательности предсказываются внутри forward pass. Процедура выглядит так:
1) линейная модель представляется меза-параметрами W
2) конструируется mesa-objective с использованием данных внутри контекста
3) W находится через минимизацию mesa-objective
4) полученная W используется для предсказания.

Репрезентация токена сделана хитрой трёхканальной, первый канал используется для предсказания будущего входа, а другие два содержат текущий и предыдущий входные элементы. Это ведёт к очень разреженным матрицам весов, которые легко реверсить.

Начинают с однослойного линейного трансформера и идентифицируют алгоритм, используемый для предсказания. Проверяют, что слой реализует шаг меза-градиентного спуска, 1) сравнивая с линейной авторегрессионной моделью, обученной одним шагом градиентного спуска, 2) изучая интерполированную модель, полученную усреднением выученных и сконструированных весов. Всё очень хорошо совпадает.

А если вместо линейного SA вставить mesa-layer, то качество на порядок лучше. То есть inductive bias для меза-оптимизации очень помогает.

Затем берут многослойный трансформер, линейный и с софтмаксом, но без FFN. Там тоже реверсят алгоритм, он описывается 16 параметрами (вместо 3200) на голову внимания. Но интерпретировать это как алгоритм меза-оптимизации сложно и авторы делают linear regression probing analysis. Например, ищут stacked multi-layer gradient descent construction, в ней выходы промежуточных слоёв должны постепенно приближаться к цели. Также ищут следы iterative preconditioning algorithm. Пробинг подтверждает гипотезы.

В конце обучают уже полноценные трансформеры без архитектурных упрощений, с позиционными энкодингами и без хитрых многоканальных представлений токена. Здесь гипотеза, что сначала модель в первом слое восстанавливает специальное представление токена, удобное для меза-обучения, а последующие слои его реализуют. Действительно, после первого слоя токен в основном зависит только от себя и предыдущего. Эту процедуру авторы назвали “creating a mesa-dataset". Дальше поведение очень похоже на наличие двухступенчатой процедуры с precondition + optimization.

Далее проверяют few-shot learning. Здесь трансформер обучают на ту же задачу авторегрессионного предсказания, что и раньше, но теперь после обучения модель просят через few-shot learning решить другую задачу -- регрессию. Выученный трансформером алгоритм меза-оптимизации справляется. Промпт-тюнинг и файнтюнинг одного EOS токена ещё всё улучшают. Есть ещё прикольный эксперимент с двумя задачами в промпте, когда через какое-то время новая задача сменяет текущую. Трансформер справляется с тем, чтобы по мере инференса переписать старую задачу и выучить новую.

Получается, трансформеры, обученные на задачу предсказания следующего элемента, можно перепрофилировать на новую задачу через in-context learning, поскольку алгоритм внутри forward pass остаётся похожим.