369 lines
12 KiB
Python
369 lines
12 KiB
Python
"""
|
||
用于文档摄取和向量搜索的知识库管理器。
|
||
"""
|
||
from typing import List, Dict, Any, Optional
|
||
from sqlalchemy.orm import Session
|
||
from langchain_core.documents import Document as LangChainDocument
|
||
|
||
from app.db.models import KnowledgeBase, Document
|
||
from app.services.model_manager import ModelManager
|
||
from app.services.job_manager import AsyncJobManager
|
||
from app.config import Settings
|
||
from app.utils.text_splitter import TextSplitter
|
||
from app.utils.faiss_helper import FAISSHelper
|
||
from app.utils.exceptions import (
|
||
ResourceNotFoundError,
|
||
DuplicateResourceError,
|
||
VectorStoreError,
|
||
)
|
||
from app.utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class KnowledgeBaseManager:
|
||
"""
|
||
管理具有向量索引和检索的知识库。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
db_session: Session,
|
||
model_manager: ModelManager,
|
||
job_manager: AsyncJobManager,
|
||
settings: Settings,
|
||
):
|
||
"""
|
||
初始化知识库管理器。
|
||
|
||
Args:
|
||
db_session: 数据库会话
|
||
model_manager: 模型管理器实例
|
||
job_manager: 异步作业管理器实例
|
||
settings: 应用程序设置
|
||
"""
|
||
self.db = db_session
|
||
self.model_manager = model_manager
|
||
self.job_manager = job_manager
|
||
self.settings = settings
|
||
self.text_splitter = TextSplitter(
|
||
chunk_size=settings.chunk_size,
|
||
chunk_overlap=settings.chunk_overlap,
|
||
)
|
||
self.faiss_helper = FAISSHelper(settings.faiss_base_path)
|
||
|
||
def create_kb(self, name: str, description: str = "") -> KnowledgeBase:
|
||
"""
|
||
创建新知识库。
|
||
|
||
Args:
|
||
name: 知识库名称
|
||
description: 可选描述
|
||
|
||
Returns:
|
||
KnowledgeBase: 创建的知识库
|
||
|
||
Raises:
|
||
DuplicateResourceError: 如果存在同名知识库
|
||
"""
|
||
logger.info("creating_kb", name=name)
|
||
|
||
# 检查知识库是否已存在
|
||
existing = self.db.query(KnowledgeBase).filter(KnowledgeBase.name == name).first()
|
||
if existing:
|
||
raise DuplicateResourceError(f"Knowledge base '{name}' already exists")
|
||
|
||
# 创建知识库
|
||
kb = KnowledgeBase(name=name, description=description)
|
||
self.db.add(kb)
|
||
self.db.commit()
|
||
self.db.refresh(kb)
|
||
|
||
logger.info("kb_created", kb_id=kb.id, name=name)
|
||
return kb
|
||
|
||
def get_kb(self, kb_id: int) -> KnowledgeBase:
|
||
"""
|
||
根据 ID 获取知识库。
|
||
|
||
Args:
|
||
kb_id: 知识库 ID
|
||
|
||
Returns:
|
||
KnowledgeBase: 知识库实例
|
||
|
||
Raises:
|
||
ResourceNotFoundError: 如果找不到知识库
|
||
"""
|
||
kb = self.db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
||
if not kb:
|
||
raise ResourceNotFoundError(f"Knowledge base not found: {kb_id}")
|
||
return kb
|
||
|
||
def list_kb(self) -> List[KnowledgeBase]:
|
||
"""
|
||
列出所有知识库。
|
||
|
||
Returns:
|
||
List[KnowledgeBase]: 知识库列表
|
||
"""
|
||
return self.db.query(KnowledgeBase).all()
|
||
|
||
def delete_kb(self, kb_id: int) -> bool:
|
||
"""
|
||
删除知识库及其索引。
|
||
|
||
Args:
|
||
kb_id: 知识库 ID
|
||
|
||
Returns:
|
||
bool: 如果已删除则为 True
|
||
|
||
Raises:
|
||
ResourceNotFoundError: 如果找不到知识库
|
||
"""
|
||
kb = self.get_kb(kb_id)
|
||
|
||
logger.info("deleting_kb", kb_id=kb_id, name=kb.name)
|
||
|
||
# 删除 FAISS 索引
|
||
try:
|
||
self.faiss_helper.delete_index(str(kb_id))
|
||
except Exception as e:
|
||
logger.warning("faiss_index_delete_failed", kb_id=kb_id, error=str(e))
|
||
|
||
# 删除知识库(级联删除文档)
|
||
self.db.delete(kb)
|
||
self.db.commit()
|
||
|
||
logger.info("kb_deleted", kb_id=kb_id)
|
||
return True
|
||
|
||
async def ingest_documents(
|
||
self,
|
||
kb_id: int,
|
||
documents: List[Dict[str, Any]],
|
||
embedding_name: Optional[str] = None,
|
||
background: bool = True,
|
||
) -> str:
|
||
"""
|
||
将文档摄取到知识库。
|
||
|
||
Args:
|
||
kb_id: 知识库 ID
|
||
documents: 包含 'content'、'title'、'source'、'metadata' 的文档字典列表
|
||
embedding_name: 可选的嵌入模型名称
|
||
background: 在后台运行(默认 True)
|
||
|
||
Returns:
|
||
str: 如果 background=True 则返回作业 ID,否则返回 "completed"
|
||
|
||
Raises:
|
||
ResourceNotFoundError: 如果找不到知识库
|
||
"""
|
||
# 验证知识库是否存在
|
||
kb = self.get_kb(kb_id)
|
||
|
||
logger.info("ingesting_documents", kb_id=kb_id, doc_count=len(documents))
|
||
|
||
if background:
|
||
# 作为后台作业提交
|
||
job_id = await self.job_manager.submit_job(
|
||
self._ingest_documents_sync,
|
||
kb_id,
|
||
documents,
|
||
embedding_name,
|
||
)
|
||
logger.info("ingest_job_submitted", kb_id=kb_id, job_id=job_id)
|
||
return job_id
|
||
else:
|
||
# 同步运行
|
||
self._ingest_documents_sync(kb_id, documents, embedding_name)
|
||
return "completed"
|
||
|
||
def _ingest_documents_sync(
|
||
self,
|
||
kb_id: int,
|
||
documents: List[Dict[str, Any]],
|
||
embedding_name: Optional[str] = None,
|
||
) -> None:
|
||
"""
|
||
同步文档摄取实现。
|
||
|
||
Args:
|
||
kb_id: 知识库 ID
|
||
documents: 文档字典列表
|
||
embedding_name: 可选的嵌入模型名称
|
||
"""
|
||
# 为此后台任务创建新的数据库会话
|
||
from app.db.session import get_db_manager
|
||
db_manager = get_db_manager()
|
||
|
||
with db_manager.session_scope() as db_session:
|
||
try:
|
||
logger.info("starting_document_ingestion", kb_id=kb_id)
|
||
|
||
# 使用新会话创建临时 ModelManager
|
||
temp_model_manager = ModelManager(db_session, self.settings)
|
||
|
||
# 获取嵌入模型
|
||
embeddings = temp_model_manager.get_embedding(embedding_name)
|
||
|
||
# 准备文档以进行分块
|
||
langchain_docs = []
|
||
for doc in documents:
|
||
content = doc.get("content", "")
|
||
metadata = {
|
||
"title": doc.get("title", ""),
|
||
"source": doc.get("source", ""),
|
||
**doc.get("metadata", {}),
|
||
}
|
||
langchain_docs.append(
|
||
LangChainDocument(page_content=content, metadata=metadata)
|
||
)
|
||
|
||
# 分块文档
|
||
logger.info("chunking_documents", kb_id=kb_id, doc_count=len(langchain_docs))
|
||
chunked_docs = self.text_splitter.split_documents(langchain_docs)
|
||
logger.info("documents_chunked", kb_id=kb_id, chunk_count=len(chunked_docs))
|
||
|
||
# 创建或加载 FAISS 索引
|
||
if self.faiss_helper.index_exists(str(kb_id)):
|
||
logger.info("loading_existing_index", kb_id=kb_id)
|
||
vector_store = self.faiss_helper.load_index(str(kb_id), embeddings)
|
||
# 添加新文档
|
||
doc_ids = self.faiss_helper.add_documents(
|
||
str(kb_id), vector_store, chunked_docs
|
||
)
|
||
else:
|
||
logger.info("creating_new_index", kb_id=kb_id)
|
||
vector_store = self.faiss_helper.create_index(
|
||
str(kb_id), chunked_docs, embeddings
|
||
)
|
||
doc_ids = [str(i) for i in range(len(chunked_docs))]
|
||
|
||
# 保存索引
|
||
self.faiss_helper.save_index(str(kb_id), vector_store)
|
||
|
||
# 将文档元数据保存到数据库
|
||
logger.info("saving_document_metadata", kb_id=kb_id)
|
||
for i, (doc, doc_id) in enumerate(zip(documents, doc_ids)):
|
||
db_doc = Document(
|
||
kb_id=kb_id,
|
||
title=doc.get("title", ""),
|
||
content=doc.get("content", ""),
|
||
source=doc.get("source", ""),
|
||
doc_metadata=doc.get("metadata", {}),
|
||
embedding_id=doc_id,
|
||
)
|
||
db_session.add(db_doc)
|
||
|
||
# 提交由 session_scope 上下文管理器处理
|
||
|
||
logger.info(
|
||
"document_ingestion_complete",
|
||
kb_id=kb_id,
|
||
doc_count=len(documents),
|
||
chunk_count=len(chunked_docs),
|
||
)
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_msg = str(e) if str(e) else repr(e)
|
||
error_trace = traceback.format_exc()
|
||
logger.error("document_ingestion_failed", kb_id=kb_id, error=error_msg, traceback=error_trace)
|
||
# 回滚由 session_scope 上下文管理器处理
|
||
raise VectorStoreError(f"Document ingestion failed: {error_msg}")
|
||
|
||
def query_kb(
|
||
self,
|
||
kb_id: int,
|
||
query: str,
|
||
k: int = 5,
|
||
embedding_name: Optional[str] = None,
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
使用向量相似度搜索查询知识库。
|
||
|
||
Args:
|
||
kb_id: 知识库 ID
|
||
query: 搜索查询
|
||
k: 返回的结果数量
|
||
embedding_name: 可选的嵌入模型名称
|
||
|
||
Returns:
|
||
List[Dict[str, Any]]: 带有分数的匹配文档列表
|
||
|
||
Raises:
|
||
ResourceNotFoundError: 如果找不到知识库或索引
|
||
VectorStoreError: 如果搜索失败
|
||
"""
|
||
# 验证知识库是否存在
|
||
kb = self.get_kb(kb_id)
|
||
|
||
logger.info("querying_kb", kb_id=kb_id, query=query[:50], k=k)
|
||
|
||
# 检查索引是否存在
|
||
if not self.faiss_helper.index_exists(str(kb_id)):
|
||
raise ResourceNotFoundError(
|
||
f"No index found for knowledge base {kb_id}. Please ingest documents first."
|
||
)
|
||
|
||
try:
|
||
# 获取嵌入模型
|
||
embeddings = self.model_manager.get_embedding(embedding_name)
|
||
|
||
# 加载索引
|
||
vector_store = self.faiss_helper.load_index(str(kb_id), embeddings)
|
||
|
||
# 搜索
|
||
results = vector_store.similarity_search_with_score(query, k=k)
|
||
|
||
# 格式化结果
|
||
formatted_results = []
|
||
for doc, score in results:
|
||
formatted_results.append(
|
||
{
|
||
"content": doc.page_content,
|
||
"metadata": doc.metadata,
|
||
"score": float(score),
|
||
}
|
||
)
|
||
|
||
logger.info("kb_query_complete", kb_id=kb_id, result_count=len(results))
|
||
return formatted_results
|
||
|
||
except Exception as e:
|
||
logger.error("kb_query_failed", kb_id=kb_id, error=str(e))
|
||
raise VectorStoreError(f"Knowledge base query failed: {str(e)}")
|
||
|
||
def get_status(self, kb_id: int) -> Dict[str, Any]:
|
||
"""
|
||
获取知识库状态和统计信息。
|
||
|
||
Args:
|
||
kb_id: 知识库 ID
|
||
|
||
Returns:
|
||
Dict[str, Any]: 状态信息
|
||
|
||
Raises:
|
||
ResourceNotFoundError: 如果找不到知识库
|
||
"""
|
||
kb = self.get_kb(kb_id)
|
||
|
||
# 计算文档数
|
||
doc_count = self.db.query(Document).filter(Document.kb_id == kb_id).count()
|
||
|
||
# 检查索引状态
|
||
index_exists = self.faiss_helper.index_exists(str(kb_id))
|
||
|
||
return {
|
||
"kb_id": kb_id,
|
||
"name": kb.name,
|
||
"description": kb.description,
|
||
"document_count": doc_count,
|
||
"index_exists": index_exists,
|
||
"created_at": kb.created_at.isoformat() if kb.created_at else None,
|
||
}
|