Google JAX คืออะไร? ทุกสิ่งที่คุณต้องรู้

เผยแพร่แล้ว: 2022-08-05

Google JAX หรือ J ust A หลังจาก E x ecution เป็นเฟรมเวิร์กที่ Google พัฒนาขึ้นเพื่อเร่งความเร็วของงานการเรียนรู้ของเครื่อง

คุณสามารถพิจารณาว่าเป็นไลบรารี่สำหรับ Python ซึ่งช่วยในการดำเนินงานได้เร็วขึ้น การคำนวณทางวิทยาศาสตร์ การแปลงฟังก์ชัน การเรียนรู้เชิงลึก โครงข่ายประสาทเทียม และอื่นๆ อีกมากมาย

เกี่ยวกับ Google JAX

แพ็คเกจการคำนวณพื้นฐานที่สุดใน Python คือแพ็คเกจ NumPy ซึ่งมีฟังก์ชันทั้งหมด เช่น การรวม การดำเนินการเวกเตอร์ พีชคณิตเชิงเส้น อาร์เรย์ n มิติและการจัดการเมทริกซ์ และฟังก์ชันขั้นสูงอื่นๆ อีกมากมาย

จะเกิดอะไรขึ้นถ้าเราสามารถเร่งการคำนวณที่ดำเนินการโดยใช้ NumPy – โดยเฉพาะสำหรับชุดข้อมูลขนาดใหญ่ได้

เรามีสิ่งที่สามารถทำงานได้ดีพอๆ กันกับโปรเซสเซอร์ประเภทต่างๆ เช่น GPU หรือ TPU โดยไม่มีการเปลี่ยนแปลงโค้ดใดๆ หรือไม่

แล้วถ้าระบบสามารถทำการแปลงฟังก์ชันแบบคอมโพสิทได้โดยอัตโนมัติและมีประสิทธิภาพมากขึ้นล่ะ?

Google JAX เป็นไลบรารี่ (หรือเฟรมเวิร์กตามที่ Wikipedia บอก) ที่ทำอย่างนั้นและอาจมากกว่านั้นอีก สร้างขึ้นเพื่อเพิ่มประสิทธิภาพและดำเนินการแมชชีนเลิร์นนิง (ML) และงานการเรียนรู้เชิงลึกอย่างมีประสิทธิภาพ Google JAX มีคุณลักษณะการแปลงต่อไปนี้ที่ทำให้แตกต่างจากไลบรารี ML อื่น ๆ และช่วยในการคำนวณทางวิทยาศาสตร์ขั้นสูงสำหรับการเรียนรู้เชิงลึกและเครือข่ายประสาท:

  • แยกความแตกต่างอัตโนมัติ
  • เวกเตอร์อัตโนมัติ
  • การขนานอัตโนมัติ
  • การรวบรวมแบบทันเวลาพอดี (JIT)
คุณลักษณะเฉพาะของ Google JAX

การแปลงทั้งหมดใช้ XLA (Accelerated Linear Algebra) เพื่อประสิทธิภาพที่สูงขึ้นและการเพิ่มประสิทธิภาพหน่วยความจำ XLA เป็นเอ็นจิ้นคอมไพเลอร์สำหรับปรับแต่งเฉพาะโดเมนที่ทำงานพีชคณิตเชิงเส้นและเร่งโมเดล TensorFlow การใช้ XLA ที่ด้านบนของโค้ด Python ไม่จำเป็นต้องมีการเปลี่ยนแปลงโค้ดที่สำคัญ!

มาสำรวจรายละเอียดคุณสมบัติเหล่านี้กัน

คุณสมบัติของ Google JAX

Google JAX มาพร้อมกับฟังก์ชันการแปลงที่เขียนได้ที่สำคัญเพื่อปรับปรุงประสิทธิภาพและทำงานการเรียนรู้เชิงลึกอย่างมีประสิทธิภาพมากขึ้น ตัวอย่างเช่น การแยกความแตกต่างอัตโนมัติเพื่อรับการไล่ระดับสีของฟังก์ชันและค้นหาอนุพันธ์ของลำดับใดๆ ในทำนองเดียวกัน auto parallelization และ JIT เพื่อทำงานหลายอย่างพร้อมกัน การเปลี่ยนแปลงเหล่านี้เป็นกุญแจสำคัญในการใช้งานต่างๆ เช่น หุ่นยนต์ เกม และแม้แต่การวิจัย

