################################################################################
# Authors: Zemin Xu
# License: MIT, see LICENSE.md
################################################################################
import math
from typing import Optional, Dict
import torch
from e3nn import o3
from tace.utils.torch_scatter import scatter_sum
from ..mlp import ACTIVATION, FFN
from ..layout import LayoutTransform
from .base import Interaction
from ..linear import e3nnLinear, e3nnElementLinear
from .fused import O3ScatterTensorProduct, SO2ScatterTensorProduct
from .attn import SO2Attention
from .nonlinear import O3Gate, O3Norm
from .layer_norm import get_normalization_layer
[docs]
class CgtpInteraction(Interaction):
"""
An interaction module based on Clebsch-Gordan tensor products (CGTP).
This module performs edge-level convolution using Clebsch-Gordan tensor
products. It supports operator fusion via OpenEquivariance or CuEquivariance,
which can significantly reduce memory consumption and improve efficiency.
"""
def _setup(self) -> None:
# self.use_graph_softmax = False
self.linear_up = e3nnLinear(
self.irreps_in,
self.irreps_in,
bias=self.use_bias,
)
self.rejector = O3ScatterTensorProduct(
self.irreps_in,
self.irreps_sh,
self.irreps_out,
l1l2=self.l1l2,
)
if self.irreps_in.lmax > 0 and self.use_graph_softmax:
if self.scatter_norm is None:
pass
elif self.scatter_norm == 'avg_num_neighbors':
self.scatter_norm = None
else:
self.avg_num_neighbors = 1.0
self.scatter_norm = None
self.attention = SO2Attention(
mmax=0,
lmax=self.Lmax,
num_channel=self.num_channel,
edge_wise_hidden=self.edge_wise_hidden,
so2_angular_basis=self.so2_angular_basis,
reshape_in=LayoutTransform(self.irreps_in),
num_head=self.num_head,
weights_shape=[ins.path_shape for ins in self.rejector.tp.instructions if ins.has_weight],
)
self.edge_info_attn = FFN[self.edge_info_type](
[self.edge_feats_channel] + self.radial_mlp + [self.attention.weight_numel],
bias=self.radial_bias,
layer_norm=self.radial_layer_norm,
act=self.radial_act,
)
irreps_node_wise_hidden = o3.Irreps([(self.node_wise_hidden, ir) for _, ir in self.irreps_out])
if self.nonlinear_type == 'gate':
irreps_gated = irreps_node_wise_hidden
irreps_gates = o3.Irreps([mul, (0, 1)] for mul, _ in irreps_node_wise_hidden)
self.nonlinearity = O3Gate(
irreps_gates=irreps_gates,
act_gates=[ACTIVATION[self.nonlinear_act]()] * len(irreps_gates),
irreps_gated=irreps_gated,
)
linear_down_irreps_out = self.nonlinearity.irreps_in.simplify()
self.linear_nonlinearity = e3nnLinear(
irreps_node_wise_hidden,
self.irreps_out,
bias=self.use_bias,
)
elif self.nonlinear_type == 'norm':
self.nonlinearity = O3Norm(
irreps=irreps_node_wise_hidden,
activation=[ACTIVATION[self.nonlinear_act]()],
)
linear_down_irreps_out = self.nonlinearity.irreps_in.simplify()
self.linear_nonlinearity = e3nnLinear(
irreps_node_wise_hidden,
self.irreps_out,
bias=self.use_bias,
)
else:
self.nonlinearity = torch.nn.Identity()
self.linear_nonlinearity = torch.nn.Identity()
linear_down_irreps_out = irreps_node_wise_hidden
self.linear_down = e3nnLinear(
self.rejector.irreps_out.simplify(),
linear_down_irreps_out,
bias=self.use_bias,
)
self.edge_info = FFN[self.edge_info_type](
[self.edge_feats_channel] + self.radial_mlp + [self.rejector.weight_numel],
bias=self.radial_bias,
layer_norm=self.radial_layer_norm,
act=self.radial_act,
)
if self.scatter_norm == 'density' or self.scatter_norm == 'no_cutoff_density':
self.edge_density = FFN[self.edge_info_type](
[self.edge_feats_channel, 64, 1],
bias=self.radial_bias,
layer_norm=self.radial_layer_norm,
act=self.radial_act,
) # From MACE
self.alpha = torch.nn.Parameter(torch.tensor(self.avg_num_neighbors))
self.beta = torch.nn.Parameter(torch.tensor(0.0))
if (self.use_first_resnet or self.layer > 0) and self.resnet_type == 'BB':
if self.resnet_linear_type == 'agnostic':
self.resnetBB = e3nnLinear(
irreps_in=self.irreps_in,
irreps_out=self.irreps_sc,
bias=self.use_bias,
)
else:
self.resnetBB = e3nnElementLinear(
irreps_in = self.irreps_in,
irreps_out = self.irreps_sc,
bias=self.use_bias,
num_elements=self.num_elements,
)
if (self.use_first_resnet or self.layer > 0) and self.resnet_type == 'BAB':
if self.resnet_linear_type == 'agnostic':
self.resnetBA = e3nnLinear(
irreps_in = self.irreps_in,
irreps_out = self.irreps_out,
bias=self.use_bias,
)
else:
self.resnetBA = e3nnElementLinear(
irreps_in = self.irreps_in,
irreps_out = self.irreps_out,
bias=self.use_bias,
num_elements=self.num_elements,
)
if (self.use_first_resnet or self.layer > 0) and self.resnet_type in ['AB', 'BAB']:
if self.resnet_linear_type == 'agnostic':
self.resnetAB = e3nnLinear(
irreps_in = self.irreps_out,
irreps_out = self.irreps_sc,
bias=self.use_bias,
)
else:
self.resnetAB = e3nnElementLinear(
irreps_in = self.irreps_out,
irreps_out = self.irreps_sc,
bias=self.use_bias,
num_elements=self.num_elements,
)
if (self.use_first_pre_norm or self.layer > 0) and self.pre_norm_type is not None:
if self.resnet_type in ['BB', "BAB"]:
self.norm1 = get_normalization_layer(
self.pre_norm_type,
ls=self.irreps_in.lmax,
num_channels=self.num_channel,
)
self.reshape1 = LayoutTransform(self.irreps_in)
if self.resnet_type in ['AB', "BAB"]:
self.norm2 = get_normalization_layer(
self.pre_norm_type,
ls=self.irreps_out.lmax,
num_channels=self.num_channel,
)
self.reshape2 = LayoutTransform(self.irreps_out)
def forward(
self,
node_feats: torch.Tensor,
node_attrs_total: torch.Tensor,
node_attrs_slice: torch.Tensor,
edge_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_index: torch.Tensor,
cutoff: Optional[torch.Tensor],
graph,
):
if isinstance(edge_feats, list):
edge_feats = torch.cat(edge_feats, dim=-1)
lmp_data = graph.lmp_data
lmp_natoms = graph.lmp_natoms
nlocal = lmp_natoms[0] if lmp_data is not None else None
density = None
resBB = None
resBA = None
resAB = None
if hasattr(self, 'resnetBB'):
if self.resnet_linear_type == 'aware':
resBB = self.resnetBB(node_feats, node_attrs_slice)
else:
resBB = self.resnetBB(node_feats)
if hasattr(self, 'resnetBA'):
if self.resnet_linear_type == 'aware':
resBA = self.resnetBA(node_feats, node_attrs_slice)
else:
resBA = self.resnetBA(node_feats)
if hasattr(self, 'norm1'):
node_feats = self.reshape1.inverse(self.norm1(self.reshape1(node_feats)))
node_feats = self.linear_up(node_feats)
node_feats = self.handle_lammps(node_feats, lmp_data, lmp_natoms, self.layer)
conv_weights = self.edge_info(edge_feats)
if self.irreps_in.lmax > 0 and self.use_graph_softmax:
conv_weights = self.attention(
node_feats,
self.edge_info_attn(edge_feats),
edge_index,
cutoff,
conv_weights,
)
else:
if cutoff is not None:
conv_weights = conv_weights * cutoff
m_i = self.linear_down(
self.truncate_ghosts(
self.rejector(node_feats, edge_attrs, conv_weights, edge_index),
nlocal
)
)
if hasattr(self, "edge_density"):
density = torch.tanh(self.edge_density(edge_feats) ** 2)
if cutoff is not None and self.apply_density_cutoff:
density = density * cutoff
# density = density * cutoff
density = scatter_sum(density, edge_index[1], dim=0, dim_size=node_attrs_total.size(0))
density = self.truncate_ghosts(density , nlocal)
density = density * self.beta + self.alpha
density = density.masked_fill(density == 0, 1e-9)
if self.scatter_norm is None:
pass
elif self.scatter_norm == 'avg_num_neighbors':
m_i = m_i / self.avg_num_neighbors
else:
m_i = m_i / density
m_i = self.linear_nonlinearity(self.nonlinearity(m_i))
if resBA is not None:
m_i = m_i + resBA
if hasattr(self, 'resnetAB'):
if self.resnet_linear_type == 'aware':
resAB = self.resnetAB(m_i, node_attrs_slice)
else:
resAB = self.resnetAB(m_i)
if hasattr(self, 'norm2'):
m_i = self.reshape2.inverse(self.norm2(self.reshape2(m_i)))
if resBB is not None:
sc = resBB
elif resAB is not None:
sc = resAB
else:
sc = None
return m_i, self.truncate_ghosts(sc, nlocal)
[docs]
class SO2Interaction(Interaction):
def _setup(self) -> None:
assert self.parity == False, "SO2Interaction not support O(3) group"
assert self.irreps_in.lmax > 0, (
"SO2Interaction's irreps_in.lmax must > 0, "
"use SO2Interaction from the second layer or use other node_embedding with l > 0"
)
assert self.edge_nonlinear is not None, "SO2Interaction forces to use edge nonlinear"
if self.use_graph_softmax: self.scatter_norm = None
self.linear_up = e3nnLinear(
self.irreps_in,
self.irreps_in,
bias=self.use_bias,
)
self.rejector = SO2ScatterTensorProduct(
mmax=self.mmax,
lmax=self.lmax,
num_channel=self.num_channel,
edge_wise_hidden=self.edge_wise_hidden,
num_elements=self.num_elements,
so2_angular_basis=self.so2_angular_basis,
reshape_in=LayoutTransform(self.irreps_in),
reshape_out=LayoutTransform(self.irreps_out),
num_head=self.num_head,
use_graph_softmax=self.use_graph_softmax,
use_so2_edge_ace=self.use_so2_edge_ace,
)
irreps_node_wise_hidden = o3.Irreps([(self.node_wise_hidden, ir) for _, ir in self.irreps_out])
if self.nonlinear_type == 'gate':
irreps_gated = irreps_node_wise_hidden
irreps_gates = o3.Irreps([mul, (0, 1)] for mul, _ in irreps_node_wise_hidden)
self.nonlinearity = O3Gate(
irreps_gates=irreps_gates,
act_gates=[ACTIVATION[self.nonlinear_act]()] * len(irreps_gates),
irreps_gated=irreps_gated,
)
linear_down_irreps_out = self.nonlinearity.irreps_in.simplify()
self.linear_nonlinearity = e3nnLinear(
irreps_node_wise_hidden,
self.irreps_out,
bias=self.use_bias,
)
elif self.nonlinear_type == 'norm':
self.nonlinearity = O3Norm(
irreps=irreps_node_wise_hidden,
activation=[ACTIVATION[self.nonlinear_act]()],
)
linear_down_irreps_out = self.nonlinearity.irreps_in.simplify()
self.linear_nonlinearity = e3nnLinear(
irreps_node_wise_hidden,
self.irreps_out,
bias=self.use_bias,
)
else:
self.nonlinearity = torch.nn.Identity()
self.linear_nonlinearity = torch.nn.Identity()
linear_down_irreps_out = irreps_node_wise_hidden
self.linear_down = e3nnLinear(
o3.Irreps([(self.edge_wise_hidden, ir) for _, ir in self.irreps_out]),
linear_down_irreps_out,
bias=self.use_bias,
)
self.edge_info = FFN[self.edge_info_type](
[self.edge_feats_channel] + self.radial_mlp + [self.rejector.weight_numel],
bias=self.radial_bias,
layer_norm=self.radial_layer_norm,
act=self.radial_act,
)
if self.scatter_norm == 'density' or self.scatter_norm == 'no_cutoff_density':
self.edge_density = FFN[self.edge_info_type](
[self.edge_feats_channel, 64, 1],
bias=self.radial_bias,
layer_norm=self.radial_layer_norm,
act=self.radial_act,
) # From MACE
self.alpha = torch.nn.Parameter(torch.tensor(self.avg_num_neighbors))
self.beta = torch.nn.Parameter(torch.tensor(0.0))
if (self.use_first_resnet or self.layer > 0) and self.resnet_type == 'BB':
if self.resnet_linear_type == 'agnostic':
self.resnetBB = e3nnLinear(
irreps_in=self.irreps_in,
irreps_out=self.irreps_sc,
bias=self.use_bias,
)
else:
self.resnetBB = e3nnElementLinear(
irreps_in = self.irreps_in,
irreps_out = self.irreps_sc,
bias=self.use_bias,
num_elements=self.num_elements,
)
if (self.use_first_resnet or self.layer > 0) and self.resnet_type == 'BAB':
if self.resnet_linear_type == 'agnostic':
self.resnetBA = e3nnLinear(
irreps_in = self.irreps_in,
irreps_out = self.irreps_out,
bias=self.use_bias,
)
else:
self.resnetBA = e3nnElementLinear(
irreps_in = self.irreps_in,
irreps_out = self.irreps_out,
bias=self.use_bias,
num_elements=self.num_elements,
)
if (self.use_first_resnet or self.layer > 0) and self.resnet_type in ['AB', 'BAB']:
if self.resnet_linear_type == 'agnostic':
self.resnetAB = e3nnLinear(
irreps_in = self.irreps_out,
irreps_out = self.irreps_sc,
bias=self.use_bias,
)
else:
self.resnetAB = e3nnElementLinear(
irreps_in = self.irreps_out,
irreps_out = self.irreps_sc,
bias=self.use_bias,
num_elements=self.num_elements,
)
if (self.use_first_pre_norm or self.layer > 0) and self.pre_norm_type is not None:
if self.resnet_type in ['BB', "BAB"]:
self.norm1 = get_normalization_layer(
self.pre_norm_type,
ls=self.irreps_in.lmax,
num_channels=self.num_channel,
)
self.reshape1 = LayoutTransform(self.irreps_in)
if self.resnet_type in ['AB', "BAB"]:
self.norm2 = get_normalization_layer(
self.pre_norm_type,
ls=self.irreps_out.lmax,
num_channels=self.num_channel,
)
self.reshape2 = LayoutTransform(self.irreps_out)
def forward(
self,
node_feats: torch.Tensor,
node_attrs_total: torch.Tensor,
node_attrs_slice: torch.Tensor,
edge_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_index: torch.Tensor,
cutoff: Optional[torch.Tensor],
graph,
):
lmp_data = graph.lmp_data
lmp_natoms = graph.lmp_natoms
nlocal = lmp_natoms[0] if lmp_data is not None else None
density = None
resBB = None
resBA = None
resAB = None
if hasattr(self, 'resnetBB'):
if self.resnet_linear_type == 'aware':
resBB = self.resnetBB(node_feats, node_attrs_slice)
else:
resBB = self.resnetBB(node_feats)
if hasattr(self, 'resnetBA'):
if self.resnet_linear_type == 'aware':
resBA = self.resnetBA(node_feats, node_attrs_slice)
else:
resBA = self.resnetBA(node_feats)
if hasattr(self, 'norm1'):
node_feats = self.reshape1.inverse(self.norm1(self.reshape1(node_feats)))
node_feats = self.linear_up(node_feats)
node_feats = self.handle_lammps(node_feats, lmp_data, lmp_natoms, self.layer)
conv_weights = self.edge_info(edge_feats)
if cutoff is not None:
conv_weights = conv_weights * cutoff
m_i = self.linear_down(
self.truncate_ghosts(
self.rejector(node_feats, node_attrs_slice, conv_weights, edge_index, cutoff),
nlocal
) # check node_attrs_slice TODO
)
if hasattr(self, "edge_density"):
density = torch.tanh(self.edge_density(edge_feats) ** 2)
if cutoff is not None and self.apply_density_cutoff:
density = density * cutoff
density = scatter_sum(density, edge_index[1], dim=0, dim_size=node_attrs_total.size(0))
density = self.truncate_ghosts(density , nlocal)
density = density * self.beta + self.alpha
density = density.masked_fill(density == 0, 1e-9)
if self.scatter_norm is None:
m_i = m_i
elif self.scatter_norm == 'avg_num_neighbors':
m_i = m_i / self.avg_num_neighbors
else:
m_i = m_i / density
m_i = self.linear_nonlinearity(self.nonlinearity(m_i))
if resBA is not None:
m_i = m_i + resBA
if hasattr(self, 'resnetAB'):
if self.resnet_linear_type == 'aware':
resAB = self.resnetAB(m_i, node_attrs_slice)
else:
resAB = self.resnetAB(m_i)
if hasattr(self, 'norm2'):
m_i = self.reshape2.inverse(self.norm2(self.reshape2(m_i)))
if resBB is not None:
sc = resBB
elif resAB is not None:
sc = resAB
else:
sc = None
return m_i, self.truncate_ghosts(sc, nlocal)
INTERACTION: Dict[str, Interaction] = {
"normal": CgtpInteraction,
"spectral": CgtpInteraction,
"cgtp": CgtpInteraction,
"so2": SO2Interaction,
}