Source code for tace.interface.ase.calculator

################################################################################
# Authors: Zemin Xu
# License: MIT, see LICENSE.md
################################################################################

import warnings
from typing import Union


import torch
from ase import units
from ase.calculators.calculator import Calculator, all_changes
from ase.calculators.mixing import SumCalculator
from torch_geometric.loader import DataLoader


from tace.models.adapter import TensorModel
from tace.lightning import load_tace
from tace.dataset.quantity import PROPERTY
from tace.dataset.graph import from_atoms
from tace.dataset.quantity import (
    PROPERTY,
    KEYS,
    KeySpecification,
    update_keyspec_from_kwargs,
)
from tace.utils._global import DTYPE, DEVICE


[docs] class TACEAseCalc(Calculator): """ Initialize a TACEAseCalc. We support the most fundamental potential energy surface property and multi-fidelity, multi-head, etc. For some advanced features, you need to store the attributes that need to be embedded in atoms.info, atoms.arrays or add a funciton by yourself. If you only need to predict, you can directly use the `tace-eval` command. It will output the predicted files, and if you add the `--test` option, it will also output the errors. Parameters ---------- model_path : str Path to the trained model, file ends with pt, .pth or .ckpt. dtype : str, optional Model dtype for computations, e.g., float32 or float64. device : str, optional The device to run computations on, e.g., cpu or cuda. If None, the device is automatically inferred. fidelity_idx : int, optional Specify which fidelity fidelity_idx to use. target_property: list(str), optional Extra caculate hessian, atomic_virials, Conservative polarizability, etc, If you want to use this parameter, you must provide all the required physical quantities. neighborlist_backend: str Support backend in one of [ase, matscipy, vesin], recommend matscipy **kwargs Additional keyword arguments passed to the ASE Calculator base class. """ def __init__( self, model: str, *, dtype: Union[str, None] = None, device: Union[str, None] = None, fidelity_idx: Union[int, None] = None, target_property: Union[list[str], None] = None, neighborlist_backend: str = "matscipy", **kwargs, ): super().__init__(**kwargs) # === init === model: TensorModel = load_tace( model, device, strict=True, use_ema=True, target_property=target_property ) model.eval() for param in model.parameters(): param.requires_grad = False model_dtype = model.get_model_dtype() self.dtype = DTYPE[dtype or model_dtype] self.device = DEVICE[device or torch.device("cuda" if torch.cuda.is_available() else "cpu")] torch.set_default_dtype(self.dtype) if DTYPE[dtype] != DTYPE[model_dtype]: print(f"[Warning] Model dtype {(model_dtype)} does not match args.dtype {(dtype)}. Forcing dtype to {self.dtype}") model = model.to(dtype=self.dtype) self.target_property = model.get_target_property() self.embedding_property = model.get_embedding_property() self.max_neighbors = model.get_max_neighbors() self.cutoff = model.get_cutoff() self.element = model.get_torch_element() self.neighborlist_backend = neighborlist_backend self.implemented_properties = [] for p in self.target_property: ase_name = PROPERTY[p]['ase_name'] save_name = ase_name if ase_name else p if save_name == 'energy': self.implemented_properties.extend(["energy" ,"free_energy"]) else: self.implemented_properties.append(save_name) if fidelity_idx is not None: self.fidelity_idx = fidelity_idx model.reset_fidelity_idx(fidelity_idx) else: self.fidelity_idx = model.get_fidelity_idx() self.keySpecification = KeySpecification() update_keyspec_from_kwargs(self.keySpecification, KEYS) self.model = model.to(self.device) def calculate(self, atoms=None, properties=None, system_changes=all_changes): Calculator.calculate(self, atoms) atoms.info["fidelity_idx"] = self.fidelity_idx # fidelity fidelity_idx # === dataloader === data = [ from_atoms( self.element, atoms, self.cutoff, max_neighbors=self.max_neighbors, target_property=self.target_property, embedding_property=self.embedding_property, keyspec=self.keySpecification, training=False, neighborlist_backend=self.neighborlist_backend, ) ] dataloader = DataLoader( dataset=data, batch_size=1, shuffle=False, drop_last=False, ) batch = next(iter(dataloader)).to(self.device) # === forward === outs = self.model(batch) # === update === self.results = {} for p in self.target_property: p_rank = PROPERTY[p]['rank'] p_scope = PROPERTY[p]['scope'] ase_name = PROPERTY[p]['ase_name'] save_name = ase_name if ase_name else p if p_scope == 'per-system': if p_rank == 0: if p == 'energy': energy = outs[p].detach().cpu().item() self.results['energy'] = energy self.results["free_energy"] = self.results['energy'] else: self.results[save_name] = outs[p].detach().cpu().item() else: self.results[save_name] = outs[p].detach().cpu().numpy().squeeze(0) elif p_scope == 'per-atom': self.results[save_name] = outs[p].detach().cpu().numpy() elif p_scope == 'per-edge': self.results[save_name] = outs[p].detach().cpu().numpy() else: self.results[save_name] = outs[p].detach().cpu().numpy()
def add_dispersion( base_calc: Calculator, damping: str = "bj", # choices: ["zero", "bj", "zerom", "bjm"] dispersion_xc: str = "pbe", dispersion_cutoff: float = 40.0 * units.Bohr, **kwargs, ) -> SumCalculator: try: from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator except ImportError as e: raise RuntimeError( "Please install torch-dftd to use dispersion corrections (see https://github.com/pfnet-research/torch-dftd)" ) from e d3_calc = TorchDFTD3Calculator( dtype=base_calc.dtype, device=base_calc.device, damping=damping, xc=dispersion_xc, cutoff=dispersion_cutoff, **kwargs, ) return SumCalculator([base_calc, d3_calc])