langchain-learning-kit/app/api/models.py

265 lines
7.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
模型管理 API 端点。
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.db.session import get_db
from app.services.model_manager import ModelManager
from app.config import get_settings
from app.utils.exceptions import ResourceNotFoundError, DuplicateResourceError
from app.utils.logger import get_logger
logger = get_logger(__name__)
router = APIRouter(prefix="/models", tags=["models"])
# Pydantic 模式
class ModelCreate(BaseModel):
"""创建模型的模式。"""
name: str = Field(..., description="Model name", max_length=128)
type: str = Field(..., description="Model type: 'llm' or 'embedding'")
config: dict = Field(..., description="Model configuration")
is_default: bool = Field(default=False, description="Set as default model")
class ModelUpdate(BaseModel):
"""更新模型的模式。"""
config: Optional[dict] = Field(None, description="Model configuration")
is_default: Optional[bool] = Field(None, description="Set as default model")
status: Optional[str] = Field(None, description="Model status")
class ModelResponse(BaseModel):
"""模型响应的模式。"""
id: int
name: str
type: str
config: dict
is_default: bool
status: str
created_at: str
class Config:
from_attributes = True
# ModelManager 的依赖项
def get_model_manager(
db: Session = Depends(get_db),
) -> ModelManager:
"""获取 ModelManager 实例。"""
settings = get_settings()
return ModelManager(db, settings)
@router.post("/", response_model=ModelResponse, status_code=status.HTTP_201_CREATED)
def create_model(
model_data: ModelCreate,
model_manager: ModelManager = Depends(get_model_manager),
):
"""
创建新的模型配置。
Args:
model_data: 模型创建数据
model_manager: ModelManager 实例
Returns:
创建的模型
Raises:
400: 如果存在同名模型
"""
try:
model = model_manager.create_model(
name=model_data.name,
type=model_data.type,
config=model_data.config,
is_default=model_data.is_default,
)
return ModelResponse(
id=model.id,
name=model.name,
type=model.type,
config=model.config,
is_default=model.is_default or False,
status=model.status or "active",
created_at=model.created_at.isoformat() if model.created_at else "",
)
except DuplicateResourceError as e:
logger.warning("duplicate_model", name=model_data.name)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
logger.error("model_creation_failed", error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create model: {str(e)}",
)
@router.get("/", response_model=List[ModelResponse])
def list_models(
model_type: Optional[str] = None,
model_manager: ModelManager = Depends(get_model_manager),
):
"""
列出所有模型,可选按类型过滤。
Args:
model_type: 可选的模型类型过滤器('llm''embedding'
model_manager: ModelManager 实例
Returns:
模型列表
"""
try:
models = model_manager.list_models(type=model_type)
return [
ModelResponse(
id=m.id,
name=m.name,
type=m.type,
config=m.config,
is_default=m.is_default or False,
status=m.status or "active",
created_at=m.created_at.isoformat() if m.created_at else "",
)
for m in models
]
except Exception as e:
logger.error("model_list_failed", error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to list models: {str(e)}",
)
@router.get("/{model_id}", response_model=ModelResponse)
def get_model(
model_id: int,
model_manager: ModelManager = Depends(get_model_manager),
):
"""
根据 ID 获取模型。
Args:
model_id: 模型 ID
model_manager: ModelManager 实例
Returns:
模型详情
Raises:
404: 如果找不到模型
"""
try:
model = model_manager.get_model(model_id)
return ModelResponse(
id=model.id,
name=model.name,
type=model.type,
config=model.config,
is_default=model.is_default or False,
status=model.status or "active",
created_at=model.created_at.isoformat() if model.created_at else "",
)
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
logger.error("model_get_failed", model_id=model_id, error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get model: {str(e)}",
)
@router.patch("/{model_id}", response_model=ModelResponse)
def update_model(
model_id: int,
model_data: ModelUpdate,
model_manager: ModelManager = Depends(get_model_manager),
):
"""
更新模型配置。
Args:
model_id: 模型 ID
model_data: 模型更新数据
model_manager: ModelManager 实例
Returns:
更新后的模型
Raises:
404: 如果找不到模型
"""
try:
# 使用 ModelManager 方法更新模型
model = model_manager.update_model(
model_id=model_id,
config=model_data.config,
is_default=model_data.is_default,
status=model_data.status,
)
logger.info("model_updated", model_id=model_id)
return ModelResponse(
id=model.id,
name=model.name,
type=model.type,
config=model.config,
is_default=model.is_default or False,
status=model.status or "active",
created_at=model.created_at.isoformat() if model.created_at else "",
)
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
logger.error("model_update_failed", model_id=model_id, error=str(e))
model_manager.db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update model: {str(e)}",
)
@router.delete("/{model_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_model(
model_id: int,
model_manager: ModelManager = Depends(get_model_manager),
):
"""
删除模型。
Args:
model_id: 模型 ID
model_manager: ModelManager 实例
Raises:
404: 如果找不到模型
"""
try:
model_manager.delete_model(model_id)
logger.info("model_deleted", model_id=model_id)
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
logger.error("model_delete_failed", model_id=model_id, error=str(e))
model_manager.db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete model: {str(e)}",
)