Source code for openff.qcsubmit.datasets.datasets

import abc
import json
from collections import defaultdict
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    Generator,
    List,
    Optional,
    Set,
    Tuple,
    Type,
    TypeVar,
    Union,
)

import qcelemental as qcel
import qcportal as ptl
from openff.toolkit import topology as off
from pydantic import Field, constr, validator
from qcelemental.models import AtomicInput, OptimizationInput
from qcelemental.models.procedures import OptimizationProtocols, QCInputSpecification
from qcportal import PortalClient, PortalRequestError
from qcportal.optimization import OptimizationDatasetNewEntry, OptimizationSpecification
from qcportal.singlepoint import (
    QCSpecification,
    SinglepointDatasetNewEntry,
    SinglepointDriver,
)
from qcportal.torsiondrive import TorsiondriveDatasetNewEntry, TorsiondriveSpecification
from typing_extensions import Literal

from openff.qcsubmit.common_structures import CommonBase, Metadata, MoleculeAttributes
from openff.qcsubmit.constraints import Constraints
from openff.qcsubmit.datasets.entries import (
    DatasetEntry,
    FilterEntry,
    OptimizationEntry,
    TorsionDriveEntry,
)
from openff.qcsubmit.exceptions import (
    DatasetCombinationError,
    MissingBasisCoverageError,
    UnsupportedFiletypeError,
)
from openff.qcsubmit.procedures import GeometricProcedure
from openff.qcsubmit.serializers import deserialize, serialize
from openff.qcsubmit.utils.smirnoff import smirnoff_coverage
from openff.qcsubmit.utils.visualize import molecules_to_pdf

if TYPE_CHECKING:
    from openff.toolkit.typing.engines.smirnoff import ForceField
    from qcportal.collections.collection import Collection

C = TypeVar("C", bound="Collection")
E = TypeVar("E", bound=DatasetEntry)


