2022-07-05 08:37:34 +00:00
# Copyright (C) 2022-present MongoDB, Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the Server Side Public License, version 1,
# as published by MongoDB, Inc.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# Server Side Public License for more details.
#
# You should have received a copy of the Server Side Public License
# along with this program. If not, see
# <http://www.mongodb.com/licensing/server-side-public-license>.
#
# As a special exception, the copyright holders give permission to link the
# code of portions of this program with the OpenSSL library under certain
# conditions as described in each individual source file and distribute
# linked combinations including the program with the OpenSSL library. You
# must comply with the Server Side Public License in all respects for
# all of the code used other than as permitted herein. If you modify file(s)
# with this exception, you may extend this exception to your version of the
# file(s), but you are not obligated to do so. If you do not wish to do so,
# delete this exception statement from your version. If you delete this
# exception statement from all source files in the program, then also delete
# it in the license file.
#
""" A wrapper with useful methods over MongoDB database. """
from __future__ import annotations
2024-10-10 10:59:18 -07:00
2022-07-05 08:37:34 +00:00
import subprocess
2022-11-03 16:32:27 +00:00
from contextlib import asynccontextmanager
2024-10-10 10:59:18 -07:00
from typing import Any , Mapping , NewType , Sequence
2022-07-05 08:37:34 +00:00
from config import DatabaseConfig , RestoreMode
2025-05-30 08:40:04 -04:00
from pymongo import AsyncMongoClient
2022-07-05 08:37:34 +00:00
2025-07-23 14:58:46 -04:00
__all__ = [ " DatabaseInstance " , " Find " ]
2022-07-05 08:37:34 +00:00
""" MongoDB Aggregate ' s Pipeline """
2025-07-23 14:58:46 -04:00
Find = NewType ( " Find " , Mapping [ str , Any ] )
2022-07-05 08:37:34 +00:00
class DatabaseInstance :
""" MongoDB Database wrapper. """
def __init__ ( self , config : DatabaseConfig ) - > None :
""" Initialize wrapper. """
self . config = config
2025-05-30 08:40:04 -04:00
self . client = AsyncMongoClient ( config . connection_string )
2022-08-30 13:36:52 +00:00
self . database = self . client [ config . database_name ]
2022-07-05 08:37:34 +00:00
def __enter__ ( self ) :
if self . config . restore_from_dump == RestoreMode . ALWAYS or (
2024-05-16 18:00:17 -04:00
self . config . restore_from_dump == RestoreMode . ONLY_NEW
and self . config . database_name not in self . client . list_database_names ( )
) :
2022-07-05 08:37:34 +00:00
self . restore ( )
return self
def __exit__ ( self , exc_type , exc_value , traceback ) :
if self . config . dump_on_exit :
self . dump ( )
2022-08-30 13:36:52 +00:00
async def drop ( self ) :
2022-07-05 08:37:34 +00:00
""" Drop the database. """
2022-08-30 13:36:52 +00:00
await self . client . drop_database ( self . config . database_name )
2022-07-05 08:37:34 +00:00
def restore ( self ) :
""" Restore the database from the ' self.dump_directory ' . """
2024-05-16 18:00:17 -04:00
subprocess . run (
[ " mongorestore " , " --nsInclude " , f " { self . config . database_name } .* " , " --drop " ] ,
shell = True ,
check = True ,
)
2022-07-05 08:37:34 +00:00
def dump ( self ) :
""" Dump the database into ' self.dump_directory ' . """
2025-06-27 13:04:57 -04:00
subprocess . run ( [ " mongodump " , " --db " , self . config . database_name ] , check = True )
2022-07-05 08:37:34 +00:00
2022-10-13 15:40:45 +00:00
async def set_parameter ( self , name : str , value : any ) - > None :
""" Set MongoDB Parameter. """
2024-05-16 18:00:17 -04:00
await self . client . admin . command ( { " setParameter " : 1 , name : value } )
2022-10-13 15:40:45 +00:00
2022-11-03 16:32:27 +00:00
async def get_parameter ( self , name : str ) - > any :
2024-05-16 18:00:17 -04:00
return ( await self . client . admin . command ( { " getParameter " : 1 , name : 1 } ) ) [ name ]
2022-11-03 16:32:27 +00:00
2022-08-30 13:36:52 +00:00
async def enable_sbe ( self , state : bool ) - > None :
2022-07-05 08:37:34 +00:00
""" Enable new query execution engine. Throw pymongo.errors.OperationFailure in case of failure. """
2024-05-16 18:00:17 -04:00
await self . set_parameter (
" internalQueryFrameworkControl " , " trySbeEngine " if state else " forceClassicEngine "
)
2022-07-05 08:37:34 +00:00
2025-07-23 14:58:46 -04:00
async def explain ( self , collection_name : str , find : Find ) - > dict [ str , any ] :
""" Return explain for the given find command. """
2022-08-30 13:36:52 +00:00
return await self . database . command (
2024-05-16 18:00:17 -04:00
" explain " ,
2025-07-23 14:58:46 -04:00
{ " find " : collection_name , * * find } ,
2024-05-16 18:00:17 -04:00
verbosity = " executionStats " ,
)
2022-07-05 08:37:34 +00:00
2022-08-30 13:36:52 +00:00
async def hide_index ( self , collection_name : str , index_name : str ) - > None :
2022-07-05 08:37:34 +00:00
""" Hide the given index from the query optimizer. """
2022-08-30 13:36:52 +00:00
await self . database . command (
2024-05-16 18:00:17 -04:00
{ " collMod " : collection_name , " index " : { " name " : index_name , " hidden " : True } }
)
2022-07-05 08:37:34 +00:00
2022-08-30 13:36:52 +00:00
async def unhide_index ( self , collection_name : str , index_name : str ) - > None :
2022-07-05 08:37:34 +00:00
""" Make the given index visible for the query optimizer. """
2022-08-30 13:36:52 +00:00
await self . database . command (
2024-05-16 18:00:17 -04:00
{ " collMod " : collection_name , " index " : { " name " : index_name , " hidden " : False } }
)
2022-07-05 08:37:34 +00:00
2022-08-30 13:36:52 +00:00
async def hide_all_indexes ( self , collection_name : str ) - > None :
2022-07-05 08:37:34 +00:00
""" Hide all indexes of the given collection from the query optimizer. """
for index in self . database [ collection_name ] . list_indexes ( ) :
2024-05-16 18:00:17 -04:00
if index [ " name " ] != " _id_ " :
await self . hide_index ( collection_name , index [ " name " ] )
2022-07-05 08:37:34 +00:00
2022-08-30 13:36:52 +00:00
async def unhide_all_indexes ( self , collection_name : str ) - > None :
2022-07-05 08:37:34 +00:00
""" Make all indexes of the given collection visible fpr the query optimizer. """
for index in self . database [ collection_name ] . list_indexes ( ) :
2024-05-16 18:00:17 -04:00
if index [ " name " ] != " _id_ " :
await self . unhide_index ( collection_name , index [ " name " ] )
2022-07-05 08:37:34 +00:00
2022-08-30 13:36:52 +00:00
async def drop_collection ( self , collection_name : str ) - > None :
2022-07-05 08:37:34 +00:00
""" Drop collection. """
2022-08-30 13:36:52 +00:00
await self . database [ collection_name ] . drop ( )
2022-07-05 08:37:34 +00:00
2022-08-30 13:36:52 +00:00
async def insert_many ( self , collection_name : str , docs : Sequence [ Mapping [ str , any ] ] ) - > None :
2022-07-05 08:37:34 +00:00
""" Insert documents into the collection with the given name. """
2022-11-08 20:06:19 +00:00
if len ( docs ) > 0 :
await self . database [ collection_name ] . insert_many ( docs , ordered = False )
2022-07-05 08:37:34 +00:00
2022-08-30 13:36:52 +00:00
async def get_all_documents ( self , collection_name : str ) :
2022-07-05 08:37:34 +00:00
""" Get all documents from the collection with the given name. """
2022-08-30 13:36:52 +00:00
return await self . database [ collection_name ] . find ( { } ) . to_list ( length = None )
2022-07-26 18:08:48 +00:00
2022-08-30 13:36:52 +00:00
async def get_stats ( self , collection_name : str ) :
2022-07-26 18:08:48 +00:00
""" Get collection statistics. """
2024-05-16 18:00:17 -04:00
return await self . database . command ( " collstats " , collection_name )
2022-07-26 18:08:48 +00:00
2022-08-30 13:36:52 +00:00
async def get_average_document_size ( self , collection_name : str ) - > float :
2022-07-26 18:08:48 +00:00
""" Get average document size for the given collection. """
2022-08-30 13:36:52 +00:00
stats = await self . get_stats ( collection_name )
2024-05-16 18:00:17 -04:00
avg_size = stats . get ( " avgObjSize " )
2022-07-26 18:08:48 +00:00
return avg_size if avg_size is not None else 0
2022-11-03 16:32:27 +00:00
class DatabaseParameter :
""" A utility class to work with MongoDB parameters. """
def __init__ ( self , database : DatabaseInstance , parameter_name : str ) - > None :
""" Initialize the class. """
self . database = database
self . parameter_name = parameter_name
self . original_value = None
async def set ( self , value ) :
""" Set the parameter ' s value. """
await self . database . set_parameter ( self . parameter_name , value )
async def remember ( self ) :
""" Store the current value of the parameter so it can be restored lately. """
self . original_value = await self . database . get_parameter ( self . parameter_name )
async def restore ( self ) :
""" Restore the remebered value of the parameter. """
if self . original_value is not None :
await self . set ( self . original_value )
else :
raise ValueError ( f ' The parameter " { self . parameter_name } " has not been remembered. ' )
@asynccontextmanager
async def get_database_parameter ( database : DatabaseInstance , parameter_name : str ) :
""" Create a new instance of a context manager on top of DatabaseParameter. It restores the original value on teardown. Useful when we need temporarily change a parameter. """
param = DatabaseParameter ( database , parameter_name )
await param . remember ( )
try :
yield param
finally :
await param . restore ( )