Тёмный

JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas 

Enthought
Подписаться 67 тыс.
Просмотров 17 тыс.
50% 1

JAX is a system for high-performance machine learning research and numerical computing. It offers the familiarity of Python+NumPy together with hardware acceleration, and it enables the definition and composition of user-wielded function transformations. These transformations include automatic differentiation, automatic vectorized batching, end-to-end compilation (via XLA), parallelizing over multiple accelerators, and more.
JAX had its initial open-source release in December 2018 (github.com/google/jax).
This talk will introduce JAX and its core function transformations with a live demo. You’ll learn about JAX’s core design, how it’s powering new research, and how you can start using it too!
*****************
/ enthought
/ enthought
/ enthought

Наука

Опубликовано:

 

17 июл 2024

Поделиться:

Ссылка:

Скачать:

Готовим ссылку...

Добавить в:

Мой плейлист
Посмотреть позже
Комментарии : 20   
@jaycostello5823
@jaycostello5823 4 года назад
This is exactly what the scientific computing community needs.
@nazardino1841
@nazardino1841 4 года назад
it saves time and reduces a lot errors
@victornoagbodji
@victornoagbodji 3 года назад
🙏 🙏 😊 very interesting, thank you !
@user-wr4yl7tx3w
@user-wr4yl7tx3w 2 года назад
wow, amazingly cool and really well explained.
@Khushpich
@Khushpich 4 года назад
Very promising. Hope to see a unified deep learning library (convs, etc) for jax in the future
@jeffcarp280
@jeffcarp280 3 года назад
Hamza Flax is a higher level library for Jax that includes Conv layers: github.com/google/flax
@sayakpaul3152
@sayakpaul3152 2 года назад
How can some dislike this video? What's wrong folks?
@MarioHari
@MarioHari 3 года назад
*bug fix* in def predict: outputs = np.tanh(outputs) # not `inputs` your current _predict_ function is just a linear projection without a non-linearity.
@QuintinMassey
@QuintinMassey 2 года назад
Thank you!!! I thought I misunderstood something and just wrote a comment asking for clarification then saw yours. Also, thanks for elaborating on what that line does.
@deterministicalgorithmslab1744
@deterministicalgorithmslab1744 4 года назад
Notes/ Important Points :- 1.) Just replacing numpy with will make your code run on GPU/TPU . 2.) Just passing the function through jit as :- g=(f) will fuse together components of f and optimize it. 3.) Given the function f, to obtain f' ; use grad_f = (f) . grad_f(x) now equals f'(x) . 4.) Jax doesn't use finite difference methods etc. for computing gradients. It uses . Note that it is always possible to find analytic gradients, but not analytic integrals. 5.) takes in a function and returns the batched version of the function , which considers the first dimension as the batch dimension. 6.) How jit works ? JAX converts code to lax primitives. Then passes a tracer value to make the intermediate representation(IR) [ 21: 00 ] This IR is then used for compilation by XLA.
@Karthikk-ln9ge
@Karthikk-ln9ge 4 года назад
Can you please explain 4th point in the above? If JAX doesn't use finite difference methods etc. for computing gradients, then which methods or techniques are used for computing gradients ?
@yodalf3548
@yodalf3548 3 года назад
@@Karthikk-ln9ge It uses algorithmic differentation (also called automatic differentation).
@TernaryM01
@TernaryM01 3 года назад
@@Karthikk-ln9ge It does analytic/symbolic differentiation, instead of numerical.
@QuintinMassey
@QuintinMassey 2 года назад
Hrrmm I’m having trouble understanding why I would want to ditch the well established frameworks that TF or PyTorch already provide for the added acceleration you get using JAX. Is the use case for using JAX over other frameworks for prototyping? Is using JAX acceleration with the frameworks of TF or PT mutually exclusive? Now I have two separate code bases, one implemented using JAX and another using TF or PT?
@f3arbhy
@f3arbhy Год назад
I don't think Jax is competing with the well established TF or Pytorch ecosystem. And it seems like you have unified TF and pytorch in your reply, but clearly one has to decide between one or the other - there is little to no interoperability between the two libraries ( unless one mix and match the dataloader and model handling). Jax operates with a completely different philosophy, and it is that you have a drop in replacement for numpy for your math operations, along with grad, which enables you to compute grandints, vmap and pmap, which allows you to completely vectorize your operations, even across multiple compute devices. It is this very simple API, and a "functional" framework ( Jax functions are ment to be without side effects and allows you to do function composition without much fear) that I love about jax. It is tailored more towards differentiable programming paradigm, and is very suitable for physics' informed networks etc. If you want to use vmap and grad, there is an implementation in torch named functorch that lets you use these features without using Jax.
@catchenal
@catchenal 4 года назад
Jake does not like quiet keyboards?
@mr.meesicks1801
@mr.meesicks1801 3 года назад
Hey, lots of people love loud mechanical keyboards! :D
@SuperPaco0o
@SuperPaco0o 3 года назад
In my opinion this was a very shallow overview of JAX. The speaker didn't compare this tool with industry standard machine learning frameworks like - PyTorch and TensorFlow.
@QuintinMassey
@QuintinMassey 2 года назад
Thank you! I thought I was the only one that didn’t get the full story. I’m still having a tough time understanding when one would want to use JAX outside of its obvious advantages you get in acceleration and parallelism.
Далее
What is Automatic Differentiation?
14:25
Просмотров 106 тыс.
Demo: JAX, Flax and Gemma
8:12
Просмотров 3 тыс.
Intro to JAX: Accelerating Machine Learning research
10:30
Simon Pressler: Getting started with JAX
29:49
Просмотров 1,7 тыс.
NeurIPS 2020: JAX Ecosystem Meetup
1:02:15
Просмотров 26 тыс.
Смело ставь iOS 18
0:57
Просмотров 177 тыс.