Source code for tace.models._e3nn.node

################################################################################
# Authors: Zemin Xu
# License: MIT, see LICENSE.md
################################################################################


import torch
from e3nn.nn import Activation
from e3nn import o3


from tace.utils.torch_scatter import scatter_sum
from ..layout import LayoutTransform
from ..mlp import MLP
from .base import NodeEmbedding
from ..linear import e3nnLinear
from .fused import O3ScatterTensorProduct


[docs] class LinearNodeEmbedding(NodeEmbedding): """ A simple node embedding module based on a linear transformation. This class projects discrete node attributes (e.g., element types) into a continuous feature space using a single linear layer, without introducing nonlinearity or structural information. """ def _setup(self) -> None: self.irreps_out = o3.Irreps(f"{self.num_channel}x0e") self.elem_emb1 = e3nnLinear( f"{self.num_elements}x0e", f"{self.num_channel}x0e", bias=self.bias, ) def forward( self, node_attrs: torch.Tensor, edge_feats: torch.Tensor, edge_index: torch.Tensor, edge_attrs: torch.Tensor, cutoff: torch.Tensor ) -> torch.Tensor: return self.elem_emb1(node_attrs)
[docs] class NonLinearNodeEmbedding(NodeEmbedding): """ A node embedding module with nonlinear transformation. This class applies a nonlinear activation function after a linear projection. """ def _setup(self) -> None: self.irreps_out = o3.Irreps(f"{self.num_channel}x0e") self.elem_emb1 = e3nnLinear( f"{self.num_elements}x0e", f"{self.num_channel}x0e", bias=self.bias, ) self.act1 = Activation(self.elem_emb1.irreps_out, [torch.nn.SiLU()]) def forward( self, node_attrs: torch.Tensor, edge_feats: torch.Tensor, edge_index: torch.Tensor, edge_attrs: torch.Tensor, cutoff: torch.Tensor ) -> torch.Tensor: return self.act1(self.elem_emb1(node_attrs))
# class GroupNodeEmbedding(NodeEmbedding): # """ # A group-based node embedding module. # This class augments basic element embeddings with a learned grouping # mechanism, where nodes are softly assigned to 32 latent groups and # group-level representations are combined with element-wise features # to enhance expressiveness. # """ # def _setup(self) -> None: # self.irreps_out = o3.Irreps(f"{self.num_channel}x0e") # self.num_groups = 32 # self.elem_emb1 = e3nnLinear( # f"{self.num_elements}x0e", # f"{self.num_channel}x0e", # bias=self.bias # ) # self.elem_emb2 = e3nnLinear( # f"{self.num_elements}x0e", # f"{self.num_groups}x0e", # bias=self.bias # ) # self.group_emb1 = e3nnLinear( # f"{self.num_groups}x0e", # f"{self.num_channel}x0e", # bias=self.bias # ) # self.act1 = Activation(self.elem_emb1.irreps_out, [torch.nn.SiLU()]) # self.act2 = torch.nn.Softmax(dim=-1) # def forward( # self, # node_attrs: torch.Tensor, # edge_feats: torch.Tensor, # edge_index: torch.Tensor, # edge_attrs: torch.Tensor, # cutoff: torch.Tensor # ) -> torch.Tensor: # elem_emb = self.elem_emb1(node_attrs) # elem_emb = self.act1(elem_emb) # logits = self.elem_emb2(node_attrs) # scores = self.act2(logits) # group_emb = self.group_emb1(scores) # return elem_emb + group_emb
[docs] class TensorNodeEmbedding(NodeEmbedding): def _setup(self) -> None: self.node_embedding = e3nnLinear( f"{self.num_elements}x0e", f"{self.num_channel}x0e", bias=self.bias ) self.source_embedding = e3nnLinear( f"{self.num_elements}x0e", f"{self.num_channel}x0e", bias=self.bias ) self.target_embedding = e3nnLinear( f"{self.num_elements}x0e", f"{self.num_channel}x0e", bias=self.bias ) torch.nn.init.uniform_(self.source_embedding.weight, a=-0.001, b=0.001) torch.nn.init.uniform_(self.target_embedding.weight, a=-0.001, b=0.001) self.rejector = O3ScatterTensorProduct( [(self.num_channel, (0, 1))], [(1, (l, (-1)**l)) for l in range(self.lmax+1)], [(1, (l, (-1)**l)) for l in range(self.Lmax+1)], ) self.irreps_out = self.rejector.irreps_out self.edge_info = MLP( channels=[ self.num_radial_basis + self.num_channel * 2, self.num_channel, self.num_channel, self.rejector.weight_numel, ], bias=True, layer_norm=True, ) def forward( self, node_attrs: torch.Tensor, edge_feats: torch.Tensor, edge_index: torch.Tensor, edge_attrs: torch.Tensor, cutoff: torch.Tensor ) -> torch.Tensor: base_node_feats = self.node_embedding(node_attrs) source_feats = self.source_embedding(node_attrs[edge_index[0]]) target_feats = self.target_embedding(node_attrs[edge_index[1]]) conv_weights = self.edge_info(torch.cat([edge_feats, source_feats, target_feats], dim=-1)) if cutoff is not None: conv_weights = conv_weights * cutoff node_feats = self.rejector( torch.ones_like(base_node_feats), edge_attrs, conv_weights, edge_index, ) / self.avg_num_neighbors node_feats[:, :self.num_channel] = node_feats.narrow(1, 0, self.num_channel) + base_node_feats return node_feats
[docs] class SO2TensorNodeEmbedding(NodeEmbedding): def _setup(self) -> None: self.node_embedding = e3nnLinear( f"{self.num_elements}x0e", f"{self.num_channel}x0e", bias=self.bias ) self.source_embedding = e3nnLinear( f"{self.num_elements}x0e", f"{self.num_channel}x0e", bias=self.bias ) self.target_embedding = e3nnLinear( f"{self.num_elements}x0e", f"{self.num_channel}x0e", bias=self.bias ) torch.nn.init.uniform_(self.source_embedding.weight, a=-0.001, b=0.001) torch.nn.init.uniform_(self.target_embedding.weight, a=-0.001, b=0.001) self.edge_info = MLP( channels=[ self.num_radial_basis + self.num_channel * 2, self.num_channel, self.num_channel, (self.Lmax + 1) * self.num_channel, ], bias=True, layer_norm=True, ) self.irreps_out = o3.Irreps([(self.num_channel, (l, (-1)**l)) for l in range(self.Lmax + 1)]) self.reshape = LayoutTransform(self.irreps_out) def forward( self, node_attrs: torch.Tensor, edge_feats: torch.Tensor, edge_index: torch.Tensor, edge_attrs: torch.Tensor, cutoff: torch.Tensor ) -> torch.Tensor: base_node_feats = self.node_embedding(node_attrs) source_feats = self.source_embedding(node_attrs[edge_index[0]]) target_feats = self.target_embedding(node_attrs[edge_index[1]]) edge_feats = self.edge_info(torch.cat([edge_feats, source_feats, target_feats], dim=-1)) if cutoff is not None: edge_feats = edge_feats * cutoff edge_feats = edge_feats.view(edge_feats.size(0), (self.Lmax + 1), self.num_channel) edge_feats = torch.bmm( self.so2_angular_basis.wigner_inv.narrow(2, 0, (self.lmax + 1)), edge_feats ) # (edge, so3_m, C) node_feats = scatter_sum( edge_feats, edge_index[1], dim=0, dim_size=base_node_feats.size(0) ) / self.avg_num_neighbors node_feats[:, 0:1, :] = node_feats.narrow(1, 0, 1) + base_node_feats.unsqueeze(1) return self.reshape.inverse(node_feats)
NODE_EMBEDDING = { "linear": LinearNodeEmbedding, "nonlinear": NonLinearNodeEmbedding, # "group": GroupNodeEmbedding, "tensor": TensorNodeEmbedding, "so2_tensor": SO2TensorNodeEmbedding, }