110 lines
3.4 KiB
Python
110 lines
3.4 KiB
Python
"""
|
||
用于 LangChain 代理的计算器工具。
|
||
"""
|
||
from typing import Type
|
||
from pydantic import BaseModel, Field
|
||
from langchain_core.tools import BaseTool
|
||
|
||
from app.utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class CalculatorInput(BaseModel):
|
||
"""计算器的输入模式。"""
|
||
expression: str = Field(
|
||
description="要计算的数学表达式(例如,'2 + 2'、'10 * 5'、'sqrt(16)')"
|
||
)
|
||
|
||
|
||
class CalculatorTool(BaseTool):
|
||
"""
|
||
执行数学计算的工具。
|
||
"""
|
||
name: str = "calculator"
|
||
description: str = (
|
||
"Useful for performing mathematical calculations. "
|
||
"Input should be a valid mathematical expression as a string. "
|
||
"Supports basic arithmetic (+, -, *, /), exponents (**), and common math functions "
|
||
"(sqrt, sin, cos, tan, log, exp, etc.)."
|
||
)
|
||
args_schema: Type[BaseModel] = CalculatorInput
|
||
|
||
def _run(self, expression: str) -> str:
|
||
"""
|
||
执行计算器工具。
|
||
|
||
Args:
|
||
expression: 要计算的数学表达式
|
||
|
||
Returns:
|
||
str: 计算结果或错误消息
|
||
"""
|
||
logger.info("calculator_tool_called", expression=expression)
|
||
|
||
try:
|
||
# 清理表达式
|
||
expression = expression.strip()
|
||
|
||
# 创建安全的计算环境
|
||
# 导入常用数学函数
|
||
import math
|
||
safe_dict = {
|
||
"abs": abs,
|
||
"round": round,
|
||
"max": max,
|
||
"min": min,
|
||
"sum": sum,
|
||
"pow": pow,
|
||
# 数学模块函数
|
||
"sqrt": math.sqrt,
|
||
"sin": math.sin,
|
||
"cos": math.cos,
|
||
"tan": math.tan,
|
||
"asin": math.asin,
|
||
"acos": math.acos,
|
||
"atan": math.atan,
|
||
"log": math.log,
|
||
"log10": math.log10,
|
||
"exp": math.exp,
|
||
"ceil": math.ceil,
|
||
"floor": math.floor,
|
||
"pi": math.pi,
|
||
"e": math.e,
|
||
}
|
||
|
||
# 安全地计算表达式
|
||
result = eval(expression, {"__builtins__": {}}, safe_dict)
|
||
|
||
# 格式化结果
|
||
if isinstance(result, (int, float)):
|
||
# 四舍五入到合理精度
|
||
if isinstance(result, float):
|
||
result = round(result, 10)
|
||
output = f"{expression} = {result}"
|
||
else:
|
||
output = str(result)
|
||
|
||
logger.info("calculator_tool_success", expression=expression, result=result)
|
||
return output
|
||
|
||
except ZeroDivisionError:
|
||
error_msg = "Error: Division by zero"
|
||
logger.warning("calculator_tool_error", expression=expression, error="division_by_zero")
|
||
return error_msg
|
||
except SyntaxError:
|
||
error_msg = f"Error: Invalid mathematical expression: {expression}"
|
||
logger.warning("calculator_tool_error", expression=expression, error="syntax_error")
|
||
return error_msg
|
||
except Exception as e:
|
||
error_msg = f"Error calculating expression: {str(e)}"
|
||
logger.error("calculator_tool_failed", expression=expression, error=str(e))
|
||
return error_msg
|
||
|
||
async def _arun(self, expression: str) -> str:
|
||
"""
|
||
_run 的异步版本。
|
||
注意:当前实现是同步的。
|
||
"""
|
||
return self._run(expression)
|