langchain-learning-kit/app/tools/calculator.py

110 lines
3.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
用于 LangChain 代理的计算器工具。
"""
from typing import Type
from pydantic import BaseModel, Field
from langchain_core.tools import BaseTool
from app.utils.logger import get_logger
logger = get_logger(__name__)
class CalculatorInput(BaseModel):
"""计算器的输入模式。"""
expression: str = Field(
description="要计算的数学表达式(例如,'2 + 2''10 * 5''sqrt(16)'"
)
class CalculatorTool(BaseTool):
"""
执行数学计算的工具。
"""
name: str = "calculator"
description: str = (
"Useful for performing mathematical calculations. "
"Input should be a valid mathematical expression as a string. "
"Supports basic arithmetic (+, -, *, /), exponents (**), and common math functions "
"(sqrt, sin, cos, tan, log, exp, etc.)."
)
args_schema: Type[BaseModel] = CalculatorInput
def _run(self, expression: str) -> str:
"""
执行计算器工具。
Args:
expression: 要计算的数学表达式
Returns:
str: 计算结果或错误消息
"""
logger.info("calculator_tool_called", expression=expression)
try:
# 清理表达式
expression = expression.strip()
# 创建安全的计算环境
# 导入常用数学函数
import math
safe_dict = {
"abs": abs,
"round": round,
"max": max,
"min": min,
"sum": sum,
"pow": pow,
# 数学模块函数
"sqrt": math.sqrt,
"sin": math.sin,
"cos": math.cos,
"tan": math.tan,
"asin": math.asin,
"acos": math.acos,
"atan": math.atan,
"log": math.log,
"log10": math.log10,
"exp": math.exp,
"ceil": math.ceil,
"floor": math.floor,
"pi": math.pi,
"e": math.e,
}
# 安全地计算表达式
result = eval(expression, {"__builtins__": {}}, safe_dict)
# 格式化结果
if isinstance(result, (int, float)):
# 四舍五入到合理精度
if isinstance(result, float):
result = round(result, 10)
output = f"{expression} = {result}"
else:
output = str(result)
logger.info("calculator_tool_success", expression=expression, result=result)
return output
except ZeroDivisionError:
error_msg = "Error: Division by zero"
logger.warning("calculator_tool_error", expression=expression, error="division_by_zero")
return error_msg
except SyntaxError:
error_msg = f"Error: Invalid mathematical expression: {expression}"
logger.warning("calculator_tool_error", expression=expression, error="syntax_error")
return error_msg
except Exception as e:
error_msg = f"Error calculating expression: {str(e)}"
logger.error("calculator_tool_failed", expression=expression, error=str(e))
return error_msg
async def _arun(self, expression: str) -> str:
"""
_run 的异步版本。
注意:当前实现是同步的。
"""
return self._run(expression)