Spin/Charge System Conditioning#1080
Conversation
Adds an `extra_data_options` parameter to `MemmapDataset` that loads per-system scalar arrays from `.bin` files alongside the training targets. Each key (e.g. `mtt::charge`) maps to a `TensorMap` in the sample namedtuple and is forwarded to the `extra` argument of `CollateFn` callables. Also adds `get_extra_data_info()` to expose `TargetInfo` metadata for the extra_data keys, mirroring `get_target_info()`. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Drop stale system.add_data() calls for charge/spin from MemmapDataset.__getitem__
(data now flows through extra_data_options + get_system_conditioning_transform)
- Remove charge/spin fields from SystemsHypers in base_hypers.py
(config now lives under extra_data: {mtt::charge: {key: ...}})
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…rade Old muon2 checkpoints already contain system_conditioning.* weights. Setting system_conditioning=False was dropping them silently. Now the upgrade checks for the presence of those weights and enables the hyper automatically, so converted checkpoints use the embedding they were trained with. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
or per-atom extra data is never passed to the system transform - system_data: raise if a TensorMap is per-atom (samples != ["system"]) rather than silently misindexing into the systems list - model: validate charge/spin are integer-valued before .long() conversion; log.debug when a system falls back to default 0/1 - conditioning: document zero-init gate design intent
…add test that confirms eval is working with spin/charge
|
What's the policy now on how to update the classifier/llpr checkpoints? |
# Conflicts: # pyproject.toml # src/metatrain/utils/scaler/_base_scaler.py # tests/utils/data/test_readers.py
…lse) merge The reference fixture's storage changed slightly when the scaler switched from the manual TensorBlock rebuild to block.copy(deep=False) (from PR metatensor#1144). Regenerated via test_checkpoint_did_not_change.
Use Mapping (covariant) so trainer call sites passing dict[str, LossSpecification] type-check, and assign the normalized form to a locally-typed variable so mypy narrows the union away inside the loop.
The origin/max-atom-sampler fork sits at model_ckpt_version 11 with an earlier flavour of the system-conditioning module (spin_embedding, max_spin) and pre-dates the per-property scaler split. Add a standalone model_update_from_max_atom_sampler() in pet/checkpoints.py that renames the spin_embedding state-dict keys to spin_multiplicity_embedding, renames max_spin to max_spin_multiplicity, fills defaults for adaptive_cutoff_method and the rest of the system_conditioning hypers, runs the per-property scaler migration, and bumps the version to current so the standard loader accepts it. Round-trip tested by downgrading a freshly-built v14 checkpoint into max-atom-sampler shape and confirming the migration restores something PET.load_checkpoint accepts.
The merge with PR metatensor#1144 replaced our regex catch-alls for torch.jit deprecations and the plural-name family ('features', 'positions', 'momenta', etc.) with PR metatensor#1144's narrower per-warning entries. macOS CI was failing 61 tests on a metatomic 0.1.14 UserWarning about Model.requested_inputs(use_new_names=False) that no remaining filter caught. Restore the regex catch-alls, add an explicit ignore for the new requested_inputs deprecation, and fix the one offending call site (cli/eval.py) to pass use_new_names=True so we no longer emit the warning from our own code.
…names=False\\) is deprecated:UserWarning",
|
is this ready for review or still wip/testing? there are quite a few changes, just want to make sure they're all related to the spin/charge conditioning feature and nothing extra slipped in ^^ |
These three fixes were carried on this branch but are unrelated to the spin/charge system-conditioning feature this PR is about: * ``utils/loss.py`` — accept string / per-target string shorthand in ``LossAggregator``. * ``utils/metrics.py`` — skip per-gradient metric when the prediction block lacks the gradient. * ``utils/additive/remove.py`` — zero-fill missing additive gradients before the block-wise subtract. They now live on a separate branch and will be proposed as their own PR. PET tests (including conditioning) still pass without them on this branch.
…oning-rebased # Conflicts: # pyproject.toml # src/metatrain/cli/eval.py
PR metatensor#1075 renamed ``per_atom`` to ``sample_kind`` throughout the target hypers, but the new extra_data path on this branch was still reading ``opts["per_atom"]``. After merging upstream, target_options that flow through ``sanitize_target_hypers`` no longer carry ``per_atom`` and the tests we added passed ``per_atom: False`` directly, so MemmapDataset crashed with ``KeyError: 'sample_kind'``. Switch the two extra_data sites in ``MemmapDataset`` to ``opts["sample_kind"] == "atom"`` and update the per-sample test configs to use ``sample_kind`` instead of ``per_atom``.
|
Sorry there was still one unrelated metric change. it is ready for review now @sofiia-chorna |
|
cscs-ci run |
sofiia-chorna
left a comment
There was a problem hiding this comment.
thanks for the updates
it would be nice to add a doc with an example yaml for training and evaluation (it is also OK for me to add it in the follow up PR)
left some clean up comments, otherwise seems good
| return checkpoint | ||
|
|
||
|
|
||
| def test_model_update_from_max_atom_sampler(): |
There was a problem hiding this comment.
do we need it on this branch? or it is a leftover from your max-atom-sampler branch?
| checkpoint[key] = new_state_dict | ||
|
|
||
|
|
||
| def model_update_from_max_atom_sampler(checkpoint: dict) -> dict: |
There was a problem hiding this comment.
same, is a leftover from the max-atom-sampler branch?
| update_per_property_scales(checkpoint) | ||
|
|
||
|
|
||
| def model_update_v13_v14(checkpoint: dict) -> None: |
There was a problem hiding this comment.
if i understand correctly, has_conditioning_weights (and edge_linear/spin_embedding renames) can only be triggerred for checkpoints from your muon2 branch. so i am wondering if we actually need those changes? or you need this to be able to load your trained checkpoints?
There was a problem hiding this comment.
yes due to how long the omol models have lived in my separated branch we need an extra checkpoint function to load them.
There was a problem hiding this comment.
oh i see 🥺
question: those checkpoints are already published? if not, maybe it is easier to convert them to the "main" branch format instead...
- Remove dead DiskDataset.extra_data_config / get_extra_data_info (never populated or called; MemmapDataset keeps its own) - Hoist the model checkpoint version into checkpoints.MODEL_CHECKPOINT_VERSION so the max-atom-sampler migration no longer needs a lazy "from .model import PET" - Document the batch-order invariant get_system_data_transform relies on (rows in batch order, label values are dataset indices) - Add docs: example training/eval yaml for charge and spin conditioning in the PET architecture page - Add tests: end-to-end mtt eval routing of extra_data to the conditioning module, non-integer charge/spin ValueError, and a trainer run proving the conditioning transform is wired into the collate pipeline Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
|
locally the tests run. probably just github having problems again I will rerun in an hour ... |
|
thanks for the updates! i will quickly push a minor clean up commit to move imports to the top (sorrryyy i have this obsession 😂) i will approve since the open question is non-blocking and can be fixed in the separate PR but: checkpoints that uses this feature, are they already published/shared anywhere? also interested to know the opinion from @pfebrer and @Luthaf ^^ |
System conditioning for PET (charge & spin)
Adds per-system charge and spin multiplicity conditioning to the PET architecture,
allowing a single model to be trained and evaluated across multiple charge and spin
states. The feature is activated through
architecture.model.system_conditioning: true.System Conditioning is separated into its own
SystemConditioningEmbeddingmodule (pet/modules/conditioning.py). The resulting embedding is added toPET's node features via a zero-initialised gated projection, so the model starts as
the unconditioned baseline and learns to use charge/spin information only as needed.
Charge and spin are supplied as
mtt::charge(integer, elementary charges) andmtt::spin(integer, spin multiplicity 2S+1) in theextra_datasection of thedataset config or via
atoms.infoin ASE (requires merging of a PR into metatrain).Changes
New files
src/metatrain/pet/modules/conditioning.py—SystemConditioningEmbeddingmoduleand
get_system_conditioning_transform(re-exported fromutils/system_data.py)src/metatrain/utils/system_data.py— genericget_system_data_transformcallablefor attaching per-system scalar TensorMaps to
Systemobjects in aCollateFnsrc/metatrain/pet/tests/test_conditioning.py— test suite for the featurePET model (
pet/model.py)system_conditioninghyper; if enabled, buildsSystemConditioningEmbeddingand injects the embedding into node features during both initial featurisation and
residual updates
mtt::charge/mtt::spininrequested_inputs()so the exported modelcommunicates its requirements to downstream tools (ASE calculator, eval pipeline)
Training (
pet/trainer.py)model.system_conditioning.required_data_keysand registersget_system_conditioning_transformas aCollateFncallable so charge/spin areattached to
Systemobjects during trainingCheckpoint upgrade (
pet/checkpoints.py)system_conditioning.*weights in thestate dict to auto-enable the hyper for checkpoints trained with conditioning
(avoids silent neutral-singlet fallback when loading old muon-branch checkpoints)
Eval (
cli/eval.py+utils/system_data.py)mtt evalnow readsextra_datafrom the dataset config and routes any keyspresent in the model's
requested_inputs()throughget_system_data_transform,so charge/spin reach the model during evaluation
preventing silent index errors on mixed datasets
Hypers (
pet/documentation.py,share/base_hypers.py)system_conditioningblock inModelHypers:system_conditioning: bool,max_charge: int = 10,max_spin: int = 10Training config example
Contributor (creator of pull-request) checklist
Maintainer/Reviewer checklist
📚 Documentation preview 📚: https://metatrain--1080.org.readthedocs.build/en/1080/