Apa itu Google JAX? Semua yang Perlu Anda Ketahui

Diterbitkan: 2022-08-05

Google JAX atau Just A after E x ecution adalah kerangka kerja yang dikembangkan oleh Google untuk mempercepat tugas pembelajaran mesin.

Anda dapat menganggapnya sebagai perpustakaan untuk Python, yang membantu dalam eksekusi tugas yang lebih cepat, komputasi ilmiah, transformasi fungsi, pembelajaran mendalam, jaringan saraf, dan banyak lagi.

Tentang Google JAX

Paket komputasi paling mendasar dalam Python adalah paket NumPy yang memiliki semua fungsi seperti agregasi, operasi vektor, aljabar linier, array n-dimensi dan manipulasi matriks, dan banyak fungsi lanjutan lainnya.

Bagaimana jika kita dapat lebih mempercepat kalkulasi yang dilakukan menggunakan NumPy – khususnya untuk kumpulan data yang sangat besar?

Apakah kami memiliki sesuatu yang dapat bekerja dengan baik pada berbagai jenis prosesor seperti GPU atau TPU, tanpa perubahan kode apa pun?

Bagaimana jika sistem dapat melakukan transformasi fungsi yang dapat dikomposisi secara otomatis dan lebih efisien?

Google JAX adalah perpustakaan (atau kerangka kerja, seperti yang dikatakan Wikipedia) yang melakukan hal itu dan mungkin lebih banyak lagi. Itu dibuat untuk mengoptimalkan kinerja dan secara efisien melakukan pembelajaran mesin (ML) dan tugas pembelajaran mendalam. Google JAX menyediakan fitur transformasi berikut yang membuatnya unik dari pustaka ML lainnya dan membantu dalam komputasi ilmiah tingkat lanjut untuk pembelajaran mendalam dan jaringan saraf:

  • Diferensiasi otomatis
  • Vektorisasi otomatis
  • Paralelisasi otomatis
  • Kompilasi just-in-time (JIT)
Fitur unik Google JAX

Semua transformasi menggunakan XLA (Accelerated Linear Algebra) untuk kinerja yang lebih tinggi dan optimalisasi memori. XLA adalah mesin kompiler pengoptimalan khusus domain yang menjalankan aljabar linier dan mempercepat model TensorFlow. Menggunakan XLA di atas kode Python Anda tidak memerlukan perubahan kode yang signifikan!

Mari kita telusuri secara detail masing-masing fitur tersebut.

Fitur Google JAX

Google JAX hadir dengan fungsi transformasi penting yang dapat dikomposisi untuk meningkatkan kinerja dan melakukan tugas pembelajaran mendalam dengan lebih efisien. Misalnya, diferensiasi otomatis untuk mendapatkan gradien suatu fungsi dan menemukan turunan dari urutan apa pun. Demikian pula, paralelisasi otomatis dan JIT untuk melakukan banyak tugas secara paralel. Transformasi ini adalah kunci untuk aplikasi seperti robotika, game, dan bahkan penelitian.

Fungsi transformasi yang dapat dikomposisi adalah fungsi murni yang mengubah sekumpulan data menjadi bentuk lain. Mereka disebut dapat dikomposisi karena mandiri (yaitu, fungsi-fungsi ini tidak memiliki ketergantungan dengan program lainnya) dan tidak memiliki status (yaitu, input yang sama akan selalu menghasilkan output yang sama).

Y(x) = T: (f(x))

Dalam persamaan di atas, f(x) adalah fungsi asli di mana transformasi diterapkan. Y(x) adalah fungsi yang dihasilkan setelah transformasi diterapkan.

Misalnya, jika Anda memiliki fungsi bernama 'total_bill_amt', dan Anda menginginkan hasilnya sebagai transformasi fungsi, Anda cukup menggunakan transformasi yang Anda inginkan, katakanlah gradien (grad):

lulusan_total_tagihan = lulusan(total_tagihan_amt)

Dengan mentransformasikan fungsi numerik menggunakan fungsi seperti grad(), kita dapat dengan mudah mendapatkan turunan orde yang lebih tinggi, yang dapat kita gunakan secara ekstensif dalam algoritme pengoptimalan pembelajaran mendalam seperti penurunan gradien, sehingga membuat algoritme lebih cepat dan lebih efisien. Demikian pula, dengan menggunakan jit(), kita dapat mengkompilasi program Python just-in-time (malas).

#1. Diferensiasi otomatis

Python menggunakan fungsi autograd untuk secara otomatis membedakan NumPy dan kode Python asli. JAX menggunakan versi autograd yang dimodifikasi (yaitu, lulusan) dan menggabungkan XLA (Accelerated Linear Algebra) untuk melakukan diferensiasi otomatis dan menemukan turunan dari urutan apa pun untuk GPU (Unit Pemrosesan Grafis) dan TPU (Unit Pemrosesan Tensor).]

