langchain-learning-kit/app/db/session.py

163 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据库会话管理和依赖注入。
"""
from contextlib import contextmanager
from typing import Generator
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import QueuePool
from app.config import Settings
class DatabaseManager:
"""
数据库连接和会话管理器。
"""
def __init__(self, settings: Settings):
self.settings = settings
self.engine = None
self.SessionLocal = None
self._initialize_engine()
def _initialize_engine(self) -> None:
"""
使用连接池初始化数据库引擎。
"""
self.engine = create_engine(
self.settings.database_url,
poolclass=QueuePool,
pool_size=5,
max_overflow=10,
pool_pre_ping=True, # 使用前验证连接
pool_recycle=300, # 5 分钟后回收连接(之前是 3600
pool_reset_on_return="rollback", # 返回时重置连接
pool_timeout=30, # 等待连接最多 30 秒
connect_args={
"connect_timeout": 10,
"read_timeout": 30,
"write_timeout": 30,
"charset": "utf8mb4",
"autocommit": False,
},
echo=self.settings.debug,
isolation_level="READ COMMITTED",
)
self.SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=self.engine,
)
def get_db(self) -> Generator[Session, None, None]:
"""
用于数据库会话的 FastAPI 依赖项。
Yields:
Session: SQLAlchemy 数据库会话
"""
db = self.SessionLocal()
try:
yield db
except Exception:
db.rollback()
raise
finally:
try:
db.close()
except Exception:
# 忽略关闭时的错误(连接可能已经关闭)
pass
@contextmanager
def session_scope(self) -> Generator[Session, None, None]:
"""
数据库会话的上下文管理器。
Yields:
Session: SQLAlchemy 数据库会话
Example:
with db_manager.session_scope() as session:
session.query(Model).all()
"""
session = self.SessionLocal()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def create_all_tables(self) -> None:
"""
创建所有数据库表(用于测试/开发)。
"""
from app.db.models import Base
Base.metadata.create_all(bind=self.engine)
def drop_all_tables(self) -> None:
"""
删除所有数据库表(用于测试)。
"""
from app.db.models import Base
Base.metadata.drop_all(bind=self.engine)
def close(self) -> None:
"""
关闭数据库引擎。
"""
if self.engine:
try:
self.engine.dispose()
except Exception:
# 忽略处理时的错误
pass
# 全局数据库管理器实例(在 main.py 中初始化)
db_manager: DatabaseManager = None
def get_db_manager() -> DatabaseManager:
"""
获取全局数据库管理器实例。
Returns:
DatabaseManager: 数据库管理器实例
"""
return db_manager
def init_db_manager(settings: Settings) -> DatabaseManager:
"""
初始化全局数据库管理器。
Args:
settings: 应用程序设置
Returns:
DatabaseManager: 初始化的数据库管理器
"""
global db_manager
db_manager = DatabaseManager(settings)
return db_manager
def get_db() -> Generator[Session, None, None]:
"""
用于 FastAPI 依赖注入的全局 get_db 函数。
Yields:
Session: SQLAlchemy 数据库会话
"""
if db_manager is None:
raise RuntimeError("Database manager not initialized. Call init_db_manager() first.")
yield from db_manager.get_db()