Нейролента - подборка новостей о нейронных сетях, ChatGPT

Learning to Model the World with Language. Jessy...

Learning to Model the World with Language
Jessy Lin, Yuqing Du, Olivia Watkins, Danijar Hafner, Pieter Abbeel, Dan Klein, Anca Dragan
Статья: https://arxiv.org/abs/2308.01399
Сайт: https://dynalang.github.io/

Интересная работа из серии про World Models. Мы по этой теме практически ничего не успели написать (https://t.me/gonzo_ML/186), но она интересная, развивается уже не первый год, и относительно свежий толчок несколько лет назад дал ей наш любимый Шмидхубер (и не менее любимый Дэвид Ха, https://arxiv.org/abs/1803.10122). Идея там была в том, что агент может выучить модель мира и дальше оттачивать свои навыки в ней, то есть в симуляции. Получалось неплохо (https://worldmodels.github.io/).

С тех пор много всего появилось, всё не перечислишь, одна из популярных моделей была Dreamer (https://arxiv.org/abs/1912.01603), которая дошла до 3-й версии DreamerV3 (https://arxiv.org/abs/2301.04104). Один из соавторов текущей работы, Danijar Hafner, как раз автор Дримера. И на самом деле текущая модель это расширение DreamerV3 на работу с языком на входе и опционально на выходе.

Новая работа представляет агента Dynalang, который выучивает мультимодальную модель мира и добавляет в микс язык. Язык использовали и раньше, по крайней мере на входе, чтобы предсказывать действия агента (например, когда агент получал текстовую команду что-то сделать). Но маппинг языка в действия, особенно если единственным обучающим сигналом является награда, это довольно слабый сигнал чтобы выучить богатые текстовые репрезентации мира и понимать не только прямые инструкции, но и фразы, относящиеся к состоянию этого мира. Гипотеза авторов в том, что предсказание будущих репрезентаций даёт богатый сигнал, чтобы понять язык и как он соотносится с миром вокруг. Язык теперь также используется и чтобы предсказывать будущие языковые и видео наблюдения, а также награды.

Dynalang разъединяет (в смысле decouple) обучение моделированию мира с помощью языка (supervised learning with prediction objectives) и обучение действиям в этом мире c использованием модели (reinforcement learning with task rewards).

Задача модели мира (world model, или далее просто WM) -- сжать входной текстовый и зрительный сигналы в латентное представление и научиться предсказывать будущие латентные представления по набранным наблюдениям взаимодействия агента в среде. Это латентное представление от WM поступает на вход полиси, которая предсказывает действия и максимизирует награду.

Благодаря этому разделению, Dynalang можно предобучать на одиночных модальностях типа текста или видео без всяких действий и наград.

Во фреймворк можно также добавить генерацию текста, когда восприятие агента даёт сигнал его языковой модели и он получает возможность “говорить в среду”.

Более формально, в интерактивных задачах агент выбирает действие a_t в среде. В большинстве экспериментов это одно из дискретных действий, то есть просто целое число. Но опционально может быть ещё и языковой токен. Из среды в ответ поступает награда r_t, флажок продолжения эпизода c_t, и наблюдение o_t, состоящее из пары: картинка x_t и языковой токен l_t. То есть получается что на входе и выходе появляется лишь по одному токену на кадр, и в работе показали, что token-level представления работают лучше чем sentence-level. Задача как обычно максимизировать ожидаемую дисконтируемую сумму наград.

WM -- это Recurrent State Space Model (RSSM, https://arxiv.org/abs/1811.04551) на базе GRU со скрытым рекуррентным состоянием h_t.

В каждый момент времени (x_t, l_t, h_t) кодируется энкодером (VAE) в латентное состояние z_t:

z_t ∼ enc(x_t, l_t, h_t)

Sequence model (GRU) выдаёт (z’_t, h_t) по предыдущим (z, h, a) от момента t-1:

z’_t, h_t = seq(z_{t−1}, h_{t−1}, a_{t−1})

Наконец декодер по (z_t, h_t) восстанавливает (x_t, l_t, r_t, c_t):

x’_t, l’_t, r’_t, c’_t = dec(z_t, h_t)

При этом для картиночных входов и выходов используется CNN, а для всех остальных MLP.

WM обучается на сумме representation learning loss (L_repr) и future prediction loss (L_pred).