392 lines
11 KiB
Python
392 lines
11 KiB
Python
|
|
"""
|
||
|
|
知识库 API 端点。
|
||
|
|
"""
|
||
|
|
from typing import List, Optional
|
||
|
|
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
|
||
|
|
from pydantic import BaseModel, Field
|
||
|
|
from sqlalchemy.orm import Session
|
||
|
|
|
||
|
|
from app.db.session import get_db
|
||
|
|
from app.services.kb_manager import KnowledgeBaseManager
|
||
|
|
from app.services.model_manager import ModelManager
|
||
|
|
from app.services.job_manager import AsyncJobManager
|
||
|
|
from app.config import get_settings
|
||
|
|
from app.utils.exceptions import (
|
||
|
|
ResourceNotFoundError,
|
||
|
|
DuplicateResourceError,
|
||
|
|
VectorStoreError,
|
||
|
|
)
|
||
|
|
from app.utils.logger import get_logger
|
||
|
|
|
||
|
|
logger = get_logger(__name__)
|
||
|
|
router = APIRouter(prefix="/kb", tags=["knowledge-base"])
|
||
|
|
|
||
|
|
|
||
|
|
# Pydantic 模式
|
||
|
|
class KBCreate(BaseModel):
|
||
|
|
"""创建知识库的模式。"""
|
||
|
|
name: str = Field(..., description="Knowledge base name", max_length=128)
|
||
|
|
description: str = Field(default="", description="Knowledge base description")
|
||
|
|
|
||
|
|
|
||
|
|
class KBResponse(BaseModel):
|
||
|
|
"""知识库响应的模式。"""
|
||
|
|
id: int
|
||
|
|
name: str
|
||
|
|
description: str
|
||
|
|
created_at: str
|
||
|
|
|
||
|
|
class Config:
|
||
|
|
from_attributes = True
|
||
|
|
|
||
|
|
|
||
|
|
class DocumentIngest(BaseModel):
|
||
|
|
"""文档摄取的模式。"""
|
||
|
|
content: str = Field(..., description="Document content")
|
||
|
|
title: str = Field(default="", description="Document title")
|
||
|
|
source: str = Field(default="", description="Document source")
|
||
|
|
metadata: dict = Field(default_factory=dict, description="Additional metadata")
|
||
|
|
|
||
|
|
|
||
|
|
class IngestRequest(BaseModel):
|
||
|
|
"""批量文档摄取的模式。"""
|
||
|
|
documents: List[DocumentIngest] = Field(..., description="List of documents to ingest")
|
||
|
|
embedding_name: Optional[str] = Field(None, description="Embedding model name")
|
||
|
|
background: bool = Field(default=True, description="Run in background")
|
||
|
|
|
||
|
|
|
||
|
|
class IngestResponse(BaseModel):
|
||
|
|
"""摄取响应的模式。"""
|
||
|
|
job_id: Optional[str] = Field(None, description="Job ID if background=True")
|
||
|
|
status: str = Field(..., description="Ingestion status")
|
||
|
|
message: str
|
||
|
|
|
||
|
|
|
||
|
|
class QueryRequest(BaseModel):
|
||
|
|
"""知识库查询的模式。"""
|
||
|
|
query: str = Field(..., description="Search query")
|
||
|
|
k: int = Field(default=5, ge=1, le=20, description="Number of results")
|
||
|
|
embedding_name: Optional[str] = Field(None, description="Embedding model name")
|
||
|
|
|
||
|
|
|
||
|
|
class QueryResult(BaseModel):
|
||
|
|
"""单个查询结果的模式。"""
|
||
|
|
content: str
|
||
|
|
metadata: dict
|
||
|
|
score: float
|
||
|
|
|
||
|
|
|
||
|
|
class QueryResponse(BaseModel):
|
||
|
|
"""查询响应的模式。"""
|
||
|
|
results: List[QueryResult]
|
||
|
|
result_count: int
|
||
|
|
|
||
|
|
|
||
|
|
class KBStatus(BaseModel):
|
||
|
|
"""知识库状态响应的模式。"""
|
||
|
|
kb_id: int
|
||
|
|
name: str
|
||
|
|
description: str
|
||
|
|
document_count: int
|
||
|
|
index_exists: bool
|
||
|
|
created_at: Optional[str]
|
||
|
|
|
||
|
|
|
||
|
|
# 依赖项
|
||
|
|
def get_kb_manager(
|
||
|
|
db: Session = Depends(get_db),
|
||
|
|
) -> KnowledgeBaseManager:
|
||
|
|
"""获取 KnowledgeBaseManager 实例。"""
|
||
|
|
settings = get_settings()
|
||
|
|
model_manager = ModelManager(db, settings)
|
||
|
|
job_manager = AsyncJobManager()
|
||
|
|
return KnowledgeBaseManager(db, model_manager, job_manager, settings)
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/", response_model=KBResponse, status_code=status.HTTP_201_CREATED)
|
||
|
|
def create_kb(
|
||
|
|
kb_data: KBCreate,
|
||
|
|
kb_manager: KnowledgeBaseManager = Depends(get_kb_manager),
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
创建一个新的知识库。
|
||
|
|
|
||
|
|
Args:
|
||
|
|
kb_data: 知识库创建数据
|
||
|
|
kb_manager: KnowledgeBaseManager 实例
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
创建的知识库
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
400: 如果存在同名知识库
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
kb = kb_manager.create_kb(name=kb_data.name, description=kb_data.description)
|
||
|
|
|
||
|
|
return KBResponse(
|
||
|
|
id=kb.id,
|
||
|
|
name=kb.name,
|
||
|
|
description=kb.description or "",
|
||
|
|
created_at=kb.created_at.isoformat() if kb.created_at else "",
|
||
|
|
)
|
||
|
|
|
||
|
|
except DuplicateResourceError as e:
|
||
|
|
logger.warning("duplicate_kb", name=kb_data.name)
|
||
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("kb_creation_failed", error=str(e))
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
detail=f"Failed to create knowledge base: {str(e)}",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/", response_model=List[KBResponse])
|
||
|
|
def list_kb(
|
||
|
|
kb_manager: KnowledgeBaseManager = Depends(get_kb_manager),
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
列出所有知识库。
|
||
|
|
|
||
|
|
Args:
|
||
|
|
kb_manager: KnowledgeBaseManager 实例
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
知识库列表
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
kbs = kb_manager.list_kb()
|
||
|
|
|
||
|
|
return [
|
||
|
|
KBResponse(
|
||
|
|
id=kb.id,
|
||
|
|
name=kb.name,
|
||
|
|
description=kb.description or "",
|
||
|
|
created_at=kb.created_at.isoformat() if kb.created_at else "",
|
||
|
|
)
|
||
|
|
for kb in kbs
|
||
|
|
]
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("kb_list_failed", error=str(e))
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
detail=f"Failed to list knowledge bases: {str(e)}",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/{kb_id}", response_model=KBResponse)
|
||
|
|
def get_kb(
|
||
|
|
kb_id: int,
|
||
|
|
kb_manager: KnowledgeBaseManager = Depends(get_kb_manager),
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
根据 ID 获取知识库。
|
||
|
|
|
||
|
|
Args:
|
||
|
|
kb_id: 知识库 ID
|
||
|
|
kb_manager: KnowledgeBaseManager 实例
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
知识库详情
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
404: 如果找不到知识库
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
kb = kb_manager.get_kb(kb_id)
|
||
|
|
|
||
|
|
return KBResponse(
|
||
|
|
id=kb.id,
|
||
|
|
name=kb.name,
|
||
|
|
description=kb.description or "",
|
||
|
|
created_at=kb.created_at.isoformat() if kb.created_at else "",
|
||
|
|
)
|
||
|
|
|
||
|
|
except ResourceNotFoundError as e:
|
||
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("kb_get_failed", kb_id=kb_id, error=str(e))
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
detail=f"Failed to get knowledge base: {str(e)}",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.delete("/{kb_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||
|
|
def delete_kb(
|
||
|
|
kb_id: int,
|
||
|
|
kb_manager: KnowledgeBaseManager = Depends(get_kb_manager),
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
删除知识库。
|
||
|
|
|
||
|
|
Args:
|
||
|
|
kb_id: 知识库 ID
|
||
|
|
kb_manager: KnowledgeBaseManager 实例
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
404: 如果找不到知识库
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
kb_manager.delete_kb(kb_id)
|
||
|
|
logger.info("kb_deleted_via_api", kb_id=kb_id)
|
||
|
|
|
||
|
|
except ResourceNotFoundError as e:
|
||
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("kb_delete_failed", kb_id=kb_id, error=str(e))
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
detail=f"Failed to delete knowledge base: {str(e)}",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/{kb_id}/ingest", response_model=IngestResponse)
|
||
|
|
async def ingest_documents(
|
||
|
|
kb_id: int,
|
||
|
|
ingest_data: IngestRequest,
|
||
|
|
kb_manager: KnowledgeBaseManager = Depends(get_kb_manager),
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
将文档摄取到知识库。
|
||
|
|
|
||
|
|
Args:
|
||
|
|
kb_id: 知识库 ID
|
||
|
|
ingest_data: 摄取请求数据
|
||
|
|
kb_manager: KnowledgeBaseManager 实例
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
摄取状态
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
404: 如果找不到知识库
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
# 将文档转换为字典格式
|
||
|
|
documents = [
|
||
|
|
{
|
||
|
|
"content": doc.content,
|
||
|
|
"title": doc.title,
|
||
|
|
"source": doc.source,
|
||
|
|
"metadata": doc.metadata,
|
||
|
|
}
|
||
|
|
for doc in ingest_data.documents
|
||
|
|
]
|
||
|
|
|
||
|
|
result = await kb_manager.ingest_documents(
|
||
|
|
kb_id=kb_id,
|
||
|
|
documents=documents,
|
||
|
|
embedding_name=ingest_data.embedding_name,
|
||
|
|
background=ingest_data.background,
|
||
|
|
)
|
||
|
|
|
||
|
|
if ingest_data.background:
|
||
|
|
return IngestResponse(
|
||
|
|
job_id=result,
|
||
|
|
status="submitted",
|
||
|
|
message=f"Ingestion job submitted with ID: {result}",
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
return IngestResponse(
|
||
|
|
status="completed",
|
||
|
|
message=f"Successfully ingested {len(documents)} documents",
|
||
|
|
)
|
||
|
|
|
||
|
|
except ResourceNotFoundError as e:
|
||
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("ingest_failed", kb_id=kb_id, error=str(e))
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
detail=f"Failed to ingest documents: {str(e)}",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/{kb_id}/query", response_model=QueryResponse)
|
||
|
|
def query_kb(
|
||
|
|
kb_id: int,
|
||
|
|
query_data: QueryRequest,
|
||
|
|
kb_manager: KnowledgeBaseManager = Depends(get_kb_manager),
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
使用向量相似度搜索查询知识库。
|
||
|
|
|
||
|
|
Args:
|
||
|
|
kb_id: 知识库 ID
|
||
|
|
query_data: 查询请求数据
|
||
|
|
kb_manager: KnowledgeBaseManager 实例
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
查询结果
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
404: 如果找不到知识库或索引不存在
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
results = kb_manager.query_kb(
|
||
|
|
kb_id=kb_id,
|
||
|
|
query=query_data.query,
|
||
|
|
k=query_data.k,
|
||
|
|
embedding_name=query_data.embedding_name,
|
||
|
|
)
|
||
|
|
|
||
|
|
return QueryResponse(
|
||
|
|
results=[
|
||
|
|
QueryResult(
|
||
|
|
content=r["content"],
|
||
|
|
metadata=r["metadata"],
|
||
|
|
score=r["score"],
|
||
|
|
)
|
||
|
|
for r in results
|
||
|
|
],
|
||
|
|
result_count=len(results),
|
||
|
|
)
|
||
|
|
|
||
|
|
except ResourceNotFoundError as e:
|
||
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||
|
|
except VectorStoreError as e:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("query_failed", kb_id=kb_id, error=str(e))
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
detail=f"Failed to query knowledge base: {str(e)}",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/{kb_id}/status", response_model=KBStatus)
|
||
|
|
def get_kb_status(
|
||
|
|
kb_id: int,
|
||
|
|
kb_manager: KnowledgeBaseManager = Depends(get_kb_manager),
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
获取知识库状态和统计信息。
|
||
|
|
|
||
|
|
Args:
|
||
|
|
kb_id: 知识库 ID
|
||
|
|
kb_manager: KnowledgeBaseManager 实例
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
知识库状态信息
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
404: 如果找不到知识库
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
status_info = kb_manager.get_status(kb_id)
|
||
|
|
|
||
|
|
return KBStatus(**status_info)
|
||
|
|
|
||
|
|
except ResourceNotFoundError as e:
|
||
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("kb_status_failed", kb_id=kb_id, error=str(e))
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
detail=f"Failed to get KB status: {str(e)}",
|
||
|
|
)
|