JAX-based automatic differentiation: Introduction of modern statistical modeling to Stingray

I assume everyone reading this is already aware of two classical forms of differentiation, namely symbolic and finite differentiation. Symbolic differentiation operates on expanded mathematical expressions which lead to inefficient code and introduction of truncation error while finite differentiation deals with round-off errors. Optimized calculation of derivatives is crucial when it comes to training neural networks or mathematical modeling using bayesian inference. Both classical methods are slow at computing partial derivatives of a function with respect to many inputs, as is needed for gradient-based optimization algorithms. Here, automatic differentiation comes to the rescue. Automatic differentiation is a powerful tool to automate the calculation of derivatives and is preferable to more traditional methods, especially when differentiating complex algorithms and mathematical functions.

Photo Source: Wikipedia

Stingray is astrophysical spectral timing software, a library in python built to perform time series analysis and related tasks on astronomical light curves. JAX is a python library designed for high-performance numerical computing. Its API for numerical functions is based on NumPy, a collection of functions used in scientific computing. Both Python and NumPy are widely used and familiar, making JAX simple, flexible, and easy to adopt. It can differentiate through a large subset of python’s features, including loops, ifs, recursion, and closures, and it can even take derivatives of derivatives. Such modern differentiation packages deploy a broad range of computational techniques to improve the applicability, run time, and memory management.

JAX utilizes the grad function transformation to convert a function into a function that returns the original function’s gradient, just like Autograd. Beyond that, JAX offers a function transformation jit for just-in-time compilation of existing functions and vmap and pmap for vectorization and parallelization, respectively.

Mini-MLP(Multiple layer Perceptron) execution time for 10,000 updates with a batch size of 1024. Source: AI Zone

As seen in the above figure, PyTorch has much more effective in terms of execution speed than TensorFlow when it came to implementing fully connected neural layers. For low-level implementations, on the other hand, JAX offers impressive speed-ups of an order of magnitude or more over the comparable Autograd library. JAX is faster than any other library when MLP implementation was limited to matrix multiplication operations.

How do we decide the ideal library to go with?

Our choice will first depend on the history of the project we start working on, if the code already uses PyTorch then most probably we will end up using PyTorch for writing our code. For general differentiable programming with low-level implementations of abstract mathematical concepts, JAX offers substantial advantages in speed and scale over Autograd while retaining much of Autograd’s simplicity and flexibility, while also offering surprisingly competitive performance against PyTorch and TensorFlow.

I am implementing modern stingray modeling to Stingray software as a part of my GSoC project. Reference to Stingray source code: https://github.com/StingraySoftware. Reference to JAX-based automatic differentiation: https://jax.readthedocs.io/en/latest/jax-101/index.html.