langchain-learning-kit/app/db/session.py

163 lines
4.2 KiB
Python
Raw Normal View History

"""
数据库会话管理和依赖注入
"""
from contextlib import contextmanager
from typing import Generator
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import QueuePool
from app.config import Settings
class DatabaseManager:
"""
数据库连接和会话管理器
"""
def __init__(self, settings: Settings):
self.settings = settings
self.engine = None
self.SessionLocal = None
self._initialize_engine()
def _initialize_engine(self) -> None:
"""
使用连接池初始化数据库引擎
"""
self.engine = create_engine(
self.settings.database_url,
poolclass=QueuePool,
pool_size=5,
max_overflow=10,
pool_pre_ping=True, # 使用前验证连接
pool_recycle=300, # 5 分钟后回收连接(之前是 3600
pool_reset_on_return="rollback", # 返回时重置连接
pool_timeout=30, # 等待连接最多 30 秒
connect_args={
"connect_timeout": 10,
"read_timeout": 30,
"write_timeout": 30,
"charset": "utf8mb4",
"autocommit": False,
},
echo=self.settings.debug,
isolation_level="READ COMMITTED",
)
self.SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=self.engine,
)
def get_db(self) -> Generator[Session, None, None]:
"""
用于数据库会话的 FastAPI 依赖项
Yields:
Session: SQLAlchemy 数据库会话
"""
db = self.SessionLocal()
try:
yield db
except Exception:
db.rollback()
raise
finally:
try:
db.close()
except Exception:
# 忽略关闭时的错误(连接可能已经关闭)
pass
@contextmanager
def session_scope(self) -> Generator[Session, None, None]:
"""
数据库会话的上下文管理器
Yields:
Session: SQLAlchemy 数据库会话
Example:
with db_manager.session_scope() as session:
session.query(Model).all()
"""
session = self.SessionLocal()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def create_all_tables(self) -> None:
"""
创建所有数据库表用于测试/开发
"""
from app.db.models import Base
Base.metadata.create_all(bind=self.engine)
def drop_all_tables(self) -> None:
"""
删除所有数据库表用于测试
"""
from app.db.models import Base
Base.metadata.drop_all(bind=self.engine)
def close(self) -> None:
"""
关闭数据库引擎
"""
if self.engine:
try:
self.engine.dispose()
except Exception:
# 忽略处理时的错误
pass
# 全局数据库管理器实例(在 main.py 中初始化)
db_manager: DatabaseManager = None
def get_db_manager() -> DatabaseManager:
"""
获取全局数据库管理器实例
Returns:
DatabaseManager: 数据库管理器实例
"""
return db_manager
def init_db_manager(settings: Settings) -> DatabaseManager:
"""
初始化全局数据库管理器
Args:
settings: 应用程序设置
Returns:
DatabaseManager: 初始化的数据库管理器
"""
global db_manager
db_manager = DatabaseManager(settings)
return db_manager
def get_db() -> Generator[Session, None, None]:
"""
用于 FastAPI 依赖注入的全局 get_db 函数
Yields:
Session: SQLAlchemy 数据库会话
"""
if db_manager is None:
raise RuntimeError("Database manager not initialized. Call init_db_manager() first.")
yield from db_manager.get_db()