OpenMM Tutorial#
This tutorial demonstrates how to use a TACE model as a calculator within OpenMM. We integrate OpenMM via the ASE interface provided by openmm-ml. Example usage is shown below, and for additional details, please refer to the OpenMM documentation.
OpenMM-ML github: github OpenMM-ML documentation: docs
'''
Based on https://github.com/openmm/openmm-ml/blob/main/test/TestASEPotential.py
'''
import os
import numpy as np
import openmm as mm
import openmm.app as app
import openmm.unit as unit
from openmmml import MLPotential
from ase.io import read
from tace.foundations import tace_foundations
from tace.interface.ase import TACEAseCalc, add_dispersion
model = tace_foundations["TACE-OAM-L"]
calculator = TACEAseCalc(
model=model,
dtype='float32',
device='cuda',
)
test_data_dir = '../data/'
def kJmol2eV(energy):
return (energy / unit.AVOGADRO_CONSTANT_NA).value_in_unit(unit.ev)
def testCalculator(platform_int):
pdb = app.PDBFile(os.path.join(test_data_dir, "toluene", "toluene.pdb"))
potential = MLPotential('ase')
system = potential.createSystem(pdb.topology, calculator=calculator)
platform = mm.Platform.getPlatform(platform_int)
context = mm.Context(system, mm.VerletIntegrator(0.001), platform)
context.setPositions(pdb.getPositions(asNumpy=True))
energyML = context.getState(energy=True).getPotentialEnergy() # kJ/mol
energyRef = -92.44810485839844
assert np.isclose(energyRef, kJmol2eV(energyML), rtol=1e-6)
print('testCalculator pass')
def testAtoms(platform_int):
path = os.path.join(test_data_dir, "toluene", "toluene.pdb")
pdb = app.PDBFile(path)
atoms = read(path)
atoms.calc = calculator
potential = MLPotential('ase')
system = potential.createSystem(pdb.topology, aseAtoms=atoms)
platform = mm.Platform.getPlatform(platform_int)
context = mm.Context(system, mm.VerletIntegrator(0.001), platform)
context.setPositions(pdb.getPositions(asNumpy=True))
energyML = context.getState(energy=True).getPotentialEnergy()
assert np.isclose(-92.44810485839844, kJmol2eV(energyML), rtol=1e-6)
print('testAtoms pass')
def testPeriodicSystem(platform_int):
pdb = app.PDBFile(os.path.join(test_data_dir, "alanine-dipeptide", "alanine-dipeptide-explicit.pdb"))
potential = MLPotential('ase')
system = potential.createSystem(pdb.topology, calculator=calculator)
platform = mm.Platform.getPlatform(platform_int)
context = mm.Context(system, mm.VerletIntegrator(0.0001), platform)
positionsOriginal = pdb.getPositions(asNumpy=True)
energyRef = -10992.9296875
for i in range(3):
positions = positionsOriginal + i * 0.9 * unit.nanometers # translate molecule to test PBC
context.setPositions(positions)
energyML = context.getState(getEnergy=True).getPotentialEnergy()
assert np.isclose(energyRef, kJmol2eV(energyML), rtol=1e-5)
context.setPositions(positionsOriginal)
energyML = context.getState(getEnergy=True).getPotentialEnergy()
print('testPeriodicSystem pass')
def testCreateMixedSystem(platform_int):
prmtop = app.AmberPrmtopFile(os.path.join(test_data_dir, "toluene", "toluene-explicit.prm7"))
inpcrd = app.AmberInpcrdFile(os.path.join(test_data_dir, "toluene", "toluene-explicit.rst7"))
mlAtoms = list(range(15))
mmSystem = prmtop.createSystem(nonbondedMethod=app.PME)
potential = MLPotential('ase')
mixedSystem = potential.createMixedSystem(prmtop.topology, mmSystem, mlAtoms, interpolate=False, calculator=calculator)
interpSystem = potential.createMixedSystem(prmtop.topology, mmSystem, mlAtoms, interpolate=True, calculator=calculator)
platform = mm.Platform.getPlatform(platform_int)
mmContext = mm.Context(mmSystem, mm.VerletIntegrator(0.001), platform)
mixedContext = mm.Context(mixedSystem, mm.VerletIntegrator(0.001), platform)
interpContext = mm.Context(interpSystem, mm.VerletIntegrator(0.001), platform)
mmContext.setPositions(inpcrd.positions)
mixedContext.setPositions(inpcrd.positions)
interpContext.setPositions(inpcrd.positions)
mmEnergy = mmContext.getState(getEnergy=True).getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
mixedEnergy = mixedContext.getState(getEnergy=True).getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
interpEnergy1 = interpContext.getState(getEnergy=True).getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
interpContext.setParameter('lambda_interpolate', 0)
interpEnergy2 = interpContext.getState(getEnergy=True).getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
assert np.isclose(mixedEnergy, interpEnergy1, rtol=1e-5)
assert np.isclose(mmEnergy, interpEnergy2, rtol=1e-5)
print('testCreateMixedSystem pass')
print('Available Platform:')
for i in range(mm.Platform.getNumPlatforms()):
p = mm.Platform.getPlatform(i)
print(" -", i, p.getName())
print()
testCalculator(2)
testAtoms(2)
testPeriodicSystem(2)
testCreateMixedSystem(2)