langchain-learning-kit/app/utils/faiss_helper.py

242 lines
7.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
FAISS 向量存储辅助函数。
"""
import os
from typing import List, Optional
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document as LangChainDocument
from langchain_core.embeddings import Embeddings
from app.utils.exceptions import VectorStoreError
from app.utils.logger import get_logger
logger = get_logger(__name__)
class FAISSHelper:
"""
FAISS 向量存储操作的辅助类。
"""
def __init__(self, base_path: str):
"""
初始化 FAISS 辅助器。
Args:
base_path: FAISS 索引的基础目录路径
"""
self.base_path = base_path
os.makedirs(base_path, exist_ok=True)
def get_index_path(self, kb_name: str) -> str:
"""
获取知识库的 FAISS 索引目录路径。
Args:
kb_name: 知识库名称或 ID
Returns:
str: FAISS 索引目录的完整路径
"""
return os.path.join(self.base_path, str(kb_name))
def index_exists(self, kb_name: str) -> bool:
"""
检查知识库的 FAISS 索引是否存在。
Args:
kb_name: 知识库名称或 ID
Returns:
bool: 如果索引存在则为 True否则为 False
"""
index_path = self.get_index_path(kb_name)
return os.path.exists(os.path.join(index_path, "index.faiss"))
def create_index(
self,
kb_name: str,
documents: List[LangChainDocument],
embeddings: Embeddings,
) -> FAISS:
"""
从文档创建新的 FAISS 索引。
Args:
kb_name: 知识库名称或 ID
documents: 要索引的文档列表
embeddings: 嵌入向量实例
Returns:
FAISS: FAISS 向量存储实例
Raises:
VectorStoreError: 如果索引创建失败
"""
try:
logger.info("creating_faiss_index", kb_name=kb_name, doc_count=len(documents))
if not documents:
raise VectorStoreError("Cannot create index from empty document list")
# 创建 FAISS 索引
vector_store = FAISS.from_documents(documents, embeddings)
logger.info("faiss_index_created", kb_name=kb_name)
return vector_store
except Exception as e:
logger.error("faiss_index_creation_failed", kb_name=kb_name, error=str(e))
raise VectorStoreError(f"Failed to create FAISS index: {str(e)}")
def save_index(self, kb_name: str, vector_store: FAISS) -> None:
"""
将 FAISS 索引保存到磁盘。
Args:
kb_name: 知识库名称或 ID
vector_store: FAISS 向量存储实例
Raises:
VectorStoreError: 如果保存操作失败
"""
try:
import pickle
index_path = self.get_index_path(kb_name)
os.makedirs(index_path, exist_ok=True)
logger.info("saving_faiss_index", kb_name=kb_name, path=index_path)
# 使用受控的 pickle 协议手动保存 FAISS 索引
index_file = os.path.join(index_path, "index.faiss")
pkl_file = os.path.join(index_path, "index.pkl")
# 保存 FAISS 索引(二进制格式,无 pickle 问题)
import faiss
faiss.write_index(vector_store.index, index_file)
# 使用 pickle 协议 4 保存文档存储和索引到文档存储的映射
with open(pkl_file, "wb") as f:
pickle.dump((vector_store.docstore, vector_store.index_to_docstore_id), f, protocol=4)
logger.info("faiss_index_saved", kb_name=kb_name)
except Exception as e:
logger.error("faiss_index_save_failed", kb_name=kb_name, error=str(e))
raise VectorStoreError(f"Failed to save FAISS index: {str(e)}")
def load_index(
self,
kb_name: str,
embeddings: Embeddings,
allow_dangerous_deserialization: bool = True,
) -> FAISS:
"""
从磁盘加载 FAISS 索引。
Args:
kb_name: 知识库名称或 ID
embeddings: 嵌入向量实例
allow_dangerous_deserialization: 允许 pickle 反序列化
Returns:
FAISS: FAISS 向量存储实例
Raises:
VectorStoreError: 如果加载操作失败
"""
try:
import pickle
index_path = self.get_index_path(kb_name)
if not self.index_exists(kb_name):
raise VectorStoreError(f"FAISS index not found for kb: {kb_name}")
logger.info("loading_faiss_index", kb_name=kb_name, path=index_path)
# 手动加载 FAISS 索引
index_file = os.path.join(index_path, "index.faiss")
pkl_file = os.path.join(index_path, "index.pkl")
# 加载 FAISS 索引
import faiss
index = faiss.read_index(index_file)
# 加载文档存储和索引到文档存储的映射
with open(pkl_file, "rb") as f:
docstore, index_to_docstore_id = pickle.load(f)
# 重构 FAISS 向量存储
vector_store = FAISS(
embedding_function=embeddings.embed_query,
index=index,
docstore=docstore,
index_to_docstore_id=index_to_docstore_id,
)
logger.info("faiss_index_loaded", kb_name=kb_name)
return vector_store
except Exception as e:
logger.error("faiss_index_load_failed", kb_name=kb_name, error=str(e))
raise VectorStoreError(f"Failed to load FAISS index: {str(e)}")
def add_documents(
self,
kb_name: str,
vector_store: FAISS,
documents: List[LangChainDocument],
) -> List[str]:
"""
向现有 FAISS 索引添加文档。
Args:
kb_name: 知识库名称或 ID
vector_store: FAISS 向量存储实例
documents: 要添加的文档列表
Returns:
List[str]: 文档 ID 列表
Raises:
VectorStoreError: 如果添加操作失败
"""
try:
logger.info("adding_documents_to_index", kb_name=kb_name, doc_count=len(documents))
ids = vector_store.add_documents(documents)
logger.info("documents_added_to_index", kb_name=kb_name, added_count=len(ids))
return ids
except Exception as e:
logger.error("add_documents_failed", kb_name=kb_name, error=str(e))
raise VectorStoreError(f"Failed to add documents to index: {str(e)}")
def delete_index(self, kb_name: str) -> None:
"""
从磁盘删除 FAISS 索引。
Args:
kb_name: 知识库名称或 ID
Raises:
VectorStoreError: 如果删除操作失败
"""
try:
import shutil
index_path = self.get_index_path(kb_name)
if os.path.exists(index_path):
logger.info("deleting_faiss_index", kb_name=kb_name, path=index_path)
shutil.rmtree(index_path)
logger.info("faiss_index_deleted", kb_name=kb_name)
else:
logger.warning("faiss_index_not_found", kb_name=kb_name)
except Exception as e:
logger.error("faiss_index_delete_failed", kb_name=kb_name, error=str(e))
raise VectorStoreError(f"Failed to delete FAISS index: {str(e)}")