diff --git a/openequivariance/openequivariance/jax/jvp/conv_prim.py b/openequivariance/openequivariance/jax/jvp/conv_prim.py index ca91014..8d38646 100644 --- a/openequivariance/openequivariance/jax/jvp/conv_prim.py +++ b/openequivariance/openequivariance/jax/jvp/conv_prim.py @@ -161,6 +161,16 @@ def conv_fwd_jvp_abstract_eval( conv_fwd_jvp_p.def_impl(conv_fwd_jvp_impl) conv_fwd_jvp_p.def_abstract_eval(conv_fwd_jvp_abstract_eval) +mlir.register_lowering( + conv_fwd_jvp_p, + mlir.lower_fun(conv_fwd_jvp_impl, multiple_results=False), + platform="cuda", +) +mlir.register_lowering( + conv_fwd_jvp_p, + mlir.lower_fun(conv_fwd_jvp_impl, multiple_results=False), + platform="rocm", +) # ============================================================================== @@ -285,6 +295,16 @@ def conv_bwd_jvp_abstract_eval( conv_bwd_jvp_p.def_impl(conv_bwd_jvp_impl) conv_bwd_jvp_p.def_abstract_eval(conv_bwd_jvp_abstract_eval) +mlir.register_lowering( + conv_bwd_jvp_p, + mlir.lower_fun(conv_bwd_jvp_impl, multiple_results=True), + platform="cuda", +) +mlir.register_lowering( + conv_bwd_jvp_p, + mlir.lower_fun(conv_bwd_jvp_impl, multiple_results=True), + platform="rocm", +) # ============================================================================== diff --git a/openequivariance/openequivariance/jax/jvp/tp_prim.py b/openequivariance/openequivariance/jax/jvp/tp_prim.py index c31c3ec..3745f24 100644 --- a/openequivariance/openequivariance/jax/jvp/tp_prim.py +++ b/openequivariance/openequivariance/jax/jvp/tp_prim.py @@ -132,6 +132,16 @@ def tp_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash): tp_fwd_jvp_p.def_impl(tp_fwd_jvp_impl) tp_fwd_jvp_p.def_abstract_eval(tp_fwd_jvp_abstract_eval) +mlir.register_lowering( + tp_fwd_jvp_p, + mlir.lower_fun(tp_fwd_jvp_impl, multiple_results=False), + platform="cuda", +) +mlir.register_lowering( + tp_fwd_jvp_p, + mlir.lower_fun(tp_fwd_jvp_impl, multiple_results=False), + platform="rocm", +) # ============================================================================== @@ -225,7 +235,16 @@ def tp_bwd_jvp_abstract_eval(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash): tp_bwd_jvp_p.def_impl(tp_bwd_jvp_impl) tp_bwd_jvp_p.def_abstract_eval(tp_bwd_jvp_abstract_eval) - +mlir.register_lowering( + tp_bwd_jvp_p, + mlir.lower_fun(tp_bwd_jvp_impl, multiple_results=True), + platform="cuda", +) +mlir.register_lowering( + tp_bwd_jvp_p, + mlir.lower_fun(tp_bwd_jvp_impl, multiple_results=True), + platform="rocm", +) # ============================================================================== # 9. Transpose Rule for Backward JVP