-
Notifications
You must be signed in to change notification settings - Fork 55
Add reEWC #312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kskjs1203
wants to merge
3
commits into
MDIL-SNU:main
Choose a base branch
from
kskjs1203:reewc
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add reEWC #312
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| # Forgetting-prevented (Continual-learning) fine-tuning (reEWC) | ||
|
|
||
| Fine-tuning a pretrained model on a target system improves accuracy there, but the | ||
| model can lose accuracy on the original training domain (catastrophic forgetting). | ||
| reEWC mitigates this with two complementary mechanisms that can be used together or | ||
| separately: | ||
|
|
||
| - **Experience replay (rehearsal)** -- replay an old-task "memory" set each training | ||
| step so the model keeps fitting it while learning the target data. | ||
| - **Elastic Weight Consolidation (EWC)** -- add a penalty | ||
| `lambda/2 * sum_i F_i (theta_i - theta*_i)^2` that anchors parameters to their | ||
| pre-fine-tuning values `theta*`, weighted by a precomputed Fisher matrix `F`. | ||
|
|
||
| reEWC is for **single-modal** models (e.g. SevenNet-0, SevenNet-Nano). Multi-fidelity | ||
| (modal) models are not supported yet. | ||
|
|
||
| ## Getting started | ||
|
|
||
| A ready-to-edit input with both mechanisms is available as a preset: | ||
|
|
||
| ```bash | ||
| sevenn preset reewc > input.yaml | ||
| ``` | ||
|
|
||
| The preset documents every key inline. Replay lives in the `data:` block | ||
| (`rehearsal`, `load_memory_path`, `mem_batch_size`, `mem_ratio`) and EWC in the | ||
| `train.continue:` block (`fisher_information`, `opt_params`, `ewc_lambda`). Every | ||
| reEWC key is optional; when none are set, training is unchanged. Remove the replay | ||
| block or the EWC keys to run only one mechanism. Run training as usual: | ||
|
|
||
| ```bash | ||
| sevenn train input.yaml -s | ||
| ``` | ||
|
|
||
| ## Fisher information and reference parameters | ||
|
|
||
| `fisher_information` and `opt_params` are **precomputed and consumed** -- SevenNet | ||
| does not estimate the Fisher matrix. Both are `torch.save`d dictionaries keyed by | ||
| parameter name; `opt_params` is the parameter set of the checkpoint before | ||
| fine-tuning. They must satisfy: | ||
|
|
||
| - `fisher_information` and `opt_params` cover the **same parameter names** with the | ||
| **same shapes** (they are a matched pair). | ||
| - Names that overlap with the model's trainable parameters must have **matching | ||
| shapes**; a mismatch is an error (usually an incompatible checkpoint or SevenNet | ||
| version). | ||
| - At least one name must overlap with the model; no overlap is an error. | ||
| - A trainable parameter without a Fisher entry is **left unconstrained** and a | ||
| warning is emitted, so partial-coverage Fisher matrices are allowed but visible. | ||
|
|
||
| `ewc_lambda` must be `> 0`, and EWC requires both `fisher_information` and | ||
| `opt_params` to be set. | ||
|
|
||
| ## Notes | ||
|
|
||
| - Replay supports `dataset_type: 'graph'` (the default) only. | ||
| - reEWC does not support distributed (DDP) training. | ||
| - `load_memory_path` is reserved for replay: setting it without `rehearsal: True` | ||
| raises an error. | ||
| - When replay is enabled, the memory set is evaluated each epoch and logged as a | ||
| `memoryset` column group in `lc.csv`, alongside `trainset` and `validset`. | ||
| - A `cosineannealingwarmuplr` scheduler (cosine annealing with warm-up restarts, | ||
| used for the reEWC paper work) is also available for fine-tuning. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ def add_args(parser): | |
| 'base', | ||
| 'multi_modal', | ||
| 'mf_ompa_fine_tune', | ||
| 'reewc', | ||
| ], | ||
| help=preset_help | ||
| ) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| # Example input.yaml for forgetting-prevented fine-tuning (reEWC). | ||
| # Replay and EWC are independent; keep one block, both, or neither. | ||
| # reEWC is for single-modal models (e.g. SevenNet-0, SevenNet-Nano). | ||
|
|
||
| model: # keep consistent with the checkpoint being fine-tuned | ||
| chemical_species: 'Auto' | ||
| cutoff: 5.0 | ||
| channel: 128 | ||
| is_parity: False | ||
| lmax: 2 | ||
| num_convolution_layer: 5 | ||
| irreps_manual: | ||
| - "128x0e" | ||
| - "128x0e+64x1e+32x2e" | ||
| - "128x0e+64x1e+32x2e" | ||
| - "128x0e+64x1e+32x2e" | ||
| - "128x0e+64x1e+32x2e" | ||
| - "128x0e" | ||
|
|
||
| weight_nn_hidden_neurons: [64, 64] | ||
| radial_basis: | ||
| radial_basis_name: 'bessel' | ||
| bessel_basis_num: 8 | ||
| cutoff_function: | ||
| cutoff_function_name: 'XPLOR' | ||
| cutoff_on: 4.5 | ||
| self_connection_type: 'linear' | ||
|
|
||
| train_shift_scale: False | ||
| train_denominator: False | ||
|
|
||
| train: | ||
| random_seed: 1 | ||
| is_train_stress: True | ||
| epoch: 100 | ||
|
|
||
| loss: 'Huber' | ||
| loss_param: | ||
| delta: 0.01 | ||
|
|
||
| optimizer: 'adam' | ||
| optim_param: | ||
| lr: 0.004 | ||
| # cosineannealingwarmuplr (cosine annealing with warm-up restarts) was used | ||
| # for the reEWC work; exponentiallr also works. | ||
| scheduler: 'exponentiallr' | ||
| scheduler_param: | ||
| gamma: 0.99 | ||
|
|
||
| force_loss_weight: 1.0 | ||
| stress_loss_weight: 0.01 | ||
|
|
||
| per_epoch: 10 | ||
|
|
||
| error_record: | ||
| - ['Energy', 'RMSE'] | ||
| - ['Force', 'RMSE'] | ||
| - ['Stress', 'RMSE'] | ||
| - ['TotalLoss', 'None'] | ||
|
|
||
| continue: | ||
| reset_optimizer: True | ||
| reset_scheduler: True | ||
| reset_epoch: True | ||
| checkpoint: 'SevenNet-0_11July2024' | ||
|
|
||
| # EWC: anchor parameters to their pre-fine-tuning values via a | ||
| # precomputed Fisher matrix. fisher_information and opt_params are | ||
| # torch.save'd dicts {param_name: tensor} matching the model's | ||
| # trainable parameters; opt_params is the checkpoint before fine-tuning. | ||
| # All three keys are required together; remove them to disable EWC. | ||
| fisher_information: './fisher.pt' | ||
| opt_params: './opt_params.pt' | ||
| ewc_lambda: 100000 # EWC penalty weight (> 0) | ||
|
|
||
| data: | ||
| batch_size: 4 | ||
| data_divide_ratio: 0.1 | ||
| data_format_args: | ||
| index: ':' | ||
|
|
||
| load_trainset_path: ['./target_train.extxyz'] | ||
| load_validset_path: ['./valid.extxyz'] | ||
|
|
||
| # Replay (experience replay): replay an old-task memory set each step so the | ||
| # model keeps fitting it. Remove this block to disable replay. | ||
| rehearsal: True | ||
| load_memory_path: ['./memory.extxyz'] # requires rehearsal: True | ||
| mem_batch_size: 8 | ||
| mem_ratio: 1 # fraction (0, 1] of the memory set |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
이거 없었을때 기존 동작은 어떤 상태였나요
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load state dict에서
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for AtomGraphSequential:
Missing key(s) in state_dict: "0_convolution.convolution.weight", "0_convolution.convolution.output_mask", "1_convolution.convolution.weight", "1_convolution.convolution.output_mask", "1_convol
이런 오류가 계속 떴었어
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
저 코드랑 실질적으로 같은 로직이 아래있어
https://github.com/MDIL-SNU/SevenNet/blob/main/sevenn%2Fmodel_build.py#L306-L316
어떤 모델을 로딩할때 생기는 문제야? 내 생각에는 그 모델의 체크포인트가 갖고있는 weight가 flashTP의 형태인지, E3NN의 형태인지를 표기 안하고 있어서 SevenNet이 모르고, 모르는 상태에서 코드가 디폴트로 E3NN으로 가정하고 매핑하려고 해서 에러나는것 같은데 그러면 그 체크포인트가 본인이 뭘 가정하고 있는지 명시적으로 표기하게 만들어야 돼.
근데 이 PR은 체크포인트가 같이 없는데 이 부분이 수정되는게 부자연스러워. 결정적으로는 내 쪽에서 너가 겪은 불편함을 재현할 수가 없어서 정확한 문제진단을 못함.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
그냥 Omni 내장 포텐셜을 로딩했을 때를 포함해서 모든 flashtp 체크포인트 loading할 때 똑같은 문제 발생해서 그 로직이 적용안되는 거 같아 보여서 이번에 연결시킨건데 그래서 아마 omni가 fine-tuning이 안된다고 다른 분들이 알고 계셨던 거 같어 Omni를 loading 해보면 될 거 같은데
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
저번 주에 얘기한 것처럼 체크포인트에 명시하는 방향은 내가 이번에 건들기에는 어려울 거 같아서 이 부분 근본 해결이 필요할 듯