350 lines
11 KiB
Python
350 lines
11 KiB
Python
"""
|
||
模型管理器服务,用于LLM和Embedding模型管理。
|
||
"""
|
||
import os
|
||
import re
|
||
from typing import Optional, List, Dict, Any
|
||
|
||
from sqlalchemy.orm import Session
|
||
from langchain_core.language_models.llms import BaseLLM
|
||
from langchain_core.embeddings import Embeddings
|
||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||
|
||
from app.db.models import Model
|
||
from app.config import Settings
|
||
from app.utils.exceptions import ResourceNotFoundError, InvalidConfigError, DuplicateResourceError
|
||
from app.utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class ModelManager:
|
||
"""
|
||
管理LLM和Embedding模型配置。
|
||
"""
|
||
|
||
def __init__(self, db_session: Session, settings: Settings):
|
||
"""
|
||
初始化模型管理器。
|
||
|
||
Args:
|
||
db_session: 数据库会话
|
||
settings: 应用程序配置
|
||
"""
|
||
self.db = db_session
|
||
self.settings = settings
|
||
|
||
def _replace_env_vars(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
替换配置中的环境变量占位符。
|
||
|
||
Args:
|
||
config: 配置字典
|
||
|
||
Returns:
|
||
Dict[str, Any]: 替换环境变量后的配置
|
||
"""
|
||
result = {}
|
||
for key, value in config.items():
|
||
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||
# 提取变量名: ${VAR_NAME} -> VAR_NAME
|
||
var_name = value[2:-1]
|
||
|
||
# 首先尝试从os.environ获取,然后从settings获取
|
||
env_value = os.getenv(var_name)
|
||
if env_value is None:
|
||
# 尝试从settings对象获取
|
||
if hasattr(self.settings, var_name.lower()):
|
||
env_value = getattr(self.settings, var_name.lower())
|
||
|
||
if env_value is None:
|
||
raise InvalidConfigError(
|
||
f"Environment variable {var_name} not found",
|
||
details={"variable": var_name}
|
||
)
|
||
result[key] = env_value
|
||
else:
|
||
result[key] = value
|
||
return result
|
||
|
||
def create_model(
|
||
self,
|
||
name: str,
|
||
type: str,
|
||
config: Dict[str, Any],
|
||
is_default: bool = False,
|
||
) -> Model:
|
||
"""
|
||
创建新的模型配置。
|
||
|
||
Args:
|
||
name: 模型名称
|
||
type: 模型类型('llm' 或 'embedding')
|
||
config: 模型配置
|
||
is_default: 设置为默认模型
|
||
|
||
Returns:
|
||
Model: 创建的模型实例
|
||
|
||
Raises:
|
||
DuplicateResourceError: 如果存在同名模型
|
||
InvalidConfigError: 如果配置无效
|
||
"""
|
||
logger.info("creating_model", name=name, type=type)
|
||
|
||
# 检查模型是否已存在
|
||
existing = self.db.query(Model).filter(Model.name == name).first()
|
||
if existing:
|
||
raise DuplicateResourceError(f"Model '{name}' already exists")
|
||
|
||
# 验证类型
|
||
if type not in ["llm", "embedding"]:
|
||
raise InvalidConfigError(f"Invalid model type: {type}. Must be 'llm' or 'embedding'")
|
||
|
||
# 如果设置为默认,取消其他默认设置
|
||
if is_default:
|
||
self.db.query(Model).filter(Model.type == type, Model.is_default == True).update(
|
||
{"is_default": False}
|
||
)
|
||
|
||
# 创建模型
|
||
model = Model(
|
||
name=name,
|
||
type=type,
|
||
config=config,
|
||
is_default=is_default,
|
||
status="active",
|
||
)
|
||
|
||
self.db.add(model)
|
||
self.db.commit()
|
||
self.db.refresh(model)
|
||
|
||
logger.info("model_created", model_id=model.id, name=name)
|
||
return model
|
||
|
||
def get_model(self, model_id: int) -> Model:
|
||
"""
|
||
根据ID获取模型。
|
||
|
||
Args:
|
||
model_id: 模型ID
|
||
|
||
Returns:
|
||
Model: 模型实例
|
||
|
||
Raises:
|
||
ResourceNotFoundError: 如果模型未找到
|
||
"""
|
||
model = self.db.query(Model).filter(Model.id == model_id).first()
|
||
if not model:
|
||
raise ResourceNotFoundError(f"Model not found: {model_id}")
|
||
return model
|
||
|
||
def get_model_by_name(self, name: str) -> Model:
|
||
"""
|
||
根据名称获取模型。
|
||
|
||
Args:
|
||
name: 模型名称
|
||
|
||
Returns:
|
||
Model: 模型实例
|
||
|
||
Raises:
|
||
ResourceNotFoundError: 如果模型未找到
|
||
"""
|
||
model = self.db.query(Model).filter(Model.name == name).first()
|
||
if not model:
|
||
raise ResourceNotFoundError(f"Model not found: {name}")
|
||
return model
|
||
|
||
def list_models(self, type: Optional[str] = None) -> List[Model]:
|
||
"""
|
||
列出所有模型,可选择按类型过滤。
|
||
|
||
Args:
|
||
type: 可选,按模型类型过滤
|
||
|
||
Returns:
|
||
List[Model]: 模型列表
|
||
"""
|
||
query = self.db.query(Model).filter(Model.status == "active")
|
||
if type:
|
||
query = query.filter(Model.type == type)
|
||
return query.all()
|
||
|
||
def update_model(
|
||
self,
|
||
model_id: int,
|
||
config: Optional[Dict[str, Any]] = None,
|
||
is_default: Optional[bool] = None,
|
||
status: Optional[str] = None,
|
||
) -> Model:
|
||
"""
|
||
更新模型配置。
|
||
|
||
Args:
|
||
model_id: 模型ID
|
||
config: 新配置
|
||
is_default: 更新默认状态
|
||
status: 更新状态
|
||
|
||
Returns:
|
||
Model: 更新后的模型实例
|
||
|
||
Raises:
|
||
ResourceNotFoundError: 如果模型未找到
|
||
"""
|
||
model = self.get_model(model_id)
|
||
|
||
if config is not None:
|
||
model.config = config
|
||
|
||
if is_default is not None and is_default:
|
||
# 取消其他默认设置
|
||
self.db.query(Model).filter(
|
||
Model.type == model.type,
|
||
Model.is_default == True,
|
||
Model.id != model_id
|
||
).update({"is_default": False})
|
||
model.is_default = True
|
||
|
||
if status is not None:
|
||
model.status = status
|
||
|
||
self.db.commit()
|
||
self.db.refresh(model)
|
||
|
||
logger.info("model_updated", model_id=model_id)
|
||
return model
|
||
|
||
def delete_model(self, model_id: int) -> bool:
|
||
"""
|
||
删除(软删除)模型。
|
||
|
||
Args:
|
||
model_id: 模型ID
|
||
|
||
Returns:
|
||
bool: 如果删除成功返回True
|
||
|
||
Raises:
|
||
ResourceNotFoundError: 如果模型未找到
|
||
"""
|
||
model = self.get_model(model_id)
|
||
model.status = "deleted"
|
||
self.db.commit()
|
||
|
||
logger.info("model_deleted", model_id=model_id)
|
||
return True
|
||
|
||
def get_llm(self, name: Optional[str] = None) -> BaseLLM:
|
||
"""
|
||
根据名称获取LLM实例或使用默认实例。
|
||
|
||
Args:
|
||
name: 模型名称(None表示使用默认)
|
||
|
||
Returns:
|
||
BaseLLM: LangChain LLM实例
|
||
|
||
Raises:
|
||
ResourceNotFoundError: 如果模型未找到
|
||
InvalidConfigError: 如果模型类型不是'llm'
|
||
"""
|
||
if name:
|
||
model = self.get_model_by_name(name)
|
||
else:
|
||
# 获取默认LLM
|
||
model = self.db.query(Model).filter(
|
||
Model.type == "llm",
|
||
Model.is_default == True,
|
||
Model.status == "active"
|
||
).first()
|
||
|
||
if not model:
|
||
raise ResourceNotFoundError("No default LLM model configured")
|
||
|
||
if model.type != "llm":
|
||
raise InvalidConfigError(f"Model '{model.name}' is not an LLM (type: {model.type})")
|
||
|
||
# 替换配置中的环境变量
|
||
config = self._replace_env_vars(model.config)
|
||
|
||
# 如果需要,修正API密钥参数名称
|
||
if "api_key" in config and "openai_api_key" not in config:
|
||
config["openai_api_key"] = config.pop("api_key")
|
||
|
||
# 如果配置中没有API密钥,从settings添加
|
||
if "openai_api_key" not in config:
|
||
if self.settings.openai_api_key:
|
||
config["openai_api_key"] = self.settings.openai_api_key
|
||
else:
|
||
raise InvalidConfigError("OPENAI_API_KEY not configured in settings or model config")
|
||
|
||
# 如果已配置,从settings添加base_url
|
||
if self.settings.openai_base_url and "base_url" not in config:
|
||
config["base_url"] = self.settings.openai_base_url
|
||
|
||
logger.info("getting_llm", model_name=model.name, base_url=config.get("base_url"))
|
||
|
||
# 目前仅支持OpenAI,可以扩展
|
||
return ChatOpenAI(**config)
|
||
|
||
def get_embedding(self, name: Optional[str] = None) -> Embeddings:
|
||
"""
|
||
根据名称获取Embedding实例或使用默认实例。
|
||
|
||
Args:
|
||
name: 模型名称(None表示使用默认)
|
||
|
||
Returns:
|
||
Embeddings: LangChain Embeddings实例
|
||
|
||
Raises:
|
||
ResourceNotFoundError: 如果模型未找到
|
||
InvalidConfigError: 如果模型类型不是'embedding'
|
||
"""
|
||
if name:
|
||
model = self.get_model_by_name(name)
|
||
else:
|
||
# 获取默认embedding
|
||
model = self.db.query(Model).filter(
|
||
Model.type == "embedding",
|
||
Model.is_default == True,
|
||
Model.status == "active"
|
||
).first()
|
||
|
||
if not model:
|
||
raise ResourceNotFoundError("No default embedding model configured")
|
||
|
||
if model.type != "embedding":
|
||
raise InvalidConfigError(f"Model '{model.name}' is not an embedding model (type: {model.type})")
|
||
|
||
# 替换配置中的环境变量
|
||
config = self._replace_env_vars(model.config)
|
||
|
||
# 如果需要,修正API密钥参数名称
|
||
if "api_key" in config and "openai_api_key" not in config:
|
||
config["openai_api_key"] = config.pop("api_key")
|
||
|
||
# 如果配置中没有API密钥,从settings添加
|
||
if "openai_api_key" not in config:
|
||
if self.settings.openai_api_key:
|
||
config["openai_api_key"] = self.settings.openai_api_key
|
||
else:
|
||
raise InvalidConfigError("OPENAI_API_KEY not configured in settings or model config")
|
||
|
||
# 移除embedding的无效参数
|
||
config.pop("temperature", None)
|
||
|
||
# 如果已配置,从settings添加base_url
|
||
if self.settings.openai_base_url and "base_url" not in config:
|
||
config["base_url"] = self.settings.openai_base_url
|
||
|
||
logger.info("getting_embedding", model_name=model.name, base_url=config.get("base_url"))
|
||
|
||
# 目前仅支持OpenAI,可以扩展
|
||
return OpenAIEmbeddings(**config)
|