Тёмный

Machine Learning with JAX - From Zero to Hero | Tutorial #1 

Aleksa Gordić - The AI Epiphany
Подписаться 53 тыс.
Просмотров 56 тыс.
50% 1

❤️ Become The AI Epiphany Patreon ❤️
/ theaiepiphany
👨‍👩‍👧‍👦 Join our Discord community 👨‍👩‍👧‍👦
/ discord
With this video I'm kicking off a series of tutorials on JAX!
JAX is a powerful and increasingly more popular ML library built by the Google Research team. The 2 most popular deep learning frameworks built on top of JAX are Haiku (DeepMInd) and Flax (Google Research).
In this video I cover the basics as well as the nitty-gritty details of jit, grad, vmap, and various other idiosyncrasies of JAX.
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
✅ JAX GitHub: github.com/google/jax
✅ JAX docs: jax.readthedocs.io/
✅ My notebook: github.com/gordicaleksa/get-s...
✅ Useful video on autodiff: • What is Automatic Diff...
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
⌚️ Timetable:
00:00:00 What is JAX? JAX ecosystem
00:03:35 JAX basics
00:10:05 JAX is accelerator agnostic
00:15:00 jit explained
00:17:45 grad explained
00:27:25 The power of JAX autodiff (Hessians and beyond)
00:31:00 vmap explained
00:36:50 JAX API (NumPy, lax, XLA)
00:39:40 The nitty-gritty details of jit
00:46:55 Static arguments
00:50:05 Gotcha 1: Pure functions
00:56:00 Gotcha 2: In-Place Updates
00:57:35 Gotcha 3: Out-of-Bounds Indexing
00:59:55 Gotcha 4: Non-Array Inputs
01:01:50 Gotcha 5: Random Numbers
01:09:40 Gotcha 6: Control Flow
01:13:45 Gotcha 7: NaNs and float32
02:15:25 Quick summary
02:16:00 Conclusion: who should be using JAX?
02:17:10 Outro
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
💰 BECOME A PATREON OF THE AI EPIPHANY ❤️
If these videos, GitHub projects, and blogs help you,
consider helping me out by supporting me on Patreon!
The AI Epiphany - / theaiepiphany
One-time donation - www.paypal.com/paypalme/theai...
Huge thank you to these AI Epiphany patreons:
Eli Mahler
Petar Veličković
Bartłomiej Danek
Zvonimir Sabljic
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
💼 LinkedIn - / aleksagordic
🐦 Twitter - / gordic_aleksa
👨‍👩‍👧‍👦 Discord - / discord
📺 RU-vid - / theaiepiphany
📚 Medium - / gordicaleksa
💻 GitHub - github.com/gordicaleksa
📢 AI Newsletter - aiepiphany.substack.com/
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
#jax #machinelearning #framework

Опубликовано:

 

17 июл 2024

Поделиться:

Ссылка:

Скачать:

Готовим ссылку...

Добавить в:

