[Combination of Gradient Loss Across Devices] Hello guys, Firstly, thank you so much for the amazing tutorials @TheAIEpiphany ! Secondly, I'd like to clarify the mathematics behind the combination of gradients of loss across multiple devices @55:34... The question arises: Is it correct to compute gradient as the average of gradients from different devices? I mean, will it give the same gradient as if we were only doing it on one device ? The answer is YES it is correct, but only if the Loss is defined as a weighted sum across the samples. This is supported by the fact that the gradient of a weighted sum is equivalent to the weighted sum of gradients. Thus, in this context, the Loss is a mean across samples (or batches), making it a weighted sum. The same principle would also be applicable for the cross-entropy Loss. Additionally, the batches size across the devices should be the same. Otherwise it would not be a mean, but instead a weighted sum (with the weights of each device equal the normalised batch size allocated to this device). Hope my comment is clear and will demystify some questions that one would have wondered :) PS : For the one that would not have understood my comment, the conclusion is : "it is good to do as @TheAIEpiphany is doing" (because we are dealing with MSE/Cross-Entropy and batch size across devices is the same)
Code is here: github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_2_JAX_HeroPro%2B_Colab.ipynb Just click the colab button and you're ready to play with the code yourself.
Great series of tutorials, congrats ! It would be nice to see a comparative in terms of performance between jax and pytorch for some real-world use case (gpu and tpu) :)
Great videos, Aleksa! I found the name of x and y arguments in the MLP forward function confusing, since they are really batches of xs and ys. You could used vmap there instead of writing in already batched form, but I guess it's a good exercise for your viewers to rewrite it in unbatched form and apply vmap :)
jax.tree_multimap was deprecated in JAX version 0.3.5, and removed in JAX version 0.3.16. What can we use to replace this function in "Training an MLP in pure JAX" part of the video?
It depends on many factors, I guess we'll have to wait and see. If Google starts pushing JAX more than TF then sure, I'd be much more confident. But even like this I see that JAX is getting more and more love from the research community.
So the parallelism you demo'd with pmap, that was data parallelism correct? replicating the whole model across all the devices, sending different batches to each device, and then collecting the mean model back on the host device after forward and backwards pass? am i understanding that correctly?
In the middle of the notebook, I saw the comment "# notice how we do jit only at the highest level - XLA will have plenty of space to optimize". Do you have a reference on when to jit only at the highest level and when to jit single nested functions and what the advantages/risks of each approach are? I used to jit every single function until now, so I'm curious what I can gain by a single high-level jit.
To precise : it is true, ONLY if your function has some part that could be parallelised. For example, if you define a train function that has a for loop across the number of epochs, you should not jit it ! You should rather jit one level lower (here the update function). Indeed, the epochs can not be parallelised (you need the epoch to be done before starting the next one... ). If you were to jit the train function, the compilation time would take way longer, and would have to compile for every epochs... But @TheAIEpiphany is explaining it way better than me "why you should not jit everything" in his first tutorial :)
8:53 why do we want to return the state if it doesn't change ? (in general I mean) Is it just a good practice and so you don't need to think about the fact that it may or may not change and always return it ?
But how can you train model in Jax? If you set up everything from scratch then I think it is not very useful, I am sure pytorch/TF are not so much behind in terms of speed.
Check out the 3rd video for building models from scratch in pure JAX. As for the frameworks there are Flax and Haiku I'll cover them next (they are equivalent to TF/PyTorch).