Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions ScaFFold/configs/benchmark_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ variance_threshold: 0.15 # Variance threshold for valid fractals. Defa
n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3.
val_split: 25 # In percent.
epochs: -1 # Number of training epochs.
learning_rate: .0001 # Learning rate for training.
disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR.
starting_learning_rate: 0.0025 # Starting LR used at scale_reference 6.
scale_learning_rate_factor: 0.5 # Multiply starting LR by this factor for each +1 increase in problem_scale.
gamma: 0.99 # ExponentialLR decay factor applied each epoch when scheduler is enabled.
min_learning_rate: 0.0001 # Floor for scheduler-adjusted learning rate.
disable_scheduler: 0 # If 1, disable scheduler during training to use constant LR.
more_determinism: 0 # If 1, improve model training determinism.
datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch.
train_from_scratch: 1 # If 1, delete existing train stats and checkpoint files. Keep 0 if want to restart runs where we left off.
Expand Down
7 changes: 5 additions & 2 deletions ScaFFold/configs/benchmark_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ variance_threshold: 0.15 # Variance threshold for valid fractals. Defa
n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3.
val_split: 25 # In percent.
epochs: 10 # Number of training epochs.
learning_rate: .0001 # Learning rate for training.
disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR.
starting_learning_rate: 0.005 # Starting LR used at scale_reference 6.
scale_learning_rate_factor: 0.5 # Multiply starting LR by this factor for each +1 increase in problem_scale.
gamma: 0.99 # ExponentialLR decay factor applied each epoch when scheduler is enabled.
min_learning_rate: 0.0001 # Floor for scheduler-adjusted learning rate.
disable_scheduler: 0 # If 1, disable scheduler during training to use constant LR.
more_determinism: 0 # If 1, improve model training determinism.
datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch.
train_from_scratch: 1 # If 1, delete existing train stats and checkpoint files. Keep 0 if want to restart runs where we left off.
Expand Down
10 changes: 9 additions & 1 deletion ScaFFold/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,15 @@ def __init__(self, config_dict):
self.seed = config_dict["seed"]
self.dist = bool(config_dict["dist"])
self.framework = config_dict["framework"]
self.learning_rate = config_dict["learning_rate"]
self.starting_learning_rate = config_dict["starting_learning_rate"]
self.scale_learning_rate_factor = config_dict["scale_learning_rate_factor"]
self.starting_learning_rate = (
self.starting_learning_rate
* self.scale_learning_rate_factor
** (self.problem_scale - 6) # Reference problem scale is 6
)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed instability at higher problem scales that is fixed by lowering the LR

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed this is correlated with turning on AMP

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no AMP
image
vs AMP
image

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably should just omit this logic and avoid using AMP

self.gamma = config_dict["gamma"]
self.min_learning_rate = config_dict["min_learning_rate"]
self.variance_threshold = config_dict["variance_threshold"]
self.torch_amp = bool(config_dict["torch_amp"])
self.loss_freq = config_dict["loss_freq"]
Expand Down
9 changes: 7 additions & 2 deletions ScaFFold/utils/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
def evaluate(
net, dataloader, device, amp, primary, criterion, n_categories, parallel_strategy
):
def foreground_dice_mean(dice_scores):
if dice_scores.size(1) > 1:
return dice_scores[:, 1:].mean()
return dice_scores.mean()

