import numpy as np
from shutil import copy, rmtree
import os
import subprocess
from typing import Any, Dict, Literal, Optional
from pathlib import Path
from abc import ABC, abstractmethod
from tqdm import tqdm
from pickle import dump, load
from rdkit import Chem
from multiprocessing import Pool, cpu_count
from ..utils import logger
from ..data.supplier import Supplier
[docs]
class QMDriver(ABC):
[docs]
def __init__(self,
supplier: Supplier, tmp_dir: str, output_dir: str, pickle_name: str,
bs: str, xc: str,
keep_molden: bool = False,
keep_stdout: bool = False,
clean_tmp: bool = True,
n_processes: int = 1,
dump_single_run: bool = True
):
'''
Base class for QM drivers.
Params:
-------
supplier: Supplier
The supplier of molecular data.
tmp_dir: str
The directory to store temporary files.
output_dir: str
The directory to store output files.
pickle_name: str
The name of the pickle file to store the results.
bs: str
The basis set to use.
xc: str
The exchange-correlation functional to use.
keep_molden: bool
Whether to keep the Molden files.
keep_stdout: bool
Whether to keep the stdout files.
clean_tmp: bool
Whether to clean the temporary files.
n_processes: int
The number of processes to use.
dump_single_run: bool
Whether to dump the single run results.
'''
self.supplier = supplier
self.tmp_dir_base = Path(tmp_dir).absolute() / self.supplier.name / "tmp"
self.output_dir = Path(output_dir).absolute() / self.supplier.name
self.output_path = (self.output_dir / pickle_name).with_suffix(".pkl")
self.keep_molden = keep_molden
if keep_molden:
os.makedirs(self.output_dir / "moldens", exist_ok=True)
self.keep_stdout = keep_stdout
if keep_stdout:
os.makedirs(self.output_dir / "stdout", exist_ok=True)
self.bs = bs
self.xc = xc
os.makedirs(self.output_dir, exist_ok=True)
self.clean_tmp = clean_tmp
self.dump_single_run = dump_single_run
if self.dump_single_run:
os.makedirs(self.output_dir / "single_run", exist_ok=True)
self.n_processes = n_processes if n_processes > 0 else cpu_count()
[docs]
@abstractmethod
def invoke_qm(self, input_file: str, tmp_dir: Path) -> str:
...
[docs]
@abstractmethod
def collect_results(self, input_file: Path, package: Dict[Literal["index", "atom_type", "Ra", "Q", "mol"], Any], tmp_dir: Path) -> Dict[str, Any]:
...
[docs]
def copy_files(self, output_file: Path, molden_file: Optional[Path]) -> None:
if self.keep_stdout:
if output_file.exists():
copy(output_file, self.output_dir / "stdout")
else:
raise FileNotFoundError(f"Output file {output_file} not found")
if self.keep_molden and molden_file is not None:
if molden_file.exists():
copy(molden_file, self.output_dir / "moldens")
else:
raise FileNotFoundError(f"Molden file {molden_file} not found")
[docs]
def dump_results(self, result_package: Dict[str, Any]) -> None:
with open(self.output_dir / "single_run" / f"{result_package['index']}.pkl", "wb") as f:
dump(result_package, f)
[docs]
def single_run(self, package: Dict[Literal["index", "atom_type", "Ra", "Q", "mol"], Any]):
single_run_path = self.output_dir / "single_run" / f"{package['index']}.pkl"
if single_run_path.exists():
with open(single_run_path, "rb") as f:
result_package = load(f)
logger.info(f"Single run results for {package['index']} already exist. Loading from {single_run_path}")
return result_package
tmp_dir = Path(str(self.tmp_dir_base) + f".{package['index']}")
os.makedirs(tmp_dir, exist_ok=True)
input_file = self.make_input(package=package, tmp_dir=tmp_dir)
output_file = self.invoke_qm(input_file, tmp_dir)
try:
result_package = self.collect_results(input_file, package, tmp_dir)
except FileNotFoundError as e:
logger.warning(f"Calculation of {input_file} failed: {e}")
result_package = {}
return result_package
self.copy_files(output_file, result_package.get("molden_file", None))
if self.clean_tmp:
rmtree(tmp_dir)
result_package["atom_type"] = package["atom_type"]
result_package["coord"] = package["Ra"]
result_package["total_spin"] = package.get("spin", 0)
result_package["total_chrg"] = package["Q"]
result_package["index"] = package["index"]
if self.dump_single_run:
self.dump_results(result_package)
return result_package
[docs]
def run(self):
datapoints = []
if self.n_processes == 1:
result_packages = []
for package in tqdm(self.supplier.suppl(), desc="Running QM", dynamic_ncols=True, leave=False, position=0):
result_package = self.single_run(package)
result_packages.append(result_package)
else:
logger.info(f"Running QM calculations with {self.n_processes} processes")
with Pool(self.n_processes) as p:
result_packages = list(tqdm(
p.imap(self.single_run, self.supplier.suppl()),
desc="Running QM",
dynamic_ncols=True,
leave=False,
position=0
))
for result_package in result_packages:
if not result_package:
continue
datapoint = {k: result_package[k] for k in [
"energy", "grad", "dipole", "index", "atom_type", "coord", "total_spin", "total_chrg"
]}
datapoints.append(datapoint)
dump(datapoints, open(self.output_path, "wb"))
logger.info(f"QM calculations finished. Pickle saved to {self.output_path}")
[docs]
class TeraChemDriver(QMDriver):
[docs]
def __init__(self,
supplier: Supplier, tmp_dir: str, output_dir: str, pickle_name: str,
bs: str, xc: str,
keep_molden: bool = False,
keep_stdout: bool = False,
clean_tmp: bool = True,
dftd: Optional[str] = None,
pcm: Optional[str] = None,
epsilon: Optional[float] = None,
pcm_radii_file: Optional[str] = None,
scf_method: Optional[str] = "diis+a",
scf_maxit: Optional[int] = 1000,
scf_guess: Optional[str] = "generate",
n_processes: int = 1,
dump_single_run: bool = True,
timeout: Optional[float] = None,
*args, **kwargs
):
super().__init__(supplier, tmp_dir, output_dir, pickle_name, bs, xc, keep_molden, keep_stdout, clean_tmp, n_processes, dump_single_run)
self.dftd = dftd
self.pcm = pcm
self.epsilon = epsilon
self.pcm_radii_file = pcm_radii_file
self.scf_method = scf_method
self.scf_maxit = scf_maxit
self.scf_guess = scf_guess
self.timeout = timeout
[docs]
def invoke_qm(self, input_file: Path, tmp_dir: Path):
output_file = input_file.with_suffix(".out")
current_dir = os.getcwd()
os.chdir(tmp_dir)
try:
with open(output_file, 'w') as f:
subprocess.run(
["terachem", str(input_file)],
stdout=f,
stderr=subprocess.PIPE,
timeout=self.timeout,
check=False
)
except subprocess.TimeoutExpired:
logger.warning(f"TeraChem calculation for {input_file} timed out after {self.timeout} seconds")
finally:
os.chdir(current_dir)
return output_file
[docs]
def collect_results(self, input_file: Path, package: Dict[Literal["index", "atom_type", "Ra", "Q", "mol"], Any], tmp_dir: Path):
scr_dir = tmp_dir / f"scr_{input_file.stem}"
if not (scr_dir / "results.dat").exists():
raise FileNotFoundError(f"Results file {scr_dir / 'results.dat'} not found")
if not (scr_dir / "grad.xyz").exists():
raise FileNotFoundError(f"Gradients file {scr_dir / 'grad.xyz'} not found")
with open(scr_dir / "results.dat", "r") as f:
lines = f.readlines()
com_line_index = -1
dipole_line_index = -1
for i, line in enumerate(lines):
if line.startswith("Center of Mass (Angs):"):
com_line_index = i + 2
if line.startswith("Ground state dipole moment (Debye):"):
dipole_line_index = i + 2
if com_line_index == -1 or dipole_line_index == -1:
raise FileNotFoundError(f"Center of Mass or dipole moment line not found in {scr_dir / 'results.dat'}")
com = np.array(list(map(float, lines[com_line_index].split())))
dipole = np.array(list(map(float, lines[dipole_line_index].split()))) * 0.20819434 # Debye to e Angstrom
dipole = dipole + com * package["Q"]
with open(scr_dir / "grad.xyz", "r") as f:
_ = f.readline()
title = f.readline()
energy = float(title.split()[6])
grad = np.loadtxt(scr_dir / "grad.xyz", skiprows=2, usecols=(1,2,3)) / 0.5291772108 # Bohr to Angstrom
return {"dipole": dipole, "grad": grad, "energy": energy, "molden_file": scr_dir / (input_file.stem + ".molden")}
[docs]
class ORCADriver(QMDriver):
pass
[docs]
class PySCFDriver(QMDriver):
pass
[docs]
class Psi4Driver(QMDriver):
pass