ฟังก์ชันการแปลงแบบผสม ได้คือฟังก์ชัน บริสุทธิ์ ที่แปลงชุดข้อมูลเป็นอีกรูปแบบหนึ่ง พวกเขาถูกเรียกว่า composable เนื่องจากมีความสมบูรณ์ในตัวเอง (กล่าวคือ ฟังก์ชันเหล่านี้ไม่มีการพึ่งพากับส่วนที่เหลือของโปรแกรม) และไร้สัญชาติ (เช่น อินพุตเดียวกันจะส่งผลให้เกิดเอาต์พุตเดียวกันเสมอ)

Y(x) = T: (f(x))

ในสมการข้างต้น f(x) คือฟังก์ชันดั้งเดิมที่ใช้การแปลง Y(x) เป็นฟังก์ชันผลลัพธ์หลังจากใช้การแปลง

ตัวอย่างเช่น หากคุณมีฟังก์ชันชื่อ 'total_bill_amt' และคุณต้องการผลลัพธ์เป็นการแปลงฟังก์ชัน คุณสามารถใช้การแปลงที่คุณต้องการได้ สมมติว่า gradient (grad):

grad_total_bill = ผู้สำเร็จการศึกษา (total_bill_amt)

ด้วยการแปลงฟังก์ชันตัวเลขโดยใช้ฟังก์ชันอย่าง grad() เราจึงสามารถรับอนุพันธ์อันดับสูงกว่าได้อย่างง่ายดาย ซึ่งเราสามารถใช้อย่างกว้างขวางในอัลกอริธึมการเพิ่มประสิทธิภาพการเรียนรู้เชิงลึก เช่น การไล่ระดับสี descent ซึ่งทำให้อัลกอริทึมเร็วขึ้นและมีประสิทธิภาพมากขึ้น ในทำนองเดียวกัน โดยใช้ jit() เราสามารถคอมไพล์โปรแกรม Python ได้ทันท่วงที (อย่างเกียจคร้าน)

#1. แยกความแตกต่างอัตโนมัติ

Python ใช้ฟังก์ชัน autograd เพื่อแยกความแตกต่างของ NumPy และโค้ด Python ดั้งเดิมโดยอัตโนมัติ JAX ใช้ autograd เวอร์ชันดัดแปลง (เช่น grad) และรวม XLA (Accelerated Linear Algebra) เพื่อดำเนินการสร้างความแตกต่างโดยอัตโนมัติและค้นหาอนุพันธ์ของลำดับใดๆ สำหรับ GPU (หน่วยประมวลผลกราฟิก) และ TPU (หน่วยประมวลผลเทนเซอร์)]

บันทึกย่อเกี่ยวกับ TPU, GPU และ CPU: CPU หรือหน่วยประมวลผลกลางจัดการการทำงานทั้งหมดบนคอมพิวเตอร์ GPU เป็นโปรเซสเซอร์เพิ่มเติมที่ช่วยเพิ่มพลังการประมวลผลและทำงานระดับไฮเอนด์ TPU เป็นหน่วยที่มีประสิทธิภาพซึ่งพัฒนาขึ้นโดยเฉพาะสำหรับปริมาณงานที่ซับซ้อนและหนัก เช่น AI และอัลกอริธึมการเรียนรู้เชิงลึก

JAX ใช้ฟังก์ชัน grad() สำหรับการไล่ระดับโหมดย้อนกลับ (backpropagation) ในลักษณะเดียวกับฟังก์ชัน autograd ซึ่งสามารถแยกความแตกต่างได้จากการวนซ้ำ การเรียกซ้ำ สาขา และอื่นๆ นอกจากนี้ เราสามารถแยกความแตกต่างของฟังก์ชันสำหรับคำสั่งใดๆ โดยใช้ grad:

grad(grad(grad(บาป θ)))) (1.0)

แยกความแตกต่างอัตโนมัติของการสั่งซื้อที่สูงขึ้น

ดังที่เราได้กล่าวไปแล้ว grad ค่อนข้างมีประโยชน์ในการหาอนุพันธ์ย่อยของฟังก์ชัน เราสามารถใช้อนุพันธ์บางส่วนในการคำนวณการไล่ระดับสีของฟังก์ชันต้นทุนที่สัมพันธ์กับพารามิเตอร์โครงข่ายประสาทเทียมในการเรียนรู้เชิงลึกเพื่อลดการสูญเสีย

