langchain-learning-kit/app/api/kb.py

392 lines
11 KiB
Python
Raw Permalink Normal View History

"""
知识库 API 端点
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.db.session import get_db
from app.services.kb_manager import KnowledgeBaseManager
from app.services.model_manager import ModelManager
from app.services.job_manager import AsyncJobManager
from app.config import get_settings
from app.utils.exceptions import (
ResourceNotFoundError,
DuplicateResourceError,
VectorStoreError,
)
from app.utils.logger import get_logger
logger = get_logger(__name__)
router = APIRouter(prefix="/kb", tags=["knowledge-base"])
# Pydantic 模式
class KBCreate(BaseModel):
"""创建知识库的模式。"""
name: str = Field(..., description="Knowledge base name", max_length=128)
description: str = Field(default="", description="Knowledge base description")
class KBResponse(BaseModel):
"""知识库响应的模式。"""
id: int
name: str
description: str
created_at: str
class Config:
from_attributes = True
class DocumentIngest(BaseModel):
"""文档摄取的模式。"""
content: str = Field(..., description="Document content")
title: str = Field(default="", description="Document title")
source: str = Field(default="", description="Document source")
metadata: dict = Field(default_factory=dict, description="Additional metadata")
class IngestRequest(BaseModel):
"""批量文档摄取的模式。"""
documents: List[DocumentIngest] = Field(..., description="List of documents to ingest")
embedding_name: Optional[str] = Field(None, description="Embedding model name")
background: bool = Field(default=True, description="Run in background")
class IngestResponse(BaseModel):
"""摄取响应的模式。"""
job_id: Optional[str] = Field(None, description="Job ID if background=True")
status: str = Field(..., description="Ingestion status")
message: str
class QueryRequest(BaseModel):
"""知识库查询的模式。"""
query: str = Field(..., description="Search query")
k: int = Field(default=5, ge=1, le=20, description="Number of results")
embedding_name: Optional[str] = Field(None, description="Embedding model name")
class QueryResult(BaseModel):
"""单个查询结果的模式。"""
content: str
metadata: dict
score: float
class QueryResponse(BaseModel):
"""查询响应的模式。"""
results: List[QueryResult]
result_count: int
class KBStatus(BaseModel):
"""知识库状态响应的模式。"""
kb_id: int
name: str
description: str
document_count: int
index_exists: bool
created_at: Optional[str]
# 依赖项
def get_kb_manager(
db: Session = Depends(get_db),
) -> KnowledgeBaseManager:
"""获取 KnowledgeBaseManager 实例。"""
settings = get_settings()
model_manager = ModelManager(db, settings)
job_manager = AsyncJobManager()
return KnowledgeBaseManager(db, model_manager, job_manager, settings)
@router.post("/", response_model=KBResponse, status_code=status.HTTP_201_CREATED)
def create_kb(
kb_data: KBCreate,
kb_manager: KnowledgeBaseManager = Depends(get_kb_manager),
):
"""
创建一个新的知识库
Args:
kb_data: 知识库创建数据
kb_manager: KnowledgeBaseManager 实例
Returns:
创建的知识库
Raises:
400: 如果存在同名知识库
"""
try:
kb = kb_manager.create_kb(name=kb_data.name, description=kb_data.description)
return KBResponse(
id=kb.id,
name=kb.name,
description=kb.description or "",
created_at=kb.created_at.isoformat() if kb.created_at else "",
)
except DuplicateResourceError as e:
logger.warning("duplicate_kb", name=kb_data.name)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
logger.error("kb_creation_failed", error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create knowledge base: {str(e)}",
)
@router.get("/", response_model=List[KBResponse])
def list_kb(
kb_manager: KnowledgeBaseManager = Depends(get_kb_manager),
):
"""
列出所有知识库
Args:
kb_manager: KnowledgeBaseManager 实例
Returns:
知识库列表
"""
try:
kbs = kb_manager.list_kb()
return [
KBResponse(
id=kb.id,
name=kb.name,
description=kb.description or "",
created_at=kb.created_at.isoformat() if kb.created_at else "",
)
for kb in kbs
]
except Exception as e:
logger.error("kb_list_failed", error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to list knowledge bases: {str(e)}",
)
@router.get("/{kb_id}", response_model=KBResponse)
def get_kb(
kb_id: int,
kb_manager: KnowledgeBaseManager = Depends(get_kb_manager),
):
"""
根据 ID 获取知识库
Args:
kb_id: 知识库 ID
kb_manager: KnowledgeBaseManager 实例
Returns:
知识库详情
Raises:
404: 如果找不到知识库
"""
try:
kb = kb_manager.get_kb(kb_id)
return KBResponse(
id=kb.id,
name=kb.name,
description=kb.description or "",
created_at=kb.created_at.isoformat() if kb.created_at else "",
)
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
logger.error("kb_get_failed", kb_id=kb_id, error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get knowledge base: {str(e)}",
)
@router.delete("/{kb_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_kb(
kb_id: int,
kb_manager: KnowledgeBaseManager = Depends(get_kb_manager),
):
"""
删除知识库
Args:
kb_id: 知识库 ID
kb_manager: KnowledgeBaseManager 实例
Raises:
404: 如果找不到知识库
"""
try:
kb_manager.delete_kb(kb_id)
logger.info("kb_deleted_via_api", kb_id=kb_id)
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
logger.error("kb_delete_failed", kb_id=kb_id, error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete knowledge base: {str(e)}",
)
@router.post("/{kb_id}/ingest", response_model=IngestResponse)
async def ingest_documents(
kb_id: int,
ingest_data: IngestRequest,
kb_manager: KnowledgeBaseManager = Depends(get_kb_manager),
):
"""
将文档摄取到知识库
Args:
kb_id: 知识库 ID
ingest_data: 摄取请求数据
kb_manager: KnowledgeBaseManager 实例
Returns:
摄取状态
Raises:
404: 如果找不到知识库
"""
try:
# 将文档转换为字典格式
documents = [
{
"content": doc.content,
"title": doc.title,
"source": doc.source,
"metadata": doc.metadata,
}
for doc in ingest_data.documents
]
result = await kb_manager.ingest_documents(
kb_id=kb_id,
documents=documents,
embedding_name=ingest_data.embedding_name,
background=ingest_data.background,
)
if ingest_data.background:
return IngestResponse(
job_id=result,
status="submitted",
message=f"Ingestion job submitted with ID: {result}",
)
else:
return IngestResponse(
status="completed",
message=f"Successfully ingested {len(documents)} documents",
)
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
logger.error("ingest_failed", kb_id=kb_id, error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to ingest documents: {str(e)}",
)
@router.post("/{kb_id}/query", response_model=QueryResponse)
def query_kb(
kb_id: int,
query_data: QueryRequest,
kb_manager: KnowledgeBaseManager = Depends(get_kb_manager),
):
"""
使用向量相似度搜索查询知识库
Args:
kb_id: 知识库 ID
query_data: 查询请求数据
kb_manager: KnowledgeBaseManager 实例
Returns:
查询结果
Raises:
404: 如果找不到知识库或索引不存在
"""
try:
results = kb_manager.query_kb(
kb_id=kb_id,
query=query_data.query,
k=query_data.k,
embedding_name=query_data.embedding_name,
)
return QueryResponse(
results=[
QueryResult(
content=r["content"],
metadata=r["metadata"],
score=r["score"],
)
for r in results
],
result_count=len(results),
)
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except VectorStoreError as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
except Exception as e:
logger.error("query_failed", kb_id=kb_id, error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to query knowledge base: {str(e)}",
)
@router.get("/{kb_id}/status", response_model=KBStatus)
def get_kb_status(
kb_id: int,
kb_manager: KnowledgeBaseManager = Depends(get_kb_manager),
):
"""
获取知识库状态和统计信息
Args:
kb_id: 知识库 ID
kb_manager: KnowledgeBaseManager 实例
Returns:
知识库状态信息
Raises:
404: 如果找不到知识库
"""
try:
status_info = kb_manager.get_status(kb_id)
return KBStatus(**status_info)
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
logger.error("kb_status_failed", kb_id=kb_id, error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get KB status: {str(e)}",
)