Skip to content

Spin/Charge System Conditioning#1080

Open
JonathanSchmidt1 wants to merge 91 commits into
metatensor:mainfrom
JonathanSchmidt1:only-system-conditioning-rebased
Open

Spin/Charge System Conditioning#1080
JonathanSchmidt1 wants to merge 91 commits into
metatensor:mainfrom
JonathanSchmidt1:only-system-conditioning-rebased

Conversation

@JonathanSchmidt1

@JonathanSchmidt1 JonathanSchmidt1 commented Mar 22, 2026

Copy link
Copy Markdown
Contributor

System conditioning for PET (charge & spin)

Adds per-system charge and spin multiplicity conditioning to the PET architecture,
allowing a single model to be trained and evaluated across multiple charge and spin
states. The feature is activated through architecture.model.system_conditioning: true.
System Conditioning is separated into its own SystemConditioningEmbedding module (pet/modules/conditioning.py). The resulting embedding is added to
PET's node features via a zero-initialised gated projection, so the model starts as
the unconditioned baseline and learns to use charge/spin information only as needed.

Charge and spin are supplied as mtt::charge (integer, elementary charges) and
mtt::spin (integer, spin multiplicity 2S+1) in the extra_data section of the
dataset config or via atoms.info in ASE (requires merging of a PR into metatrain).

Changes

New files

  • src/metatrain/pet/modules/conditioning.pySystemConditioningEmbedding module
    and get_system_conditioning_transform (re-exported from utils/system_data.py)
  • src/metatrain/utils/system_data.py — generic get_system_data_transform callable
    for attaching per-system scalar TensorMaps to System objects in a CollateFn
  • src/metatrain/pet/tests/test_conditioning.py — test suite for the feature

PET model (pet/model.py)

  • Adds system_conditioning hyper; if enabled, builds SystemConditioningEmbedding
    and injects the embedding into node features during both initial featurisation and
    residual updates
  • Declares mtt::charge / mtt::spin in requested_inputs() so the exported model
    communicates its requirements to downstream tools (ASE calculator, eval pipeline)

Training (pet/trainer.py)

  • Reads model.system_conditioning.required_data_keys and registers
    get_system_conditioning_transform as a CollateFn callable so charge/spin are
    attached to System objects during training

Checkpoint upgrade (pet/checkpoints.py)

  • v11 → v12 upgrade detects the presence of system_conditioning.* weights in the
    state dict to auto-enable the hyper for checkpoints trained with conditioning
    (avoids silent neutral-singlet fallback when loading old muon-branch checkpoints)

Eval (cli/eval.py + utils/system_data.py)

  • mtt eval now reads extra_data from the dataset config and routes any keys
    present in the model's requested_inputs() through get_system_data_transform,
    so charge/spin reach the model during evaluation
  • The transform raises early if a TensorMap is per-atom rather than per-system,
    preventing silent index errors on mixed datasets

Hypers (pet/documentation.py, share/base_hypers.py)

  • New system_conditioning block in ModelHypers:
    system_conditioning: bool, max_charge: int = 10, max_spin: int = 10

Training config example

architecture:
  model:
    system_conditioning: true
    max_charge: 5   # embeds charges in [-5, +5]
    max_spin: 5     # embeds multiplicities in [1, 5]

training_set:
  - path: dataset.mtt
    extra_data:
      mtt::charge:
        field: charge
      mtt::spin:
        field: spin

Contributor (creator of pull-request) checklist

  • Tests updated (for new features and bugfixes)?
  • Documentation updated (for new features)?
  • Issue referenced (for PRs that solve an issue)?

Maintainer/Reviewer checklist

  • CHANGELOG updated with public API or any other important changes?
  • GPU tests passed (maintainer comment: "cscs-ci run")?

📚 Documentation preview 📚: https://metatrain--1080.org.readthedocs.build/en/1080/