การคำนวณอนุพันธ์ย่อย

สมมติว่าฟังก์ชันมีหลายตัวแปร x, y และ z การหาอนุพันธ์ของตัวแปรหนึ่งโดยการรักษาตัวแปรอื่นให้คงที่เรียกว่าอนุพันธ์ย่อย สมมุติว่าเรามีฟังก์ชัน

f(x,y,z) = x + 2y + z 2

ตัวอย่างการแสดงอนุพันธ์ย่อย

อนุพันธ์ย่อยของ x จะเป็น ∂f/∂x ซึ่งบอกเราว่าฟังก์ชันเปลี่ยนแปลงอย่างไรสำหรับตัวแปรเมื่อค่าคงที่อื่นๆ หากเราดำเนินการด้วยตนเอง เราต้องเขียนโปรแกรมเพื่อแยกความแตกต่าง นำไปใช้กับตัวแปรแต่ละตัว แล้วคำนวณการลงระดับการไล่ระดับสี สิ่งนี้จะกลายเป็นเรื่องที่ซับซ้อนและใช้เวลานานสำหรับตัวแปรหลายตัว

การแยกความแตกต่างอัตโนมัติจะแบ่งฟังก์ชันออกเป็นชุดของการดำเนินการเบื้องต้น เช่น +, -, *, / หรือ sin, cos, tan, exp ฯลฯ จากนั้นใช้กฎลูกโซ่เพื่อคำนวณอนุพันธ์ เราสามารถทำได้ทั้งในโหมดเดินหน้าและถอยหลัง

นี้ ไม่ได้ มัน! การคำนวณทั้งหมดเหล่านี้เกิดขึ้นเร็วมาก (ลองนึกถึงการคำนวณหนึ่งล้านรายการที่คล้ายกันกับด้านบนและเวลาที่อาจต้องใช้!) XLA ดูแลความเร็วและประสิทธิภาพ

#2. พีชคณิตเชิงเส้นเร่ง

ลองใช้สมการก่อนหน้ากัน หากไม่มี XLA การคำนวณจะใช้เคอร์เนลสามตัว (หรือมากกว่า) โดยที่เคอร์เนลแต่ละตัวจะทำงานที่เล็กกว่า ตัวอย่างเช่น,

เคอร์เนล k1 –> x * 2y (การคูณ)

k2 –> x * 2y + z (เพิ่มเติม)

k3 –> การลด

หากงานเดียวกันดำเนินการโดย XLA เคอร์เนลตัวเดียวจะดูแลการดำเนินการระดับกลางทั้งหมดโดยการหลอมรวมเข้าด้วยกัน ผลลัพธ์ขั้นกลางของการดำเนินการเบื้องต้นจะถูกสตรีมแทนที่จะเก็บไว้ในหน่วยความจำ ซึ่งจะช่วยประหยัดหน่วยความจำและเพิ่มความเร็ว

#3. การรวบรวมแบบทันเวลา

JAX ใช้คอมไพเลอร์ XLA ภายในเพื่อเพิ่มความเร็วในการดำเนินการ XLA สามารถเพิ่มความเร็วของ CPU, GPU และ TPU ทั้งหมดนี้เป็นไปได้โดยใช้การเรียกใช้โค้ด JIT ในการใช้สิ่งนี้ เราสามารถใช้ jit ผ่านการนำเข้า:

 from jax import jit def my_function(x): …………some lines of code my_function_jit = jit(my_function)

อีกวิธีหนึ่งคือการตกแต่ง jit เหนือคำจำกัดความของฟังก์ชัน:

 @jit def my_function(x): …………some lines of code

รหัสนี้เร็วกว่ามากเพราะการแปลงจะส่งกลับเวอร์ชันที่คอมไพล์แล้วของรหัสไปยังผู้โทรแทนที่จะใช้ตัวแปล Python สิ่งนี้มีประโยชน์อย่างยิ่งสำหรับอินพุตเวกเตอร์ เช่น อาร์เรย์และเมทริกซ์

