""" 用于 LangChain 代理的知识库检索工具。 """ from typing import Optional, Type from pydantic import BaseModel, Field from langchain_core.tools import BaseTool from app.services.kb_manager import KnowledgeBaseManager from app.utils.logger import get_logger logger = get_logger(__name__) class RetrieverInput(BaseModel): """知识库检索器的输入模式。""" query: str = Field(description="用于查找相关文档的搜索查询") kb_id: int = Field(description="要搜索的知识库 ID") k: int = Field(default=3, description="要返回的结果数量(默认为 3)") class KnowledgeBaseRetriever(BaseTool): """ 使用向量相似度搜索知识库的工具。 """ name: str = "knowledge_base_search" description: str = ( "Search a knowledge base for relevant information. " "Use this when you need to find specific information from stored documents. " "Input should include the search query and knowledge base ID." ) args_schema: Type[BaseModel] = RetrieverInput kb_manager: KnowledgeBaseManager = Field(exclude=True) class Config: arbitrary_types_allowed = True def _run( self, query: str, kb_id: int, k: int = 3, ) -> str: """ 执行检索工具。 Args: query: 搜索查询 kb_id: 知识库 ID k: 结果数量 Returns: str: 格式化的搜索结果 """ logger.info("retriever_tool_called", query=query[:50], kb_id=kb_id, k=k) try: results = self.kb_manager.query_kb(kb_id, query, k=k) if not results: return "No relevant documents found." # 格式化结果 formatted_parts = [] for i, result in enumerate(results, 1): content = result["content"] score = result["score"] metadata = result.get("metadata", {}) source_info = "" if metadata.get("title"): source_info = f" (Source: {metadata['title']})" formatted_parts.append( f"[{i}] (Relevance: {score:.2f}){source_info}\n{content}" ) output = "\n\n".join(formatted_parts) logger.info("retriever_tool_success", kb_id=kb_id, result_count=len(results)) return output except Exception as e: error_msg = f"Error searching knowledge base: {str(e)}" logger.error("retriever_tool_failed", error=str(e), kb_id=kb_id) return error_msg async def _arun( self, query: str, kb_id: int, k: int = 3, ) -> str: """ _run 的异步版本。 注意:当前实现是同步的。 """ return self._run(query, kb_id, k)