Developing software tools for accelerated and differentiable scientific computing using JAX
- Track: HPC, Big Data & Data Science
- Room: H.1308 (Rolin)
- Day: Sunday
- Start: 13:20
- End: 13:30
- Video only: h1308
- Chat: Join the conversation!
JAX is an open-source Python package for high-performance numerical computing. It provides a familiar NumPy style interface but with the advantages of allowing computations to be dispatched to accelerator devices such as graphics and tensor processing units, and supporting transformations to automatically differentiate, vectorize and just-in-time compile functions. While extensively used in machine learning applications, JAX's design also makes it ideal for scientific computing tasks such as simulating numerical models and fitting them to data.
This talk will introduce JAX's interface and computation model, and discuss my experiences in developing two open-source software tools that exploit JAX as a key dependency: S2FFT, a Python package providing Fourier-like transforms for spherical data and Mici, a Python package implementing algorithms for fitting probabilistic models to data. I will also introduce the Python Array API standard and explain how it can be used to write portable code which works across JAX, NumPy and other array backends.
Speakers
| Matt Graham |