Что такое Google JAX? Все, что Вам нужно знать

Опубликовано: 2022-08-05

Google JAX или Just After Execution это платформа, разработанная Google для ускорения задач машинного обучения.

Вы можете считать это библиотекой для Python, которая помогает ускорить выполнение задач, научные вычисления, преобразования функций, глубокое обучение, нейронные сети и многое другое.

О Google JAX

Самый фундаментальный пакет вычислений в Python — это пакет NumPy, в котором есть все функции, такие как агрегирование, векторные операции, линейная алгебра, n-мерные массивы и манипуляции с матрицами, а также многие другие расширенные функции.

Что, если бы мы могли еще больше ускорить вычисления, выполняемые с помощью NumPy, особенно для огромных наборов данных?

Есть ли у нас что-то, что могло бы одинаково хорошо работать на разных типах процессоров, таких как GPU или TPU, без каких-либо изменений кода?

Как насчет того, чтобы система могла выполнять преобразования составных функций автоматически и более эффективно?

Google JAX — это библиотека (или фреймворк, как говорит Википедия), которая делает именно это и, возможно, многое другое. Он был создан для оптимизации производительности и эффективного выполнения задач машинного обучения (ML) и глубокого обучения. Google JAX предоставляет следующие функции преобразования, которые делают его уникальным по сравнению с другими библиотеками ML и помогают в расширенных научных вычислениях для глубокого обучения и нейронных сетей:

  • Автоматическая дифференциация
  • Автоматическая векторизация
  • Автоматическое распараллеливание
  • Компиляция точно в срок (JIT)
Уникальные возможности Google JAX

Все преобразования используют XLA (ускоренную линейную алгебру) для повышения производительности и оптимизации памяти. XLA — это механизм компилятора, оптимизирующий предметную область, который выполняет линейную алгебру и ускоряет модели TensorFlow. Использование XLA поверх вашего кода Python не требует значительных изменений кода!

Рассмотрим подробно каждую из этих функций.

Возможности Google JAX.

Google JAX поставляется с важными составляющими функциями преобразования для повышения производительности и более эффективного выполнения задач глубокого обучения. Например, автодифференцирование для получения градиента функции и нахождения производных любого порядка. Точно так же автоматическое распараллеливание и JIT для параллельного выполнения нескольких задач. Эти преобразования являются ключевыми для таких приложений, как робототехника, игры и даже исследования.

Составная функция преобразования — это чистая функция, которая преобразует набор данных в другую форму. Они называются компонуемыми, поскольку являются автономными (т. е. эти функции не зависят от остальной части программы) и не имеют состояния (т. е. один и тот же ввод всегда приводит к одному и тому же результату).

Y(x) = T: (f(x))

В приведенном выше уравнении f(x) — исходная функция, к которой применено преобразование. Y(x) — результирующая функция после применения преобразования.

Например, если у вас есть функция с именем 'total_bill_amt', и вы хотите получить результат в виде функционального преобразования, вы можете просто использовать желаемое преобразование, скажем, градиент (град):

grad_total_bill = grad(total_bill_amt)

Преобразуя числовые функции с помощью таких функций, как grad(), мы можем легко получить их производные более высокого порядка, которые мы можем широко использовать в алгоритмах оптимизации глубокого обучения, таких как градиентный спуск, тем самым делая алгоритмы быстрее и эффективнее. Точно так же, используя jit(), мы можем компилировать программы Python точно в срок (лениво).

№1. Автоматическая дифференциация

Python использует функцию автоградации, чтобы автоматически различать код NumPy и собственный код Python. JAX использует модифицированную версию autograd (т. е. grad) и объединяет XLA (ускоренную линейную алгебру) для выполнения автоматического дифференцирования и поиска производных любого порядка для GPU (графических процессоров) и TPU (тензорных процессоров).]

