Source code for pharmaforge.labeling.relabeler

import pickle 

import numpy as np

from pathlib import Path

from pharmaforge.queries.query import Query
import time



[docs] class Relabeler: """ A class to handle relabeling of data in the database. This class is cognizant of the fact that one might want to relabel data in a database without having to have the database available at the time of relabeling. Attributes ---------- """ def __init__(self): self.collections = {} self.input_data = {} self.charge_data = {} self.relabeled_data = {} self.level_of_theory = None pass
[docs] def select_data(self, database, query=None, select_collection=None): """ Select data that must be relabeled. This can either be a query or an entire collection of data. Parameters ---------- database : pymongo.database.Database The database to select data from. query : str, optional The query to select data from the database. If not provided, the entire collection will be selected. select_collection : str, optional The collection to select data from the database. If not provided, the default collection will be used. """ self.collections = database.list_collection_names() if select_collection is None: for collection in self.collections: self.input_data[collection] = self._select_data(database, query, collection) self.charge_data[collection] = {} for result in self.input_data[collection]: if result.get("charge") is not None: self.charge_data[collection][result.get("molecule_id")] = result.get("charge") else: self.charge_data[collection][result.get("molecule_id")] = 0 else: self.input_data[select_collection] = self._select_data(database, query, select_collection) self.collections = [select_collection] if len(self.input_data) == 0: raise ValueError("No data selected. Please check your query or collection name.")
def _select_data(self, database, query, collection=None): """ Select data from the database by query. Parameters ---------- database : pymongo.database.Database The database to select data from. query : str The query to select data from the database. collection : str The collection to select data from the database. """ if collection is None: raise ValueError("Collection name must be provided.") elif collection not in self.collections: raise ValueError(f"Collection {collection} not found in the database.") if query is None: query = "nmols gt 0" # This will select all data in the collection. q = Query(query) return q.apply(database, collection)
[docs] def relabel(self, interface=None, restart=False, **kwargs): """ Relabels the data in the entire selection in one go. Ideal for systems with small number of processors, and small datasets. Parameters ---------- interface : object, optional The interface to use for relabeling. **kwargs : dict, optional Additional arguments to pass to the interface. """ if interface is None: raise ValueError("No interface provided. Please provide an interface for relabeling.") if not hasattr(interface, "level_of_theory"): raise ValueError("Interface does not have a level_of_theory attribute. Please provide a valid interface.") self.level_of_theory = interface.level_of_theory for collection in self.input_data: if restart and Path(f"{collection}_{self.level_of_theory}_checkpoint.pckl").exists(): print(f"Checkpoint for {collection} already exists. Skipping relabeling.") self.relabeled_data[collection] = pickle.load(open(f"{collection}_{self.level_of_theory}_checkpoint.pckl", "rb")) continue interface.ObtainQueryData(self.input_data[collection]) e, f = interface.calculate(**kwargs) self.relabeled_data[collection] = {"energy": e, "forces": f} # Save a checkpoint of the relabeled data self._self_save_checkpoint(collection, self.relabeled_data[collection]) return self.relabeled_data
def _self_save_checkpoint(self, collection, data): """ Saves a checkpoint of the relabeled data to a file. This is useful for large datasets, where one might want to save the data in chunks. Parameters ---------- collection_data : dict The data to save. collection_name : str The name of the collection to save the data to. """ filename=Path(f"{collection}_{self.level_of_theory}_checkpoint.pckl") with open(filename, "wb") as f: pickle.dump(data, f) print(f"Checkpoint for {collection} saved to {filename}")
[docs] def subdivided_relabel(self, nchunks=1): """ Sets up the relabeling process to be subdivided into smaller chunks with an equal number of tasks. This is ideal for systems with a large number of processors, or for HPC environments where one might want to do many tasks at once. Parameters ---------- nchunks : int, optional The number of chunks to subdivide the relabeling process into. Default is 1. """ #Flatten the input data into a single list of (collection, molecule) tuples all_molecules = [] for collection, molecules in self.input_data.items(): print(f"Collection: {collection}, Number of molecules: {len(molecules)}") for mol in molecules: nconfigs = mol.get('nconfigs') all_molecules.append((collection, mol, nconfigs)) # Sort the molecules by number of configurations all_molecules.sort(key=lambda x: x[2], reverse=True) # Calculate the number of configurations per chunk total_configs = sum(mol[2] for mol in all_molecules) target_configs_per_chunk = total_configs // nchunks # Distribute the molecules into chunks chunks = [[] for _ in range(nchunks)] chunk_config_counts = [0] * nchunks for collection, mol, nconfigs in all_molecules: min_chunk_index = chunk_config_counts.index(min(chunk_config_counts)) chunks[min_chunk_index].append((collection, mol)) chunk_config_counts[min_chunk_index] += nconfigs # Save each chunk to a file for i, chunk in enumerate(chunks): filename = Path(f"chunk_{i}.pckl") with open(filename, "wb") as f: pickle.dump(chunk, f) print(f"Chunk {i+1} saved to {filename} with {len(chunk)} molecules and {chunk_config_counts[i]} configurations total.") with open("input_data.pckl", "wb") as f: pickle.dump(self.input_data, f) print("Input data saved to input_data.pckl")
[docs] @staticmethod def relabel_chunk(chunk, interface=None, **kwargs): """ Relabels a chunk of data. This is useful for large datasets, where one might want to relabel the data in chunks. Parameters ---------- chunk : list The chunk of data to relabel. interface : object, optional The interface to use for relabeling. output_file : str, optional The file to save the relabeled data to. **kwargs : dict, optional Additional arguments to pass to the interface. """ if interface is None: raise ValueError("No interface provided. Please provide an interface for relabeling.") if not hasattr(interface, "level_of_theory"): raise ValueError("Interface does not have a level_of_theory attribute. Please provide a valid interface.") chunk = Path(chunk) if not chunk.exists(): raise ValueError(f"Chunk file {chunk} does not exist.") with open(chunk, 'rb') as f: chunk_data = pickle.load(f) chunk_results = {} total_molecules = len(chunk_data) start_time = time.time() failed_molecules = 0 for idx, (collection, mol) in enumerate(chunk_data, start=1): interface.ObtainQueryData([mol]) try: e, f = interface.calculate(**kwargs) except Exception as e: print(f"Error calculating for molecule {mol.get('molecule_id')} in {collection}: {e}") failed_molecules += 1 continue mol_id = mol.get("molecule_id") if collection not in chunk_results: chunk_results[collection] = {"energy": {}, "forces": {}} chunk_results[collection]["energy"][mol_id] = e[mol_id] chunk_results[collection]["forces"][mol_id] = f[mol_id] # Estimate time remaining elapsed_time = time.time() - start_time avg_time_per_molecule = elapsed_time / idx remaining_time = avg_time_per_molecule * (total_molecules - idx) print(f"Processed {idx}/{total_molecules} molecules. Estimated time remaining: {remaining_time:.2f} seconds.") """ for collection, mol in chunk_data: interface.ObtainQueryData([mol]) try: e, f = interface.calculate(**kwargs) except Exception as e: print(f"Error calculating for molecule {mol.get('molecule_id')} in {collection}: {e}") continue mol_id = mol.get("molecule_id") chunk_results[mol_id] = {"energy": e, "forces": f} """ # Save the relabeled data to a file output_file = Path(f"{chunk.stem}_relabel_{interface.level_of_theory}.pckl") with open(output_file, "wb") as f: pickle.dump(chunk_results, f) print(f"Chunk results saved to {output_file}") if failed_molecules > 0: print(f"Warning: {failed_molecules} molecules failed to relabel.") with open("failed_molecules.txt", "a") as f: f.write(f"Failed molecules in {chunk}: {failed_molecules}\n")
def _read_input_data(self, input_pickle): """ Reads in the input data from a pickle file. This is useful for large datasets, where one might want to relabel the data in chunks. Parameters ---------- input_pickle : str The path to the input data pickle file. Returns ------- dict The input data read from the pickle file. """ if not Path(input_pickle).exists(): raise ValueError(f"Input data file {input_pickle} does not exist.") with open(input_pickle, 'rb') as f: input_data = pickle.load(f) self.input_data = input_data
[docs] def combine_subdivided(self, result_files): """ Combines the results of the subdivided relabeling process into a single, calculated result in the same format as the relabel output. Parameters ---------- nchunks : int, optional The number of chunks that were combined. Default is 1. """ combined_results = {} for result_file in result_files: if not Path(result_file).exists(): raise ValueError(f"Result file {result_file} does not exist.") with open(result_file, 'rb') as f: chunk_results = pickle.load(f) # Merge the results into the combined dataset for collection, results in chunk_results.items(): if collection not in combined_results: combined_results[collection] = {"energy": {}, "forces": {}} combined_results[collection]['energy'].update(results['energy']) combined_results[collection]['forces'].update(results['forces']) print(f"Combined results from {result_file}") self.relabeled_data = combined_results print("Combined results stored in relabeled_data.") self._read_input_data("input_data.pckl") return self.relabeled_data
[docs] def to_database(self, database, level_of_theory=None): """ Saves the relabeled data back to the database. """ if self.level_of_theory is None: if level_of_theory is None: raise ValueError("No level of theory provided. Please provide a level of theory.") else: self.level_of_theory = level_of_theory for collection in self.relabeled_data: collection_reported=False for mol_id in self.relabeled_data[collection]["energy"]: # Grab the molecule from the database, and make sure it exists. db_mol = database[collection].find_one({"molecule_id": mol_id}) if db_mol is None: raise ValueError(f"Could not find molecule {mol_id} in the database.") # Check if the theory level already exists in the database. # If it does, don't make changes. existing_theory = db_mol.get("theorylevels").keys() if not collection_reported: print(f"Existing theory levels: {existing_theory}") if self.level_of_theory in existing_theory: print(f"Skipping {mol_id} as it already has a theory level of {self.level_of_theory}") continue else: if not collection_reported: print(f"Adding theory level {self.level_of_theory} to collection {collection}") collection_reported=True #print("molecule_id: ", mol_id) #print("forces: ", self.relabeled_data[collection]["forces"][mol_id]) # Add the new theory level to the database new_theory_data = { "energies": [ energy.tolist()[0] if isinstance(energy, np.ndarray) else energy for energy in self.relabeled_data[collection]["energy"][mol_id] ], "forces": [ force.flatten().tolist() if isinstance(force, np.ndarray) else force for force in self.relabeled_data[collection]["forces"][mol_id] ], } try: database[collection].update_one( {"molecule_id": mol_id}, {"$set": {f"theorylevels.{self.level_of_theory}": new_theory_data}}, ) except Exception as e: print(f"Error updating molecule {mol_id} in collection {collection}: {e}") print(new_theory_data["energies"]) exit()
[docs] @staticmethod def chunk_calc(nchunk): """ Reads in a chunk input file and then runs the relabeling process on that chunk and then saves the results to a new file. Parameters ---------- nchunk : int The chunk number to read in. Returns ------- None """ pass