Source code for pharmaforge.dbutils.mongo_utils

"""MongoDB utility functions for handling molecular data.

This module provides functions to create collections, add fields, modify documents,
and visualize data in a MongoDB database. It also includes functions to create HDF5 files
from query results and to count documents without SMILES strings.

"""


import numpy as np
import matplotlib.pyplot as plt

from pymongo import MongoClient
from collections import Counter
from pymongo.database import Database
from pathlib import Path




[docs] def add_field_to_all_documents(database_name, collection_name, new_field_name, new_field_value): """ Add a new field to all documents in a MongoDB collection. Parameters ---------- database_name : str Name of the MongoDB database. collection_name : str Name of the collection within the database. new_field_name : str Name of the new field to add. new_field_value : any Value to set for the new field. Returns ------- dict Summary of the update operation, including matched and modified counts. """ # Connect to MongoDB # TODO: Add error handling for connection issues client = MongoClient("mongodb://localhost:27017") db = client[database_name] collection = db[collection_name] # Update all documents to include the new field result = collection.update_many( {}, # Match all documents {"$set": {new_field_name: new_field_value}} # Add the new field with the given value ) # Return a summary of the operation return { "matched_count": result.matched_count, "modified_count": result.modified_count, "message": f"Added '{new_field_name}' to all documents with value '{new_field_value}'." }
[docs] def modify_field_by_query(database_name, collection_name, query, field_name, new_value): """ Modify a specific field for documents matching the query criteria. Parameters ---------- database_name : str Name of the MongoDB database. collection_name : str Name of the collection within the database. query : dict MongoDB query to filter documents. field_name : str Name of the field to modify. new_value : any New value to set for the field. Returns ------- dict Summary of the update operation, including matched and modified counts. """ #TODO: Add error handling for connection issues #TODO: Should we allow different mongodb ports? client = MongoClient("mongodb://localhost:27017") db = client[database_name] collection = db[collection_name] result = collection.update_many( query, {"$set": {field_name: new_value}} ) return { "matched_count": result.matched_count, "modified_count": result.modified_count, "message": f"Modified '{field_name}' for {result.modified_count} documents matching the query." }
[docs] def add_configurations_count(database_name, collection_name): """ Add a 'configurations_count' field to all documents in the collection based on the number of configurations in the 'set.000.force.npy' array. Parameters ---------- database_name : str Name of the MongoDB database. collection_name : str Name of the collection within the database. Returns ------- dict Summary of how many documents were updated. """ # TODO: Add error handling for connection issues # TODO: Should we allow different mongodb ports? client = MongoClient("mongodb://localhost:27017") db = client[database_name] collection = db[collection_name] updated_count = 0 for doc in collection.find(): molecule_id = doc.get("molecule_id") coords = doc.get("coordinates") configs_count = np.array(coords).shape[0] if coords is not None else 0 result = collection.update_one( {"_id": doc["_id"]}, {"$set": {"configurations_count": configs_count}} ) if result.modified_count > 0: updated_count += 1 return { "modified_documents": updated_count, "message": f"Added 'configurations_count' to {updated_count} documents." }
[docs] def modify_single_field(database_name, collection_name, field_name, old_value, new_value): """ Modify a single attribute value across all documents that match the old value. Parameters ---------- database_name : str Name of the MongoDB database. collection_name : str Name of the collection within the database. field_name : str Name of the field to modify. old_value : any The value to be replaced. new_value : any The new value to set. Returns ------- dict Summary of the update operation, including matched and modified counts. """ # TODO: Add error handling for connection issues # TODO: Should we allow different mongodb ports? client = MongoClient("mongodb://localhost:27017") db = client[database_name] collection = db[collection_name] result = collection.update_many( {field_name: old_value}, # Match documents with this field value {"$set": {field_name: new_value}} # Update the field value ) return { "matched": result.matched_count, "modified": result.modified_count, "message": f"Updated '{field_name}' from '{old_value}' to '{new_value}' in {result.modified_count} documents." }
[docs] def plot_molecule_frequency_histogram(database_name, collection_name, smiles_field="smiles"): """ Create a histogram of molecule counts for each type of molecule in a MongoDB collection. Parameters ---------- database_name : str Name of the MongoDB database. collection_name : str Name of the collection within the database. smiles_field : str Field name containing the SMILES string. Returns ------- None The function displays a histogram of molecule counts. """ client = MongoClient("mongodb://localhost:27017") db = client[database_name] collection = db[collection_name] molecule_frequencies = [] for document in collection.find(): smiles = document.get(smiles_field, "") if smiles: num_molecules = len(smiles.split('.')) # Count molecules separated by '.' molecule_frequencies.append(num_molecules) frequency_counts = Counter(molecule_frequencies) # print(molecule_frequencies) # Plot the histogram plt.figure(figsize=(10, 6)) plt.bar(frequency_counts.keys(), frequency_counts.values(), color="skyblue", edgecolor="black") plt.xlabel("Number of Molecules") plt.ylabel("Frequency") plt.title(f"Frequency Histogram of Molecule Counts in '{collection_name}' Collection") plt.xticks(range(min(frequency_counts.keys()), max(frequency_counts.keys()) + 1)) # Ensure all integers are displayed plt.tight_layout() plt.show()
[docs] def count_docs_without_smiles_for_all_collections(db: Database): """ Prints the number of documents without SMILES in each collection of the MongoDB database. Parameters ---------- db : Database The MongoDB database instance. Returns ------- None The function prints the count of documents without SMILES for each collection. """ for collection_name in db.list_collection_names(): collection = db[collection_name] missing_smiles_count = collection.count_documents({ "$or": [ {"smiles": {"$exists": False}}, {"smiles": None} ] }) print(f"{collection_name}: {missing_smiles_count} documents without SMILES")
[docs] def clientloader(mongoclient='mongodb://localhost:27017/'): """ Loads the database from the MongoDB server. Parameters ---------- mongoclient : str The MongoDB client connection string. Default is 'mongodb://localhost:27017/'. Returns ------- client : MongoClient The MongoDB client object. """ try: client = MongoClient(mongoclient) except: print("MongoDB is not running, or you haven't created the database in the previous example. Please start MongoDB and try again.") exit() return client