"""MongoDB utility functions for handling molecular data.
This module provides functions to create collections, add fields, modify documents,
and visualize data in a MongoDB database. It also includes functions to create HDF5 files
from query results and to count documents without SMILES strings.
"""
import numpy as np
import matplotlib.pyplot as plt
from pymongo import MongoClient
from collections import Counter
from pymongo.database import Database
from pathlib import Path
[docs]
def add_field_to_all_documents(database_name, collection_name, new_field_name, new_field_value):
"""
Add a new field to all documents in a MongoDB collection.
Parameters
----------
database_name : str
Name of the MongoDB database.
collection_name : str
Name of the collection within the database.
new_field_name : str
Name of the new field to add.
new_field_value : any
Value to set for the new field.
Returns
-------
dict
Summary of the update operation, including matched and modified counts.
"""
# Connect to MongoDB
# TODO: Add error handling for connection issues
client = MongoClient("mongodb://localhost:27017")
db = client[database_name]
collection = db[collection_name]
# Update all documents to include the new field
result = collection.update_many(
{}, # Match all documents
{"$set": {new_field_name: new_field_value}} # Add the new field with the given value
)
# Return a summary of the operation
return {
"matched_count": result.matched_count,
"modified_count": result.modified_count,
"message": f"Added '{new_field_name}' to all documents with value '{new_field_value}'."
}
[docs]
def modify_field_by_query(database_name, collection_name, query, field_name, new_value):
"""
Modify a specific field for documents matching the query criteria.
Parameters
----------
database_name : str
Name of the MongoDB database.
collection_name : str
Name of the collection within the database.
query : dict
MongoDB query to filter documents.
field_name : str
Name of the field to modify.
new_value : any
New value to set for the field.
Returns
-------
dict
Summary of the update operation, including matched and modified counts.
"""
#TODO: Add error handling for connection issues
#TODO: Should we allow different mongodb ports?
client = MongoClient("mongodb://localhost:27017")
db = client[database_name]
collection = db[collection_name]
result = collection.update_many(
query,
{"$set": {field_name: new_value}}
)
return {
"matched_count": result.matched_count,
"modified_count": result.modified_count,
"message": f"Modified '{field_name}' for {result.modified_count} documents matching the query."
}
[docs]
def add_configurations_count(database_name, collection_name):
"""
Add a 'configurations_count' field to all documents in the collection based on the number of configurations
in the 'set.000.force.npy' array.
Parameters
----------
database_name : str
Name of the MongoDB database.
collection_name : str
Name of the collection within the database.
Returns
-------
dict
Summary of how many documents were updated.
"""
# TODO: Add error handling for connection issues
# TODO: Should we allow different mongodb ports?
client = MongoClient("mongodb://localhost:27017")
db = client[database_name]
collection = db[collection_name]
updated_count = 0
for doc in collection.find():
molecule_id = doc.get("molecule_id")
coords = doc.get("coordinates")
configs_count = np.array(coords).shape[0] if coords is not None else 0
result = collection.update_one(
{"_id": doc["_id"]},
{"$set": {"configurations_count": configs_count}}
)
if result.modified_count > 0:
updated_count += 1
return {
"modified_documents": updated_count,
"message": f"Added 'configurations_count' to {updated_count} documents."
}
[docs]
def modify_single_field(database_name, collection_name, field_name, old_value, new_value):
"""
Modify a single attribute value across all documents that match the old value.
Parameters
----------
database_name : str
Name of the MongoDB database.
collection_name : str
Name of the collection within the database.
field_name : str
Name of the field to modify.
old_value : any
The value to be replaced.
new_value : any
The new value to set.
Returns
-------
dict
Summary of the update operation, including matched and modified counts.
"""
# TODO: Add error handling for connection issues
# TODO: Should we allow different mongodb ports?
client = MongoClient("mongodb://localhost:27017")
db = client[database_name]
collection = db[collection_name]
result = collection.update_many(
{field_name: old_value}, # Match documents with this field value
{"$set": {field_name: new_value}} # Update the field value
)
return {
"matched": result.matched_count,
"modified": result.modified_count,
"message": f"Updated '{field_name}' from '{old_value}' to '{new_value}' in {result.modified_count} documents."
}
[docs]
def plot_molecule_frequency_histogram(database_name, collection_name, smiles_field="smiles"):
"""
Create a histogram of molecule counts for each type of molecule in a MongoDB collection.
Parameters
----------
database_name : str
Name of the MongoDB database.
collection_name : str
Name of the collection within the database.
smiles_field : str
Field name containing the SMILES string.
Returns
-------
None
The function displays a histogram of molecule counts.
"""
client = MongoClient("mongodb://localhost:27017")
db = client[database_name]
collection = db[collection_name]
molecule_frequencies = []
for document in collection.find():
smiles = document.get(smiles_field, "")
if smiles:
num_molecules = len(smiles.split('.')) # Count molecules separated by '.'
molecule_frequencies.append(num_molecules)
frequency_counts = Counter(molecule_frequencies)
# print(molecule_frequencies)
# Plot the histogram
plt.figure(figsize=(10, 6))
plt.bar(frequency_counts.keys(), frequency_counts.values(), color="skyblue", edgecolor="black")
plt.xlabel("Number of Molecules")
plt.ylabel("Frequency")
plt.title(f"Frequency Histogram of Molecule Counts in '{collection_name}' Collection")
plt.xticks(range(min(frequency_counts.keys()), max(frequency_counts.keys()) + 1)) # Ensure all integers are displayed
plt.tight_layout()
plt.show()
[docs]
def count_docs_without_smiles_for_all_collections(db: Database):
"""
Prints the number of documents without SMILES in each collection of the MongoDB database.
Parameters
----------
db : Database
The MongoDB database instance.
Returns
-------
None
The function prints the count of documents without SMILES for each collection.
"""
for collection_name in db.list_collection_names():
collection = db[collection_name]
missing_smiles_count = collection.count_documents({
"$or": [
{"smiles": {"$exists": False}},
{"smiles": None}
]
})
print(f"{collection_name}: {missing_smiles_count} documents without SMILES")
[docs]
def clientloader(mongoclient='mongodb://localhost:27017/'):
"""
Loads the database from the MongoDB server.
Parameters
----------
mongoclient : str
The MongoDB client connection string. Default is 'mongodb://localhost:27017/'.
Returns
-------
client : MongoClient
The MongoDB client object.
"""
try:
client = MongoClient(mongoclient)
except:
print("MongoDB is not running, or you haven't created the database in the previous example. Please start MongoDB and try again.")
exit()
return client