Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/mattersim/applications/bte.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def compute_force_constants(
second_force_sets.append(pa_second.get_forces())

phonon3.phonon_forces = second_force_sets
phonon3.produce_fc2(symmetrize_fc2=True)
phonon3.produce_fc2(symmetrize_fc2=True, is_compact_fc=False)

# Compute 3rd force constants
third_scs = phonon3.supercells_with_displacements
Expand All @@ -237,7 +237,7 @@ def compute_force_constants(
third_force_sets.append(pa_third.get_forces())

phonon3.forces = third_force_sets
phonon3.produce_fc3(symmetrize_fc3r=True)
phonon3.produce_fc3(symmetrize_fc3r=True, is_compact_fc=False)

# Save to file
if self.save_fcs:
Expand Down Expand Up @@ -746,4 +746,3 @@ def get_kappa(

finally:
os.chdir(current_path)

73 changes: 71 additions & 2 deletions src/mattersim/torchsim/graph_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,62 @@
from mattersim.datasets.utils.converter import create_batch_graph_dict


def _normalize_nonperiodic_systems(
pos: torch.Tensor,
cell: torch.Tensor,
pbc: torch.Tensor,
system_idx: torch.Tensor,
n_systems: int,
twobody_cutoff: float,
threebody_cutoff: float,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Normalize non-periodic systems for graph construction.

For fully non-periodic systems (pbc=False in all directions), creates a
large fake cubic cell and wraps positions into it, matching the behavior
of :func:`~mattersim.datasets.utils.converter._normalize_atoms` in the
ASE calculator code path.

Args:
pos: [total_atoms, 3] raw Cartesian positions.
cell: [n_systems, 3, 3] unit cell matrices.
pbc: [n_systems, 3] periodic boundary condition flags.
system_idx: [total_atoms] system index per atom.
n_systems: Number of systems in the batch.
twobody_cutoff: Two-body cutoff radius in Angstrom.
threebody_cutoff: Three-body cutoff radius in Angstrom.

Returns:
Tuple of (pos, cell, pbc) with non-periodic systems normalized.
"""
non_periodic = ~pbc.any(dim=1) # [n_systems]
if not non_periodic.any():
return pos, cell, pbc

device = pos.device
pos = pos.clone()
cell = cell.clone()
pbc = pbc.clone()
pad = max(twobody_cutoff, threebody_cutoff) * 5.0

for i in range(n_systems):
if non_periodic[i]:
mask = system_idx == i
sys_pos = pos[mask]

extent = sys_pos.max(dim=0).values - sys_pos.min(dim=0).values
box_len = max(extent.max().item() + pad, pad)

cell[i] = torch.eye(3, device=device, dtype=cell.dtype) * box_len
pbc[i] = True

# Wrap positions into the fake cell (equivalent to ASE atoms.wrap())
frac = sys_pos / box_len
pos[mask] = (frac % 1.0) * box_len

return pos, cell, pbc


def build_graph_from_simstate(
sim_state: ts.SimState,
*,
Expand Down Expand Up @@ -48,9 +104,22 @@ def build_graph_from_simstate(
else:
pbc = sim_state.pbc

return create_batch_graph_dict(
pos=sim_state.wrap_positions,
# Normalize non-periodic systems: create a fake large periodic cell so
# that graph construction matches the ASE calculator code path
# (see _normalize_atoms in datasets/utils/converter.py).
pos, cell, pbc = _normalize_nonperiodic_systems(
pos=sim_state.positions,
cell=sim_state.row_vector_cell,
pbc=pbc,
system_idx=sim_state.system_idx,
n_systems=sim_state.n_systems,
twobody_cutoff=twobody_cutoff,
threebody_cutoff=threebody_cutoff,
)

return create_batch_graph_dict(
pos=pos,
cell=cell,
atomic_numbers=sim_state.atomic_numbers,
num_atoms=n_atoms_per_graph,
twobody_cutoff=twobody_cutoff,
Expand Down
168 changes: 165 additions & 3 deletions tests/torchsim/test_torchsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from typing import Literal

import numpy as np
import pytest
import torch
import torch_sim as ts

from mattersim.forcefield.potential import Potential
from mattersim.torchsim.graph_construction import _normalize_nonperiodic_systems
from mattersim.torchsim.torchsim_wrapper import TorchSimWrapper

requires_gpu = pytest.mark.skipif(
Expand Down Expand Up @@ -45,6 +47,130 @@ def test_package_imports():
assert get_torchsim_wrapper is not None


# ---------------------------------------------------------------------------
# _normalize_nonperiodic_systems tests (lightweight, no model needed)
# ---------------------------------------------------------------------------


class TestNormalizeNonperiodicSystems:
"""Tests for the non-periodic normalization logic."""

def test_periodic_system_unchanged(self, si_diamond_cubic):
"""Periodic systems should pass through without modification."""
state = ts.initialize_state(
[si_diamond_cubic], device="cpu", dtype=torch.float64
)
pbc = state.pbc.unsqueeze(0) if state.pbc.dim() == 1 else state.pbc
cell = state.row_vector_cell
pos = state.positions

pos_out, cell_out, pbc_out = _normalize_nonperiodic_systems(
pos=pos,
cell=cell,
pbc=pbc,
system_idx=state.system_idx,
n_systems=state.n_systems,
twobody_cutoff=5.0,
threebody_cutoff=4.0,
)

# Should return the same tensor objects (no clone) when all periodic
assert pos_out is pos
assert cell_out is cell
assert pbc_out is pbc

def test_nonperiodic_gets_fake_cell(self, water_molecule):
"""Non-periodic molecules should get a large fake periodic cell."""
state = ts.initialize_state([water_molecule], device="cpu", dtype=torch.float64)
pbc = state.pbc.unsqueeze(0) if state.pbc.dim() == 1 else state.pbc

pos_out, cell_out, pbc_out = _normalize_nonperiodic_systems(
pos=state.positions,
cell=state.row_vector_cell,
pbc=pbc,
system_idx=state.system_idx,
n_systems=state.n_systems,
twobody_cutoff=5.0,
threebody_cutoff=4.0,
)

# PBC should now be True in all directions
assert pbc_out.all()

# Cell should be a diagonal matrix with a large box length
assert cell_out[0, 0, 0] > 20.0 # at least pad = 5*5 = 25 Å
assert cell_out[0, 0, 0] == cell_out[0, 1, 1] == cell_out[0, 2, 2]
assert cell_out[0, 0, 1] == 0.0 # off-diagonal should be zero

# All positions should be inside the box [0, box_len)
box_len = cell_out[0, 0, 0].item()
assert (pos_out >= 0).all()
assert (pos_out < box_len).all()

def test_matches_ase_normalize_atoms(self, water_molecule):
"""Normalization should produce the same cell and positions as
_normalize_atoms from the ASE code path."""
from mattersim.datasets.utils.converter import _normalize_atoms

twobody_cutoff = 5.0
threebody_cutoff = 4.0

# ASE path
atoms_norm = _normalize_atoms(water_molecule, twobody_cutoff, threebody_cutoff)
ase_cell = np.array(atoms_norm.cell)
ase_pos = atoms_norm.positions

# TorchSim path
state = ts.initialize_state([water_molecule], device="cpu", dtype=torch.float64)
pbc = state.pbc.unsqueeze(0) if state.pbc.dim() == 1 else state.pbc

pos_out, cell_out, pbc_out = _normalize_nonperiodic_systems(
pos=state.positions,
cell=state.row_vector_cell,
pbc=pbc,
system_idx=state.system_idx,
n_systems=state.n_systems,
twobody_cutoff=twobody_cutoff,
threebody_cutoff=threebody_cutoff,
)

np.testing.assert_allclose(
cell_out[0].numpy(),
ase_cell,
atol=1e-10,
err_msg="Cell matrices should match between TorchSim and ASE paths",
)
np.testing.assert_allclose(
pos_out.numpy(),
ase_pos,
atol=1e-10,
err_msg="Wrapped positions should match between TorchSim and ASE paths",
)

def test_original_tensors_not_mutated(self, water_molecule):
"""Normalization should not mutate the original SimState tensors."""
state = ts.initialize_state([water_molecule], device="cpu", dtype=torch.float64)
pbc = state.pbc.unsqueeze(0) if state.pbc.dim() == 1 else state.pbc

orig_pos = state.positions.clone()
orig_cell = state.row_vector_cell.clone()
orig_pbc = pbc.clone()

_normalize_nonperiodic_systems(
pos=state.positions,
cell=state.row_vector_cell,
pbc=pbc,
system_idx=state.system_idx,
n_systems=state.n_systems,
twobody_cutoff=5.0,
threebody_cutoff=4.0,
)

torch.testing.assert_close(state.positions, orig_pos)
torch.testing.assert_close(state.row_vector_cell, orig_cell)
torch.testing.assert_close(pbc, orig_pbc)


# ---------------------------------------------------------------------------
# TorchSimWrapper tests
# ---------------------------------------------------------------------------
Expand All @@ -62,9 +188,7 @@ def test_wrapper_creation(self, torchsim_wrapper: TorchSimWrapper):
assert "forces" in torchsim_wrapper.implemented_properties
assert "stress" in torchsim_wrapper.implemented_properties

def test_wrapper_forward(
self, torchsim_wrapper: TorchSimWrapper, si_diamond_cubic
):
def test_wrapper_forward(self, torchsim_wrapper: TorchSimWrapper, si_diamond_cubic):
state = ts.initialize_state(
[si_diamond_cubic], device=DEVICE, dtype=torch.float64
)
Expand All @@ -76,3 +200,41 @@ def test_wrapper_forward(
assert result["energy"].shape == (1,)
assert result["forces"].shape == (len(si_diamond_cubic), 3)
assert result["stress"].shape == (1, 3, 3)

def test_wrapper_molecule_consistency(
self, torchsim_wrapper: TorchSimWrapper, water_molecule
):
"""TorchSimWrapper and MatterSimCalculator should agree on molecules.

This is a regression test for GitHub issue #160.
"""
from mattersim.forcefield.potential import MatterSimCalculator

# ASE calculator path
calc = MatterSimCalculator(
potential=torchsim_wrapper.model,
device=DEVICE,
direct_graph=True,
)
water_molecule.calc = calc
ase_energy = water_molecule.get_potential_energy()
ase_forces = water_molecule.get_forces()

# TorchSim wrapper path
state = ts.initialize_state(
[water_molecule], device=DEVICE, dtype=torch.float64
)
result = torchsim_wrapper(state)

torch.testing.assert_close(
result["energy"].item(),
ase_energy,
rtol=1e-5,
atol=1e-5,
)
torch.testing.assert_close(
result["forces"].cpu(),
torch.tensor(ase_forces, dtype=result["forces"].dtype),
rtol=1e-5,
atol=1e-5,
)
Loading