265 lines
7.5 KiB
Python
265 lines
7.5 KiB
Python
"""
|
||
模型管理 API 端点。
|
||
"""
|
||
from typing import List, Optional
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from pydantic import BaseModel, Field
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.db.session import get_db
|
||
from app.services.model_manager import ModelManager
|
||
from app.config import get_settings
|
||
from app.utils.exceptions import ResourceNotFoundError, DuplicateResourceError
|
||
from app.utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
router = APIRouter(prefix="/models", tags=["models"])
|
||
|
||
|
||
# Pydantic 模式
|
||
class ModelCreate(BaseModel):
|
||
"""创建模型的模式。"""
|
||
name: str = Field(..., description="Model name", max_length=128)
|
||
type: str = Field(..., description="Model type: 'llm' or 'embedding'")
|
||
config: dict = Field(..., description="Model configuration")
|
||
is_default: bool = Field(default=False, description="Set as default model")
|
||
|
||
|
||
class ModelUpdate(BaseModel):
|
||
"""更新模型的模式。"""
|
||
config: Optional[dict] = Field(None, description="Model configuration")
|
||
is_default: Optional[bool] = Field(None, description="Set as default model")
|
||
status: Optional[str] = Field(None, description="Model status")
|
||
|
||
|
||
class ModelResponse(BaseModel):
|
||
"""模型响应的模式。"""
|
||
id: int
|
||
name: str
|
||
type: str
|
||
config: dict
|
||
is_default: bool
|
||
status: str
|
||
created_at: str
|
||
|
||
class Config:
|
||
from_attributes = True
|
||
|
||
|
||
# ModelManager 的依赖项
|
||
def get_model_manager(
|
||
db: Session = Depends(get_db),
|
||
) -> ModelManager:
|
||
"""获取 ModelManager 实例。"""
|
||
settings = get_settings()
|
||
return ModelManager(db, settings)
|
||
|
||
|
||
@router.post("/", response_model=ModelResponse, status_code=status.HTTP_201_CREATED)
|
||
def create_model(
|
||
model_data: ModelCreate,
|
||
model_manager: ModelManager = Depends(get_model_manager),
|
||
):
|
||
"""
|
||
创建新的模型配置。
|
||
|
||
Args:
|
||
model_data: 模型创建数据
|
||
model_manager: ModelManager 实例
|
||
|
||
Returns:
|
||
创建的模型
|
||
|
||
Raises:
|
||
400: 如果存在同名模型
|
||
"""
|
||
try:
|
||
model = model_manager.create_model(
|
||
name=model_data.name,
|
||
type=model_data.type,
|
||
config=model_data.config,
|
||
is_default=model_data.is_default,
|
||
)
|
||
|
||
return ModelResponse(
|
||
id=model.id,
|
||
name=model.name,
|
||
type=model.type,
|
||
config=model.config,
|
||
is_default=model.is_default or False,
|
||
status=model.status or "active",
|
||
created_at=model.created_at.isoformat() if model.created_at else "",
|
||
)
|
||
|
||
except DuplicateResourceError as e:
|
||
logger.warning("duplicate_model", name=model_data.name)
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||
except Exception as e:
|
||
logger.error("model_creation_failed", error=str(e))
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Failed to create model: {str(e)}",
|
||
)
|
||
|
||
|
||
@router.get("/", response_model=List[ModelResponse])
|
||
def list_models(
|
||
model_type: Optional[str] = None,
|
||
model_manager: ModelManager = Depends(get_model_manager),
|
||
):
|
||
"""
|
||
列出所有模型,可选按类型过滤。
|
||
|
||
Args:
|
||
model_type: 可选的模型类型过滤器('llm' 或 'embedding')
|
||
model_manager: ModelManager 实例
|
||
|
||
Returns:
|
||
模型列表
|
||
"""
|
||
try:
|
||
models = model_manager.list_models(type=model_type)
|
||
|
||
return [
|
||
ModelResponse(
|
||
id=m.id,
|
||
name=m.name,
|
||
type=m.type,
|
||
config=m.config,
|
||
is_default=m.is_default or False,
|
||
status=m.status or "active",
|
||
created_at=m.created_at.isoformat() if m.created_at else "",
|
||
)
|
||
for m in models
|
||
]
|
||
|
||
except Exception as e:
|
||
logger.error("model_list_failed", error=str(e))
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Failed to list models: {str(e)}",
|
||
)
|
||
|
||
|
||
@router.get("/{model_id}", response_model=ModelResponse)
|
||
def get_model(
|
||
model_id: int,
|
||
model_manager: ModelManager = Depends(get_model_manager),
|
||
):
|
||
"""
|
||
根据 ID 获取模型。
|
||
|
||
Args:
|
||
model_id: 模型 ID
|
||
model_manager: ModelManager 实例
|
||
|
||
Returns:
|
||
模型详情
|
||
|
||
Raises:
|
||
404: 如果找不到模型
|
||
"""
|
||
try:
|
||
model = model_manager.get_model(model_id)
|
||
|
||
return ModelResponse(
|
||
id=model.id,
|
||
name=model.name,
|
||
type=model.type,
|
||
config=model.config,
|
||
is_default=model.is_default or False,
|
||
status=model.status or "active",
|
||
created_at=model.created_at.isoformat() if model.created_at else "",
|
||
)
|
||
|
||
except ResourceNotFoundError as e:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||
except Exception as e:
|
||
logger.error("model_get_failed", model_id=model_id, error=str(e))
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Failed to get model: {str(e)}",
|
||
)
|
||
|
||
|
||
@router.patch("/{model_id}", response_model=ModelResponse)
|
||
def update_model(
|
||
model_id: int,
|
||
model_data: ModelUpdate,
|
||
model_manager: ModelManager = Depends(get_model_manager),
|
||
):
|
||
"""
|
||
更新模型配置。
|
||
|
||
Args:
|
||
model_id: 模型 ID
|
||
model_data: 模型更新数据
|
||
model_manager: ModelManager 实例
|
||
|
||
Returns:
|
||
更新后的模型
|
||
|
||
Raises:
|
||
404: 如果找不到模型
|
||
"""
|
||
try:
|
||
# 使用 ModelManager 方法更新模型
|
||
model = model_manager.update_model(
|
||
model_id=model_id,
|
||
config=model_data.config,
|
||
is_default=model_data.is_default,
|
||
status=model_data.status,
|
||
)
|
||
|
||
logger.info("model_updated", model_id=model_id)
|
||
|
||
return ModelResponse(
|
||
id=model.id,
|
||
name=model.name,
|
||
type=model.type,
|
||
config=model.config,
|
||
is_default=model.is_default or False,
|
||
status=model.status or "active",
|
||
created_at=model.created_at.isoformat() if model.created_at else "",
|
||
)
|
||
|
||
except ResourceNotFoundError as e:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||
except Exception as e:
|
||
logger.error("model_update_failed", model_id=model_id, error=str(e))
|
||
model_manager.db.rollback()
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Failed to update model: {str(e)}",
|
||
)
|
||
|
||
|
||
@router.delete("/{model_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||
def delete_model(
|
||
model_id: int,
|
||
model_manager: ModelManager = Depends(get_model_manager),
|
||
):
|
||
"""
|
||
删除模型。
|
||
|
||
Args:
|
||
model_id: 模型 ID
|
||
model_manager: ModelManager 实例
|
||
|
||
Raises:
|
||
404: 如果找不到模型
|
||
"""
|
||
try:
|
||
model_manager.delete_model(model_id)
|
||
logger.info("model_deleted", model_id=model_id)
|
||
|
||
except ResourceNotFoundError as e:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||
except Exception as e:
|
||
logger.error("model_delete_failed", model_id=model_id, error=str(e))
|
||
model_manager.db.rollback()
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Failed to delete model: {str(e)}",
|
||
)
|