97 lines
2.8 KiB
Python
97 lines
2.8 KiB
Python
"""
|
||
用于 LangChain 代理的知识库检索工具。
|
||
"""
|
||
from typing import Optional, Type
|
||
from pydantic import BaseModel, Field
|
||
from langchain.tools import BaseTool
|
||
|
||
from app.services.kb_manager import KnowledgeBaseManager
|
||
from app.utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class RetrieverInput(BaseModel):
|
||
"""知识库检索器的输入模式。"""
|
||
query: str = Field(description="用于查找相关文档的搜索查询")
|
||
kb_id: int = Field(description="要搜索的知识库 ID")
|
||
k: int = Field(default=3, description="要返回的结果数量(默认为 3)")
|
||
|
||
|
||
class KnowledgeBaseRetriever(BaseTool):
|
||
"""
|
||
使用向量相似度搜索知识库的工具。
|
||
"""
|
||
name: str = "knowledge_base_search"
|
||
description: str = (
|
||
"Search a knowledge base for relevant information. "
|
||
"Use this when you need to find specific information from stored documents. "
|
||
"Input should include the search query and knowledge base ID."
|
||
)
|
||
args_schema: Type[BaseModel] = RetrieverInput
|
||
kb_manager: KnowledgeBaseManager = Field(exclude=True)
|
||
|
||
class Config:
|
||
arbitrary_types_allowed = True
|
||
|
||
def _run(
|
||
self,
|
||
query: str,
|
||
kb_id: int,
|
||
k: int = 3,
|
||
) -> str:
|
||
"""
|
||
执行检索工具。
|
||
|
||
Args:
|
||
query: 搜索查询
|
||
kb_id: 知识库 ID
|
||
k: 结果数量
|
||
|
||
Returns:
|
||
str: 格式化的搜索结果
|
||
"""
|
||
logger.info("retriever_tool_called", query=query[:50], kb_id=kb_id, k=k)
|
||
|
||
try:
|
||
results = self.kb_manager.query_kb(kb_id, query, k=k)
|
||
|
||
if not results:
|
||
return "No relevant documents found."
|
||
|
||
# 格式化结果
|
||
formatted_parts = []
|
||
for i, result in enumerate(results, 1):
|
||
content = result["content"]
|
||
score = result["score"]
|
||
metadata = result.get("metadata", {})
|
||
|
||
source_info = ""
|
||
if metadata.get("title"):
|
||
source_info = f" (Source: {metadata['title']})"
|
||
|
||
formatted_parts.append(
|
||
f"[{i}] (Relevance: {score:.2f}){source_info}\n{content}"
|
||
)
|
||
|
||
output = "\n\n".join(formatted_parts)
|
||
logger.info("retriever_tool_success", kb_id=kb_id, result_count=len(results))
|
||
return output
|
||
|
||
except Exception as e:
|
||
error_msg = f"Error searching knowledge base: {str(e)}"
|
||
logger.error("retriever_tool_failed", error=str(e), kb_id=kb_id)
|
||
return error_msg
|
||
|
||
async def _arun(
|
||
self,
|
||
query: str,
|
||
kb_id: int,
|
||
k: int = 3,
|
||
) -> str:
|
||
"""
|
||
_run 的异步版本。
|
||
注意:当前实现是同步的。
|
||
"""
|
||
return self._run(query, kb_id, k)
|