Мой плейлист
Посмотреть позже
Комментарии : 89   
@user-wr4yl7tx3w
@user-wr4yl7tx3w 2 года назад
this video is such a great service to the community. really great examples to help better understand Jax at a nuanced level.
@matthewkhoo1153
@matthewkhoo1153 2 года назад
When I started learning JAX, I personally think it stands for JIT (J), Autograd (A), XLA (X) which is essentially an abbreviation for a bunch of abbreviations. Given that those features are the 'highlights' of JAX, its very possible. If that's the case, pretty cool naming from DeepMind. Anyways, there aren't many comprehensive resources for JAX right now, so I'm really looking forward to this series! Cheers Aleksa.
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Wow great point I totally overlooked it. 😅 That looks like a better hypothesis than the one I had. If you google "Jax meaning" it seems it's a legit name and means something like "God has been gracious; has shown favor". 😂
@matthewkhoo1153
@matthewkhoo1153 2 года назад
@@TheAIEpiphany Probably an alternative in case the name 'Jack' is too boring lmao. Had a similar experience, first time I googled "jax tutorial" it was a guide for a game character haha.
@DanielSuo
@DanielSuo 11 месяцев назад
Believe it stands for "Just Another XLA" compiler.
@sarahel-sherif3318
@sarahel-sherif3318 2 года назад
Great material and great efforts , excited to see FLAX and Haiku ,thank you
@mikesmith853
@mikesmith853 2 года назад
Awesome work!! JAX is a fantastic library. This series is the reason I finally subscribed to your channel. Thanks for your work!
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Thank you! 😄
@billykotsos4642
@billykotsos4642 2 года назад
This channel is going on the GOAT level status
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Hahaha thank you! Step by step
@mariuskombou6729
@mariuskombou6729 4 месяца назад
Thanks for this video!! that was really interesting for a new user of JAX like me
@jawadmansoor6064
@jawadmansoor6064 2 года назад
Great video thanks, kindly complete the tutorial series on Flax as well.
@vijayanand7270
@vijayanand7270 2 года назад
Great video!🔥 Would need Paper implementations too💯
@deoabhijit5935
@deoabhijit5935 2 года назад
looking forward for such grt videos
@Khushpich
@Khushpich 2 года назад
Great stuff, looking forward to the next jax tutorials
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Thanks!
@TheParkitny
@TheParkitny 2 года назад
great video and content, this channel needs more recognition.
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Thank you! 🥳
@shyamalchandra6597
@shyamalchandra6597 2 года назад
Great job! Keep it up!😀
@maikkschischo1790
@maikkschischo1790 Год назад
Great work. Thank you, Aleksa. I learned a lot. Coming from R, I like the functional approach here. Would be interested to hear about your current opinion about jax, after knowing it better.
@kaneyxx
@kaneyxx 2 года назад
Thanks for the tutorial! Love it!
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Glad to hear that ^^
@DullPigeon2750
@DullPigeon2750 Год назад
Wonderful explanation about vmap function
@johanngerberding5956
@johanngerberding5956 2 года назад
congrats to your deepmind job man (read your post), nice channel, keep going!
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Thank you!! 🙏😄
@broccoli322
@broccoli322 Год назад
Great job. Very nice tutorial
@alexanderchernyavskiy9538
@alexanderchernyavskiy9538 2 года назад
oh, quite impressive series with the perfect explanation
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Thank you man! 🙏
@akashraut3581
@akashraut3581 2 года назад
Finally some jax tutorial.. Keep them coming
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Yup it was about time I started learning it. Will do sir
@bionhoward3159
@bionhoward3159 2 года назад
i love jax ... thank you for your work!
@TheAIEpiphany
@TheAIEpiphany 2 года назад
You're welcome!
@RamithHettiarachchi
@RamithHettiarachchi 2 года назад
Thank you for the tutorial! By the way, according to their paper (Compiling machine learning programs via high-level tracing), JAX stands for just after execution 😃
@sallanmega1
@sallanmega1 2 года назад
Thank you for the amazing content. Greetings from Spain
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Gracias y saludos desde Belgrado! 🙏 Tengo muchos amigos en España.
@MikeOxmol_
@MikeOxmol_ 2 года назад
Saw your video retweeted by someone, watched it and subbed, because your content is great :) How often will you be uploading the following Jax vids?
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Thanks! Next one tomorrow or Thursday.
@TheAIEpiphany
@TheAIEpiphany 2 года назад
I also open-sourced an accompanying repo here: github.com/gordicaleksa/get-started-with-JAX I recommend opening the notebook in parallel while watching the video so that you can play and tweak the code as well. Just open the notebook, click the Colab button on top of the file, and voila! You'll avoid having to set up the Python env and everything will just work! (you can potentially choose a GPU as an accelerator).
@user-vg4gv9zj3d
@user-vg4gv9zj3d 2 года назад
It's really good Tutorial!! thx :)
@alinajafistudent6906
@alinajafistudent6906 Год назад
Very Cool, Thanks
@yagneshm.bhadiyadra4359
@yagneshm.bhadiyadra4359 Год назад
Thank you for a great content!
@TheAIEpiphany
@TheAIEpiphany Год назад
Thanks!
@michaellaskin3407
@michaellaskin3407 2 года назад
great video!
@gim8377
@gim8377 4 месяца назад
Thank you so much !
@1potdish271
@1potdish271 2 года назад
Hi Aleksa, First of thank you very much for sharing great content. I learn a lot from you. Could you please explain some up side of JAX over other frameworks?? I really need motivation to get started with JAX. Thanking you. Cheers :)
@sacramentofwilderness6656
@sacramentofwilderness6656 2 года назад
Thanks for the great tutorial with pointing out the strong and weak points of the JAX framework, with caveats and salient features. What makes me somehow confused -the behavior, that overshooting the index array clips the index to maximum or does nothing. In C/C++ if one does this usually if the displacement is small - some memory part outside the given data is modified, and for strongly index mistake one would receive SEGMENTATION FAULT. Clipping the index makes the program safer, but in addition to counterintuitive behavior is adds some small additional cost for fetching the index.
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Thank you! It is confusing. It'd be cool to understand why exactly is it difficult to handle this "correctly" (throwing an exception).
@mohammedelfatihsalahmohame7288
@mohammedelfatihsalahmohame7288 2 года назад
I am glad that I found this channel
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Welcome 🚀
@mrlovalovaa
@mrlovalovaa 2 года назад
Finally the Jax !!!!!!!
@promethesured
@promethesured 3 месяца назад
ty ty ty ty ty for this video
@nicecomment5411
@nicecomment5411 Год назад
Thanks for the video. I have a question, how do i run tensorflow jax on browser? (Not in an online notebook)
@yagneshm.bhadiyadra4359
@yagneshm.bhadiyadra4359 Год назад
Can we say that if we made all arguments static, then it will be as good as normal code without jax? Thank you for these videos btw
@teetanrobotics5363
@teetanrobotics5363 2 года назад
amazing content
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Thank you!
@peterkanini867
@peterkanini867 Год назад
What font are you using for the Colab notebook?
@vaishnav4035
@vaishnav4035 8 месяцев назад
Thank you so much ..
@kenbobcorn
@kenbobcorn 2 года назад
Can someone explain to me why at 20:50 jnp.sum() is required and why it returns [0, 2, 4]? I would assume it would return 0 + 2 + 4 = 6 like its described in the comment and using sum(), but it doesn't it just returns the original vector size with all the elements squared.
@TheAIEpiphany
@TheAIEpiphany 2 года назад
I made a mistake. It's going to return a vector (df/dx1, df/dx2, df/dx3) and not the sum. f = x1^2 + x2^2 + x3^2 and grad takes derivatives independently for x1, x2 and x3 since they are all bundled into the first argument of the function f. Hope that makes sense. You can always consult the docs and experiment yourself.
@Gannicus99
@Gannicus99 2 года назад
Good question. Took me a moment as well, but, the function gives the sum, whereas grad(function) gives the three gradients, one per parameter, since the output of grad is used for SGD parameter updates w1, w2, w3 = w1 - lr*df/x1, w2 - lr*df/x2, w3 - lr*df/x3.
@zumpitu
@zumpitu 11 месяцев назад
Hi ! Thank you for your video. Is not that very similar tu Numba ?
@not_a_human_being
@not_a_human_being 10 месяцев назад
50:40 - I would argue here, that it's not necessary to pass all the parameters into the function, as long as it's not changing any of the params, it's ok to use external globals(), like for some reference tables etc. This definition (though academically thorough), make practical application a bit more cumbersome. I believe that the better way to think "2." is sufficient to make this work. No need to pass long list of params. Just make sure not to update/change anything external inside the function, and whatever is not passed in is static. Alternatively, you can have "get jit_function" every time you anticipate that your globals might've changed. So, you will be effectively re-creating your jit function with new globals(). In some cases that feels much preferable to passing everything in. For instance, you can use all sorts of globals inside it, then just re-create it just before your training loop.
@user-wr4yl7tx3w
@user-wr4yl7tx3w 2 года назад
Would it be better to use Julia and not have to worry about the gotchas? And still get the performance.
@Gannicus99
@Gannicus99 2 года назад
Function Pureness rule #2 means one can not use closure variables (wrapping function variables)? That’s good to know since jax states that it is functional, but does not include closure use - due to jit caching only regarding explicitly passed function parameters. Closure variables are hacky, but they are valid python code. Just not in JAX.
@adityakane5669
@adityakane5669 2 года назад
At around 1:10:34, you have used static_argnums=(0,) for jit. Wouldn't this extremely slow down the program as it will have to retrace for all new values of x? Code to reproduce: def switch(x): print("traced") if x > 10.: return 1. else: return 0. jit_switch = jit(switch, static_argnums=(0,)) x = 5 jit_switch(x) x = 16 jit_switch(x) ''' Output: traced DeviceArray(0., dtype=float32, weak_type=True) traced DeviceArray(1., dtype=float32, weak_type=True) '''
@adels1388
@adels1388 2 года назад
Thanks alot. Keep up the good work. Am I wrong or the derivative at 20:15 should be (x1*2, x2*2, x3*2). I mean you take the gradient with respect to a vector so you should take the derivative with respect of each variable separately.
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Of course what did I do? 😂
@adels1388
@adels1388 2 года назад
@@TheAIEpiphany you wrote x1*2+x2*2+x3*2. I replaced + with comma :)
@TheAIEpiphany
@TheAIEpiphany 2 года назад
@@adels1388 My bad 😅 Thanks for noticing, the printed result was correct...
@yulanliu3839
@yulanliu3839 2 года назад
JAX = Just After eXecution (related to the tracing behaviour) JIT = Just In Time (related to the compilation behaviour)
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Thank you! A couple more people also pointed it out. Makes sense
@alvinhew1872
@alvinhew1872 2 года назад
Hi, are there any resources on how to freeze certain layers of the network for transfer learning?
@TheAIEpiphany
@TheAIEpiphany 2 года назад
jax.lax.stop_gradient
@tshegofatsotshego375
@tshegofatsotshego375 2 года назад
Is it possible to use jax with python statsmodel?
@PhucLe-qs7nx
@PhucLe-qs7nx 2 года назад
JAX is Just After eXecution, represent the paradigm of tracing and transform (grad, vmap, jit,..) after the first execution.
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Hmmm, source?
@PhucLe-qs7nx
@PhucLe-qs7nx 2 года назад
@@TheAIEpiphany Sorry I can't remember it now. But it somewhere in the documentation or a Jax's github issue/discussion,
@TheAIEpiphany
@TheAIEpiphany 2 года назад
@@PhucLe-qs7nx Thanks in any case! One of the other comments mentioned it simply stands for Jit Autograd XLA. 😄 That sounds reasonable as well.
@arshsharma8627
@arshsharma8627 Месяц назад
bru youre great
@TheMazyProduction
@TheMazyProduction 2 года назад
Noice
@user-wr4yl7tx3w
@user-wr4yl7tx3w 2 года назад
at ru-vid.com/video/%D0%B2%D0%B8%D0%B4%D0%B5%D0%BE-SstuvS-tVc0.html, what I noticed was that when I tried, print(grad(f_jit)(2.)), even with the static_argnums.
@L4rsTrysToMakeTut
@L4rsTrysToMakeTut 2 года назад
Why not julia lang?
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Why would this video imply that you shouldn't give Julia a shot? I may make a video on Julia in the future. I personally wanted to learn JAX since I'll be using it in DeepMind.
@alokraj7120
@alokraj7120 2 года назад
Can you make a video "how to install jax in anaconda, python and other python Ide
@adrianstaniec
@adrianstaniec 2 года назад
great content, horribile font ;)
@TheAIEpiphany
@TheAIEpiphany 2 года назад
Hahaha thank you and thank you!
@kirtipandya4618
@kirtipandya4618 2 года назад
JAX = Just Autograd and Xla
@samueltrif5472
@samueltrif5472 Год назад
one hour of nothing lol
@heyman620
@heyman620 2 года назад
49:00 jnp.reshape(x, (np.prod(x.shape),)) works.
Далее
OpenAI CLIP | Machine Learning Coding Series
1:30:40
Просмотров 15 тыс.
O-Zone - Numa Numa yei на русском!🤓
00:56
Просмотров 255 тыс.
Joy and Anxiety Mood (Inside Out Animation)
00:13
Просмотров 1,8 млн
What is Automatic Differentiation?
14:25
Просмотров 106 тыс.
Demo: JAX, Flax and Gemma
8:12
Просмотров 3 тыс.
Machine Learning with Flax - From Zero to Hero
1:18:16
Просмотров 17 тыс.
Just In Time (JIT) Compilers - Computerphile
10:41
Просмотров 264 тыс.
O-Zone - Numa Numa yei на русском!🤓
00:56
Просмотров 255 тыс.