JonathanSchmidt1 and others added 26 commits March 19, 2026 13:36
Adds an `extra_data_options` parameter to `MemmapDataset` that loads
per-system scalar arrays from `.bin` files alongside the training targets.
Each key (e.g. `mtt::charge`) maps to a `TensorMap` in the sample namedtuple
and is forwarded to the `extra` argument of `CollateFn` callables.

Also adds `get_extra_data_info()` to expose `TargetInfo` metadata for the
extra_data keys, mirroring `get_target_info()`.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Drop stale system.add_data() calls for charge/spin from MemmapDataset.__getitem__
  (data now flows through extra_data_options + get_system_conditioning_transform)
- Remove charge/spin fields from SystemsHypers in base_hypers.py
  (config now lives under extra_data: {mtt::charge: {key: ...}})

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…rade

Old muon2 checkpoints already contain system_conditioning.* weights.
Setting system_conditioning=False was dropping them silently. Now the
upgrade checks for the presence of those weights and enables the hyper
automatically, so converted checkpoints use the embedding they were
trained with.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
  or per-atom extra data is never passed to the system transform
- system_data: raise if a TensorMap is per-atom (samples != ["system"])
  rather than silently misindexing into the systems list
- model: validate charge/spin are integer-valued before .long()
  conversion; log.debug when a system falls back to default 0/1
- conditioning: document zero-init gate design intent
…add test that confirms eval is working with spin/charge
@JonathanSchmidt1

Copy link
Copy Markdown
Contributor Author

What's the policy now on how to update the classifier/llpr checkpoints?

sofiia-chorna and others added 6 commits May 20, 2026 16:31
# Conflicts:
#	pyproject.toml
#	src/metatrain/utils/scaler/_base_scaler.py
#	tests/utils/data/test_readers.py
…lse) merge

