2025-10-02 17:22:19 +08:00
|
|
|
|
"""
|
|
|
|
|
|
用于 LangChain 代理的知识库检索工具。
|
|
|
|
|
|
"""
|
|
|
|
|
|
from typing import Optional, Type
|
|
|
|
|
|
from pydantic import BaseModel, Field
|
2025-10-02 18:10:53 +08:00
|
|
|
|
from langchain_core.tools import BaseTool
|
2025-10-02 17:22:19 +08:00
|
|
|
|
|
|
|
|
|
|
from app.services.kb_manager import KnowledgeBaseManager
|
|
|
|
|
|
from app.utils.logger import get_logger
|
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RetrieverInput(BaseModel):
|
|
|
|
|
|
"""知识库检索器的输入模式。"""
|
|
|
|
|
|
query: str = Field(description="用于查找相关文档的搜索查询")
|
|
|
|
|
|
kb_id: int = Field(description="要搜索的知识库 ID")
|
|
|
|
|
|
k: int = Field(default=3, description="要返回的结果数量(默认为 3)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KnowledgeBaseRetriever(BaseTool):
|
|
|
|
|
|
"""
|
|
|
|
|
|
使用向量相似度搜索知识库的工具。
|
|
|
|
|
|
"""
|
|
|
|
|
|
name: str = "knowledge_base_search"
|
|
|
|
|
|
description: str = (
|
|
|
|
|
|
"Search a knowledge base for relevant information. "
|
|
|
|
|
|
"Use this when you need to find specific information from stored documents. "
|
|
|
|
|
|
"Input should include the search query and knowledge base ID."
|
|
|
|
|
|
)
|
|
|
|
|
|
args_schema: Type[BaseModel] = RetrieverInput
|
|
|
|
|
|
kb_manager: KnowledgeBaseManager = Field(exclude=True)
|
|
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
|
|
|
|
|
|
|
|
def _run(
|
|
|
|
|
|
self,
|
|
|
|
|
|
query: str,
|
|
|
|
|
|
kb_id: int,
|
|
|
|
|
|
k: int = 3,
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
执行检索工具。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
query: 搜索查询
|
|
|
|
|
|
kb_id: 知识库 ID
|
|
|
|
|
|
k: 结果数量
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
str: 格式化的搜索结果
|
|
|
|
|
|
"""
|
|
|
|
|
|
logger.info("retriever_tool_called", query=query[:50], kb_id=kb_id, k=k)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
results = self.kb_manager.query_kb(kb_id, query, k=k)
|
|
|
|
|
|
|
|
|
|
|
|
if not results:
|
|
|
|
|
|
return "No relevant documents found."
|
|
|
|
|
|
|
|
|
|
|
|
# 格式化结果
|
|
|
|
|
|
formatted_parts = []
|
|
|
|
|
|
for i, result in enumerate(results, 1):
|
|
|
|
|
|
content = result["content"]
|
|
|
|
|
|
score = result["score"]
|
|
|
|
|
|
metadata = result.get("metadata", {})
|
|
|
|
|
|
|
|
|
|
|
|
source_info = ""
|
|
|
|
|
|
if metadata.get("title"):
|
|
|
|
|
|
source_info = f" (Source: {metadata['title']})"
|
|
|
|
|
|
|
|
|
|
|
|
formatted_parts.append(
|
|
|
|
|
|
f"[{i}] (Relevance: {score:.2f}){source_info}\n{content}"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
output = "\n\n".join(formatted_parts)
|
|
|
|
|
|
logger.info("retriever_tool_success", kb_id=kb_id, result_count=len(results))
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
error_msg = f"Error searching knowledge base: {str(e)}"
|
|
|
|
|
|
logger.error("retriever_tool_failed", error=str(e), kb_id=kb_id)
|
|
|
|
|
|
return error_msg
|
|
|
|
|
|
|
|
|
|
|
|
async def _arun(
|
|
|
|
|
|
self,
|
|
|
|
|
|
query: str,
|
|
|
|
|
|
kb_id: int,
|
|
|
|
|
|
k: int = 3,
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
_run 的异步版本。
|
|
|
|
|
|
注意:当前实现是同步的。
|
|
|
|
|
|
"""
|
|
|
|
|
|
return self._run(query, kb_id, k)
|