Developing software tools for accelerated and differentiable scientific computing using JAX
- Track: HPC, Big Data & Data Science
- Room: H.1308 (Rolin)
- Day: Sunday
- Start (UTC+1): 13:20
- End (UTC+1): 13:30
- Room livestream: h1308
- Chat: Join the conversation!
JAX is an open-source Python package for high-performance numerical computing. It provides a familiar NumPy style interface but with the advantages of allowing computations to be dispatched to accelerator devices such as graphics and tensor processing units, and supporting transformations to automatically differentiate, vectorize and just-in-time compile functions. While extensively used in machine learning applications, JAX's design also makes it ideal for scientific computing tasks such as simulating numerical models and fitting them to data.
This lightning talk will introduce JAX's interface and computation model, and some of its key function transformations. I will also briefly introduce the Python Array API standard and explain how it can be used to write portable code which works across JAX, NumPy and other array backends.
Speakers
| Matt Graham |