Skip to content

Add reEWC#312

Open
kskjs1203 wants to merge 2 commits into
MDIL-SNU:mainfrom
kskjs1203:reewc
Open

Add reEWC#312
kskjs1203 wants to merge 2 commits into
MDIL-SNU:mainfrom
kskjs1203:reewc

Conversation

@kskjs1203

Copy link
Copy Markdown
Collaborator

Adds reEWC for single-modal models only (sevennet-0, sevennet-nano)

  • Replay (rehearsal, load_memory_path, mem_batch_size, mem_ratio), logged as a memoryset column.
  • EWC penalty from a precomputed Fisher matrix (continue.fisher_information, continue.opt_params, continue.ewc_lambda).
  • cosineannealingwarmuplr scheduler import
  • FlashTP-saved checkpoints load on the e3nn backend when FlashTP is unavailable.

Comment thread sevenn/train/loss.py Outdated
Comment on lines +236 to +275
def _check_and_align(self, model: Callable) -> None:
if len(self.fisher_dict) == 0 or len(self.opt_params_dict) == 0:
raise ValueError('EWC fisher_information/opt_params is empty')
model_params = {
n: p for n, p in model.named_parameters() if p.requires_grad
}
if len(model_params) == 0:
raise ValueError('EWC requires the model to have trainable parameters')
if len(set(self.fisher_dict) & set(model_params)) == 0:
raise ValueError(
'EWC fisher/opt_params parameter names do not match the model; '
'the pickle was likely produced by an incompatible SevenNet '
f'version. example model param: {next(iter(model_params))}; '
f'example fisher key: {next(iter(self.fisher_dict))}'
)
# every trainable parameter must be covered by Fisher and reference
# params with the right shape, so EWC never silently skips a parameter
# it should constrain.
for name, param in model_params.items():
if name not in self.fisher_dict:
raise ValueError(
f'EWC fisher_information is missing trainable param {name}'
)
if name not in self.opt_params_dict:
raise ValueError(
f'EWC opt_params is missing trainable param {name}'
)
if self.fisher_dict[name].shape != param.shape:
raise ValueError(
f'EWC fisher shape mismatch for {name}: '
f'{tuple(self.fisher_dict[name].shape)} != {tuple(param.shape)}'
)
if self.opt_params_dict[name].shape != param.shape:
raise ValueError(
f'EWC opt_params shape mismatch for {name}: '
f'{tuple(self.opt_params_dict[name].shape)} != '
f'{tuple(param.shape)}'
)
self.to(next(iter(model_params.values())).device)
self._checked = True

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_check_and_align: 제 생각에는 아래와 같이 작동해야할 것 같은데요.

  1. fisher_dict vs. opt_params_dict
  • set(key)와 value.shape 가 모두 동일해야 함.
  1. fisher_dict vs. model_params
  • set(key of fisher_dict)가 set(key of model_params)에 포함됨.
  • 포함된 key들에 대해 value.shape가 모두 동일해야 함.
  • fisher_dict에 없는 model_param key의 경우 패널티 없이 warning (해당 키에 fisher 없음) 띄우고 조용히 넘어가기.

현재 버전의 로직은 특히 아래 부분에서 좀 다른 것으로 보입니다.

  • model_params에 없는 키가 fisher_dict에 존재할 수 있음.
  • fisher_dict에 없는 model_param 키가 있으면 raise error.

당연히 제가 잘못 이해했을 수 있습니다

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

이 부분이 EWC sanity check하는데 핵심처럼 보여서, 위 제약 조건은 docs에도 적히는 게 좋아 보입니다.

reEWC is for **single-modal** models (e.g. SevenNet-0, SevenNet-Nano). Multi-fidelity
(modal) models are not supported yet.

## Configuration

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yaml 내용이랑 설명을 sevenn/presets 아래로 옮기고 document 에서는 해당 file ref 하는 식으로

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sevenn preset 커맨드로 해당 폴더에있는 yaml 은 그대로 가져올 수 있음.

Comment thread sevenn/checkpoint.py
enable_oeq = cp_using_oeq if enable_oeq is None else enable_oeq

# FlashTP-saved checkpoints must still load where FlashTP is unavailable.
if enable_flash:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

이거 없었을때 기존 동작은 어떤 상태였나요

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants