diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5e3b3d2d..7feb303a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -59,7 +59,7 @@ repos: hooks: - id: codespell stages: [pre-commit, commit-msg] - args: ["--ignore-words-list", "Commun"] + args: ["--ignore-words-list", "Commun,Mater"] exclude: | (?x)( ^example_inputs/data/| diff --git a/CHANGELOG.md b/CHANGELOG.md index cf571068..5e0d22f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ All notable changes to this project will be documented in this file. ## [0.12.2.dev] ### Added +- reEWC fine-tuning with forgetting prevention for single-modal models: optional experience replay (`rehearsal`, `load_memory_path`, `mem_batch_size`, `mem_ratio`) and an Elastic Weight Consolidation penalty from a precomputed Fisher matrix (`continue.fisher_information`, `continue.opt_params`, `continue.ewc_lambda`), plus a `cosineannealingwarmuplr` scheduler - Support OpenEquivariance - Per-atom stress (atomic virial) support in LAMMPS pair_e3gnn and ASE calculator - `compute_atomic_virial` option in `SevenNetCalculator` @@ -14,6 +15,9 @@ All notable changes to this project will be documented in this file. - LAMMPS pair_e3gnn refactored to use pair-wise force (dE/dr) instead of position-based gradient. - Deploy no longer replaces force_output with ForceStressOutput; force/stress computed in LAMMPS C++ side. +### Fixed +- Load FlashTP-saved checkpoints (e.g. SevenNet-Nano) when FlashTP is unavailable by falling back to the e3nn backend, so they work for inference and fine-tuning without FlashTP installed. An explicit `enable_flash=True` still fails loud. + ## [0.12.1] ### Fixed - FlashTP with LAMMPS parallel in torch diff --git a/README.md b/README.md index 9ed904f6..420f67ea 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ Full documentation, including **installation**, **usage**, and **pretrained mode - GPU-parallelized molecular dynamics with LAMMPS - CUDA-accelerated D3 (van der Waals) dispersion - Multi-fidelity training for combining multiple databases with different calculation settings + - Fine-tuning with forgetting prevention (experience replay + Elastic Weight Consolidation) for continual learning - [Tensor product accelerators](https://sevennet.readthedocs.io/en/latest/user_guide/accelerator.html) @@ -71,3 +72,16 @@ If you utilize the pretrained model SevenNet-Omni or multi-task training strateg year = {2026}, } ``` + +If you utilize the reEWC forgetting-aware fine-tuning strategy for continual learning of pretrained universal machine-learning interatomic potentials, please cite the following paper: +```bib +@article{kim_efficient_2026, + title = {An Efficient Forgetting-Aware Fine-Tuning Framework for Pretrained Universal Machine-Learning Interatomic Potentials}, + volume = {12}, + doi = {10.1038/s41524-025-01895-w}, + number = {26}, + journal = {npj Comput. Mater.}, + author = {Kim, Jisu and Lee, Jiho and Oh, Sangmin and Park, Yutack and Hwang, Seungwoo and Han, Seungwu and Kang, Sungwoo and Kang, Youngho}, + year = {2026}, +} +``` diff --git a/docs/source/index.rst b/docs/source/index.rst index 2004844c..bc809039 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -16,6 +16,7 @@ SevenNet (Scalable EquiVariance-Enabled Neural Network) is a graph neural networ * GPU-parallelized molecular dynamics with LAMMPS * CUDA-accelerated D3 (van der Waals) dispersion * Multi-fidelity training for combining multiple databases with different calculation settings +* Fine-tuning with forgetting prevention (Experience replay + Elastic Weight Consolidation) for continual learning Installation diff --git a/docs/source/user_guide/index.rst b/docs/source/user_guide/index.rst index 349c10c4..44370601 100644 --- a/docs/source/user_guide/index.rst +++ b/docs/source/user_guide/index.rst @@ -26,5 +26,6 @@ SevenNet offers various pretrained models, MD engines (ASE, LAMMPS), and user in ase_calculator torchsim cli + reewc d3 note_book diff --git a/docs/source/user_guide/reewc.md b/docs/source/user_guide/reewc.md new file mode 100644 index 00000000..2a1099de --- /dev/null +++ b/docs/source/user_guide/reewc.md @@ -0,0 +1,63 @@ +# Forgetting-prevented (Continual-learning) fine-tuning (reEWC) + +Fine-tuning a pretrained model on a target system improves accuracy there, but the +model can lose accuracy on the original training domain (catastrophic forgetting). +reEWC mitigates this with two complementary mechanisms that can be used together or +separately: + +- **Experience replay (rehearsal)** -- replay an old-task "memory" set each training + step so the model keeps fitting it while learning the target data. +- **Elastic Weight Consolidation (EWC)** -- add a penalty + `lambda/2 * sum_i F_i (theta_i - theta*_i)^2` that anchors parameters to their + pre-fine-tuning values `theta*`, weighted by a precomputed Fisher matrix `F`. + +reEWC is for **single-modal** models (e.g. SevenNet-0, SevenNet-Nano). Multi-fidelity +(modal) models are not supported yet. + +## Getting started + +A ready-to-edit input with both mechanisms is available as a preset: + +```bash +sevenn preset reewc > input.yaml +``` + +The preset documents every key inline. Replay lives in the `data:` block +(`rehearsal`, `load_memory_path`, `mem_batch_size`, `mem_ratio`) and EWC in the +`train.continue:` block (`fisher_information`, `opt_params`, `ewc_lambda`). Every +reEWC key is optional; when none are set, training is unchanged. Remove the replay +block or the EWC keys to run only one mechanism. Run training as usual: + +```bash +sevenn train input.yaml -s +``` + +## Fisher information and reference parameters + +`fisher_information` and `opt_params` are **precomputed and consumed** -- SevenNet +does not estimate the Fisher matrix. Both are `torch.save`d dictionaries keyed by +parameter name; `opt_params` is the parameter set of the checkpoint before +fine-tuning. They must satisfy: + +- `fisher_information` and `opt_params` cover the **same parameter names** with the + **same shapes** (they are a matched pair). +- Names that overlap with the model's trainable parameters must have **matching + shapes**; a mismatch is an error (usually an incompatible checkpoint or SevenNet + version). +- At least one name must overlap with the model; no overlap is an error. +- A trainable parameter without a Fisher entry is **left unconstrained** and a + warning is emitted, so partial-coverage Fisher matrices are allowed but visible. + +`ewc_lambda` must be `> 0`, and EWC requires both `fisher_information` and +`opt_params` to be set. + +## Notes + +- Replay supports `dataset_type: 'graph'` (the default) only. +- reEWC does not support distributed (DDP) training. +- `load_memory_path` is reserved for replay: setting it without `rehearsal: True` + raises an error. +- When replay is enabled, the memory set is evaluated each epoch and logged as a + `memoryset` column group in `lc.csv`, alongside `trainset` and `validset`. +- A `cosineannealingwarmuplr` scheduler (cosine annealing with warm-up restarts, + used for the reEWC paper work) is also available for fine-tuning. diff --git a/example_inputs/training/input_full.yaml b/example_inputs/training/input_full.yaml index 69787538..57b75412 100644 --- a/example_inputs/training/input_full.yaml +++ b/example_inputs/training/input_full.yaml @@ -73,6 +73,11 @@ train: #checkpoint: 'checkpoint_best.pth' # Checkpoint of pre-trained model or a model want to continue training. #reset_optimizer: False # Set True for fine-tuning #reset_scheduler: False # Set True for fine-tuning + # reEWC (single-modal models only): add an Elastic Weight Consolidation penalty from a + # precomputed Fisher matrix and reference parameters to preserve prior-task accuracy. + #fisher_information: './fisher.pt' # dict {param_name: tensor} of precomputed Fisher information + #opt_params: './opt_params.pt' # dict {param_name: tensor} of reference (pre-finetune) parameters + #ewc_lambda: 100000 # EWC penalty weight (must be > 0 when fisher/opt are given) data: batch_size: 4 # Per GPU batch size. @@ -91,3 +96,10 @@ data: load_trainset_path: ['./structure_list'] # Example of using ase as data_format, support multiple files and expansion(*) #load_validset_path: ['./valid.extxyz'] #load_testset_path: ['./sevenn_data/mydata.pt'] # Graph can be preprocessed using `sevenn_graph_build` and accessible like this + + # reEWC rehearsal (experience replay, single-modal models only): replay an old-task memory set + # each training step to mitigate catastrophic forgetting while fine-tuning on the target data. + #rehearsal: False # Set True to enable replay + #load_memory_path: ['./memory.extxyz'] # memory (old-task) set; this key is reserved for rehearsal + #mem_batch_size: 8 # batch size for the replayed memory set + #mem_ratio: 1 # fraction (0, 1] of the memory set to use diff --git a/sevenn/_const.py b/sevenn/_const.py index 6dc45589..0e411d77 100644 --- a/sevenn/_const.py +++ b/sevenn/_const.py @@ -212,6 +212,9 @@ def model_defaults(config): KEY.USE_MODAL_WISE_SCALE: False, KEY.SHIFT: 'per_atom_energy_mean', KEY.SCALE: 'force_rms', + KEY.REHEARSAL: False, + KEY.MEM_BATCH_SIZE: 0, + KEY.MEM_RATIO: 1, # KEY.DATA_SHUFFLE: True, # KEY.DATA_WEIGHT: False, # KEY.DATA_MODALITY: False, @@ -233,6 +236,11 @@ def model_defaults(config): KEY.SCALE: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SCALE, KEY.USE_MODAL_WISE_SHIFT: bool, KEY.USE_MODAL_WISE_SCALE: bool, + KEY.REHEARSAL: lambda x: isinstance(x, bool), + KEY.MEM_BATCH_SIZE: lambda x: isinstance(x, int) and not isinstance(x, bool), + KEY.MEM_RATIO: lambda x: isinstance(x, (int, float)) + and not isinstance(x, bool) + and 0 < x <= 1, # KEY.DATA_SHUFFLE: bool, KEY.COMPUTE_STATISTICS: bool, # KEY.DATA_WEIGHT: bool, diff --git a/sevenn/_keys.py b/sevenn/_keys.py index 4b8d63d7..31d29ae5 100644 --- a/sevenn/_keys.py +++ b/sevenn/_keys.py @@ -103,6 +103,10 @@ LOAD_TRAINSET = 'load_trainset_path' LOAD_VALIDSET = 'load_validset_path' LOAD_TESTSET = 'load_testset_path' +LOAD_MEMORY_PATH = 'load_memory_path' # reEWC rehearsal memory set +REHEARSAL = 'rehearsal' +MEM_BATCH_SIZE = 'mem_batch_size' +MEM_RATIO = 'mem_ratio' FORMAT_OUTPUTS = 'format_outputs_for_ase' COMPUTE_STATISTICS = 'compute_statistics' DATASET_TYPE = 'dataset_type' @@ -135,6 +139,9 @@ USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY = ( 'use_statistic_values_for_cp_modal_only' ) +OPT_PARAMS = 'opt_params' # reEWC: reference (optimal) params pickle +FISHER = 'fisher_information' # reEWC: precomputed Fisher information pickle +EWC_LAMBDA = 'ewc_lambda' # reEWC: EWC penalty weight CSV_LOG = 'csv_log' diff --git a/sevenn/checkpoint.py b/sevenn/checkpoint.py index e0422ec2..9bdbb60a 100644 --- a/sevenn/checkpoint.py +++ b/sevenn/checkpoint.py @@ -327,11 +327,29 @@ def build_model( enable_cueq = cp_using_cueq if enable_cueq is None else enable_cueq cp_using_flash = self.config.get(KEY.USE_FLASH_TP, False) + flash_requested_explicitly = enable_flash is True enable_flash = cp_using_flash if enable_flash is None else enable_flash cp_using_oeq = self.config.get(KEY.USE_OEQ, False) enable_oeq = cp_using_oeq if enable_oeq is None else enable_oeq + # FlashTP-saved checkpoints must still load where FlashTP is unavailable. + if enable_flash: + from sevenn.nn.flash_helper import is_flash_available + + if not is_flash_available(): + if flash_requested_explicitly or _flash_lammps: + raise ValueError( + 'FlashTP was requested but is not available (package ' + 'not installed or no GPU available).' + ) + warnings.warn( + 'FlashTP is unavailable; loading the checkpoint with the ' + 'e3nn backend instead.', + UserWarning, + ) + enable_flash = False + if sum([enable_cueq, enable_flash, enable_oeq]) > 1: raise ValueError('Only one TP accelerator can be enabled.') diff --git a/sevenn/main/sevenn_preset.py b/sevenn/main/sevenn_preset.py index 10e11e22..1b4e526e 100644 --- a/sevenn/main/sevenn_preset.py +++ b/sevenn/main/sevenn_preset.py @@ -27,6 +27,7 @@ def add_args(parser): 'base', 'multi_modal', 'mf_ompa_fine_tune', + 'reewc', ], help=preset_help ) diff --git a/sevenn/presets/reewc.yaml b/sevenn/presets/reewc.yaml new file mode 100644 index 00000000..f1410b4a --- /dev/null +++ b/sevenn/presets/reewc.yaml @@ -0,0 +1,90 @@ +# Example input.yaml for forgetting-prevented fine-tuning (reEWC). +# Replay and EWC are independent; keep one block, both, or neither. +# reEWC is for single-modal models (e.g. SevenNet-0, SevenNet-Nano). + +model: # keep consistent with the checkpoint being fine-tuned + chemical_species: 'Auto' + cutoff: 5.0 + channel: 128 + is_parity: False + lmax: 2 + num_convolution_layer: 5 + irreps_manual: + - "128x0e" + - "128x0e+64x1e+32x2e" + - "128x0e+64x1e+32x2e" + - "128x0e+64x1e+32x2e" + - "128x0e+64x1e+32x2e" + - "128x0e" + + weight_nn_hidden_neurons: [64, 64] + radial_basis: + radial_basis_name: 'bessel' + bessel_basis_num: 8 + cutoff_function: + cutoff_function_name: 'XPLOR' + cutoff_on: 4.5 + self_connection_type: 'linear' + + train_shift_scale: False + train_denominator: False + +train: + random_seed: 1 + is_train_stress: True + epoch: 100 + + loss: 'Huber' + loss_param: + delta: 0.01 + + optimizer: 'adam' + optim_param: + lr: 0.004 + # cosineannealingwarmuplr (cosine annealing with warm-up restarts) was used + # for the reEWC work; exponentiallr also works. + scheduler: 'exponentiallr' + scheduler_param: + gamma: 0.99 + + force_loss_weight: 1.0 + stress_loss_weight: 0.01 + + per_epoch: 10 + + error_record: + - ['Energy', 'RMSE'] + - ['Force', 'RMSE'] + - ['Stress', 'RMSE'] + - ['TotalLoss', 'None'] + + continue: + reset_optimizer: True + reset_scheduler: True + reset_epoch: True + checkpoint: 'SevenNet-0_11July2024' + + # EWC: anchor parameters to their pre-fine-tuning values via a + # precomputed Fisher matrix. fisher_information and opt_params are + # torch.save'd dicts {param_name: tensor} matching the model's + # trainable parameters; opt_params is the checkpoint before fine-tuning. + # All three keys are required together; remove them to disable EWC. + fisher_information: './fisher.pt' + opt_params: './opt_params.pt' + ewc_lambda: 100000 # EWC penalty weight (> 0) + +data: + batch_size: 4 + data_divide_ratio: 0.1 + data_format_args: + index: ':' + + load_trainset_path: ['./target_train.extxyz'] + load_validset_path: ['./valid.extxyz'] + + # Replay (experience replay): replay an old-task memory set each step so the + # model keeps fitting it. Remove this block to disable replay. + rehearsal: True + load_memory_path: ['./memory.extxyz'] # requires rehearsal: True + mem_batch_size: 8 + mem_ratio: 1 # fraction (0, 1] of the memory set diff --git a/sevenn/scripts/processing_epoch.py b/sevenn/scripts/processing_epoch.py index 38e671db..34338403 100644 --- a/sevenn/scripts/processing_epoch.py +++ b/sevenn/scripts/processing_epoch.py @@ -41,6 +41,12 @@ def processing_epoch_v2( config, trainer.loss_functions ) recorders = {k: deepcopy(recorder) for k in loaders} + # reEWC: log the replayed memory set as a separate 'memoryset' column group. + memory_recorder = ( + deepcopy(recorder) + if getattr(trainer, 'memory_loader', None) is not None + else None + ) best_val = float('inf') best_key = None @@ -58,6 +64,8 @@ def processing_epoch_v2( head = ['epoch', 'lr'] for k, rec in recorders.items(): head.extend(list(rec.get_dct(prefix=k))) + if memory_recorder is not None: + head.extend(list(memory_recorder.get_dct(prefix='memoryset'))) with open(csv_path, 'w') as f: f.write(','.join(head) + '\n') @@ -88,9 +96,17 @@ def processing_epoch_v2( loader.sampler.set_epoch(epoch) rec = recorders[k] - trainer.run_one_epoch(loader, is_train, rec) + trainer.run_one_epoch( + loader, + is_train, + rec, + memory_error_recorder=memory_recorder if is_train else None, + ) csv_dct.update(rec.get_dct(prefix=k)) errors[k] = rec.epoch_forward() + if memory_recorder is not None: + csv_dct.update(memory_recorder.get_dct(prefix='memoryset')) + errors['memoryset'] = memory_recorder.epoch_forward() log.write_full_table(list(errors.values()), list(errors)) trainer.scheduler_step(best_val) diff --git a/sevenn/scripts/train.py b/sevenn/scripts/train.py index 28371857..06eba2e7 100644 --- a/sevenn/scripts/train.py +++ b/sevenn/scripts/train.py @@ -11,6 +11,12 @@ from sevenn.scripts.processing_continue import ( convert_modality_of_checkpoint_state_dct, ) +from sevenn.train.reewc import ( + ReewcTrainer, + build_memory_loader, + reewc_dataset_keys, + validate_reewc_config, +) from sevenn.train.trainer import Trainer @@ -55,6 +61,8 @@ def train_v2(config: Dict[str, Any], working_dir: str) -> None: log.writeline('***************************************************') config[KEY.LOAD_TRAINSET] = config.pop(KEY.LOAD_DATASET) + validate_reewc_config(config) + # config updated start_epoch = 1 state_dicts: Optional[List[dict]] = None @@ -64,7 +72,9 @@ def train_v2(config: Dict[str, Any], working_dir: str) -> None: if config.get(KEY.USE_MODALITY, False): datasets = modal_dataset.from_config(config, working_dir) elif config[KEY.DATASET_TYPE] == 'graph': - datasets = graph_dataset.from_config(config, working_dir) + datasets = graph_dataset.from_config( + config, working_dir, dataset_keys=reewc_dataset_keys(config) + ) elif config[KEY.DATASET_TYPE] == 'atoms': datasets = atoms_dataset.from_config(config, working_dir) else: @@ -74,11 +84,19 @@ def train_v2(config: Dict[str, Any], working_dir: str) -> None: for k, v in datasets.items() } + rehearsal = config.get(KEY.REHEARSAL, False) + memory_loader = build_memory_loader(config) if rehearsal else None + log.write('\nModel building...\n') model = build_E3_equivariant_model(config) log.print_model_info(model, config) - trainer = Trainer.from_config(model, config) + if memory_loader is not None: + trainer = ReewcTrainer.from_config( + model, config, memory_loader=memory_loader + ) + else: + trainer = Trainer.from_config(model, config) if state_dicts: trainer.load_state_dicts(*state_dicts, strict=False) diff --git a/sevenn/train/loss.py b/sevenn/train/loss.py index a6f8a769..3b271be7 100644 --- a/sevenn/train/loss.py +++ b/sevenn/train/loss.py @@ -227,4 +227,8 @@ def get_loss_functions_from_config( if loss_function.criterion is None: loss_function.assign_criteria(criterion) + from sevenn.train.reewc.loss import append_ewc_loss + + append_ewc_loss(loss_functions, config) + return loss_functions diff --git a/sevenn/train/optim.py b/sevenn/train/optim.py index 10e75790..1a6c7d74 100644 --- a/sevenn/train/optim.py +++ b/sevenn/train/optim.py @@ -1,6 +1,10 @@ +import math + +import torch import torch.nn as nn import torch.optim.lr_scheduler as scheduler from torch.optim import adagrad, adam, adamw, radam, sgd +from torch.optim.lr_scheduler import _LRScheduler optim_dict = { 'sgd': sgd.SGD, @@ -11,11 +15,138 @@ } +# Adapted from the cosine_annealing_warmup package, MIT License, +# Copyright (c) 2022 Naoki Katsura. +class CosineAnnealingWarmupRestarts(_LRScheduler): + """ + Cosine annealing scheduler with linear warmup and warm restarts. + + Args: + optimizer: wrapped optimizer. + first_cycle_steps: number of steps in the first cycle. + cycle_mult: cycle length magnification applied at each restart. + max_lr: maximum (post-warmup) learning rate of the first cycle. + min_lr: minimum learning rate. + warmup_steps: number of linear warmup steps. + gamma: max_lr decay factor applied each cycle. + last_epoch: index of the last epoch. + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + first_cycle_steps: int, + cycle_mult: float = 1.0, + max_lr: float = 0.1, + min_lr: float = 0.001, + warmup_steps: int = 0, + gamma: float = 1.0, + last_epoch: int = -1, + ) -> None: + assert warmup_steps < first_cycle_steps + + self.first_cycle_steps = first_cycle_steps + self.cycle_mult = cycle_mult + self.base_max_lr = max_lr + self.max_lr = max_lr + self.min_lr = min_lr + self.warmup_steps = warmup_steps + self.gamma = gamma + + self.cur_cycle_steps = first_cycle_steps + self.cycle = 0 + self.step_in_cycle = last_epoch + + super().__init__(optimizer, last_epoch) + + self.init_lr() + + def init_lr(self) -> None: + self.base_lrs = [] + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.min_lr + self.base_lrs.append(self.min_lr) + + def get_lr(self): + if self.step_in_cycle == -1: + return self.base_lrs + elif self.step_in_cycle < self.warmup_steps: + return [ + (self.max_lr - base_lr) * self.step_in_cycle / self.warmup_steps + + base_lr + for base_lr in self.base_lrs + ] + else: + return [ + base_lr + + (self.max_lr - base_lr) + * ( + 1 + + math.cos( + math.pi + * (self.step_in_cycle - self.warmup_steps) + / (self.cur_cycle_steps - self.warmup_steps) + ) + ) + / 2 + for base_lr in self.base_lrs + ] + + def step(self, epoch=None): + if epoch is None: + epoch = self.last_epoch + 1 + self.step_in_cycle = self.step_in_cycle + 1 + if self.step_in_cycle >= self.cur_cycle_steps: + self.cycle += 1 + self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps + self.cur_cycle_steps = ( + int( + (self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult + ) + + self.warmup_steps + ) + else: + if epoch >= self.first_cycle_steps: + if self.cycle_mult == 1.0: + self.step_in_cycle = epoch % self.first_cycle_steps + self.cycle = epoch // self.first_cycle_steps + else: + n = int( + math.log( + ( + epoch + / self.first_cycle_steps + * (self.cycle_mult - 1) + + 1 + ), + self.cycle_mult, + ) + ) + self.cycle = n + self.step_in_cycle = epoch - int( + self.first_cycle_steps + * (self.cycle_mult**n - 1) + / (self.cycle_mult - 1) + ) + self.cur_cycle_steps = ( + self.first_cycle_steps * self.cycle_mult**n + ) + else: + self.cur_cycle_steps = self.first_cycle_steps + self.step_in_cycle = epoch + + self.max_lr = self.base_max_lr * (self.gamma**self.cycle) + self.last_epoch = math.floor(epoch) + for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): + param_group['lr'] = lr + + scheduler_dict = { 'steplr': scheduler.StepLR, 'multisteplr': scheduler.MultiStepLR, 'exponentiallr': scheduler.ExponentialLR, 'cosineannealinglr': scheduler.CosineAnnealingLR, + 'cosineannealingwarmuplr': CosineAnnealingWarmupRestarts, 'reducelronplateau': scheduler.ReduceLROnPlateau, 'linearlr': scheduler.LinearLR, } diff --git a/sevenn/train/reewc/__init__.py b/sevenn/train/reewc/__init__.py new file mode 100644 index 00000000..d2640282 --- /dev/null +++ b/sevenn/train/reewc/__init__.py @@ -0,0 +1,16 @@ +from .loss import EWCLoss, append_ewc_loss +from .rehearsal import ( + build_memory_loader, + reewc_dataset_keys, + validate_reewc_config, +) +from .trainer import ReewcTrainer + +__all__ = [ + 'EWCLoss', + 'append_ewc_loss', + 'ReewcTrainer', + 'build_memory_loader', + 'reewc_dataset_keys', + 'validate_reewc_config', +] diff --git a/sevenn/train/reewc/loss.py b/sevenn/train/reewc/loss.py new file mode 100644 index 00000000..e69d8933 --- /dev/null +++ b/sevenn/train/reewc/loss.py @@ -0,0 +1,141 @@ +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch + +import sevenn._keys as KEY +from sevenn.train.loss import LossDefinition + + +class EWCLoss(LossDefinition): + """ + Elastic Weight Consolidation penalty: sum_i F_i (theta_i - theta*_i)^2, + with precomputed Fisher information F and reference parameters theta*. + Consumes precomputed Fisher/optimal-params dicts; it does not compute them. + """ + + def __init__( + self, + fisher_dict: Dict[str, torch.Tensor], + opt_params_dict: Dict[str, torch.Tensor], + name: str = 'EWC', + device: Optional[str] = None, + **kwargs, + ) -> None: + if not isinstance(fisher_dict, dict) or not isinstance( + opt_params_dict, dict + ): + raise ValueError('EWC fisher_information/opt_params must be dicts') + super().__init__(name=name, use_weight=False, **kwargs) + self.fisher_dict = fisher_dict + self.opt_params_dict = opt_params_dict + self._checked = False + if device is not None: + self.to(device) + + def to(self, device) -> None: + self.fisher_dict = {k: v.to(device) for k, v in self.fisher_dict.items()} + self.opt_params_dict = { + k: v.to(device) for k, v in self.opt_params_dict.items() + } + + def _check_and_align(self, model: Callable) -> None: + if len(self.fisher_dict) == 0 or len(self.opt_params_dict) == 0: + raise ValueError('EWC fisher_information/opt_params is empty') + + # Fisher and reference params are a matched pair; they must agree on + # both parameter names and shapes regardless of the model. + if set(self.fisher_dict) != set(self.opt_params_dict): + raise ValueError( + 'EWC fisher_information and opt_params cover different parameters' + ) + for name, fisher in self.fisher_dict.items(): + if fisher.shape != self.opt_params_dict[name].shape: + raise ValueError( + f'EWC fisher/opt_params shape mismatch for {name}: ' + f'{tuple(fisher.shape)} != ' + f'{tuple(self.opt_params_dict[name].shape)}' + ) + + model_params = { + n: p for n, p in model.named_parameters() if p.requires_grad + } + if len(model_params) == 0: + raise ValueError('EWC requires the model to have trainable parameters') + + shared = set(self.fisher_dict) & set(model_params) + if len(shared) == 0: + raise ValueError( + 'EWC fisher/opt_params parameter names do not match the model; ' + 'the pickle was likely produced by an incompatible SevenNet ' + f'version. example model param: {next(iter(model_params))}; ' + f'example fisher key: {next(iter(self.fisher_dict))}' + ) + for name in shared: + if self.fisher_dict[name].shape != model_params[name].shape: + raise ValueError( + f'EWC fisher shape mismatch for {name}: ' + f'{tuple(self.fisher_dict[name].shape)} != ' + f'{tuple(model_params[name].shape)}' + ) + + # A trainable param without a Fisher entry is left unconstrained. + unconstrained = [n for n in model_params if n not in self.fisher_dict] + if unconstrained: + warnings.warn( + f'EWC has no Fisher information for {len(unconstrained)} ' + f'trainable parameter(s); they stay unconstrained ' + f'(e.g. {unconstrained[0]})', + UserWarning, + ) + + self.to(next(iter(model_params.values())).device) + self._checked = True + + def get_loss( + self, batch_data: Dict[str, Any], model: Optional[Callable] = None + ): + _ = batch_data + if model is None: + raise ValueError('EWCLoss requires the model to compute the penalty') + if not self._checked: + self._check_and_align(model) + device = next(model.parameters()).device + ewc_loss = torch.zeros(1, device=device) + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + if name not in self.fisher_dict or name not in self.opt_params_dict: + continue + fisher = self.fisher_dict[name] + opt_param = self.opt_params_dict[name] + ewc_loss = ewc_loss + torch.sum(fisher * (param - opt_param) ** 2) + return ewc_loss + + +def append_ewc_loss( + loss_functions: List[Tuple[LossDefinition, float]], + config: Dict[str, Any], +) -> None: + """reEWC: append the EWC penalty as an extra loss term when a precomputed + Fisher information and reference parameters are given under continue.""" + cont = config.get(KEY.CONTINUE, {}) + fisher_path = cont.get(KEY.FISHER, False) + opt_path = cont.get(KEY.OPT_PARAMS, False) + ewc_lambda = cont.get(KEY.EWC_LAMBDA, 0) + if not (fisher_path or opt_path or ewc_lambda): + return + if not (fisher_path and opt_path): + raise ValueError( + 'EWC requires both continue.fisher_information and ' + 'continue.opt_params to be set' + ) + if not ( + isinstance(ewc_lambda, (int, float)) + and not isinstance(ewc_lambda, bool) + and ewc_lambda > 0 + ): + raise ValueError('EWC requires continue.ewc_lambda > 0') + fisher = torch.load(fisher_path, map_location='cpu', weights_only=True) + opt = torch.load(opt_path, map_location='cpu', weights_only=True) + loss_functions.append((EWCLoss(fisher, opt), ewc_lambda / 2.0)) diff --git a/sevenn/train/reewc/rehearsal.py b/sevenn/train/reewc/rehearsal.py new file mode 100644 index 00000000..c74f0ee0 --- /dev/null +++ b/sevenn/train/reewc/rehearsal.py @@ -0,0 +1,84 @@ +import random +from typing import Any, Dict, List, Optional + +from torch_geometric.loader import DataLoader + +import sevenn._keys as KEY +from sevenn.logger import Logger + + +def validate_reewc_config(config: Dict[str, Any]) -> None: + """Fail-loud guards for reEWC. No-op when no reEWC keys are set.""" + rehearsal = config.get(KEY.REHEARSAL, False) + cont = config.get(KEY.CONTINUE, {}) + ewc_active = bool(cont.get(KEY.FISHER, False)) or bool( + cont.get(KEY.OPT_PARAMS, False) + ) + memory_paths = config.get(KEY.LOAD_MEMORY_PATH, False) + if (rehearsal or ewc_active) and config.get(KEY.IS_DDP, False): + raise NotImplementedError( + 'reEWC (rehearsal/EWC) does not support distributed training' + ) + if memory_paths and not rehearsal: + raise ValueError( + 'load_memory_path is set but rehearsal is False; load_memory_path ' + 'is reserved for reEWC rehearsal' + ) + if rehearsal and config.get(KEY.DATASET_TYPE) == 'atoms': + raise NotImplementedError( + 'reEWC rehearsal supports dataset_type="graph" only' + ) + if (rehearsal or ewc_active) and config.get(KEY.USE_MODALITY, False): + raise ValueError( + 'reEWC (rehearsal/EWC) supports single-modal models only; ' + 'multifidelity/modal models are not supported' + ) + + +def reewc_dataset_keys(config: Dict[str, Any]) -> Optional[List[str]]: + """Dataset keys for normal discovery, excluding the reserved memory set so + it is not run as an extra (validation-style) loader every epoch. Returns + None (load all) when no rehearsal memory set is configured.""" + memory_paths = config.get(KEY.LOAD_MEMORY_PATH, False) + if not memory_paths: + return None + return [ + k + for k in config + if k.startswith('load_') + and k.endswith('_path') + and k != KEY.LOAD_MEMORY_PATH + ] + + +def build_memory_loader(config: Dict[str, Any]) -> DataLoader: + """Build the reEWC rehearsal (replay) memory loader from load_memory_path.""" + from sevenn.train.graph_dataset import SevenNetGraphDataset + + memory_paths = config.get(KEY.LOAD_MEMORY_PATH, False) + if not memory_paths: + raise ValueError('rehearsal is True but load_memory_path is not set') + if isinstance(memory_paths, str): + memory_paths = [memory_paths] + mem_batch_size = config.get(KEY.MEM_BATCH_SIZE, 0) + if not (isinstance(mem_batch_size, int) and mem_batch_size > 0): + raise ValueError('rehearsal requires mem_batch_size > 0') + mem_ratio = config.get(KEY.MEM_RATIO, 1) + if not (0 < mem_ratio <= 1): + raise ValueError('rehearsal requires 0 < mem_ratio <= 1') + + graphs = [] + for file in memory_paths: + graphs.extend( + SevenNetGraphDataset.file_to_graph_list(file, cutoff=config[KEY.CUTOFF]) + ) + if mem_ratio < 1: + random.Random(config.get(KEY.RANDOM_SEED, 1)).shuffle(graphs) + graphs = graphs[: int(len(graphs) * mem_ratio)] + if len(graphs) == 0: + raise ValueError('reEWC rehearsal memory set is empty after loading') + Logger().writeline( + f'Rehearsal enabled: {len(graphs)} memory graphs, ' + f'mem_batch_size={mem_batch_size}' + ) + return DataLoader(graphs, batch_size=mem_batch_size, shuffle=True) diff --git a/sevenn/train/reewc/trainer.py b/sevenn/train/reewc/trainer.py new file mode 100644 index 00000000..02cd7845 --- /dev/null +++ b/sevenn/train/reewc/trainer.py @@ -0,0 +1,103 @@ +from typing import Any, Dict, Iterable, Optional, Union + +import torch +from tqdm import tqdm + +import sevenn._keys as KEY +from sevenn.error_recorder import ErrorRecorder +from sevenn.train.loss import get_loss_functions_from_config +from sevenn.train.optim import optim_dict, scheduler_dict +from sevenn.train.trainer import Trainer + + +class ReewcTrainer(Trainer): + """ + Trainer with reEWC experience replay: one extra optimizer step on a memory + (old-task) batch per training batch, to mitigate catastrophic forgetting. + """ + + def __init__( + self, *args, memory_loader: Optional[Iterable] = None, **kwargs + ) -> None: + super().__init__(*args, **kwargs) + self.memory_loader = memory_loader + + @staticmethod + def from_config( + model: torch.nn.Module, + config: Dict[str, Any], + memory_loader: Optional[Iterable] = None, + ) -> 'ReewcTrainer': + trainer = ReewcTrainer( + model, + loss_functions=get_loss_functions_from_config(config), + optimizer_cls=optim_dict[config.get(KEY.OPTIMIZER, 'adam').lower()], + optimizer_args=config.get(KEY.OPTIM_PARAM, {}), + scheduler_cls=scheduler_dict[ + config.get(KEY.SCHEDULER, 'exponentiallr').lower() + ], + scheduler_args=config.get(KEY.SCHEDULER_PARAM, {}), + device=config.get(KEY.DEVICE, 'auto'), + distributed=config.get(KEY.IS_DDP, False), + distributed_backend=config.get(KEY.DDP_BACKEND, 'nccl'), + memory_loader=memory_loader, + ) + return trainer + + def run_one_epoch( + self, + loader: Iterable, + is_train: bool = False, + error_recorder: Optional[ErrorRecorder] = None, + memory_error_recorder: Optional[ErrorRecorder] = None, + wrap_tqdm: Union[bool, int] = False, + ) -> None: + # Without a memory set to replay, behave exactly like the base Trainer. + if not (is_train and self.memory_loader is not None): + super().run_one_epoch( + loader, is_train, error_recorder, wrap_tqdm=wrap_tqdm + ) + return + + self.model.train() + if wrap_tqdm: + total_len = wrap_tqdm if isinstance(wrap_tqdm, int) else None + loader = tqdm(loader, total=total_len) + + mem_iter = iter(self.memory_loader) + for _, batch in enumerate(loader): + self.optimizer.zero_grad() + batch = batch.to(self.device, non_blocking=True) + output = self.model(batch) + if error_recorder is not None: + error_recorder.update(output) + total_loss = torch.tensor([0.0], device=self.device) + for loss_def, w in self.loss_functions: + indv_loss = loss_def.get_loss(output, self.model) + if indv_loss is not None: + total_loss += (indv_loss * w) + total_loss.backward() + self.optimizer.step() + + # reEWC rehearsal: replay one memory batch with an independent + # optimizer step. + try: + mem_batch = next(mem_iter) + except StopIteration: + mem_iter = iter(self.memory_loader) + mem_batch = next(mem_iter) + mem_batch = mem_batch.to(self.device, non_blocking=True) + mem_output = self.model(mem_batch) + if memory_error_recorder is not None: + memory_error_recorder.update(mem_output) + self.optimizer.zero_grad() + mem_loss = torch.tensor([0.0], device=self.device) + for loss_def, w in self.loss_functions: + indv_loss = loss_def.get_loss(mem_output, self.model) + if indv_loss is not None: + mem_loss += (indv_loss * w) + mem_loss.backward() + self.optimizer.step() + + if self.distributed and error_recorder is not None: + self.recorder_all_reduce(error_recorder) diff --git a/sevenn/train/trainer.py b/sevenn/train/trainer.py index 9a7e186e..22d952e8 100644 --- a/sevenn/train/trainer.py +++ b/sevenn/train/trainer.py @@ -82,7 +82,10 @@ def __init__( self.loss_functions = loss_functions @staticmethod - def from_config(model: torch.nn.Module, config: Dict[str, Any]) -> 'Trainer': + def from_config( + model: torch.nn.Module, + config: Dict[str, Any], + ) -> 'Trainer': trainer = Trainer( model, loss_functions=get_loss_functions_from_config(config), @@ -137,6 +140,7 @@ def run_one_epoch( loader: Iterable, is_train: bool = False, error_recorder: Optional[ErrorRecorder] = None, + memory_error_recorder: Optional[ErrorRecorder] = None, wrap_tqdm: Union[bool, int] = False, ) -> None: """ @@ -155,6 +159,7 @@ def run_one_epoch( if wrap_tqdm: total_len = wrap_tqdm if isinstance(wrap_tqdm, int) else None loader = tqdm(loader, total=total_len) + for _, batch in enumerate(loader): if is_train: self.optimizer.zero_grad() diff --git a/tests/unit_tests/test_flash.py b/tests/unit_tests/test_flash.py index 54473e9c..c2ea076e 100644 --- a/tests/unit_tests/test_flash.py +++ b/tests/unit_tests/test_flash.py @@ -259,3 +259,46 @@ def test_calculator(tmp_path): atoms2.calc = calc2 assert_atoms(atoms, atoms2) + + +def _make_flash_like_checkpoint(tmp_path): + """A FlashTP-style checkpoint (use_flash_tp=True) whose state_dict lacks the + e3nn-only convolution keys, as a FlashTP-saved checkpoint would. Built on CPU + so it can be loaded where FlashTP is unavailable.""" + cf = get_model_config() + cf['use_flash_tp'] = False + model = build_E3_equivariant_model(cf, parallel=False) + sd = model.state_dict() + drop = [ + k + for k in sd + if k.endswith('.convolution.weight') + or k.endswith('.convolution.output_mask') + or '._compiled_main_left_right._w3j' in k + ] + state_dict = {k: v for k, v in sd.items() if k not in drop} + cfg = get_model_config() + cfg.update({'use_flash_tp': True, 'version': sevenn.__version__}) + path = str(tmp_path / 'flash_like_cp.pth') + torch.save({'model_state_dict': state_dict, 'config': cfg}, path) + return path + + +def test_flash_checkpoint_loads_when_flash_unavailable(tmp_path, monkeypatch): + monkeypatch.setattr('sevenn.nn.flash_helper.is_flash_available', lambda: False) + path = _make_flash_like_checkpoint(tmp_path) + from sevenn.util import load_checkpoint + + # default build_model() must fall back to e3nn, not fail on missing keys + model = load_checkpoint(path).build_model() + assert isinstance(model, AtomGraphSequential) + + +def test_explicit_flash_request_raises_when_unavailable(tmp_path, monkeypatch): + monkeypatch.setattr('sevenn.nn.flash_helper.is_flash_available', lambda: False) + path = _make_flash_like_checkpoint(tmp_path) + from sevenn.util import load_checkpoint + + # an explicit enable_flash=True cannot be honored -> fail loud + with pytest.raises(ValueError): + load_checkpoint(path).build_model(enable_flash=True) diff --git a/tests/unit_tests/test_reewc.py b/tests/unit_tests/test_reewc.py new file mode 100644 index 00000000..e7e79fe4 --- /dev/null +++ b/tests/unit_tests/test_reewc.py @@ -0,0 +1,380 @@ +import pathlib + +import ase.io +import pytest +import torch +from torch_geometric.loader import DataLoader + +from sevenn.logger import Logger +from sevenn.train.dataload import graph_build +from sevenn.train.reewc.loss import EWCLoss +from sevenn.train.reewc.trainer import ReewcTrainer +from sevenn.train.trainer import Trainer + +Logger() # init singleton used by train_v2 + + +class _TinyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.Parameter(torch.zeros(3)) + self.b = torch.nn.Parameter(torch.ones(2)) + + +def _fisher_opt(model): + fisher = {n: torch.ones_like(p) for n, p in model.named_parameters()} + opt = {n: p.detach().clone() for n, p in model.named_parameters()} + return fisher, opt + + +def test_ewc_loss_zero_at_optimum(): + model = _TinyModel() + fisher, opt = _fisher_opt(model) + ewc = EWCLoss(fisher, opt) + assert float(ewc.get_loss({}, model)) == pytest.approx(0.0) + + +def test_ewc_loss_positive_after_perturbation(): + model = _TinyModel() + fisher, opt = _fisher_opt(model) + ewc = EWCLoss(fisher, opt) + with torch.no_grad(): + model.a += 2.0 + # sum_i fisher_i * (p_i - opt_i)^2 = 3 * (1 * 2^2) = 12 + assert float(ewc.get_loss({}, model)) == pytest.approx(12.0) + + +def test_ewc_loss_requires_model(): + model = _TinyModel() + fisher, opt = _fisher_opt(model) + ewc = EWCLoss(fisher, opt) + with pytest.raises(ValueError): + ewc.get_loss({}, None) + + +def test_ewc_loss_fisher_must_be_dict(): + with pytest.raises(ValueError): + EWCLoss([1, 2, 3], {}) + + +def test_ewc_loss_no_matching_param_raises(): + model = _TinyModel() + fisher = {'bogus.param': torch.ones(3)} + opt = {'bogus.param': torch.zeros(3)} + ewc = EWCLoss(fisher, opt) + with pytest.raises(ValueError): + ewc.get_loss({}, model) + + +def test_ewc_loss_shape_mismatch_raises(): + model = _TinyModel() + fisher = {'a': torch.ones(5)} # model.a has shape (3,) + opt = {'a': torch.zeros(5)} + ewc = EWCLoss(fisher, opt) + with pytest.raises(ValueError): + ewc.get_loss({}, model) + + +def test_ewc_loss_partial_fisher_warns_and_skips(): + # a trainable param without a Fisher entry is left unconstrained, not fatal + model = _TinyModel() + fisher = {'a': torch.ones(3)} # no entry for 'b' + opt = {'a': torch.zeros(3)} + ewc = EWCLoss(fisher, opt) + with torch.no_grad(): + model.a += 1.0 # 'b' stays at its optimum + with pytest.warns(UserWarning): + loss = float(ewc.get_loss({}, model)) + assert loss == pytest.approx(3.0) # 3 * (1 * 1^2); 'b' contributes nothing + + +def test_ewc_loss_opt_shape_mismatch_raises(): + model = _TinyModel() + fisher = {'a': torch.ones(3), 'b': torch.ones(2)} + opt = {'a': torch.zeros(3), 'b': torch.zeros(5)} # wrong opt shape for 'b' + ewc = EWCLoss(fisher, opt) + with pytest.raises(ValueError): + ewc.get_loss({}, model) + + +def test_ewc_loss_empty_fisher_raises(): + model = _TinyModel() + ewc = EWCLoss({}, {}) + with pytest.raises(ValueError): # clean error, not StopIteration + ewc.get_loss({}, model) + + +def test_continue_ewc_keys_survive_parse(): + # continue.fisher_information / opt_params / ewc_lambda must survive config + # normalization so EWC is not silently disabled on a real YAML run. + from sevenn.parse_input import init_train_config + + cfg = { + 'optimizer': 'adam', + 'scheduler': 'exponentiallr', + 'continue': { + 'checkpoint': '7net-0', + 'fisher_information': '/x/fisher.pt', + 'opt_params': '/x/opt.pt', + 'ewc_lambda': 1000, + }, + } + train_meta = init_train_config(cfg) + cont = train_meta['continue'] + assert cont['fisher_information'] == '/x/fisher.pt' + assert cont['opt_params'] == '/x/opt.pt' + assert cont['ewc_lambda'] == 1000 + + +# --- Scheduler: native CosineAnnealingWarmupRestarts --- + + +def _make_opt(): + return torch.optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=0.0) + + +def test_warmup_scheduler_registered(): + from sevenn.train.optim import scheduler_dict + + assert 'cosineannealingwarmuplr' in scheduler_dict + + +def test_warmup_scheduler_known_lrs(): + from sevenn.train.optim import CosineAnnealingWarmupRestarts + + opt = _make_opt() + sch = CosineAnnealingWarmupRestarts( + opt, first_cycle_steps=10, max_lr=1.0, min_lr=0.0, warmup_steps=4 + ) + assert opt.param_groups[0]['lr'] == pytest.approx(0.0) # init = min_lr + sch.step() # warmup step 1 of 4 -> max_lr * 1/4 + assert opt.param_groups[0]['lr'] == pytest.approx(0.25) + for _ in range(3): # warmup end (step 4) -> max_lr + sch.step() + assert opt.param_groups[0]['lr'] == pytest.approx(1.0) + for _ in range(6): # full cycle (step 10) -> restart to min_lr + sch.step() + assert opt.param_groups[0]['lr'] == pytest.approx(0.0) + + +# --- Config injection + validation (get_loss_functions_from_config) --- + + +def _base_loss_cfg(): + return { + 'loss': 'huber', + 'loss_param': {'delta': 0.01}, + 'use_weight': False, + 'force_loss_weight': 1.0, + 'stress_loss_weight': 0.01, + 'is_train_stress': True, + 'device': 'cpu', + 'continue': {'checkpoint': False}, + } + + +def _write_ewc_pkls(tmp_path): + fisher = {'p': torch.ones(2)} + opt = {'p': torch.zeros(2)} + fp, op = tmp_path / 'fisher.pkl', tmp_path / 'opt.pkl' + torch.save(fisher, fp) + torch.save(opt, op) + return str(fp), str(op) + + +def test_loss_functions_inject_ewc(tmp_path): + from sevenn.train.loss import get_loss_functions_from_config + from sevenn.train.reewc.loss import EWCLoss + + fp, op = _write_ewc_pkls(tmp_path) + cfg = _base_loss_cfg() + cfg['continue'] = { + 'checkpoint': False, + 'fisher_information': fp, + 'opt_params': op, + 'ewc_lambda': 100.0, + } + lfs = get_loss_functions_from_config(cfg) + ewc = [(ld, w) for ld, w in lfs if isinstance(ld, EWCLoss)] + assert len(ewc) == 1 + assert ewc[0][1] == pytest.approx(50.0) # ewc_lambda / 2 + + +def test_no_ewc_keys_no_ewc_loss(): + from sevenn.train.loss import get_loss_functions_from_config + from sevenn.train.reewc.loss import EWCLoss + + lfs = get_loss_functions_from_config(_base_loss_cfg()) + assert not any(isinstance(ld, EWCLoss) for ld, _ in lfs) + + +def test_loss_ewc_requires_both_fisher_and_opt(tmp_path): + from sevenn.train.loss import get_loss_functions_from_config + + fp, _ = _write_ewc_pkls(tmp_path) + cfg = _base_loss_cfg() + cfg['continue'] = { + 'checkpoint': False, + 'fisher_information': fp, + 'ewc_lambda': 100.0, + } + with pytest.raises(ValueError): + get_loss_functions_from_config(cfg) + + +def test_loss_ewc_lambda_only_raises(): + from sevenn.train.loss import get_loss_functions_from_config + + cfg = _base_loss_cfg() + cfg['continue'] = {'checkpoint': False, 'ewc_lambda': 100.0} + with pytest.raises(ValueError): + get_loss_functions_from_config(cfg) + + +def test_loss_ewc_nonpositive_lambda_raises(tmp_path): + from sevenn.train.loss import get_loss_functions_from_config + + fp, op = _write_ewc_pkls(tmp_path) + cfg = _base_loss_cfg() + cfg['continue'] = { + 'checkpoint': False, + 'fisher_information': fp, + 'opt_params': op, + 'ewc_lambda': 0, + } + with pytest.raises(ValueError): + get_loss_functions_from_config(cfg) + + +# --- Config keys parsed + default values --- + + +def test_rehearsal_data_keys_parsed(): + from sevenn.parse_input import init_data_config + + data_meta = init_data_config( + {'batch_size': 8, 'rehearsal': True, 'mem_batch_size': 4, 'mem_ratio': 0.5} + ) + assert data_meta['rehearsal'] is True + assert data_meta['mem_batch_size'] == 4 + assert data_meta['mem_ratio'] == 0.5 + + +def test_rehearsal_defaults_off(): + from sevenn.parse_input import init_data_config + + data_meta = init_data_config({'batch_size': 8}) + assert data_meta['rehearsal'] is False # default is off + + +# --- Replay (rehearsal) in Trainer.run_one_epoch --- + +_data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() +_hfo2_path = str(_data_root / 'systems' / 'hfo2.extxyz') +_cp_0_path = str(_data_root / 'checkpoints' / 'cp_0.pth') + + +@pytest.fixture(scope='module') +def hfo2_loader(): + atoms = ase.io.read(_hfo2_path, index=':') + graphs = graph_build(atoms, 4.0, y_from_calc=True) + return DataLoader(graphs, batch_size=2) + + +class _CountingLoader: + def __init__(self, base): + self.base = base + self.pulls = 0 + + def __iter__(self): + for batch in self.base: + self.pulls += 1 + yield batch + + +def test_replay_consumed_only_in_train(hfo2_loader): + targs, _, _ = Trainer.args_from_checkpoint(_cp_0_path) + mem = _CountingLoader(hfo2_loader) + trainer = ReewcTrainer(**targs, device='cpu', memory_loader=mem) + n_batches = sum(1 for _ in hfo2_loader) + trainer.run_one_epoch(hfo2_loader, is_train=True) + assert mem.pulls == n_batches # one memory batch per training batch + mem.pulls = 0 + trainer.run_one_epoch(hfo2_loader, is_train=False) + assert mem.pulls == 0 # replay is skipped during evaluation + + +def test_replay_changes_params(hfo2_loader): + targs, _, _ = Trainer.args_from_checkpoint(_cp_0_path) + trainer = ReewcTrainer(**targs, device='cpu', memory_loader=hfo2_loader) + before = [p.detach().clone() for p in trainer.model.parameters()] + trainer.run_one_epoch(hfo2_loader, is_train=True) + after = list(trainer.model.parameters()) + assert any(not torch.allclose(b, a) for b, a in zip(before, after)) + + +def test_replay_records_memory_metrics(hfo2_loader): + # the replayed memory set is recorded so it can be logged as 'memoryset' + from sevenn.util import get_error_recorder + + targs, _, _ = Trainer.args_from_checkpoint(_cp_0_path) + trainer = ReewcTrainer(**targs, device='cpu', memory_loader=hfo2_loader) + mem_rec = get_error_recorder([('Energy', 'RMSE'), ('Force', 'RMSE')]) + trainer.run_one_epoch( + hfo2_loader, is_train=True, memory_error_recorder=mem_rec + ) + errs = mem_rec.epoch_forward() + assert len(errs) > 0 + assert all(v == v for v in errs.values()) # finite (recorder was updated) + + +# --- train_v2 reEWC guards (fail-loud, no silent fallback) --- + + +def test_train_v2_reewc_ddp_guard(tmp_path): + from sevenn.scripts.train import train_v2 + + cfg = {'rehearsal': True, 'is_ddp': True, 'continue': {'checkpoint': False}} + with pytest.raises(NotImplementedError): + train_v2(cfg, str(tmp_path)) + + +def test_train_v2_memory_path_requires_rehearsal(tmp_path): + from sevenn.scripts.train import train_v2 + + cfg = { + 'rehearsal': False, + 'load_memory_path': ['/nonexistent.sevenn_data'], + 'is_ddp': False, + 'continue': {'checkpoint': False}, + } + with pytest.raises(ValueError): + train_v2(cfg, str(tmp_path)) + + +def test_train_v2_single_modal_guard(tmp_path): + from sevenn.scripts.train import train_v2 + + cfg = { + 'rehearsal': True, + 'is_ddp': False, + 'use_modality': True, + 'load_memory_path': ['/nonexistent.sevenn_data'], + 'continue': {'checkpoint': False}, + } + with pytest.raises(ValueError): + train_v2(cfg, str(tmp_path)) + + +def test_train_v2_rehearsal_atoms_guard(tmp_path): + from sevenn.scripts.train import train_v2 + + cfg = { + 'rehearsal': True, + 'is_ddp': False, + 'dataset_type': 'atoms', + 'load_memory_path': ['/nonexistent.sevenn_data'], + 'continue': {'checkpoint': False}, + } + with pytest.raises(NotImplementedError): + train_v2(cfg, str(tmp_path))