เช่นเดียวกับฟังก์ชันไพ ธ อนที่มีอยู่ทั้งหมดเช่นกัน ตัวอย่างเช่น ฟังก์ชันจากแพ็คเกจ NumPy ในกรณีนี้ เราควรนำเข้า jax.numpy เป็น jnp แทนที่จะเป็น NumPy:

 import jax import jax.numpy as jnp x = jnp.array([[1,2,3,4], [5,6,7,8]])

เมื่อคุณทำเช่นนี้ ออบเจ็กต์อาร์เรย์ JAX หลักที่เรียกว่า DeviceArray จะแทนที่อาร์เรย์ NumPy มาตรฐาน DeviceArray ขี้เกียจ – ค่าจะถูกเก็บไว้ในคันเร่งจนกว่าจะจำเป็น นี่ยังหมายความว่าโปรแกรม JAX ไม่รอให้ผลลัพธ์กลับมาที่โปรแกรมโทร (Python) ดังนั้นจึงติดตามการส่งแบบอะซิงโครนัส

#4. vectorization อัตโนมัติ (vmap)

ในโลกของแมชชีนเลิร์นนิงทั่วไป เรามีชุดข้อมูลที่มีจุดข้อมูลมากกว่าหนึ่งล้านจุด เป็นไปได้มากที่เราจะทำการคำนวณหรือปรับเปลี่ยนจุดข้อมูลแต่ละจุดหรือเกือบทั้งหมด ซึ่งเป็นงานที่ต้องใช้เวลาและหน่วยความจำมาก! ตัวอย่างเช่น หากคุณต้องการหากำลังสองของจุดข้อมูลแต่ละจุดในชุดข้อมูล สิ่งแรกที่คุณจะนึกถึงคือการสร้างลูปและหาค่ากำลังสองทีละจุด – อ๊ะ!

ถ้าเราสร้างจุดเหล่านี้เป็นเวกเตอร์ เราก็สามารถทำกำลังสองทั้งหมดได้ในคราวเดียวโดยดำเนินการจัดการเวกเตอร์หรือเมทริกซ์บนจุดข้อมูลด้วย NumPy ที่เราโปรดปราน และหากโปรแกรมของคุณทำสิ่งนี้ได้โดยอัตโนมัติ คุณขออะไรเพิ่มเติมได้ไหม นั่นคือสิ่งที่ JAX ทำ! มันสามารถแปลงจุดข้อมูลทั้งหมดของคุณให้เป็นเวกเตอร์โดยอัตโนมัติ เพื่อให้คุณสามารถดำเนินการใดๆ กับมันได้อย่างง่ายดาย ทำให้อัลกอริทึมของคุณเร็วขึ้นและมีประสิทธิภาพมากขึ้น

JAX ใช้ฟังก์ชัน vmap สำหรับเวกเตอร์อัตโนมัติ พิจารณาอาร์เรย์ต่อไปนี้:

 x = jnp.array([1,2,3,4,5,6,7,8,9,10]) y = jnp.square(x)

เมื่อทำเพียงข้างต้น เมธอดสแควร์จะดำเนินการกับแต่ละจุดในอาร์เรย์ แต่ถ้าคุณทำสิ่งต่อไปนี้:

 vmap(jnp.square(x))

สแควร์เมธอดจะดำเนินการเพียงครั้งเดียวเนื่องจากจุดข้อมูลจะถูกแปลงเป็นเวกเตอร์โดยอัตโนมัติโดยใช้เมธอด vmap ก่อนดำเนินการฟังก์ชัน และการวนซ้ำจะถูกผลักลงไปที่ระดับเบื้องต้นของการดำเนินการ ส่งผลให้มีการคูณเมทริกซ์มากกว่าการคูณสเกลาร์ จึงให้ประสิทธิภาพที่ดีขึ้น .

#5. การเขียนโปรแกรม SPMD (pmap)

SPMD – หรือ S ingle P rogram M ultiple D ata การเขียนโปรแกรมเป็นสิ่งจำเป็นในบริบทการเรียนรู้เชิงลึก – คุณมักจะใช้ฟังก์ชันเดียวกันนี้กับชุดข้อมูลต่างๆ ที่อยู่ใน GPU หรือ TPU หลายตัว JAX มีฟังก์ชันชื่อ pump ซึ่งช่วยให้สามารถตั้งโปรแกรมแบบขนานบน GPU หลายตัวหรือตัวเร่งความเร็วได้ เช่นเดียวกับ JIT โปรแกรมที่ใช้ pmap จะถูกรวบรวมโดย XLA และดำเนินการพร้อมกันทั่วทั้งระบบ การขนานอัตโนมัตินี้ใช้ได้กับการคำนวณทั้งแบบไปข้างหน้าและย้อนกลับ

