langchain-learning-kit/app/tools/retriever.py

97 lines
2.8 KiB
Python
Raw Normal View History

"""
用于 LangChain 代理的知识库检索工具
"""
from typing import Optional, Type
from pydantic import BaseModel, Field
from langchain_core.tools import BaseTool
from app.services.kb_manager import KnowledgeBaseManager
from app.utils.logger import get_logger
logger = get_logger(__name__)
class RetrieverInput(BaseModel):
"""知识库检索器的输入模式。"""
query: str = Field(description="用于查找相关文档的搜索查询")
kb_id: int = Field(description="要搜索的知识库 ID")
k: int = Field(default=3, description="要返回的结果数量(默认为 3")
class KnowledgeBaseRetriever(BaseTool):
"""
使用向量相似度搜索知识库的工具
"""
name: str = "knowledge_base_search"
description: str = (
"Search a knowledge base for relevant information. "
"Use this when you need to find specific information from stored documents. "
"Input should include the search query and knowledge base ID."
)
args_schema: Type[BaseModel] = RetrieverInput
kb_manager: KnowledgeBaseManager = Field(exclude=True)
class Config:
arbitrary_types_allowed = True
def _run(
self,
query: str,
kb_id: int,
k: int = 3,
) -> str:
"""
执行检索工具
Args:
query: 搜索查询
kb_id: 知识库 ID
k: 结果数量
Returns:
str: 格式化的搜索结果
"""
logger.info("retriever_tool_called", query=query[:50], kb_id=kb_id, k=k)
try:
results = self.kb_manager.query_kb(kb_id, query, k=k)
if not results:
return "No relevant documents found."
# 格式化结果
formatted_parts = []
for i, result in enumerate(results, 1):
content = result["content"]
score = result["score"]
metadata = result.get("metadata", {})
source_info = ""
if metadata.get("title"):
source_info = f" (Source: {metadata['title']})"
formatted_parts.append(
f"[{i}] (Relevance: {score:.2f}){source_info}\n{content}"
)
output = "\n\n".join(formatted_parts)
logger.info("retriever_tool_success", kb_id=kb_id, result_count=len(results))
return output
except Exception as e:
error_msg = f"Error searching knowledge base: {str(e)}"
logger.error("retriever_tool_failed", error=str(e), kb_id=kb_id)
return error_msg
async def _arun(
self,
query: str,
kb_id: int,
k: int = 3,
) -> str:
"""
_run 的异步版本
注意当前实现是同步的
"""
return self._run(query, kb_id, k)