diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py index ecce5e29..ceb148d5 100644 --- a/src/state/tx/models/state_transition.py +++ b/src/state/tx/models/state_transition.py @@ -369,6 +369,62 @@ def encode_basal_expression(self, expr: torch.Tensor) -> torch.Tensor: """Define how we embed basal state input, if needed.""" return self.basal_encoder(expr) + def _compute_batch_token_loss(self, batch: Dict[str, torch.Tensor], padded: bool) -> Optional[torch.Tensor]: + """Compute CE loss for the optional batch token from cached token features. + + Returns None if batch token training is disabled or cache is unavailable. + """ + if not (self.use_batch_token and self.batch_classifier is not None and self._batch_token_cache is not None): + return None + + logits = self.batch_classifier(self._batch_token_cache) # [B, 1, C] + batch_token_targets = batch["batch"] + + B = logits.shape[0] + C = logits.size(-1) + + # Prepare one label per sequence (all S cells share the same batch) + if batch_token_targets.dim() > 1 and batch_token_targets.size(-1) == C: + # One-hot labels; reshape to [B, S, C] + if padded: + target_oh = batch_token_targets.reshape(-1, self.cell_sentence_len, C) + else: + target_oh = batch_token_targets.reshape(1, -1, C) + sentence_batch_labels = target_oh.argmax(-1) + else: + # Integer labels; reshape to [B, S] + if padded: + sentence_batch_labels = batch_token_targets.reshape(-1, self.cell_sentence_len) + else: + sentence_batch_labels = batch_token_targets.reshape(1, -1) + + if sentence_batch_labels.shape[0] != B: + sentence_batch_labels = sentence_batch_labels.reshape(B, -1) + + if self.basal_mapping_strategy == "batch": + uniform_mask = sentence_batch_labels.eq(sentence_batch_labels[:, :1]).all(dim=1) + if not torch.all(uniform_mask): + bad_indices = torch.where(~uniform_mask)[0] + label_strings = [] + for idx in bad_indices: + labels = sentence_batch_labels[idx].detach().cpu().tolist() + logger.error("Batch labels for sentence %d: %s", idx.item(), labels) + label_strings.append(f"sentence {idx.item()}: {labels}") + raise ValueError( + "Expected all cells in a sentence to share the same batch when " + "basal_mapping_strategy is 'batch'. " + f"Found mixed batch labels: {', '.join(label_strings)}" + ) + + target_idx = sentence_batch_labels[:, 0] + + # Safety: ensure exactly one target per sequence + if target_idx.numel() != B: + target_idx = target_idx.reshape(-1)[:B] + + ce_loss = F.cross_entropy(logits.reshape(B, -1, C).squeeze(1), target_idx.long()) + return ce_loss + def forward(self, batch: dict, padded=True) -> torch.Tensor: """ The main forward call. Batch is a flattened sequence of cell sentences, @@ -431,11 +487,11 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: if self.hparams.get("mask_attn", False): batch_size, seq_length, _ = seq_input.shape device = seq_input.device - self.transformer_backbone._attn_implementation = "eager" # pyright: ignore[reportAttributeAccessIssue, reportArgumentType] + self.transformer_backbone._attn_implementation = "eager" # pyright: ignore[reportAttributeAccessIssue, reportArgumentType] # create a [1,1,S,S] mask (now S+1 if confidence token is used) base = torch.eye(seq_length, device=device, dtype=torch.bool).view(1, 1, seq_length, seq_length) - + # Get number of attention heads from model config num_heads = self.transformer_backbone.config.num_attention_heads @@ -529,53 +585,8 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T decoder_loss = None total_loss = main_loss - if self.use_batch_token and self.batch_classifier is not None and self._batch_token_cache is not None: - logits = self.batch_classifier(self._batch_token_cache) # [B, 1, C] - batch_token_targets = batch["batch"] - - B = logits.shape[0] - C = logits.size(-1) - - # Prepare one label per sequence (all S cells share the same batch) - if batch_token_targets.dim() > 1 and batch_token_targets.size(-1) == C: - # One-hot labels; reshape to [B, S, C] - if padded: - target_oh = batch_token_targets.reshape(-1, self.cell_sentence_len, C) - else: - target_oh = batch_token_targets.reshape(1, -1, C) - sentence_batch_labels = target_oh.argmax(-1) - else: - # Integer labels; reshape to [B, S] - if padded: - sentence_batch_labels = batch_token_targets.reshape(-1, self.cell_sentence_len) - else: - sentence_batch_labels = batch_token_targets.reshape(1, -1) - - if sentence_batch_labels.shape[0] != B: - sentence_batch_labels = sentence_batch_labels.reshape(B, -1) - - if self.basal_mapping_strategy == "batch": - uniform_mask = sentence_batch_labels.eq(sentence_batch_labels[:, :1]).all(dim=1) - if not torch.all(uniform_mask): - bad_indices = torch.where(~uniform_mask)[0] - label_strings = [] - for idx in bad_indices: - labels = sentence_batch_labels[idx].detach().cpu().tolist() - logger.error("Batch labels for sentence %d: %s", idx.item(), labels) - label_strings.append(f"sentence {idx.item()}: {labels}") - raise ValueError( - "Expected all cells in a sentence to share the same batch when " - "basal_mapping_strategy is 'batch'. " - f"Found mixed batch labels: {', '.join(label_strings)}" - ) - - target_idx = sentence_batch_labels[:, 0] - - # Safety: ensure exactly one target per sequence - if target_idx.numel() != B: - target_idx = target_idx.reshape(-1)[:B] - - ce_loss = F.cross_entropy(logits.reshape(B, -1, C).squeeze(1), target_idx.long()) + ce_loss = self._compute_batch_token_loss(batch, padded=padded) + if ce_loss is not None: self.log("train/batch_token_loss", ce_loss) total_loss = total_loss + self.batch_token_weight * ce_loss @@ -668,6 +679,11 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non self.log("val/sinkhorn_loss", sinkhorn_component) self.log("val/energy_loss", energy_component) + # Log batch token loss during validation without adding to validation loss + ce_loss_val = self._compute_batch_token_loss(batch, padded=True) + if ce_loss_val is not None: + self.log("val/batch_token_loss", ce_loss_val) + if self.gene_decoder is not None and "pert_cell_counts" in batch: gene_targets = batch["pert_cell_counts"]