class _BaseDataset(abc.ABC, CommonBase):
    """
    A general base model for QCSubmit datasets which act as wrappers around a corresponding QFractal collection.
    """

    dataset_name: str = Field(
        ...,
        description="The name of the dataset, this will be the name given to the collection in QCArchive.",
    )
    dataset_tagline: constr(min_length=8, regex="[a-zA-Z]") = Field(  # noqa
        ...,
        description="The tagline should be a short description of the dataset which will be displayed by the QCArchive client when the datasets are listed.",
    )
    type: Literal["_BaseDataset"] = Field(
        "_BaseDataset",
        description="The dataset type corresponds to the type of collection that will be made in QCArchive.",
    )
    description: constr(min_length=8, regex="[a-zA-Z]") = Field(  # noqa
        ...,
        description="A long description of the datasets purpose and details about the molecules within.",
    )
    metadata: Metadata = Field(
        Metadata(), description="The metadata describing the dataset."
    )
    provenance: Dict[str, str] = Field(
        {},
        description="A dictionary of the software and versions used to generate the dataset.",
    )
    dataset: Dict[str, DatasetEntry] = Field(
        {}, description="The actual dataset to be stored in QCArchive."
    )
    filtered_molecules: Dict[str, FilterEntry] = Field(
        {},
        description="The set of workflow components used to generate the dataset with any filtered molecules.",
    )
    _file_writers = {"json": json.dump}

    def __init__(self, **kwargs):
        """
        Make sure the metadata has been assigned correctly if not autofill some information.
        """

        super().__init__(**kwargs)

        # set the collection type here
        self.metadata.collection_type = self.type
        self.metadata.dataset_name = self.dataset_name

        # some fields can be reused here
        if self.metadata.short_description is None:
            self.metadata.short_description = self.dataset_tagline
        if self.metadata.long_description is None:
            self.metadata.long_description = self.description

    @classmethod
    @abc.abstractmethod
    def _entry_class(cls) -> Type[E]:
        raise NotImplementedError()

    @abc.abstractmethod
    def _generate_collection(self, client: "PortalClient") -> C:
        """Generate the corresponding QCFractal Collection for this Dataset.

        Each QCSubmit Dataset class corresponds to and wraps
        a QCFractal Collection class. This method generates an instance
        of that corresponding Collection, with inputs applied from
        Dataset attributes.

        Args:
            client:
                Client to use for connecting to a QCFractal server instance.

        Returns:
            Collection instance corresponding to this Dataset.

        """
        raise NotImplementedError()

    @abc.abstractmethod
    def _get_specifications(self) -> "OptimizationSpecification":
        """Get the procedure spec, if applicable, for this Dataset.

        If the dataset has no concept of procedure specs, this method
        should return `None`.

        Returns:
            Specification for the optimization procedure to perform.

        """
        raise NotImplementedError()

    @abc.abstractmethod
    def _get_entries(self) -> List[Any]:
        """Add entries to the Dataset's corresponding Collection.

        This method allows for handling of e.g. generating the index/name for
        the corresponding Collection from each item in `self.dataset`. Since
        each item may feature more than one conformer, appropriate handling
        differs between e.g. `OptimizationDataset` and `TorsiondriveDataset`

        Returns
        -------
        :
            List of new entries to add to the dataset on the server
        """
        pass

    @abc.abstractmethod
    def to_tasks(self) -> Dict[str, List[Union[AtomicInput, OptimizationInput]]]:
        """
        Create a dictionary of QCengine tasks which correspond to this dataset stored by the program which should be used for the task.
        """
        raise NotImplementedError()

    def submit(
        self, client: "PortalClient", ignore_errors: bool = False, verbose: bool = False
    ) -> Dict:
        """
        Submit the dataset to a QCFractal server.

        Args:
            client:
                Instance of a portal client
            ignore_errors:
                If the user wants to submit the compute regardless of errors set this to ``True``.
                Mainly to override basis coverage.
            verbose:
                If progress bars and submission statistics should be printed ``True`` or not ``False``.

        Returns:
            A dictionary of the compute response from the client for each specification submitted.

        Raises:
            MissingBasisCoverageError:
                If the chosen basis set does not cover some of the elements in the dataset.

        """
        from openff.qcsubmit.datasets import legacy_qcsubmit_ds_type_to_next_qcf_ds_type

        # pre submission checks
        # make sure we have some QCSpec to submit
        self._check_qc_specs()
        # basis set coverage check
        self._get_missing_basis_coverage(raise_errors=(not ignore_errors))

        # see if collection already exists
        # if so, we'll extend it
        # if not, we'll create a new one

        try:
            qcf_ds_type = legacy_qcsubmit_ds_type_to_next_qcf_ds_type[self.type]
            collection = client.get_dataset(qcf_ds_type, self.dataset_name)
        except PortalRequestError:
            self.metadata.validate_metadata(raise_errors=not ignore_errors)
            collection = self._generate_collection(client=client)

        # create specifications
        # TODO - check if specifications already exist
        specs = self._get_specifications()
        for spec_name, spec in specs.items():
            # Send the new specifications to the server
            collection.add_specification(
                name=spec_name,
                specification=spec,
                description=self.qc_specifications[spec_name].spec_description,
            )

        # add the molecules/entries to the database
        entries = self._get_entries()

        # TODO - check if entries already exist
        insert_metadata = collection.add_entries(entries)
        if verbose:
            print(
                f"Number of new entries: {len(insert_metadata.inserted_idx)}/{self.n_records}"
            )

        return collection.submit(
            tag=self.compute_tag,
            priority=self.priority,
            # find_existing=find_existing
        )

    @abc.abstractmethod
    def __add__(self, other: "_BaseDataset") -> "_BaseDataset":
        """
        Add two Basicdatasets together.
        """
        raise NotImplementedError()

    @classmethod
    def parse_file(cls, file_name: str):
        """
        Create a Dataset object from a compressed json file.

        Args:
            file_name: The name of the file the dataset should be created from.
        """
        data = deserialize(file_name=file_name)
        return cls(**data)

    def get_molecule_entry(self, molecule: Union[off.Molecule, str]) -> List[str]:
        """
        Search through the dataset for a molecule and return the dataset index of any exact molecule matches.

        Args:
            molecule: The smiles string for the molecule or an openforcefield.topology.Molecule that is to be searched for.

        Returns:
            A list of dataset indices which contain the target molecule.
        """
        # if we have a smiles string convert it
        if isinstance(molecule, str):
            molecule = off.Molecule.from_smiles(molecule, allow_undefined_stereo=True)

        # make a unique inchi key
        inchi_key = molecule.to_inchikey(fixed_hydrogens=False)
        hits = []
        for entry in self.dataset.values():
            if inchi_key == entry.attributes.inchi_key:
                # they have same basic inchi now match the molecule
                if molecule == entry.get_off_molecule(include_conformers=False):
                    hits.append(entry.index)

        return hits

    @property
    def filtered(self) -> off.Molecule:
        """
        A generator which yields a openff molecule representation for each molecule filtered while creating this dataset.

        Note:
            Modifying the molecule will have no effect on the data stored.
        """

        for component, data in self.filtered_molecules.items():
            for smiles in data.molecules:
                offmol = off.Molecule.from_smiles(smiles, allow_undefined_stereo=True)
                yield offmol

    @property
    def n_filtered(self) -> int:
        """
        Calculate the total number of molecules filtered by the components used in a workflow to create this dataset.
        """
        filtered = sum(
            [len(data.molecules) for data in self.filtered_molecules.values()]
        )
        return filtered

    @property
    def n_records(self) -> int:
        """
        Return the total number of records that will be created on submission of the dataset.

        Note:
            * The number returned will be different depending on the dataset used.
            * The amount of unique molecule can be found using `n_molecules`
        """

        n_records = sum([len(data.initial_molecules) for data in self.dataset.values()])
        return n_records

    @property
    def n_molecules(self) -> int:
        """
        Calculate the number of unique molecules to be submitted.

        Notes:
            * This method has been improved for better performance on large datasets and has been tested on an optimization dataset of over 10500 molecules.
            * This function does not calculate the total number of entries of the dataset see `n_records`
        """
        molecules = {}
        for entry in self.dataset.values():
            inchikey = entry.attributes.inchi_key
            try:
                like_mols = molecules[inchikey]
                mol_to_add = entry.get_off_molecule(False).to_inchikey(
                    fixed_hydrogens=True
                )
                for index in like_mols:
                    if mol_to_add == self.dataset[index].get_off_molecule(
                        False
                    ).to_inchikey(fixed_hydrogens=True):
                        break
                else:
                    molecules[inchikey].append(entry.index)
            except KeyError:
                molecules[inchikey] = [
                    entry.index,
                ]
        return sum([len(value) for value in molecules.values()])

    @property
    def molecules(self) -> Generator[off.Molecule, None, None]:
        """
        A generator that creates an openforcefield.topology.Molecule one by one from the dataset.

        Note:
            Editing the molecule will not effect the data stored in the dataset as it is immutable.
        """

        for molecule_data in self.dataset.values():
            # create the molecule from the cmiles data
            yield molecule_data.get_off_molecule(include_conformers=True)

    @property
    def n_components(self) -> int:
        """
        Return the amount of components that have been ran during generating the dataset.
        """

        n_filtered = len(self.filtered_molecules)
        return n_filtered

    @property
    def components(self) -> List[Dict[str, Union[str, Dict[str, str]]]]:
        """
        Gather the details of the components that were ran during the creation of this dataset.
        """

        components = []
        for component in self.filtered_molecules.values():
            components.append(component.dict(exclude={"molecules"}))

        return components

    def filter_molecules(
        self,
        molecules: Union[off.Molecule, List[off.Molecule]],
        component: str,
        component_settings: Dict[str, Any],
        component_provenance: Dict[str, str],
    ) -> None:
        """
        Filter a molecule or list of molecules by the component they failed.

        Args:
            molecules:
                A molecule or list of molecules to be filtered.
            component_settings:
                The dictionary representation of the component that filtered this set of molecules.
            component:
                The name of the component.
            component_provenance:
                The dictionary representation of the component provenance.
        """

        if isinstance(molecules, off.Molecule):
            # make into a list
            molecules = [molecules]

        if component in self.filtered_molecules:
            filter_mols = [
                molecule.to_smiles(isomeric=True, explicit_hydrogens=True)
                for molecule in molecules
            ]
            self.filtered_molecules[component].molecules.extend(filter_mols)
        else:
            filter_data = FilterEntry(
                off_molecules=molecules,
                component=component,
                component_provenance=component_provenance,
                component_settings=component_settings,
            )

            self.filtered_molecules[filter_data.component] = filter_data

    def add_molecule(
        self,
        index: str,
        molecule: Optional[off.Molecule],
        extras: Optional[Dict[str, Any]] = None,
        keywords: Optional[Dict[str, Any]] = None,
        **kwargs,
    ) -> None:
        """
        Add a molecule to the dataset under the given index with the passed cmiles.

        Args:
            index:
                The index that should be associated with the molecule in QCArchive.
            molecule:
                The instance of the molecule which contains its conformer information.
            extras:
                The extras that should be supplied into the qcportal.moldels.Molecule.
            keywords:
                Any extra keywords which are required for the calculation.

        Note:
            Each molecule in this basic dataset should have all of its conformers expanded out into separate entries.
            Thus here we take the general molecule index and increment it.
        """
        # only use attributes if supplied else generate
        # Note we should only reuse attributes if making a dataset from a result so the attributes are consistent
        if "attributes" in kwargs:
            attributes = kwargs.pop("attributes")
        else:
            attributes = MoleculeAttributes.from_openff_molecule(molecule=molecule)

        try:
            data_entry = self._entry_class()(
                off_molecule=molecule,
                attributes=attributes,
                index=index,
                extras=extras or {},
                keywords=keywords or {},
                **kwargs,
            )
            self.dataset[index] = data_entry
            # add any extra elements to the metadata
            self.metadata.elements.update(data_entry.initial_molecules[0].symbols)

        except qcel.exceptions.ValidationError:
            # the molecule has some qcschema issue and should be removed
            self.filter_molecules(
                molecules=molecule,
                component="QCSchemaIssues",
                component_settings={
                    "component_description": "The molecule was removed as a valid QCSchema could not be made",
                    "type": "QCSchemaIssues",
                },
                component_provenance=self.provenance,
            )

    def _get_missing_basis_coverage(
        self, raise_errors: bool = True
    ) -> Dict[str, Set[str]]:
        """
        Work out if the selected basis set covers all of the elements in the dataset for each specification if not return the missing
        element symbols.

        Args:
            raise_errors: If `True` the function will raise an error for missing basis coverage, else we return the missing data and just print warnings.
        """
        import re
        import warnings

        import basis_set_exchange as bse

        try:
            from openmm.app import Element
        except ImportError:
            from simtk.openmm.app import Element

        basis_report = {}
        for spec in self.qc_specifications.values():
            if spec.program.lower() == "torchani":
                # check ani1 first
                ani_coverage = {
                    "ani1x": {"C", "H", "N", "O"},
                    "ani1ccx": {"C", "H", "N", "O"},
                    "ani2x": {"C", "H", "N", "O", "S", "F", "Cl"},
                }
                covered_elements = ani_coverage[spec.method.lower()]
                # this is validated at the spec level so we should not get an error here
                difference = self.metadata.elements.difference(covered_elements)

            elif spec.program.lower() == "psi4":
                if spec.basis is not None:
                    # now check psi4
                    # TODO this list should be updated with more basis transforms as we find them
                    psi4_converter = {"dzvp": "dgauss-dzvp"}
                    month_subs = {"jun-", "mar-", "apr-", "may-", "feb-"}
                    basis = psi4_converter.get(spec.basis.lower(), spec.basis.lower())
                    # here we need to apply conversions for special characters to match bse
                    # replace the *
                    basis = re.sub("\*", "_st_", basis)  # noqa
                    # replace any /
                    basis = re.sub("/", "_sl_", basis)
                    # check for heavy tags
                    basis = re.sub("heavy-", "", basis)
                    try:
                        basis_meta = bse.get_metadata()[basis]
                    except KeyError:
                        # now try and do month subs
                        for month in month_subs:
                            if month in basis:
                                basis = re.sub(month, "", basis)
                        # now try and get the basis again
                        basis_meta = bse.get_metadata()[basis]

                    elements = basis_meta["versions"][basis_meta["latest_version"]][
                        "elements"
                    ]
                    covered_elements = set(
                        [
                            Element.getByAtomicNumber(int(element)).symbol
                            for element in elements
                        ]
                    )
                    difference = self.metadata.elements.difference(covered_elements)
                else:
                    # the basis is wrote with the method so print a warning about validation
                    warnings.warn(
                        f"The spec {spec.spec_name} has a basis of None, this will not be validated.",
                        UserWarning,
                    )
                    difference = set()

            elif spec.program.lower() == "openmm":
                # smirnoff covered elements
                covered_elements = {"C", "H", "N", "O", "P", "S", "Cl", "Br", "F", "I"}
                difference = self.metadata.elements.difference(covered_elements)

            elif spec.program.lower() == "rdkit":
                # all atoms are defined in the uff so return an empty set.
                difference = set()

            else:
                # xtb
                # all atoms are covered and this must be xtb
                difference = set()

            basis_report[spec.spec_name] = difference

        for spec_name, report in basis_report.items():
            if report:
                if raise_errors:
                    raise MissingBasisCoverageError(
                        f"The following elements: {report} are not covered by the selected basis : {self.qc_specifications[spec_name].basis} and method : {self.qc_specifications[spec_name].method}"
                    )
                else:
                    warnings.warn(
                        f"The following elements: {report} are not covered by the selected basis : {self.qc_specifications[spec_name].basis} and method : {self.qc_specifications[spec_name].method}",
                        UserWarning,
                    )
        if not raise_errors:
            return basis_report

    def export_dataset(self, file_name: str, compression: Optional[str] = None) -> None:
        """
        Export the dataset to file so that it can be used to make another dataset quickly.

        Args:
            file_name:
                The name of the file the dataset should be wrote to.
            compression:
                The type of compression that should be added to the export.

        Raises:
            UnsupportedFiletypeError: If the requested file type is not supported.


        Note:
            The supported file types are:

            - `json`

            Additionally, the file will automatically compressed depending on the
            final extension if compression is not explicitly supplied:

            - `json.xz`
            - `json.gz`
            - `json.bz2`

            Check serializers.py for more details. Right now bz2 seems to
            produce the smallest files.
        """

        # Check here early, just to filter anything non-json for now
        # Ideally the serializers should be checking this
        split = file_name.split(".")
        split = split[-1:] if len(split) == 1 else split[-2:]
        if "json" not in split:
            raise UnsupportedFiletypeError(
                f"The dataset export file name with leading extension {split[-1]} is not supported, "
                "please end the file name with json."
            )

        serialize(serializable=self, file_name=file_name, compression=compression)

    def coverage_report(
        self, force_field: "ForceField", verbose: bool = False
    ) -> Dict[str, Dict[str, int]]:
        """Returns a summary of how many molecules within this dataset would be assigned
        each of the parameters in a force field.

        Notes:
            * Parameters which would not be assigned to any molecules in the dataset
              will not be included in the returned summary.

        Args:
            force_field: The force field containing the parameters to summarize.
            verbose: If true a progress bar will be shown on screen.

        Returns:
            A dictionary of the form ``coverage[handler_name][parameter_smirks] = count``
            which stores the number of molecules within this dataset that would be
            assigned to each parameter.
        """

        return smirnoff_coverage(self.molecules, force_field, verbose)

    def visualize(
        self,
        file_name: str,
        columns: int = 4,
        toolkit: Optional[Literal["openeye", "rdkit"]] = None,
    ) -> None:
        """
        Create a pdf file of the molecules with any torsions highlighted using either openeye or rdkit.

        Args:
            file_name:
                The name of the pdf file which will be produced.
            columns:
                The number of molecules per row.
            toolkit:
                The option to specify the backend toolkit used to produce the pdf file.
        """

        molecules = []

        for data in self.dataset.values():
            off_mol = data.get_off_molecule(include_conformers=False)
            off_mol.name = None

            if hasattr(data, "dihedrals"):
                off_mol.properties["dihedrals"] = data.dihedrals

            molecules.append(off_mol)

        molecules_to_pdf(molecules, file_name, columns, toolkit)

    def molecules_to_file(self, file_name: str, file_type: str) -> None:
        """
        Write the molecules to the requested file type.

        Args:
            file_name:
                The name of the file the molecules should be stored in.
            file_type:
                The file format that should be used to store the molecules.

        Important:
            The supported file types are:

            - SMI
            - INCHI
            - INCKIKEY
        """

        file_writers = {
            "smi": self._molecules_to_smiles,
            "inchi": self._molecules_to_inchi,
            "inchikey": self._molecules_to_inchikey,
        }

        try:
            # get the list of molecules
            molecules = file_writers[file_type.lower()]()

            with open(file_name, "w") as output:
                for molecule in molecules:
                    output.write(f"{molecule}\n")
        except KeyError:
            raise UnsupportedFiletypeError(
                f"The requested file type {file_type} is not supported, supported types are"
                f"{file_writers.keys()}."
            )

    def _molecules_to_smiles(self) -> List[str]:
        """
        Create a list of molecules canonical isomeric smiles.
        """

        smiles = [
            data.attributes.canonical_isomeric_smiles for data in self.dataset.values()
        ]
        return smiles

    def _molecules_to_inchi(self) -> List[str]:
        """
        Create a list of the molecules standard InChI.
        """

        inchi = [data.attributes.standard_inchi for data in self.dataset.values()]
        return inchi

    def _molecules_to_inchikey(self) -> List[str]:
        """
        Create a list of the molecules standard InChIKey.
        """

        inchikey = [data.attributes.inchi_key for data in self.dataset.values()]
        return inchikey


