################################################################################
# Authors: Zemin Xu
# License: MIT, see LICENSE.md
################################################################################
import warnings
from typing import Optional
import torch
from ase import units
from ase.calculators.calculator import Calculator, all_changes
from ase.calculators.mixing import SumCalculator
from torch_geometric.loader import DataLoader
from ...lightning import load_tace
from ...dataset.quantity import PROPERTY
from ...dataset.element import TorchElement
from ...dataset.graph import from_atoms
from ...dataset.quantity import (
PROPERTY,
KEYS,
KeySpecification,
update_keyspec_from_kwargs,
)
from ...utils._global import DTYPE, DEVICE
[docs]
class TACEAseCalc(Calculator):
"""
Initialize a TACEAseCalc. We support the most fundamental potential energy surface property and multi-fidelity,
multi-head, etc. For some advanced features, you need to store the attributes that need to be embedded in atoms.info,
atoms.arrays or add a funciton by yourself. If you only need to predict, you can directly use the `tace-eval`
command. It will output the predicted files, and if you add the `--test` option, it will also output the errors.
Parameters
----------
model_path : str
Path to the trained model, file ends with pt, .pth or .ckpt.
device : str | torch.device, optional
The device to run computations on, e.g., cpu or cuda.
If None, the device is automatically inferred.
dtype : str, optional, default=None
Model dtype for computations, e.g., float32 or float64.
level : int
Specify which fidelity level to use.
spin_on : bool
If your model uses spin_on uie embedding, you can control whether
your calculation enables spin polarization.
target_property: list(str)
Extra caculate hessians, atomic_virials, Conservative polarizability, etc,
If you want to use this parameter, you must provide all the required physical quantities.
neighborlist_backend: str
Support backend in one of [ase, matscipy, vesin], recommend matscipy
**kwargs
Additional keyword arguments passed to the ASE Calculator base class.
"""
def __init__(
self,
model: str,
*,
dtype: Optional[str] = None,
device: Optional[str] = None,
level: Optional[int] = None,
spin_on: Optional[bool] = None,
target_property: Optional[list[str]] = None,
neighborlist_backend: str = "matscipy",
**kwargs,
):
super().__init__(**kwargs)
# === init ===
model = load_tace(
model,
device,
strict=True,
use_ema=True,
target_property=target_property
)
model_dtype = model.readout_fn.cutoff.dtype
dtype = dtype or model_dtype
self.dtype = DTYPE[dtype]
self.device = DEVICE[device or torch.device("cuda" if torch.cuda.is_available() else "cpu")]
torch.set_default_dtype(self.dtype)
if DTYPE[dtype] != DTYPE[model_dtype]:
warnings.warn(
f"Model dtype {model_dtype} != default dtype {dtype}. "
f"This may cause silent type conversions."
)
model = model.to(dtype=self.dtype)
self.target_property = model.get_target_property()
self.embedding_property = model.get_embedding_property()
self.implemented_properties = []
for p in self.target_property:
ase_name = PROPERTY[p]['ase_name']
save_name = ase_name if ase_name else p
if save_name == 'energy':
self.implemented_properties.extend(["energy" ,"free_energy"])
else:
self.implemented_properties.append(save_name)
self.max_neighbors = getattr(model.readout_fn, "max_neighbors", None)
self.cutoff = float(model.readout_fn.cutoff.item())
self.element = TorchElement([int(z) for z in model.readout_fn.atomic_numbers.cpu().tolist()])
self.neighborlist_backend = neighborlist_backend
model.eval()
for param in model.parameters():
param.requires_grad = False
self.keySpecification = KeySpecification()
update_keyspec_from_kwargs(self.keySpecification, KEYS)
if level is not None:
self.level = level
model.reset_computing_level(level)
else:
self.level = model.get_computing_level()
if spin_on is not None:
self.spin_on = 1 if spin_on else 0
model.reset_spin_on(self.spin_on)
else:
self.spin_on = model.get_spin_on()
self.model = model.to(self.device)
def calculate(self, atoms=None, properties=None, system_changes=all_changes):
Calculator.calculate(self, atoms)
atoms.info['level'] = self.level # fidelity level
# === dataloader ===
data = [
from_atoms(
self.element,
atoms,
self.cutoff,
max_neighbors=self.max_neighbors,
target_property=self.target_property,
embedding_property=self.embedding_property,
keyspec=self.keySpecification,
training=False,
neighborlist_backend=self.neighborlist_backend,
)
]
dataloader = DataLoader(
dataset=data,
batch_size=1,
shuffle=False,
drop_last=False,
)
batch = next(iter(dataloader)).to(self.device)
# === forward ===
outs = self.model(batch)
# === update ===
self.results = {}
for p in self.target_property:
p_rank = PROPERTY[p]['rank']
p_scope = PROPERTY[p]['scope']
ase_name = PROPERTY[p]['ase_name']
save_name = ase_name if ase_name else p
if p_scope == 'per-system':
if p_rank == 0:
if p == 'energy':
energy = outs[p].detach().cpu().item()
self.results['energy'] = energy
self.results["free_energy"] = self.results['energy']
else:
self.results[save_name] = outs[p].detach().cpu().item()
else:
self.results[save_name] = outs[p].detach().cpu().numpy().squeeze(0)
elif p_scope == 'per-atom':
self.results[save_name] = outs[p].detach().cpu().numpy()
elif p_scope == 'per-edge':
self.results[save_name] = outs[p].detach().cpu().numpy()
else:
self.results[save_name] = outs[p].detach().cpu().numpy()
def add_dispersion(
base_calc: Calculator,
damping: str = "bj", # choices: ["zero", "bj", "zerom", "bjm"]
dispersion_xc: str = "pbe",
dispersion_cutoff: float = 40.0 * units.Bohr,
**kwargs,
) -> SumCalculator:
try:
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
except ImportError as e:
raise RuntimeError(
"Please install torch-dftd to use dispersion corrections (see https://github.com/pfnet-research/torch-dftd)"
) from e
d3_calc = TorchDFTD3Calculator(
dtype=base_calc.dtype,
device=base_calc.device,
damping=damping,
xc=dispersion_xc,
cutoff=dispersion_cutoff,
**kwargs,
)
return SumCalculator([base_calc, d3_calc])