369 lines
12 KiB
Python
369 lines
12 KiB
Python
|
|
"""
|
|||
|
|
用于文档摄取和向量搜索的知识库管理器。
|
|||
|
|
"""
|
|||
|
|
from typing import List, Dict, Any, Optional
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
from langchain.schema import Document as LangChainDocument
|
|||
|
|
|
|||
|
|
from app.db.models import KnowledgeBase, Document
|
|||
|
|
from app.services.model_manager import ModelManager
|
|||
|
|
from app.services.job_manager import AsyncJobManager
|
|||
|
|
from app.config import Settings
|
|||
|
|
from app.utils.text_splitter import TextSplitter
|
|||
|
|
from app.utils.faiss_helper import FAISSHelper
|
|||
|
|
from app.utils.exceptions import (
|
|||
|
|
ResourceNotFoundError,
|
|||
|
|
DuplicateResourceError,
|
|||
|
|
VectorStoreError,
|
|||
|
|
)
|
|||
|
|
from app.utils.logger import get_logger
|
|||
|
|
|
|||
|
|
logger = get_logger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class KnowledgeBaseManager:
|
|||
|
|
"""
|
|||
|
|
管理具有向量索引和检索的知识库。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
db_session: Session,
|
|||
|
|
model_manager: ModelManager,
|
|||
|
|
job_manager: AsyncJobManager,
|
|||
|
|
settings: Settings,
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
初始化知识库管理器。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
db_session: 数据库会话
|
|||
|
|
model_manager: 模型管理器实例
|
|||
|
|
job_manager: 异步作业管理器实例
|
|||
|
|
settings: 应用程序设置
|
|||
|
|
"""
|
|||
|
|
self.db = db_session
|
|||
|
|
self.model_manager = model_manager
|
|||
|
|
self.job_manager = job_manager
|
|||
|
|
self.settings = settings
|
|||
|
|
self.text_splitter = TextSplitter(
|
|||
|
|
chunk_size=settings.chunk_size,
|
|||
|
|
chunk_overlap=settings.chunk_overlap,
|
|||
|
|
)
|
|||
|
|
self.faiss_helper = FAISSHelper(settings.faiss_base_path)
|
|||
|
|
|
|||
|
|
def create_kb(self, name: str, description: str = "") -> KnowledgeBase:
|
|||
|
|
"""
|
|||
|
|
创建新知识库。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
name: 知识库名称
|
|||
|
|
description: 可选描述
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
KnowledgeBase: 创建的知识库
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
DuplicateResourceError: 如果存在同名知识库
|
|||
|
|
"""
|
|||
|
|
logger.info("creating_kb", name=name)
|
|||
|
|
|
|||
|
|
# 检查知识库是否已存在
|
|||
|
|
existing = self.db.query(KnowledgeBase).filter(KnowledgeBase.name == name).first()
|
|||
|
|
if existing:
|
|||
|
|
raise DuplicateResourceError(f"Knowledge base '{name}' already exists")
|
|||
|
|
|
|||
|
|
# 创建知识库
|
|||
|
|
kb = KnowledgeBase(name=name, description=description)
|
|||
|
|
self.db.add(kb)
|
|||
|
|
self.db.commit()
|
|||
|
|
self.db.refresh(kb)
|
|||
|
|
|
|||
|
|
logger.info("kb_created", kb_id=kb.id, name=name)
|
|||
|
|
return kb
|
|||
|
|
|
|||
|
|
def get_kb(self, kb_id: int) -> KnowledgeBase:
|
|||
|
|
"""
|
|||
|
|
根据 ID 获取知识库。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
kb_id: 知识库 ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
KnowledgeBase: 知识库实例
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
ResourceNotFoundError: 如果找不到知识库
|
|||
|
|
"""
|
|||
|
|
kb = self.db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|||
|
|
if not kb:
|
|||
|
|
raise ResourceNotFoundError(f"Knowledge base not found: {kb_id}")
|
|||
|
|
return kb
|
|||
|
|
|
|||
|
|
def list_kb(self) -> List[KnowledgeBase]:
|
|||
|
|
"""
|
|||
|
|
列出所有知识库。
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
List[KnowledgeBase]: 知识库列表
|
|||
|
|
"""
|
|||
|
|
return self.db.query(KnowledgeBase).all()
|
|||
|
|
|
|||
|
|
def delete_kb(self, kb_id: int) -> bool:
|
|||
|
|
"""
|
|||
|
|
删除知识库及其索引。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
kb_id: 知识库 ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
bool: 如果已删除则为 True
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
ResourceNotFoundError: 如果找不到知识库
|
|||
|
|
"""
|
|||
|
|
kb = self.get_kb(kb_id)
|
|||
|
|
|
|||
|
|
logger.info("deleting_kb", kb_id=kb_id, name=kb.name)
|
|||
|
|
|
|||
|
|
# 删除 FAISS 索引
|
|||
|
|
try:
|
|||
|
|
self.faiss_helper.delete_index(str(kb_id))
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning("faiss_index_delete_failed", kb_id=kb_id, error=str(e))
|
|||
|
|
|
|||
|
|
# 删除知识库(级联删除文档)
|
|||
|
|
self.db.delete(kb)
|
|||
|
|
self.db.commit()
|
|||
|
|
|
|||
|
|
logger.info("kb_deleted", kb_id=kb_id)
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
async def ingest_documents(
|
|||
|
|
self,
|
|||
|
|
kb_id: int,
|
|||
|
|
documents: List[Dict[str, Any]],
|
|||
|
|
embedding_name: Optional[str] = None,
|
|||
|
|
background: bool = True,
|
|||
|
|
) -> str:
|
|||
|
|
"""
|
|||
|
|
将文档摄取到知识库。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
kb_id: 知识库 ID
|
|||
|
|
documents: 包含 'content'、'title'、'source'、'metadata' 的文档字典列表
|
|||
|
|
embedding_name: 可选的嵌入模型名称
|
|||
|
|
background: 在后台运行(默认 True)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
str: 如果 background=True 则返回作业 ID,否则返回 "completed"
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
ResourceNotFoundError: 如果找不到知识库
|
|||
|
|
"""
|
|||
|
|
# 验证知识库是否存在
|
|||
|
|
kb = self.get_kb(kb_id)
|
|||
|
|
|
|||
|
|
logger.info("ingesting_documents", kb_id=kb_id, doc_count=len(documents))
|
|||
|
|
|
|||
|
|
if background:
|
|||
|
|
# 作为后台作业提交
|
|||
|
|
job_id = await self.job_manager.submit_job(
|
|||
|
|
self._ingest_documents_sync,
|
|||
|
|
kb_id,
|
|||
|
|
documents,
|
|||
|
|
embedding_name,
|
|||
|
|
)
|
|||
|
|
logger.info("ingest_job_submitted", kb_id=kb_id, job_id=job_id)
|
|||
|
|
return job_id
|
|||
|
|
else:
|
|||
|
|
# 同步运行
|
|||
|
|
self._ingest_documents_sync(kb_id, documents, embedding_name)
|
|||
|
|
return "completed"
|
|||
|
|
|
|||
|
|
def _ingest_documents_sync(
|
|||
|
|
self,
|
|||
|
|
kb_id: int,
|
|||
|
|
documents: List[Dict[str, Any]],
|
|||
|
|
embedding_name: Optional[str] = None,
|
|||
|
|
) -> None:
|
|||
|
|
"""
|
|||
|
|
同步文档摄取实现。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
kb_id: 知识库 ID
|
|||
|
|
documents: 文档字典列表
|
|||
|
|
embedding_name: 可选的嵌入模型名称
|
|||
|
|
"""
|
|||
|
|
# 为此后台任务创建新的数据库会话
|
|||
|
|
from app.db.session import get_db_manager
|
|||
|
|
db_manager = get_db_manager()
|
|||
|
|
|
|||
|
|
with db_manager.session_scope() as db_session:
|
|||
|
|
try:
|
|||
|
|
logger.info("starting_document_ingestion", kb_id=kb_id)
|
|||
|
|
|
|||
|
|
# 使用新会话创建临时 ModelManager
|
|||
|
|
temp_model_manager = ModelManager(db_session, self.settings)
|
|||
|
|
|
|||
|
|
# 获取嵌入模型
|
|||
|
|
embeddings = temp_model_manager.get_embedding(embedding_name)
|
|||
|
|
|
|||
|
|
# 准备文档以进行分块
|
|||
|
|
langchain_docs = []
|
|||
|
|
for doc in documents:
|
|||
|
|
content = doc.get("content", "")
|
|||
|
|
metadata = {
|
|||
|
|
"title": doc.get("title", ""),
|
|||
|
|
"source": doc.get("source", ""),
|
|||
|
|
**doc.get("metadata", {}),
|
|||
|
|
}
|
|||
|
|
langchain_docs.append(
|
|||
|
|
LangChainDocument(page_content=content, metadata=metadata)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 分块文档
|
|||
|
|
logger.info("chunking_documents", kb_id=kb_id, doc_count=len(langchain_docs))
|
|||
|
|
chunked_docs = self.text_splitter.split_documents(langchain_docs)
|
|||
|
|
logger.info("documents_chunked", kb_id=kb_id, chunk_count=len(chunked_docs))
|
|||
|
|
|
|||
|
|
# 创建或加载 FAISS 索引
|
|||
|
|
if self.faiss_helper.index_exists(str(kb_id)):
|
|||
|
|
logger.info("loading_existing_index", kb_id=kb_id)
|
|||
|
|
vector_store = self.faiss_helper.load_index(str(kb_id), embeddings)
|
|||
|
|
# 添加新文档
|
|||
|
|
doc_ids = self.faiss_helper.add_documents(
|
|||
|
|
str(kb_id), vector_store, chunked_docs
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
logger.info("creating_new_index", kb_id=kb_id)
|
|||
|
|
vector_store = self.faiss_helper.create_index(
|
|||
|
|
str(kb_id), chunked_docs, embeddings
|
|||
|
|
)
|
|||
|
|
doc_ids = [str(i) for i in range(len(chunked_docs))]
|
|||
|
|
|
|||
|
|
# 保存索引
|
|||
|
|
self.faiss_helper.save_index(str(kb_id), vector_store)
|
|||
|
|
|
|||
|
|
# 将文档元数据保存到数据库
|
|||
|
|
logger.info("saving_document_metadata", kb_id=kb_id)
|
|||
|
|
for i, (doc, doc_id) in enumerate(zip(documents, doc_ids)):
|
|||
|
|
db_doc = Document(
|
|||
|
|
kb_id=kb_id,
|
|||
|
|
title=doc.get("title", ""),
|
|||
|
|
content=doc.get("content", ""),
|
|||
|
|
source=doc.get("source", ""),
|
|||
|
|
doc_metadata=doc.get("metadata", {}),
|
|||
|
|
embedding_id=doc_id,
|
|||
|
|
)
|
|||
|
|
db_session.add(db_doc)
|
|||
|
|
|
|||
|
|
# 提交由 session_scope 上下文管理器处理
|
|||
|
|
|
|||
|
|
logger.info(
|
|||
|
|
"document_ingestion_complete",
|
|||
|
|
kb_id=kb_id,
|
|||
|
|
doc_count=len(documents),
|
|||
|
|
chunk_count=len(chunked_docs),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
import traceback
|
|||
|
|
error_msg = str(e) if str(e) else repr(e)
|
|||
|
|
error_trace = traceback.format_exc()
|
|||
|
|
logger.error("document_ingestion_failed", kb_id=kb_id, error=error_msg, traceback=error_trace)
|
|||
|
|
# 回滚由 session_scope 上下文管理器处理
|
|||
|
|
raise VectorStoreError(f"Document ingestion failed: {error_msg}")
|
|||
|
|
|
|||
|
|
def query_kb(
|
|||
|
|
self,
|
|||
|
|
kb_id: int,
|
|||
|
|
query: str,
|
|||
|
|
k: int = 5,
|
|||
|
|
embedding_name: Optional[str] = None,
|
|||
|
|
) -> List[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
使用向量相似度搜索查询知识库。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
kb_id: 知识库 ID
|
|||
|
|
query: 搜索查询
|
|||
|
|
k: 返回的结果数量
|
|||
|
|
embedding_name: 可选的嵌入模型名称
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
List[Dict[str, Any]]: 带有分数的匹配文档列表
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
ResourceNotFoundError: 如果找不到知识库或索引
|
|||
|
|
VectorStoreError: 如果搜索失败
|
|||
|
|
"""
|
|||
|
|
# 验证知识库是否存在
|
|||
|
|
kb = self.get_kb(kb_id)
|
|||
|
|
|
|||
|
|
logger.info("querying_kb", kb_id=kb_id, query=query[:50], k=k)
|
|||
|
|
|
|||
|
|
# 检查索引是否存在
|
|||
|
|
if not self.faiss_helper.index_exists(str(kb_id)):
|
|||
|
|
raise ResourceNotFoundError(
|
|||
|
|
f"No index found for knowledge base {kb_id}. Please ingest documents first."
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 获取嵌入模型
|
|||
|
|
embeddings = self.model_manager.get_embedding(embedding_name)
|
|||
|
|
|
|||
|
|
# 加载索引
|
|||
|
|
vector_store = self.faiss_helper.load_index(str(kb_id), embeddings)
|
|||
|
|
|
|||
|
|
# 搜索
|
|||
|
|
results = vector_store.similarity_search_with_score(query, k=k)
|
|||
|
|
|
|||
|
|
# 格式化结果
|
|||
|
|
formatted_results = []
|
|||
|
|
for doc, score in results:
|
|||
|
|
formatted_results.append(
|
|||
|
|
{
|
|||
|
|
"content": doc.page_content,
|
|||
|
|
"metadata": doc.metadata,
|
|||
|
|
"score": float(score),
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
logger.info("kb_query_complete", kb_id=kb_id, result_count=len(results))
|
|||
|
|
return formatted_results
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error("kb_query_failed", kb_id=kb_id, error=str(e))
|
|||
|
|
raise VectorStoreError(f"Knowledge base query failed: {str(e)}")
|
|||
|
|
|
|||
|
|
def get_status(self, kb_id: int) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
获取知识库状态和统计信息。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
kb_id: 知识库 ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Dict[str, Any]: 状态信息
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
ResourceNotFoundError: 如果找不到知识库
|
|||
|
|
"""
|
|||
|
|
kb = self.get_kb(kb_id)
|
|||
|
|
|
|||
|
|
# 计算文档数
|
|||
|
|
doc_count = self.db.query(Document).filter(Document.kb_id == kb_id).count()
|
|||
|
|
|
|||
|
|
# 检查索引状态
|
|||
|
|
index_exists = self.faiss_helper.index_exists(str(kb_id))
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"kb_id": kb_id,
|
|||
|
|
"name": kb.name,
|
|||
|
|
"description": kb.description,
|
|||
|
|
"document_count": doc_count,
|
|||
|
|
"index_exists": index_exists,
|
|||
|
|
"created_at": kb.created_at.isoformat() if kb.created_at else None,
|
|||
|
|
}
|