Source code for enerzyme.data.supplier

from rdkit import Chem
from typing import Iterator, Dict, Any, List
from abc import ABC, abstractmethod
import numpy as np
import pickle


[docs] class Supplier(ABC): name: str
[docs] def __init__(self, supplying_fields: List[str]=["atom_type", "Ra", "Q", "mol"]): self.supplying_fields = set(supplying_fields) self.field_to_value = { "atom_type": self._get_atom_type, "Ra": self._get_coord, "Q": self._get_charge, "mol": self._get_mol, "Za": self._get_nuclear_charge, "N": self._get_num_atoms, }
[docs] @abstractmethod def suppl(self) -> Iterator[Dict[str, Any]]: ...
def _get_atom_type(self, mol: Chem.Mol) -> np.ndarray: return np.array([atom.GetSymbol() for atom in mol.GetAtoms()]) def _get_coord(self, mol: Chem.Mol) -> np.ndarray: return np.array(mol.GetConformer().GetPositions()) def _get_charge(self, mol: Chem.Mol) -> int: return Chem.GetFormalCharge(mol) def _get_num_atoms(self, mol: Chem.Mol) -> int: return mol.GetNumAtoms() def _get_nuclear_charge(self, mol: Chem.Mol) -> np.ndarray: return np.array([atom.GetAtomicNum() for atom in mol.GetAtoms()]) def _get_mol(self, mol: Chem.Mol) -> Chem.Mol: return mol
[docs] def get_package(self, mol: Chem.Mol) -> Dict[str, Any]: package = {} for field in self.supplying_fields: package[field] = self.field_to_value[field](mol) return package
[docs] def raw_data(self) -> Dict[str, Any]: raw_data = {field: [] for field in self.supplying_fields} for package in self.suppl(): for field in self.supplying_fields: raw_data[field].append(package[field]) return raw_data
[docs] class SDFSupplier(Supplier):
[docs] def __init__(self, sdf_file: str, start: int = 0, end: int = -1, supplying_fields: List[str]=["atom_type", "Ra", "Q", "mol"], **kwargs): super().__init__(supplying_fields) self.supplier = list(Chem.SDMolSupplier(sdf_file, removeHs=False)) self.start = start to_end = True if end >= 0: self.end = end to_end = False else: self.end = len(self.supplier) self.name = sdf_file.split("/")[-1].split(".")[0] + ( f"_{start}_{end}" if (start != 0 or not to_end) else "" )
[docs] def suppl(self): i = self.start for mol in self.supplier[self.start:self.end]: package = self.get_package(mol) package.update({"index": i}) yield package i += 1
[docs] class PickleSupplier:
[docs] def __init__(self, pkl_file: str, start: int = 0, end: int = -1, **kwargs): with open(pkl_file, "rb") as f: self.supplier = pickle.load(f) self.start = start to_end = True if end >= 0: self.end = end to_end = False else: self.end = len(self.supplier) self.name = pkl_file.split("/")[-1].split(".")[0] + ( f"_{start}_{end}" if (start != 0 or not to_end) else "" )
[docs] def suppl(self): i = self.start for item in self.supplier[self.start:self.end]: yield { "atom_type": item["Za"], "Ra": item["Ra"], "Q": item["Q"], "index": i, } i += 1
[docs] def get_supplier(path: str, start: int = 0, end: int = -1, **kwargs) -> Supplier: if path.endswith(".sdf"): return SDFSupplier(path, start, end, **kwargs) elif path.endswith(".pkl"): return PickleSupplier(path, start, end, **kwargs) else: raise ValueError(f"File type of {path} not supported")