242 lines
7.4 KiB
Python
242 lines
7.4 KiB
Python
|
|
"""
|
|||
|
|
FAISS 向量存储辅助函数。
|
|||
|
|
"""
|
|||
|
|
import os
|
|||
|
|
from typing import List, Optional
|
|||
|
|
|
|||
|
|
from langchain_community.vectorstores import FAISS
|
|||
|
|
from langchain.schema import Document as LangChainDocument
|
|||
|
|
from langchain.embeddings.base import Embeddings
|
|||
|
|
|
|||
|
|
from app.utils.exceptions import VectorStoreError
|
|||
|
|
from app.utils.logger import get_logger
|
|||
|
|
|
|||
|
|
logger = get_logger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FAISSHelper:
|
|||
|
|
"""
|
|||
|
|
FAISS 向量存储操作的辅助类。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, base_path: str):
|
|||
|
|
"""
|
|||
|
|
初始化 FAISS 辅助器。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
base_path: FAISS 索引的基础目录路径
|
|||
|
|
"""
|
|||
|
|
self.base_path = base_path
|
|||
|
|
os.makedirs(base_path, exist_ok=True)
|
|||
|
|
|
|||
|
|
def get_index_path(self, kb_name: str) -> str:
|
|||
|
|
"""
|
|||
|
|
获取知识库的 FAISS 索引目录路径。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
kb_name: 知识库名称或 ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
str: FAISS 索引目录的完整路径
|
|||
|
|
"""
|
|||
|
|
return os.path.join(self.base_path, str(kb_name))
|
|||
|
|
|
|||
|
|
def index_exists(self, kb_name: str) -> bool:
|
|||
|
|
"""
|
|||
|
|
检查知识库的 FAISS 索引是否存在。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
kb_name: 知识库名称或 ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
bool: 如果索引存在则为 True,否则为 False
|
|||
|
|
"""
|
|||
|
|
index_path = self.get_index_path(kb_name)
|
|||
|
|
return os.path.exists(os.path.join(index_path, "index.faiss"))
|
|||
|
|
|
|||
|
|
def create_index(
|
|||
|
|
self,
|
|||
|
|
kb_name: str,
|
|||
|
|
documents: List[LangChainDocument],
|
|||
|
|
embeddings: Embeddings,
|
|||
|
|
) -> FAISS:
|
|||
|
|
"""
|
|||
|
|
从文档创建新的 FAISS 索引。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
kb_name: 知识库名称或 ID
|
|||
|
|
documents: 要索引的文档列表
|
|||
|
|
embeddings: 嵌入向量实例
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
FAISS: FAISS 向量存储实例
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
VectorStoreError: 如果索引创建失败
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
logger.info("creating_faiss_index", kb_name=kb_name, doc_count=len(documents))
|
|||
|
|
|
|||
|
|
if not documents:
|
|||
|
|
raise VectorStoreError("Cannot create index from empty document list")
|
|||
|
|
|
|||
|
|
# 创建 FAISS 索引
|
|||
|
|
vector_store = FAISS.from_documents(documents, embeddings)
|
|||
|
|
|
|||
|
|
logger.info("faiss_index_created", kb_name=kb_name)
|
|||
|
|
return vector_store
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error("faiss_index_creation_failed", kb_name=kb_name, error=str(e))
|
|||
|
|
raise VectorStoreError(f"Failed to create FAISS index: {str(e)}")
|
|||
|
|
|
|||
|
|
def save_index(self, kb_name: str, vector_store: FAISS) -> None:
|
|||
|
|
"""
|
|||
|
|
将 FAISS 索引保存到磁盘。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
kb_name: 知识库名称或 ID
|
|||
|
|
vector_store: FAISS 向量存储实例
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
VectorStoreError: 如果保存操作失败
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
import pickle
|
|||
|
|
|
|||
|
|
index_path = self.get_index_path(kb_name)
|
|||
|
|
os.makedirs(index_path, exist_ok=True)
|
|||
|
|
|
|||
|
|
logger.info("saving_faiss_index", kb_name=kb_name, path=index_path)
|
|||
|
|
|
|||
|
|
# 使用受控的 pickle 协议手动保存 FAISS 索引
|
|||
|
|
index_file = os.path.join(index_path, "index.faiss")
|
|||
|
|
pkl_file = os.path.join(index_path, "index.pkl")
|
|||
|
|
|
|||
|
|
# 保存 FAISS 索引(二进制格式,无 pickle 问题)
|
|||
|
|
import faiss
|
|||
|
|
faiss.write_index(vector_store.index, index_file)
|
|||
|
|
|
|||
|
|
# 使用 pickle 协议 4 保存文档存储和索引到文档存储的映射
|
|||
|
|
with open(pkl_file, "wb") as f:
|
|||
|
|
pickle.dump((vector_store.docstore, vector_store.index_to_docstore_id), f, protocol=4)
|
|||
|
|
|
|||
|
|
logger.info("faiss_index_saved", kb_name=kb_name)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error("faiss_index_save_failed", kb_name=kb_name, error=str(e))
|
|||
|
|
raise VectorStoreError(f"Failed to save FAISS index: {str(e)}")
|
|||
|
|
|
|||
|
|
def load_index(
|
|||
|
|
self,
|
|||
|
|
kb_name: str,
|
|||
|
|
embeddings: Embeddings,
|
|||
|
|
allow_dangerous_deserialization: bool = True,
|
|||
|
|
) -> FAISS:
|
|||
|
|
"""
|
|||
|
|
从磁盘加载 FAISS 索引。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
kb_name: 知识库名称或 ID
|
|||
|
|
embeddings: 嵌入向量实例
|
|||
|
|
allow_dangerous_deserialization: 允许 pickle 反序列化
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
FAISS: FAISS 向量存储实例
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
VectorStoreError: 如果加载操作失败
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
import pickle
|
|||
|
|
|
|||
|
|
index_path = self.get_index_path(kb_name)
|
|||
|
|
|
|||
|
|
if not self.index_exists(kb_name):
|
|||
|
|
raise VectorStoreError(f"FAISS index not found for kb: {kb_name}")
|
|||
|
|
|
|||
|
|
logger.info("loading_faiss_index", kb_name=kb_name, path=index_path)
|
|||
|
|
|
|||
|
|
# 手动加载 FAISS 索引
|
|||
|
|
index_file = os.path.join(index_path, "index.faiss")
|
|||
|
|
pkl_file = os.path.join(index_path, "index.pkl")
|
|||
|
|
|
|||
|
|
# 加载 FAISS 索引
|
|||
|
|
import faiss
|
|||
|
|
index = faiss.read_index(index_file)
|
|||
|
|
|
|||
|
|
# 加载文档存储和索引到文档存储的映射
|
|||
|
|
with open(pkl_file, "rb") as f:
|
|||
|
|
docstore, index_to_docstore_id = pickle.load(f)
|
|||
|
|
|
|||
|
|
# 重构 FAISS 向量存储
|
|||
|
|
vector_store = FAISS(
|
|||
|
|
embedding_function=embeddings.embed_query,
|
|||
|
|
index=index,
|
|||
|
|
docstore=docstore,
|
|||
|
|
index_to_docstore_id=index_to_docstore_id,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
logger.info("faiss_index_loaded", kb_name=kb_name)
|
|||
|
|
return vector_store
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error("faiss_index_load_failed", kb_name=kb_name, error=str(e))
|
|||
|
|
raise VectorStoreError(f"Failed to load FAISS index: {str(e)}")
|
|||
|
|
|
|||
|
|
def add_documents(
|
|||
|
|
self,
|
|||
|
|
kb_name: str,
|
|||
|
|
vector_store: FAISS,
|
|||
|
|
documents: List[LangChainDocument],
|
|||
|
|
) -> List[str]:
|
|||
|
|
"""
|
|||
|
|
向现有 FAISS 索引添加文档。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
kb_name: 知识库名称或 ID
|
|||
|
|
vector_store: FAISS 向量存储实例
|
|||
|
|
documents: 要添加的文档列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
List[str]: 文档 ID 列表
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
VectorStoreError: 如果添加操作失败
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
logger.info("adding_documents_to_index", kb_name=kb_name, doc_count=len(documents))
|
|||
|
|
ids = vector_store.add_documents(documents)
|
|||
|
|
logger.info("documents_added_to_index", kb_name=kb_name, added_count=len(ids))
|
|||
|
|
return ids
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error("add_documents_failed", kb_name=kb_name, error=str(e))
|
|||
|
|
raise VectorStoreError(f"Failed to add documents to index: {str(e)}")
|
|||
|
|
|
|||
|
|
def delete_index(self, kb_name: str) -> None:
|
|||
|
|
"""
|
|||
|
|
从磁盘删除 FAISS 索引。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
kb_name: 知识库名称或 ID
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
VectorStoreError: 如果删除操作失败
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
import shutil
|
|||
|
|
|
|||
|
|
index_path = self.get_index_path(kb_name)
|
|||
|
|
|
|||
|
|
if os.path.exists(index_path):
|
|||
|
|
logger.info("deleting_faiss_index", kb_name=kb_name, path=index_path)
|
|||
|
|
shutil.rmtree(index_path)
|
|||
|
|
logger.info("faiss_index_deleted", kb_name=kb_name)
|
|||
|
|
else:
|
|||
|
|
logger.warning("faiss_index_not_found", kb_name=kb_name)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error("faiss_index_delete_failed", kb_name=kb_name, error=str(e))
|
|||
|
|
raise VectorStoreError(f"Failed to delete FAISS index: {str(e)}")
|