pmap ทำงานอย่างไร

เรายังสามารถใช้การแปลงหลายแบบในครั้งเดียวในลำดับใดก็ได้ในฟังก์ชันใดๆ ดังนี้:

pmap(vmap(jit(grad (f(x))))))

แปลงร่างได้หลายแบบ

ข้อจำกัดของ Google JAX

นักพัฒนาซอฟต์แวร์ Google JAX คิดมาอย่างดีเกี่ยวกับการเร่งความเร็วอัลกอริธึมการเรียนรู้เชิงลึกในขณะที่แนะนำการเปลี่ยนแปลงที่ยอดเยี่ยมเหล่านี้ ฟังก์ชันและแพ็คเกจการคำนวณทางวิทยาศาสตร์อยู่ในบรรทัดของ NumPy ดังนั้นคุณจึงไม่ต้องกังวลเกี่ยวกับเส้นโค้งการเรียนรู้ อย่างไรก็ตาม JAX มีข้อจำกัดดังต่อไปนี้:

  • Google JAX ยังอยู่ในช่วงเริ่มต้นของการพัฒนา และแม้ว่าจุดประสงค์หลักคือการเพิ่มประสิทธิภาพการทำงาน แต่ก็ไม่ได้ให้ประโยชน์มากนักสำหรับการประมวลผลของ CPU ดูเหมือนว่า NumPy จะทำงานได้ดีกว่า และการใช้ JAX อาจเพิ่มเฉพาะโอเวอร์เฮดเท่านั้น
  • JAX ยังอยู่ในการวิจัยหรืออยู่ในช่วงเริ่มต้น และต้องการการปรับแต่งเพิ่มเติมเพื่อให้เข้าถึงมาตรฐานโครงสร้างพื้นฐานของเฟรมเวิร์ก เช่น TensorFlow ซึ่งเป็นที่ยอมรับมากขึ้นและมีโมเดลที่กำหนดไว้ล่วงหน้า โครงการโอเพ่นซอร์ส และสื่อการเรียนรู้
  • ณ ตอนนี้ JAX ยังไม่รองรับระบบปฏิบัติการ Windows คุณต้องมีเครื่องเสมือนจึงจะใช้งานได้
  • JAX ใช้งานได้กับฟังก์ชันล้วนๆ เท่านั้น ซึ่งไม่มีผลข้างเคียงใดๆ สำหรับฟังก์ชันที่มีผลข้างเคียง JAX อาจไม่ใช่ตัวเลือกที่ดี

วิธีการติดตั้ง JAX ในสภาพแวดล้อม Python ของคุณ

หากคุณมีการตั้งค่าหลามในระบบของคุณและต้องการเรียกใช้ JAX บนเครื่องท้องถิ่น (CPU) ให้ใช้คำสั่งต่อไปนี้:

 pip install --upgrade pip pip install --upgrade "jax[cpu]"

หากคุณต้องการเรียกใช้ Google JAX บน GPU หรือ TPU ให้ทำตามคำแนะนำที่ให้ไว้ในหน้า GitHub JAX หากต้องการตั้งค่า Python ให้ไปที่หน้าดาวน์โหลดอย่างเป็นทางการของ python

บทสรุป

Google JAX เหมาะอย่างยิ่งสำหรับการเขียนอัลกอริธึมการเรียนรู้เชิงลึก วิทยาการหุ่นยนต์ และการค้นคว้าที่มีประสิทธิภาพ แม้จะมีข้อจำกัด แต่ก็มีการใช้งานอย่างกว้างขวางกับเฟรมเวิร์กอื่นๆ เช่น ไฮกุ แฟลกซ์ และอื่นๆ อีกมากมาย คุณจะสามารถชื่นชมสิ่งที่ JAX ทำเมื่อคุณเรียกใช้โปรแกรมและเห็นความแตกต่างของเวลาในการรันโค้ดที่มีและไม่มี JAX คุณสามารถเริ่มต้นด้วยการอ่านเอกสารอย่างเป็นทางการของ Google JAX ซึ่งค่อนข้างครอบคลุม