From ff07bc318785cee9114a67df43dbd7d00463cc77 Mon Sep 17 00:00:00 2001 From: Han Yang Date: Thu, 28 May 2026 13:43:19 +0100 Subject: [PATCH 1/3] fix: normalize non-periodic systems in TorchSim graph construction The TorchSim code path (build_graph_from_simstate) was passing raw pbc=False and a zero cell for non-periodic molecules, while the ASE calculator path (_normalize_atoms) creates a large fake periodic cell and wraps positions into it. This caused the M3GNet model to receive different graph inputs for the same molecule, producing inconsistent energies and forces. Add _normalize_nonperiodic_systems() that replicates the same fake-cell normalization for the TorchSim path, ensuring both code paths feed identical inputs to the model. Fixes #160 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/mattersim/torchsim/graph_construction.py | 73 +++++++++++++++++++- 1 file changed, 71 insertions(+), 2 deletions(-) diff --git a/src/mattersim/torchsim/graph_construction.py b/src/mattersim/torchsim/graph_construction.py index 245cad5..a2ed151 100644 --- a/src/mattersim/torchsim/graph_construction.py +++ b/src/mattersim/torchsim/graph_construction.py @@ -12,6 +12,62 @@ from mattersim.datasets.utils.converter import create_batch_graph_dict +def _normalize_nonperiodic_systems( + pos: torch.Tensor, + cell: torch.Tensor, + pbc: torch.Tensor, + system_idx: torch.Tensor, + n_systems: int, + twobody_cutoff: float, + threebody_cutoff: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Normalize non-periodic systems for graph construction. + + For fully non-periodic systems (pbc=False in all directions), creates a + large fake cubic cell and wraps positions into it, matching the behavior + of :func:`~mattersim.datasets.utils.converter._normalize_atoms` in the + ASE calculator code path. + + Args: + pos: [total_atoms, 3] raw Cartesian positions. + cell: [n_systems, 3, 3] unit cell matrices. + pbc: [n_systems, 3] periodic boundary condition flags. + system_idx: [total_atoms] system index per atom. + n_systems: Number of systems in the batch. + twobody_cutoff: Two-body cutoff radius in Angstrom. + threebody_cutoff: Three-body cutoff radius in Angstrom. + + Returns: + Tuple of (pos, cell, pbc) with non-periodic systems normalized. + """ + non_periodic = ~pbc.any(dim=1) # [n_systems] + if not non_periodic.any(): + return pos, cell, pbc + + device = pos.device + pos = pos.clone() + cell = cell.clone() + pbc = pbc.clone() + pad = max(twobody_cutoff, threebody_cutoff) * 5.0 + + for i in range(n_systems): + if non_periodic[i]: + mask = system_idx == i + sys_pos = pos[mask] + + extent = sys_pos.max(dim=0).values - sys_pos.min(dim=0).values + box_len = max(extent.max().item() + pad, pad) + + cell[i] = torch.eye(3, device=device, dtype=cell.dtype) * box_len + pbc[i] = True + + # Wrap positions into the fake cell (equivalent to ASE atoms.wrap()) + frac = sys_pos / box_len + pos[mask] = (frac % 1.0) * box_len + + return pos, cell, pbc + + def build_graph_from_simstate( sim_state: ts.SimState, *, @@ -48,9 +104,22 @@ def build_graph_from_simstate( else: pbc = sim_state.pbc - return create_batch_graph_dict( - pos=sim_state.wrap_positions, + # Normalize non-periodic systems: create a fake large periodic cell so + # that graph construction matches the ASE calculator code path + # (see _normalize_atoms in datasets/utils/converter.py). + pos, cell, pbc = _normalize_nonperiodic_systems( + pos=sim_state.positions, cell=sim_state.row_vector_cell, + pbc=pbc, + system_idx=sim_state.system_idx, + n_systems=sim_state.n_systems, + twobody_cutoff=twobody_cutoff, + threebody_cutoff=threebody_cutoff, + ) + + return create_batch_graph_dict( + pos=pos, + cell=cell, atomic_numbers=sim_state.atomic_numbers, num_atoms=n_atoms_per_graph, twobody_cutoff=twobody_cutoff, From 82b693868d2ec69e8b1edd1245953168f8687a5b Mon Sep 17 00:00:00 2001 From: Han Yang Date: Thu, 28 May 2026 13:43:25 +0100 Subject: [PATCH 2/3] test: add regression tests for non-periodic TorchSim normalization Add unit tests for _normalize_nonperiodic_systems() verifying: - Periodic systems pass through unchanged - Non-periodic molecules get a large fake periodic cell - Cell and positions match the ASE _normalize_atoms code path - Original SimState tensors are not mutated Add end-to-end test (test_wrapper_molecule_consistency) comparing TorchSimWrapper vs MatterSimCalculator on a water molecule. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/torchsim/test_torchsim.py | 168 +++++++++++++++++++++++++++++++- 1 file changed, 165 insertions(+), 3 deletions(-) diff --git a/tests/torchsim/test_torchsim.py b/tests/torchsim/test_torchsim.py index 3e17a49..ffb66cd 100644 --- a/tests/torchsim/test_torchsim.py +++ b/tests/torchsim/test_torchsim.py @@ -2,11 +2,13 @@ from typing import Literal +import numpy as np import pytest import torch import torch_sim as ts from mattersim.forcefield.potential import Potential +from mattersim.torchsim.graph_construction import _normalize_nonperiodic_systems from mattersim.torchsim.torchsim_wrapper import TorchSimWrapper requires_gpu = pytest.mark.skipif( @@ -45,6 +47,130 @@ def test_package_imports(): assert get_torchsim_wrapper is not None +# --------------------------------------------------------------------------- +# _normalize_nonperiodic_systems tests (lightweight, no model needed) +# --------------------------------------------------------------------------- + + +class TestNormalizeNonperiodicSystems: + """Tests for the non-periodic normalization logic.""" + + def test_periodic_system_unchanged(self, si_diamond_cubic): + """Periodic systems should pass through without modification.""" + state = ts.initialize_state( + [si_diamond_cubic], device="cpu", dtype=torch.float64 + ) + pbc = state.pbc.unsqueeze(0) if state.pbc.dim() == 1 else state.pbc + cell = state.row_vector_cell + pos = state.positions + + pos_out, cell_out, pbc_out = _normalize_nonperiodic_systems( + pos=pos, + cell=cell, + pbc=pbc, + system_idx=state.system_idx, + n_systems=state.n_systems, + twobody_cutoff=5.0, + threebody_cutoff=4.0, + ) + + # Should return the same tensor objects (no clone) when all periodic + assert pos_out is pos + assert cell_out is cell + assert pbc_out is pbc + + def test_nonperiodic_gets_fake_cell(self, water_molecule): + """Non-periodic molecules should get a large fake periodic cell.""" + state = ts.initialize_state([water_molecule], device="cpu", dtype=torch.float64) + pbc = state.pbc.unsqueeze(0) if state.pbc.dim() == 1 else state.pbc + + pos_out, cell_out, pbc_out = _normalize_nonperiodic_systems( + pos=state.positions, + cell=state.row_vector_cell, + pbc=pbc, + system_idx=state.system_idx, + n_systems=state.n_systems, + twobody_cutoff=5.0, + threebody_cutoff=4.0, + ) + + # PBC should now be True in all directions + assert pbc_out.all() + + # Cell should be a diagonal matrix with a large box length + assert cell_out[0, 0, 0] > 20.0 # at least pad = 5*5 = 25 Å + assert cell_out[0, 0, 0] == cell_out[0, 1, 1] == cell_out[0, 2, 2] + assert cell_out[0, 0, 1] == 0.0 # off-diagonal should be zero + + # All positions should be inside the box [0, box_len) + box_len = cell_out[0, 0, 0].item() + assert (pos_out >= 0).all() + assert (pos_out < box_len).all() + + def test_matches_ase_normalize_atoms(self, water_molecule): + """Normalization should produce the same cell and positions as + _normalize_atoms from the ASE code path.""" + from mattersim.datasets.utils.converter import _normalize_atoms + + twobody_cutoff = 5.0 + threebody_cutoff = 4.0 + + # ASE path + atoms_norm = _normalize_atoms(water_molecule, twobody_cutoff, threebody_cutoff) + ase_cell = np.array(atoms_norm.cell) + ase_pos = atoms_norm.positions + + # TorchSim path + state = ts.initialize_state([water_molecule], device="cpu", dtype=torch.float64) + pbc = state.pbc.unsqueeze(0) if state.pbc.dim() == 1 else state.pbc + + pos_out, cell_out, pbc_out = _normalize_nonperiodic_systems( + pos=state.positions, + cell=state.row_vector_cell, + pbc=pbc, + system_idx=state.system_idx, + n_systems=state.n_systems, + twobody_cutoff=twobody_cutoff, + threebody_cutoff=threebody_cutoff, + ) + + np.testing.assert_allclose( + cell_out[0].numpy(), + ase_cell, + atol=1e-10, + err_msg="Cell matrices should match between TorchSim and ASE paths", + ) + np.testing.assert_allclose( + pos_out.numpy(), + ase_pos, + atol=1e-10, + err_msg="Wrapped positions should match between TorchSim and ASE paths", + ) + + def test_original_tensors_not_mutated(self, water_molecule): + """Normalization should not mutate the original SimState tensors.""" + state = ts.initialize_state([water_molecule], device="cpu", dtype=torch.float64) + pbc = state.pbc.unsqueeze(0) if state.pbc.dim() == 1 else state.pbc + + orig_pos = state.positions.clone() + orig_cell = state.row_vector_cell.clone() + orig_pbc = pbc.clone() + + _normalize_nonperiodic_systems( + pos=state.positions, + cell=state.row_vector_cell, + pbc=pbc, + system_idx=state.system_idx, + n_systems=state.n_systems, + twobody_cutoff=5.0, + threebody_cutoff=4.0, + ) + + torch.testing.assert_close(state.positions, orig_pos) + torch.testing.assert_close(state.row_vector_cell, orig_cell) + torch.testing.assert_close(pbc, orig_pbc) + + # --------------------------------------------------------------------------- # TorchSimWrapper tests # --------------------------------------------------------------------------- @@ -62,9 +188,7 @@ def test_wrapper_creation(self, torchsim_wrapper: TorchSimWrapper): assert "forces" in torchsim_wrapper.implemented_properties assert "stress" in torchsim_wrapper.implemented_properties - def test_wrapper_forward( - self, torchsim_wrapper: TorchSimWrapper, si_diamond_cubic - ): + def test_wrapper_forward(self, torchsim_wrapper: TorchSimWrapper, si_diamond_cubic): state = ts.initialize_state( [si_diamond_cubic], device=DEVICE, dtype=torch.float64 ) @@ -76,3 +200,41 @@ def test_wrapper_forward( assert result["energy"].shape == (1,) assert result["forces"].shape == (len(si_diamond_cubic), 3) assert result["stress"].shape == (1, 3, 3) + + def test_wrapper_molecule_consistency( + self, torchsim_wrapper: TorchSimWrapper, water_molecule + ): + """TorchSimWrapper and MatterSimCalculator should agree on molecules. + + This is a regression test for GitHub issue #160. + """ + from mattersim.forcefield.potential import MatterSimCalculator + + # ASE calculator path + calc = MatterSimCalculator( + potential=torchsim_wrapper.model, + device=DEVICE, + direct_graph=True, + ) + water_molecule.calc = calc + ase_energy = water_molecule.get_potential_energy() + ase_forces = water_molecule.get_forces() + + # TorchSim wrapper path + state = ts.initialize_state( + [water_molecule], device=DEVICE, dtype=torch.float64 + ) + result = torchsim_wrapper(state) + + torch.testing.assert_close( + result["energy"].item(), + ase_energy, + rtol=1e-5, + atol=1e-5, + ) + torch.testing.assert_close( + result["forces"].cpu(), + torch.tensor(ase_forces, dtype=result["forces"].dtype), + rtol=1e-5, + atol=1e-5, + ) From 10366fce5a685385bf1d8e63d776f355fe203369 Mon Sep 17 00:00:00 2001 From: Han Yang Date: Thu, 28 May 2026 16:11:49 +0100 Subject: [PATCH 3/3] fix: request full FC format for phono3py v4 compatibility Phono3py v4 defaults to compact force constants format, changing fc2 shape from (n_satom, n_satom, 3, 3) to (n_patom, n_satom, 3, 3). Explicitly pass is_compact_fc=False to produce_fc2() and produce_fc3() to maintain the full format expected by downstream code. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/mattersim/applications/bte.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/mattersim/applications/bte.py b/src/mattersim/applications/bte.py index dbef4b8..1403514 100644 --- a/src/mattersim/applications/bte.py +++ b/src/mattersim/applications/bte.py @@ -224,7 +224,7 @@ def compute_force_constants( second_force_sets.append(pa_second.get_forces()) phonon3.phonon_forces = second_force_sets - phonon3.produce_fc2(symmetrize_fc2=True) + phonon3.produce_fc2(symmetrize_fc2=True, is_compact_fc=False) # Compute 3rd force constants third_scs = phonon3.supercells_with_displacements @@ -237,7 +237,7 @@ def compute_force_constants( third_force_sets.append(pa_third.get_forces()) phonon3.forces = third_force_sets - phonon3.produce_fc3(symmetrize_fc3r=True) + phonon3.produce_fc3(symmetrize_fc3r=True, is_compact_fc=False) # Save to file if self.save_fcs: @@ -746,4 +746,3 @@ def get_kappa( finally: os.chdir(current_path) -