import pickle
import numpy as np
from pathlib import Path
from pharmaforge.queries.query import Query
import time
[docs]
class Relabeler:
"""
A class to handle relabeling of data in the database. This class is cognizant of the fact that one
might want to relabel data in a database without having to have the database available at the time of relabeling.
Attributes
----------
"""
def __init__(self):
self.collections = {}
self.input_data = {}
self.charge_data = {}
self.relabeled_data = {}
self.level_of_theory = None
pass
[docs]
def select_data(self, database, query=None, select_collection=None):
"""
Select data that must be relabeled. This can either be a query or an entire collection of data.
Parameters
----------
database : pymongo.database.Database
The database to select data from.
query : str, optional
The query to select data from the database. If not provided, the entire collection will be selected.
select_collection : str, optional
The collection to select data from the database. If not provided, the default collection will be used.
"""
self.collections = database.list_collection_names()
if select_collection is None:
for collection in self.collections:
self.input_data[collection] = self._select_data(database, query, collection)
self.charge_data[collection] = {}
for result in self.input_data[collection]:
if result.get("charge") is not None:
self.charge_data[collection][result.get("molecule_id")] = result.get("charge")
else:
self.charge_data[collection][result.get("molecule_id")] = 0
else:
self.input_data[select_collection] = self._select_data(database, query, select_collection)
self.collections = [select_collection]
if len(self.input_data) == 0:
raise ValueError("No data selected. Please check your query or collection name.")
def _select_data(self, database, query, collection=None):
"""
Select data from the database by query.
Parameters
----------
database : pymongo.database.Database
The database to select data from.
query : str
The query to select data from the database.
collection : str
The collection to select data from the database.
"""
if collection is None:
raise ValueError("Collection name must be provided.")
elif collection not in self.collections:
raise ValueError(f"Collection {collection} not found in the database.")
if query is None:
query = "nmols gt 0" # This will select all data in the collection.
q = Query(query)
return q.apply(database, collection)
[docs]
def relabel(self, interface=None, restart=False, **kwargs):
"""
Relabels the data in the entire selection in one go. Ideal for systems with small number of processors, and small datasets.
Parameters
----------
interface : object, optional
The interface to use for relabeling.
**kwargs : dict, optional
Additional arguments to pass to the interface.
"""
if interface is None:
raise ValueError("No interface provided. Please provide an interface for relabeling.")
if not hasattr(interface, "level_of_theory"):
raise ValueError("Interface does not have a level_of_theory attribute. Please provide a valid interface.")
self.level_of_theory = interface.level_of_theory
for collection in self.input_data:
if restart and Path(f"{collection}_{self.level_of_theory}_checkpoint.pckl").exists():
print(f"Checkpoint for {collection} already exists. Skipping relabeling.")
self.relabeled_data[collection] = pickle.load(open(f"{collection}_{self.level_of_theory}_checkpoint.pckl", "rb"))
continue
interface.ObtainQueryData(self.input_data[collection])
e, f = interface.calculate(**kwargs)
self.relabeled_data[collection] = {"energy": e, "forces": f}
# Save a checkpoint of the relabeled data
self._self_save_checkpoint(collection, self.relabeled_data[collection])
return self.relabeled_data
def _self_save_checkpoint(self, collection, data):
"""
Saves a checkpoint of the relabeled data to a file. This is useful for large datasets, where one might want to save the data in chunks.
Parameters
----------
collection_data : dict
The data to save.
collection_name : str
The name of the collection to save the data to.
"""
filename=Path(f"{collection}_{self.level_of_theory}_checkpoint.pckl")
with open(filename, "wb") as f:
pickle.dump(data, f)
print(f"Checkpoint for {collection} saved to {filename}")
[docs]
def subdivided_relabel(self, nchunks=1):
"""
Sets up the relabeling process to be subdivided into smaller chunks with an equal number of tasks. This is ideal for systems with a large number of processors, or for HPC environments where one might want to do many tasks at once.
Parameters
----------
nchunks : int, optional
The number of chunks to subdivide the relabeling process into. Default is 1.
"""
#Flatten the input data into a single list of (collection, molecule) tuples
all_molecules = []
for collection, molecules in self.input_data.items():
print(f"Collection: {collection}, Number of molecules: {len(molecules)}")
for mol in molecules:
nconfigs = mol.get('nconfigs')
all_molecules.append((collection, mol, nconfigs))
# Sort the molecules by number of configurations
all_molecules.sort(key=lambda x: x[2], reverse=True)
# Calculate the number of configurations per chunk
total_configs = sum(mol[2] for mol in all_molecules)
target_configs_per_chunk = total_configs // nchunks
# Distribute the molecules into chunks
chunks = [[] for _ in range(nchunks)]
chunk_config_counts = [0] * nchunks
for collection, mol, nconfigs in all_molecules:
min_chunk_index = chunk_config_counts.index(min(chunk_config_counts))
chunks[min_chunk_index].append((collection, mol))
chunk_config_counts[min_chunk_index] += nconfigs
# Save each chunk to a file
for i, chunk in enumerate(chunks):
filename = Path(f"chunk_{i}.pckl")
with open(filename, "wb") as f:
pickle.dump(chunk, f)
print(f"Chunk {i+1} saved to {filename} with {len(chunk)} molecules and {chunk_config_counts[i]} configurations total.")
with open("input_data.pckl", "wb") as f:
pickle.dump(self.input_data, f)
print("Input data saved to input_data.pckl")
[docs]
@staticmethod
def relabel_chunk(chunk, interface=None, **kwargs):
"""
Relabels a chunk of data. This is useful for large datasets, where one might want to relabel the data in chunks.
Parameters
----------
chunk : list
The chunk of data to relabel.
interface : object, optional
The interface to use for relabeling.
output_file : str, optional
The file to save the relabeled data to.
**kwargs : dict, optional
Additional arguments to pass to the interface.
"""
if interface is None:
raise ValueError("No interface provided. Please provide an interface for relabeling.")
if not hasattr(interface, "level_of_theory"):
raise ValueError("Interface does not have a level_of_theory attribute. Please provide a valid interface.")
chunk = Path(chunk)
if not chunk.exists():
raise ValueError(f"Chunk file {chunk} does not exist.")
with open(chunk, 'rb') as f:
chunk_data = pickle.load(f)
chunk_results = {}
total_molecules = len(chunk_data)
start_time = time.time()
failed_molecules = 0
for idx, (collection, mol) in enumerate(chunk_data, start=1):
interface.ObtainQueryData([mol])
try:
e, f = interface.calculate(**kwargs)
except Exception as e:
print(f"Error calculating for molecule {mol.get('molecule_id')} in {collection}: {e}")
failed_molecules += 1
continue
mol_id = mol.get("molecule_id")
if collection not in chunk_results:
chunk_results[collection] = {"energy": {}, "forces": {}}
chunk_results[collection]["energy"][mol_id] = e[mol_id]
chunk_results[collection]["forces"][mol_id] = f[mol_id]
# Estimate time remaining
elapsed_time = time.time() - start_time
avg_time_per_molecule = elapsed_time / idx
remaining_time = avg_time_per_molecule * (total_molecules - idx)
print(f"Processed {idx}/{total_molecules} molecules. Estimated time remaining: {remaining_time:.2f} seconds.")
"""
for collection, mol in chunk_data:
interface.ObtainQueryData([mol])
try:
e, f = interface.calculate(**kwargs)
except Exception as e:
print(f"Error calculating for molecule {mol.get('molecule_id')} in {collection}: {e}")
continue
mol_id = mol.get("molecule_id")
chunk_results[mol_id] = {"energy": e, "forces": f}
"""
# Save the relabeled data to a file
output_file = Path(f"{chunk.stem}_relabel_{interface.level_of_theory}.pckl")
with open(output_file, "wb") as f:
pickle.dump(chunk_results, f)
print(f"Chunk results saved to {output_file}")
if failed_molecules > 0:
print(f"Warning: {failed_molecules} molecules failed to relabel.")
with open("failed_molecules.txt", "a") as f:
f.write(f"Failed molecules in {chunk}: {failed_molecules}\n")
def _read_input_data(self, input_pickle):
"""
Reads in the input data from a pickle file. This is useful for large datasets, where one might want to relabel the data in chunks.
Parameters
----------
input_pickle : str
The path to the input data pickle file.
Returns
-------
dict
The input data read from the pickle file.
"""
if not Path(input_pickle).exists():
raise ValueError(f"Input data file {input_pickle} does not exist.")
with open(input_pickle, 'rb') as f:
input_data = pickle.load(f)
self.input_data = input_data
[docs]
def combine_subdivided(self, result_files):
"""
Combines the results of the subdivided relabeling process into a single, calculated result in the same format as the relabel output.
Parameters
----------
nchunks : int, optional
The number of chunks that were combined. Default is 1.
"""
combined_results = {}
for result_file in result_files:
if not Path(result_file).exists():
raise ValueError(f"Result file {result_file} does not exist.")
with open(result_file, 'rb') as f:
chunk_results = pickle.load(f)
# Merge the results into the combined dataset
for collection, results in chunk_results.items():
if collection not in combined_results:
combined_results[collection] = {"energy": {}, "forces": {}}
combined_results[collection]['energy'].update(results['energy'])
combined_results[collection]['forces'].update(results['forces'])
print(f"Combined results from {result_file}")
self.relabeled_data = combined_results
print("Combined results stored in relabeled_data.")
self._read_input_data("input_data.pckl")
return self.relabeled_data
[docs]
def to_database(self, database, level_of_theory=None):
"""
Saves the relabeled data back to the database.
"""
if self.level_of_theory is None:
if level_of_theory is None:
raise ValueError("No level of theory provided. Please provide a level of theory.")
else:
self.level_of_theory = level_of_theory
for collection in self.relabeled_data:
collection_reported=False
for mol_id in self.relabeled_data[collection]["energy"]:
# Grab the molecule from the database, and make sure it exists.
db_mol = database[collection].find_one({"molecule_id": mol_id})
if db_mol is None:
raise ValueError(f"Could not find molecule {mol_id} in the database.")
# Check if the theory level already exists in the database.
# If it does, don't make changes.
existing_theory = db_mol.get("theorylevels").keys()
if not collection_reported: print(f"Existing theory levels: {existing_theory}")
if self.level_of_theory in existing_theory:
print(f"Skipping {mol_id} as it already has a theory level of {self.level_of_theory}")
continue
else:
if not collection_reported: print(f"Adding theory level {self.level_of_theory} to collection {collection}")
collection_reported=True
#print("molecule_id: ", mol_id)
#print("forces: ", self.relabeled_data[collection]["forces"][mol_id])
# Add the new theory level to the database
new_theory_data = {
"energies": [
energy.tolist()[0] if isinstance(energy, np.ndarray) else energy
for energy in self.relabeled_data[collection]["energy"][mol_id]
],
"forces": [
force.flatten().tolist() if isinstance(force, np.ndarray) else force
for force in self.relabeled_data[collection]["forces"][mol_id]
],
}
try:
database[collection].update_one(
{"molecule_id": mol_id},
{"$set": {f"theorylevels.{self.level_of_theory}": new_theory_data}},
)
except Exception as e:
print(f"Error updating molecule {mol_id} in collection {collection}: {e}")
print(new_theory_data["energies"])
exit()
[docs]
@staticmethod
def chunk_calc(nchunk):
"""
Reads in a chunk input file and then runs the relabeling process on that chunk and then saves the results to a new file.
Parameters
----------
nchunk : int
The chunk number to read in.
Returns
-------
None
"""
pass