Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 99 additions & 1 deletion pymarlin/utils/checkpointer/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
import re
from typing import Optional, Dict
from typing import Optional, Dict, Tuple, Callable
from abc import ABC, abstractmethod
from operator import itemgetter
from dataclasses import dataclass
Expand Down Expand Up @@ -291,3 +291,101 @@ def check_mk_dir(self, dirpath: str) -> None:
os.makedirs(dirpath)
assert os.path.isdir(dirpath), "supplied checkpoint dirpath "\
"is not a directory"


@dataclass
class BestCheckpointerArguments(DefaultCheckpointerArguments):
"""Additional arguments for checkpointer

metric_name: name of metric where minimal is defined as best. Must be a registered buffer in module interface
save_intermediate_checkpoints: whether to produce a checkpointer every epoch in addition to latest and best.
load_best: whether to load best or latest checkpoint. Default behavior is to load latest.
"""
metric_name: str = "val_perplexity"
init_metric_val: Optional[float] = None
criteria: Optional[Tuple[str, Callable]] = "min"
save_intermediate_checkpoints: bool = False # not usually necessary in practice
load_best: bool = False # default to load latest


class BestCheckpointer(DefaultCheckpointer):
"""
Saves best and latest checkpoint. Best checkpoint is defined as the smallest value of a given parameter in the
module interface. Therefore this checkpointer works by relying on the parameter defined in metric_name existing as a
single value. By default it checks "val_perplexity" which is a registered buffer in `AbstractUserMessageReplyModule`
that gets updated after every call to `on_end_val_epoch`.
"""
def __init__(self, args: BestCheckpointerArguments):
super().__init__(args)
self.best_checkpoint_name = f"{self.args.file_prefix}_best_checkpoint.{self.args.file_ext}"
self.latest_checkpoint_name = f"{self.args.file_prefix}_latest_checkpoint.{self.args.file_ext}"
if self.args.criteria == 'min':
self.criteria_func = lambda new, old: new < old
self.best_metric = float('inf')
elif self.args.criteria == 'max':
self.criteria_func = lambda new, old: new > old
self.best_metric = -float('inf')
else:
self.criteria_func = self.args.criteria
self.best_metric = self.args.init_metric_value

if self.args.init_metric_value is not None:
self.best_metric = self.args.init_metric_value

def save(self, checkpoint_state: Checkpoint, index: int, force=False) -> str:
"""
Saves trainer, optimizer, and module interface state.

Args:
checkpoint_state: instance of `Checkpoint` which contains trainer, optimizer, and module interface state
index: current epoch number
force: whether to force a save even if period of checkpointing does not line up

Returns:
list of paths checkpoint state was saved to
"""
paths = []
if self.args.save_intermediate_checkpoints:
paths.append(super().save(checkpoint_state, index, force))
if self.args.checkpoint:
# TODO grab this from logged metrics instead, checkpoint state is hacky
self.logger.debug(f"Available metrics {checkpoint_state.module_interface_state.keys()}")
metric = float(checkpoint_state.module_interface_state[self.args.metric_name])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part feels hacky if not pulled from metrics in terms of the design patterns used here, but either way I think adding logging here would be good (available metrics, metric selected and its value)

self.logger.info(f"epoch {index}: metric {self.args.metric_name}={metric}, best score={self.best_metric}")

# optiionally save best
if self.criteria_func(metric, self.best_metric):
self.best_metric = metric
best_path = os.path.join(self.args.save_dir, self.best_checkpoint_name)
torch.save(checkpoint_state.__dict__, best_path)
paths.append(best_path)

# save latest
latest_path = os.path.join(self.args.save_dir, self.latest_checkpoint_name)
torch.save(checkpoint_state.__dict__, latest_path)
paths.append(latest_path)
return paths

def load(self) -> Checkpoint:
"""
Optionally loads a checkpoint from a given directory. Either loads a specified filename, the best checkpoint, or
the latest checkpoint. Raises a `ValueError` upon failure to load checkpoint.

Returns:
An instance of `Checkpoint`
"""
if self.args.load_dir:
if self.args.load_filename:
load_path = os.path.join(self.args.load_dir, self.args.load_filename)
elif self.args.load_best:
load_path = os.path.join(self.args.load_dir, self.best_checkpoint_name)
else:
load_path = os.path.join(self.args.load_dir, self.latest_checkpoint_name)

# TODO how to set best metric to match loaded checkpoint?
self.logger.debug(f"loading checkpoint from {load_path}")
checkpoint = torch.load(load_path, map_location=torch.device('cpu'))
self.logger.debug('Checkpoint loaded')
return Checkpoint(**checkpoint)

return Checkpoint()