Perjalanan seorang ahli robotika dengan JAX: Menemukan efisiensi dalam kontrol dan simulasi optimal

29 JULI 2025
Srikanth Kilaru Senior Product Manager Google ML Frameworks
Max Muchen Sun Robotics Researcher Northwestern University

JAX semakin banyak diadopsi oleh developer untuk berbagai tugas komputasi, memperluas perannya di luar fokus awalnya pada AI berskala besar. Meskipun tetap menjadi framework populer untuk mengembangkan LLM dan model dasar, JAX juga mendapatkan momentum dalam berbagai domain ilmiah. Salah satu bidang yang sangat menarik perhatian adalah robotika, di sini JAX memungkinkan kemampuan simulasi, kontrol, dan integrasi metode berbasis pembelajaran yang kuat.

Baru-baru ini, saya mendapat kesempatan untuk berbicara dengan Max Muchen Sun, seorang kandidat Ph.D. Robotika dan peneliti di Northwestern University yang dibimbing Prof. Todd Murphey. Pengalamannya dengan jelas menggambarkan bagaimana JAX bisa mengatasi tantangan kritis dalam penelitian robotika, terutama seputar efisiensi komputasi untuk algoritme kontrol yang kompleks dan kombinasi yang mulus antara pendekatan berbasis model maupun pembelajaran. Perjalanan Max dari bergulat dengan alat tradisional hingga memanfaatkan fitur unik JAX seperti vmap dan scan adalah cerita yang relevan dan menginspirasi banyak orang di bidang ini.


Perjalanan Max: Dalam kata-katanya sendiri

Ketertarikan saya pada JAX dimulai dari perspektif efisiensi komputasi. Mentor saya pada saat itu, Ian Abraham (sekarang profesor di Yale University), menggunakan autograd dan kemudian membawa saya ke JAX. Kami mengerjakan penelitian menggunakan kontrol ergodic, yang merupakan framework kontrol untuk masalah cakupan. Dibandingkan formulasi kontrol standar, kompleksitas komputasi kontrol ergodic secara inheren lebih tinggi. Untuk mencapai kontrol ergodic real-time, saya awalnya menggunakan NumPy standar dan memanfaatkan fitur vektorisasi dan penyiaran.

Fitur JAX pertama yang menarik perhatian saya adalah vmap JAX. Bagi saya, fitur ini menggabungkan mekanisme vektorisasi dan penyiaran dari NumPy standar, dan menggeneralisasinya lebih jauh melalui transformasi fungsi dan abstraksi komposisi, memudahkan saya untuk melakukan penalaran dan mengimplementasikan paralelisasi untuk masalah yang saya tangani.

Kemudian, saya belajar tentang scan. Awalnya terasa kurang intuitif, tetapi pada akhirnya ini menjadi alat yang efisien untuk menyimulasikan lintasan sistem dinamis. Dalam pengoptimalan lintasan, simulasi maju dinamika sistem merupakan operasi inti yang harus dilakukan berulang kali dan sering kali menjadi bottleneck komputasi. Dengan scan, simulasi lintasan bisa dipercepat hingga dua kali lipat dibandingkan implementasi berbasis NumPy standar. Kemudahan penggunaan dan keunggulan kecepatan yang substansial menarik saya sepenuhnya ke dalam ekosistem JAX.

Di sisi lain, fokus utama PhD saya adalah mengintegrasikan kontrol berbasis model dengan representasi berbasis pembelajaran untuk eksplorasi otonom dan kerja sama multi-agen. Saya memandang metode berbasis model bukan sebagai solusi yang berdiri sendiri, tetapi sebagai struktur untuk meningkatkan efisiensi dan ketahanan pembelajaran. Kemampuan komposabilitas JAX membuatnya ideal untuk menggabungkan pipeline berbasis model dan pembelajaran.

Dalam salah satu makalah terbaru saya yang diterima di Robotics: Science and Systems (RSS), saya menggabungkan flow matching dari model generatif dengan kontrol optimal berbasis model untuk eksplorasi robot, menggunakan gradien aliran untuk memetakan aliran state-space ke kontrol melalui update berbasis LQR—analog dengan backpropagation tetapi pada sistem dinamis. Awalnya saya membuat modul flow matching di PyTorch dan menggunakan C++ untuk LQR, tetapi integrasinya lambat. Beralih ke JAX, saya mengimplementasikan ulang bagian flow matching menggunakan vmap dan grad, serta memanfaatkan alat berbasis JAX seperti OTT (Optimal Transport Toolbox). Bagian yang lainnya adalah pipeline LQR native JAX.