Catatan singkat tentang TPU, GPU, dan CPU: CPU atau Central Processing Unit mengelola semua operasi di komputer. GPU adalah prosesor tambahan yang meningkatkan daya komputasi dan menjalankan operasi kelas atas. TPU adalah unit kuat yang secara khusus dikembangkan untuk beban kerja yang kompleks dan berat seperti AI dan algoritme pembelajaran mendalam.

Sejalan dengan fungsi autograd, yang dapat membedakan melalui loop, rekursi, cabang, dan sebagainya, JAX menggunakan fungsi grad() untuk gradien mode-balik (backpropagation). Juga, kita dapat membedakan fungsi ke urutan apa pun menggunakan grad:

lulusan(lulusan(lulus(sin ))) (1.0)

Diferensiasi otomatis dari urutan yang lebih tinggi

Seperti yang kami sebutkan sebelumnya, grad cukup berguna dalam menemukan turunan parsial dari suatu fungsi. Kita dapat menggunakan turunan parsial untuk menghitung penurunan gradien fungsi biaya sehubungan dengan parameter jaringan saraf dalam pembelajaran mendalam untuk meminimalkan kerugian.

Menghitung turunan parsial

Misalkan suatu fungsi memiliki beberapa variabel, x, y, dan z. Menemukan turunan dari satu variabel dengan menjaga variabel lain konstan disebut turunan parsial. Misalkan kita memiliki fungsi,

f(x,y,z) = x + 2y + z 2

Contoh untuk menunjukkan turunan parsial

Turunan parsial dari x akan menjadi f/∂x, yang memberitahu kita bagaimana suatu fungsi berubah untuk suatu variabel ketika yang lain konstan. Jika kita melakukan ini secara manual, kita harus menulis program untuk membedakan, menerapkannya untuk setiap variabel, dan kemudian menghitung penurunan gradien. Ini akan menjadi urusan yang kompleks dan memakan waktu untuk banyak variabel.

Diferensiasi otomatis memecah fungsi menjadi serangkaian operasi dasar, seperti +, -, *, / atau sin, cos, tan, exp, dll., lalu menerapkan aturan rantai untuk menghitung turunannya. Kita dapat melakukan ini dalam mode maju dan mundur.

Ini bukan ! Semua perhitungan ini terjadi begitu cepat (well, pikirkan tentang satu juta perhitungan yang mirip dengan di atas dan waktu yang diperlukan!). XLA menangani kecepatan dan kinerja.

#2. Aljabar Linier Dipercepat

Mari kita ambil persamaan sebelumnya. Tanpa XLA, komputasi akan membutuhkan tiga (atau lebih) kernel, di mana setiap kernel akan melakukan tugas yang lebih kecil. Sebagai contoh,

Kernel k1 -> x * 2y (perkalian)

k2 -> x * 2y + z (penjumlahan)

k3 -> Pengurangan

Jika tugas yang sama dilakukan oleh XLA, satu kernel menangani semua operasi perantara dengan menggabungkannya. Hasil antara dari operasi dasar dialirkan alih-alih menyimpannya dalam memori, sehingga menghemat memori dan meningkatkan kecepatan.

#3. Kompilasi tepat waktu

JAX secara internal menggunakan kompiler XLA untuk meningkatkan kecepatan eksekusi. XLA dapat meningkatkan kecepatan CPU, GPU, dan TPU. Semua ini dimungkinkan dengan menggunakan eksekusi kode JIT. Untuk menggunakan ini, kita dapat menggunakan jit melalui impor:

 from jax import jit def my_function(x): …………some lines of code my_function_jit = jit(my_function)

Cara lain adalah dengan mendekorasi jit di atas definisi fungsi:

 @jit def my_function(x): …………some lines of code

Kode ini jauh lebih cepat karena transformasi akan mengembalikan versi kode yang dikompilasi ke pemanggil daripada menggunakan interpreter Python. Ini sangat berguna untuk input vektor, seperti array dan matriks.

Hal yang sama juga berlaku untuk semua fungsi python yang ada. Misalnya, fungsi dari paket NumPy. Dalam hal ini, kita harus mengimpor jax.numpy sebagai jnp daripada NumPy:

 import jax import jax.numpy as jnp x = jnp.array([[1,2,3,4], [5,6,7,8]])

Setelah Anda melakukan ini, objek larik JAX inti yang disebut DeviceArray menggantikan larik NumPy standar. DeviceArray malas – nilainya disimpan di akselerator sampai dibutuhkan. Ini juga berarti bahwa program JAX tidak menunggu hasil untuk kembali ke program panggilan (Python), sehingga mengikuti pengiriman asinkron.

#4. Vektorisasi otomatis (vmap)

Dalam dunia pembelajaran mesin yang khas, kami memiliki kumpulan data dengan satu juta atau lebih titik data. Kemungkinan besar, kami akan melakukan beberapa perhitungan atau manipulasi pada setiap atau sebagian besar titik data ini – yang merupakan tugas yang sangat memakan waktu dan memori! Misalnya, jika Anda ingin menemukan kuadrat dari setiap titik data dalam kumpulan data, hal pertama yang Anda pikirkan adalah membuat lingkaran dan mengambil kuadratnya satu per satu – argh!

