The numerical computing library JAX has attracted much attention since its release. Supporters believe that it is “really fragrant” and has fast performance; but opponents also say that JAX is still too young and has many loopholes, which recently sparked a big discussion on Reddit.
Since JAX was released in late 2018, the audience has gradually increased. With DeepMind announcing in 2020 to start using JAX to accelerate research, more and more code, projects from companies like Google Brain are using JAX.
A recent developer education practitioner, Ryan O’Connor, published an article detailing whether JAX will replace TensorFlow/PyTorch, and who should start using JAX in 2022.
What is JAX?
JAX is not a deep learning framework or library , nor is it designed to be a new deep learning framework. Simply put, JAX is a numerical computing library that includes composable functional transformations , but deep learning happens to be a job that JAX can do.
JAX is at the interface between function transformations and scientific computing, so it also has the ability to train neural network models, but not only training.
JAX was originally initiated by Matt Johnson, Roy Frostig, Dougal Maclaurin, and Chris Leary of the Google Brain team. With the help of an updated version of Autograd, combined with XLA, it can perform automatic differentiation of Python programs and NumPy operations, supporting loops, branches, recursion , closure function derivation, and third-order derivatives; depending on XLA, JAX can compile and run NumPy programs on GPU and TPU; through grad, it can support automatic mode backpropagation and forward propagation, and both can be arbitrary combined in any order.
At present, JAX has won more than 16,000 stars on Github, compared with 160,000 stars of tensorflow and 54,000 stars of pyTorch , so there is still a long way to go to surpass the two big brothers in the field of deep learning. .
Why should you know about JAX?
If a user chooses JAX, there is basically only one reason: speed.
For example, for the same function, you can see that the numpy implementation takes about 851 milliseconds.
And if it is replaced by JAX, the result is shortened to 5.54ms, which achieves a performance improvement of more than 150 times over numpy!
If you draw a histogram, the performance advantage is even more obvious.
And the reason JAX computation is faster is that TPU is used, while NumPy only uses CPU. Although it is not as simple as “use JAX and your program will be 150 times faster” in practice, there are still many reasons to use it. JAX provides a common foundation for scientific computing, and it has different uses for people in different fields. Fundamentally, if you are in any field related to scientific computing, you should know JAX.
The author lists 6 reasons why you should use JAX:
1. Speed up NumPy. NumPy is one of the basic packages for scientific computing with Python, but it is only compatible with CPUs. JAX provides a NumPy implementation (with a near-identical API) that works very easily on GPUs and TPUs. For many users, this alone is enough to justify the use of JAX.
2. XLA, Accelerated Linear Algebra (Accelerated Linear Algebra), is a full program optimizing compiler, specially designed for linear algebra. JAX is built on top of XLA, which greatly improves the upper limit of computing speed.
3. JIT. JAX allows users to use XLA to convert a function into a JIT (just in time) compiled version. This means that users can speed up the computation by adding a simple function decorator to the computation function, possibly by several orders of magnitude in performance.
4. Automatic derivation. The JAX documentation refers to JAX as a combination of Autograd and XLA. The ability to auto-derivation is critical in many areas of scientific computing, and JAX provides several powerful auto-derivation tools.
5. Deep Learning. While JAX itself is not a deep learning framework, it certainly provides a more adequate foundation for deep learning. There are many deep learning libraries built on top of JAX, such as Flax, Haiku, and Elegy. Even some researchers in the PyTorch vs TensorFlow article emphasized that JAX is also a “framework” worthy of attention, recommending it for TPU-based deep learning research. The efficient computation of Hessians by JAX is also relevant to deep learning, as they take higher-order optimization techniques a step further.
6. General Differentiable Programming Paradigm. While it is possible to use JAX to build and train deep learning models, it also provides a framework for general-purpose differentiable programming. This means that JAX can solve real-world problems by using a model-based machine learning approach.
In 2022, should I learn JAX?
As with all tangled questions, the answer to this question remains: It depends.
If you are interested in JAX for general scientific computing, the first question you should ask yourself is whether you just want to speed up NumPy.
If your answer is “yes”, then you should be using JAX yesterday.
If you’re not just computing numbers, but participating in dynamic computational modeling, then whether you should use JAX will depend on your use case. If most of your work is in Python with a lot of custom code, it’s worth starting to learn JAX to improve your workflow.
Or if most of your work is not in Python, but you want to build some sort of hybrid model/neural network based system, then it might be worthwhile to use JAX.
If most of your work is not in Python, or you are using some specialized software for research (thermodynamics, semiconductors, etc.), then JAX is probably not the right tool for you, unless you want to export data from these programs for some kind of Custom calculation processing.
If your area of interest is closer to physics/mathematics and includes computational methods (dynamical systems, differential geometry, statistical physics), or if most of your work is done in e.g. Mathematica, then stick with what you are using Might be worth it, especially if you have a large custom codebase.
For deep learning practitioners, if you are new to deep learning and considering using JAX, and you are interested in learning deep learning, for your own education, then I suggest you use JAX or PyTorch.
If you want to learn deep learning top-down and/or have some Python/software experience, then I suggest you start with PyTorch.
If you want to learn deep learning from the bottom up and/or come from a math background, you might find JAX intuitive and should give it a try. In this case, make sure you understand how to use JAX before going into any big project.
If you are interested in deep learning and have the potential to change your career, then you are better off using PyTorch or TensorFlow. While it’s best to be familiar with both frameworks, you should be aware that TensorFlow is considered an “industry” framework, while PyTorch is more “academic”.
If you are a complete beginner with no math or software background, but want to learn deep learning, then you won’t want to use JAX, starting with Keras is a better choice.
When not to use JAX
While JAX has the potential to greatly improve the performance of your program, there are several situations in which it is not appropriate to use JAX.
1. JAX is still officially considered an experimental framework, not a full-fledged Google product, so if you’re considering moving to JAX, you need to think carefully.
2. When using JAX, the time cost of debugging will be higher, and there are many bugs that remain undiscovered. For those who don’t have a solid grasp of functional programming, using JAX might not be worth it. Before you start using JAX for a formal project, make sure you understand the common pitfalls of using JAX.
3. JAX is not optimized for CPU computing. Given that JAX was developed in an “acceleration-first” fashion, the scheduling of each operation is not fully optimized. Because of this, NumPy may actually be faster than JAX in some cases, especially for small programs.
4. JAX is not compatible with Windows. There is currently no support for JAX on Windows.