163 lines
4.2 KiB
Python
163 lines
4.2 KiB
Python
"""
|
||
数据库会话管理和依赖注入。
|
||
"""
|
||
from contextlib import contextmanager
|
||
from typing import Generator
|
||
|
||
from sqlalchemy import create_engine
|
||
from sqlalchemy.orm import sessionmaker, Session
|
||
from sqlalchemy.pool import QueuePool
|
||
|
||
from app.config import Settings
|
||
|
||
|
||
class DatabaseManager:
|
||
"""
|
||
数据库连接和会话管理器。
|
||
"""
|
||
|
||
def __init__(self, settings: Settings):
|
||
self.settings = settings
|
||
self.engine = None
|
||
self.SessionLocal = None
|
||
self._initialize_engine()
|
||
|
||
def _initialize_engine(self) -> None:
|
||
"""
|
||
使用连接池初始化数据库引擎。
|
||
"""
|
||
self.engine = create_engine(
|
||
self.settings.database_url,
|
||
poolclass=QueuePool,
|
||
pool_size=5,
|
||
max_overflow=10,
|
||
pool_pre_ping=True, # 使用前验证连接
|
||
pool_recycle=300, # 5 分钟后回收连接(之前是 3600)
|
||
pool_reset_on_return="rollback", # 返回时重置连接
|
||
pool_timeout=30, # 等待连接最多 30 秒
|
||
connect_args={
|
||
"connect_timeout": 10,
|
||
"read_timeout": 30,
|
||
"write_timeout": 30,
|
||
"charset": "utf8mb4",
|
||
"autocommit": False,
|
||
},
|
||
echo=self.settings.debug,
|
||
isolation_level="READ COMMITTED",
|
||
)
|
||
self.SessionLocal = sessionmaker(
|
||
autocommit=False,
|
||
autoflush=False,
|
||
bind=self.engine,
|
||
)
|
||
|
||
def get_db(self) -> Generator[Session, None, None]:
|
||
"""
|
||
用于数据库会话的 FastAPI 依赖项。
|
||
|
||
Yields:
|
||
Session: SQLAlchemy 数据库会话
|
||
"""
|
||
db = self.SessionLocal()
|
||
try:
|
||
yield db
|
||
except Exception:
|
||
db.rollback()
|
||
raise
|
||
finally:
|
||
try:
|
||
db.close()
|
||
except Exception:
|
||
# 忽略关闭时的错误(连接可能已经关闭)
|
||
pass
|
||
|
||
@contextmanager
|
||
def session_scope(self) -> Generator[Session, None, None]:
|
||
"""
|
||
数据库会话的上下文管理器。
|
||
|
||
Yields:
|
||
Session: SQLAlchemy 数据库会话
|
||
|
||
Example:
|
||
with db_manager.session_scope() as session:
|
||
session.query(Model).all()
|
||
"""
|
||
session = self.SessionLocal()
|
||
try:
|
||
yield session
|
||
session.commit()
|
||
except Exception:
|
||
session.rollback()
|
||
raise
|
||
finally:
|
||
session.close()
|
||
|
||
def create_all_tables(self) -> None:
|
||
"""
|
||
创建所有数据库表(用于测试/开发)。
|
||
"""
|
||
from app.db.models import Base
|
||
|
||
Base.metadata.create_all(bind=self.engine)
|
||
|
||
def drop_all_tables(self) -> None:
|
||
"""
|
||
删除所有数据库表(用于测试)。
|
||
"""
|
||
from app.db.models import Base
|
||
|
||
Base.metadata.drop_all(bind=self.engine)
|
||
|
||
def close(self) -> None:
|
||
"""
|
||
关闭数据库引擎。
|
||
"""
|
||
if self.engine:
|
||
try:
|
||
self.engine.dispose()
|
||
except Exception:
|
||
# 忽略处理时的错误
|
||
pass
|
||
|
||
|
||
# 全局数据库管理器实例(在 main.py 中初始化)
|
||
db_manager: DatabaseManager = None
|
||
|
||
|
||
def get_db_manager() -> DatabaseManager:
|
||
"""
|
||
获取全局数据库管理器实例。
|
||
|
||
Returns:
|
||
DatabaseManager: 数据库管理器实例
|
||
"""
|
||
return db_manager
|
||
|
||
|
||
def init_db_manager(settings: Settings) -> DatabaseManager:
|
||
"""
|
||
初始化全局数据库管理器。
|
||
|
||
Args:
|
||
settings: 应用程序设置
|
||
|
||
Returns:
|
||
DatabaseManager: 初始化的数据库管理器
|
||
"""
|
||
global db_manager
|
||
db_manager = DatabaseManager(settings)
|
||
return db_manager
|
||
|
||
|
||
def get_db() -> Generator[Session, None, None]:
|
||
"""
|
||
用于 FastAPI 依赖注入的全局 get_db 函数。
|
||
|
||
Yields:
|
||
Session: SQLAlchemy 数据库会话
|
||
"""
|
||
if db_manager is None:
|
||
raise RuntimeError("Database manager not initialized. Call init_db_manager() first.")
|
||
yield from db_manager.get_db()
|