Merge pull request #60 from manmohan659/fix/wire-calculator-force

fix: auto-inject calculator force-path
This commit is contained in:
Manmohan 2026-04-22 19:04:33 -04:00 committed by GitHub
commit 64067a5edd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 66 additions and 15 deletions

View File

@ -386,23 +386,49 @@ def _rewrite_query(text: str) -> str:
# Calculator triggers (cheap, local)
# ---------------------------------------------------------------------------
_CALC_RX = re.compile(
r"""
\b(?:calculate|compute|what\s+is|what's)\b.*?
(?:
\d[\d,\.\s]*\s*[+\-\*/x×÷]\s*\d # basic arithmetic
| \d+\s*%\s+(?:of|tip|tax|discount)\s+\d # percentage
| \b(?:emi|cagr|compound\s+interest|tip|discount|percent(?:age)?)\b.*\d
)
""",
re.IGNORECASE | re.VERBOSE,
_BARE_EXPR_RX = re.compile(
r"(-?\d[\d,\.]*\s*[+\-*/×÷]\s*-?\d[\d,\.]*(?:\s*[+\-*/×÷]\s*-?\d[\d,\.]*)*)"
)
_PERCENT_RX = re.compile(
r"(\d+(?:\.\d+)?)\s*(?:%|percent)\s+(?:of|tip|tax|discount|off)\s+(?:on\s+)?\$?(\d+(?:\.\d+)?)",
re.IGNORECASE,
)
_VERBAL_RX = re.compile(
r"(\d+(?:\.\d+)?)\s+(plus|minus|times|divided\s+by|multiplied\s+by|over)\s+(\d+(?:\.\d+)?)",
re.IGNORECASE,
)
_WORD_OP = {
"plus": "+", "minus": "-", "times": "*",
"multiplied by": "*", "divided by": "/", "over": "/",
}
def _normalize_expr(expr: str) -> str:
e = expr.replace(",", "").replace("×", "*").replace("÷", "/")
e = re.sub(r"\s+", "", e) # strip all internal whitespace
return e
def needs_calculator(text: str) -> Tuple[bool, str]:
"""Return (True, expression) if the text contains arithmetic that the
calculator tool should execute. `expression` is passed as-is to the
sandboxed evaluator (accepts +-*/ on numbers, plus helpers like
percent(base,rate), emi(p,r,n), cagr(s,e,y))."""
if not text:
return False, ""
m = _CALC_RX.search(text)
if not m:
return False, ""
return True, text.strip()
# 1. percentage phrasing
m = _PERCENT_RX.search(text)
if m:
return True, f"percent({m.group(2)},{m.group(1)})"
# 2. verbal arithmetic
m = _VERBAL_RX.search(text)
if m:
op = _WORD_OP[m.group(2).lower().replace(" ", " ").strip()]
return True, f"{m.group(1)}{op}{m.group(3)}"
# 3. bare arithmetic expression
m = _BARE_EXPR_RX.search(text)
if m:
return True, _normalize_expr(m.group(1))
return False, ""

View File

@ -199,11 +199,12 @@ class Inference:
import sys as _sys
if '/root' not in _sys.path: _sys.path.insert(0, '/root')
from _tools import build_default_tool_registry, parse_tool_call_payload
from _query_classifier import needs_web_search, needs_web_search_contextual
from _query_classifier import needs_web_search, needs_web_search_contextual, needs_calculator
self.tool_registry = build_default_tool_registry()
self._parse_tool_call = parse_tool_call_payload
self._needs_web_search = needs_web_search
self._needs_web_search_contextual = needs_web_search_contextual
self._needs_calculator = needs_calculator
# Marker tokens for tool state machine
self.python_start_id = self.tokenizer.encode_special("<|python_start|>")[0]
self.python_end_id = self.tokenizer.encode_special("<|python_end|>")[0]
@ -316,6 +317,30 @@ class Inference:
+ "<|output_start|>" + result_text + "<|output_end|>\n"
)
tokens.extend(self.tokenizer.encode(forced_prefix_text))
else:
# Try calculator force-inject: arithmetic in the user message?
try:
needs_calc, calc_expr = self._needs_calculator(query_for_classify)
except Exception:
needs_calc, calc_expr = False, ""
if needs_calc and calc_expr:
preface = "Let me calculate that. "
calc_call_json = json.dumps(
{"arguments": {"expression": calc_expr}, "tool": "calculator"},
separators=(",", ":"),
)
try:
invocation = self._parse_tool_call(calc_call_json)
calc_result = self.tool_registry.execute(invocation.tool_name, invocation.arguments)
calc_result_text = calc_result.to_payload()[:2048]
except Exception as exc:
calc_result_text = json.dumps({"error": str(exc)[:500]})
forced_prefix_text = (
preface
+ "<|python_start|>" + calc_call_json + "<|python_end|>"
+ "<|output_start|>" + calc_result_text + "<|output_end|>\n"
)
tokens.extend(self.tokenizer.encode(forced_prefix_text))
# Truncate to fit context
max_context = self.config.sequence_len - max_tokens