Краткое примечание о TPU, GPU и CPU: CPU или центральный процессор управляет всеми операциями на компьютере. GPU — это дополнительный процессор, который увеличивает вычислительную мощность и выполняет высокопроизводительные операции. TPU — это мощное устройство, специально разработанное для сложных и тяжелых рабочих нагрузок, таких как алгоритмы искусственного интеллекта и глубокого обучения.

В том же духе, что и функция autograd, которая может различать циклы, рекурсии, переходы и т. д., JAX использует функцию grad() для градиентов обратного режима (обратное распространение). Кроме того, мы можем дифференцировать функцию в любом порядке, используя grad:

град(град(град(sin θ))) (1.0)

Автодифференциация высшего порядка

Как мы упоминали ранее, grad весьма полезен при нахождении частных производных функции. Мы можем использовать частную производную для вычисления градиентного спуска функции стоимости по отношению к параметрам нейронной сети в глубоком обучении, чтобы минимизировать потери.

Вычисление частной производной

Предположим, что функция имеет несколько переменных x, y и z. Нахождение производной одной переменной при неизменности остальных переменных называется частной производной. Предположим, у нас есть функция,

f(x,y,z) = x + 2y + z 2

Пример, показывающий частную производную

Частная производная x будет равна ∂f/∂x, что говорит нам о том, как функция изменяется для переменной, когда другие постоянны. Если мы выполним это вручную, мы должны написать программу для дифференцирования, применить ее для каждой переменной, а затем вычислить градиентный спуск. Это стало бы сложным и трудоемким делом для нескольких переменных.

Автодифференциация разбивает функцию на набор элементарных операций, таких как +, -, *, / или sin, cos, tan, exp и т. д., а затем применяет цепное правило для вычисления производной. Мы можем сделать это как в прямом, так и в обратном режиме.

Это не так! Все эти вычисления происходят так быстро (ну, подумайте о миллионе вычислений, подобных приведенным выше, и времени, которое это может занять!). XLA заботится о скорости и производительности.

№ 2. Ускоренная линейная алгебра

Возьмем предыдущее уравнение. Без XLA для вычисления потребуется три (или более) ядра, где каждое ядро ​​будет выполнять меньшую задачу. Например,

Ядро k1 -> x * 2y (умножение)

k2 -> x * 2y + z (сложение)

k3 -> Редукция

Если ту же задачу выполняет XLA, одно ядро ​​берет на себя все промежуточные операции, объединяя их. Промежуточные результаты элементарных операций передаются в потоковом режиме, а не сохраняются в памяти, что позволяет экономить память и повышать скорость.

№3. Своевременная компиляция

JAX внутри использует компилятор XLA для повышения скорости выполнения. XLA может повысить скорость процессора, графического процессора и TPU. Все это возможно с помощью JIT-исполнения кода. Чтобы использовать это, мы можем использовать jit через импорт:

 from jax import jit def my_function(x): …………some lines of code my_function_jit = jit(my_function)

Другой способ — украсить jit над определением функции:

 @jit def my_function(x): …………some lines of code

Этот код намного быстрее, потому что преобразование вернет скомпилированную версию кода вызывающей стороне, а не использует интерпретатор Python. Это особенно полезно для векторных входных данных, таких как массивы и матрицы.

То же самое верно и для всех существующих функций Python. Например, функции из пакета NumPy. В этом случае мы должны импортировать jax.numpy как jnp, а не как NumPy:

 import jax import jax.numpy as jnp x = jnp.array([[1,2,3,4], [5,6,7,8]])

Как только вы это сделаете, основной объект массива JAX с именем DeviceArray заменит стандартный массив NumPy. DeviceArray ленив — значения хранятся в ускорителе до тех пор, пока они не потребуются. Это также означает, что программа JAX не ждет возврата результатов вызывающей программе (Python), следуя асинхронной отправке.

№ 4. Автоматическая векторизация (vmap)

В типичном мире машинного обучения у нас есть наборы данных с миллионом или более точек данных. Скорее всего, мы бы выполнили некоторые вычисления или манипуляции с каждой или большинством этих точек данных, что требует очень много времени и памяти! Например, если вы хотите найти квадрат каждой из точек данных в наборе данных, первое, о чем вы подумали бы, это создать цикл и взять квадрат один за другим — ааа!

