Looks great! I tend to default to numpy when I want to do something that is not fully supported in keras or pytorch and if i can get paralellization on gpu very easily from this that is perfect!
This typically looks like a problem that could be easily solved with a language that supports multi-stage programming; meta-programming as a first class citizen, which is not really the case with Python. Like Rust or Elixir via the Nx library which is actually directly inspired of Jax.
Ok, this is seriously cool. Is this brand new? Haven't seen it before. Also, in the first code sample did you mean to import vmap and pmap instead of map, or is that some kind of namespace black magic I don't understand?
How are you going to compare torch to tf/jax when run on a different GPU? There is no way you can argue the 2 gpus are comparable, they will be faster/slower at different types of computation regardless of software used. Should have compared the 3 on a common gpu if for some reason torch couldnt be run on the tpuv3.
Why do you need a new lib? Tensorflow can do 90+% of this, doesn't it? Is it a good idea to make a completely new thing instead extending the old one? One more question: do/will you have Keras support?
Numerical differentiation computes f’(x) by evaluating the function around x: (f(x+h)-f(x-h))/2h with a small h. Automatic differentiation represents the function expression or code as a computational graph. It looks at the actual code of the function. The final derivative is obtained by propagating the value of local derivatives of simple expressions through the graph via the chain rule. The simple expressions are functions like +, -, cos(x), exp(x) for which we knows the derivatives at a given x.
I mean it is kinda niche but suppose you solve a problem that heavily relies on many custom functions, e.g., a very specific algebra like quaternion-operations. Then you can write super-fast basic operations and compose them to build a complicated loss-function that as a whole you can then jit-compile and let it get optimized. Or differentiate it, or vectorize it, all with a tiny decorator.
torch and keras is "slow" and is only meant for the development phase. not sure how fast jax can outperform them. edit: "slow" as in computation/inference time