langchain-learning-kit/app/tools/calculator.py

110 lines
3.4 KiB
Python
Raw Normal View History

"""
用于 LangChain 代理的计算器工具
"""
from typing import Type
from pydantic import BaseModel, Field
from langchain_core.tools import BaseTool
from app.utils.logger import get_logger
logger = get_logger(__name__)
class CalculatorInput(BaseModel):
"""计算器的输入模式。"""
expression: str = Field(
description="要计算的数学表达式(例如,'2 + 2''10 * 5''sqrt(16)'"
)
class CalculatorTool(BaseTool):
"""
执行数学计算的工具
"""
name: str = "calculator"
description: str = (
"Useful for performing mathematical calculations. "
"Input should be a valid mathematical expression as a string. "
"Supports basic arithmetic (+, -, *, /), exponents (**), and common math functions "
"(sqrt, sin, cos, tan, log, exp, etc.)."
)
args_schema: Type[BaseModel] = CalculatorInput
def _run(self, expression: str) -> str:
"""
执行计算器工具
Args:
expression: 要计算的数学表达式
Returns:
str: 计算结果或错误消息
"""
logger.info("calculator_tool_called", expression=expression)
try:
# 清理表达式
expression = expression.strip()
# 创建安全的计算环境
# 导入常用数学函数
import math
safe_dict = {
"abs": abs,
"round": round,
"max": max,
"min": min,
"sum": sum,
"pow": pow,
# 数学模块函数
"sqrt": math.sqrt,
"sin": math.sin,
"cos": math.cos,
"tan": math.tan,
"asin": math.asin,
"acos": math.acos,
"atan": math.atan,
"log": math.log,
"log10": math.log10,
"exp": math.exp,
"ceil": math.ceil,
"floor": math.floor,
"pi": math.pi,
"e": math.e,
}
# 安全地计算表达式
result = eval(expression, {"__builtins__": {}}, safe_dict)
# 格式化结果
if isinstance(result, (int, float)):
# 四舍五入到合理精度
if isinstance(result, float):
result = round(result, 10)
output = f"{expression} = {result}"
else:
output = str(result)
logger.info("calculator_tool_success", expression=expression, result=result)
return output
except ZeroDivisionError:
error_msg = "Error: Division by zero"
logger.warning("calculator_tool_error", expression=expression, error="division_by_zero")
return error_msg
except SyntaxError:
error_msg = f"Error: Invalid mathematical expression: {expression}"
logger.warning("calculator_tool_error", expression=expression, error="syntax_error")
return error_msg
except Exception as e:
error_msg = f"Error calculating expression: {str(e)}"
logger.error("calculator_tool_failed", expression=expression, error=str(e))
return error_msg
async def _arun(self, expression: str) -> str:
"""
_run 的异步版本
注意当前实现是同步的
"""
return self._run(expression)