""" 知识库 API 端点。 """ from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks from pydantic import BaseModel, Field from sqlalchemy.orm import Session from app.db.session import get_db from app.services.kb_manager import KnowledgeBaseManager from app.services.model_manager import ModelManager from app.services.job_manager import AsyncJobManager from app.config import get_settings from app.utils.exceptions import ( ResourceNotFoundError, DuplicateResourceError, VectorStoreError, ) from app.utils.logger import get_logger logger = get_logger(__name__) router = APIRouter(prefix="/kb", tags=["knowledge-base"]) # Pydantic 模式 class KBCreate(BaseModel): """创建知识库的模式。""" name: str = Field(..., description="Knowledge base name", max_length=128) description: str = Field(default="", description="Knowledge base description") class KBResponse(BaseModel): """知识库响应的模式。""" id: int name: str description: str created_at: str class Config: from_attributes = True class DocumentIngest(BaseModel): """文档摄取的模式。""" content: str = Field(..., description="Document content") title: str = Field(default="", description="Document title") source: str = Field(default="", description="Document source") metadata: dict = Field(default_factory=dict, description="Additional metadata") class IngestRequest(BaseModel): """批量文档摄取的模式。""" documents: List[DocumentIngest] = Field(..., description="List of documents to ingest") embedding_name: Optional[str] = Field(None, description="Embedding model name") background: bool = Field(default=True, description="Run in background") class IngestResponse(BaseModel): """摄取响应的模式。""" job_id: Optional[str] = Field(None, description="Job ID if background=True") status: str = Field(..., description="Ingestion status") message: str class QueryRequest(BaseModel): """知识库查询的模式。""" query: str = Field(..., description="Search query") k: int = Field(default=5, ge=1, le=20, description="Number of results") embedding_name: Optional[str] = Field(None, description="Embedding model name") class QueryResult(BaseModel): """单个查询结果的模式。""" content: str metadata: dict score: float class QueryResponse(BaseModel): """查询响应的模式。""" results: List[QueryResult] result_count: int class KBStatus(BaseModel): """知识库状态响应的模式。""" kb_id: int name: str description: str document_count: int index_exists: bool created_at: Optional[str] # 依赖项 def get_kb_manager( db: Session = Depends(get_db), ) -> KnowledgeBaseManager: """获取 KnowledgeBaseManager 实例。""" settings = get_settings() model_manager = ModelManager(db, settings) job_manager = AsyncJobManager() return KnowledgeBaseManager(db, model_manager, job_manager, settings) @router.post("/", response_model=KBResponse, status_code=status.HTTP_201_CREATED) def create_kb( kb_data: KBCreate, kb_manager: KnowledgeBaseManager = Depends(get_kb_manager), ): """ 创建一个新的知识库。 Args: kb_data: 知识库创建数据 kb_manager: KnowledgeBaseManager 实例 Returns: 创建的知识库 Raises: 400: 如果存在同名知识库 """ try: kb = kb_manager.create_kb(name=kb_data.name, description=kb_data.description) return KBResponse( id=kb.id, name=kb.name, description=kb.description or "", created_at=kb.created_at.isoformat() if kb.created_at else "", ) except DuplicateResourceError as e: logger.warning("duplicate_kb", name=kb_data.name) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) except Exception as e: logger.error("kb_creation_failed", error=str(e)) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create knowledge base: {str(e)}", ) @router.get("/", response_model=List[KBResponse]) def list_kb( kb_manager: KnowledgeBaseManager = Depends(get_kb_manager), ): """ 列出所有知识库。 Args: kb_manager: KnowledgeBaseManager 实例 Returns: 知识库列表 """ try: kbs = kb_manager.list_kb() return [ KBResponse( id=kb.id, name=kb.name, description=kb.description or "", created_at=kb.created_at.isoformat() if kb.created_at else "", ) for kb in kbs ] except Exception as e: logger.error("kb_list_failed", error=str(e)) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to list knowledge bases: {str(e)}", ) @router.get("/{kb_id}", response_model=KBResponse) def get_kb( kb_id: int, kb_manager: KnowledgeBaseManager = Depends(get_kb_manager), ): """ 根据 ID 获取知识库。 Args: kb_id: 知识库 ID kb_manager: KnowledgeBaseManager 实例 Returns: 知识库详情 Raises: 404: 如果找不到知识库 """ try: kb = kb_manager.get_kb(kb_id) return KBResponse( id=kb.id, name=kb.name, description=kb.description or "", created_at=kb.created_at.isoformat() if kb.created_at else "", ) except ResourceNotFoundError as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) except Exception as e: logger.error("kb_get_failed", kb_id=kb_id, error=str(e)) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to get knowledge base: {str(e)}", ) @router.delete("/{kb_id}", status_code=status.HTTP_204_NO_CONTENT) def delete_kb( kb_id: int, kb_manager: KnowledgeBaseManager = Depends(get_kb_manager), ): """ 删除知识库。 Args: kb_id: 知识库 ID kb_manager: KnowledgeBaseManager 实例 Raises: 404: 如果找不到知识库 """ try: kb_manager.delete_kb(kb_id) logger.info("kb_deleted_via_api", kb_id=kb_id) except ResourceNotFoundError as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) except Exception as e: logger.error("kb_delete_failed", kb_id=kb_id, error=str(e)) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to delete knowledge base: {str(e)}", ) @router.post("/{kb_id}/ingest", response_model=IngestResponse) async def ingest_documents( kb_id: int, ingest_data: IngestRequest, kb_manager: KnowledgeBaseManager = Depends(get_kb_manager), ): """ 将文档摄取到知识库。 Args: kb_id: 知识库 ID ingest_data: 摄取请求数据 kb_manager: KnowledgeBaseManager 实例 Returns: 摄取状态 Raises: 404: 如果找不到知识库 """ try: # 将文档转换为字典格式 documents = [ { "content": doc.content, "title": doc.title, "source": doc.source, "metadata": doc.metadata, } for doc in ingest_data.documents ] result = await kb_manager.ingest_documents( kb_id=kb_id, documents=documents, embedding_name=ingest_data.embedding_name, background=ingest_data.background, ) if ingest_data.background: return IngestResponse( job_id=result, status="submitted", message=f"Ingestion job submitted with ID: {result}", ) else: return IngestResponse( status="completed", message=f"Successfully ingested {len(documents)} documents", ) except ResourceNotFoundError as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) except Exception as e: logger.error("ingest_failed", kb_id=kb_id, error=str(e)) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to ingest documents: {str(e)}", ) @router.post("/{kb_id}/query", response_model=QueryResponse) def query_kb( kb_id: int, query_data: QueryRequest, kb_manager: KnowledgeBaseManager = Depends(get_kb_manager), ): """ 使用向量相似度搜索查询知识库。 Args: kb_id: 知识库 ID query_data: 查询请求数据 kb_manager: KnowledgeBaseManager 实例 Returns: 查询结果 Raises: 404: 如果找不到知识库或索引不存在 """ try: results = kb_manager.query_kb( kb_id=kb_id, query=query_data.query, k=query_data.k, embedding_name=query_data.embedding_name, ) return QueryResponse( results=[ QueryResult( content=r["content"], metadata=r["metadata"], score=r["score"], ) for r in results ], result_count=len(results), ) except ResourceNotFoundError as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) except VectorStoreError as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) except Exception as e: logger.error("query_failed", kb_id=kb_id, error=str(e)) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query knowledge base: {str(e)}", ) @router.get("/{kb_id}/status", response_model=KBStatus) def get_kb_status( kb_id: int, kb_manager: KnowledgeBaseManager = Depends(get_kb_manager), ): """ 获取知识库状态和统计信息。 Args: kb_id: 知识库 ID kb_manager: KnowledgeBaseManager 实例 Returns: 知识库状态信息 Raises: 404: 如果找不到知识库 """ try: status_info = kb_manager.get_status(kb_id) return KBStatus(**status_info) except ResourceNotFoundError as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) except Exception as e: logger.error("kb_status_failed", kb_id=kb_id, error=str(e)) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to get KB status: {str(e)}", )