The reference fixture's storage changed slightly when the scaler switched
from the manual TensorBlock rebuild to block.copy(deep=False) (from PR
metatensor#1144). Regenerated via test_checkpoint_did_not_change.
JonathanSchmidt1 and others added 4 commits May 21, 2026 14:14
Use Mapping (covariant) so trainer call sites passing dict[str,
LossSpecification] type-check, and assign the normalized form to a
locally-typed variable so mypy narrows the union away inside the loop.
The origin/max-atom-sampler fork sits at model_ckpt_version 11 with an
earlier flavour of the system-conditioning module (spin_embedding,
max_spin) and pre-dates the per-property scaler split. Add a standalone
model_update_from_max_atom_sampler() in pet/checkpoints.py that renames
the spin_embedding state-dict keys to spin_multiplicity_embedding,
renames max_spin to max_spin_multiplicity, fills defaults for
adaptive_cutoff_method and the rest of the system_conditioning hypers,
runs the per-property scaler migration, and bumps the version to current
so the standard loader accepts it.

Round-trip tested by downgrading a freshly-built v14 checkpoint into
max-atom-sampler shape and confirming the migration restores something
PET.load_checkpoint accepts.
The merge with PR metatensor#1144 replaced our regex catch-alls for torch.jit
deprecations and the plural-name family ('features', 'positions',
'momenta', etc.) with PR metatensor#1144's narrower per-warning entries. macOS CI
was failing 61 tests on a metatomic 0.1.14 UserWarning about
Model.requested_inputs(use_new_names=False) that no remaining filter
caught.

Restore the regex catch-alls, add an explicit ignore for the new
requested_inputs deprecation, and fix the one offending call site
(cli/eval.py) to pass use_new_names=True so we no longer emit the
warning from our own code.
Comment thread pyproject.toml Outdated
@sofiia-chorna

Copy link
Copy Markdown
Collaborator

is this ready for review or still wip/testing? there are quite a few changes, just want to make sure they're all related to the spin/charge conditioning feature and nothing extra slipped in ^^

These three fixes were carried on this branch but are unrelated to the
spin/charge system-conditioning feature this PR is about:

* ``utils/loss.py`` — accept string / per-target string shorthand in
  ``LossAggregator``.
* ``utils/metrics.py`` — skip per-gradient metric when the prediction
  block lacks the gradient.
* ``utils/additive/remove.py`` — zero-fill missing additive gradients
  before the block-wise subtract.

They now live on a separate branch and will be proposed as their own
PR. PET tests (including conditioning) still pass without them on this
branch.
…oning-rebased

# Conflicts:
#	pyproject.toml
#	src/metatrain/cli/eval.py
PR metatensor#1075 renamed ``per_atom`` to ``sample_kind`` throughout the target
hypers, but the new extra_data path on this branch was still reading
``opts["per_atom"]``. After merging upstream, target_options that flow
through ``sanitize_target_hypers`` no longer carry ``per_atom`` and the
tests we added passed ``per_atom: False`` directly, so MemmapDataset
crashed with ``KeyError: 'sample_kind'``.

Switch the two extra_data sites in ``MemmapDataset`` to
``opts["sample_kind"] == "atom"`` and update the per-sample test
configs to use ``sample_kind`` instead of ``per_atom``.
@JonathanSchmidt1

Copy link
Copy Markdown
Contributor Author

Sorry there was still one unrelated metric change. it is ready for review now @sofiia-chorna

@Luthaf

Luthaf commented Jun 3, 2026

Copy link
Copy Markdown
Member

cscs-ci run

@sofiia-chorna sofiia-chorna left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the updates

it would be nice to add a doc with an example yaml for training and evaluation (it is also OK for me to add it in the follow up PR)

left some clean up comments, otherwise seems good

return checkpoint


def test_model_update_from_max_atom_sampler():

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need it on this branch? or it is a leftover from your max-atom-sampler branch?

checkpoint[key] = new_state_dict


def model_update_from_max_atom_sampler(checkpoint: dict) -> dict:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, is a leftover from the max-atom-sampler branch?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes same

update_per_property_scales(checkpoint)


def model_update_v13_v14(checkpoint: dict) -> None:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if i understand correctly, has_conditioning_weights (and edge_linear/spin_embedding renames) can only be triggerred for checkpoints from your muon2 branch. so i am wondering if we actually need those changes? or you need this to be able to load your trained checkpoints?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes due to how long the omol models have lived in my separated branch we need an extra checkpoint function to load them.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i see 🥺

question: those checkpoints are already published? if not, maybe it is easier to convert them to the "main" branch format instead...

Comment thread src/metatrain/utils/data/dataset.py Outdated
- Remove dead DiskDataset.extra_data_config / get_extra_data_info
  (never populated or called; MemmapDataset keeps its own)
- Hoist the model checkpoint version into
  checkpoints.MODEL_CHECKPOINT_VERSION so the max-atom-sampler
  migration no longer needs a lazy "from .model import PET"
- Document the batch-order invariant get_system_data_transform
  relies on (rows in batch order, label values are dataset indices)
- Add docs: example training/eval yaml for charge and spin
  conditioning in the PET architecture page
- Add tests: end-to-end mtt eval routing of extra_data to the
  conditioning module, non-integer charge/spin ValueError, and a
  trainer run proving the conditioning transform is wired into the
  collate pipeline

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
@JonathanSchmidt1

Copy link
Copy Markdown
Contributor Author

locally the tests run. probably just github having problems again I will rerun in an hour ...

@sofiia-chorna

Copy link
Copy Markdown
Collaborator

thanks for the updates! i will quickly push a minor clean up commit to move imports to the top (sorrryyy i have this obsession 😂)

i will approve since the open question is non-blocking and can be fixed in the separate PR but:

checkpoints that uses this feature, are they already published/shared anywhere?
if yes => we probably need to keep your model_update_from_max_atom_sampler so the fork checkpoints can still be loaded on main
if no => I think it'd be cleaner to convert your fork checkpoints once into the new format of this PR and drop the migration function

also interested to know the opinion from @pfebrer and @Luthaf ^^

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants