[Newton] Adds Warp based inhand_manipulation env#4413
[Newton] Adds Warp based inhand_manipulation env#4413kellyguo11 merged 9 commits intoisaac-sim:dev/newtonfrom
Conversation
1. con success 10.67 2. used the same MJWarpSolver as torch env
|
Attaching performance data Performance summary (torch → warp)
Δ% is computed as ((\text{warp} - \text{torch}) / \text{torch} \times 100%), so negative time Δ% = less time (better). |
Greptile OverviewGreptile SummaryThis PR implements a Warp-accelerated in-hand manipulation environment for the Allegro Hand robot, enabling high-performance parallel simulation across 8192 environments using GPU kernels. Key changes:
Architecture:
Minor issues found:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Gym
participant InHandManipulationWarpEnv
participant DirectRLEnvWarp
participant WarpKernels
participant Hand as Articulation (Hand)
participant Object as Articulation (Object)
User->>Gym: register environment
Gym->>InHandManipulationWarpEnv: create with AllegroHandWarpEnvCfg
InHandManipulationWarpEnv->>DirectRLEnvWarp: __init__(cfg)
InHandManipulationWarpEnv->>InHandManipulationWarpEnv: _setup_scene()
InHandManipulationWarpEnv->>Hand: create Articulation(robot_cfg)
InHandManipulationWarpEnv->>Object: create Articulation(object_cfg)
InHandManipulationWarpEnv->>WarpKernels: initialize_rng_state()
InHandManipulationWarpEnv->>WarpKernels: initialize_goal_constants()
InHandManipulationWarpEnv->>WarpKernels: initialize_xyz_unit_vecs()
User->>InHandManipulationWarpEnv: step(actions)
InHandManipulationWarpEnv->>InHandManipulationWarpEnv: _pre_physics_step(actions)
InHandManipulationWarpEnv->>InHandManipulationWarpEnv: _apply_action()
InHandManipulationWarpEnv->>WarpKernels: apply_actions_to_targets()
InHandManipulationWarpEnv->>Hand: set_joint_position_target()
InHandManipulationWarpEnv->>DirectRLEnvWarp: simulate physics
InHandManipulationWarpEnv->>InHandManipulationWarpEnv: _get_dones()
InHandManipulationWarpEnv->>WarpKernels: compute_intermediate_values()
InHandManipulationWarpEnv->>WarpKernels: get_dones()
InHandManipulationWarpEnv->>InHandManipulationWarpEnv: _get_observations()
InHandManipulationWarpEnv->>WarpKernels: compute_full_observations()
InHandManipulationWarpEnv->>WarpKernels: sanitize_and_print_once()
InHandManipulationWarpEnv->>InHandManipulationWarpEnv: _get_rewards()
InHandManipulationWarpEnv->>WarpKernels: compute_rewards()
InHandManipulationWarpEnv->>WarpKernels: update_consecutive_successes_from_stats()
InHandManipulationWarpEnv->>InHandManipulationWarpEnv: _reset_target_pose()
InHandManipulationWarpEnv->>WarpKernels: reset_target_pose()
alt Reset Required
InHandManipulationWarpEnv->>InHandManipulationWarpEnv: _reset_idx(mask)
InHandManipulationWarpEnv->>WarpKernels: reset_object()
InHandManipulationWarpEnv->>Object: update root_link_pose_w
InHandManipulationWarpEnv->>WarpKernels: reset_hand()
InHandManipulationWarpEnv->>Hand: update joint_pos/joint_vel
InHandManipulationWarpEnv->>WarpKernels: reset_successes()
end
InHandManipulationWarpEnv-->>User: obs, reward, done, info
|
| def rotation_distance(object_rot: wp.quatf, target_rot: wp.quatf) -> wp.float32: | ||
| # Orientation alignment for the cube in hand and goal cube | ||
| quat_diff = quat_mul(object_rot, quat_conjugate(target_rot)) | ||
| # Match Torch env convention: uses indices [1:4] for the vector part (see `rotation_distance` in Torch env). | ||
| v_norm = wp.sqrt(quat_diff[1] * quat_diff[1] + quat_diff[2] * quat_diff[2] + quat_diff[3] * quat_diff[3]) | ||
| v_norm = wp.min(v_norm, wp.float32(1.0)) | ||
| return wp.float32(2.0) * wp.asin(v_norm) |
There was a problem hiding this comment.
vector part indexing in comment inconsistent with implementation
The comment says "uses indices [1:4]" but the implementation correctly uses indices 1, 2, 3 (which is [1:4) in Python slicing). The comment should clarify this is the xyz components (indices 1,2,3) not including the w component (index 3 in some conventions).
| def rotation_distance(object_rot: wp.quatf, target_rot: wp.quatf) -> wp.float32: | |
| # Orientation alignment for the cube in hand and goal cube | |
| quat_diff = quat_mul(object_rot, quat_conjugate(target_rot)) | |
| # Match Torch env convention: uses indices [1:4] for the vector part (see `rotation_distance` in Torch env). | |
| v_norm = wp.sqrt(quat_diff[1] * quat_diff[1] + quat_diff[2] * quat_diff[2] + quat_diff[3] * quat_diff[3]) | |
| v_norm = wp.min(v_norm, wp.float32(1.0)) | |
| return wp.float32(2.0) * wp.asin(v_norm) | |
| # Orientation alignment for the cube in hand and goal cube | |
| quat_diff = quat_mul(object_rot, quat_conjugate(target_rot)) | |
| # Match Torch env convention: uses xyz components (indices 1, 2, 3) for the vector part (see `rotation_distance` in Torch env). | |
| v_norm = wp.sqrt(quat_diff[1] * quat_diff[1] + quat_diff[2] * quat_diff[2] + quat_diff[3] * quat_diff[3]) | |
| v_norm = wp.min(v_norm, wp.float32(1.0)) | |
| return wp.float32(2.0) * wp.asin(v_norm) |
| # unit vectors | ||
| self.x_unit_vecs = wp.zeros(self.num_envs, dtype=wp.vec3f, device=self.device) | ||
| self.y_unit_vecs = wp.zeros(self.num_envs, dtype=wp.vec3f, device=self.device) | ||
| self.z_unit_vecs = wp.zeros(self.num_envs, dtype=wp.vec3f, device=self.device) |
There was a problem hiding this comment.
z_unit_vecs initialized but never used
The z-axis unit vector is initialized but never referenced in any kernel or method. Consider removing it or adding a comment explaining why it's reserved for future use.
AntoineRichard
left a comment
There was a problem hiding this comment.
A couple of nits around the warp code. Otherwise it looks good to me.
| @wp.func | ||
| def quat_mul(q1: wp.quatf, q2: wp.quatf) -> wp.quatf: | ||
| # Hamilton product for quaternions in (x, y, z, w). | ||
| x1, y1, z1, w1 = q1[0], q1[1], q1[2], q1[3] | ||
| x2, y2, z2, w2 = q2[0], q2[1], q2[2], q2[3] | ||
| x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 | ||
| y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 | ||
| z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 | ||
| w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 | ||
| return wp.quatf(x, y, z, w) |
There was a problem hiding this comment.
There is already native support for that in warp. q1*q2 does it.
| def quat_conjugate(q: wp.quatf) -> wp.quatf: | ||
| return wp.quatf(-q[0], -q[1], -q[2], q[3]) | ||
|
|
There was a problem hiding this comment.
There is already warp support for that: wp.quat_inverse()
| @wp.func | ||
| def quat_from_angle_axis(angle: wp.float32, axis: wp.vec3f) -> wp.quatf: | ||
| # axis assumed to be unit-length in this task. | ||
| half = angle * wp.float32(0.5) | ||
| s = wp.sin(half) | ||
| c = wp.cos(half) | ||
| return wp.quatf(axis[0] * s, axis[1] * s, axis[2] * s, c) |
There was a problem hiding this comment.
There is already a function for that: quat_from_axis_angle
| wp.launch( | ||
| initialize_xyz_unit_vecs, | ||
| dim=self.num_envs, | ||
| inputs=[ | ||
| self.x_unit_vecs, | ||
| self.y_unit_vecs, | ||
| self.z_unit_vecs, | ||
| ], | ||
| device=self.device, | ||
| ) |
There was a problem hiding this comment.
This could be replaced by
self.x_unit_vec = wp.vec3f(1.0,0.0,0.0)
self.y_unit_vec = wp.vec3f(0.0,1.0,0.0)
self.z_unit_vec = wp.vec3f(0.0,0.0,1.0)There was a problem hiding this comment.
The torch env do have per env unit vector for some reason. Confirming if that's unnecessary now.
| x_unit_vecs: wp.array(dtype=wp.vec3f), | ||
| y_unit_vecs: wp.array(dtype=wp.vec3f), |
There was a problem hiding this comment.
These do not need to be arrays.
| object_pos: wp.array(dtype=wp.vec3f), | ||
| object_rot: wp.array(dtype=wp.quatf), |
There was a problem hiding this comment.
These could be fed as object_pose and a transformf. The risk is that if the tensor if not contiguous we are launching kernels to split the poses.
| object_linvel: wp.array(dtype=wp.vec3f), | ||
| object_angvel: wp.array(dtype=wp.vec3f), |
There was a problem hiding this comment.
These could be fed as a spatial vector directly. The risk is that if the tensor if not contiguous we are launching kernels to split the velocities.
| object_pos: wp.array(dtype=wp.vec3f), | ||
| object_rot: wp.array(dtype=wp.quatf), |
There was a problem hiding this comment.
Same here could be a transform directly.
There was a problem hiding this comment.
Albeit there is no much gain since there is no transform ops.
There was a problem hiding this comment.
Yea. In this env, it looks like there's no direct operation on transformf. Converting to a transform type would require extracting pos and rot everywhere, which might be less convenient. Do you think it's required?
| target_pos: wp.array(dtype=wp.vec3f), | ||
| target_rot: wp.array(dtype=wp.quatf), |
There was a problem hiding this comment.
This could also be a transform
There was a problem hiding this comment.
Correct me if I am wrong. I think it gets a little complicated for this. goal_rot is the target_rot, but it's also used for goal marker visualization which takes a goal_pos + goal_rot in torch. Combing them as single transformf will still require a separate kernel to extract those for goal_marker visualization purpose, which currently only takes np array or torch tensor.
| float(self.cfg.dist_reward_scale), | ||
| float(self.cfg.rot_reward_scale), | ||
| float(self.cfg.rot_eps), |
There was a problem hiding this comment.
Is this float conversion needed?
Performance summary (
|
| Metric | Torch | Warp | Δ% (torch→warp) |
|---|---|---|---|
| Action processing mean (us, N=9600) | 1502.42 | 36.77 | -97.55% |
| Newton simulation mean (us, N=9600) | 17604.23 | 17201.77 | -2.29% |
| Post-processing mean (us, N=4800) | 6708.68 | 168.72 | -97.49% |
| Total step mean (us, N=4800) | 45722.28 | 35529.03 | -22.29% |
| Throughput (steps/s) | 70185 | 85147 | +21.32% |
| Iteration time (s) | 0.93 | 0.77 | -17.20% |
| Collection time (s) | 0.762 | 0.600 | -21.26% |
| Learning time (s) | 0.171 | 0.170 | -0.58% |
Δ% is computed as (((\text{warp}-\text{torch})/\text{torch})\times 100%), so negative time Δ% = less time (better).
The potential joint sampling issue is summarized in #4404
It does seem that updating sampling, which previous included in the performance stats, puts initial configuration into a harder case such that Newton simulation takes more time. Removing the sampling fix puts warp and torch into the same condition for comparision.
Description
Add warp env for inhand_manipulation
Fixes # (issue)
Type of change
Screenshots
Please attach before and after screenshots of the change if applicable.
Checklist
pre-commitchecks with./isaaclab.sh --formatconfig/extension.tomlfileCONTRIBUTORS.mdor my name already exists there