163 lines
4.2 KiB
Python
163 lines
4.2 KiB
Python
|
|
"""
|
|||
|
|
数据库会话管理和依赖注入。
|
|||
|
|
"""
|
|||
|
|
from contextlib import contextmanager
|
|||
|
|
from typing import Generator
|
|||
|
|
|
|||
|
|
from sqlalchemy import create_engine
|
|||
|
|
from sqlalchemy.orm import sessionmaker, Session
|
|||
|
|
from sqlalchemy.pool import QueuePool
|
|||
|
|
|
|||
|
|
from app.config import Settings
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DatabaseManager:
|
|||
|
|
"""
|
|||
|
|
数据库连接和会话管理器。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, settings: Settings):
|
|||
|
|
self.settings = settings
|
|||
|
|
self.engine = None
|
|||
|
|
self.SessionLocal = None
|
|||
|
|
self._initialize_engine()
|
|||
|
|
|
|||
|
|
def _initialize_engine(self) -> None:
|
|||
|
|
"""
|
|||
|
|
使用连接池初始化数据库引擎。
|
|||
|
|
"""
|
|||
|
|
self.engine = create_engine(
|
|||
|
|
self.settings.database_url,
|
|||
|
|
poolclass=QueuePool,
|
|||
|
|
pool_size=5,
|
|||
|
|
max_overflow=10,
|
|||
|
|
pool_pre_ping=True, # 使用前验证连接
|
|||
|
|
pool_recycle=300, # 5 分钟后回收连接(之前是 3600)
|
|||
|
|
pool_reset_on_return="rollback", # 返回时重置连接
|
|||
|
|
pool_timeout=30, # 等待连接最多 30 秒
|
|||
|
|
connect_args={
|
|||
|
|
"connect_timeout": 10,
|
|||
|
|
"read_timeout": 30,
|
|||
|
|
"write_timeout": 30,
|
|||
|
|
"charset": "utf8mb4",
|
|||
|
|
"autocommit": False,
|
|||
|
|
},
|
|||
|
|
echo=self.settings.debug,
|
|||
|
|
isolation_level="READ COMMITTED",
|
|||
|
|
)
|
|||
|
|
self.SessionLocal = sessionmaker(
|
|||
|
|
autocommit=False,
|
|||
|
|
autoflush=False,
|
|||
|
|
bind=self.engine,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def get_db(self) -> Generator[Session, None, None]:
|
|||
|
|
"""
|
|||
|
|
用于数据库会话的 FastAPI 依赖项。
|
|||
|
|
|
|||
|
|
Yields:
|
|||
|
|
Session: SQLAlchemy 数据库会话
|
|||
|
|
"""
|
|||
|
|
db = self.SessionLocal()
|
|||
|
|
try:
|
|||
|
|
yield db
|
|||
|
|
except Exception:
|
|||
|
|
db.rollback()
|
|||
|
|
raise
|
|||
|
|
finally:
|
|||
|
|
try:
|
|||
|
|
db.close()
|
|||
|
|
except Exception:
|
|||
|
|
# 忽略关闭时的错误(连接可能已经关闭)
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
@contextmanager
|
|||
|
|
def session_scope(self) -> Generator[Session, None, None]:
|
|||
|
|
"""
|
|||
|
|
数据库会话的上下文管理器。
|
|||
|
|
|
|||
|
|
Yields:
|
|||
|
|
Session: SQLAlchemy 数据库会话
|
|||
|
|
|
|||
|
|
Example:
|
|||
|
|
with db_manager.session_scope() as session:
|
|||
|
|
session.query(Model).all()
|
|||
|
|
"""
|
|||
|
|
session = self.SessionLocal()
|
|||
|
|
try:
|
|||
|
|
yield session
|
|||
|
|
session.commit()
|
|||
|
|
except Exception:
|
|||
|
|
session.rollback()
|
|||
|
|
raise
|
|||
|
|
finally:
|
|||
|
|
session.close()
|
|||
|
|
|
|||
|
|
def create_all_tables(self) -> None:
|
|||
|
|
"""
|
|||
|
|
创建所有数据库表(用于测试/开发)。
|
|||
|
|
"""
|
|||
|
|
from app.db.models import Base
|
|||
|
|
|
|||
|
|
Base.metadata.create_all(bind=self.engine)
|
|||
|
|
|
|||
|
|
def drop_all_tables(self) -> None:
|
|||
|
|
"""
|
|||
|
|
删除所有数据库表(用于测试)。
|
|||
|
|
"""
|
|||
|
|
from app.db.models import Base
|
|||
|
|
|
|||
|
|
Base.metadata.drop_all(bind=self.engine)
|
|||
|
|
|
|||
|
|
def close(self) -> None:
|
|||
|
|
"""
|
|||
|
|
关闭数据库引擎。
|
|||
|
|
"""
|
|||
|
|
if self.engine:
|
|||
|
|
try:
|
|||
|
|
self.engine.dispose()
|
|||
|
|
except Exception:
|
|||
|
|
# 忽略处理时的错误
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局数据库管理器实例(在 main.py 中初始化)
|
|||
|
|
db_manager: DatabaseManager = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_db_manager() -> DatabaseManager:
|
|||
|
|
"""
|
|||
|
|
获取全局数据库管理器实例。
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
DatabaseManager: 数据库管理器实例
|
|||
|
|
"""
|
|||
|
|
return db_manager
|
|||
|
|
|
|||
|
|
|
|||
|
|
def init_db_manager(settings: Settings) -> DatabaseManager:
|
|||
|
|
"""
|
|||
|
|
初始化全局数据库管理器。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
settings: 应用程序设置
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
DatabaseManager: 初始化的数据库管理器
|
|||
|
|
"""
|
|||
|
|
global db_manager
|
|||
|
|
db_manager = DatabaseManager(settings)
|
|||
|
|
return db_manager
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_db() -> Generator[Session, None, None]:
|
|||
|
|
"""
|
|||
|
|
用于 FastAPI 依赖注入的全局 get_db 函数。
|
|||
|
|
|
|||
|
|
Yields:
|
|||
|
|
Session: SQLAlchemy 数据库会话
|
|||
|
|
"""
|
|||
|
|
if db_manager is None:
|
|||
|
|
raise RuntimeError("Database manager not initialized. Call init_db_manager() first.")
|
|||
|
|
yield from db_manager.get_db()
|