Если мы создадим эти точки в виде векторов, мы сможем создать все квадраты за один раз, выполняя векторные или матричные манипуляции с точками данных с помощью нашего любимого NumPy. И если бы ваша программа могла делать это автоматически — можете ли вы требовать чего-то большего? Это именно то, что делает JAX! Он может автоматически векторизовать все ваши точки данных, чтобы вы могли легко выполнять с ними любые операции, что делает ваши алгоритмы намного быстрее и эффективнее.

JAX использует функцию vmap для автоматической векторизации. Рассмотрим следующий массив:

 x = jnp.array([1,2,3,4,5,6,7,8,9,10]) y = jnp.square(x)

Выполняя все вышеперечисленное, метод Square будет выполняться для каждой точки в массиве. Но если сделать следующее:

 vmap(jnp.square(x))

Метод Square будет выполняться только один раз, потому что точки данных теперь автоматически векторизируются с использованием метода vmap перед выполнением функции, а циклы переходят на элементарный уровень операций, что приводит к матричному умножению, а не скалярному умножению, что обеспечивает лучшую производительность. .

№ 5. Программирование СПМД (pmap)

SPMD или одиночная программа Программирование нескольких данных очень важно в контексте глубокого обучения — вы часто применяете одни и те же функции к разным наборам данных, хранящихся на нескольких графических процессорах или TPU. В JAX есть функция pump, которая позволяет выполнять параллельное программирование на нескольких графических процессорах или любом ускорителе. Как и JIT, программы, использующие pmap, будут компилироваться XLA и выполняться одновременно во всех системах. Это автоматическое распараллеливание работает как для прямых, так и для обратных вычислений.

Как работает пмап

Мы также можем применить несколько преобразований за один раз в любом порядке к любой функции, например:

pmap (vmap (jit (град (f (x)))))

Несколько составных преобразований

Ограничения Google JAX

Разработчики Google JAX хорошо подумали об ускорении алгоритмов глубокого обучения, внедрив все эти потрясающие преобразования. Функции и пакеты научных вычислений аналогичны NumPy, поэтому вам не нужно беспокоиться о кривой обучения. Однако JAX имеет следующие ограничения:

  • Google JAX все еще находится на ранних стадиях разработки, и хотя его основная цель — оптимизация производительности, он не дает особых преимуществ для вычислений на ЦП. NumPy работает лучше, а использование JAX может только увеличить накладные расходы.
  • JAX все еще находится в стадии исследования или на ранних стадиях и нуждается в более тонкой настройке, чтобы достичь стандартов инфраструктуры таких фреймворков, как TensorFlow, которые более устоялись и имеют больше предопределенных моделей, проектов с открытым исходным кодом и учебных материалов.
  • На данный момент JAX не поддерживает операционную систему Windows — для его работы вам понадобится виртуальная машина.
  • JAX работает только с чистыми функциями, не имеющими побочных эффектов. Для функций с побочными эффектами JAX может оказаться не лучшим вариантом.

Как установить JAX в вашей среде Python

Если в вашей системе установлен Python и вы хотите запустить JAX на своем локальном компьютере (ЦП), используйте следующие команды:

 pip install --upgrade pip pip install --upgrade "jax[cpu]"

Если вы хотите запустить Google JAX на GPU или TPU, следуйте инструкциям на странице GitHub JAX. Чтобы настроить Python, посетите официальную страницу загрузки Python.

Вывод

Google JAX отлично подходит для написания эффективных алгоритмов глубокого обучения, робототехники и исследований. Несмотря на ограничения, он широко используется с другими фреймворками, такими как Haiku, Flax и многими другими. Вы сможете оценить, что делает JAX при запуске программ, и увидеть разницу во времени выполнения кода с JAX и без него. Вы можете начать с прочтения официальной документации Google JAX, которая достаточно обширна.