In the last blog, I wrote about Introduction to JAX and Automatic Differentiation. In this one, my plan for the next stage of implementation. Currently, I am working on the modeling notebook (https://github.com/StingraySoftware/notebooks/blob/main/Modeling/ModelingExamples.ipynb) to re-design it using JAX, especially to make optimization more robust by having JAX compute gradients on the likelihood function.
My mentor Daniela highlighted the issue that the current implementation is not robust using NumPy. The plan is to keep working on the current modeling notebook replacing NumPy by jax.numpy and also use grad, jit, vmap, random functionality of JAX.
When it comes to re-design, understanding the current design and the possible drawback and issues with corresponding packages comes on you first and I am trying them out. One such challenge is importing emcee into jupyter notebook for sampling. Despite making sure, I download the dependency in the current virtual environment and then making sure I import emcee into the notebook, it is still acting weird and showing an error: emcee not installed! Can’t sample! It looks like a clash of dependencies.
For now, the plan is to solve every bug I face in the journey and then proceed with understanding how everything connects and the next step is to come up with the report of optimization using JAX. Stay tuned for more on how JAX can accelerate and augment the current modeling framework.
I would recommend one video for anyone who wants to understand the functionality of JAX better and relate more to my study (click here).