2025-10-02 17:22:19 +08:00
|
|
|
|
"""
|
|
|
|
|
|
用于 LangChain 代理的计算器工具。
|
|
|
|
|
|
"""
|
|
|
|
|
|
from typing import Type
|
|
|
|
|
|
from pydantic import BaseModel, Field
|
2025-10-02 18:10:53 +08:00
|
|
|
|
from langchain_core.tools import BaseTool
|
2025-10-02 17:22:19 +08:00
|
|
|
|
|
|
|
|
|
|
from app.utils.logger import get_logger
|
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CalculatorInput(BaseModel):
|
|
|
|
|
|
"""计算器的输入模式。"""
|
|
|
|
|
|
expression: str = Field(
|
|
|
|
|
|
description="要计算的数学表达式(例如,'2 + 2'、'10 * 5'、'sqrt(16)')"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CalculatorTool(BaseTool):
|
|
|
|
|
|
"""
|
|
|
|
|
|
执行数学计算的工具。
|
|
|
|
|
|
"""
|
|
|
|
|
|
name: str = "calculator"
|
|
|
|
|
|
description: str = (
|
|
|
|
|
|
"Useful for performing mathematical calculations. "
|
|
|
|
|
|
"Input should be a valid mathematical expression as a string. "
|
|
|
|
|
|
"Supports basic arithmetic (+, -, *, /), exponents (**), and common math functions "
|
|
|
|
|
|
"(sqrt, sin, cos, tan, log, exp, etc.)."
|
|
|
|
|
|
)
|
|
|
|
|
|
args_schema: Type[BaseModel] = CalculatorInput
|
|
|
|
|
|
|
|
|
|
|
|
def _run(self, expression: str) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
执行计算器工具。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
expression: 要计算的数学表达式
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
str: 计算结果或错误消息
|
|
|
|
|
|
"""
|
|
|
|
|
|
logger.info("calculator_tool_called", expression=expression)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 清理表达式
|
|
|
|
|
|
expression = expression.strip()
|
|
|
|
|
|
|
|
|
|
|
|
# 创建安全的计算环境
|
|
|
|
|
|
# 导入常用数学函数
|
|
|
|
|
|
import math
|
|
|
|
|
|
safe_dict = {
|
|
|
|
|
|
"abs": abs,
|
|
|
|
|
|
"round": round,
|
|
|
|
|
|
"max": max,
|
|
|
|
|
|
"min": min,
|
|
|
|
|
|
"sum": sum,
|
|
|
|
|
|
"pow": pow,
|
|
|
|
|
|
# 数学模块函数
|
|
|
|
|
|
"sqrt": math.sqrt,
|
|
|
|
|
|
"sin": math.sin,
|
|
|
|
|
|
"cos": math.cos,
|
|
|
|
|
|
"tan": math.tan,
|
|
|
|
|
|
"asin": math.asin,
|
|
|
|
|
|
"acos": math.acos,
|
|
|
|
|
|
"atan": math.atan,
|
|
|
|
|
|
"log": math.log,
|
|
|
|
|
|
"log10": math.log10,
|
|
|
|
|
|
"exp": math.exp,
|
|
|
|
|
|
"ceil": math.ceil,
|
|
|
|
|
|
"floor": math.floor,
|
|
|
|
|
|
"pi": math.pi,
|
|
|
|
|
|
"e": math.e,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 安全地计算表达式
|
|
|
|
|
|
result = eval(expression, {"__builtins__": {}}, safe_dict)
|
|
|
|
|
|
|
|
|
|
|
|
# 格式化结果
|
|
|
|
|
|
if isinstance(result, (int, float)):
|
|
|
|
|
|
# 四舍五入到合理精度
|
|
|
|
|
|
if isinstance(result, float):
|
|
|
|
|
|
result = round(result, 10)
|
|
|
|
|
|
output = f"{expression} = {result}"
|
|
|
|
|
|
else:
|
|
|
|
|
|
output = str(result)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("calculator_tool_success", expression=expression, result=result)
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
except ZeroDivisionError:
|
|
|
|
|
|
error_msg = "Error: Division by zero"
|
|
|
|
|
|
logger.warning("calculator_tool_error", expression=expression, error="division_by_zero")
|
|
|
|
|
|
return error_msg
|
|
|
|
|
|
except SyntaxError:
|
|
|
|
|
|
error_msg = f"Error: Invalid mathematical expression: {expression}"
|
|
|
|
|
|
logger.warning("calculator_tool_error", expression=expression, error="syntax_error")
|
|
|
|
|
|
return error_msg
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
error_msg = f"Error calculating expression: {str(e)}"
|
|
|
|
|
|
logger.error("calculator_tool_failed", expression=expression, error=str(e))
|
|
|
|
|
|
return error_msg
|
|
|
|
|
|
|
|
|
|
|
|
async def _arun(self, expression: str) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
_run 的异步版本。
|
|
|
|
|
|
注意:当前实现是同步的。
|
|
|
|
|
|
"""
|
|
|
|
|
|
return self._run(expression)
|