# TODO: SinglepointDataset
[docs]class BasicDataset(_BaseDataset): """ The general QCFractal dataset class which contains all of the molecules and information about them prior to submission. The class is a simple holder of the dataset and information about it and can do simple checks on the data before submitting it such as ensuring that the molecules have cmiles information and a unique index to be identified by. Note: The molecules in this dataset are all expanded so that different conformers are unique submissions. """ type: Literal["DataSet"] = "DataSet" @classmethod def _entry_class(cls) -> Type[DatasetEntry]: return DatasetEntry def __add__(self, other: "BasicDataset") -> "BasicDataset": import copy # make sure the dataset types match if self.type != other.type: raise DatasetCombinationError( f"The datasets must be the same type, you can not add types {self.type} and {other.type}" ) # create a new datset new_dataset = copy.deepcopy(self) # update the elements in the dataset new_dataset.metadata.elements.update(other.metadata.elements) for index, entry in other.dataset.items(): # search for the molecule entry_ids = new_dataset.get_molecule_entry( entry.get_off_molecule(include_conformers=False) ) if not entry_ids: new_dataset.dataset[index] = entry else: mol_id = entry_ids[0] current_entry = new_dataset.dataset[mol_id] _, atom_map = off.Molecule.are_isomorphic( entry.get_off_molecule(include_conformers=False), current_entry.get_off_molecule(include_conformers=False), return_atom_map=True, ) # remap the molecule and all conformers entry_mol = entry.get_off_molecule(include_conformers=True) mapped_mol = entry_mol.remap(mapping_dict=atom_map, current_to_new=True) for i in range(mapped_mol.n_conformers): mapped_schema = mapped_mol.to_qcschema( conformer=i, extras=current_entry.initial_molecules[0].extras ) if mapped_schema not in current_entry.initial_molecules: current_entry.initial_molecules.append(mapped_schema) return new_dataset def _generate_collection( self, client: "PortalClient" ) -> ptl.singlepoint.SinglepointDataset: return client.add_dataset( dataset_type="singlepoint", name=self.dataset_name, tagline=self.dataset_tagline, tags=self.dataset_tags, description=self.description, provenance=self.provenance, default_tag=self.compute_tag, default_priority=self.priority, metadata=self.metadata.dict(), ) def _get_specifications(self) -> Dict[str, QCSpecification]: """Needed for `submit` usage.""" ret = {} for spec_name, spec in self.qc_specifications.items(): ret[spec_name] = QCSpecification( driver=self.driver, method=spec.method, basis=spec.basis, keywords=spec.qc_keywords, program=spec.program, protocols={"wavefunction": spec.store_wavefunction}, ) return ret def _get_entries(self) -> List[SinglepointDatasetNewEntry]: entries: List[SinglepointDatasetNewEntry] = [] for entry_name, entry in self.dataset.items(): if len(entry.initial_molecules) > 1: # check if the index has a number tag # if so, start from this tag index, tag = self._clean_index(index=entry_name) for j, molecule in enumerate(entry.initial_molecules): name = index + f"-{tag + j}" entries.append( SinglepointDatasetNewEntry( name=name, molecule=molecule, attributes=entry.attributes ) ) else: entries.append( SinglepointDatasetNewEntry( name=entry_name, molecule=entry.initial_molecules[0], attributes=entry.attributes, ) ) return entries
[docs] def to_tasks(self) -> Dict[str, List[AtomicInput]]: """ Build a dictionary of single QCEngine tasks that correspond to this dataset organised by program name. The tasks can be passed directly to qcengine.compute. """ data = defaultdict(list) for spec in self.qc_specifications.values(): qc_model = spec.qc_model keywords = spec.qc_keywords protocols = {"wavefunction": spec.store_wavefunction.value} program = spec.program.lower() for index, entry in self.dataset.items(): # check if the index has a number tag # if so, start from this tag index, tag = self._clean_index(index=index) for j, molecule in enumerate(entry.initial_molecules): name = index + f"-{tag + j}" data[program].append( AtomicInput( id=name, molecule=molecule, driver=self.driver, model=qc_model, keywords=keywords, protocols=protocols, ) ) return data
class OptimizationDataset(BasicDataset): """ An optimisation dataset class which handles submission of settings differently from the basic dataset, and creates optimization datasets in the public or local qcarchive instance. """ type: Literal["OptimizationDataset"] = "OptimizationDataset" driver: SinglepointDriver = SinglepointDriver.deferred optimization_procedure: GeometricProcedure = Field( GeometricProcedure(), description="The optimization program and settings that should be used.", ) protocols: OptimizationProtocols = Field( OptimizationProtocols(), description="Protocols regarding the manipulation of Optimization output data.", ) dataset: Dict[str, OptimizationEntry] = {} @classmethod def _entry_class(cls) -> Type[OptimizationEntry]: return OptimizationEntry @validator("driver") def _check_driver(cls, driver): """Make sure that the driver is set to deferred only and not changed.""" return SinglepointDriver.deferred def __add__(self, other: "OptimizationDataset") -> "OptimizationDataset": """ Add two Optimization datasets together, if the constraints are different then the entries are considered different. """ import copy from openff.qcsubmit.utils import remap_list # make sure the dataset types match if self.type != other.type: raise DatasetCombinationError( f"The datasets must be the same type, you can not add types {self.type} and {other.type}" ) # create a new dataset new_dataset = copy.deepcopy(self) # update the elements in the dataset new_dataset.metadata.elements.update(other.metadata.elements) for entry in other.dataset.values(): # search for the molecule entry_ids = new_dataset.get_molecule_entry( entry.get_off_molecule(include_conformers=False) ) if entry_ids: records = 0 for mol_id in entry_ids: current_entry = new_dataset.dataset[mol_id] # for each entry count the number of inputs incase we need a new entry records += len(current_entry.initial_molecules) _, atom_map = off.Molecule.are_isomorphic( entry.get_off_molecule(include_conformers=False), current_entry.get_off_molecule(include_conformers=False), return_atom_map=True, ) current_constraints = current_entry.constraints # make sure all constraints are the same # remap the entry to compare entry_constraints = Constraints() for constraint in entry.constraints.freeze: entry_constraints.add_freeze_constraint( constraint.type, remap_list(constraint.indices, atom_map) ) for constraint in entry.constraints.set: entry_constraints.add_set_constraint( constraint.type, remap_list(constraint.indices, atom_map), constraint.value, ) if current_constraints == entry_constraints: # transfer the entries # remap and transfer off_mol = entry.get_off_molecule(include_conformers=True) mapped_mol = off_mol.remap( mapping_dict=atom_map, current_to_new=True ) for i in range(mapped_mol.n_conformers): mapped_schema = mapped_mol.to_qcschema( conformer=i, extras=current_entry.initial_molecules[0].extras, ) if mapped_schema not in current_entry.initial_molecules: current_entry.initial_molecules.append(mapped_schema) break # else: # # if they are not the same move on to the next entry # continue else: # we did not break so add the entry with a new unique index core, tag = self._clean_index(entry.index) entry.index = core + f"-{tag + records}" new_dataset.dataset[entry.index] = entry else: # if no other molecules just add it new_dataset.dataset[entry.index] = entry return new_dataset def _generate_collection( self, client: "PortalClient" ) -> ptl.optimization.OptimizationDataset: return client.add_dataset( dataset_type="optimization", name=self.dataset_name, tagline=self.dataset_tagline, tags=self.dataset_tags, description=self.description, provenance=self.provenance, default_tag=self.compute_tag, default_priority=self.priority, metadata=self.metadata.dict(), ) def _get_specifications(self) -> Dict[str, OptimizationSpecification]: opt_kw = self.optimization_procedure.get_optimzation_keywords() ret = {} for spec_name, spec in self.qc_specifications.items(): qc_spec = QCSpecification( driver=self.driver, method=spec.method, basis=spec.basis, keywords=spec.qc_keywords, program=spec.program, protocols={"wavefunction": spec.store_wavefunction}, ) ret[spec_name] = OptimizationSpecification( program=self.optimization_procedure.program, qc_specification=qc_spec, keywords=opt_kw, protocols=self.protocols, ) return ret def _get_entries(self) -> List[OptimizationDatasetNewEntry]: entries: List[OptimizationDatasetNewEntry] = [] for entry_name, entry in self.dataset.items(): # TODO this probably needs even more keywords opt_kw = dict(constraints=entry.constraints) opt_kw.update(entry.keywords) if len(entry.initial_molecules) > 1: # check if the index has a number tag # if so, start from this tag index, tag = self._clean_index(index=entry_name) for j, molecule in enumerate(entry.initial_molecules): name = index + f"-{tag + j}" entries.append( OptimizationDatasetNewEntry( name=name, initial_molecule=molecule, additional_keywords=opt_kw, attributes=entry.attributes, ) ) else: entries.append( OptimizationDatasetNewEntry( name=entry_name, initial_molecule=entry.initial_molecules[0], additional_keywords=opt_kw, attributes=entry.attributes, ) ) return entries def to_tasks(self) -> Dict[str, List[OptimizationInput]]: """ Build a list of QCEngine optimisation inputs organised by the optimisation engine which should be used to run the task. """ data = defaultdict(list) opt_program = self.optimization_procedure.program.lower() for spec in self.qc_specifications.values(): qc_model = spec.qc_model qc_keywords = spec.qc_keywords qc_spec = QCInputSpecification( # TODO: self.driver is now set to "deferred" - is it safe to put "gradient" here? driver="gradient", model=qc_model, keywords=qc_keywords, ) opt_spec = self.optimization_procedure.dict(exclude={"program"}) # this needs to be the single point calculation program opt_spec["program"] = spec.program.lower() for index, entry in self.dataset.items(): index, tag = self._clean_index(index=index) for j, molecule in enumerate(entry.initial_molecules): name = index + f"-{tag + j}" data[opt_program].append( OptimizationInput( id=name, keywords=opt_spec, input_specification=qc_spec, initial_molecule=molecule, protocols=self.protocols, ) ) return data class TorsiondriveDataset(OptimizationDataset): """ An torsiondrive dataset class which handles submission of settings differently from the basic dataset, and creates torsiondrive datasets in the public or local qcarchive instance. Important: The dihedral_ranges for the whole dataset can be defined here or if different scan ranges are required on a case by case basis they can be defined for each torsion in a molecule separately in the keywords of the torsiondrive entry. """ dataset: Dict[str, TorsionDriveEntry] = {} type: Literal["TorsionDriveDataset"] = "TorsionDriveDataset" optimization_procedure: GeometricProcedure = GeometricProcedure.parse_obj( {"enforce": 0.1, "reset": True, "qccnv": True, "epsilon": 0.0} ) grid_spacing: List[int] = Field( [15], description="The grid spcaing that should be used for all torsiondrives, this can be overwriten on a per entry basis.", ) energy_upper_limit: float = Field( 0.05, description="The upper energy limit to spawn new optimizations in the torsiondrive.", ) dihedral_ranges: Optional[List[Tuple[int, int]]] = Field( None, description="The scan range that should be used for each torsiondrive, this can be overwriten on a per entry basis.", ) energy_decrease_thresh: Optional[float] = Field( None, description="The energy lower threshold to trigger new optimizations in the torsiondrive.", ) @classmethod def _entry_class(cls) -> Type[TorsionDriveEntry]: return TorsionDriveEntry def __add__(self, other: "TorsiondriveDataset") -> "TorsiondriveDataset": """ Add two TorsiondriveDatasets together, if the central bond in the dihedral is the same the entries are considered the same. """ import copy # make sure the dataset types match if self.type != other.type: raise DatasetCombinationError( f"The datasets must be the same type, you can not add types {self.type} and {other.type}" ) # create a new dataset new_dataset = copy.deepcopy(self) # update the elements in the dataset new_dataset.metadata.elements.update(other.metadata.elements) for index, entry in other.dataset.items(): # search for the molecule entry_ids = new_dataset.get_molecule_entry( entry.get_off_molecule(include_conformers=False) ) for mol_id in entry_ids: current_entry = new_dataset.dataset[mol_id] _, atom_map = off.Molecule.are_isomorphic( entry.get_off_molecule(include_conformers=False), current_entry.get_off_molecule(include_conformers=False), return_atom_map=True, ) # gather the current dihedrals forward and backwards current_dihedrals = set( [(dihedral[1:3]) for dihedral in current_entry.dihedrals] ) for dihedral in current_entry.dihedrals: current_dihedrals.add((dihedral[1:3])) current_dihedrals.add((dihedral[2:0:-1])) # now gather the other entry dihedrals forwards and backwards other_dihedrals = set() for dihedral in entry.dihedrals: other_dihedrals.add(tuple(atom_map[i] for i in dihedral[1:3])) other_dihedrals.add(tuple(atom_map[i] for i in dihedral[2:0:-1])) difference = current_dihedrals - other_dihedrals if not difference: # the entry is already there so add new conformers and skip off_mol = entry.get_off_molecule(include_conformers=True) mapped_mol = off_mol.remap( mapping_dict=atom_map, current_to_new=True ) for i in range(mapped_mol.n_conformers): mapped_schema = mapped_mol.to_qcschema( conformer=i, extras=current_entry.initial_molecules[0].extras, ) if mapped_schema not in current_entry.initial_molecules: current_entry.initial_molecules.append(mapped_schema) break else: # none of the entries matched so add it new_dataset.dataset[index] = entry return new_dataset @property def n_records(self) -> int: """ Calculate the number of records that will be submitted. """ return len(self.dataset) def _generate_collection( self, client: "PortalClient" ) -> ptl.torsiondrive.TorsiondriveDataset: return client.add_dataset( dataset_type="torsiondrive", name=self.dataset_name, tagline=self.dataset_tagline, tags=self.dataset_tags, description=self.description, provenance=self.provenance, default_tag=self.compute_tag, default_priority=self.priority, metadata=self.metadata.dict(), ) def _get_specifications(self) -> Dict[str, TorsiondriveSpecification]: td_kw = dict( grid_spacing=self.grid_spacing, dihedral_ranges=self.dihedral_ranges, energy_decrease_thresh=self.energy_decrease_thresh, energy_upper_limit=self.energy_upper_limit, ) ret = {} for spec_name, spec in self.qc_specifications.items(): qc_spec = QCSpecification( driver=self.driver, method=spec.method, basis=spec.basis, keywords=spec.qc_keywords, program=spec.program, protocols={"wavefunction": spec.store_wavefunction}, ) spec = OptimizationSpecification( program=self.optimization_procedure.program, qc_specification=qc_spec, protocols=self.protocols, ) ret[spec_name] = TorsiondriveSpecification( optimization_specification=spec, keywords=td_kw, ) return ret def _get_entries(self) -> List[TorsiondriveDatasetNewEntry]: entries: List[TorsiondriveDatasetNewEntry] = [] for entry_name, entry in self.dataset.items(): td_keywords = dict( grid_spacing=self.grid_spacing, energy_upper_limit=self.energy_upper_limit, energy_decrease_thresh=self.energy_decrease_thresh, dihedral_ranges=self.dihedral_ranges, dihedrals=entry.dihedrals, ) td_keywords.update(entry.keywords.dict(exclude_defaults=True)) opt_keywords = dict(constraints=entry.constraints) entries.append( TorsiondriveDatasetNewEntry( name=entry_name, initial_molecules=entry.initial_molecules, additional_keywords=td_keywords, additional_optimization_keywords=opt_keywords, attributes=entry.attributes, ) ) return entries def to_tasks(self) -> Dict[str, List[OptimizationInput]]: """Build a list of QCEngine procedure tasks which correspond to this dataset.""" raise NotImplementedError()