diff --git a/generations/library_rag/flask_app.py b/generations/library_rag/flask_app.py index d207da2..88571c9 100644 --- a/generations/library_rag/flask_app.py +++ b/generations/library_rag/flask_app.py @@ -917,7 +917,11 @@ def search() -> str: ) -def rag_search(query: str, limit: int = 5) -> List[Dict[str, Any]]: +def rag_search( + query: str, + limit: int = 5, + selected_works: Optional[List[str]] = None +) -> List[Dict[str, Any]]: """Search passages for RAG context with formatted results. Wraps the existing search_passages() function but returns results formatted @@ -927,6 +931,9 @@ def rag_search(query: str, limit: int = 5) -> List[Dict[str, Any]]: Args: query: The user's question or search query. limit: Maximum number of context chunks to retrieve. Defaults to 5. + selected_works: Optional list of work titles to filter results. + If provided and non-empty, only chunks from these works are returned. + If None or empty, all works are included (no filter). Returns: List of context dictionaries with keys: @@ -943,10 +950,19 @@ def rag_search(query: str, limit: int = 5) -> List[Dict[str, Any]]: 'Platon' >>> results[0]["work"] 'République' + + >>> # With work filter + >>> results = rag_search("la vertu", limit=5, selected_works=["Ménon"]) + >>> all(r["work"] == "Ménon" for r in results) + True """ import time start_time = time.time() + # Normalize selected_works + if selected_works is None: + selected_works = [] + try: with get_weaviate_client() as client: if client is None: @@ -955,10 +971,18 @@ def rag_search(query: str, limit: int = 5) -> List[Dict[str, Any]]: chunks = client.collections.get("Chunk") + # Build work filter if selected_works is provided + work_filter: Optional[Any] = None + if selected_works: + # Use contains_any for filtering multiple work titles + work_filter = wvq.Filter.by_property("workTitle").contains_any(selected_works) + print(f"[RAG Search] Applying work filter: {selected_works}") + # Query with properties needed for RAG context result = chunks.query.near_text( query=query, limit=limit, + filters=work_filter, return_metadata=wvq.MetadataQuery(distance=True), return_properties=[ "text", @@ -1001,7 +1025,8 @@ def diverse_author_search( limit: int = 10, initial_pool: int = 100, max_authors: int = 5, - chunks_per_author: int = 2 + chunks_per_author: int = 2, + selected_works: Optional[List[str]] = None ) -> List[Dict[str, Any]]: """Search passages with author diversity to avoid corpus imbalance bias. @@ -1023,6 +1048,9 @@ def diverse_author_search( initial_pool: Size of initial candidate pool (default: 100). max_authors: Maximum number of distinct authors to include (default: 5). chunks_per_author: Number of chunks per selected author (default: 2). + selected_works: Optional list of work titles to filter results. + If provided and non-empty, only chunks from these works are returned. + If None or empty, all works are included (no filter). Returns: List of context dictionaries with keys: @@ -1049,12 +1077,17 @@ def diverse_author_search( import time start_time = time.time() - print(f"[Diverse Search] CALLED with query='{query[:50]}...', initial_pool={initial_pool}, max_authors={max_authors}, chunks_per_author={chunks_per_author}") + # Normalize selected_works + if selected_works is None: + selected_works = [] + + works_filter_str = selected_works if selected_works else "all" + print(f"[Diverse Search] CALLED with query='{query[:50]}...', initial_pool={initial_pool}, max_authors={max_authors}, chunks_per_author={chunks_per_author}, selected_works={works_filter_str}") try: - # Step 1: Retrieve large initial pool - print(f"[Diverse Search] Calling rag_search with limit={initial_pool}") - candidates = rag_search(query, limit=initial_pool) + # Step 1: Retrieve large initial pool (with work filter if specified) + print(f"[Diverse Search] Calling rag_search with limit={initial_pool}, selected_works={works_filter_str}") + candidates = rag_search(query, limit=initial_pool, selected_works=selected_works) print(f"[Diverse Search] rag_search returned {len(candidates)} candidates") if not candidates: @@ -1126,9 +1159,9 @@ def diverse_author_search( import traceback print(f"[Diverse Search] EXCEPTION CAUGHT: {e}") print(f"[Diverse Search] Traceback: {traceback.format_exc()}") - print(f"[Diverse Search] Falling back to standard rag_search with limit={limit}") - # Fallback to standard search - return rag_search(query, limit) + print(f"[Diverse Search] Falling back to standard rag_search with limit={limit}, selected_works={works_filter_str}") + # Fallback to standard search (preserve work filter) + return rag_search(query, limit, selected_works=selected_works) def build_prompt_with_context(user_question: str, rag_context: List[Dict[str, Any]]) -> str: @@ -1707,13 +1740,15 @@ def run_chat_generation( # The question parameter here is the final chosen version (original or reformulated) # Step 1: Diverse author search (avoids corpus imbalance bias) + # Apply selected_works filter if specified session["status"] = "searching" rag_context = diverse_author_search( query=question, limit=25, # Get 25 diverse chunks initial_pool=200, # LARGE pool to find all relevant authors (increased from 100) max_authors=8, # Include up to 8 distinct authors (increased from 6) - chunks_per_author=3 # Max 3 chunks per author for balance + chunks_per_author=3, # Max 3 chunks per author for balance + selected_works=selected_works # Filter by selected works (empty = all) ) print(f"[Pipeline] diverse_author_search returned {len(rag_context)} chunks")