diff --git a/generations/library_rag/mcp_tools/retrieval_tools.py b/generations/library_rag/mcp_tools/retrieval_tools.py index 3660b5a..f7bd618 100644 --- a/generations/library_rag/mcp_tools/retrieval_tools.py +++ b/generations/library_rag/mcp_tools/retrieval_tools.py @@ -69,9 +69,39 @@ from mcp_tools.logging_config import ( log_weaviate_query, ) +# GPU embedder for BGE-M3 vectorization (replaces text2vec-transformers) +from memory.core import get_embedder + # Logger for this module - uses structured logging logger = get_tool_logger("retrieval") +# ============================================================================= +# GPU Embedder Singleton (BGE-M3) +# ============================================================================= + +_embedder = None + + +def get_gpu_embedder(): + """Get or create GPU embedder singleton for BGE-M3 vectorization. + + Returns the shared GPU embedding service instance. The embedder uses + BAAI/bge-m3 model (1024 dimensions) for semantic vectorization. + + Returns: + GPUEmbeddingService instance. + + Note: + This singleton pattern ensures the model is loaded only once, + avoiding repeated GPU memory allocation. + """ + global _embedder + if _embedder is None: + logger.info("Initializing GPU embedder (BGE-M3) for retrieval...") + _embedder = get_embedder() + logger.info(f"GPU embedder ready: {_embedder.model_name}") + return _embedder + # ============================================================================= # Canonical Reference Extraction @@ -472,10 +502,14 @@ async def search_chunks_handler(input_data: SearchChunksInput) -> SearchChunksOu ) filters = (filters & lang_f) if filters else lang_f - # Perform near_text query with timing + # Vectorize query with GPU embedder (BGE-M3) + embedder = get_gpu_embedder() + query_vector = embedder.embed_single(input_data.query) + + # Perform near_vector query with timing query_start = time.perf_counter() - result = chunks.query.near_text( - query=input_data.query, + result = chunks.query.near_vector( + near_vector=query_vector.tolist(), limit=input_data.limit, filters=filters, return_metadata=wvq.MetadataQuery(distance=True), @@ -484,7 +518,7 @@ async def search_chunks_handler(input_data: SearchChunksInput) -> SearchChunksOu # Log Weaviate query log_weaviate_query( - operation="near_text", + operation="near_vector", collection="Chunk", filters={ "author": input_data.author_filter, @@ -619,10 +653,14 @@ async def search_summaries_handler( ) filters = (filters & max_filter) if filters else max_filter - # Perform near_text query with timing + # Vectorize query with GPU embedder (BGE-M3) + embedder = get_gpu_embedder() + query_vector = embedder.embed_single(input_data.query) + + # Perform near_vector query with timing query_start = time.perf_counter() - result = summaries.query.near_text( - query=input_data.query, + result = summaries.query.near_vector( + near_vector=query_vector.tolist(), limit=input_data.limit, filters=filters, return_metadata=wvq.MetadataQuery(distance=True), @@ -631,7 +669,7 @@ async def search_summaries_handler( # Log Weaviate query log_weaviate_query( - operation="near_text", + operation="near_vector", collection="Summary", filters={ "min_level": input_data.min_level,