langchain-learning-kit/app/services/kb_manager.py

369 lines
12 KiB
Python
Raw Normal View History

"""
用于文档摄取和向量搜索的知识库管理器
"""
from typing import List, Dict, Any, Optional
from sqlalchemy.orm import Session
from langchain_core.documents import Document as LangChainDocument
from app.db.models import KnowledgeBase, Document
from app.services.model_manager import ModelManager
from app.services.job_manager import AsyncJobManager
from app.config import Settings
from app.utils.text_splitter import TextSplitter
from app.utils.faiss_helper import FAISSHelper
from app.utils.exceptions import (
ResourceNotFoundError,
DuplicateResourceError,
VectorStoreError,
)
from app.utils.logger import get_logger
logger = get_logger(__name__)
class KnowledgeBaseManager:
"""
管理具有向量索引和检索的知识库
"""
def __init__(
self,
db_session: Session,
model_manager: ModelManager,
job_manager: AsyncJobManager,
settings: Settings,
):
"""
初始化知识库管理器
Args:
db_session: 数据库会话
model_manager: 模型管理器实例
job_manager: 异步作业管理器实例
settings: 应用程序设置
"""
self.db = db_session
self.model_manager = model_manager
self.job_manager = job_manager
self.settings = settings
self.text_splitter = TextSplitter(
chunk_size=settings.chunk_size,
chunk_overlap=settings.chunk_overlap,
)
self.faiss_helper = FAISSHelper(settings.faiss_base_path)
def create_kb(self, name: str, description: str = "") -> KnowledgeBase:
"""
创建新知识库
Args:
name: 知识库名称
description: 可选描述
Returns:
KnowledgeBase: 创建的知识库
Raises:
DuplicateResourceError: 如果存在同名知识库
"""
logger.info("creating_kb", name=name)
# 检查知识库是否已存在
existing = self.db.query(KnowledgeBase).filter(KnowledgeBase.name == name).first()
if existing:
raise DuplicateResourceError(f"Knowledge base '{name}' already exists")
# 创建知识库
kb = KnowledgeBase(name=name, description=description)
self.db.add(kb)
self.db.commit()
self.db.refresh(kb)
logger.info("kb_created", kb_id=kb.id, name=name)
return kb
def get_kb(self, kb_id: int) -> KnowledgeBase:
"""
根据 ID 获取知识库
Args:
kb_id: 知识库 ID
Returns:
KnowledgeBase: 知识库实例
Raises:
ResourceNotFoundError: 如果找不到知识库
"""
kb = self.db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
if not kb:
raise ResourceNotFoundError(f"Knowledge base not found: {kb_id}")
return kb
def list_kb(self) -> List[KnowledgeBase]:
"""
列出所有知识库
Returns:
List[KnowledgeBase]: 知识库列表
"""
return self.db.query(KnowledgeBase).all()
def delete_kb(self, kb_id: int) -> bool:
"""
删除知识库及其索引
Args:
kb_id: 知识库 ID
Returns:
bool: 如果已删除则为 True
Raises:
ResourceNotFoundError: 如果找不到知识库
"""
kb = self.get_kb(kb_id)
logger.info("deleting_kb", kb_id=kb_id, name=kb.name)
# 删除 FAISS 索引
try:
self.faiss_helper.delete_index(str(kb_id))
except Exception as e:
logger.warning("faiss_index_delete_failed", kb_id=kb_id, error=str(e))
# 删除知识库(级联删除文档)
self.db.delete(kb)
self.db.commit()
logger.info("kb_deleted", kb_id=kb_id)
return True
async def ingest_documents(
self,
kb_id: int,
documents: List[Dict[str, Any]],
embedding_name: Optional[str] = None,
background: bool = True,
) -> str:
"""
将文档摄取到知识库
Args:
kb_id: 知识库 ID
documents: 包含 'content''title''source''metadata' 的文档字典列表
embedding_name: 可选的嵌入模型名称
background: 在后台运行默认 True
Returns:
str: 如果 background=True 则返回作业 ID否则返回 "completed"
Raises:
ResourceNotFoundError: 如果找不到知识库
"""
# 验证知识库是否存在
kb = self.get_kb(kb_id)
logger.info("ingesting_documents", kb_id=kb_id, doc_count=len(documents))
if background:
# 作为后台作业提交
job_id = await self.job_manager.submit_job(
self._ingest_documents_sync,
kb_id,
documents,
embedding_name,
)
logger.info("ingest_job_submitted", kb_id=kb_id, job_id=job_id)
return job_id
else:
# 同步运行
self._ingest_documents_sync(kb_id, documents, embedding_name)
return "completed"
def _ingest_documents_sync(
self,
kb_id: int,
documents: List[Dict[str, Any]],
embedding_name: Optional[str] = None,
) -> None:
"""
同步文档摄取实现
Args:
kb_id: 知识库 ID
documents: 文档字典列表
embedding_name: 可选的嵌入模型名称
"""
# 为此后台任务创建新的数据库会话
from app.db.session import get_db_manager
db_manager = get_db_manager()
with db_manager.session_scope() as db_session:
try:
logger.info("starting_document_ingestion", kb_id=kb_id)
# 使用新会话创建临时 ModelManager
temp_model_manager = ModelManager(db_session, self.settings)
# 获取嵌入模型
embeddings = temp_model_manager.get_embedding(embedding_name)
# 准备文档以进行分块
langchain_docs = []
for doc in documents:
content = doc.get("content", "")
metadata = {
"title": doc.get("title", ""),
"source": doc.get("source", ""),
**doc.get("metadata", {}),
}
langchain_docs.append(
LangChainDocument(page_content=content, metadata=metadata)
)
# 分块文档
logger.info("chunking_documents", kb_id=kb_id, doc_count=len(langchain_docs))
chunked_docs = self.text_splitter.split_documents(langchain_docs)
logger.info("documents_chunked", kb_id=kb_id, chunk_count=len(chunked_docs))
# 创建或加载 FAISS 索引
if self.faiss_helper.index_exists(str(kb_id)):
logger.info("loading_existing_index", kb_id=kb_id)
vector_store = self.faiss_helper.load_index(str(kb_id), embeddings)
# 添加新文档
doc_ids = self.faiss_helper.add_documents(
str(kb_id), vector_store, chunked_docs
)
else:
logger.info("creating_new_index", kb_id=kb_id)
vector_store = self.faiss_helper.create_index(
str(kb_id), chunked_docs, embeddings
)
doc_ids = [str(i) for i in range(len(chunked_docs))]
# 保存索引
self.faiss_helper.save_index(str(kb_id), vector_store)
# 将文档元数据保存到数据库
logger.info("saving_document_metadata", kb_id=kb_id)
for i, (doc, doc_id) in enumerate(zip(documents, doc_ids)):
db_doc = Document(
kb_id=kb_id,
title=doc.get("title", ""),
content=doc.get("content", ""),
source=doc.get("source", ""),
doc_metadata=doc.get("metadata", {}),
embedding_id=doc_id,
)
db_session.add(db_doc)
# 提交由 session_scope 上下文管理器处理
logger.info(
"document_ingestion_complete",
kb_id=kb_id,
doc_count=len(documents),
chunk_count=len(chunked_docs),
)
except Exception as e:
import traceback
error_msg = str(e) if str(e) else repr(e)
error_trace = traceback.format_exc()
logger.error("document_ingestion_failed", kb_id=kb_id, error=error_msg, traceback=error_trace)
# 回滚由 session_scope 上下文管理器处理
raise VectorStoreError(f"Document ingestion failed: {error_msg}")
def query_kb(
self,
kb_id: int,
query: str,
k: int = 5,
embedding_name: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
使用向量相似度搜索查询知识库
Args:
kb_id: 知识库 ID
query: 搜索查询
k: 返回的结果数量
embedding_name: 可选的嵌入模型名称
Returns:
List[Dict[str, Any]]: 带有分数的匹配文档列表
Raises:
ResourceNotFoundError: 如果找不到知识库或索引
VectorStoreError: 如果搜索失败
"""
# 验证知识库是否存在
kb = self.get_kb(kb_id)
logger.info("querying_kb", kb_id=kb_id, query=query[:50], k=k)
# 检查索引是否存在
if not self.faiss_helper.index_exists(str(kb_id)):
raise ResourceNotFoundError(
f"No index found for knowledge base {kb_id}. Please ingest documents first."
)
try:
# 获取嵌入模型
embeddings = self.model_manager.get_embedding(embedding_name)
# 加载索引
vector_store = self.faiss_helper.load_index(str(kb_id), embeddings)
# 搜索
results = vector_store.similarity_search_with_score(query, k=k)
# 格式化结果
formatted_results = []
for doc, score in results:
formatted_results.append(
{
"content": doc.page_content,
"metadata": doc.metadata,
"score": float(score),
}
)
logger.info("kb_query_complete", kb_id=kb_id, result_count=len(results))
return formatted_results
except Exception as e:
logger.error("kb_query_failed", kb_id=kb_id, error=str(e))
raise VectorStoreError(f"Knowledge base query failed: {str(e)}")
def get_status(self, kb_id: int) -> Dict[str, Any]:
"""
获取知识库状态和统计信息
Args:
kb_id: 知识库 ID
Returns:
Dict[str, Any]: 状态信息
Raises:
ResourceNotFoundError: 如果找不到知识库
"""
kb = self.get_kb(kb_id)
# 计算文档数
doc_count = self.db.query(Document).filter(Document.kb_id == kb_id).count()
# 检查索引状态
index_exists = self.faiss_helper.index_exists(str(kb_id))
return {
"kb_id": kb_id,
"name": kb.name,
"description": kb.description,
"document_count": doc_count,
"index_exists": index_exists,
"created_at": kb.created_at.isoformat() if kb.created_at else None,
}