219 lines
6.1 KiB
Python
219 lines
6.1 KiB
Python
"""
|
|
环境验证脚本 - 检查所有依赖和配置是否正确
|
|
"""
|
|
import sys
|
|
import os
|
|
|
|
def check_python_version():
|
|
"""检查 Python 版本"""
|
|
print("=" * 50)
|
|
print("检查 Python 版本...")
|
|
print("=" * 50)
|
|
version = sys.version_info
|
|
print(f"Python 版本: {version.major}.{version.minor}.{version.micro}")
|
|
|
|
if version.major == 3 and version.minor >= 11:
|
|
print("✓ Python 版本符合要求 (3.11+)")
|
|
return True
|
|
else:
|
|
print("✗ Python 版本过低,需要 3.11+")
|
|
return False
|
|
|
|
def check_dependencies():
|
|
"""检查依赖包"""
|
|
print("\n" + "=" * 50)
|
|
print("检查依赖包...")
|
|
print("=" * 50)
|
|
|
|
packages = [
|
|
("fastapi", "FastAPI"),
|
|
("sqlalchemy", "SQLAlchemy"),
|
|
("langchain", "LangChain"),
|
|
("faiss", "FAISS"),
|
|
("pydantic", "Pydantic"),
|
|
("structlog", "Structlog"),
|
|
("alembic", "Alembic"),
|
|
("pymysql", "PyMySQL"),
|
|
]
|
|
|
|
all_ok = True
|
|
for module_name, display_name in packages:
|
|
try:
|
|
module = __import__(module_name)
|
|
version = getattr(module, "__version__", "unknown")
|
|
print(f"✓ {display_name}: {version}")
|
|
except ImportError:
|
|
print(f"✗ {display_name}: 未安装")
|
|
all_ok = False
|
|
|
|
return all_ok
|
|
|
|
def check_env_file():
|
|
"""检查环境变量文件"""
|
|
print("\n" + "=" * 50)
|
|
print("检查配置文件...")
|
|
print("=" * 50)
|
|
|
|
if os.path.exists(".env"):
|
|
print("✓ .env 文件存在")
|
|
|
|
# 尝试加载配置
|
|
try:
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
|
|
# 检查关键配置
|
|
required_vars = [
|
|
"DATABASE_URL",
|
|
"OPENAI_API_KEY",
|
|
]
|
|
|
|
all_ok = True
|
|
for var in required_vars:
|
|
value = os.getenv(var)
|
|
if value:
|
|
masked_value = value[:10] + "..." if len(value) > 10 else value
|
|
print(f"✓ {var}: {masked_value}")
|
|
else:
|
|
print(f"✗ {var}: 未配置")
|
|
all_ok = False
|
|
|
|
return all_ok
|
|
except Exception as e:
|
|
print(f"✗ 加载 .env 失败: {e}")
|
|
return False
|
|
else:
|
|
print("✗ .env 文件不存在")
|
|
print(" 请从 .env.example 复制并配置")
|
|
return False
|
|
|
|
def check_database_connection():
|
|
"""检查数据库连接"""
|
|
print("\n" + "=" * 50)
|
|
print("检查数据库连接...")
|
|
print("=" * 50)
|
|
|
|
try:
|
|
sys.path.insert(0, "src")
|
|
from app.config import get_settings
|
|
from sqlalchemy import create_engine
|
|
|
|
settings = get_settings()
|
|
engine = create_engine(settings.database_url)
|
|
|
|
# 尝试连接
|
|
with engine.connect() as conn:
|
|
result = conn.execute("SELECT 1").fetchone()
|
|
if result:
|
|
print(f"✓ 数据库连接成功")
|
|
print(f" URL: {settings.database_url.split('@')[1] if '@' in settings.database_url else 'localhost'}")
|
|
return True
|
|
except Exception as e:
|
|
print(f"✗ 数据库连接失败: {e}")
|
|
print(" 请检查:")
|
|
print(" 1. MySQL 服务是否运行")
|
|
print(" 2. 数据库是否已创建")
|
|
print(" 3. DATABASE_URL 配置是否正确")
|
|
return False
|
|
|
|
def check_project_structure():
|
|
"""检查项目结构"""
|
|
print("\n" + "=" * 50)
|
|
print("检查项目结构...")
|
|
print("=" * 50)
|
|
|
|
required_dirs = [
|
|
"src/app",
|
|
"src/app/db",
|
|
"src/app/services",
|
|
"src/app/api/v1",
|
|
"src/app/utils",
|
|
"data/faiss",
|
|
"docs",
|
|
]
|
|
|
|
required_files = [
|
|
"src/app/config.py",
|
|
"src/app/db/models.py",
|
|
"src/app/db/session.py",
|
|
"requirements.txt",
|
|
"environment.yml",
|
|
]
|
|
|
|
all_ok = True
|
|
|
|
for dir_path in required_dirs:
|
|
if os.path.isdir(dir_path):
|
|
print(f"✓ {dir_path}/")
|
|
else:
|
|
print(f"✗ {dir_path}/ (不存在)")
|
|
all_ok = False
|
|
|
|
for file_path in required_files:
|
|
if os.path.isfile(file_path):
|
|
print(f"✓ {file_path}")
|
|
else:
|
|
print(f"✗ {file_path} (不存在)")
|
|
all_ok = False
|
|
|
|
return all_ok
|
|
|
|
def main():
|
|
"""主函数"""
|
|
print("\n")
|
|
print("╔" + "=" * 48 + "╗")
|
|
print("║ LangChain Learning Kit - 环境验证工具 ║")
|
|
print("╚" + "=" * 48 + "╝")
|
|
print()
|
|
|
|
checks = [
|
|
("Python 版本", check_python_version),
|
|
("依赖包", check_dependencies),
|
|
("配置文件", check_env_file),
|
|
("项目结构", check_project_structure),
|
|
("数据库连接", check_database_connection),
|
|
]
|
|
|
|
results = {}
|
|
for name, check_func in checks:
|
|
try:
|
|
results[name] = check_func()
|
|
except Exception as e:
|
|
print(f"\n✗ {name} 检查失败: {e}")
|
|
results[name] = False
|
|
|
|
# 总结
|
|
print("\n" + "=" * 50)
|
|
print("检查总结")
|
|
print("=" * 50)
|
|
|
|
passed = sum(1 for v in results.values() if v)
|
|
total = len(results)
|
|
|
|
for name, status in results.items():
|
|
status_str = "✓ 通过" if status else "✗ 失败"
|
|
print(f"{name}: {status_str}")
|
|
|
|
print("\n" + "-" * 50)
|
|
print(f"通过: {passed}/{total}")
|
|
|
|
if passed == total:
|
|
print("\n🎉 所有检查通过!环境配置完成。")
|
|
print("\n下一步操作:")
|
|
print("1. 初始化数据库: cd src && alembic upgrade head")
|
|
print("2. 启动服务: cd src && uvicorn app.main:app --reload")
|
|
else:
|
|
print("\n⚠️ 部分检查未通过,请根据上述提示修复问题。")
|
|
print("\n常见问题解决:")
|
|
print("- 依赖未安装: pip install -r requirements.txt")
|
|
print("- .env 未配置: cp .env.example .env (然后编辑)")
|
|
print("- 数据库未启动: 启动 MySQL 服务")
|
|
print("- 数据库未创建: CREATE DATABASE langchain_learning;")
|
|
|
|
print("\n" + "=" * 50)
|
|
return 0 if passed == total else 1
|
|
|
|
if __name__ == "__main__":
|
|
exit_code = main()
|
|
sys.exit(exit_code)
|