feat: enable save/load after recovery without additional steps#619
feat: enable save/load after recovery without additional steps#619psyberck wants to merge 5 commits into
Conversation
…: get_save_before_smash_dir
…'s weights for "save_before_apply" algos
Not up to standards ⛔🔴 Issues
|
| Category | Results |
|---|---|
| Security | 3 high |
🟢 Metrics 12 complexity · 0 duplication
Metric Results Complexity 12 Duplication 0
TIP This summary will be updated as you push new changes. Give us feedback
|
Hi @psyberck! Thank you for your contribution! I've run the tests. Please take a look before we jump into any review :) |
|
Hey @sdiazlor , can you help kindly re-run the tests? There seems some network failures during the tests: |
|
Sure thing! I ran them again :) |
|
This PR has been inactive for 10 days and is now marked as stale. |
simlang
left a comment
There was a problem hiding this comment.
Thanks for looking at this! I left a view requests :)
| smash_config.save_fns = [fn for fn in smash_config.save_fns if fn != SAVE_FUNCTIONS.save_before_apply.name] | ||
| # Re-save with recovered weights | ||
| shutil.rmtree(save_dir, ignore_errors=True) | ||
| save_dir.mkdir(parents=True) |
There was a problem hiding this comment.
this line can be removed, as this is already handled in save_pruna_model
|
|
||
|
|
||
| @pytest.mark.cpu | ||
| def test_recovery_save_fn_is_none() -> None: |
|
|
||
|
|
||
| @pytest.mark.cpu | ||
| def test_recovery_does_not_add_to_save_fns(tmp_path) -> None: |
There was a problem hiding this comment.
i'm not sure what this is testing, as recovery is not applied
| # Simulate a save_before_apply algorithm having run before recovery: | ||
| # 1. Save original (pre-transformation) model to cache | ||
| save_dir = PrunaAlgorithmBase.get_save_before_smash_dir(config) | ||
| save_dir.mkdir(parents=True) |
There was a problem hiding this comment.
creating the directory is handled in the save
| adapter_smash_config = SmashConfigPrefixWrapper(smash_config, adapter.adapter_prefix + "_") | ||
| adapter.pre_smash_hook(model_recovery, adapter_smash_config, seed=adapter_seed) | ||
|
|
||
| def apply(self, model: Any, smash_config: SmashConfig) -> Any: |
There was a problem hiding this comment.
imo it would make sense to add a new saving function instead of overriding the apply here :)
| """Test that recovery refreshes a stale save_before_apply cache with recovered weights.""" | ||
| from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained("yujiepan/opt-tiny-random") |
There was a problem hiding this comment.
since we only change a single weight, i don't think it's necessary to load a HF model, but instead we probably should use a dummy model like other tests above.
| # 3. Simulate the transformation (e.g., half) + recovery modifying weights | ||
| model.lm_head.weight.data.fill_(0.99) # "recovered" weights | ||
|
|
||
| # 4. Simulate what recovery's apply() does: refresh the stale cache |
There was a problem hiding this comment.
i don't think that simulating is a good idea, because if the actual recovery-apply-save-logic changes, it's not actually being tested anymore.
|
Hi @psyberck! When you have time, could you take a look at the comments? :) |
Description
Extend save and load to also allow saving and loading after recovery without additional steps.
Related Issue
Fixes #603
Type of Change
Testing
uv run pytest -m "cpu and not slow")For full setup and testing instructions, see the Contributing Guide.
Checklist
Thanks for contributing to Pruna! We're excited to review your work.
New to contributing? Check out our Contributing Guide for everything you need to get started.
First Prune (1-year OSS anniversary)
First Prune marks one year of Pruna’s open-source work. During the initiative window, qualifying merged contributions count toward First Prune. You can earn credits for our performance models via our API.
If you’d like your contribution to count toward First Prune, here’s how it works: