langchain-learning-kit/app/services/model_manager.py

350 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
模型管理器服务用于LLM和Embedding模型管理。
"""
import os
import re
from typing import Optional, List, Dict, Any
from sqlalchemy.orm import Session
from langchain_core.language_models.llms import BaseLLM
from langchain_core.embeddings import Embeddings
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from app.db.models import Model
from app.config import Settings
from app.utils.exceptions import ResourceNotFoundError, InvalidConfigError, DuplicateResourceError
from app.utils.logger import get_logger
logger = get_logger(__name__)
class ModelManager:
"""
管理LLM和Embedding模型配置。
"""
def __init__(self, db_session: Session, settings: Settings):
"""
初始化模型管理器。
Args:
db_session: 数据库会话
settings: 应用程序配置
"""
self.db = db_session
self.settings = settings
def _replace_env_vars(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""
替换配置中的环境变量占位符。
Args:
config: 配置字典
Returns:
Dict[str, Any]: 替换环境变量后的配置
"""
result = {}
for key, value in config.items():
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
# 提取变量名: ${VAR_NAME} -> VAR_NAME
var_name = value[2:-1]
# 首先尝试从os.environ获取然后从settings获取
env_value = os.getenv(var_name)
if env_value is None:
# 尝试从settings对象获取
if hasattr(self.settings, var_name.lower()):
env_value = getattr(self.settings, var_name.lower())
if env_value is None:
raise InvalidConfigError(
f"Environment variable {var_name} not found",
details={"variable": var_name}
)
result[key] = env_value
else:
result[key] = value
return result
def create_model(
self,
name: str,
type: str,
config: Dict[str, Any],
is_default: bool = False,
) -> Model:
"""
创建新的模型配置。
Args:
name: 模型名称
type: 模型类型('llm''embedding'
config: 模型配置
is_default: 设置为默认模型
Returns:
Model: 创建的模型实例
Raises:
DuplicateResourceError: 如果存在同名模型
InvalidConfigError: 如果配置无效
"""
logger.info("creating_model", name=name, type=type)
# 检查模型是否已存在
existing = self.db.query(Model).filter(Model.name == name).first()
if existing:
raise DuplicateResourceError(f"Model '{name}' already exists")
# 验证类型
if type not in ["llm", "embedding"]:
raise InvalidConfigError(f"Invalid model type: {type}. Must be 'llm' or 'embedding'")
# 如果设置为默认,取消其他默认设置
if is_default:
self.db.query(Model).filter(Model.type == type, Model.is_default == True).update(
{"is_default": False}
)
# 创建模型
model = Model(
name=name,
type=type,
config=config,
is_default=is_default,
status="active",
)
self.db.add(model)
self.db.commit()
self.db.refresh(model)
logger.info("model_created", model_id=model.id, name=name)
return model
def get_model(self, model_id: int) -> Model:
"""
根据ID获取模型。
Args:
model_id: 模型ID
Returns:
Model: 模型实例
Raises:
ResourceNotFoundError: 如果模型未找到
"""
model = self.db.query(Model).filter(Model.id == model_id).first()
if not model:
raise ResourceNotFoundError(f"Model not found: {model_id}")
return model
def get_model_by_name(self, name: str) -> Model:
"""
根据名称获取模型。
Args:
name: 模型名称
Returns:
Model: 模型实例
Raises:
ResourceNotFoundError: 如果模型未找到
"""
model = self.db.query(Model).filter(Model.name == name).first()
if not model:
raise ResourceNotFoundError(f"Model not found: {name}")
return model
def list_models(self, type: Optional[str] = None) -> List[Model]:
"""
列出所有模型,可选择按类型过滤。
Args:
type: 可选,按模型类型过滤
Returns:
List[Model]: 模型列表
"""
query = self.db.query(Model).filter(Model.status == "active")
if type:
query = query.filter(Model.type == type)
return query.all()
def update_model(
self,
model_id: int,
config: Optional[Dict[str, Any]] = None,
is_default: Optional[bool] = None,
status: Optional[str] = None,
) -> Model:
"""
更新模型配置。
Args:
model_id: 模型ID
config: 新配置
is_default: 更新默认状态
status: 更新状态
Returns:
Model: 更新后的模型实例
Raises:
ResourceNotFoundError: 如果模型未找到
"""
model = self.get_model(model_id)
if config is not None:
model.config = config
if is_default is not None and is_default:
# 取消其他默认设置
self.db.query(Model).filter(
Model.type == model.type,
Model.is_default == True,
Model.id != model_id
).update({"is_default": False})
model.is_default = True
if status is not None:
model.status = status
self.db.commit()
self.db.refresh(model)
logger.info("model_updated", model_id=model_id)
return model
def delete_model(self, model_id: int) -> bool:
"""
删除(软删除)模型。
Args:
model_id: 模型ID
Returns:
bool: 如果删除成功返回True
Raises:
ResourceNotFoundError: 如果模型未找到
"""
model = self.get_model(model_id)
model.status = "deleted"
self.db.commit()
logger.info("model_deleted", model_id=model_id)
return True
def get_llm(self, name: Optional[str] = None) -> BaseLLM:
"""
根据名称获取LLM实例或使用默认实例。
Args:
name: 模型名称None表示使用默认
Returns:
BaseLLM: LangChain LLM实例
Raises:
ResourceNotFoundError: 如果模型未找到
InvalidConfigError: 如果模型类型不是'llm'
"""
if name:
model = self.get_model_by_name(name)
else:
# 获取默认LLM
model = self.db.query(Model).filter(
Model.type == "llm",
Model.is_default == True,
Model.status == "active"
).first()
if not model:
raise ResourceNotFoundError("No default LLM model configured")
if model.type != "llm":
raise InvalidConfigError(f"Model '{model.name}' is not an LLM (type: {model.type})")
# 替换配置中的环境变量
config = self._replace_env_vars(model.config)
# 如果需要修正API密钥参数名称
if "api_key" in config and "openai_api_key" not in config:
config["openai_api_key"] = config.pop("api_key")
# 如果配置中没有API密钥从settings添加
if "openai_api_key" not in config:
if self.settings.openai_api_key:
config["openai_api_key"] = self.settings.openai_api_key
else:
raise InvalidConfigError("OPENAI_API_KEY not configured in settings or model config")
# 如果已配置从settings添加base_url
if self.settings.openai_base_url and "base_url" not in config:
config["base_url"] = self.settings.openai_base_url
logger.info("getting_llm", model_name=model.name, base_url=config.get("base_url"))
# 目前仅支持OpenAI可以扩展
return ChatOpenAI(**config)
def get_embedding(self, name: Optional[str] = None) -> Embeddings:
"""
根据名称获取Embedding实例或使用默认实例。
Args:
name: 模型名称None表示使用默认
Returns:
Embeddings: LangChain Embeddings实例
Raises:
ResourceNotFoundError: 如果模型未找到
InvalidConfigError: 如果模型类型不是'embedding'
"""
if name:
model = self.get_model_by_name(name)
else:
# 获取默认embedding
model = self.db.query(Model).filter(
Model.type == "embedding",
Model.is_default == True,
Model.status == "active"
).first()
if not model:
raise ResourceNotFoundError("No default embedding model configured")
if model.type != "embedding":
raise InvalidConfigError(f"Model '{model.name}' is not an embedding model (type: {model.type})")
# 替换配置中的环境变量
config = self._replace_env_vars(model.config)
# 如果需要修正API密钥参数名称
if "api_key" in config and "openai_api_key" not in config:
config["openai_api_key"] = config.pop("api_key")
# 如果配置中没有API密钥从settings添加
if "openai_api_key" not in config:
if self.settings.openai_api_key:
config["openai_api_key"] = self.settings.openai_api_key
else:
raise InvalidConfigError("OPENAI_API_KEY not configured in settings or model config")
# 移除embedding的无效参数
config.pop("temperature", None)
# 如果已配置从settings添加base_url
if self.settings.openai_base_url and "base_url" not in config:
config["base_url"] = self.settings.openai_base_url
logger.info("getting_embedding", model_name=model.name, base_url=config.get("base_url"))
# 目前仅支持OpenAI可以扩展
return OpenAIEmbeddings(**config)