diff --git a/python/metatomic_torch/MANIFEST.in b/python/metatomic_torch/MANIFEST.in index 6d341b48e..5f3ae9425 100644 --- a/python/metatomic_torch/MANIFEST.in +++ b/python/metatomic_torch/MANIFEST.in @@ -8,3 +8,5 @@ include git_version_info include metatomic-torch-*.tar.gz recursive-include build-backend *.py + +recursive-include metatomic/torch/data *.npz diff --git a/python/metatomic_torch/metatomic/torch/data/dftd3_parameters.npz b/python/metatomic_torch/metatomic/torch/data/dftd3_parameters.npz new file mode 100644 index 000000000..bd7022c5f Binary files /dev/null and b/python/metatomic_torch/metatomic/torch/data/dftd3_parameters.npz differ diff --git a/python/metatomic_torch/metatomic/torch/dftd3.py b/python/metatomic_torch/metatomic/torch/dftd3.py new file mode 100644 index 000000000..8ebb78861 --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/dftd3.py @@ -0,0 +1,1119 @@ +import warnings +from importlib.resources import files +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelOutput, + NeighborListOptions, + System, + unit_conversion_factor, +) + + +_REQUIRED_D3_TABLES = ("rcov", "r4r2", "c6", "cn_ref") +_REQUIRED_DAMPING = ("a1", "a2", "s8") + +# Standard Grimme D3 cutoffs from ``tad_dftd3.defaults`` (atomic units / Bohr). +# Used as defaults when the caller does not specify cutoffs explicitly; we +# convert from Bohr to the wrapped model's length unit at construction. +_D3_DISP_CUTOFF_BOHR = 50.0 +_D3_CN_CUTOFF_BOHR = 25.0 +_D3_PARAMETERS_NPZ = "dftd3_parameters.npz" + + +def _load_dftd3_parameters( + dtype: Optional[torch.dtype] = None, +) -> Dict[str, torch.Tensor]: + """Load the packaged Grimme D3 reference tables. + + The returned tensors are in the atomic-unit convention expected by + :class:`DFTD3`: ``rcov`` and ``r4r2`` in Bohr, ``c6`` in + Hartree * Bohr^6, and dimensionless ``cn_ref`` values. The table layout is + the current pure-PyTorch wrapper layout: ``c6`` has shape ``(Z, Z, M, M)`` + and ``cn_ref`` has shape ``(Z, M)``. + """ + + path = files("metatomic.torch").joinpath("data", _D3_PARAMETERS_NPZ) + with path.open("rb") as fd: + with np.load(fd) as data: + params = { + key: torch.from_numpy(data[key].copy()) for key in _REQUIRED_D3_TABLES + } + + if dtype is not None: + params = {key: value.to(dtype=dtype) for key, value in params.items()} + return params + + +class DFTD3(torch.nn.Module): + """ + :py:class:`DFTD3` wraps an :py:class:`AtomisticModel` and adds a DFT-D3(BJ) + dispersion correction to its energy output(s). The three-body correction + term is **not** included. + + The wrapper can correct multiple output variants at once, each with its + own damping parameters. For every energy output key (e.g. ``"energy"`` or + ``"energy/pbe"``) listed in ``damping_params``, the wrapper adds the D3 + energy as a differentiable tensor: ``E_corrected = E_base + E_D3``. + + The D3 energy is implemented in pure PyTorch and is naturally + differentiable: ``torch.autograd`` flows from the corrected energy back to + ``positions`` and ``cell`` through the neighbor list distances. + + ``damping_params`` can also be specified for non-conservative outputs such as + ``"non_conservative_force/"`` and + ``"non_conservative_stress/"``. Non-conservative force/stress outputs are + corrected only when their output keys are explicitly listed in + ``damping_params``; energy damping parameters are not inferred for these + outputs. The D3 contribution to non-conservative force/stress outputs is computed + analytically with respect to the neighbor-vector values, including the + coordination-number dependence of the interpolated C6 coefficients. + + ``selected_atoms`` is supported with the usual domain-decomposition + convention: the D3 environment is computed with all atoms in each + :class:`System`, while pair energies are split equally between the two pair + endpoints and only the shares belonging to selected atoms are added. + Per-atom energy outputs use the same half-pair split. + + ``excluded_atom_types`` can be used to disable D3 pair energies involving + specific atom types, while keeping all atoms in the D3 coordination-number + environment. This is useful for systems where D3 should not be applied to + pairs involving selected species, such as common cations. + + The D3 reference tables (``d3_params``) are shared across variants, + matching the convention that the Grimme reference data is functional + independent. The reference tables include four sets of parameters, ``rcov`` + (dimension of length, Bohr), ``r4r2`` (dimension of length, Bohr), ``c6`` (dimension + of energy * length^6, Hartree * Bohr^6), and ``cn_ref`` (dimensionless). Damping + parameters (``a1``, ``a2``, ``s8``, ...) are provided per variant. Among these + damping parameters, only ``a2`` has a dimension of length (Bohr), the others are + dimensionless. All D3 tables and damping parameters **must** be passed in **atomic + units**. The wrapper converts the final D3 energy into the wrapped model's energy + unit of the corresponding output. + """ + + _energy_keys: List[str] + _energy_units: Dict[str, str] + _a1: Dict[str, float] + _a2: Dict[str, float] + _s8: Dict[str, float] + _s6: Dict[str, float] + _force_keys: List[str] + _stress_keys: List[str] + _force_damping_keys: Dict[str, str] + _stress_damping_keys: Dict[str, str] + _force_units: Dict[str, str] + _stress_units: Dict[str, str] + + def __init__( + self, + model: AtomisticModel, + *, + damping_params: Dict[str, Dict[str, float]], + d3_params: Optional[Dict[str, torch.Tensor]] = None, + cutoff: Optional[float] = None, + cn_cutoff: Optional[float] = None, + excluded_atom_types: Optional[List[int]] = None, + ): + """ + :param model: the :py:class:`AtomisticModel` to wrap + :param damping_params: a mapping from an output key to a mapping of damping + parameters for that output. Keys can be ``"energy[/]"``, + ``"non_conservative_force[/]"`` or + ``"non_conservative_stress[/]"``. Each damping map must provide + ``a1``, ``a2`` and ``s8``; ``s6`` is optional (defaults to 1.0). + :param d3_params: shared DFT-D3 reference tables with keys ``"rcov"`` + (shape ``(Z,)``), ``"r4r2"`` (shape ``(Z,)``), ``"c6"`` (shape + ``(Z, Z, M, M)``) and ``"cn_ref"`` (shape + ``(Z, M)`` — per-element CN reference grid, with ``-1`` marking + absent slots), where ``Z-1`` is the maximum atomic number supported by + the wrapped model (because of the 0-based indexing), and ``M`` is the number + of CN reference points. Tables must be in D3 atomic units. If ``None``, the + packaged Grimme D3 reference tables are used. + :param cutoff: dispersion-pair cutoff in the wrapped model's length + unit. If ``None``, defaults to the standard Grimme value of + ``50 Bohr`` converted into the model's length unit. + :param cn_cutoff: coordination-number cutoff in the wrapped model's + length unit. If ``None``, defaults to ``25 Bohr`` converted into + the model's length unit. + :param excluded_atom_types: optional atom types for which D3 pair + energies should be disabled. Any pair where either endpoint has one + of these types contributes zero D3 energy. Coordination numbers are + still computed with all atoms. + The wrapped model's atomic types must be real atomic numbers; these + are used directly to index the D3 parameter tables. + """ + super().__init__() + + assert isinstance(model, AtomisticModel) + + if d3_params is None: + d3_params = _load_dftd3_parameters() + for key in _REQUIRED_D3_TABLES: + if key not in d3_params: + raise KeyError(f"missing required D3 parameter table '{key}'") + + if len(damping_params) == 0: + raise ValueError( + "DFTD3 requires at least one corrected output in 'damping_params'" + ) + + capabilities = model.capabilities() + if capabilities.length_unit == "": + raise ValueError("DFTD3 requires the wrapped model to define a length unit") + self._length_unit = capabilities.length_unit + + outputs = capabilities.outputs + + self._validate_d3_params(d3_params) + rcov = d3_params["rcov"] + r4r2 = d3_params["r4r2"] + c6 = d3_params["c6"] + cn_ref = d3_params["cn_ref"] + + bohr_to_model = float(unit_conversion_factor("bohr", self._length_unit)) + if cutoff is None: + cutoff = _D3_DISP_CUTOFF_BOHR * bohr_to_model + if cn_cutoff is None: + cn_cutoff = _D3_CN_CUTOFF_BOHR * bohr_to_model + + if cutoff <= 0.0: + raise ValueError(f"DFTD3 cutoff must be positive, got {cutoff}") + if cn_cutoff <= 0.0: + raise ValueError(f"DFTD3 cn_cutoff must be positive, got {cn_cutoff}") + + max_atomic_type = max(capabilities.atomic_types) + if max_atomic_type >= rcov.shape[0]: + warnings.warn( + "D3 tables do not cover all wrapped-model atomic types: " + f"maximum atomic type is {max_atomic_type} but tables only " + f"support up to {rcov.shape[0] - 1}. This will likely cause " + "out-of-bounds errors in D3 table lookups. Proceed at your own risk.", + stacklevel=2, + ) + + # The D3 reference tables and damping math run in the model's compute + # dtype to avoid lossy float32 round-trips for float64 wrapped models. + if capabilities.dtype == "float64": + buffer_dtype = torch.float64 + else: + buffer_dtype = torch.float32 + + self.register_buffer("_rcov", rcov.detach().to(dtype=buffer_dtype)) + self.register_buffer("_r4r2", r4r2.detach().to(dtype=buffer_dtype)) + self.register_buffer("_c6", c6.detach().to(dtype=buffer_dtype)) + self.register_buffer("_cn_ref", cn_ref.detach().to(dtype=buffer_dtype)) + if excluded_atom_types is None: + excluded_atom_types = [] + self.register_buffer( + "_excluded_atom_types", + torch.tensor(excluded_atom_types, dtype=torch.int64), + ) + + self._energy_keys = [] + self._energy_units = {} + self._a1 = {} + self._a2 = {} + self._s8 = {} + self._s6 = {} + self._force_keys = [] + self._stress_keys = [] + self._force_damping_keys = {} + self._stress_damping_keys = {} + self._force_units = {} + self._stress_units = {} + + # Register the D3 corrections for the outputs explicitly listed in + # ``damping_params``. + damping_method = "bj" + for output_key, params in damping_params.items(): + is_energy = output_key.split("/")[0] == "energy" + is_force = output_key.split("/")[0] == "non_conservative_force" + is_stress = output_key.split("/")[0] == "non_conservative_stress" + if not (is_energy or is_force or is_stress): + raise ValueError( + "DFTD3 damping_params key must be 'energy[/]', " + "'non_conservative_force[/]' or " + f"'non_conservative_stress[/]', got '{output_key}'" + ) + if output_key not in outputs: + raise ValueError( + f"DFTD3 cannot correct '{output_key}': the wrapped model " + "does not expose this output" + ) + if is_energy and outputs[output_key].sample_kind not in ["system", "atom"]: + raise ValueError( + f"DFTD3 requires output '{output_key}' to have " + "sample_kind='system' or sample_kind='atom'" + ) + if outputs[output_key].unit == "": + raise ValueError( + f"DFTD3 requires a defined unit for output '{output_key}'" + ) + for required in _REQUIRED_DAMPING: + if required not in params: + raise KeyError( + f"missing required damping parameter '{required}' " + f"for output '{output_key}'" + ) + method = str(params.get("damping", damping_method)) + if method != "bj": + raise NotImplementedError( + "DFTD3 only implements Becke-Johnson damping; " + f"got '{method}' for output '{output_key}'" + ) + + self._a1[output_key] = float(params["a1"]) + self._a2[output_key] = float(params["a2"]) + self._s8[output_key] = float(params["s8"]) + self._s6[output_key] = float(params.get("s6", 1.0)) + + if is_energy: + self._energy_keys.append(output_key) + self._energy_units[output_key] = outputs[output_key].unit + elif is_force: + self._register_non_conservative_output( + output_key, + outputs[output_key], + "atom", + self._force_keys, + self._force_damping_keys, + self._force_units, + output_key, + ) + else: + self._register_non_conservative_output( + output_key, + outputs[output_key], + "system", + self._stress_keys, + self._stress_damping_keys, + self._stress_units, + output_key, + ) + + self._model = model.module + self._cutoff = cutoff + self._cn_cutoff = cn_cutoff + self._neighbor_cutoff = max(cutoff, cn_cutoff) + + self._requested_neighbor_lists = model.requested_neighbor_lists() + self._neighbor_list = NeighborListOptions( + cutoff=self._neighbor_cutoff, + full_list=False, + strict=True, + requestor="DFTD3", + ) + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return self._requested_neighbor_lists + [self._neighbor_list] + + @staticmethod + def _register_non_conservative_output( + output_key: str, + output: ModelOutput, + sample_kind: str, + output_keys: List[str], + damping_keys: Dict[str, str], + units: Dict[str, str], + damping_key: str, + ): + if output.sample_kind != sample_kind: + raise ValueError( + f"DFTD3 requires output '{output_key}' to have " + f"sample_kind='{sample_kind}'" + ) + if output.unit == "": + raise ValueError(f"DFTD3 requires a defined unit for output '{output_key}'") + + output_keys.append(output_key) + damping_keys[output_key] = damping_key + units[output_key] = output.unit + + @staticmethod + def _validate_d3_params(d3_params: Dict[str, torch.Tensor]): + rcov = d3_params["rcov"] + r4r2 = d3_params["r4r2"] + c6 = d3_params["c6"] + cn_ref = d3_params["cn_ref"] + + for name, tensor in [ + ("rcov", rcov), + ("r4r2", r4r2), + ("c6", c6), + ("cn_ref", cn_ref), + ]: + if not isinstance(tensor, torch.Tensor): + raise TypeError(f"D3 table '{name}' must be a torch.Tensor") + if rcov.ndim != 1: + raise ValueError(f"'rcov' must be 1D, got shape {tuple(rcov.shape)}") + if r4r2.ndim != 1: + raise ValueError(f"'r4r2' must be 1D, got shape {tuple(r4r2.shape)}") + if c6.ndim != 4: + raise ValueError(f"'c6' must be 4D, got shape {tuple(c6.shape)}") + if cn_ref.ndim != 2: + raise ValueError(f"'cn_ref' must be 2D, got shape {tuple(cn_ref.shape)}") + if c6.shape[0] != c6.shape[1]: + raise ValueError( + f"'c6' must be square in its first two axes, got shape " + f"{tuple(c6.shape)}" + ) + if c6.shape[2] != c6.shape[3]: + raise ValueError( + f"'c6' must be square in its last two axes, got shape {tuple(c6.shape)}" + ) + if cn_ref.shape[0] != c6.shape[0]: + raise ValueError( + f"'cn_ref' first axis must match 'c6' first axis, got " + f"{cn_ref.shape[0]} vs {c6.shape[0]} vs {c6.shape[0]}" + ) + if cn_ref.shape[1] != c6.shape[2]: + raise ValueError( + f"'cn_ref' second axis must match 'c6' last axis, got " + f"{cn_ref.shape[1]} vs {c6.shape[2]} vs {c6.shape[2]}" + ) + if rcov.shape[0] < c6.shape[0] or r4r2.shape[0] < c6.shape[0]: + raise ValueError( + f"'rcov' and 'r4r2' must cover at least 'c6' first axis " + f"length ({c6.shape[0]}), got {rcov.shape[0]} and {r4r2.shape[0]}" + ) + + @staticmethod + def wrap( + model: AtomisticModel, + *, + damping_params: Dict[str, Dict[str, float]], + d3_params: Optional[Dict[str, torch.Tensor]] = None, + cutoff: Optional[float] = None, + cn_cutoff: Optional[float] = None, + excluded_atom_types: Optional[List[int]] = None, + ) -> AtomisticModel: + """Wrap ``model`` with a differentiable DFT-D3(BJ) energy correction. + + The returned :py:class:`AtomisticModel` has the same outputs as the + input, but each output listed in ``damping_params`` is corrected by + the corresponding D3 contribution. The correction is differentiable so the + standard autograd path produces D3-corrected conservative forces + and stress. + """ + wrapper = DFTD3( + model=model.eval(), + damping_params=damping_params, + d3_params=d3_params, + cutoff=cutoff, + cn_cutoff=cn_cutoff, + excluded_atom_types=excluded_atom_types, + ) + + capabilities = model.capabilities() + supported_devices = [device for device in capabilities.supported_devices] + if len(supported_devices) == 0: + raise ValueError( + "DFTD3 only supports CPU and CUDA devices, but the wrapped " + f"model declares {capabilities.supported_devices}" + ) + + # ``AtomisticModel.capabilities()`` includes compatibility aliases for + # deprecated output names. Keep only the names declared by the input + # model; ``AtomisticModel`` will add aliases again if needed. + declared_outputs = {} + for name in model._model_capabilities_outputs_names: + declared_outputs[name] = capabilities.outputs[name] + + new_capabilities = ModelCapabilities( + outputs=declared_outputs, + atomic_types=capabilities.atomic_types, + interaction_range=max( + capabilities.interaction_range, wrapper._neighbor_cutoff + ), + length_unit=capabilities.length_unit, + supported_devices=supported_devices, + dtype=capabilities.dtype, + ) + + return AtomisticModel( + wrapper.eval(), model.metadata(), capabilities=new_capabilities + ) + + @staticmethod + def _cn_counting_term( + dist: torch.Tensor, + r_cov_pair: torch.Tensor, + compute_derivative: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Per-pair D3 CN counting function ``1 / (1 + exp(-k_cn * (r_cov / r - 1)))`` + and, optionally, its derivative ``d term / d r``. + + ``k_cn`` is hardcoded to match tad_dftd3. + Ref: k_1 in Section 2.E of https://doi.org/10.1063/1.3382344 + In the ref above, you can also find a parameter k_2 = 4/3, but in the latest + tad-dftd3 codebase, this is set to 1, see https://github.com/tad-mctc/tad-mctc/blob/0d3bb31018520fb8a85bc79c000d4aae01f51235/src/tad_mctc/ncoord/count.py#L51-L72 + + The ``clamp(exponent, max=100.0)`` matches tad_mctc's stable sigmoid, see + https://github.com/tad-mctc/tad-mctc/blob/0d3bb31018520fb8a85bc79c000d4aae01f51235/src/tad_mctc/storch/elemental.py#L34-L78 + """ + k_cn: float = 16.0 + dist_safe = torch.clamp(dist, min=1e-10) + raw_exponent = -k_cn * (r_cov_pair / dist_safe - 1.0) + exponent = torch.clamp(raw_exponent, max=100.0) + term = 1.0 / (1.0 + torch.exp(exponent)) + if not compute_derivative: + return term, torch.empty(0, dtype=term.dtype, device=term.device) + + active = (raw_exponent <= 100.0) & (dist >= 1e-10) + dterm_ddist = -term * (1.0 - term) * k_cn * r_cov_pair / (dist_safe * dist_safe) + dterm_ddist = torch.where(active, dterm_ddist, torch.zeros_like(dterm_ddist)) + return term, dterm_ddist + + def _compute_cn( + self, + atomic_numbers: torch.Tensor, + idx_i: torch.Tensor, + idx_j: torch.Tensor, + dist: torch.Tensor, + compute_derivative: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute the D3 coordination number for every atom. + + The shared half neighbor list visits each pair once; the CN + contribution is symmetric, so we add it to both atoms. If the shared + list was built with a larger dispersion cutoff, filter back to the CN + cutoff before evaluating the counting function. + + Returns ``(cn, dterm_ddist, cn_pair_mask)`` where ``cn_pair_mask`` is a + bool tensor over the input pairs (``True`` for pairs within + ``cn_cutoff``), and ``dterm_ddist`` is the per-pair derivative of the + counting term on those kept pairs (empty if + ``compute_derivative=False``). + """ + if self._cn_cutoff < self._neighbor_cutoff: + cn_pair_mask = dist <= self._cn_cutoff * unit_conversion_factor( + self._length_unit, "bohr" + ) + idx_i_cn = idx_i[cn_pair_mask] + idx_j_cn = idx_j[cn_pair_mask] + dist_cn = dist[cn_pair_mask] + else: + cn_pair_mask = torch.ones( + dist.shape[0], dtype=torch.bool, device=dist.device + ) + idx_i_cn, idx_j_cn, dist_cn = idx_i, idx_j, dist + + z_i = atomic_numbers[idx_i_cn] + z_j = atomic_numbers[idx_j_cn] + r_cov_pair = self._rcov[z_i] + self._rcov[z_j] + + term, dterm_ddist = self._cn_counting_term( + dist_cn, r_cov_pair, compute_derivative=compute_derivative + ) + + n_atoms = atomic_numbers.shape[0] + cn = torch.zeros(n_atoms, dtype=term.dtype, device=term.device) + cn = cn.index_add(0, idx_i_cn, term) + cn = cn.index_add(0, idx_j_cn, term) + return cn, dterm_ddist, cn_pair_mask + + def _compute_weights( + self, + atomic_numbers: torch.Tensor, + cn: torch.Tensor, + compute_derivatives: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Per-atom Gaussian weights over the (Z, M) reference CN grid. + + For each atom A, ``w_A^k = exp(-k_w * (CN_A - CN_ref_A^k)^2)``, + with absent reference slots (``CN_ref < 0``) zeroed out, then + normalized so that ``sum_k w_A^k = 1``. If ``compute_derivatives`` is + true, also return ``dw_A^k / dCN_A`` for the analytical direct-force + path. + """ + # Gaussian weighting steepness on (CN - CN_ref)^2, hardcoded to + # match tad_dftd3. + # Ref: k_3 in Section 2.E of https://doi.org/10.1063/1.3382344 + k_weight: float = 4.0 + + ref_cn_i = self._cn_ref[atomic_numbers] # (n_atoms, M) + # CN references can be zero, e.g. Noble gases, + # negative values (-1) mean invalid references + mask = ref_cn_i >= 0.0 + + # Match tad-dftd3's numerics, but use a log-sum-exp normalization + # instead of dividing by tiny raw Gaussian sums. In float32, + # transition-metal CN values far from all reference points can produce + # norms around 1e-30; a direct division creates unstable backward + # intermediates even if the forward value is finite. + diff = (cn.unsqueeze(1) - ref_cn_i).to(dtype=torch.float64) + log_w = -k_weight * diff.pow(2) + neg_inf = torch.full_like(log_w, -torch.inf) + log_w = torch.where(mask, log_w, neg_inf) + + max_log_w, _ = log_w.max(dim=1, keepdim=True) + w = torch.exp(log_w - max_log_w) + w = torch.where(mask, w, torch.zeros_like(w)) + + norm = w.sum(dim=1, keepdim=True) + w_normalized = w / norm + weights = w_normalized.to(dtype=cn.dtype) + + if not compute_derivatives: + return weights, torch.empty(0, dtype=cn.dtype, device=cn.device) + + dlog_w = -2.0 * k_weight * diff + dlog_w = torch.where(mask, dlog_w, torch.zeros_like(dlog_w)) + mean_dlog_w = (w_normalized * dlog_w).sum(dim=1, keepdim=True) + dw_dcn = w_normalized * (dlog_w - mean_dlog_w) + + return weights, dw_dcn.to(dtype=cn.dtype) + + def _compute_c6_pairs( + self, + atomic_numbers: torch.Tensor, + weights: torch.Tensor, + dweights_dcn: torch.Tensor, + idx_i: torch.Tensor, + idx_j: torch.Tensor, + compute_derivatives: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Effective ``C6_AB`` for each dispersion pair via the (M, M) + reference grid contracted against the per-atom CN weights. + + Zero C6 reference entries are missing D3 reference points, not physical + zero-C6 environments, so exclude them from the pair-specific + normalization denominator. I haven't found a clear statement of this in the D3 + literature, but I checked the values of C6, and the smallest nonzero entry is + 0.9311. + + If ``compute_derivatives`` is true, also return derivatives + with respect to the two endpoint CN values. + """ + z_i = atomic_numbers[idx_i] + z_j = atomic_numbers[idx_j] + + w_i = weights[idx_i] # (P, M) + w_j = weights[idx_j] + c6_ref_pairs = self._c6[z_i, z_j] # (P, M, M) + + weighted_c6 = torch.bmm(w_i.unsqueeze(1), c6_ref_pairs).squeeze(1) + numerator = (weighted_c6 * w_j).sum(dim=1) + + valid_reference = (c6_ref_pairs != 0.0).to(dtype=w_i.dtype) + valid_weight = torch.bmm(w_i.unsqueeze(1), valid_reference).squeeze(1) + denominator = (valid_weight * w_j).sum(dim=1) + + small = torch.full_like(denominator, 1e-20) + has_reference = denominator > small + safe_denominator = torch.where( + has_reference, denominator, torch.ones_like(denominator) + ) + zero = torch.zeros((), dtype=numerator.dtype, device=numerator.device) + c6_pairs = torch.where(has_reference, numerator / safe_denominator, zero) + + if not compute_derivatives: + empty = torch.empty(0, dtype=numerator.dtype, device=numerator.device) + return c6_pairs, empty, empty + + # Calculate the derivative + dw_i = dweights_dcn[idx_i] + dw_j = dweights_dcn[idx_j] + dnumerator_i = ( + torch.bmm(dw_i.unsqueeze(1), c6_ref_pairs).squeeze(1) * w_j + ).sum(dim=1) + dnumerator_j = (weighted_c6 * dw_j).sum(dim=1) + ddenominator_i = ( + torch.bmm(dw_i.unsqueeze(1), valid_reference).squeeze(1) * w_j + ).sum(dim=1) + ddenominator_j = (valid_weight * dw_j).sum(dim=1) + + dc6_dcn_i = torch.where( + has_reference, + (dnumerator_i - c6_pairs * ddenominator_i) / safe_denominator, + zero, + ) + dc6_dcn_j = torch.where( + has_reference, + (dnumerator_j - c6_pairs * ddenominator_j) / safe_denominator, + zero, + ) + return c6_pairs, dc6_dcn_i, dc6_dcn_j + + def _excluded_pair_mask( + self, + atomic_numbers: torch.Tensor, + idx_i: torch.Tensor, + idx_j: torch.Tensor, + ) -> torch.Tensor: + """Per-pair bool mask: ``True`` for pairs to keep (i.e. neither endpoint + has a type in :attr:`_excluded_atom_types`).""" + excluded_atom_types = self._excluded_atom_types.to(device=atomic_numbers.device) + z_i = atomic_numbers[idx_i] + z_j = atomic_numbers[idx_j] + excluded_i = (z_i.unsqueeze(1) == excluded_atom_types.unsqueeze(0)).any(dim=1) + excluded_j = (z_j.unsqueeze(1) == excluded_atom_types.unsqueeze(0)).any(dim=1) + return ~(excluded_i | excluded_j) + + def _bj_damping_terms( + self, + atomic_numbers: torch.Tensor, + idx_i: torch.Tensor, + idx_j: torch.Tensor, + dist: torch.Tensor, + a1: float, + a2: float, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Shared Becke-Johnson damping geometry for a set of pairs. + + Returns ``(qq, denom6, denom8)`` with ``qq = 3 * r4r2_i * r4r2_j`` + (so ``R0 = sqrt(qq)``) and ``denom_n = dist^n + cutoff_r^n`` where + ``cutoff_r = a1 * R0 + a2``. + """ + # C8 = C6 * 3 * Q_A * Q_B, with Q stored as r4r2 (length unit). + # This means R0 = sqrt(C8 / C6) simplifies to sqrt(3 * Q_A * Q_B). + z_i = atomic_numbers[idx_i] + z_j = atomic_numbers[idx_j] + qq = 3.0 * self._r4r2[z_i] * self._r4r2[z_j] + cutoff_r = a1 * torch.sqrt(qq) + a2 + + dist2 = dist * dist + dist4 = dist2 * dist2 + dist6 = dist4 * dist2 + dist8 = dist4 * dist4 + + cutoff2 = cutoff_r * cutoff_r + cutoff4 = cutoff2 * cutoff2 + cutoff6 = cutoff4 * cutoff2 + cutoff8 = cutoff4 * cutoff4 + + denom6 = dist6 + cutoff6 + denom8 = dist8 + cutoff8 + return qq, denom6, denom8 + + def _compute_pair_energies( + self, + atomic_numbers: torch.Tensor, + c6_pairs: torch.Tensor, + idx_i: torch.Tensor, + idx_j: torch.Tensor, + dist: torch.Tensor, + a1: float, + a2: float, + s6: float, + s8: float, + ) -> torch.Tensor: + """Becke-Johnson damped energy for every half-list pair.""" + qq, denom6, denom8 = self._bj_damping_terms( + atomic_numbers, idx_i, idx_j, dist, a1, a2 + ) + c8_pairs = c6_pairs * qq + + e6 = -(c6_pairs / denom6) + e8 = -(c8_pairs / denom8) + + energy_pairs = s6 * e6 + s8 * e8 + return energy_pairs + + @staticmethod + def _selected_atoms_for_system( + selected_atoms: Optional[Labels], system_index: int + ) -> Optional[torch.Tensor]: + if selected_atoms is None: + return None + + selected_values = selected_atoms.values.to(torch.int64) + mask = selected_values[:, 0] == system_index + return selected_values[mask, 1] + + @staticmethod + def _system_sample_indices(block: TensorBlock) -> torch.Tensor: + return block.samples.values[:, 0].to(torch.int64) + + @staticmethod + def _atom_sample_indices(block: TensorBlock, systems: List[System]) -> torch.Tensor: + sample_values = block.samples.values.to(torch.int64) + offsets = torch.jit.annotate(List[int], []) + offset = 0 + for system in systems: + offsets.append(offset) + offset += system.positions.shape[0] + + offset_tensor = torch.tensor( + offsets, dtype=torch.int64, device=sample_values.device + ) + return offset_tensor[sample_values[:, 0]] + sample_values[:, 1] + + def _neighbor_pairs( + self, system: System + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + nl = system.get_neighbor_list(self._neighbor_list) + sample_values = nl.samples.values.to(torch.int64) + idx_i = sample_values[:, 0] + idx_j = sample_values[:, 1] + return idx_i, idx_j, nl.values + + def _d3_atomic_energies( + self, + system: System, + damping_key: str, + ) -> torch.Tensor: + idx_i, idx_j, neighbor_values = self._neighbor_pairs(system) + + atomic_numbers = system.types.to(torch.int64) + + dist = torch.linalg.vector_norm(neighbor_values, dim=1).squeeze( + -1 + ) * unit_conversion_factor(self._length_unit, "bohr") + + # TODO: if we ever support workflows that request multiple distinct + # damping keys at once, these damping-independent pair terms + # (CN/weights/C6) could be computed once and reused. + cn, _, _ = self._compute_cn(atomic_numbers, idx_i, idx_j, dist) + weights, dummy_dweights_dcn = self._compute_weights(atomic_numbers, cn) + + # We select pairs that are within the cutoff and not excluded by type before + # computing. + if self._cutoff < self._neighbor_cutoff: + mask = dist <= self._cutoff * unit_conversion_factor( + self._length_unit, "bohr" + ) + idx_i, idx_j, dist = idx_i[mask], idx_j[mask], dist[mask] + if self._excluded_atom_types.numel() != 0: + keep_pair = self._excluded_pair_mask(atomic_numbers, idx_i, idx_j) + idx_i, idx_j, dist = idx_i[keep_pair], idx_j[keep_pair], dist[keep_pair] + + c6_pairs, _, _ = self._compute_c6_pairs( + atomic_numbers, weights, dummy_dweights_dcn, idx_i, idx_j + ) + + energy_pairs = self._compute_pair_energies( + atomic_numbers, + c6_pairs, + idx_i, + idx_j, + dist, + a1=self._a1[damping_key], + a2=self._a2[damping_key], + s6=self._s6[damping_key], + s8=self._s8[damping_key], + ) + atomic_energies = torch.zeros( + atomic_numbers.shape[0], + dtype=energy_pairs.dtype, + device=energy_pairs.device, + ) + half_pair_energies = 0.5 * energy_pairs + atomic_energies = atomic_energies.index_add(0, idx_i, half_pair_energies) + atomic_energies = atomic_energies.index_add(0, idx_j, half_pair_energies) + return atomic_energies + + def _d3_energy( + self, + system: System, + damping_key: str, + selected_atom_indices: Optional[torch.Tensor], + ) -> torch.Tensor: + atomic_energies = self._d3_atomic_energies(system, damping_key) + if selected_atom_indices is None: + return atomic_energies.sum() + + return atomic_energies.index_select(0, selected_atom_indices).sum() + + def _d3_direct_derivatives( + self, + system: System, + damping_key: str, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """D3 direct force and stress in ``Hartree / length`` and + ``Hartree / length^3``, where ``length`` is the wrapped model's length + unit.""" + idx_i, idx_j, neighbor_values = self._neighbor_pairs(system) + n_atoms = system.positions.shape[0] + forces = torch.zeros( + n_atoms, + 3, + dtype=neighbor_values.dtype, + device=neighbor_values.device, + ) + stress = torch.zeros( + 3, + 3, + dtype=neighbor_values.dtype, + device=neighbor_values.device, + ) + + if neighbor_values.numel() == 0: + return forces, stress + + pair_vectors = neighbor_values.detach().squeeze(-1) + dist_model = torch.linalg.vector_norm(pair_vectors, dim=1) + length_to_bohr = unit_conversion_factor(self._length_unit, "bohr") + dist = dist_model * length_to_bohr + + atomic_numbers = system.types.to(torch.int64) + cn, dcn_ddist, cn_pair_mask = self._compute_cn( + atomic_numbers, idx_i, idx_j, dist, compute_derivative=True + ) + weights, dweights_dcn = self._compute_weights( + atomic_numbers, cn, compute_derivatives=True + ) + + pair_indices = torch.arange( + dist.shape[0], dtype=torch.int64, device=dist.device + ) + energy_pair_mask = torch.ones( + dist.shape[0], dtype=torch.bool, device=dist.device + ) + if self._cutoff < self._neighbor_cutoff: + energy_pair_mask = dist <= self._cutoff * length_to_bohr + + if self._excluded_atom_types.numel() != 0: + energy_pair_mask = energy_pair_mask & self._excluded_pair_mask( + atomic_numbers, idx_i, idx_j + ) + + dE_dcn = torch.zeros( + n_atoms, + dtype=neighbor_values.dtype, + device=neighbor_values.device, + ) + dE_ddist = torch.zeros_like(dist) + + energy_pair_indices = pair_indices[energy_pair_mask] + if energy_pair_indices.numel() != 0: + idx_i_energy = idx_i[energy_pair_mask] + idx_j_energy = idx_j[energy_pair_mask] + dist_energy = dist[energy_pair_mask] + + c6_pairs, dc6_dcn_i, dc6_dcn_j = self._compute_c6_pairs( + atomic_numbers, + weights, + dweights_dcn, + idx_i_energy, + idx_j_energy, + compute_derivatives=True, + ) + + qq, denom6, denom8 = self._bj_damping_terms( + atomic_numbers, + idx_i_energy, + idx_j_energy, + dist_energy, + self._a1[damping_key], + self._a2[damping_key], + ) + + dist2 = dist_energy * dist_energy + dist4 = dist2 * dist2 + dist5 = dist4 * dist_energy + dist7 = dist4 * dist2 * dist_energy + + inv_denom6 = 1.0 / denom6 + inv_denom8 = 1.0 / denom8 + damp_sum = ( + self._s6[damping_key] * inv_denom6 + + self._s8[damping_key] * qq * inv_denom8 + ) + + direct_dE_ddist = c6_pairs * ( + 6.0 * self._s6[damping_key] * dist5 * inv_denom6 * inv_denom6 + + 8.0 * self._s8[damping_key] * qq * dist7 * inv_denom8 * inv_denom8 + ) + dE_ddist = dE_ddist.index_copy(0, energy_pair_indices, direct_dE_ddist) + + dE_dc6 = -damp_sum + dE_dcn_i = dE_dc6 * dc6_dcn_i + dE_dcn_j = dE_dc6 * dc6_dcn_j + dE_dcn = dE_dcn.index_add(0, idx_i_energy, dE_dcn_i) + dE_dcn = dE_dcn.index_add(0, idx_j_energy, dE_dcn_j) + + cn_pair_indices = pair_indices[cn_pair_mask] + if cn_pair_indices.numel() != 0: + idx_i_cn, idx_j_cn = idx_i[cn_pair_mask], idx_j[cn_pair_mask] + chain_dE_ddist = (dE_dcn[idx_i_cn] + dE_dcn[idx_j_cn]) * dcn_ddist + dE_ddist = dE_ddist.index_add(0, cn_pair_indices, chain_dE_ddist) + + dist_model_safe = torch.clamp(dist_model, min=1e-20) + scale = dE_ddist * length_to_bohr / dist_model_safe + dE_dr_vectors = pair_vectors * scale.unsqueeze(1) + dE_dr_vectors = torch.where( + (dist_model > 1e-20).unsqueeze(1), + dE_dr_vectors, + torch.zeros_like(dE_dr_vectors), + ) + + forces = forces.index_add(0, idx_i, dE_dr_vectors) + forces = forces.index_add(0, idx_j, -dE_dr_vectors) + + volume = torch.abs(torch.linalg.det(system.cell)) + if volume > 0.0: + stress = pair_vectors.T @ dE_dr_vectors / volume + return forces, stress + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + # Determine which of our variants the user requested. + need_variants: List[str] = [] + for energy_key in self._energy_keys: + if energy_key not in outputs: + continue + if outputs[energy_key].sample_kind not in ["system", "atom"]: + raise NotImplementedError( + "DFTD3 only supports system- or atom-sample corrected energies" + ) + need_variants.append(energy_key) + + need_force_keys: List[str] = [] + need_stress_keys: List[str] = [] + need_non_conservative_damping_keys: List[str] = [] + for force_key in self._force_keys: + if force_key in outputs: + if outputs[force_key].sample_kind != "atom": + raise NotImplementedError( + "DFTD3 only supports atom-sample non-conservative forces" + ) + need_force_keys.append(force_key) + damping_key = self._force_damping_keys[force_key] + if damping_key not in need_non_conservative_damping_keys: + need_non_conservative_damping_keys.append(damping_key) + + for stress_key in self._stress_keys: + if stress_key in outputs: + if outputs[stress_key].sample_kind != "system": + raise NotImplementedError( + "DFTD3 only supports system-sample non-conservative stress" + ) + need_stress_keys.append(stress_key) + damping_key = self._stress_damping_keys[stress_key] + if damping_key not in need_non_conservative_damping_keys: + need_non_conservative_damping_keys.append(damping_key) + + # Always forward every requested output to the base model. Non-D3 + # outputs pass through unchanged; D3-corrected energies get the + # correction added below. + if len(outputs) == 0: + results = torch.jit.annotate(Dict[str, TensorMap], {}) + else: + results = self._model(systems, outputs, selected_atoms) + + if len(need_variants) == 0 and len(need_non_conservative_damping_keys) == 0: + return results + + # First compute the D3 correction for energy variants, which will automatically + # correct the corresponding conservative forces and stress via autograd. + for energy_key in need_variants: + energy_result = results[energy_key] + block = energy_result.block() + if block.samples.names == ["system"]: + d3_energies: List[torch.Tensor] = [] + for system_i, system in enumerate(systems): + selected = self._selected_atoms_for_system(selected_atoms, system_i) + d3_energies.append(self._d3_energy(system, energy_key, selected)) + + if len(d3_energies) > 0: + correction_by_system = torch.stack(d3_energies, dim=0) + correction = correction_by_system.index_select( + 0, self._system_sample_indices(block) + ).reshape(-1, 1) + else: + correction = torch.zeros_like(block.values) + + elif block.samples.names == ["system", "atom"]: + d3_atomic_energies: List[torch.Tensor] = [] + for system in systems: + d3_atomic_energies.append( + self._d3_atomic_energies(system, energy_key) + ) + + if len(d3_atomic_energies) > 0: + correction_by_atom = torch.cat(d3_atomic_energies, dim=0) + correction = correction_by_atom.index_select( + 0, self._atom_sample_indices(block, systems) + ).reshape(-1, 1) + else: + correction = torch.zeros_like(block.values) + + else: + raise ValueError( + "DFTD3 can only correct energy blocks with 'system' or " + "'system', 'atom' samples" + ) + + corrected_values = block.values + correction.to( + dtype=block.values.dtype, device=block.values.device + ) * unit_conversion_factor("hartree", self._energy_units[energy_key]) + + corrected_block = TensorBlock( + values=corrected_values, + samples=block.samples, + components=block.components, + properties=block.properties, + ) + results[energy_key] = TensorMap(energy_result.keys, [corrected_block]) + + # Calculate the corrections for non-conservative forces and stresses + non_conservative_forces = torch.jit.annotate(Dict[str, List[torch.Tensor]], {}) + non_conservative_stresses = torch.jit.annotate( + Dict[str, List[torch.Tensor]], {} + ) + for damping_key in need_non_conservative_damping_keys: + d3_forces: List[torch.Tensor] = [] + d3_stresses: List[torch.Tensor] = [] + for system in systems: + force, stress = self._d3_direct_derivatives(system, damping_key) + d3_forces.append(force) + d3_stresses.append(stress) + non_conservative_forces[damping_key] = d3_forces + non_conservative_stresses[damping_key] = d3_stresses + + # Add the non-conservative corrections to the base model's forces and stress. + for force_key in need_force_keys: + damping_key = self._force_damping_keys[force_key] + force_result = results[force_key] + block = force_result.block() + correction_by_atom = torch.cat(non_conservative_forces[damping_key], dim=0) + correction = correction_by_atom.index_select( + 0, self._atom_sample_indices(block, systems) + ).reshape(-1, 3, 1) + force_unit = f"hartree/{self._length_unit}" + corrected_values = block.values + correction.to( + dtype=block.values.dtype, device=block.values.device + ) * unit_conversion_factor(force_unit, self._force_units[force_key]) + corrected_block = TensorBlock( + values=corrected_values, + samples=block.samples, + components=block.components, + properties=block.properties, + ) + results[force_key] = TensorMap(force_result.keys, [corrected_block]) + + for stress_key in need_stress_keys: + damping_key = self._stress_damping_keys[stress_key] + stress_result = results[stress_key] + block = stress_result.block() + correction_by_system = torch.stack( + non_conservative_stresses[damping_key], dim=0 + ) + correction = correction_by_system.index_select( + 0, self._system_sample_indices(block) + ).unsqueeze(-1) + stress_unit = f"hartree/{self._length_unit}^3" + corrected_values = block.values + correction.to( + dtype=block.values.dtype, device=block.values.device + ) * unit_conversion_factor(stress_unit, self._stress_units[stress_key]) + corrected_block = TensorBlock( + values=corrected_values, + samples=block.samples, + components=block.components, + properties=block.properties, + ) + results[stress_key] = TensorMap(stress_result.keys, [corrected_block]) + + return results diff --git a/python/metatomic_torch/pyproject.toml b/python/metatomic_torch/pyproject.toml index 40259291d..111900057 100644 --- a/python/metatomic_torch/pyproject.toml +++ b/python/metatomic_torch/pyproject.toml @@ -62,6 +62,7 @@ filterwarnings = [ "ignore:`torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning", "ignore:`torch.jit.save` is deprecated. Please switch to `torch.export`:DeprecationWarning", "ignore:`torch.jit.load` is deprecated. Please switch to `torch.export`:DeprecationWarning", + "ignore:`torch.jit.script_method` is not supported in Python 3.14+:DeprecationWarning", "ignore:`torch.jit.script` is not supported in Python 3.14+:DeprecationWarning", "ignore:`torch.jit.save` is not supported in Python 3.14+:DeprecationWarning", "ignore:`torch.jit.load` is not supported in Python 3.14+:DeprecationWarning", diff --git a/python/metatomic_torch/setup.py b/python/metatomic_torch/setup.py index 7f327b645..2ce1ae9fd 100644 --- a/python/metatomic_torch/setup.py +++ b/python/metatomic_torch/setup.py @@ -350,6 +350,7 @@ def create_version_number(version): "metatomic/torch/torch-*/bin/*", "metatomic/torch/torch-*/lib/*", "metatomic/torch/torch-*/include/*", - ] + "metatomic/torch/data/*", + ], }, ) diff --git a/python/metatomic_torch/tests/dftd3.py b/python/metatomic_torch/tests/dftd3.py new file mode 100644 index 000000000..c693c652f --- /dev/null +++ b/python/metatomic_torch/tests/dftd3.py @@ -0,0 +1,887 @@ +import os +import re +from typing import Dict, List, Optional + +import ase.build +import numpy as np +import pytest +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from vesin.metatomic import compute_requested_neighbors_from_options + +from metatomic.torch import ( # noqa: E402 + AtomisticModel, + ModelCapabilities, + ModelEvaluationOptions, + ModelMetadata, + ModelOutput, + System, + load_atomistic_model, + systems_to_torch, + unit_conversion_factor, +) +from metatomic.torch.dftd3 import DFTD3 # noqa: E402 + + +ATOMIC_NUMBER = 18 +D3_CUTOFF = 5.0 + +# 5-point reference grid matching the standard Grimme tables. +_REF_GRID = 5 + +# Energy, forces, and stress reference values for a single snapshot of 64 Ar atoms with +# the artificial D3 parameters defined in _d3_params() and damping parameters `{"a1": +# 0.4, "a2": 4.0, "s8": 1.0}`. Calculated by tad-dftd3 with the same parameters. +_D3_REFERENCE = { + "default": { + "energy": -4.557688784512186, + "forces": np.array( + [ + [-0.00726508070174851, 0.00421613323677529, 0.00713047350689955], + [0.00444419132651295, -0.0138453480050549, -0.00191253781443601], + [-0.00308607387021654, 0.01232699781381951, 0.00900597104085257], + [-0.00610167736545015, 0.00276299380227561, -0.00353193039472426], + [0.01260083353435382, -0.00614033346641161, 0.00477616628028993], + [-0.00043465811545061, 0.00566025704614823, 0.00934931094503733], + [-0.00772758165588948, 0.00030335892908554, -0.01598145511929504], + [0.00775120111007042, -0.02001380204825312, -0.00881644834538752], + [0.0040023123129326, 0.00240434690975679, -0.01750620395815836], + [0.00102021733779979, 0.00279214703020113, 0.00530657675982078], + [0.00168171921365183, 0.00178250547909488, 0.00274908781463904], + [0.00679706473916125, -0.0127589706130313, -0.0089965376824202], + [0.00913127902302183, 0.01302685144580063, 0.00685386705634312], + [0.01109794534808769, -0.00438573870534469, -0.0053358145115363], + [-0.01208514260862648, 0.01365753638264123, 0.01206812791952485], + [-0.01362855558638936, -0.00146861040263824, 0.00906093284477281], + [0.00666587625609256, -0.00157434351635681, 0.00157821155234659], + [-0.00371968595647699, -0.00444630226580355, -0.00021902175920297], + [0.00791739330252776, 0.00810729776920197, -0.00481271020086399], + [-0.00212081385321295, -0.00873496197538598, -0.00383113551715393], + [0.00020810819027938, 0.01694212346347319, 0.01154587588359852], + [0.00102035905274625, -0.01338970052217113, -0.00740874907733129], + [-0.00809861648157284, 0.0039800283781339, -0.00353126147783306], + [-0.00614531290019489, -0.01571869766478938, -0.0011870192141982], + [-0.00697144589551845, 0.00835972553326233, 0.00344246550272658], + [-0.00045746125560518, -0.00053953437674736, -0.01648164832703997], + [0.00067758271174474, 0.0046551984437648, -0.00803243988498612], + [0.00104514491467035, -0.00015729419107496, 0.00718322804358889], + [-0.00335313838707656, 0.00512768136258867, 0.00740313395620435], + [-0.00709951943713622, -0.0034666077071482, -0.00246717309068727], + [0.0081580287143527, -0.00227504064164253, 0.00462850801794592], + [0.00407550698255931, 0.00281010307583006, 0.00797014925066365], + ] + ), + "stress": np.array( + [ + [ + 7.0830378606643353e-03, + 1.7653051504568702e-06, + 1.2574319088078316e-06, + ], + [ + 1.7653051504568702e-06, + 7.0822819743685574e-03, + 3.6606343262798787e-06, + ], + [ + 1.2574319088078316e-06, + 3.6606343262798787e-06, + 7.0824372365694095e-03, + ], + ] + ), + }, + "shifted": { + "energy": -6.195052365042283, + }, +} + + +@pytest.fixture +def model(): + return AtomisticModel( + ZeroEnergyModel().eval(), + ModelMetadata(), + ModelCapabilities( + outputs={ + "energy": ModelOutput( + sample_kind="system", unit="eV", description="energy" + ), + "energy/shifted": ModelOutput( + sample_kind="system", + unit="eV", + description="energy for shifted head", + ), + "non_conservative_force": ModelOutput( + sample_kind="atom", + unit="eV/Angstrom", + description="non-conservative forces", + ), + "non_conservative_stress": ModelOutput( + sample_kind="system", + unit="eV/Angstrom^3", + description="non-conservative stress", + ), + "non_conservative_force/direct": ModelOutput( + sample_kind="atom", + unit="eV/Angstrom", + description="non-conservative force head", + ), + "non_conservative_stress/direct": ModelOutput( + sample_kind="system", + unit="eV/Angstrom^3", + description="non-conservative stress head", + ), + }, + atomic_types=[ATOMIC_NUMBER], + interaction_range=0.0, + length_unit="Angstrom", + supported_devices=["cpu", "cuda"], + dtype="float64", + ), + ) + + +@pytest.fixture +def model_with_atomic_energy(): + return AtomisticModel( + ZeroEnergyModel().eval(), + ModelMetadata(), + ModelCapabilities( + outputs={ + "energy": ModelOutput( + sample_kind="atom", unit="eV", description="atomic energy" + ), + }, + atomic_types=[ATOMIC_NUMBER], + interaction_range=0.0, + length_unit="Angstrom", + supported_devices=["cpu", "cuda"], + dtype="float64", + ), + ) + + +class ZeroEnergyModel(torch.nn.Module): + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + results = torch.jit.annotate(Dict[str, TensorMap], {}) + + values = torch.jit.annotate(List[torch.Tensor], []) + for system in systems: + values.append((system.positions.sum() + system.cell.sum()) * 0.0) + + if len(values) == 0: + raise ValueError("ZeroEnergyModel requires at least one system") + + device = values[0].device + system_labels = Labels( + "system", + torch.arange(len(values), dtype=torch.int64, device=device).reshape(-1, 1), + ) + keys = Labels("_", torch.tensor([[0]], dtype=torch.int64, device=device)) + base_values = torch.stack(values, dim=0).reshape(-1, 1) + for name in outputs: + if name == "energy" or name == "energy/shifted": + properties = Labels( + "energy", torch.tensor([[0]], dtype=torch.int64, device=device) + ) + output = outputs[name] + + if output.sample_kind == "system": + output_values = base_values.clone() + if name == "energy/shifted": + output_values = 0.1 + output_values + + block = TensorBlock( + values=output_values, + samples=system_labels, + components=torch.jit.annotate(List[Labels], []), + properties=properties, + ) + + elif output.sample_kind == "atom": + atomic_values = torch.jit.annotate(List[torch.Tensor], []) + atomic_samples = torch.jit.annotate(List[torch.Tensor], []) + for i, system in enumerate(systems): + n_atoms = system.positions.shape[0] + atom_indices = torch.arange( + n_atoms, dtype=torch.int64, device=device + ) + if selected_atoms is not None: + selected_values = selected_atoms.values.to(torch.int64) + mask = selected_values[:, 0] == i + atom_indices = selected_values[mask, 1] + + values_i = ( + system.positions.index_select(0, atom_indices).sum( + dim=1, keepdim=True + ) + * 0.0 + ) + if name == "energy/shifted": + values_i = 0.1 + values_i + atomic_values.append(values_i) + atomic_samples.append( + torch.cat( + [ + torch.full( + (atom_indices.shape[0], 1), + i, + dtype=torch.int64, + device=device, + ), + atom_indices.reshape(-1, 1), + ], + dim=1, + ) + ) + + block = TensorBlock( + values=torch.cat(atomic_values, dim=0), + samples=Labels( + ["system", "atom"], torch.cat(atomic_samples, dim=0) + ), + components=torch.jit.annotate(List[Labels], []), + properties=properties, + ) + + else: + raise ValueError("unsupported energy sample kind") + + blocks = torch.jit.annotate(List[TensorBlock], [block]) + results[name] = TensorMap(keys, blocks) + + elif name == "non_conservative_force" or name.startswith( + "non_conservative_force/" + ): + force_values = torch.jit.annotate(List[torch.Tensor], []) + force_samples = torch.jit.annotate(List[torch.Tensor], []) + for i, system in enumerate(systems): + n_atoms = system.positions.shape[0] + atom_indices = torch.arange( + n_atoms, dtype=torch.int64, device=device + ) + if selected_atoms is not None: + selected_values = selected_atoms.values.to(torch.int64) + mask = selected_values[:, 0] == i + atom_indices = selected_values[mask, 1] + + force_values.append( + system.positions.index_select(0, atom_indices) * 0.0 + ) + force_samples.append( + torch.cat( + [ + torch.full( + (atom_indices.shape[0], 1), + i, + dtype=torch.int64, + device=device, + ), + atom_indices.reshape(-1, 1), + ], + dim=1, + ) + ) + + block = TensorBlock( + values=torch.cat(force_values, dim=0).unsqueeze(-1), + samples=Labels(["system", "atom"], torch.cat(force_samples, dim=0)), + components=[ + Labels( + "xyz", + torch.arange(3, dtype=torch.int64, device=device).reshape( + -1, 1 + ), + ) + ], + properties=Labels( + "non_conservative_force", + torch.tensor([[0]], dtype=torch.int64, device=device), + ), + ) + blocks = torch.jit.annotate(List[TensorBlock], [block]) + results[name] = TensorMap(keys, blocks) + + elif name == "non_conservative_stress" or name.startswith( + "non_conservative_stress/" + ): + stress_values = torch.jit.annotate(List[torch.Tensor], []) + for system in systems: + stress_values.append( + torch.zeros((3, 3), dtype=base_values.dtype, device=device) + + system.cell.sum() * 0.0 + ) + + block = TensorBlock( + values=torch.stack(stress_values, dim=0).unsqueeze(-1), + samples=system_labels, + components=[ + Labels( + "xyz_1", + torch.arange(3, dtype=torch.int64, device=device).reshape( + -1, 1 + ), + ), + Labels( + "xyz_2", + torch.arange(3, dtype=torch.int64, device=device).reshape( + -1, 1 + ), + ), + ], + properties=Labels( + "non_conservative_stress", + torch.tensor([[0]], dtype=torch.int64, device=device), + ), + ) + blocks = torch.jit.annotate(List[TensorBlock], [block]) + results[name] = TensorMap(keys, blocks) + + return results + + +@pytest.fixture +def atoms(): + rng = np.random.default_rng(0xDEADBEEF) + system = ase.build.bulk("Ar", "fcc", a=5.26, cubic=True).repeat((2, 2, 2)) + system.positions += 0.15 * rng.random(system.positions.shape) + return system + + +def _d3_params(): + """Synthetic D3 reference tables matching the values used to generate the + ground-truth snapshot. The function produces a single Ar-Ar C6 of 100, + one valid CN reference at 1.0, and rcov / r4r2 of 1.5 / 2.0. + """ + size = ATOMIC_NUMBER + 1 + rcov = torch.zeros(size, dtype=torch.float64) + rcov[1:] = 1.0 + rcov[ATOMIC_NUMBER] = 1.5 + + r4r2 = torch.zeros(size, dtype=torch.float64) + r4r2[1:] = 1.0 + r4r2[ATOMIC_NUMBER] = 2.0 + + c6 = torch.zeros((size, size, _REF_GRID, _REF_GRID), dtype=torch.float64) + c6[ATOMIC_NUMBER, ATOMIC_NUMBER] = 100.0 + + cn_ref = torch.full((size, _REF_GRID), -1.0, dtype=torch.float64) + # All 5 reference points share the same CN value, so the weights are + # uniform and the effective C6 matches the per-pair (5, 5) test fixture. + cn_ref[ATOMIC_NUMBER, :] = 1.0 + + return {"rcov": rcov, "r4r2": r4r2, "c6": c6, "cn_ref": cn_ref} + + +def _eval( + model, + atoms, + outputs, + selected_atoms=None, + positions_requires_grad=False, + with_strain=False, +): + system = systems_to_torch( + atoms, + dtype=torch.float64, + positions_requires_grad=positions_requires_grad, + ) + strain = None + if with_strain: + strain = torch.eye( + 3, + dtype=system.positions.dtype, + device=system.positions.device, + requires_grad=True, + ) + positions = system.positions @ strain + if positions_requires_grad: + positions.retain_grad() + + system = System( + positions=positions, + cell=system.cell @ strain, + types=system.types, + pbc=system.pbc, + ) + + systems = [system] + compute_requested_neighbors_from_options( + systems, + model.requested_neighbor_lists(), + "Angstrom", + True, + ) + options = ModelEvaluationOptions( + length_unit="Angstrom", + outputs=outputs, + selected_atoms=selected_atoms, + ) + return model(systems, options, check_consistency=True), system, strain + + +def test_dftd3_default_cutoffs_use_grimme(model): + """When the caller does not specify cutoffs, the wrapper must default to + the Grimme values (50 / 25 Bohr) converted into the model's length unit + (here Angstrom).""" + wrapper = DFTD3.wrap( + model, + d3_params=_d3_params(), + damping_params={"energy": {"a1": 0.4, "a2": 4.0, "s8": 1.0}}, + ) + BOHR_IN_ANGSTROM = unit_conversion_factor("Bohr", "Angstrom") + nls = wrapper.requested_neighbor_lists() + assert len(nls) == 1 + assert nls[0].requestors() == ["DFTD3"] + assert nls[0].cutoff == pytest.approx(50.0 * BOHR_IN_ANGSTROM) + + +def test_dftd3_selected_atoms(atoms, model): + damping = {"a1": 0.4, "a2": 4.0, "s8": 1.0} + wrapped = DFTD3.wrap( + model, + d3_params=_d3_params(), + damping_params={ + "energy": damping, + "non_conservative_force": damping, + "non_conservative_stress": damping, + }, + cutoff=D3_CUTOFF, + cn_cutoff=D3_CUTOFF, + ) + + all_atoms = Labels( + ["system", "atom"], + torch.tensor([[0, atom_i] for atom_i in range(len(atoms))], dtype=torch.int64), + ) + even_atoms = Labels( + ["system", "atom"], + torch.tensor( + [[0, atom_i] for atom_i in range(len(atoms)) if atom_i % 2 == 0], + dtype=torch.int64, + ), + ) + odd_atoms = Labels( + ["system", "atom"], + torch.tensor( + [[0, atom_i] for atom_i in range(len(atoms)) if atom_i % 2 == 1], + dtype=torch.int64, + ), + ) + + energy_output = {"energy": ModelOutput(sample_kind="system")} + full_results, _, _ = _eval(wrapped, atoms, energy_output, selected_atoms=all_atoms) + even_results, _, _ = _eval(wrapped, atoms, energy_output, selected_atoms=even_atoms) + odd_results, _, _ = _eval(wrapped, atoms, energy_output, selected_atoms=odd_atoms) + + full = full_results["energy"].block().values.item() + even = even_results["energy"].block().values.item() + odd = odd_results["energy"].block().values.item() + + np.testing.assert_allclose(full, _D3_REFERENCE["default"]["energy"], rtol=1e-10) + np.testing.assert_allclose(even + odd, full, rtol=1e-10, atol=1e-12) + + direct_outputs_even, _, _ = _eval( + wrapped, + atoms, + { + "non_conservative_force": ModelOutput(sample_kind="atom"), + }, + selected_atoms=even_atoms, + ) + force_block = direct_outputs_even["non_conservative_force"].block() + assert torch.equal(force_block.samples.values.cpu(), even_atoms.values) + np.testing.assert_allclose( + force_block.values.squeeze(-1).detach().numpy(), + _D3_REFERENCE["default"]["forces"][::2], + atol=1e-10, + rtol=1e-8, + ) + + +def test_dftd3_excluded_atom_types_disable_all_outputs(atoms, model): + wrapped = DFTD3.wrap( + model, + d3_params=_d3_params(), + damping_params={ + "energy": {"a1": 0.4, "a2": 4.0, "s8": 1.0}, + "non_conservative_force": {"a1": 0.4, "a2": 4.0, "s8": 1.0}, + "non_conservative_stress": {"a1": 0.4, "a2": 4.0, "s8": 1.0}, + }, + cutoff=D3_CUTOFF, + cn_cutoff=D3_CUTOFF, + excluded_atom_types=[ATOMIC_NUMBER], + ) + + results, system, strain = _eval( + wrapped, + atoms, + { + "energy": ModelOutput(sample_kind="system"), + "non_conservative_force": ModelOutput(sample_kind="atom"), + "non_conservative_stress": ModelOutput(sample_kind="system"), + }, + positions_requires_grad=True, + with_strain=True, + ) + corrected_energy = float(results["energy"].block().values.item()) + results["energy"].block().values.backward() + forces = -system.positions.grad.detach().cpu().numpy() + stress = (strain.grad / atoms.cell.volume).detach().cpu().numpy() + + np.testing.assert_allclose(corrected_energy, 0.0, atol=1e-12) + np.testing.assert_allclose(forces, 0.0, atol=1e-12) + np.testing.assert_allclose(stress, 0.0, atol=1e-12) + np.testing.assert_allclose( + results["non_conservative_force"].block().values.squeeze(-1).detach().numpy(), + 0.0, + atol=1e-12, + ) + np.testing.assert_allclose( + results["non_conservative_stress"] + .block() + .values.squeeze(-1) + .detach() + .numpy()[0], + 0.0, + atol=1e-12, + ) + + +def test_dftd3_multiple_variants_use_independent_damping(atoms, model): + wrapped = DFTD3.wrap( + model, + d3_params=_d3_params(), + damping_params={ + "energy": {"a1": 0.4, "a2": 4.0, "s8": 1.0}, + "energy/shifted": {"a1": 0.3, "a2": 3.0, "s8": 2.0}, + }, + cutoff=D3_CUTOFF, + cn_cutoff=D3_CUTOFF, + ) + + results, _, _ = _eval( + wrapped, + atoms, + { + "energy": ModelOutput(sample_kind="system"), + "energy/shifted": ModelOutput(sample_kind="system"), + }, + ) + corrected_default = float(results["energy"].block().values.item()) + corrected_shifted = float(results["energy/shifted"].block().values.item()) + + expected_default = _D3_REFERENCE["default"]["energy"] + expected_shifted = _D3_REFERENCE["shifted"]["energy"] + np.testing.assert_allclose( + corrected_default, expected_default, rtol=1e-10, atol=1e-12 + ) + np.testing.assert_allclose( + corrected_shifted, expected_shifted, rtol=1e-10, atol=1e-12 + ) + assert not np.isclose(expected_default, expected_shifted) + + +def test_dftd3_rejects_per_atom_corrected_energy(model, atoms): + wrapped = DFTD3.wrap( + model, + d3_params=_d3_params(), + damping_params={"energy": {"a1": 0.4, "a2": 4.0, "s8": 1.0}}, + cutoff=D3_CUTOFF, + cn_cutoff=D3_CUTOFF, + ) + with pytest.raises( + Exception, match="this model can not compute 'energy' per atom, only globally" + ): + _eval(wrapped, atoms, {"energy": ModelOutput(sample_kind="atom")}) + + +def test_dftd3_atomic_energy_matches_system_energy(model_with_atomic_energy, atoms): + wrapped = DFTD3.wrap( + model_with_atomic_energy, + d3_params=_d3_params(), + damping_params={"energy": {"a1": 0.4, "a2": 4.0, "s8": 1.0}}, + cutoff=D3_CUTOFF, + cn_cutoff=D3_CUTOFF, + ) + + assert wrapped.capabilities().outputs["energy"].sample_kind == "atom" + + atom_output = {"energy": ModelOutput(sample_kind="atom")} + results, _, _ = _eval(wrapped, atoms, atom_output) + block = results["energy"].block() + assert block.samples.names == ["system", "atom"] + assert block.values.shape == (len(atoms), 1) + + atomic_energies = block.values.detach().numpy().reshape(-1) + np.testing.assert_allclose( + atomic_energies.sum(), + _D3_REFERENCE["default"]["energy"], + rtol=1e-10, + atol=1e-12, + ) + + even_atoms = Labels( + ["system", "atom"], + torch.tensor( + [[0, atom_i] for atom_i in range(len(atoms)) if atom_i % 2 == 0], + dtype=torch.int64, + ), + ) + odd_atoms = Labels( + ["system", "atom"], + torch.tensor( + [[0, atom_i] for atom_i in range(len(atoms)) if atom_i % 2 == 1], + dtype=torch.int64, + ), + ) + + even_results, _, _ = _eval(wrapped, atoms, atom_output, selected_atoms=even_atoms) + odd_results, _, _ = _eval(wrapped, atoms, atom_output, selected_atoms=odd_atoms) + + even_block = even_results["energy"].block() + odd_block = odd_results["energy"].block() + assert torch.equal(even_block.samples.values.cpu(), even_atoms.values) + assert torch.equal(odd_block.samples.values.cpu(), odd_atoms.values) + + even_atom_indices = even_block.samples.values[:, 1].cpu().numpy() + odd_atom_indices = odd_block.samples.values[:, 1].cpu().numpy() + np.testing.assert_allclose( + even_block.values.detach().numpy().reshape(-1), + atomic_energies[even_atom_indices], + rtol=1e-10, + ) + np.testing.assert_allclose( + odd_block.values.detach().numpy().reshape(-1), + atomic_energies[odd_atom_indices], + rtol=1e-10, + ) + + +def test_dftd3_non_conservative_save_and_reload(tmp_path, model, atoms): + wrapped = DFTD3.wrap( + model, + d3_params=_d3_params(), + damping_params={ + "energy": {"a1": 0.4, "a2": 4.0, "s8": 1.0}, + "non_conservative_force": {"a1": 0.4, "a2": 4.0, "s8": 1.0}, + "non_conservative_stress": {"a1": 0.4, "a2": 4.0, "s8": 1.0}, + }, + cutoff=D3_CUTOFF, + cn_cutoff=D3_CUTOFF, + ) + outputs, _, _ = _eval( + wrapped, + atoms, + { + "energy": ModelOutput(sample_kind="system"), + "non_conservative_force": ModelOutput(sample_kind="atom"), + }, + ) + original_energy = float(outputs["energy"].block().values.item()) + original_forces = ( + outputs["non_conservative_force"].block().values.squeeze(-1).detach().numpy() + ) + + path = os.path.join(tmp_path, "dftd3.pt") + wrapped.save(path) + reloaded = load_atomistic_model(path) + + outputs, _, _ = _eval( + reloaded, + atoms, + { + "energy": ModelOutput(sample_kind="system"), + "non_conservative_force": ModelOutput(sample_kind="atom"), + }, + ) + + reloaded_energy = float(outputs["energy"].block().values.item()) + reloaded_forces = ( + outputs["non_conservative_force"].block().values.squeeze(-1).detach().numpy() + ) + + assert np.isclose(reloaded_energy, original_energy, rtol=1e-10, atol=1e-12) + assert np.allclose(reloaded_forces, original_forces, rtol=1e-10, atol=1e-12) + + outputs, _, _ = _eval( + reloaded, + atoms, + { + "energy": ModelOutput(sample_kind="system"), + "non_conservative_force": ModelOutput(sample_kind="atom"), + }, + ) + + +@pytest.mark.parametrize( + ("damping_params", "message"), + [ + ( + {"energy/does_not_exist": {"a1": 0.4, "a2": 4.0, "s8": 1.0}}, + "DFTD3 cannot correct 'energy/does_not_exist': the wrapped model does not " + "expose this output", + ), + ( + {"not_an_energy_key": {"a1": 0.4, "a2": 4.0, "s8": 1.0}}, + r"DFTD3 damping_params key must be 'energy[/]', " + "'non_conservative_force[/]' or " + "'non_conservative_stress[/]', got 'not_an_energy_key'", + ), + ], +) +def test_dftd3_rejects_invalid_damping_keys(model, damping_params, message): + with pytest.raises(ValueError, match=re.escape(message)): + DFTD3.wrap( + model, + d3_params=_d3_params(), + damping_params=damping_params, + cutoff=D3_CUTOFF, + cn_cutoff=D3_CUTOFF, + ) + + +def test_dftd3_energy_forces_and_stress_match_reference(atoms, model): + """The pure-PyTorch corrected energy is naturally differentiable through + the neighbor-list distances. Verify that the autograd path + yields conservative forces and stresses matching the frozen D3 smoke reference. + """ + wrapped = DFTD3.wrap( + model, + d3_params=_d3_params(), + damping_params={"energy": {"a1": 0.4, "a2": 4.0, "s8": 1.0}}, + cutoff=D3_CUTOFF, + cn_cutoff=D3_CUTOFF, + ) + + results, system, strain = _eval( + wrapped, + atoms, + {"energy": ModelOutput(sample_kind="system")}, + positions_requires_grad=True, + with_strain=True, + ) + corrected_energy = float(results["energy"].block().values.item()) + results["energy"].block().values.backward() + wrapped_forces = -system.positions.grad.detach().cpu().numpy() + wrapped_stress = (strain.grad / atoms.cell.volume).detach().cpu().numpy() + + np.testing.assert_allclose( + corrected_energy, _D3_REFERENCE["default"]["energy"], rtol=1e-10, atol=1e-12 + ) + d3_forces = _D3_REFERENCE["default"]["forces"] + d3_stress = _D3_REFERENCE["default"]["stress"] + + np.testing.assert_allclose(wrapped_forces, d3_forces, atol=1e-10, rtol=1e-8) + np.testing.assert_allclose(wrapped_stress, d3_stress, atol=1e-12, rtol=1e-8) + + # Shouldn't correct non-conservative outputs + outputs, _, _ = _eval( + wrapped, + atoms, + { + "non_conservative_force": ModelOutput(sample_kind="atom"), + "non_conservative_stress": ModelOutput(sample_kind="system"), + }, + ) + + assert torch.all(outputs["non_conservative_force"].block().values == 0.0) + assert torch.all(outputs["non_conservative_stress"].block().values == 0.0) + + +def test_dftd3_non_conservative_outputs_match_d3_reference(atoms, model): + """Non-conservative force/stress outputs get the same D3 correction as the + conservative autograd path. The mixed modes exercise non-conservative forces with + autograd stress and non-conservative stress with autograd forces.""" + wrapped = DFTD3.wrap( + model, + d3_params=_d3_params(), + damping_params={ + "non_conservative_force": {"a1": 0.4, "a2": 4.0, "s8": 1.0}, + "non_conservative_stress": {"a1": 0.4, "a2": 4.0, "s8": 1.0}, + }, + cutoff=D3_CUTOFF, + cn_cutoff=D3_CUTOFF, + ) + outputs = { + "non_conservative_force": ModelOutput(sample_kind="atom"), + "non_conservative_stress": ModelOutput(sample_kind="system"), + } + + results, _, _ = _eval( + wrapped, + atoms, + outputs, + ) + + wrapped_forces = ( + results["non_conservative_force"].block().values.squeeze(-1).detach().numpy() + ) + wrapped_stress = ( + results["non_conservative_stress"] + .block() + .values.squeeze(-1) + .detach() + .numpy()[0] + ) + np.testing.assert_allclose( + wrapped_forces, + _D3_REFERENCE["default"]["forces"], + atol=1e-10, + rtol=1e-8, + ) + np.testing.assert_allclose( + wrapped_stress, + _D3_REFERENCE["default"]["stress"], + atol=1e-12, + rtol=1e-8, + ) + + +def test_dftd3_non_conservative_variant_without_energy_variant(atoms, model): + wrapped = DFTD3.wrap( + model, + d3_params=_d3_params(), + damping_params={ + "non_conservative_force/direct": {"a1": 0.4, "a2": 4.0, "s8": 1.0}, + "non_conservative_stress/direct": {"a1": 0.4, "a2": 4.0, "s8": 1.0}, + }, + cutoff=D3_CUTOFF, + cn_cutoff=D3_CUTOFF, + ) + + outputs, _, _ = _eval( + wrapped, + atoms, + { + "non_conservative_force/direct": ModelOutput(sample_kind="atom"), + "non_conservative_stress/direct": ModelOutput(sample_kind="system"), + }, + ) + + np.testing.assert_allclose( + outputs["non_conservative_force/direct"] + .block() + .values.squeeze(-1) + .detach() + .numpy(), + _D3_REFERENCE["default"]["forces"], + atol=1e-10, + rtol=1e-8, + ) + np.testing.assert_allclose( + outputs["non_conservative_stress/direct"] + .block() + .values.squeeze(-1) + .detach() + .numpy()[0], + _D3_REFERENCE["default"]["stress"], + atol=1e-12, + rtol=1e-8, + ) diff --git a/scripts/create-dftd3-params-npz.py b/scripts/create-dftd3-params-npz.py new file mode 100644 index 000000000..18a5fbb75 --- /dev/null +++ b/scripts/create-dftd3-params-npz.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python3 +"""Create the packaged DFT-D3 reference-table npz. + +By default this downloads the reference tables from ``simple-dftd3`` and the +matching covalent radii from ``mctc-lib``, and writes the packaged ``npz`` used +by ``metatomic.torch.dftd3``. This currently provides DFT-D3 C6 references up +to element 103. + +The output is the layout used by ``metatomic.torch.dftd3.DFTD3``: + +* ``rcov``: ``(Z,)`` +* ``r4r2``: ``(Z,)`` +* ``c6``: ``(Z, Z, M, M)`` +* ``cn_ref``: ``(Z, M)`` +""" + +from __future__ import annotations + +import argparse +import re +import urllib.error +import urllib.request +from pathlib import Path + +import numpy as np + + +ROOT = Path(__file__).resolve().parents[1] +SIMPLE_DFTD3_REFERENCE_URL = ( + "https://raw.githubusercontent.com/dftd3/simple-dftd3/refs/heads/main/" + "src/dftd3/reference.f90" +) +SIMPLE_DFTD3_R4R2_URL = ( + "https://raw.githubusercontent.com/dftd3/simple-dftd3/refs/heads/main/" + "src/dftd3/data/r4r2.f90" +) +MCTC_COVRAD_URL = ( + "https://raw.githubusercontent.com/grimme-lab/mctc-lib/refs/heads/main/" + "src/mctc/data/covrad.f90" +) +ANGSTROM_TO_BOHR = 1.0 / 0.5291772105448199 +DEFAULT_OUTPUT = ( + ROOT + / "python" + / "metatomic_torch" + / "metatomic" + / "torch" + / "data" + / "dftd3_parameters.npz" +) +_FLOAT_RE = r"[-+]?(?:\d+\.\d*|\d*\.\d+|\d+)(?:[eEdD][-+]?\d+)?(?:_wp)?" + + +def _extract_from_simple_dftd3_sources(files: dict[str, str]) -> dict[str, np.ndarray]: + def find_integer_parameter(content: str, var_name: str) -> int: + match = re.search( + rf"integer\s*,\s*parameter\s*::\s*{var_name}\s*=\s*(\d+)", + content, + re.IGNORECASE, + ) + if match is None: + raise ValueError(f"integer parameter '{var_name}' not found") + return int(match.group(1)) + + def find_bracket_array(content: str, var_name: str) -> str: + match = re.search( + rf"{var_name}\s*(?:\([^)]*\))?\s*=\s*[^[]*\[", + content, + re.IGNORECASE | re.DOTALL, + ) + if match is None: + raise ValueError(f"array '{var_name}' not found") + + start = match.end() - 1 + depth = 0 + for index in range(start, len(content)): + char = content[index] + if char == "[": + depth += 1 + elif char == "]": + depth -= 1 + if depth == 0: + return content[start + 1 : index] + + raise ValueError(f"failed to find end of array '{var_name}'") + + def fortran_numbers(content: str) -> np.ndarray: + without_comments = [] + for line in content.splitlines(): + if "!" in line: + line = line[: line.index("!")] + without_comments.append(line) + + values = [] + for number in re.findall(_FLOAT_RE, "\n".join(without_comments)): + values.append( + float(number.replace("_wp", "").replace("D", "e").replace("d", "e")) + ) + return np.asarray(values, dtype=np.float64) + + def parse_c6(content: str, max_elem: int, max_ref: int) -> np.ndarray: + n_pairs = max_elem * (max_elem + 1) // 2 + flat = np.zeros(max_ref * max_ref * n_pairs, dtype=np.float64) + + for match in re.finditer( + r"c6ab_view\(\s*(\d+)\s*:\s*(\d+)\s*\)\s*=\s*\[", + content, + re.IGNORECASE, + ): + start = int(match.group(1)) - 1 + stop = int(match.group(2)) + depth = 0 + data_start = match.end() - 1 + data_stop = None + for index in range(data_start, len(content)): + char = content[index] + if char == "[": + depth += 1 + elif char == "]": + depth -= 1 + if depth == 0: + data_stop = index + break + + if data_stop is None: + raise ValueError("failed to find end of c6ab_view assignment") + + values = fortran_numbers(content[data_start + 1 : data_stop]) + if values.shape[0] != stop - start: + raise ValueError( + f"c6ab_view assignment {start + 1}:{stop} contains " + f"{values.shape[0]} values" + ) + flat[start:stop] = values + + triangular = flat.reshape((max_ref, max_ref, n_pairs), order="F") + c6 = np.zeros((max_elem + 1, max_elem + 1, max_ref, max_ref), dtype=np.float32) + + for z_i in range(1, max_elem + 1): + for z_j in range(1, max_elem + 1): + if z_i > z_j: + pair_index = z_j + z_i * (z_i - 1) // 2 + c6[z_i, z_j] = triangular[:, :, pair_index - 1] + else: + pair_index = z_i + z_j * (z_j - 1) // 2 + c6[z_i, z_j] = triangular[:, :, pair_index - 1].T + + return c6 + + reference = files["reference.f90"] + max_elem = find_integer_parameter(reference, "max_elem") + max_ref = find_integer_parameter(reference, "max_ref") + + reference_cn_values = fortran_numbers(find_bracket_array(reference, "reference_cn")) + expected_cn_values = max_elem * max_ref + if reference_cn_values.shape[0] != expected_cn_values: + raise ValueError( + f"reference_cn contains {reference_cn_values.shape[0]} values, " + f"expected {expected_cn_values}" + ) + cn_ref = np.full((max_elem + 1, max_ref), -1.0, dtype=np.float32) + cn_ref[1:] = reference_cn_values.reshape((max_elem, max_ref)).astype(np.float32) + + r4_over_r2 = fortran_numbers(find_bracket_array(files["r4r2.f90"], "r4_over_r2")) + if r4_over_r2.shape[0] < max_elem: + raise ValueError( + f"r4_over_r2 covers only {r4_over_r2.shape[0]} elements, " + f"but reference.f90 covers {max_elem}" + ) + atomic_numbers = np.arange(1, max_elem + 1, dtype=np.float64) + r4r2 = np.zeros(max_elem + 1, dtype=np.float32) + r4r2[1:] = np.sqrt(0.5 * r4_over_r2[:max_elem] * np.sqrt(atomic_numbers)).astype( + np.float32 + ) + + covalent_radii = fortran_numbers( + find_bracket_array(files["covrad.f90"], "covalent_rad_2009") + ) + if covalent_radii.shape[0] < max_elem: + raise ValueError( + f"covalent radii cover only {covalent_radii.shape[0]} elements, " + f"but reference.f90 covers {max_elem}" + ) + # Pre-multiply by 4/3 and convert from Angstrom to Bohr + # See https://github.com/tad-mctc/tad-mctc/blob/0d3bb31018520fb8a85bc79c000d4aae01f51235/src/tad_mctc/data/radii.py#L133 + rcov = np.zeros(max_elem + 1, dtype=np.float32) + rcov[1:] = ( + 4.0 / 3.0 * covalent_radii[:max_elem] * ANGSTROM_TO_BOHR + ).astype(np.float32) + + return { + "rcov": rcov, + "r4r2": r4r2, + "c6": parse_c6(reference, max_elem, max_ref), + "cn_ref": cn_ref, + } + + +def main() -> int: + def download_sources( + reference_url: str, + r4r2_url: str, + covrad_url: str, + ) -> dict[str, str]: + def download_url(url: str) -> bytes: + with urllib.request.urlopen(url, timeout=30) as response: + return response.read() + + return { + "reference.f90": download_url(reference_url).decode( + "utf-8", errors="ignore" + ), + "r4r2.f90": download_url(r4r2_url).decode("utf-8", errors="ignore"), + "covrad.f90": download_url(covrad_url).decode("utf-8", errors="ignore"), + } + + def read_sources(source_dir: Path) -> dict[str, str]: + source_dir = source_dir.expanduser().resolve() + return { + "reference.f90": (source_dir / "reference.f90").read_text( + encoding="utf-8" + ), + "r4r2.f90": (source_dir / "r4r2.f90").read_text(encoding="utf-8"), + "covrad.f90": (source_dir / "covrad.f90").read_text(encoding="utf-8"), + } + + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--source-dir", + type=Path, + help="directory containing reference.f90, r4r2.f90 and covrad.f90 files", + ) + parser.add_argument( + "--simple-reference-url", + default=SIMPLE_DFTD3_REFERENCE_URL, + help="simple-dftd3 reference.f90 URL", + ) + parser.add_argument( + "--simple-r4r2-url", + default=SIMPLE_DFTD3_R4R2_URL, + help="simple-dftd3 r4r2.f90 URL", + ) + parser.add_argument( + "--covrad-url", + default=MCTC_COVRAD_URL, + help="mctc-lib covrad.f90 URL", + ) + parser.add_argument( + "--output", + type=Path, + default=DEFAULT_OUTPUT, + help="output npz file to create", + ) + args = parser.parse_args() + + output_path = args.output.expanduser().resolve() + + if args.source_dir is not None: + source_dir = args.source_dir.expanduser().resolve() + params = _extract_from_simple_dftd3_sources(read_sources(source_dir)) + source_description = str(source_dir) + else: + try: + params = _extract_from_simple_dftd3_sources( + download_sources( + args.simple_reference_url, + args.simple_r4r2_url, + args.covrad_url, + ) + ) + except (OSError, urllib.error.URLError) as e: + raise RuntimeError("failed to download simple-dftd3 reference tables") from e + source_description = ( + f"{args.simple_reference_url}, {args.simple_r4r2_url}, " + f"and {args.covrad_url}" + ) + + if params["c6"].ndim != 4: + raise ValueError(f"'c6' must be 4D, got {params['c6'].shape}") + if params["cn_ref"].shape != (params["c6"].shape[0], params["c6"].shape[2]): + raise ValueError( + f"unexpected converted 'cn_ref' shape {params['cn_ref'].shape} " + f"for c6 shape {params['c6'].shape}" + ) + + output_path.parent.mkdir(parents=True, exist_ok=True) + np.savez_compressed( + output_path, + **params, + source=np.array(source_description), + layout=np.array("metatomic.torch.dftd3 v1"), + ) + + print(f"read {source_description}") + print(f"wrote {output_path}") + for name, value in params.items(): + print(f" {name}: shape={value.shape} dtype={value.dtype}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())