langchain-learning-kit/app/services/kb_manager.py

369 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
用于文档摄取和向量搜索的知识库管理器。
"""
from typing import List, Dict, Any, Optional
from sqlalchemy.orm import Session
from langchain_core.documents import Document as LangChainDocument
from app.db.models import KnowledgeBase, Document
from app.services.model_manager import ModelManager
from app.services.job_manager import AsyncJobManager
from app.config import Settings
from app.utils.text_splitter import TextSplitter
from app.utils.faiss_helper import FAISSHelper
from app.utils.exceptions import (
ResourceNotFoundError,
DuplicateResourceError,
VectorStoreError,
)
from app.utils.logger import get_logger
logger = get_logger(__name__)
class KnowledgeBaseManager:
"""
管理具有向量索引和检索的知识库。
"""
def __init__(
self,
db_session: Session,
model_manager: ModelManager,
job_manager: AsyncJobManager,
settings: Settings,
):
"""
初始化知识库管理器。
Args:
db_session: 数据库会话
model_manager: 模型管理器实例
job_manager: 异步作业管理器实例
settings: 应用程序设置
"""
self.db = db_session
self.model_manager = model_manager
self.job_manager = job_manager
self.settings = settings
self.text_splitter = TextSplitter(
chunk_size=settings.chunk_size,
chunk_overlap=settings.chunk_overlap,
)
self.faiss_helper = FAISSHelper(settings.faiss_base_path)
def create_kb(self, name: str, description: str = "") -> KnowledgeBase:
"""
创建新知识库。
Args:
name: 知识库名称
description: 可选描述
Returns:
KnowledgeBase: 创建的知识库
Raises:
DuplicateResourceError: 如果存在同名知识库
"""
logger.info("creating_kb", name=name)
# 检查知识库是否已存在
existing = self.db.query(KnowledgeBase).filter(KnowledgeBase.name == name).first()
if existing:
raise DuplicateResourceError(f"Knowledge base '{name}' already exists")
# 创建知识库
kb = KnowledgeBase(name=name, description=description)
self.db.add(kb)
self.db.commit()
self.db.refresh(kb)
logger.info("kb_created", kb_id=kb.id, name=name)
return kb
def get_kb(self, kb_id: int) -> KnowledgeBase:
"""
根据 ID 获取知识库。
Args:
kb_id: 知识库 ID
Returns:
KnowledgeBase: 知识库实例
Raises:
ResourceNotFoundError: 如果找不到知识库
"""
kb = self.db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
if not kb:
raise ResourceNotFoundError(f"Knowledge base not found: {kb_id}")
return kb
def list_kb(self) -> List[KnowledgeBase]:
"""
列出所有知识库。
Returns:
List[KnowledgeBase]: 知识库列表
"""
return self.db.query(KnowledgeBase).all()
def delete_kb(self, kb_id: int) -> bool:
"""
删除知识库及其索引。
Args:
kb_id: 知识库 ID
Returns:
bool: 如果已删除则为 True
Raises:
ResourceNotFoundError: 如果找不到知识库
"""
kb = self.get_kb(kb_id)
logger.info("deleting_kb", kb_id=kb_id, name=kb.name)
# 删除 FAISS 索引
try:
self.faiss_helper.delete_index(str(kb_id))
except Exception as e:
logger.warning("faiss_index_delete_failed", kb_id=kb_id, error=str(e))
# 删除知识库(级联删除文档)
self.db.delete(kb)
self.db.commit()
logger.info("kb_deleted", kb_id=kb_id)
return True
async def ingest_documents(
self,
kb_id: int,
documents: List[Dict[str, Any]],
embedding_name: Optional[str] = None,
background: bool = True,
) -> str:
"""
将文档摄取到知识库。
Args:
kb_id: 知识库 ID
documents: 包含 'content''title''source''metadata' 的文档字典列表
embedding_name: 可选的嵌入模型名称
background: 在后台运行(默认 True
Returns:
str: 如果 background=True 则返回作业 ID否则返回 "completed"
Raises:
ResourceNotFoundError: 如果找不到知识库
"""
# 验证知识库是否存在
kb = self.get_kb(kb_id)
logger.info("ingesting_documents", kb_id=kb_id, doc_count=len(documents))
if background:
# 作为后台作业提交
job_id = await self.job_manager.submit_job(
self._ingest_documents_sync,
kb_id,
documents,
embedding_name,
)
logger.info("ingest_job_submitted", kb_id=kb_id, job_id=job_id)
return job_id
else:
# 同步运行
self._ingest_documents_sync(kb_id, documents, embedding_name)
return "completed"
def _ingest_documents_sync(
self,
kb_id: int,
documents: List[Dict[str, Any]],
embedding_name: Optional[str] = None,
) -> None:
"""
同步文档摄取实现。
Args:
kb_id: 知识库 ID
documents: 文档字典列表
embedding_name: 可选的嵌入模型名称
"""
# 为此后台任务创建新的数据库会话
from app.db.session import get_db_manager
db_manager = get_db_manager()
with db_manager.session_scope() as db_session:
try:
logger.info("starting_document_ingestion", kb_id=kb_id)
# 使用新会话创建临时 ModelManager
temp_model_manager = ModelManager(db_session, self.settings)
# 获取嵌入模型
embeddings = temp_model_manager.get_embedding(embedding_name)
# 准备文档以进行分块
langchain_docs = []
for doc in documents:
content = doc.get("content", "")
metadata = {
"title": doc.get("title", ""),
"source": doc.get("source", ""),
**doc.get("metadata", {}),
}
langchain_docs.append(
LangChainDocument(page_content=content, metadata=metadata)
)
# 分块文档
logger.info("chunking_documents", kb_id=kb_id, doc_count=len(langchain_docs))
chunked_docs = self.text_splitter.split_documents(langchain_docs)
logger.info("documents_chunked", kb_id=kb_id, chunk_count=len(chunked_docs))
# 创建或加载 FAISS 索引
if self.faiss_helper.index_exists(str(kb_id)):
logger.info("loading_existing_index", kb_id=kb_id)
vector_store = self.faiss_helper.load_index(str(kb_id), embeddings)
# 添加新文档
doc_ids = self.faiss_helper.add_documents(
str(kb_id), vector_store, chunked_docs
)
else:
logger.info("creating_new_index", kb_id=kb_id)
vector_store = self.faiss_helper.create_index(
str(kb_id), chunked_docs, embeddings
)
doc_ids = [str(i) for i in range(len(chunked_docs))]
# 保存索引
self.faiss_helper.save_index(str(kb_id), vector_store)
# 将文档元数据保存到数据库
logger.info("saving_document_metadata", kb_id=kb_id)
for i, (doc, doc_id) in enumerate(zip(documents, doc_ids)):
db_doc = Document(
kb_id=kb_id,
title=doc.get("title", ""),
content=doc.get("content", ""),
source=doc.get("source", ""),
doc_metadata=doc.get("metadata", {}),
embedding_id=doc_id,
)
db_session.add(db_doc)
# 提交由 session_scope 上下文管理器处理
logger.info(
"document_ingestion_complete",
kb_id=kb_id,
doc_count=len(documents),
chunk_count=len(chunked_docs),
)
except Exception as e:
import traceback
error_msg = str(e) if str(e) else repr(e)
error_trace = traceback.format_exc()
logger.error("document_ingestion_failed", kb_id=kb_id, error=error_msg, traceback=error_trace)
# 回滚由 session_scope 上下文管理器处理
raise VectorStoreError(f"Document ingestion failed: {error_msg}")
def query_kb(
self,
kb_id: int,
query: str,
k: int = 5,
embedding_name: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
使用向量相似度搜索查询知识库。
Args:
kb_id: 知识库 ID
query: 搜索查询
k: 返回的结果数量
embedding_name: 可选的嵌入模型名称
Returns:
List[Dict[str, Any]]: 带有分数的匹配文档列表
Raises:
ResourceNotFoundError: 如果找不到知识库或索引
VectorStoreError: 如果搜索失败
"""
# 验证知识库是否存在
kb = self.get_kb(kb_id)
logger.info("querying_kb", kb_id=kb_id, query=query[:50], k=k)
# 检查索引是否存在
if not self.faiss_helper.index_exists(str(kb_id)):
raise ResourceNotFoundError(
f"No index found for knowledge base {kb_id}. Please ingest documents first."
)
try:
# 获取嵌入模型
embeddings = self.model_manager.get_embedding(embedding_name)
# 加载索引
vector_store = self.faiss_helper.load_index(str(kb_id), embeddings)
# 搜索
results = vector_store.similarity_search_with_score(query, k=k)
# 格式化结果
formatted_results = []
for doc, score in results:
formatted_results.append(
{
"content": doc.page_content,
"metadata": doc.metadata,
"score": float(score),
}
)
logger.info("kb_query_complete", kb_id=kb_id, result_count=len(results))
return formatted_results
except Exception as e:
logger.error("kb_query_failed", kb_id=kb_id, error=str(e))
raise VectorStoreError(f"Knowledge base query failed: {str(e)}")
def get_status(self, kb_id: int) -> Dict[str, Any]:
"""
获取知识库状态和统计信息。
Args:
kb_id: 知识库 ID
Returns:
Dict[str, Any]: 状态信息
Raises:
ResourceNotFoundError: 如果找不到知识库
"""
kb = self.get_kb(kb_id)
# 计算文档数
doc_count = self.db.query(Document).filter(Document.kb_id == kb_id).count()
# 检查索引状态
index_exists = self.faiss_helper.index_exists(str(kb_id))
return {
"kb_id": kb_id,
"name": kb.name,
"description": kb.description,
"document_count": doc_count,
"index_exists": index_exists,
"created_at": kb.created_at.isoformat() if kb.created_at else None,
}