You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
β οΈjaxify is an experimental project under development
You're welcome to try it out and report any issues!
jaxify lets you apply JAX transformations (like @jax.jit and/or @jax.vmap) to functions with common Python constructs that JAX cannot itself handle, such as if conditions that depend on input values.
Installation
pip install jaxify
Getting started
importjaximportjax.numpyasjnpfromjaxifyimportjaxify@jax.jit@jax.vmap@jaxify# <-- Just decorate your function with @jaxifydefabsolute_value(x):
ifx>=0: # <-- If block in a JIT-compiled functionreturnxelse:
return-xxs=jnp.arange(-1000, 1000)
ys=absolute_value(xs) # <-- Runs at JAX speed!print(ys)
How it works
The @jaxify decorator transforms Python functions using static analysis to replace unsupported Python constructs with JAX-compatible alternatives. After the transformations, the functions become traceable by JAX, enabling you to apply functional JAX transformations like @jax.jit and @jax.vmap in a seamless manner.
Compatibility status
The following Python constructs are currently supported within @jaxify-decorated functions:
π Conditionals
Construct
Works?
Notes
if statements
β
Fully supported including elif and else clauses. Translated to calls to jax.lax.cond