DFT-D3 wrapper#219
Conversation
Yes, ideally running this code should not require internet access
Isn't this implemented?
Yes!
This should be doable, but how would it look when working with heterogeneous systems?
Yes!
Why would someone want to do or not do this? I don't quite understand the implications. |
Luthaf
left a comment
There was a problem hiding this comment.
This will need more review, just sending the notes I took during our discussion
|
Here I upload the parity plot of the current version of d3 wrapper against #!/usr/bin/env python3
from __future__ import annotations
import argparse
import os
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import ase.io
import numpy as np
import torch
from metatomic.torch import load_atomistic_model
from metatomic.torch.dftd3 import DFTD3
from metatomic_ase import MetatomicCalculator
ROOT = Path('/home/qxu/repos/metatomic')
DAMPING = {'s6': 1.0, 's8': 0.7875, 'a1': 0.4289, 'a2': 4.4407}
CUTOFF = 15.0
CN_CUTOFF = 15.0
def _stress_allowed(atoms) -> bool:
return bool(np.all(atoms.pbc)) and abs(atoms.cell.volume) > 0.0
def _batched(frames, batch_size: int, max_atoms: int):
batch = []
atoms_in_batch = 0
stress_allowed = False
for frame in frames:
frame_stress_allowed = _stress_allowed(frame[1])
n_atoms = len(frame[1])
if batch and frame_stress_allowed != stress_allowed:
yield batch
batch = []
atoms_in_batch = 0
if batch and max_atoms > 0 and atoms_in_batch + n_atoms > max_atoms:
yield batch
batch = []
atoms_in_batch = 0
batch.append(frame)
atoms_in_batch += n_atoms
stress_allowed = frame_stress_allowed
if len(batch) == batch_size:
yield batch
batch = []
atoms_in_batch = 0
if batch:
yield batch
def _simple_dftd3(atoms):
from dftd3.ase import DFTD3 as SimpleDFTD3
atoms = atoms.copy()
atoms.calc = SimpleDFTD3(
damping='d3bj',
params_tweaks={**DAMPING, 's9': 0.0},
realspace_cutoff={'disp2': CUTOFF, 'disp3': CUTOFF, 'cn': CN_CUTOFF},
)
energy = atoms.get_potential_energy()
forces = atoms.get_forces()
stress = atoms.get_stress(voigt=False) if _stress_allowed(atoms) else None
return energy, forces, stress
def _metrics(reference: np.ndarray, value: np.ndarray) -> str:
delta = np.asarray(value, dtype=float).reshape(-1) - np.asarray(
reference, dtype=float
).reshape(-1)
if delta.size == 0:
return 'not available'
return (
f'RMSE={np.sqrt(np.mean(delta * delta)):.3e} '
f'MAE={np.mean(np.abs(delta)):.3e} '
f'max={np.max(np.abs(delta)):.3e}'
)
def _plot_panel(ax, reference, value, title: str, max_points: int):
reference = np.asarray(reference, dtype=float).reshape(-1)
value = np.asarray(value, dtype=float).reshape(-1)
finite = np.isfinite(reference) & np.isfinite(value)
reference = reference[finite]
value = value[finite]
if reference.size == 0:
ax.set_title(title)
ax.text(0.5, 0.5, 'not available', ha='center', va='center')
ax.set_axis_off()
return
if reference.size > max_points:
rng = np.random.default_rng(0)
keep = rng.choice(reference.size, size=max_points, replace=False)
reference_plot = reference[keep]
value_plot = value[keep]
else:
reference_plot = reference
value_plot = value
lo = min(reference.min(), value.min())
hi = max(reference.max(), value.max())
pad = 0.03 * (hi - lo if hi > lo else max(abs(hi), 1.0))
lo -= pad
hi += pad
ax.scatter(reference_plot, value_plot, s=6, alpha=0.35, linewidths=0)
ax.plot([lo, hi], [lo, hi], color='black', linewidth=1)
ax.set_xlim(lo, hi)
ax.set_ylim(lo, hi)
ax.set_title(title)
ax.set_xlabel('simple-dftd3')
ax.set_ylabel('metatomic wrapper correction')
ax.text(
0.04,
0.96,
f'N={reference.size}\n{_metrics(reference, value)}',
transform=ax.transAxes,
va='top',
ha='left',
fontsize=9,
bbox={'facecolor': 'white', 'alpha': 0.75, 'edgecolor': 'none'},
)
def _plot(path: Path, ref_energy, value_energy, ref_forces, value_forces, ref_stress, value_stress, max_points: int):
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(14, 4.2), constrained_layout=True)
_plot_panel(axes[0], ref_energy, value_energy, 'Energy (eV)', max_points)
_plot_panel(axes[1], ref_forces, value_forces, 'Force components (eV/A)', max_points)
_plot_panel(axes[2], ref_stress, value_stress, 'Stress components (eV/A^3)', max_points)
fig.savefig(path, dpi=220)
plt.close(fig)
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument('--xyz', default=str(ROOT / 'mad-train.xyz'))
parser.add_argument('--model', default=str(ROOT / 'pet-mad-s-v1.0.2.pt'))
parser.add_argument('--output', default='dftd3-simple-parity.png')
parser.add_argument('--npz', default='dftd3-simple-parity.npz')
parser.add_argument('--device', choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('--batch-size', type=int, default=200)
parser.add_argument('--max-atoms-per-batch', type=int, default=1800)
parser.add_argument('--reference-workers', type=int, default=8)
parser.add_argument('--max-points', type=int, default=200_000)
args = parser.parse_args()
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
model = load_atomistic_model(args.model)
wrapped = DFTD3.wrap(
model,
damping_params={'energy': DAMPING},
cutoff=CUTOFF,
cn_cutoff=CN_CUTOFF,
)
base_calc = MetatomicCalculator(
model,
device=args.device,
check_consistency=False,
non_conservative=False,
uncertainty_threshold=None,
)
wrapped_calc = MetatomicCalculator(
wrapped,
device=args.device,
check_consistency=False,
non_conservative=False,
uncertainty_threshold=None,
)
frame_ids = []
stress_frame_ids = []
reference_energy = []
reference_forces = []
reference_stress = []
wrapper_energy = []
wrapper_forces = []
wrapper_stress = []
reference_time = 0.0
wrapper_time = 0.0
frames = enumerate(ase.io.iread(args.xyz, ':'))
batches = _batched(frames, args.batch_size, args.max_atoms_per_batch)
for batch in batches:
indices = [index for index, _ in batch]
atoms_batch = [atoms for _, atoms in batch]
frame_ids.extend(indices)
batch_has_stress = all(_stress_allowed(atoms) for atoms in atoms_batch)
if args.device == 'cuda':
torch.cuda.synchronize()
start = time.perf_counter()
base = base_calc.compute_energy(atoms_batch, compute_forces_and_stresses=True)
wrapped_out = wrapped_calc.compute_energy(
atoms_batch, compute_forces_and_stresses=True
)
if args.device == 'cuda':
torch.cuda.synchronize()
wrapper_time += time.perf_counter() - start
wrapper_energy.append(
np.asarray(wrapped_out['energy'], dtype=float).reshape(-1)
- np.asarray(base['energy'], dtype=float).reshape(-1)
)
wrapper_forces.append(
np.concatenate(wrapped_out['forces'], axis=0)
- np.concatenate(base['forces'], axis=0)
)
if batch_has_stress:
wrapper_stress.append(
np.asarray(wrapped_out['stress'], dtype=float)
- np.asarray(base['stress'], dtype=float)
)
start = time.perf_counter()
with ThreadPoolExecutor(max_workers=args.reference_workers) as pool:
references = list(pool.map(_simple_dftd3, atoms_batch))
reference_time += time.perf_counter() - start
for index, (energy, forces, stress) in zip(indices, references, strict=True):
reference_energy.append(energy)
reference_forces.append(forces)
if stress is not None:
stress_frame_ids.append(index)
reference_stress.append(stress.reshape(1, 3, 3))
reference_energy = np.asarray(reference_energy)
reference_forces = np.concatenate(reference_forces, axis=0)
reference_stress = (
np.concatenate(reference_stress, axis=0)
if reference_stress
else np.empty((0, 3, 3))
)
wrapper_energy = np.concatenate(wrapper_energy, axis=0)
wrapper_forces = np.concatenate(wrapper_forces, axis=0)
wrapper_stress = (
np.concatenate(wrapper_stress, axis=0)
if wrapper_stress
else np.empty((0, 3, 3))
)
print(f'frames={len(frame_ids)}')
print(f'stress_frames={len(stress_frame_ids)}')
print(f'simple_dftd3_time_s={reference_time:.3f}')
print(f'wrapper_minus_base_time_s={wrapper_time:.3f}')
print('energy:', _metrics(reference_energy, wrapper_energy))
print('forces:', _metrics(reference_forces, wrapper_forces))
print('stress:', _metrics(reference_stress, wrapper_stress))
_plot(
Path(args.output),
reference_energy,
wrapper_energy,
reference_forces,
wrapper_forces,
reference_stress,
wrapper_stress,
args.max_points,
)
np.savez_compressed(
args.npz,
frame_ids=np.asarray(frame_ids, dtype=np.int64),
stress_frame_ids=np.asarray(stress_frame_ids, dtype=np.int64),
reference_energy=reference_energy,
wrapper_energy=wrapper_energy,
reference_forces=reference_forces,
wrapper_forces=wrapper_forces,
reference_stress=reference_stress,
wrapper_stress=wrapper_stress,
simple_dftd3_time_s=np.asarray(reference_time),
wrapper_minus_base_time_s=np.asarray(wrapper_time),
)
print(f'plot: {args.output}')
print(f'data: {args.npz}')
return 0
if __name__ == '__main__':
raise SystemExit(main())To run this script, one should compile simple-dftd3 from source, because one would have to modify this line |
82c3701 to
f6b4e80
Compare
Co-authored-by: Guillaume Fraux <luthaf@luthaf.fr>
Co-authored-by: Guillaume Fraux <luthaf@luthaf.fr>
…d create a script for downloading and parsing the parameters
…ded more comments
Co-authored-by: Guillaume Fraux <luthaf@luthaf.fr>
f65b0b0 to
5750c58
Compare

This PR aims to introduce a DFT-D3 wrapper, which can be wrapped outside of MLIPs in
AtomisticModel. The wrapper currently only supports D3(BJ).This wrapper calculates the energy correction of the input structure, and add it to the energy calculated by the wrapped MLIP. In this way, and through the auto-differentiation, the D3-corrected forces and stresses (if with full-PBC) can be automatically calculated.
I have tested this wrapper on MAD dataset. The energies, forces, and stresses are in good agreement with another wrapper implemented with the DFT-D3 provided in the nvalchemiops package.
The overhead is about 15%-35% for s models on systems ranging from 600 atoms to 4k atoms. Mainly from the neighbor list calculation, the actual overheads of forward and backward are only ~3% and ~8%.
Remaining issues:
ModelCapability? Currently it's updatedContributor (creator of pull-request) checklist
Reviewer checklist