diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index ae5cb5a343..1150261bb8 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -160,7 +160,7 @@ async def run_async(self): await asyncio.sleep(0.1) memory.append({"role": "user", "content": content}) memory.append({"role": "assistant", "content": content.upper()}) - experience = self.process_messages_to_experience(memory, 0, {}) + experience = await self.process_messages_to_experience_async(memory, 0, {}) experience_list.append(experience) return experience_list diff --git a/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py index 15050dc0eb..4bc34833cb 100644 --- a/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py @@ -10,7 +10,7 @@ generate_default_empty_experience, get_jinja_env, parse_response, - process_messages_to_experience, + process_messages_to_experience_async, validate_trajectory_format, ) from trinity.common.workflows.workflow import Task, Workflow @@ -202,7 +202,7 @@ async def run_async(self) -> List[Experience]: if reward >= 1 and traj_format_valid: print("✅ Task completed successfully in the first attempt!") - experience = process_messages_to_experience( + experience = await process_messages_to_experience_async( self.model, trajectory, info={"success": success, "reward": reward, "steps": steps} ) return [experience] diff --git a/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py index 589fb9a8e6..be2889a2b5 100644 --- a/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py @@ -14,7 +14,7 @@ generate_default_empty_experience, generate_reward_feedback, parse_response, - process_messages_to_experience, + process_messages_to_experience_async, save_task_data, validate_trajectory_format, ) @@ -215,9 +215,9 @@ def _should_keep_for_sft(self, second_traj_format_valid: bool, re_explore_info: or (re_explore_info["efficiency_improved"] and re_explore_info["new_reward"] >= 1.0) ) - def _generate_experience_from_sft(self, sft_messages: list, metrics: dict) -> Experience: + async def _generate_experience_from_sft(self, sft_messages: list, metrics: dict) -> Experience: """Generate experience from SFT messages""" - return process_messages_to_experience(self.model, sft_messages, info=metrics) + return await process_messages_to_experience_async(self.model, sft_messages, info=metrics) async def run_async(self) -> List[Experience]: """Run the RAFT alfworld workflow and return experiences""" @@ -245,7 +245,7 @@ async def run_async(self) -> List[Experience]: # Handle first attempt success cases if reward >= 1 and traj_format_valid: print("✅ Task completed successfully in the first attempt!") - experience = process_messages_to_experience( + experience = await process_messages_to_experience_async( self.model, trajectory, info={"success": success, "reward": reward, "steps": steps} ) return [experience] @@ -275,7 +275,7 @@ async def run_async(self) -> List[Experience]: kept_for_sft = self._should_keep_for_sft(second_traj_format_valid, re_explore_info) if kept_for_sft: - experience = self._generate_experience_from_sft(sft_messages, metrics) + experience = await self._generate_experience_from_sft(sft_messages, metrics) experiences.append(experience) print( f"✅ Generated good training data: orig={reward}, steps={steps}, new={re_explore_info['new_reward']}, new_steps={re_explore_info['new_steps']}" diff --git a/trinity/common/workflows/envs/alfworld/RAFT_utils.py b/trinity/common/workflows/envs/alfworld/RAFT_utils.py index 46b6f356a6..5e57ba597a 100644 --- a/trinity/common/workflows/envs/alfworld/RAFT_utils.py +++ b/trinity/common/workflows/envs/alfworld/RAFT_utils.py @@ -107,13 +107,13 @@ def create_alfworld_environment(game_file): raise ImportError(error_message) -def process_messages_to_experience(model, messages, info=None) -> Experience: +async def process_messages_to_experience_async(model, messages, info=None) -> Experience: """Convert messages to experience for training, with fallback to default empty experience""" if info is None: info = {} try: - converted_experience = model.convert_messages_to_experience(messages) + converted_experience = await model.convert_messages_to_experience_async(messages) metrics = {} for k, v in info.items(): diff --git a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py index 64fe07a6ce..9266483ee0 100644 --- a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py @@ -135,7 +135,7 @@ async def generate_env_inference_samples(self, env) -> List[Experience]: if done: final_reward = reward break - experience = self.process_messages_to_experience( + experience = await self.process_messages_to_experience_async( memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0} ) # Close the env to save cpu memory diff --git a/trinity/common/workflows/envs/frozen_lake/workflow.py b/trinity/common/workflows/envs/frozen_lake/workflow.py index 604b50282d..6d26a4775b 100644 --- a/trinity/common/workflows/envs/frozen_lake/workflow.py +++ b/trinity/common/workflows/envs/frozen_lake/workflow.py @@ -353,9 +353,9 @@ async def run_async(self) -> List[Experience]: # Create experience from messages final_reward = sum(self.step_rewards) # print(f"final_reward: {final_reward}, terminate_reason: {terminate_reason}") - experience = self.process_messages_to_experience( + experience = await self.process_messages_to_experience_async( messages=messages, - reward=final_reward, + reward=float(final_reward), info={ "env_steps": self.step_count, "env_done": 1 if self.done else 0, diff --git a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py index c9d9cdc684..beefc55295 100644 --- a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py +++ b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py @@ -107,7 +107,7 @@ async def generate_env_inference_samples(self, env, rollout_num) -> List[Experie if done: break final_reward = final_reward / 100.0 - experience = self.process_messages_to_experience( + experience = await self.process_messages_to_experience_async( memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0, "golden_rounds": golden_rounds}, diff --git a/trinity/common/workflows/envs/webshop/webshop_workflow.py b/trinity/common/workflows/envs/webshop/webshop_workflow.py index 7514965eba..e5d48dd9c1 100644 --- a/trinity/common/workflows/envs/webshop/webshop_workflow.py +++ b/trinity/common/workflows/envs/webshop/webshop_workflow.py @@ -258,9 +258,9 @@ async def generate_env_inference_samples( final_reward = 0 else: final_reward = -0.1 - experience = self.process_messages_to_experience( + experience = await self.process_messages_to_experience_async( memory, - final_reward, + float(final_reward), {"session_id": session_id, "env_rounds": r, "env_done": 1 if done else 0}, ) experience_list.append(experience) diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index cf3b4d449b..0c45466a70 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -176,11 +176,20 @@ def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base - def process_messages_to_experience( - self, messages, reward, info={}, truncate_status=None + def _build_experience_from_converted( + self, converted_experience, reward, info={}, truncate_status=None ) -> Experience: - converted_experience = self.model.convert_messages_to_experience(messages) + """Private helper method to build Experience from converted_experience. + + Args: + converted_experience: The converted experience from the model. + reward: The reward value. + info: Additional info dictionary. + truncate_status: Optional truncate status to override. + Returns: + Experience: The constructed Experience object. + """ if converted_experience.truncate_status == "response_truncated": reward = 0.0 @@ -209,6 +218,22 @@ def process_messages_to_experience( ) return experience + def process_messages_to_experience( + self, messages, reward, info={}, truncate_status=None + ) -> Experience: + converted_experience = self.model.convert_messages_to_experience(messages) + return self._build_experience_from_converted( + converted_experience, reward, info, truncate_status + ) + + async def process_messages_to_experience_async( + self, messages, reward, info={}, truncate_status=None + ) -> Experience: + converted_experience = await self.model.convert_messages_to_experience_async(messages) + return self._build_experience_from_converted( + converted_experience, reward, info, truncate_status + ) + class BaseSimpleWorkflow(Workflow): def __init__( diff --git a/trinity/trainer/verl/utils.py b/trinity/trainer/verl/utils.py index 640ee2b748..57aa3e0467 100644 --- a/trinity/trainer/verl/utils.py +++ b/trinity/trainer/verl/utils.py @@ -66,7 +66,8 @@ def to_data_proto( token_level_rewards = torch.zeros(attention_mask.shape, dtype=torch.float32) eos_mask_idx = cumsum.argmax(dim=-1) token_level_rewards[torch.arange(len(experiences)), eos_mask_idx] = torch.tensor( - [exp.reward for exp in experiences] + [exp.reward for exp in experiences], + dtype=torch.float32, ) token_level_rewards = token_level_rewards[:, max_prompt_length:] batch_dict.update(