langchain-learning-kit/app/services/model_manager.py

350 lines
11 KiB
Python
Raw Normal View History

"""
模型管理器服务用于LLM和Embedding模型管理
"""
import os
import re
from typing import Optional, List, Dict, Any
from sqlalchemy.orm import Session
from langchain_core.language_models.llms import BaseLLM
from langchain_core.embeddings import Embeddings
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from app.db.models import Model
from app.config import Settings
from app.utils.exceptions import ResourceNotFoundError, InvalidConfigError, DuplicateResourceError
from app.utils.logger import get_logger
logger = get_logger(__name__)
class ModelManager:
"""
管理LLM和Embedding模型配置
"""
def __init__(self, db_session: Session, settings: Settings):
"""
初始化模型管理器
Args:
db_session: 数据库会话
settings: 应用程序配置
"""
self.db = db_session
self.settings = settings
def _replace_env_vars(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""
替换配置中的环境变量占位符
Args:
config: 配置字典
Returns:
Dict[str, Any]: 替换环境变量后的配置
"""
result = {}
for key, value in config.items():
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
# 提取变量名: ${VAR_NAME} -> VAR_NAME
var_name = value[2:-1]
# 首先尝试从os.environ获取然后从settings获取
env_value = os.getenv(var_name)
if env_value is None:
# 尝试从settings对象获取
if hasattr(self.settings, var_name.lower()):
env_value = getattr(self.settings, var_name.lower())
if env_value is None:
raise InvalidConfigError(
f"Environment variable {var_name} not found",
details={"variable": var_name}
)
result[key] = env_value
else:
result[key] = value
return result
def create_model(
self,
name: str,
type: str,
config: Dict[str, Any],
is_default: bool = False,
) -> Model:
"""
创建新的模型配置
Args:
name: 模型名称
type: 模型类型'llm' 'embedding'
config: 模型配置
is_default: 设置为默认模型
Returns:
Model: 创建的模型实例
Raises:
DuplicateResourceError: 如果存在同名模型
InvalidConfigError: 如果配置无效
"""
logger.info("creating_model", name=name, type=type)
# 检查模型是否已存在
existing = self.db.query(Model).filter(Model.name == name).first()
if existing:
raise DuplicateResourceError(f"Model '{name}' already exists")
# 验证类型
if type not in ["llm", "embedding"]:
raise InvalidConfigError(f"Invalid model type: {type}. Must be 'llm' or 'embedding'")
# 如果设置为默认,取消其他默认设置
if is_default:
self.db.query(Model).filter(Model.type == type, Model.is_default == True).update(
{"is_default": False}
)
# 创建模型
model = Model(
name=name,
type=type,
config=config,
is_default=is_default,
status="active",
)
self.db.add(model)
self.db.commit()
self.db.refresh(model)
logger.info("model_created", model_id=model.id, name=name)
return model
def get_model(self, model_id: int) -> Model:
"""
根据ID获取模型
Args:
model_id: 模型ID
Returns:
Model: 模型实例
Raises:
ResourceNotFoundError: 如果模型未找到
"""
model = self.db.query(Model).filter(Model.id == model_id).first()
if not model:
raise ResourceNotFoundError(f"Model not found: {model_id}")
return model
def get_model_by_name(self, name: str) -> Model:
"""
根据名称获取模型
Args:
name: 模型名称
Returns:
Model: 模型实例
Raises:
ResourceNotFoundError: 如果模型未找到
"""
model = self.db.query(Model).filter(Model.name == name).first()
if not model:
raise ResourceNotFoundError(f"Model not found: {name}")
return model
def list_models(self, type: Optional[str] = None) -> List[Model]:
"""
列出所有模型可选择按类型过滤
Args:
type: 可选按模型类型过滤
Returns:
List[Model]: 模型列表
"""
query = self.db.query(Model).filter(Model.status == "active")
if type:
query = query.filter(Model.type == type)
return query.all()
def update_model(
self,
model_id: int,
config: Optional[Dict[str, Any]] = None,
is_default: Optional[bool] = None,
status: Optional[str] = None,
) -> Model:
"""
更新模型配置
Args:
model_id: 模型ID
config: 新配置
is_default: 更新默认状态
status: 更新状态
Returns:
Model: 更新后的模型实例
Raises:
ResourceNotFoundError: 如果模型未找到
"""
model = self.get_model(model_id)
if config is not None:
model.config = config
if is_default is not None and is_default:
# 取消其他默认设置
self.db.query(Model).filter(
Model.type == model.type,
Model.is_default == True,
Model.id != model_id
).update({"is_default": False})
model.is_default = True
if status is not None:
model.status = status
self.db.commit()
self.db.refresh(model)
logger.info("model_updated", model_id=model_id)
return model
def delete_model(self, model_id: int) -> bool:
"""
删除软删除模型
Args:
model_id: 模型ID
Returns:
bool: 如果删除成功返回True
Raises:
ResourceNotFoundError: 如果模型未找到
"""
model = self.get_model(model_id)
model.status = "deleted"
self.db.commit()
logger.info("model_deleted", model_id=model_id)
return True
def get_llm(self, name: Optional[str] = None) -> BaseLLM:
"""
根据名称获取LLM实例或使用默认实例
Args:
name: 模型名称None表示使用默认
Returns:
BaseLLM: LangChain LLM实例
Raises:
ResourceNotFoundError: 如果模型未找到
InvalidConfigError: 如果模型类型不是'llm'
"""
if name:
model = self.get_model_by_name(name)
else:
# 获取默认LLM
model = self.db.query(Model).filter(
Model.type == "llm",
Model.is_default == True,
Model.status == "active"
).first()
if not model:
raise ResourceNotFoundError("No default LLM model configured")
if model.type != "llm":
raise InvalidConfigError(f"Model '{model.name}' is not an LLM (type: {model.type})")
# 替换配置中的环境变量
config = self._replace_env_vars(model.config)
# 如果需要修正API密钥参数名称
if "api_key" in config and "openai_api_key" not in config:
config["openai_api_key"] = config.pop("api_key")
# 如果配置中没有API密钥从settings添加
if "openai_api_key" not in config:
if self.settings.openai_api_key:
config["openai_api_key"] = self.settings.openai_api_key
else:
raise InvalidConfigError("OPENAI_API_KEY not configured in settings or model config")
# 如果已配置从settings添加base_url
if self.settings.openai_base_url and "base_url" not in config:
config["base_url"] = self.settings.openai_base_url
logger.info("getting_llm", model_name=model.name, base_url=config.get("base_url"))
# 目前仅支持OpenAI可以扩展
return ChatOpenAI(**config)
def get_embedding(self, name: Optional[str] = None) -> Embeddings:
"""
根据名称获取Embedding实例或使用默认实例
Args:
name: 模型名称None表示使用默认
Returns:
Embeddings: LangChain Embeddings实例
Raises:
ResourceNotFoundError: 如果模型未找到
InvalidConfigError: 如果模型类型不是'embedding'
"""
if name:
model = self.get_model_by_name(name)
else:
# 获取默认embedding
model = self.db.query(Model).filter(
Model.type == "embedding",
Model.is_default == True,
Model.status == "active"
).first()
if not model:
raise ResourceNotFoundError("No default embedding model configured")
if model.type != "embedding":
raise InvalidConfigError(f"Model '{model.name}' is not an embedding model (type: {model.type})")
# 替换配置中的环境变量
config = self._replace_env_vars(model.config)
# 如果需要修正API密钥参数名称
if "api_key" in config and "openai_api_key" not in config:
config["openai_api_key"] = config.pop("api_key")
# 如果配置中没有API密钥从settings添加
if "openai_api_key" not in config:
if self.settings.openai_api_key:
config["openai_api_key"] = self.settings.openai_api_key
else:
raise InvalidConfigError("OPENAI_API_KEY not configured in settings or model config")
# 移除embedding的无效参数
config.pop("temperature", None)
# 如果已配置从settings添加base_url
if self.settings.openai_base_url and "base_url" not in config:
config["base_url"] = self.settings.openai_base_url
logger.info("getting_embedding", model_name=model.name, base_url=config.get("base_url"))
# 目前仅支持OpenAI可以扩展
return OpenAIEmbeddings(**config)