Dalam makalah terbaru lainnya yang dipresentasikan di IEEE International Conference on Robotics and Automation (ICRA), saya mengintegrasikan pipeline kontrol teori-game berbasis model ke dalam model lintasan generatif untuk mempelajari kerja sama multi-agen dari demonstrasi. Alih-alih menggunakan kontrol teori game sebagai solusi menyeluruh—yang sering kali mahal secara komputasi dan membutuhkan spesifikasi kerugian manual—saya menyematkan komputasi teori-game sebagai lapisan terstruktur di dalam conditional variational autoencoder (CVAE). Ini meningkatkan efisiensi data tanpa mengorbankan performa. Kedua komponen tersebut diimplementasikan dalam JAX—CVAE dengan Flax dan lapisan kontrol dari awal. JAX membuatnya mudah: grad bisa melakukan diferensiasi langsung melalui ekuilibrium. Saya juga membuat pemecah masalah iLQGames berbasis JAX untuk menghasilkan data sintetis.

Setelah project ini, saya menyadari bahwa saya menggunakan kembali sebagian besar kode JAX untuk perhitungan sistem dinamis, terutama yang berbasis LQR. Karena saya menggunakan LQR untuk mengintegrasikan kontrol berbasis pembelajaran dan berbasis model secara tidak standar, saya mengemasnya ke dalam pemecah masalah native JAX yang berdiri sendiri—LQRax. Ini mendukung akselerasi GPU, vmap, scan, dan grad, yang memungkinkan LQR tervektorisasi dan dapat dibedakan. Saya menyertakan contoh seperti ergodic dan kontrol teori-game untuk menyoroti bagaimana metode berbasis model bisa melengkapi pembelajaran.

Saya menggunakan JAX di CPU dan GPU, sering kali berbeda dengan komunitas ML. Sebagai contoh, dalam project flow matching, LQR berjalan lebih cepat di CPU, sementara gradien flow matching lebih cepat di GPU. Saya belum pernah menggunakan TPU karena saya biasanya menjalankan semua komputasi secara lokal. Beberapa tahun yang lalu, saya mencoba JAX di Nvidia Jetson, dan penginstalannya sulit. Saya senang JAX sekarang didukung di platform tersemat ini, yang sangat penting untuk robotika. Saya telah menguji algoritme navigasi kerumunan pada robot berkaki empat menggunakan Jetson dengan semua komputasi dilakukan di dalamnya, dan saya berencana untuk mengintegrasikan JAX ke dalam project ini dalam waktu dekat.

Ke depannya, saya akan terus menggunakan JAX untuk alasan yang sama seperti saat saya memulainya. Pertama, efisiensi komputasi, terutama paralelisasi berbasis GPU, semakin penting dalam robotika. Selain pelatihan, JAX memungkinkan kontrol berbasis model baru seperti simulasi paralel yang masif dan update parameter real-time, mirip dengan pembelajaran aktif yang diwujudkan. Kedua, JAX membuat pengintegrasian struktur berbasis model ke dalam pipeline pembelajaran intuitif—baik untuk dinamika, pembentukan kerugian, atau pemecah masalah yang dapat dibedakan. Fleksibilitas tersebut membuat saya bersemangat untuk terus maju.


Jelajahi ekosistem robotika JAX: Dari LQRax hingga MJX

Pengalaman Max menunjukkan beberapa keuntungan utama yang ditawarkan JAX kepada komunitas robotika. Kecepatan signifikan yang dicapai dengan vmap untuk operasi paralel dan scan untuk simulasi lintasan sangatlah penting untuk kontrol real-time dan perencanaan yang kompleks. Selain itu, kemampuan paradigma fungsional dan diferensiasi otomatis membuatnya sangat cocok untuk mengintegrasikan teknik berbasis model klasik dengan komponen berbasis pembelajaran modern.

Kami percaya cerita seperti Max merupakan tanda dari ekosistem yang berkembang pesat dan semakin matang. Paket LQRax miliknya merupakan tambahan yang luar biasa dalam lanskap alat robotika native JAX yang dinamis, dan kami mendorong Anda untuk menjelajahi project ini di GitHub serta mencobanya sendiri. Dalam dunia simulasi, JAX menyediakan fondasi yang kuat dengan mesin paralel masif seperti Brax dan MuJoCo XLA (MJX) baru, yang menghadirkan physics engine MuJoCo populer dan standar secara langsung ke JAX. Kami juga melihat alat khusus dari komunitas, seperti library JaxSim untuk dinamika multibody yang berfokus pada kontrol.

Dalam domain pengoptimalan lintasan, di mana pionir seperti Trajax pertama kali memulainya, LQRax hadir sebagai library modern yang disambut baik oleh peneliti yang sedang membangun sistem kontrol generasi berikutnya. Ia dengan sempurna mewujudkan semangat JAX dengan menyediakan alat yang kuat dan dapat dikomposisikan yang menjembatani kesenjangan antara kontrol berbasis model dan deep learning.

Terima kasih banyak kepada Max yang telah membagikan perjalanannya yang penuh inspirasi kepada kami. Kami sangat senang melihat bagaimana dia dan peneliti lainnya terus memanfaatkan JAX untuk membangun sistem robot cerdas generasi berikutnya. Tim JAX di Google berkomitmen untuk mendukung dan mengembangkan ekosistem yang dinamis ini.