Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ repos:
hooks:
- id: codespell
stages: [pre-commit, commit-msg]
args: ["--ignore-words-list", "Commun"]
args: ["--ignore-words-list", "Commun,Mater"]
exclude: |
(?x)(
^example_inputs/data/|
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ All notable changes to this project will be documented in this file.

## [0.12.2.dev]
### Added
- reEWC fine-tuning with forgetting prevention for single-modal models: optional experience replay (`rehearsal`, `load_memory_path`, `mem_batch_size`, `mem_ratio`) and an Elastic Weight Consolidation penalty from a precomputed Fisher matrix (`continue.fisher_information`, `continue.opt_params`, `continue.ewc_lambda`), plus a `cosineannealingwarmuplr` scheduler
- Support OpenEquivariance
- Per-atom stress (atomic virial) support in LAMMPS pair_e3gnn and ASE calculator
- `compute_atomic_virial` option in `SevenNetCalculator`
Expand All @@ -14,6 +15,9 @@ All notable changes to this project will be documented in this file.
- LAMMPS pair_e3gnn refactored to use pair-wise force (dE/dr) instead of position-based gradient.
- Deploy no longer replaces force_output with ForceStressOutput; force/stress computed in LAMMPS C++ side.

### Fixed
- Load FlashTP-saved checkpoints (e.g. SevenNet-Nano) when FlashTP is unavailable by falling back to the e3nn backend, so they work for inference and fine-tuning without FlashTP installed. An explicit `enable_flash=True` still fails loud.

## [0.12.1]
### Fixed
- FlashTP with LAMMPS parallel in torch
Expand Down
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Full documentation, including **installation**, **usage**, and **pretrained mode
- GPU-parallelized molecular dynamics with LAMMPS
- CUDA-accelerated D3 (van der Waals) dispersion
- Multi-fidelity training for combining multiple databases with different calculation settings
- Fine-tuning with forgetting prevention (experience replay + Elastic Weight Consolidation) for continual learning
- [Tensor product accelerators](https://sevennet.readthedocs.io/en/latest/user_guide/accelerator.html)


Expand Down Expand Up @@ -71,3 +72,16 @@ If you utilize the pretrained model SevenNet-Omni or multi-task training strateg
year = {2026},
}
```

If you utilize the reEWC forgetting-aware fine-tuning strategy for continual learning of pretrained universal machine-learning interatomic potentials, please cite the following paper:
```bib
@article{kim_efficient_2026,
title = {An Efficient Forgetting-Aware Fine-Tuning Framework for Pretrained Universal Machine-Learning Interatomic Potentials},
volume = {12},
doi = {10.1038/s41524-025-01895-w},
number = {26},
journal = {npj Comput. Mater.},
author = {Kim, Jisu and Lee, Jiho and Oh, Sangmin and Park, Yutack and Hwang, Seungwoo and Han, Seungwu and Kang, Sungwoo and Kang, Youngho},
year = {2026},
}
```
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ SevenNet (Scalable EquiVariance-Enabled Neural Network) is a graph neural networ
* GPU-parallelized molecular dynamics with LAMMPS
* CUDA-accelerated D3 (van der Waals) dispersion
* Multi-fidelity training for combining multiple databases with different calculation settings
* Fine-tuning with forgetting prevention (Experience replay + Elastic Weight Consolidation) for continual learning


Installation
Expand Down
1 change: 1 addition & 0 deletions docs/source/user_guide/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@ SevenNet offers various pretrained models, MD engines (ASE, LAMMPS), and user in
ase_calculator
torchsim
cli
reewc
d3
note_book
63 changes: 63 additions & 0 deletions docs/source/user_guide/reewc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Forgetting-prevented (Continual-learning) fine-tuning (reEWC)

Fine-tuning a pretrained model on a target system improves accuracy there, but the
model can lose accuracy on the original training domain (catastrophic forgetting).
reEWC mitigates this with two complementary mechanisms that can be used together or
separately:

- **Experience replay (rehearsal)** -- replay an old-task "memory" set each training
step so the model keeps fitting it while learning the target data.
- **Elastic Weight Consolidation (EWC)** -- add a penalty
`lambda/2 * sum_i F_i (theta_i - theta*_i)^2` that anchors parameters to their
pre-fine-tuning values `theta*`, weighted by a precomputed Fisher matrix `F`.

reEWC is for **single-modal** models (e.g. SevenNet-0, SevenNet-Nano). Multi-fidelity
(modal) models are not supported yet.

## Getting started

A ready-to-edit input with both mechanisms is available as a preset:

```bash
sevenn preset reewc > input.yaml
```

The preset documents every key inline. Replay lives in the `data:` block
(`rehearsal`, `load_memory_path`, `mem_batch_size`, `mem_ratio`) and EWC in the
`train.continue:` block (`fisher_information`, `opt_params`, `ewc_lambda`). Every
reEWC key is optional; when none are set, training is unchanged. Remove the replay
block or the EWC keys to run only one mechanism. Run training as usual:

```bash
sevenn train input.yaml -s
```

## Fisher information and reference parameters

`fisher_information` and `opt_params` are **precomputed and consumed** -- SevenNet
does not estimate the Fisher matrix. Both are `torch.save`d dictionaries keyed by
parameter name; `opt_params` is the parameter set of the checkpoint before
fine-tuning. They must satisfy:

- `fisher_information` and `opt_params` cover the **same parameter names** with the
**same shapes** (they are a matched pair).
- Names that overlap with the model's trainable parameters must have **matching
shapes**; a mismatch is an error (usually an incompatible checkpoint or SevenNet
version).
- At least one name must overlap with the model; no overlap is an error.
- A trainable parameter without a Fisher entry is **left unconstrained** and a
warning is emitted, so partial-coverage Fisher matrices are allowed but visible.

`ewc_lambda` must be `> 0`, and EWC requires both `fisher_information` and
`opt_params` to be set.

## Notes

- Replay supports `dataset_type: 'graph'` (the default) only.
- reEWC does not support distributed (DDP) training.
- `load_memory_path` is reserved for replay: setting it without `rehearsal: True`
raises an error.
- When replay is enabled, the memory set is evaluated each epoch and logged as a
`memoryset` column group in `lc.csv`, alongside `trainset` and `validset`.
- A `cosineannealingwarmuplr` scheduler (cosine annealing with warm-up restarts,
used for the reEWC paper work) is also available for fine-tuning.
12 changes: 12 additions & 0 deletions example_inputs/training/input_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ train:
#checkpoint: 'checkpoint_best.pth' # Checkpoint of pre-trained model or a model want to continue training.
#reset_optimizer: False # Set True for fine-tuning
#reset_scheduler: False # Set True for fine-tuning
# reEWC (single-modal models only): add an Elastic Weight Consolidation penalty from a
# precomputed Fisher matrix and reference parameters to preserve prior-task accuracy.
#fisher_information: './fisher.pt' # dict {param_name: tensor} of precomputed Fisher information
#opt_params: './opt_params.pt' # dict {param_name: tensor} of reference (pre-finetune) parameters
#ewc_lambda: 100000 # EWC penalty weight (must be > 0 when fisher/opt are given)

data:
batch_size: 4 # Per GPU batch size.
Expand All @@ -91,3 +96,10 @@ data:
load_trainset_path: ['./structure_list'] # Example of using ase as data_format, support multiple files and expansion(*)
#load_validset_path: ['./valid.extxyz']
#load_testset_path: ['./sevenn_data/mydata.pt'] # Graph can be preprocessed using `sevenn_graph_build` and accessible like this

# reEWC rehearsal (experience replay, single-modal models only): replay an old-task memory set
# each training step to mitigate catastrophic forgetting while fine-tuning on the target data.
#rehearsal: False # Set True to enable replay
#load_memory_path: ['./memory.extxyz'] # memory (old-task) set; this key is reserved for rehearsal
#mem_batch_size: 8 # batch size for the replayed memory set
#mem_ratio: 1 # fraction (0, 1] of the memory set to use
8 changes: 8 additions & 0 deletions sevenn/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ def model_defaults(config):
KEY.USE_MODAL_WISE_SCALE: False,
KEY.SHIFT: 'per_atom_energy_mean',
KEY.SCALE: 'force_rms',
KEY.REHEARSAL: False,
KEY.MEM_BATCH_SIZE: 0,
KEY.MEM_RATIO: 1,
# KEY.DATA_SHUFFLE: True,
# KEY.DATA_WEIGHT: False,
# KEY.DATA_MODALITY: False,
Expand All @@ -233,6 +236,11 @@ def model_defaults(config):
KEY.SCALE: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SCALE,
KEY.USE_MODAL_WISE_SHIFT: bool,
KEY.USE_MODAL_WISE_SCALE: bool,
KEY.REHEARSAL: lambda x: isinstance(x, bool),
KEY.MEM_BATCH_SIZE: lambda x: isinstance(x, int) and not isinstance(x, bool),
KEY.MEM_RATIO: lambda x: isinstance(x, (int, float))
and not isinstance(x, bool)
and 0 < x <= 1,
# KEY.DATA_SHUFFLE: bool,
KEY.COMPUTE_STATISTICS: bool,
# KEY.DATA_WEIGHT: bool,
Expand Down
7 changes: 7 additions & 0 deletions sevenn/_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@
LOAD_TRAINSET = 'load_trainset_path'
LOAD_VALIDSET = 'load_validset_path'
LOAD_TESTSET = 'load_testset_path'
LOAD_MEMORY_PATH = 'load_memory_path' # reEWC rehearsal memory set
REHEARSAL = 'rehearsal'
MEM_BATCH_SIZE = 'mem_batch_size'
MEM_RATIO = 'mem_ratio'
FORMAT_OUTPUTS = 'format_outputs_for_ase'
COMPUTE_STATISTICS = 'compute_statistics'
DATASET_TYPE = 'dataset_type'
Expand Down Expand Up @@ -135,6 +139,9 @@
USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY = (
'use_statistic_values_for_cp_modal_only'
)
OPT_PARAMS = 'opt_params' # reEWC: reference (optimal) params pickle
FISHER = 'fisher_information' # reEWC: precomputed Fisher information pickle
EWC_LAMBDA = 'ewc_lambda' # reEWC: EWC penalty weight

CSV_LOG = 'csv_log'

Expand Down
18 changes: 18 additions & 0 deletions sevenn/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,29 @@ def build_model(
enable_cueq = cp_using_cueq if enable_cueq is None else enable_cueq

cp_using_flash = self.config.get(KEY.USE_FLASH_TP, False)
flash_requested_explicitly = enable_flash is True
enable_flash = cp_using_flash if enable_flash is None else enable_flash

cp_using_oeq = self.config.get(KEY.USE_OEQ, False)
enable_oeq = cp_using_oeq if enable_oeq is None else enable_oeq

# FlashTP-saved checkpoints must still load where FlashTP is unavailable.
if enable_flash:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

이거 없었을때 기존 동작은 어떤 상태였나요

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load state dict에서
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for AtomGraphSequential:
Missing key(s) in state_dict: "0_convolution.convolution.weight", "0_convolution.convolution.output_mask", "1_convolution.convolution.weight", "1_convolution.convolution.output_mask", "1_convol
이런 오류가 계속 떴었어

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

저 코드랑 실질적으로 같은 로직이 아래있어
https://github.com/MDIL-SNU/SevenNet/blob/main/sevenn%2Fmodel_build.py#L306-L316

어떤 모델을 로딩할때 생기는 문제야? 내 생각에는 그 모델의 체크포인트가 갖고있는 weight가 flashTP의 형태인지, E3NN의 형태인지를 표기 안하고 있어서 SevenNet이 모르고, 모르는 상태에서 코드가 디폴트로 E3NN으로 가정하고 매핑하려고 해서 에러나는것 같은데 그러면 그 체크포인트가 본인이 뭘 가정하고 있는지 명시적으로 표기하게 만들어야 돼.

근데 이 PR은 체크포인트가 같이 없는데 이 부분이 수정되는게 부자연스러워. 결정적으로는 내 쪽에서 너가 겪은 불편함을 재현할 수가 없어서 정확한 문제진단을 못함.

@kskjs1203 kskjs1203 Jun 17, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

그냥 Omni 내장 포텐셜을 로딩했을 때를 포함해서 모든 flashtp 체크포인트 loading할 때 똑같은 문제 발생해서 그 로직이 적용안되는 거 같아 보여서 이번에 연결시킨건데 그래서 아마 omni가 fine-tuning이 안된다고 다른 분들이 알고 계셨던 거 같어 Omni를 loading 해보면 될 거 같은데

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

저번 주에 얘기한 것처럼 체크포인트에 명시하는 방향은 내가 이번에 건들기에는 어려울 거 같아서 이 부분 근본 해결이 필요할 듯

from sevenn.nn.flash_helper import is_flash_available

if not is_flash_available():
if flash_requested_explicitly or _flash_lammps:
raise ValueError(
'FlashTP was requested but is not available (package '
'not installed or no GPU available).'
)
warnings.warn(
'FlashTP is unavailable; loading the checkpoint with the '
'e3nn backend instead.',
UserWarning,
)
enable_flash = False

if sum([enable_cueq, enable_flash, enable_oeq]) > 1:
raise ValueError('Only one TP accelerator can be enabled.')

Expand Down
1 change: 1 addition & 0 deletions sevenn/main/sevenn_preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def add_args(parser):
'base',
'multi_modal',
'mf_ompa_fine_tune',
'reewc',
],
help=preset_help
)
Expand Down
90 changes: 90 additions & 0 deletions sevenn/presets/reewc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Example input.yaml for forgetting-prevented fine-tuning (reEWC).
# Replay and EWC are independent; keep one block, both, or neither.
# reEWC is for single-modal models (e.g. SevenNet-0, SevenNet-Nano).

model: # keep consistent with the checkpoint being fine-tuned
chemical_species: 'Auto'
cutoff: 5.0
channel: 128
is_parity: False
lmax: 2
num_convolution_layer: 5
irreps_manual:
- "128x0e"
- "128x0e+64x1e+32x2e"
- "128x0e+64x1e+32x2e"
- "128x0e+64x1e+32x2e"
- "128x0e+64x1e+32x2e"
- "128x0e"

weight_nn_hidden_neurons: [64, 64]
radial_basis:
radial_basis_name: 'bessel'
bessel_basis_num: 8
cutoff_function:
cutoff_function_name: 'XPLOR'
cutoff_on: 4.5
self_connection_type: 'linear'

train_shift_scale: False
train_denominator: False

train:
random_seed: 1
is_train_stress: True
epoch: 100

loss: 'Huber'
loss_param:
delta: 0.01

optimizer: 'adam'
optim_param:
lr: 0.004
# cosineannealingwarmuplr (cosine annealing with warm-up restarts) was used
# for the reEWC work; exponentiallr also works.
scheduler: 'exponentiallr'
scheduler_param:
gamma: 0.99

force_loss_weight: 1.0
stress_loss_weight: 0.01

per_epoch: 10

error_record:
- ['Energy', 'RMSE']
- ['Force', 'RMSE']
- ['Stress', 'RMSE']
- ['TotalLoss', 'None']

continue:
reset_optimizer: True
reset_scheduler: True
reset_epoch: True
checkpoint: 'SevenNet-0_11July2024'

# EWC: anchor parameters to their pre-fine-tuning values via a
# precomputed Fisher matrix. fisher_information and opt_params are
# torch.save'd dicts {param_name: tensor} matching the model's
# trainable parameters; opt_params is the checkpoint before fine-tuning.
# All three keys are required together; remove them to disable EWC.
fisher_information: './fisher.pt'
opt_params: './opt_params.pt'
ewc_lambda: 100000 # EWC penalty weight (> 0)

data:
batch_size: 4
data_divide_ratio: 0.1
data_format_args:
index: ':'

load_trainset_path: ['./target_train.extxyz']
load_validset_path: ['./valid.extxyz']

# Replay (experience replay): replay an old-task memory set each step so the
# model keeps fitting it. Remove this block to disable replay.
rehearsal: True
load_memory_path: ['./memory.extxyz'] # requires rehearsal: True
mem_batch_size: 8
mem_ratio: 1 # fraction (0, 1] of the memory set
18 changes: 17 additions & 1 deletion sevenn/scripts/processing_epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def processing_epoch_v2(
config, trainer.loss_functions
)
recorders = {k: deepcopy(recorder) for k in loaders}
# reEWC: log the replayed memory set as a separate 'memoryset' column group.
memory_recorder = (
deepcopy(recorder)
if getattr(trainer, 'memory_loader', None) is not None
else None
)

best_val = float('inf')
best_key = None
Expand All @@ -58,6 +64,8 @@ def processing_epoch_v2(
head = ['epoch', 'lr']
for k, rec in recorders.items():
head.extend(list(rec.get_dct(prefix=k)))
if memory_recorder is not None:
head.extend(list(memory_recorder.get_dct(prefix='memoryset')))
with open(csv_path, 'w') as f:
f.write(','.join(head) + '\n')

Expand Down Expand Up @@ -88,9 +96,17 @@ def processing_epoch_v2(
loader.sampler.set_epoch(epoch)

rec = recorders[k]
trainer.run_one_epoch(loader, is_train, rec)
trainer.run_one_epoch(
loader,
is_train,
rec,
memory_error_recorder=memory_recorder if is_train else None,
)
csv_dct.update(rec.get_dct(prefix=k))
errors[k] = rec.epoch_forward()
if memory_recorder is not None:
csv_dct.update(memory_recorder.get_dct(prefix='memoryset'))
errors['memoryset'] = memory_recorder.epoch_forward()
log.write_full_table(list(errors.values()), list(errors))
trainer.scheduler_step(best_val)

Expand Down
Loading
Loading