Jika kita membuat titik-titik ini sebagai vektor, kita bisa melakukan semua kuadrat sekaligus dengan melakukan manipulasi vektor atau matriks pada titik data dengan NumPy favorit kita. Dan jika program Anda dapat melakukan ini secara otomatis – dapatkah Anda meminta lebih banyak lagi? Itulah yang dilakukan JAX! Ini dapat secara otomatis membuat vektor semua titik data Anda sehingga Anda dapat dengan mudah melakukan operasi apa pun pada mereka – membuat algoritme Anda jauh lebih cepat dan lebih efisien.

JAX menggunakan fungsi vmap untuk vektorisasi otomatis. Perhatikan larik berikut:

 x = jnp.array([1,2,3,4,5,6,7,8,9,10]) y = jnp.square(x)

Dengan melakukan hal di atas, metode kuadrat akan dijalankan untuk setiap titik dalam array. Tetapi jika Anda melakukan hal berikut:

 vmap(jnp.square(x))

Metode kuadrat akan dijalankan hanya sekali karena titik data sekarang divektorkan secara otomatis menggunakan metode vmap sebelum menjalankan fungsi, dan perulangan didorong ke bawah ke tingkat operasi dasar – menghasilkan perkalian matriks daripada perkalian skalar, sehingga memberikan kinerja yang lebih baik .

#5. Pemrograman SPMD (pmap)

SPMD – atau S ingle P rogram Pemrograman multi D a sangat penting dalam konteks pembelajaran mendalam – Anda akan sering menerapkan fungsi yang sama pada kumpulan data berbeda yang berada di beberapa GPU atau TPU. JAX memiliki fungsi bernama pump, yang memungkinkan pemrograman paralel pada beberapa GPU atau akselerator apa pun. Seperti JIT, program yang menggunakan pmap akan dikompilasi oleh XLA dan dieksekusi secara bersamaan di seluruh sistem. Paralelisasi otomatis ini berfungsi untuk komputasi maju dan mundur.

Bagaimana cara kerja pmap

Kami juga dapat menerapkan beberapa transformasi sekaligus dalam urutan apa pun pada fungsi apa pun sebagai:

pmap(vmap(jit(grad (f(x))))))

Beberapa transformasi yang dapat dikomposisi

Keterbatasan Google JAX

Pengembang Google JAX telah memikirkan dengan baik tentang mempercepat algoritme pembelajaran mendalam sambil memperkenalkan semua transformasi yang luar biasa ini. Fungsi dan paket komputasi ilmiah ada di baris NumPy, jadi Anda tidak perlu khawatir tentang kurva pembelajaran. Namun, JAX memiliki batasan berikut:

  • Google JAX masih dalam tahap awal pengembangan, dan meskipun tujuan utamanya adalah pengoptimalan kinerja, itu tidak memberikan banyak manfaat untuk komputasi CPU. NumPy tampaknya berkinerja lebih baik, dan menggunakan JAX hanya dapat menambah overhead.
  • JAX masih dalam tahap penelitian atau tahap awal dan membutuhkan lebih banyak penyesuaian untuk mencapai standar infrastruktur kerangka kerja seperti TensorFlow, yang lebih mapan dan memiliki lebih banyak model yang telah ditentukan sebelumnya, proyek sumber terbuka, dan materi pembelajaran.
  • Sampai sekarang, JAX tidak mendukung Sistem Operasi Windows – Anda memerlukan mesin virtual untuk membuatnya bekerja.
  • JAX hanya bekerja pada fungsi murni – yang tidak memiliki efek samping. Untuk fungsi dengan efek samping, JAX mungkin bukan pilihan yang baik.

Cara menginstal JAX di lingkungan Python Anda

Jika Anda memiliki pengaturan python di sistem Anda dan ingin menjalankan JAX di mesin lokal (CPU), gunakan perintah berikut:

 pip install --upgrade pip pip install --upgrade "jax[cpu]"

Jika Anda ingin menjalankan Google JAX pada GPU atau TPU, ikuti instruksi yang diberikan di halaman GitHub JAX. Untuk mengatur Python, kunjungi halaman unduhan resmi python.

Kesimpulan

Google JAX sangat bagus untuk menulis algoritme pembelajaran mendalam yang efisien, robotika, dan penelitian. Terlepas dari keterbatasannya, ini digunakan secara luas dengan kerangka kerja lain seperti Haiku, Flax, dan banyak lagi. Anda akan dapat menghargai apa yang JAX lakukan ketika Anda menjalankan program dan melihat perbedaan waktu dalam mengeksekusi kode dengan dan tanpa JAX. Anda dapat memulai dengan membaca dokumentasi resmi Google JAX, yang cukup lengkap.