321 lines
8.8 KiB
Python
321 lines
8.8 KiB
Python
"""
|
||
用于后台任务处理的异步作业管理器。
|
||
"""
|
||
import asyncio
|
||
import uuid
|
||
from datetime import datetime
|
||
from typing import Dict, Any, Callable, Optional
|
||
from enum import Enum
|
||
|
||
from app.utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class JobStatus(str, Enum):
|
||
"""作业状态枚举。"""
|
||
PENDING = "pending"
|
||
RUNNING = "running"
|
||
COMPLETED = "completed"
|
||
FAILED = "failed"
|
||
|
||
|
||
class Job:
|
||
"""作业表示。"""
|
||
|
||
def __init__(self, job_id: str, job_fn: Callable, *args, **kwargs):
|
||
self.id = job_id
|
||
self.job_fn = job_fn
|
||
self.args = args
|
||
self.kwargs = kwargs
|
||
self.status = JobStatus.PENDING
|
||
self.result = None
|
||
self.error = None
|
||
self.created_at = datetime.utcnow()
|
||
self.started_at: Optional[datetime] = None
|
||
self.completed_at: Optional[datetime] = None
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""将作业转换为字典。"""
|
||
return {
|
||
"id": self.id,
|
||
"status": self.status.value,
|
||
"created_at": self.created_at.isoformat(),
|
||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||
"error": str(self.error) if self.error else None,
|
||
}
|
||
|
||
|
||
class AsyncJobManager:
|
||
"""
|
||
使用 asyncio 管理异步后台作业。
|
||
|
||
注意:这是用于教育目的的简化实现。
|
||
在生产环境中,请考虑使用 Celery + Redis 或类似解决方案。
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""初始化作业管理器。"""
|
||
self.jobs: Dict[str, Job] = {}
|
||
self._lock = asyncio.Lock()
|
||
|
||
async def submit_job(
|
||
self,
|
||
job_fn: Callable,
|
||
*args,
|
||
**kwargs
|
||
) -> str:
|
||
"""
|
||
提交作业进行异步执行。
|
||
|
||
Args:
|
||
job_fn: 要执行的可调用函数
|
||
*args: 函数的位置参数
|
||
**kwargs: 函数的关键字参数
|
||
|
||
Returns:
|
||
str: 作业 ID
|
||
"""
|
||
job_id = str(uuid.uuid4())
|
||
job = Job(job_id, job_fn, *args, **kwargs)
|
||
|
||
async with self._lock:
|
||
self.jobs[job_id] = job
|
||
|
||
logger.info("job_submitted", job_id=job_id)
|
||
|
||
# 在后台启动作业执行
|
||
asyncio.create_task(self._execute_job(job))
|
||
|
||
return job_id
|
||
|
||
async def _execute_job(self, job: Job) -> None:
|
||
"""
|
||
在后台执行作业。
|
||
|
||
Args:
|
||
job: 要执行的作业实例
|
||
"""
|
||
try:
|
||
job.status = JobStatus.RUNNING
|
||
job.started_at = datetime.utcnow()
|
||
|
||
logger.info("job_started", job_id=job.id)
|
||
|
||
# 检查函数是异步还是同步
|
||
if asyncio.iscoroutinefunction(job.job_fn):
|
||
result = await job.job_fn(*job.args, **job.kwargs)
|
||
else:
|
||
# 在执行器中运行同步函数以避免阻塞
|
||
loop = asyncio.get_event_loop()
|
||
result = await loop.run_in_executor(
|
||
None,
|
||
lambda: job.job_fn(*job.args, **job.kwargs)
|
||
)
|
||
|
||
job.result = result
|
||
job.status = JobStatus.COMPLETED
|
||
job.completed_at = datetime.utcnow()
|
||
|
||
logger.info("job_completed", job_id=job.id)
|
||
|
||
except Exception as e:
|
||
job.error = e
|
||
job.status = JobStatus.FAILED
|
||
job.completed_at = datetime.utcnow()
|
||
|
||
logger.error("job_failed", job_id=job.id, error=str(e))
|
||
|
||
async def get_job_status(self, job_id: str) -> Dict[str, Any]:
|
||
"""
|
||
获取作业状态和元数据。
|
||
|
||
Args:
|
||
job_id: 作业 ID
|
||
|
||
Returns:
|
||
Dict[str, Any]: 作业状态信息
|
||
|
||
Raises:
|
||
KeyError: 如果找不到作业
|
||
"""
|
||
async with self._lock:
|
||
if job_id not in self.jobs:
|
||
raise KeyError(f"Job not found: {job_id}")
|
||
|
||
job = self.jobs[job_id]
|
||
return job.to_dict()
|
||
|
||
async def get_job_result(self, job_id: str) -> Any:
|
||
"""
|
||
获取作业结果。
|
||
|
||
Args:
|
||
job_id: 作业 ID
|
||
|
||
Returns:
|
||
Any: 作业结果
|
||
|
||
Raises:
|
||
KeyError: 如果找不到作业
|
||
RuntimeError: 如果作业未完成或失败
|
||
"""
|
||
async with self._lock:
|
||
if job_id not in self.jobs:
|
||
raise KeyError(f"Job not found: {job_id}")
|
||
|
||
job = self.jobs[job_id]
|
||
|
||
if job.status == JobStatus.FAILED:
|
||
raise RuntimeError(f"Job failed: {job.error}")
|
||
|
||
if job.status != JobStatus.COMPLETED:
|
||
raise RuntimeError(f"Job not completed yet. Status: {job.status}")
|
||
|
||
return job.result
|
||
|
||
async def wait_for_job(
|
||
self,
|
||
job_id: str,
|
||
timeout: Optional[float] = None
|
||
) -> Any:
|
||
"""
|
||
等待作业完成并返回结果。
|
||
|
||
Args:
|
||
job_id: 作业 ID
|
||
timeout: 可选的超时时间(秒)
|
||
|
||
Returns:
|
||
Any: 作业结果
|
||
|
||
Raises:
|
||
KeyError: 如果找不到作业
|
||
asyncio.TimeoutError: 如果超时
|
||
RuntimeError: 如果作业失败
|
||
"""
|
||
start_time = asyncio.get_event_loop().time()
|
||
|
||
while True:
|
||
async with self._lock:
|
||
if job_id not in self.jobs:
|
||
raise KeyError(f"Job not found: {job_id}")
|
||
|
||
job = self.jobs[job_id]
|
||
|
||
if job.status == JobStatus.COMPLETED:
|
||
return job.result
|
||
|
||
if job.status == JobStatus.FAILED:
|
||
raise RuntimeError(f"Job failed: {job.error}")
|
||
|
||
# 检查超时
|
||
if timeout is not None:
|
||
elapsed = asyncio.get_event_loop().time() - start_time
|
||
if elapsed >= timeout:
|
||
raise asyncio.TimeoutError(f"Job timeout after {timeout}s")
|
||
|
||
# 稍等片刻再检查
|
||
await asyncio.sleep(0.1)
|
||
|
||
async def cancel_job(self, job_id: str) -> bool:
|
||
"""
|
||
取消待处理的作业(尽力而为)。
|
||
|
||
Args:
|
||
job_id: 作业 ID
|
||
|
||
Returns:
|
||
bool: 如果已取消则为 True,如果已运行/完成则为 False
|
||
"""
|
||
async with self._lock:
|
||
if job_id not in self.jobs:
|
||
raise KeyError(f"Job not found: {job_id}")
|
||
|
||
job = self.jobs[job_id]
|
||
|
||
if job.status == JobStatus.PENDING:
|
||
job.status = JobStatus.FAILED
|
||
job.error = Exception("Job cancelled by user")
|
||
job.completed_at = datetime.utcnow()
|
||
logger.info("job_cancelled", job_id=job_id)
|
||
return True
|
||
|
||
return False
|
||
|
||
async def list_jobs(
|
||
self,
|
||
status: Optional[JobStatus] = None,
|
||
limit: int = 100
|
||
) -> list[Dict[str, Any]]:
|
||
"""
|
||
列出作业,可选状态过滤器。
|
||
|
||
Args:
|
||
status: 可选的状态过滤器
|
||
limit: 返回的最大作业数
|
||
|
||
Returns:
|
||
list[Dict[str, Any]]: 作业信息列表
|
||
"""
|
||
async with self._lock:
|
||
jobs = list(self.jobs.values())
|
||
|
||
if status:
|
||
jobs = [j for j in jobs if j.status == status]
|
||
|
||
# 按创建时间排序(最新的在前)
|
||
jobs.sort(key=lambda j: j.created_at, reverse=True)
|
||
|
||
return [j.to_dict() for j in jobs[:limit]]
|
||
|
||
async def cleanup_old_jobs(self, max_age_hours: int = 24) -> int:
|
||
"""
|
||
清理旧的已完成/失败作业。
|
||
|
||
Args:
|
||
max_age_hours: 最大年龄(小时)
|
||
|
||
Returns:
|
||
int: 清理的作业数量
|
||
"""
|
||
from datetime import timedelta
|
||
|
||
cutoff_time = datetime.utcnow() - timedelta(hours=max_age_hours)
|
||
cleaned = 0
|
||
|
||
async with self._lock:
|
||
job_ids_to_remove = []
|
||
|
||
for job_id, job in self.jobs.items():
|
||
if job.status in [JobStatus.COMPLETED, JobStatus.FAILED]:
|
||
if job.completed_at and job.completed_at < cutoff_time:
|
||
job_ids_to_remove.append(job_id)
|
||
|
||
for job_id in job_ids_to_remove:
|
||
del self.jobs[job_id]
|
||
cleaned += 1
|
||
|
||
if cleaned > 0:
|
||
logger.info("jobs_cleaned_up", count=cleaned)
|
||
|
||
return cleaned
|
||
|
||
|
||
# 全局作业管理器实例
|
||
_job_manager: Optional[AsyncJobManager] = None
|
||
|
||
|
||
def get_job_manager() -> AsyncJobManager:
|
||
"""
|
||
获取或创建全局作业管理器实例。
|
||
|
||
Returns:
|
||
AsyncJobManager: 作业管理器实例
|
||
"""
|
||
global _job_manager
|
||
if _job_manager is None:
|
||
_job_manager = AsyncJobManager()
|
||
return _job_manager
|