Source code for enerzyme.data.datahub

import pickle, os
from hashlib import md5
from typing import Union, List, Dict, Optional, Iterable, Literal
import h5py
import numpy as np
from addict import Dict
from tqdm import tqdm
from torch.utils.data import Dataset
from .datatype import is_atomic, is_rounded, is_int, register_data_type
from .transform import parse_Za, Transform
from ..utils import YamlHandler, logger


[docs] def load_from_pickle(data_path=str): with open(data_path, "rb") as f: data = pickle.load(f) if isinstance(data, list) and isinstance(data[0], dict): keys = set() for datapoint in data: keys.update(datapoint.keys()) logger.info(f"Collected keys from data: {keys}") dd = {key: [datapoint.get(key, None) for datapoint in data] for key in keys} return dd elif isinstance(data, dict): return data else: raise TypeError(f"Unknown data type in {data_path}!")
def _collect_types(types: Optional[Union[List, Dict]]) -> Dict: if types is None: return dict() elif isinstance(types, list): return {single_type: single_type for single_type in types} else: return {k: v if v is not None else k for k, v in types.items()}
[docs] def array_padding(data, max_N, pad_value=0): for i in range(len(data)): pad_shape = [(0, max_N - len(data[i]))] + [(0,0)] * (len(data[i].shape) - 1) data[i] = np.pad(data[i], pad_shape, constant_values=pad_value) return np.array(data)
[docs] class FieldDataset(Dataset):
[docs] def __init__(self, data: Dict[str, Iterable]) -> None: self.data = data self.compressed_keys = set() for k, v in self.data.items(): if len(v) == 1: self.compressed_keys.add(k)
def __getitem__(self, k) -> Iterable: return self.data[k] def __setitem__(self, k, v) -> None: self.data[k] = v if len(v) == 1: self.compressed_keys.add(k) def __contains__(self, k) -> bool: return k in self.data def __len__(self) -> int: for v in self.data.values(): if len(v) != 1: return len(v) else: return 1
[docs] def items(self): return self.data.items()
[docs] def keys(self): return self.data.keys()
[docs] def values(self): return self.data.values()
[docs] def loc(self, idx) -> Dict[str, Iterable]: return {k: v[0 if k in self.compressed_keys else idx] for k, v in self.data.items()}
[docs] def load_subset(self, indices: Iterable[int]) -> "FieldDataset": data = dict() for k, v in self.data.items(): if k in self.compressed_keys: data[k] = np.array(v) else: data[k] = np.array([v[idx] for idx in indices]) return FieldDataset(data)
[docs] class SingleDataHub:
[docs] def __init__(self, dump_dir=".", data_format: Optional[str]=None, data_path: str="", preload: bool=True, features: Dict[str, str]=dict(), targets: Dict[str, str]=dict(), preprocessings: Optional[Dict[str, Union[str, bool]]]=None, global_transforms: Optional[Dict[str, Union[str, bool]]]=None, neighbor_list: Optional[str]=None, hash_length: int=16, compressed: bool=True, max_memory: int=10, **params ): self.data_path = os.path.abspath(data_path) self.data_format = data_format self.preload = preload self.feature_types = _collect_types(features) self.target_types = _collect_types(targets) self.data_types = self.feature_types | self.target_types self.neighbor_list_type = neighbor_list self.compressed = compressed self.max_memory = max_memory datahub_str = data_path + neighbor_list + \ str(sorted(preprocessings.items()) if preprocessings is not None else '') + \ str(sorted(global_transforms.items()) if global_transforms is not None else '') self.hash = md5(datahub_str.encode("utf-8")).hexdigest()[:hash_length] self.preload_path = os.path.join(dump_dir, f"processed_dataset_{self.hash}") logger.info(f"Preload path {self.preload_path} is created") self.preprocessing = Transform(preprocessings, self.preload_path) self.global_transform = Transform(global_transforms, self.preload_path) self.preprocessings = preprocessings self.global_transforms = global_transforms if not self.preload or not self.preload_data(): self.get_handle("w") self._init_data() self._init_neighbor_list() self.preprocessing.transform(self.data) self.global_transform.transform(self.data) self._save_config() self.reset_handle()
def _preload_data(self, hdf5_path): loaded_file = h5py.File(hdf5_path, mode="r") loaded_data = loaded_file["data"] self.data["N"] = loaded_data["N"] for k in self.data_types: if k == "N": continue elif is_atomic(k): self._load_atomic_data(k, loaded_data) else: self._load_molecular_data(k, loaded_data) loaded_file.close()
[docs] def preload_data(self): hdf5_path = os.path.join(self.preload_path, "pre_transformed.hdf5") config_path = os.path.join(self.preload_path, "datahub.yaml") if ( os.path.isdir(self.preload_path) and os.path.isfile(hdf5_path) and os.path.isfile(config_path) ): handler = YamlHandler(config_path) datahub_config = handler.read_yaml() preload_data_types = _collect_types(datahub_config.feature) | _collect_types(datahub_config.target) if preload_data_types.keys() <= self.data_types.keys(): # all kinds of features and targets are contained in the processed dataset self.get_handle() logger.info(f"Data matched and preloaded from {self.preload_path}") return True return False
def _expand(self, k: str, values: Iterable) -> np.ndarray: if isinstance(values, int) or isinstance(values, float): if is_int(k) and self.compressed: return np.array([values]) else: logger.info(f"Values of {k} (data type {self.data_types[k]}) are single and repeated") return np.full(values, self.n_datapoint) else: if is_int(k) and self.compressed: return values else: logger.info(f"Values of {k} (data type {self.data_types[k]}) are single and repeated") return np.repeat(values, self.n_datapoint, axis=0) def _compress(self, k: str, values: Iterable) -> np.ndarray: # only works for equal length data value_array = np.array(values) if is_int(k) and self.compressed and (value_array == value_array[0]).all(): logger.info(f"Values of {k} (data type {self.data_types[k]}) are all the same and compressed into a single value") return value_array[:1] else: return value_array def _load_molecular_data(self, k: str, raw_data: Dict) -> None: if self.data_types[k] in raw_data.keys(): values = raw_data[self.data_types[k]] if isinstance(values, int) or isinstance(values, float) or len(values) == 1: self.data.create_dataset(k, data=self._expand(k, values)) elif len(values) == self.n_datapoint: self.data.create_dataset(k, data=self._compress(k, values)) else: raise IndexError(f"Length of '{k}' should be n_datapoint or 1") elif self.data_types[k + "a"] in raw_data.keys(): self._load_atomic_data(k + "a", raw_data) # reduce atomic property into molecular property, mainly for Qa into Q logger.info(f"Molecular property {k} are reduced from atomic property {k + 'a'} ({self.data_types[k + 'a']})") if is_rounded(k): values = [round(sum(self.data[k + "a"][i][:self.data["N"][i % len(self.data["N"])]])) for i in tqdm(range(self.n_datapoint))] else: values = [sum(self.data[k + "a"][i][:self.data["N"][i % len(self.data["N"])]]) for i in tqdm(range(self.n_datapoint))] # don't compress summation of atomic property self.data.create_dataset(k, data=np.array(values)) def _load_atomic_data(self, k: str, raw_data: Dict) -> None: if k in self.data: return values = raw_data[self.data_types[k]] v0 = np.array(values[0]) if len(values) == self.n_datapoint: # for a datapoint, the shape of this property is (N, a, b, ...) # for the whole dataset, the shape of this property is (n_datapoint, max_N, a, b, ...) self.data.create_dataset(k, shape=(self.n_datapoint, self.max_N, *v0.shape[1:]), dtype=v0.dtype) logger.info(f"Storing atomic data {k} ({self.data_types[k]})") for i, v in tqdm(enumerate(values), total=self.n_datapoint): self.data[k][i,:len(v)] = v elif len(values) == 1: self.data.create_dataset(k, data=self._expand(k, values)) else: raise IndexError(f"Length of {k} ({self.data_types[k]}) should be n_datapoint") def _init_data(self) -> None: if not os.path.isfile(self.data_path): raise ValueError(f"Data path {self.data_path} doesn't exist.") suffix = self.data_path.split(".")[-1] if self.data_format == "hdf5" or suffix == "hdf5": self.data_format = "hdf5" raw_data = h5py.File(self.data_path, mode="r")["data"] elif self.data_format == "pickle" or suffix == "pkl" or suffix == "pickle": self.data_format = "pickle" raw_data = load_from_pickle(self.data_path) elif self.data_format == "npz" or suffix == "npz": self.data_format = "npz" raw_data = np.load(self.data_path, allow_pickle=True) elif self.data_format == "sdf" or suffix == "sdf": self.data_format = "sdf" from .supplier import SDFSupplier supplier = SDFSupplier(self.data_path, supplying_fields=self.data_types.keys()) raw_data = supplier.raw_data() else: raise ValueError(f"Data format of {self.data_path} is unknown") if "Ra" not in self.data_types: # atomic position must be provided raise KeyError(f"Dataset must contain 'Ra' key (Atomic positions)") # number of datapoints is defined as number of different configurations n_datapoint = len(raw_data[self.data_types["Ra"]]) self.n_datapoint = n_datapoint if "Za" not in self.data_types: # atomic number/type must be provided raise KeyError(f"Dataset must contain 'Za' key (Atomic numbers)") n_Za = len(raw_data[self.data_types["Za"]]) Zas = parse_Za(raw_data[self.data_types["Za"]]) if n_Za == 1: if self.data_types["N"] not in raw_data.keys(): # atom count determined by length of atomic numbers self.data.create_dataset("N", data=self._expand("N", len(Zas))) else: self._load_molecular_data("N", raw_data) self.data.create_dataset("Za", data=self._expand("Za", Zas)) self.max_N = max(self.data["N"]) elif n_Za == n_datapoint: if self.data_types["N"] not in raw_data.keys(): self.data.create_dataset("N", data=self._compress("N", [len(Za) for Za in Zas])) else: self._load_molecular_data("N", raw_data) self.max_N = max(self.data["N"]) Za_compressed_flag = True Za0 = np.array(Zas[0]) N0 = len(Za0) for Za in Zas: if len(Za) != N0 or (Za != Za0).any(): Za_compressed_flag = False break if self.compressed and Za_compressed_flag: self.data.create_dataset("Za", data=[Za0]) else: self.data.create_dataset("Za", shape=(n_datapoint, self.max_N), dtype=int) logger.info(f'Storing Za ({self.data_types["Za"]})') for i, Za in tqdm(enumerate(Zas), total=self.n_datapoint): self.data["Za"][i,:len(Za)] = Za else: raise IndexError(f"Length of 'Za' should be n_datapoint or 1") for k in self.data_types: if k in ["Za", "N"]: continue elif is_atomic(k): self._load_atomic_data(k, raw_data) else: self._load_molecular_data(k, raw_data) if self.data_format in ["hdf5", "npz"]: raw_data.close() def _init_neighbor_list(self) -> None: if self.neighbor_list_type == "full": from .neighbor_list import full_neighbor_list logger.info("producing neighbor list") if self.compressed and len(self.data["N"]) == 1: idx_i, idx_j = full_neighbor_list(self.data["N"][0]) self.data.create_dataset("idx_i", data=[idx_i]) self.data.create_dataset("idx_j", data=[idx_j]) self.data.create_dataset("N_pair", data=[len(idx_i)]) else: max_N_pairs = self.max_N * (self.max_N - 1) self.data.create_dataset("idx_i", shape=(self.n_datapoint, max_N_pairs), dtype=int) self.data.create_dataset("idx_j", shape=(self.n_datapoint, max_N_pairs), dtype=int) self.data.create_dataset("N_pair", shape=self.n_datapoint, dtype=int) for i in tqdm(range(self.n_datapoint)): idx_i, idx_j = full_neighbor_list(self.data["N"][i]) self.data["N_pair"][i] = len(idx_i) self.data["idx_i"][i] = array_padding([idx_i], max_N_pairs, pad_value=-1) self.data["idx_j"][i] = array_padding([idx_j], max_N_pairs, pad_value=-1)
[docs] def get_handle(self, mode: Literal["r", "w"]="r") -> None: if mode == "w" and os.path.exists(self.preload_path): logger.warning(f"Preload path {self.preload_path} exists and will be overwritten") else: os.makedirs(self.preload_path, exist_ok=True) self.file = h5py.File(os.path.join(self.preload_path, "pre_transformed.hdf5"), mode=mode, rdcc_nbytes=1024 ** 3 * self.max_memory) if mode == "r": self.data = self.file["data"] else: self.file.clear() self.data = self.file.create_group("data")
[docs] def reset_handle(self): self.file.close() self.get_handle()
def _save_config(self): handler = YamlHandler(os.path.join(self.preload_path, "datahub.yaml")) datahub_config = Dict({ "feature": self.feature_types, "target": self.target_types, "preprocessings": self.preprocessings, "global_transforms": self.global_transforms, "neighbor_list": self.neighbor_list_type }) handler.write_yaml(datahub_config) logger.info(f"Save preloaded dataset at {self.preload_path}") @property def features(self) -> FieldDataset: return FieldDataset({k: v for k, v in self.data.items() if k in self.feature_types.keys() | {"idx_i", "idx_j", "N_pair"}}) @property def targets(self) -> FieldDataset: return FieldDataset({k: v for k, v in self.data.items() if k in self.target_types})
[docs] class DataHub:
[docs] def __init__(self, dump_dir=".", datasets: Optional[Union[List, Dict]]=None, fields: Optional[Dict[str, str]]=None, **params ): self.dump_dir = dump_dir if fields is not None: for k, v in fields.items(): register_data_type(k, **v) if datasets is None: if "global_transforms" not in params: params["global_transforms"] = params.get("transforms", None) self.datahubs = {"default": SingleDataHub(dump_dir=dump_dir, **params)} elif isinstance(datasets, list): self.datahubs = {str(i): SingleDataHub(dump_dir=dump_dir, global_transforms=params.get("global_transforms", None), **dataset_params) for i, dataset_params in enumerate(datasets)} elif isinstance(datasets, dict): self.datahubs = {name: SingleDataHub(dump_dir=dump_dir, global_transforms=params.get("global_transforms", None), **dataset_params) for name, dataset_params in datasets.items()} else: raise ValueError(f"Unknown type of datasets: {type(datasets)}")
@property def features(self) -> Dict[str, FieldDataset]: return {name: datahub.features for name, datahub in self.datahubs.items()} @property def targets(self) -> Dict[str, FieldDataset]: return {name: datahub.targets for name, datahub in self.datahubs.items()} @property def preload_path(self) -> Dict[str, str]: return {name: datahub.preload_path for name, datahub in self.datahubs.items()} @property def transform(self) -> Transform: return list(self.datahubs.values())[0].global_transform