mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-15 04:07:32 +00:00
Merge pull request #60 from manmohan659/fix/wire-calculator-force
fix: auto-inject calculator force-path
This commit is contained in:
commit
64067a5edd
|
|
@ -386,23 +386,49 @@ def _rewrite_query(text: str) -> str:
|
|||
# Calculator triggers (cheap, local)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CALC_RX = re.compile(
|
||||
r"""
|
||||
\b(?:calculate|compute|what\s+is|what's)\b.*?
|
||||
(?:
|
||||
\d[\d,\.\s]*\s*[+\-\*/x×÷]\s*\d # basic arithmetic
|
||||
| \d+\s*%\s+(?:of|tip|tax|discount)\s+\d # percentage
|
||||
| \b(?:emi|cagr|compound\s+interest|tip|discount|percent(?:age)?)\b.*\d
|
||||
)
|
||||
""",
|
||||
re.IGNORECASE | re.VERBOSE,
|
||||
_BARE_EXPR_RX = re.compile(
|
||||
r"(-?\d[\d,\.]*\s*[+\-*/×÷]\s*-?\d[\d,\.]*(?:\s*[+\-*/×÷]\s*-?\d[\d,\.]*)*)"
|
||||
)
|
||||
_PERCENT_RX = re.compile(
|
||||
r"(\d+(?:\.\d+)?)\s*(?:%|percent)\s+(?:of|tip|tax|discount|off)\s+(?:on\s+)?\$?(\d+(?:\.\d+)?)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_VERBAL_RX = re.compile(
|
||||
r"(\d+(?:\.\d+)?)\s+(plus|minus|times|divided\s+by|multiplied\s+by|over)\s+(\d+(?:\.\d+)?)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_WORD_OP = {
|
||||
"plus": "+", "minus": "-", "times": "*",
|
||||
"multiplied by": "*", "divided by": "/", "over": "/",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_expr(expr: str) -> str:
|
||||
e = expr.replace(",", "").replace("×", "*").replace("÷", "/")
|
||||
e = re.sub(r"\s+", "", e) # strip all internal whitespace
|
||||
return e
|
||||
|
||||
|
||||
def needs_calculator(text: str) -> Tuple[bool, str]:
|
||||
"""Return (True, expression) if the text contains arithmetic that the
|
||||
calculator tool should execute. `expression` is passed as-is to the
|
||||
sandboxed evaluator (accepts +-*/ on numbers, plus helpers like
|
||||
percent(base,rate), emi(p,r,n), cagr(s,e,y))."""
|
||||
if not text:
|
||||
return False, ""
|
||||
m = _CALC_RX.search(text)
|
||||
if not m:
|
||||
return False, ""
|
||||
return True, text.strip()
|
||||
# 1. percentage phrasing
|
||||
m = _PERCENT_RX.search(text)
|
||||
if m:
|
||||
return True, f"percent({m.group(2)},{m.group(1)})"
|
||||
# 2. verbal arithmetic
|
||||
m = _VERBAL_RX.search(text)
|
||||
if m:
|
||||
op = _WORD_OP[m.group(2).lower().replace(" ", " ").strip()]
|
||||
return True, f"{m.group(1)}{op}{m.group(3)}"
|
||||
# 3. bare arithmetic expression
|
||||
m = _BARE_EXPR_RX.search(text)
|
||||
if m:
|
||||
return True, _normalize_expr(m.group(1))
|
||||
return False, ""
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -199,11 +199,12 @@ class Inference:
|
|||
import sys as _sys
|
||||
if '/root' not in _sys.path: _sys.path.insert(0, '/root')
|
||||
from _tools import build_default_tool_registry, parse_tool_call_payload
|
||||
from _query_classifier import needs_web_search, needs_web_search_contextual
|
||||
from _query_classifier import needs_web_search, needs_web_search_contextual, needs_calculator
|
||||
self.tool_registry = build_default_tool_registry()
|
||||
self._parse_tool_call = parse_tool_call_payload
|
||||
self._needs_web_search = needs_web_search
|
||||
self._needs_web_search_contextual = needs_web_search_contextual
|
||||
self._needs_calculator = needs_calculator
|
||||
# Marker tokens for tool state machine
|
||||
self.python_start_id = self.tokenizer.encode_special("<|python_start|>")[0]
|
||||
self.python_end_id = self.tokenizer.encode_special("<|python_end|>")[0]
|
||||
|
|
@ -316,6 +317,30 @@ class Inference:
|
|||
+ "<|output_start|>" + result_text + "<|output_end|>\n"
|
||||
)
|
||||
tokens.extend(self.tokenizer.encode(forced_prefix_text))
|
||||
else:
|
||||
# Try calculator force-inject: arithmetic in the user message?
|
||||
try:
|
||||
needs_calc, calc_expr = self._needs_calculator(query_for_classify)
|
||||
except Exception:
|
||||
needs_calc, calc_expr = False, ""
|
||||
if needs_calc and calc_expr:
|
||||
preface = "Let me calculate that. "
|
||||
calc_call_json = json.dumps(
|
||||
{"arguments": {"expression": calc_expr}, "tool": "calculator"},
|
||||
separators=(",", ":"),
|
||||
)
|
||||
try:
|
||||
invocation = self._parse_tool_call(calc_call_json)
|
||||
calc_result = self.tool_registry.execute(invocation.tool_name, invocation.arguments)
|
||||
calc_result_text = calc_result.to_payload()[:2048]
|
||||
except Exception as exc:
|
||||
calc_result_text = json.dumps({"error": str(exc)[:500]})
|
||||
forced_prefix_text = (
|
||||
preface
|
||||
+ "<|python_start|>" + calc_call_json + "<|python_end|>"
|
||||
+ "<|output_start|>" + calc_result_text + "<|output_end|>\n"
|
||||
)
|
||||
tokens.extend(self.tokenizer.encode(forced_prefix_text))
|
||||
|
||||
# Truncate to fit context
|
||||
max_context = self.config.sequence_len - max_tokens
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user