langchain-learning-kit/app/tools/retriever.py

97 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
用于 LangChain 代理的知识库检索工具。
"""
from typing import Optional, Type
from pydantic import BaseModel, Field
from langchain.tools import BaseTool
from app.services.kb_manager import KnowledgeBaseManager
from app.utils.logger import get_logger
logger = get_logger(__name__)
class RetrieverInput(BaseModel):
"""知识库检索器的输入模式。"""
query: str = Field(description="用于查找相关文档的搜索查询")
kb_id: int = Field(description="要搜索的知识库 ID")
k: int = Field(default=3, description="要返回的结果数量(默认为 3")
class KnowledgeBaseRetriever(BaseTool):
"""
使用向量相似度搜索知识库的工具。
"""
name: str = "knowledge_base_search"
description: str = (
"Search a knowledge base for relevant information. "
"Use this when you need to find specific information from stored documents. "
"Input should include the search query and knowledge base ID."
)
args_schema: Type[BaseModel] = RetrieverInput
kb_manager: KnowledgeBaseManager = Field(exclude=True)
class Config:
arbitrary_types_allowed = True
def _run(
self,
query: str,
kb_id: int,
k: int = 3,
) -> str:
"""
执行检索工具。
Args:
query: 搜索查询
kb_id: 知识库 ID
k: 结果数量
Returns:
str: 格式化的搜索结果
"""
logger.info("retriever_tool_called", query=query[:50], kb_id=kb_id, k=k)
try:
results = self.kb_manager.query_kb(kb_id, query, k=k)
if not results:
return "No relevant documents found."
# 格式化结果
formatted_parts = []
for i, result in enumerate(results, 1):
content = result["content"]
score = result["score"]
metadata = result.get("metadata", {})
source_info = ""
if metadata.get("title"):
source_info = f" (Source: {metadata['title']})"
formatted_parts.append(
f"[{i}] (Relevance: {score:.2f}){source_info}\n{content}"
)
output = "\n\n".join(formatted_parts)
logger.info("retriever_tool_success", kb_id=kb_id, result_count=len(results))
return output
except Exception as e:
error_msg = f"Error searching knowledge base: {str(e)}"
logger.error("retriever_tool_failed", error=str(e), kb_id=kb_id)
return error_msg
async def _arun(
self,
query: str,
kb_id: int,
k: int = 3,
) -> str:
"""
_run 的异步版本。
注意:当前实现是同步的。
"""
return self._run(query, kb_id, k)