net.eval()
num_val_batches = len(dataloader)
total_dice_score = 0.0
Expand Down Expand Up @@ -118,11 +123,11 @@ def evaluate(
dice_score_probs = compute_sharded_dice(
mask_pred_probs, mask_true_onehot, spatial_mesh
)
dice_loss_curr = 1.0 - dice_score_probs.mean()
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was background + foreground -> now foreground

dice_loss_curr = 1.0 - foreground_dice_mean(dice_score_probs)

# Eval metric (excluding background class 0)
# dice_score_probs shape is [Batch, Channels]. We slice [:, 1:] to drop background
batch_dice_score = dice_score_probs[:, 1:].mean()
batch_dice_score = foreground_dice_mean(dice_score_probs)

# --- Combine and Accumulate ---
loss = CE_loss + dice_loss_curr
Expand Down
56 changes: 34 additions & 22 deletions ScaFFold/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,24 @@ def setup_training_components(self):
if self.config.optimizer == "ADAM":
self.log.info("Using ADAM optimizer.")
self.optimizer = optim.Adam(
self.model.parameters(), lr=self.config.learning_rate
self.model.parameters(), lr=self.config.starting_learning_rate
)
elif self.config.optimizer == "SGD":
self.log.info("Using SGD optimizer.")
self.optimizer = optim.SGD(
self.model.parameters(), lr=self.config.learning_rate
self.model.parameters(), lr=self.config.starting_learning_rate
)
else:
self.log.info("Using RMSprop optimizer.")
self.optimizer = optim.RMSprop(
self.model.parameters(), lr=self.config.learning_rate, foreach=True
self.model.parameters(),
lr=self.config.starting_learning_rate,
foreach=True,
)

# Set up learning rate scheduler
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, "max", patience=25
self.scheduler = optim.lr_scheduler.ExponentialLR(
self.optimizer, gamma=self.config.gamma
)

# Set up gradient scaler for AMP (Automatic Mixed Precision)
Expand All @@ -186,6 +188,24 @@ def setup_training_components(self):
f"Optimizer: {self.optimizer}, Scheduler: {self.scheduler}, Gradient Scaler Enabled: {self.config.torch_amp}"
)

@staticmethod
def _foreground_dice_mean(dice_scores):
"""Match optimization to the reported validation metric by excluding background."""
if dice_scores.size(1) > 1:
return dice_scores[:, 1:].mean()
return dice_scores.mean()

def _maybe_step_scheduler(self):
"""Apply scheduler updates when enabled."""
if self.config.disable_scheduler:
self.log.debug("scheduler disabled, no LR update this step")
return

self.scheduler.step()
for param_group in self.optimizer.param_groups:
if param_group["lr"] < self.config.min_learning_rate:
param_group["lr"] = self.config.min_learning_rate


class PyTorchTrainer(BaseTrainer):
"""
Expand Down Expand Up @@ -436,7 +456,7 @@ def warmup(self):
dice_scores = compute_sharded_dice(
local_preds_softmax, local_labels_one_hot, self.spatial_mesh
)
loss_dice = 1.0 - dice_scores.mean()
loss_dice = 1.0 - self._foreground_dice_mean(dice_scores)

# 3. Combine Loss
loss = loss_ce + loss_dice
Expand Down Expand Up @@ -641,11 +661,13 @@ def train(self):
local_labels_one_hot,
self.spatial_mesh,
)
loss_dice = 1.0 - dice_scores.mean()
loss_dice = 1.0 - self._foreground_dice_mean(dice_scores)

# 3. Combine Loss
loss = loss_ce + loss_dice
train_dice_total += dice_scores[:, 1:].mean().item()
train_dice_total += self._foreground_dice_mean(
dice_scores
).item()

end_code_region("calculate_loss")

Expand Down Expand Up @@ -698,19 +720,8 @@ def train(self):
dice_info, op=torch.distributed.ReduceOp.SUM
)
val_score = dice_info[0].item() / max(dice_info[1].item(), 1)
if not self.config.disable_scheduler:
# The following is true when trying to overfit,
# in which case we only care about train loss
if self.n_train == 1 or "overfit" in self.outfile_path:
self.log.debug(
"WARNING: scheduler step by overall_loss, \
not val_score (n_train==1 or overfit in outfile_path)"
)
self.scheduler.step(overall_loss)
else: # Otherwise, we're really trying to optimize for validation dice score
self.scheduler.step(val_score)
else:
self.log.debug("scheduler disabled, no LR update this step")
self._maybe_step_scheduler()
current_lr = self.optimizer.param_groups[0]["lr"]

epoch_end_time = time.time()
epoch_duration = epoch_end_time - epoch_start_time
Expand All @@ -721,7 +732,8 @@ def train(self):
self.log.info(
f" epoch {epoch} \
| train_dice_loss {train_dice:.6f} (type {type(train_dice)}) \
| val_dice_score {val_score:.6f}"
| val_dice_score {val_score:.6f} \
| lr {current_lr:.8f}"
)
self.log.debug(f" writing to csv at {self.outfile_path}")
if self.world_rank == 0:
Expand Down
Loading