算法可视化与交互学习平台
Tool-Calling TinyGPT:训练一个会调用工具的小模型Tool-Calling TinyGPT: Train a Local Tiny Model That Emits Tool Calls
在 TinyCodeGPT 的本地训练闭环之上进入第二阶段:不再让小模型直接写完整代码,而是训练它把用户请求翻译成 CALL calculator / CALL plot / CALL search 三类结构化工具指令,并在本地执行、验证、评测工具选择准确率、JSON 合法率和执行通过率。
从生成代码,推进到生成工具调用
No.7 训练的是 **task -> code**:看到任务,续写一段 Python 代码。
这个模块进入第二阶段:训练一个更像 agent 内核的小模型,让它先生成结构化工具调用:
这一步的关键不是让小模型“什么都知道”,而是让它学会三件事:
1. 什么时候该调用哪个工具。 2. 如何把自然语言参数变成稳定 JSON。 3. 如何让本地 runner 执行 CALL,并用执行结果评测 checkpoint。
完整闭环仍然是 No.7 的工程形状:JSONL 数据集、字符级 tokenizer、TinyGPT 本地训练、checkpoint、生成、执行、错误分析和续训。
第一步:先把 CALL 协议钉死
Tool-calling 训练最怕输出格式漂移,所以第一版只允许三种单行指令。模型不是自由发挥,而是在固定语法里补全工具名和 JSON 参数。
| 工具 | 模型应该生成 | 本地 runner 执行什么 |
|---|---|---|
calculator | CALL calculator {"expression":"23 + 17 * 4"} | 用 AST 白名单解析算术表达式,返回数值结果。 |
plot | CALL plot {"kind":"function","expression":"x**2","x_min":-10,"x_max":10} | 用 Matplotlib Agg 在本地生成 PNG 图表 artifact。 |
search | CALL search {"query":"transformer self attention","top_k":3} | 在本地课程小语料中检索,返回排序结果。 |
监督学习目标:让 CALL 序列概率最大
数据集由用户请求 u_i、工具名 a_i 和目标 CALL 文本 g_i 组成。训练时仍然是 next-token 预测,只是被预测的目标不再是 Python 代码,而是一条结构化工具指令。
为什么不直接训练最终答案
calculator 的数值、plot 的图片和 search 的检索结果都可以由确定性工具完成。小模型只需要负责“路由和参数化”,这样更容易训练,也更容易定位失败原因。
JSONL 数据格式:user 到 call
jsonl{"tool":"calculator","user":"calculate 23 plus 17 times 4","call":"CALL calculator {\"expression\":\"23 + 17 * 4\"}"}
{"tool":"plot","user":"plot y = x**2 from -10 to 10","call":"CALL plot {\"kind\":\"function\",\"expression\":\"x**2\",\"x_min\":-10,\"x_max\":10,\"points\":140}"}
{"tool":"search","user":"search for transformer self attention","call":"CALL search {\"query\":\"transformer self attention\",\"top_k\":3}"}一条样本如何进入训练窗口
user = 'plot y = x**2 from -10 to 10'训练仍是 next-token,只是目标变成 CALL
No.8 没有改变 TinyGPT 的训练数学。改变的是样本语义:模型在 <call> 后续写的是工具指令,因此 loss 下降意味着 CALL 格式、工具名和 JSON 参数越来越像训练分布。
低 loss 不等于工具调用正确
loss 只衡量字符概率。真正质量要看工具名是否正确、JSON 是否能解析、参数是否能执行、执行结果是否符合任务。
本地 runner 架构:训练和工具执行都在本机
第 8 模块继续采用本地 runner。网页只发 HTTP 请求和展示状态;PyTorch 训练、checkpoint、工具执行和图表 artifact 都留在本机。runner 的默认 --device auto 策略是优先使用 CUDA GPU;如果当前 PyTorch 或驱动环境没有 CUDA,则自动回退 CPU,并在状态卡里显示原因。
| 接口 | 作用 |
|---|---|
POST /train | 解析 JSONL,训练本地 Tool-Calling TinyGPT。 |
POST /generate | 从 user prompt 生成一条 CALL。 |
POST /execute_call | 解析 CALL 并执行本地工具。 |
POST /evaluate | 跑少量内置 prompt,返回 tool accuracy、valid CALL rate 和 execution pass rate。 |
本地训练控制台:Tool-Calling TinyGPT
本面板只负责控制和可视化;真正训练发生在本地 scripts/toolcalltinygpt_local_runner.py 里。
# Windows / PowerShell & 'C:\Users\richi\TI_richiebao\LLM\.venv\Scripts\python.exe' scripts\toolcalltinygpt_local_runner.py --host 127.0.0.1 --port 4888 --device auto # 推荐使用与 No.7 TinyCodeGPT 相同的 GPU Python 环境。 # --device auto = 优先 CUDA GPU;CUDA 不可用时自动回退 CPU # Runner API GET /status POST /train POST /generate POST /execute_call POST /evaluate
每条样本都把用户请求监督到一条完整 CALL 指令,而不是最终答案。
<user>
calculate 23 plus 17 times 4
</user>
<call>
CALL calculator {"expression":"23 + 17 * 4"}
</call>
训练目标仍然是 next-token cross entropy,但上下文从 task-code 改成 user-call。
等待训练日志...
默认使用当前训练好的 checkpoint 生成 CALL;CALL 会先被解析,再由本地工具执行。
生成后这里会显示 JSON args。
启动本地 runner 的命令
powershell& 'C:\Users\richi\TI_richiebao\LLM\.venv\Scripts\python.exe' scripts\toolcalltinygpt_local_runner.py --host 127.0.0.1 --port 4888 --device auto三个工具为什么要做成受限执行
工具调用的意义是把不稳定的语言输出交给稳定的程序执行。但这也要求工具边界清晰,尤其不能把模型输出当作任意代码运行。
| 工具 | 限制策略 | 失败信号 |
|---|---|---|
| calculator | 只允许 AST 白名单:数字、变量 pi/e、基础运算和少量 math 函数。 | 表达式无法解析、指数过大、结果非有限数。 |
| plot | 只允许 function / line / bar / scatter / histogram,并限制数组长度和点数。 | kind 不合法、x/y 长度不一致、表达式不安全。 |
| search | 第一版只检索本地课程小语料,不接互联网,结果可复现。 | query 为空、top_k 超范围、检索词过宽。 |
源码深读:runner 的关键函数地图
| 函数 | 负责什么 | 为什么重要 |
|---|---|---|
parse_jsonl_dataset | 校验 {tool,user,call},并确认 CALL 工具名匹配。 | 脏数据会直接破坏格式学习。 |
format_sample | 把样本写成 <user>...<call>...。 | 生成时必须复用同一前缀。 |
run_training | 沿用 No.7 的 PyTorch 训练循环。 | 训练预算仍由 token 数、batch、context、epoch 决定。 |
generate_call | 加载 checkpoint 并采样 CALL。 | temperature 太高会让 JSON 漂移。 |
execute_tool_call | 解析 CALL 并分派到三个本地工具。 | 把“像 CALL”变成“能跑通”。 |
evaluate_checkpoint | 跑固定 prompt 小评测。 | 把 loss、格式和执行结果放在同一张成绩单里。 |
本地 runner 源码:toolcalltinygpt_local_runner.py
pythonfrom __future__ import annotations
import argparse
import ast
import base64
import json
import math
import operator
import random
import re
import sys
import threading
import time
import traceback
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
except Exception: # pragma: no cover - status endpoint reports this clearly.
torch = None
class _NNStub:
Module = object
nn = _NNStub()
F = None
def torch_no_grad():
if torch is not None:
return torch.no_grad()
def decorator(function):
return function
return decorator
ROOT_DIR = Path(__file__).resolve().parents[1]
RUN_DIR = ROOT_DIR / ".tmp" / "toolcalltinygpt"
CHECKPOINT_DIR = RUN_DIR / "checkpoints"
RUN_OUTPUT_DIR = RUN_DIR / "tool_outputs"
DATASET_DIR = RUN_DIR / "datasets"
RUN_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
RUN_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
DATASET_DIR.mkdir(parents=True, exist_ok=True)
JOBS: dict[str, dict[str, Any]] = {}
STOP_EVENTS: dict[str, threading.Event] = {}
ACTIVE_JOB_ID: str | None = None
LOCK = threading.Lock()
TOOLS = ("calculator", "plot", "search")
CALL_RE = re.compile(r"^\s*CALL\s+([a-zA-Z_][a-zA-Z0-9_]*)\s+(\{.*\})\s*$", re.DOTALL)
@dataclass
class ModelConfig:
n_layer: int
n_head: int
n_embd: int
block_size: int
PRESETS = {
"tiny": ModelConfig(n_layer=2, n_head=2, n_embd=96, block_size=160),
"small": ModelConfig(n_layer=4, n_head=4, n_embd=192, block_size=256),
"medium": ModelConfig(n_layer=6, n_head=6, n_embd=384, block_size=384),
}
DEVICE_POLICY = "auto"
def torch_cuda_version() -> str | None:
if torch is None:
return None
return getattr(getattr(torch, "version", None), "cuda", None)
def is_cuda_available() -> bool:
return bool(torch is not None and torch.cuda.is_available())
def cuda_unavailable_reason() -> str:
if torch is None:
return "PyTorch is not installed."
version = str(getattr(torch, "__version__", ""))
if version.endswith("+cpu") or "+cpu" in version:
return "The current PyTorch wheel is CPU-only. Install a CUDA-enabled PyTorch wheel to use GPU."
if torch_cuda_version() is None:
return "The current PyTorch build does not report CUDA support."
return "CUDA is not available to PyTorch in this environment."
def resolve_device(requested_policy: str | None = None) -> tuple[str, str]:
policy = str(requested_policy or DEVICE_POLICY or "auto").strip().lower()
if policy in {"gpu", "cuda"}:
policy = "auto"
if policy not in {"auto", "cpu"}:
policy = "auto"
if policy == "cpu":
return "cpu", "CPU was selected explicitly."
if is_cuda_available():
name = torch.cuda.get_device_name(0)
return "cuda", f"GPU priority: using CUDA device {name}."
return "cpu", f"GPU priority: CUDA unavailable, falling back to CPU. {cuda_unavailable_reason()}"
LOCAL_SEARCH_DOCS = [
{
"title": "Gradient descent and learning rate",
"body": "Gradient descent updates parameters in the negative gradient direction. The learning rate controls step size.",
},
{
"title": "Transformer self attention",
"body": "Self attention lets each token mix information from previous tokens through query, key, and value vectors.",
},
{
"title": "TinyGPT next token training",
"body": "A TinyGPT model learns by predicting the next token at every position with cross entropy loss.",
},
{
"title": "Tool calling JSON schema",
"body": "Tool calling separates language understanding from deterministic tool execution using structured JSON arguments.",
},
{
"title": "Safe calculator expression parser",
"body": "A calculator tool should parse expressions with a whitelist instead of using eval or shell execution.",
},
{
"title": "Matplotlib plotting tool",
"body": "A plotting tool can turn structured chart arguments into local PNG artifacts using a non-interactive backend.",
},
{
"title": "VAE latent space",
"body": "A variational autoencoder maps data into a probabilistic latent space and samples through reparameterization.",
},
{
"title": "Diffusion denoising",
"body": "Diffusion models learn to reverse a noise process step by step until an image-like sample appears.",
},
{
"title": "MLP nonlinear feature space",
"body": "An MLP stacks affine layers and nonlinear activations to fit relationships beyond a linear boundary.",
},
{
"title": "Local checkpoint evaluation",
"body": "A local model checkpoint should be evaluated by format validity, tool selection accuracy, and execution pass rate.",
},
]
def utc_now() -> str:
return datetime.now(timezone.utc).isoformat()
def jsonable(value: Any) -> Any:
if isinstance(value, Path):
return str(value)
return value
def respond(handler: BaseHTTPRequestHandler, status: int, payload: dict[str, Any]) -> None:
raw = json.dumps(payload, ensure_ascii=False, default=jsonable).encode("utf-8")
handler.send_response(status)
handler.send_header("Content-Type", "application/json; charset=utf-8")
handler.send_header("Content-Length", str(len(raw)))
handler.send_header("Access-Control-Allow-Origin", "*")
handler.send_header("Access-Control-Allow-Headers", "Content-Type")
handler.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
handler.end_headers()
handler.wfile.write(raw)
def read_body(handler: BaseHTTPRequestHandler) -> dict[str, Any]:
length = int(handler.headers.get("Content-Length") or "0")
if length <= 0:
return {}
raw = handler.rfile.read(length).decode("utf-8")
return json.loads(raw or "{}")
def update_job(job_id: str, **patch: Any) -> None:
with LOCK:
job = JOBS[job_id]
job.update(patch)
def append_log(job_id: str, line: str) -> None:
with LOCK:
job = JOBS[job_id]
logs = job.setdefault("logs", [])
logs.append(line)
if len(logs) > 240:
del logs[: len(logs) - 240]
def append_metric(job_id: str, metric: dict[str, Any]) -> None:
with LOCK:
job = JOBS[job_id]
metrics = job.setdefault("metrics", [])
metrics.append(metric)
if len(metrics) > 240:
del metrics[: len(metrics) - 240]
def update_training_progress(
job_id: str,
step: int,
max_steps: int,
steps_per_epoch: int,
target_epochs: float,
) -> None:
epoch = step / max(steps_per_epoch, 1)
effective_target_epochs = target_epochs if target_epochs > 0 else max_steps / max(steps_per_epoch, 1)
update_job(
job_id,
progress={
"step": step,
"max_steps": max_steps,
"epoch": round(epoch, 4),
"target_epochs": round(effective_target_epochs, 4),
"percent": round(step / max(max_steps, 1), 6),
},
)
def clamp_int(value: Any, fallback: int, lower: int, upper: int) -> int:
try:
parsed = int(value)
except (TypeError, ValueError):
parsed = fallback
return max(lower, min(parsed, upper))
def clamp_float(value: Any, fallback: float, lower: float, upper: float) -> float:
try:
parsed = float(value)
except (TypeError, ValueError):
parsed = fallback
return max(lower, min(parsed, upper))
def resolve_model_config(request: dict[str, Any]) -> tuple[str, ModelConfig]:
preset = str(request.get("preset") or "small")
base_config = PRESETS.get(preset, PRESETS["small"])
raw_config = request.get("model_config") if preset == "custom" else None
if not isinstance(raw_config, dict):
return preset if preset in PRESETS else "small", base_config
config = ModelConfig(
n_layer=clamp_int(raw_config.get("n_layer"), base_config.n_layer, 1, 12),
n_head=clamp_int(raw_config.get("n_head"), base_config.n_head, 1, 12),
n_embd=clamp_int(raw_config.get("n_embd"), base_config.n_embd, 32, 768),
block_size=clamp_int(raw_config.get("block_size"), base_config.block_size, 64, 1024),
)
if config.n_embd % config.n_head != 0:
raise ValueError("n_embd must be divisible by n_head for custom model config.")
return "custom", config
def make_call(tool: str, args: dict[str, Any]) -> str:
return f"CALL {tool} {json.dumps(args, ensure_ascii=False, separators=(',', ':'))}"
def add_sample(samples: list[dict[str, str]], tool: str, user: str, call: str) -> None:
samples.append({"tool": tool, "user": user.strip(), "call": call.strip()})
def generate_dataset(samples_per_tool: int, seed: int = 11) -> list[dict[str, str]]:
rng = random.Random(seed)
samples: list[dict[str, str]] = []
search_topics = [
"gradient descent learning rate",
"transformer self attention",
"TinyGPT next token training",
"tool calling JSON schema",
"safe calculator expression parser",
"matplotlib plotting tool",
"local checkpoint evaluation",
]
for _ in range(samples_per_tool):
a, b, c = rng.randint(2, 90), rng.randint(2, 90), rng.randint(2, 20)
add_sample(samples, "calculator", f"calculate {a} plus {b} times {c}", make_call("calculator", {"expression": f"{a} + {b} * {c}"}))
add_sample(samples, "calculator", f"what is {a} percent of {b * c}", make_call("calculator", {"expression": f"{b * c} * {a} / 100"}))
add_sample(samples, "calculator", f"compute area of circle with radius {c}", make_call("calculator", {"expression": f"pi * {c} ** 2"}))
values = [rng.randint(1, 50) for _ in range(rng.randint(4, 8))]
labels = [f"item_{index + 1}" for index in range(len(values))]
expression = rng.choice(["x**2", "sin(x)", "cos(x)", "exp(-x**2)"])
add_sample(samples, "plot", f"plot y = {expression} from -10 to 10", make_call("plot", {"kind": "function", "expression": expression, "x_min": -10, "x_max": 10, "points": 140}))
add_sample(samples, "plot", f"draw a line chart for values {values}", make_call("plot", {"kind": "line", "x": list(range(1, len(values) + 1)), "y": values, "title": "Line chart"}))
add_sample(samples, "plot", f"make a bar chart for values {values[:5]}", make_call("plot", {"kind": "bar", "labels": labels[:5], "values": values[:5], "title": "Bar chart"}))
topic = rng.choice(search_topics)
add_sample(samples, "search", f"search for {topic}", make_call("search", {"query": topic, "top_k": rng.randint(2, 5)}))
add_sample(samples, "search", f"find beginner notes about {topic}", make_call("search", {"query": topic, "top_k": rng.randint(2, 5)}))
rng.shuffle(samples)
return samples
def strip_jsonl_fences(raw: str) -> str:
text = raw.strip()
if text.startswith("```"):
lines = text.splitlines()
if lines and lines[0].lstrip().startswith("```"):
lines = lines[1:]
if lines and lines[-1].strip() == "```":
lines = lines[:-1]
text = "\n".join(lines).strip()
return text
def parse_call(call_text: str) -> tuple[str, dict[str, Any]]:
match = CALL_RE.match(call_text.strip())
if not match:
raise ValueError("CALL must match: CALL tool {json_args}")
tool = match.group(1)
if tool not in TOOLS:
raise ValueError(f"Unsupported tool: {tool}")
args = json.loads(match.group(2))
if not isinstance(args, dict):
raise ValueError("CALL args must be a JSON object")
return tool, args
def normalize_dataset_sample(value: Any, line_number: int) -> dict[str, str]:
if not isinstance(value, dict):
raise ValueError(f"line {line_number}: expected a JSON object")
tool = str(value.get("tool") or "").strip()
user = str(value.get("user") or "").strip()
call_text = str(value.get("call") or "").strip()
if tool not in TOOLS:
raise ValueError(f"line {line_number}: tool must be one of {', '.join(TOOLS)}")
if not user:
raise ValueError(f"line {line_number}: user is required")
parsed_tool, _args = parse_call(call_text)
if parsed_tool != tool:
raise ValueError(f"line {line_number}: tool and CALL tool do not match")
return {"tool": tool, "user": user, "call": call_text}
def parse_jsonl_dataset(raw: str) -> list[dict[str, str]]:
text = strip_jsonl_fences(raw)
if not text:
raise ValueError("dataset_jsonl is empty")
if text.startswith("["):
parsed = json.loads(text)
if not isinstance(parsed, list):
raise ValueError("dataset JSON array is invalid")
samples = [normalize_dataset_sample(item, index + 1) for index, item in enumerate(parsed)]
else:
samples = []
for index, line in enumerate(text.splitlines(), start=1):
stripped = line.strip()
if not stripped:
continue
samples.append(normalize_dataset_sample(json.loads(stripped), index))
if not samples:
raise ValueError("dataset_jsonl does not contain any samples")
return samples
def save_jsonl_dataset(job_id: str, samples: list[dict[str, str]]) -> Path:
dataset_path = DATASET_DIR / f"{job_id}.jsonl"
payload = "\n".join(json.dumps(sample, ensure_ascii=False) for sample in samples) + "\n"
dataset_path.write_text(payload, encoding="utf-8")
return dataset_path
def load_training_samples(
job_id: str,
request: dict[str, Any],
samples_per_tool: int,
seed: int,
) -> tuple[list[dict[str, str]], str, Path | None]:
raw_jsonl = str(request.get("dataset_jsonl") or "").strip()
if raw_jsonl:
samples = parse_jsonl_dataset(raw_jsonl)
dataset_path = save_jsonl_dataset(job_id, samples)
return samples, "tool_call_jsonl", dataset_path
return generate_dataset(samples_per_tool, seed=seed), "template_synthetic", None
def format_sample(sample: dict[str, str]) -> str:
return f"<user>\n{sample['user']}\n</user>\n<call>\n{sample['call']}\n</call>\n"
def build_vocab(text: str) -> tuple[dict[str, int], dict[int, str]]:
chars = sorted(set(text))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}
return stoi, itos
def normalize_itos(raw_itos: dict[Any, str]) -> dict[int, str]:
if not raw_itos:
return {}
return {int(key): value for key, value in raw_itos.items()} if isinstance(next(iter(raw_itos.keys())), str) else raw_itos
def count_parameters(model: nn.Module) -> int:
return sum(parameter.numel() for parameter in model.parameters())
class CausalSelfAttention(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
if config.n_embd % config.n_head != 0:
raise ValueError("n_embd must be divisible by n_head")
self.n_head = config.n_head
self.head_dim = config.n_embd // config.n_head
self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd)
self.proj = nn.Linear(config.n_embd, config.n_embd)
self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
def forward(self, x):
batch, tokens, channels = x.shape
qkv = self.qkv(x)
q, k, v = qkv.split(channels, dim=2)
q = q.view(batch, tokens, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(batch, tokens, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(batch, tokens, self.n_head, self.head_dim).transpose(1, 2)
attention = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
attention = attention.masked_fill(self.mask[:, :, :tokens, :tokens] == 0, float("-inf"))
attention = F.softmax(attention, dim=-1)
y = attention @ v
y = y.transpose(1, 2).contiguous().view(batch, tokens, channels)
return self.proj(y)
class TransformerBlock(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln2 = nn.LayerNorm(config.n_embd)
self.ffn = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(),
nn.Linear(4 * config.n_embd, config.n_embd),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class ToolCallTinyGPT(nn.Module):
def __init__(self, vocab_size: int, config: ModelConfig):
super().__init__()
self.config = config
self.token_emb = nn.Embedding(vocab_size, config.n_embd)
self.pos_emb = nn.Embedding(config.block_size, config.n_embd)
self.blocks = nn.Sequential(*[TransformerBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, vocab_size, bias=False)
def forward(self, idx, targets=None):
_batch, tokens = idx.shape
positions = torch.arange(tokens, device=idx.device)
x = self.token_emb(idx) + self.pos_emb(positions)
x = self.blocks(x)
logits = self.head(self.ln_f(x))
loss = None
if targets is not None:
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
return logits, loss
@torch_no_grad()
def generate(
self,
idx,
max_new_tokens: int,
temperature: float = 0.45,
top_k: int = 30,
stop_sequence: list[int] | None = None,
):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.config.block_size :]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-4)
if top_k:
values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < values[:, [-1]]] = -float("inf")
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, next_id), dim=1)
if stop_sequence and idx.size(1) >= len(stop_sequence):
tail = idx[0, -len(stop_sequence) :].detach().cpu().tolist()
if tail == stop_sequence:
break
return idx
def make_batch(data: Any, block_size: int, batch_size: int, device: str):
starts = torch.randint(0, len(data) - block_size - 1, (batch_size,))
x = torch.stack([data[i : i + block_size] for i in starts]).to(device)
y = torch.stack([data[i + 1 : i + block_size + 1] for i in starts]).to(device)
return x, y
def evaluate_loss(model: nn.Module, data: Any, block_size: int, batch_size: int, device: str) -> float:
model.eval()
losses = []
with torch.no_grad():
for _ in range(6):
x, y = make_batch(data, block_size, batch_size, device)
_logits, loss = model(x, y)
losses.append(float(loss.item()))
model.train()
return sum(losses) / len(losses)
def latest_checkpoint() -> Path | None:
checkpoints = sorted(CHECKPOINT_DIR.glob("*.pt"), key=lambda item: item.stat().st_mtime, reverse=True)
return checkpoints[0] if checkpoints else None
def load_resume_checkpoint(path_value: str | None) -> tuple[Path, dict[str, Any]]:
checkpoint_path = Path(path_value) if path_value else latest_checkpoint()
if checkpoint_path is None or not checkpoint_path.exists():
raise RuntimeError("No checkpoint found to continue training from.")
return checkpoint_path, torch.load(checkpoint_path, map_location="cpu")
def run_training(job_id: str, request: dict[str, Any]) -> None:
global ACTIVE_JOB_ID
if torch is None:
update_job(job_id, status="failed", error="PyTorch is not installed in this Python environment.", finished_at=utc_now())
return
stop_event = STOP_EVENTS[job_id]
try:
continue_from_checkpoint = bool(request.get("continue_from_checkpoint"))
resume_checkpoint_path = str(request.get("checkpoint_path") or "").strip() or None
preset, config = resolve_model_config(request)
samples_per_tool = max(10, min(int(request.get("samples_per_tool") or 120), 1000))
fallback_max_steps = max(20, min(int(request.get("max_steps") or 50_000), 200_000))
target_epochs = clamp_float(request.get("target_epochs"), 6.0, 0.0, 100.0)
batch_size = max(4, min(int(request.get("batch_size") or 48), 128))
learning_rate = float(request.get("learning_rate") or 8e-4)
seed = int(request.get("seed") or 11)
random.seed(seed)
torch.manual_seed(seed)
device, device_reason = resolve_device(request.get("device"))
update_job(job_id, status="running", started_at=utc_now(), device=device, device_policy=DEVICE_POLICY, device_reason=device_reason)
append_log(job_id, f"device_policy = {DEVICE_POLICY}")
append_log(job_id, f"device = {device}")
append_log(job_id, device_reason)
append_log(job_id, f"preset = {preset}, config = {config}")
samples, dataset_source, dataset_path = load_training_samples(job_id, request, samples_per_tool, seed)
text = "\n".join(format_sample(sample) for sample in samples)
resume_checkpoint: dict[str, Any] | None = None
resumed_from_checkpoint: Path | None = None
if continue_from_checkpoint:
resumed_from_checkpoint, resume_checkpoint = load_resume_checkpoint(resume_checkpoint_path)
config = ModelConfig(**resume_checkpoint["config"])
preset = str(resume_checkpoint.get("metadata", {}).get("preset") or "continued")
stoi = resume_checkpoint["stoi"]
itos = normalize_itos(resume_checkpoint["itos"])
missing_chars = sorted(set(text) - set(stoi))
if missing_chars:
preview = "".join(missing_chars[:20])
raise RuntimeError(
"Cannot continue from this checkpoint because the new dataset contains characters "
f"that are not in the checkpoint vocabulary: {preview!r}"
)
append_log(job_id, f"continue_from_checkpoint = {resumed_from_checkpoint}")
else:
stoi, itos = build_vocab(text)
ids = torch.tensor([stoi[ch] for ch in text], dtype=torch.long)
if len(ids) <= config.block_size + batch_size + 2:
raise RuntimeError("Dataset is too small for the selected preset and batch size.")
split = max(config.block_size + batch_size + 2, int(len(ids) * 0.92))
split = min(split, len(ids) - config.block_size - batch_size - 2)
train_ids = ids[:split]
val_ids = ids[split:]
if len(val_ids) <= config.block_size + batch_size + 2:
val_ids = train_ids
steps_per_epoch = max(1, math.ceil(len(train_ids) / (batch_size * config.block_size)))
epoch_steps = math.ceil(steps_per_epoch * target_epochs) if target_epochs > 0 else fallback_max_steps
max_steps = max(20, min(epoch_steps, 200_000))
model = ToolCallTinyGPT(len(stoi), config).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1)
if resume_checkpoint is not None:
model.load_state_dict(resume_checkpoint["model_state_dict"])
if resume_checkpoint.get("optimizer_state_dict"):
optimizer.load_state_dict(resume_checkpoint["optimizer_state_dict"])
for state in optimizer.state.values():
for key, value in state.items():
if hasattr(value, "to"):
state[key] = value.to(device)
parameter_count = count_parameters(model)
tool_counts = {tool: sum(1 for sample in samples if sample["tool"] == tool) for tool in TOOLS}
update_job(
job_id,
dataset_size=len(samples),
tool_counts=tool_counts,
dataset_source=dataset_source,
dataset_path=str(dataset_path) if dataset_path else None,
vocab_size=len(stoi),
parameter_count=parameter_count,
preset=preset,
model_config=config.__dict__,
max_steps=max_steps,
fallback_max_steps=fallback_max_steps,
target_epochs=target_epochs,
steps_per_epoch=steps_per_epoch,
resumed_from_checkpoint=str(resumed_from_checkpoint) if resumed_from_checkpoint else None,
)
update_training_progress(job_id, 0, max_steps, steps_per_epoch, target_epochs)
append_log(job_id, f"dataset_size = {len(samples)} samples")
append_log(job_id, f"tool_counts = {tool_counts}")
append_log(job_id, f"vocab_size = {len(stoi)} characters")
append_log(job_id, f"parameters = {parameter_count:,}")
append_log(job_id, f"target_epochs = {target_epochs:g}, steps_per_epoch ≈ {steps_per_epoch}, planned_steps = {max_steps}")
log_every = max(10, max_steps // 18)
progress_every = 1 if max_steps <= 10_000 else max(1, max_steps // 2_000)
start_time = time.time()
for step in range(1, max_steps + 1):
if stop_event.is_set():
update_job(job_id, status="stopped", finished_at=utc_now())
append_log(job_id, "training stopped by user")
return
x, y = make_batch(train_ids, config.block_size, batch_size, device)
_logits, loss = model(x, y)
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
if step == 1 or step % progress_every == 0 or step == max_steps:
update_training_progress(job_id, step, max_steps, steps_per_epoch, target_epochs)
if step == 1 or step % log_every == 0 or step == max_steps:
val_loss = evaluate_loss(model, val_ids, config.block_size, min(batch_size, 32), device)
elapsed = max(time.time() - start_time, 1e-6)
tokens_per_second = int(step * batch_size * config.block_size / elapsed)
metric = {
"step": step,
"train_loss": round(float(loss.item()), 6),
"val_loss": round(float(val_loss), 6),
"tokens_per_second": tokens_per_second,
"epoch": round(step / steps_per_epoch, 4),
}
append_metric(job_id, metric)
append_log(job_id, f"step {step:5d}/{max_steps} | epoch {metric['epoch']:.2f}/{target_epochs:g} | train_loss {metric['train_loss']:.4f} | val_loss {metric['val_loss']:.4f} | {tokens_per_second:,} tok/s")
checkpoint_path = CHECKPOINT_DIR / f"{job_id}.pt"
torch.save(
{
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"config": config.__dict__,
"stoi": stoi,
"itos": itos,
"samples": samples[:120],
"metadata": {
"preset": preset,
"model_config": config.__dict__,
"samples_per_tool": samples_per_tool,
"max_steps": max_steps,
"fallback_max_steps": fallback_max_steps,
"target_epochs": target_epochs,
"steps_per_epoch": steps_per_epoch,
"batch_size": batch_size,
"learning_rate": learning_rate,
"device": device,
"device_policy": DEVICE_POLICY,
"device_reason": device_reason,
"dataset_source": dataset_source,
"dataset_path": str(dataset_path) if dataset_path else None,
"dataset_size": len(samples),
"tool_counts": tool_counts,
"resumed_from_checkpoint": str(resumed_from_checkpoint) if resumed_from_checkpoint else None,
"created_at": utc_now(),
},
},
checkpoint_path,
)
update_job(job_id, status="completed", checkpoint_path=str(checkpoint_path), finished_at=utc_now())
append_log(job_id, f"checkpoint saved: {checkpoint_path}")
except Exception as error:
update_job(job_id, status="failed", error=str(error), finished_at=utc_now())
append_log(job_id, traceback.format_exc())
finally:
with LOCK:
if ACTIVE_JOB_ID == job_id:
ACTIVE_JOB_ID = None
def load_checkpoint(path_value: str | None):
if torch is None:
raise RuntimeError("PyTorch is not installed in this Python environment.")
checkpoint_path = Path(path_value) if path_value else latest_checkpoint()
if checkpoint_path is None or not checkpoint_path.exists():
raise RuntimeError("No Tool-Calling TinyGPT checkpoint found. Train a model first.")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
config = ModelConfig(**checkpoint["config"])
stoi = checkpoint["stoi"]
itos = normalize_itos(checkpoint["itos"])
model = ToolCallTinyGPT(len(stoi), config)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return checkpoint_path, model, config, stoi, itos
def extract_call(text: str) -> str:
marker = "<call>\n"
start = text.find(marker)
if start >= 0:
text = text[start + len(marker) :]
end = text.find("</call>")
if end >= 0:
text = text[:end]
return text.strip()
def generate_call(request: dict[str, Any]) -> dict[str, Any]:
checkpoint_path, model, config, stoi, itos = load_checkpoint(request.get("checkpoint_path"))
device, device_reason = resolve_device(request.get("device"))
model.to(device)
user = str(request.get("user") or "").strip()
if not user:
raise RuntimeError("User prompt cannot be empty.")
prompt = f"<user>\n{user}\n</user>\n<call>\n"
encoded = [stoi[ch] for ch in prompt if ch in stoi]
if not encoded:
raise RuntimeError("Prompt contains no known characters from the tokenizer vocabulary.")
idx = torch.tensor([encoded], dtype=torch.long, device=device)
max_new_tokens = max(40, min(int(request.get("max_new_tokens") or 240), 800))
temperature = max(0.1, min(float(request.get("temperature") or 0.4), 1.2))
top_k = max(0, min(int(request.get("top_k") or 30), 100))
stop_sequence = [stoi[ch] for ch in "\n</call>" if ch in stoi]
with torch.no_grad():
generated = model.generate(
idx,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
stop_sequence=stop_sequence if stop_sequence else None,
)
text = "".join(itos[int(token)] for token in generated[0].detach().cpu().tolist())
call_text = extract_call(text)
parsed: dict[str, Any] | None = None
parse_error = None
try:
tool, args = parse_call(call_text)
parsed = {"tool": tool, "args": args}
except Exception as error:
parse_error = str(error)
return {
"user": user,
"call": call_text,
"parsed": parsed,
"parse_error": parse_error,
"truncated": "</call>" not in text,
"checkpoint_path": str(checkpoint_path),
"device": device,
"device_reason": device_reason,
}
ALLOWED_BINOPS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.FloorDiv: operator.floordiv,
ast.Mod: operator.mod,
ast.Pow: operator.pow,
}
ALLOWED_UNARYOPS = {ast.UAdd: operator.pos, ast.USub: operator.neg}
ALLOWED_NAMES = {
"pi": math.pi,
"e": math.e,
}
ALLOWED_FUNCS = {
"sqrt": math.sqrt,
"sin": math.sin,
"cos": math.cos,
"tan": math.tan,
"log": math.log,
"log10": math.log10,
"exp": math.exp,
"abs": abs,
"round": round,
}
def safe_eval_expression(expression: str, variable_values: dict[str, float] | None = None) -> float:
tree = ast.parse(expression, mode="eval")
variables = {**ALLOWED_NAMES, **(variable_values or {})}
def visit(node: ast.AST):
if isinstance(node, ast.Expression):
return visit(node.body)
if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
return float(node.value)
if isinstance(node, ast.Name) and node.id in variables:
return variables[node.id]
if isinstance(node, ast.BinOp) and type(node.op) in ALLOWED_BINOPS:
left = visit(node.left)
right = visit(node.right)
if isinstance(node.op, ast.Pow) and abs(right) > 8:
raise ValueError("Power exponent is too large.")
return ALLOWED_BINOPS[type(node.op)](left, right)
if isinstance(node, ast.UnaryOp) and type(node.op) in ALLOWED_UNARYOPS:
return ALLOWED_UNARYOPS[type(node.op)](visit(node.operand))
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id in ALLOWED_FUNCS:
args = [visit(arg) for arg in node.args]
if len(args) > 3:
raise ValueError("Too many function arguments.")
return ALLOWED_FUNCS[node.func.id](*args)
raise ValueError(f"Unsupported expression node: {type(node).__name__}")
result = visit(tree)
if not isinstance(result, (int, float)) or not math.isfinite(result):
raise ValueError("Expression result is not finite.")
return float(result)
def execute_calculator(args: dict[str, Any]) -> dict[str, Any]:
expression = str(args.get("expression") or "").strip()
if not expression:
raise ValueError("calculator.expression is required")
if len(expression) > 240:
raise ValueError("calculator.expression is too long")
value = safe_eval_expression(expression)
return {
"tool": "calculator",
"output": {"expression": expression, "value": value, "rounded": round(value, 8)},
"artifacts": [],
}
def sanitize_float_list(value: Any, name: str, max_items: int = 200) -> list[float]:
if not isinstance(value, list):
raise ValueError(f"{name} must be a list")
if not 1 <= len(value) <= max_items:
raise ValueError(f"{name} must contain 1-{max_items} items")
parsed = [float(item) for item in value]
if not all(math.isfinite(item) for item in parsed):
raise ValueError(f"{name} contains non-finite values")
return parsed
def encode_png_artifact(path: Path) -> dict[str, str]:
payload = base64.b64encode(path.read_bytes()).decode("ascii")
return {
"filename": path.name,
"path": str(path),
"mime_type": "image/png",
"data_url": f"data:image/png;base64,{payload}",
}
def execute_plot(args: dict[str, Any]) -> dict[str, Any]:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
kind = str(args.get("kind") or "").strip().lower()
title = str(args.get("title") or kind.title())[:80]
run_id = uuid.uuid4().hex[:10]
output_path = RUN_OUTPUT_DIR / f"{run_id}_plot.png"
plt.figure(figsize=(6, 4))
if kind == "function":
expression = str(args.get("expression") or "").strip()
x_min = float(args.get("x_min", -10))
x_max = float(args.get("x_max", 10))
points = max(20, min(int(args.get("points") or 120), 300))
if not expression:
raise ValueError("plot.expression is required for function plots")
if x_min >= x_max:
raise ValueError("x_min must be smaller than x_max")
xs = [x_min + (x_max - x_min) * index / (points - 1) for index in range(points)]
ys = [safe_eval_expression(expression, {"x": x}) for x in xs]
plt.plot(xs, ys)
elif kind == "line":
y = sanitize_float_list(args.get("y"), "plot.y")
x = sanitize_float_list(args.get("x"), "plot.x") if isinstance(args.get("x"), list) else list(range(len(y)))
if len(x) != len(y):
raise ValueError("plot.x and plot.y must have the same length")
plt.plot(x, y, marker="o")
elif kind == "bar":
values = sanitize_float_list(args.get("values"), "plot.values", max_items=40)
labels_raw = args.get("labels")
labels = [str(item)[:24] for item in labels_raw] if isinstance(labels_raw, list) else [str(index + 1) for index in range(len(values))]
if len(labels) != len(values):
raise ValueError("plot.labels and plot.values must have the same length")
plt.bar(labels, values)
plt.xticks(rotation=20)
elif kind == "scatter":
x = sanitize_float_list(args.get("x"), "plot.x")
y = sanitize_float_list(args.get("y"), "plot.y")
if len(x) != len(y):
raise ValueError("plot.x and plot.y must have the same length")
plt.scatter(x, y)
elif kind == "histogram":
values = sanitize_float_list(args.get("values"), "plot.values")
bins = max(2, min(int(args.get("bins") or 8), 40))
plt.hist(values, bins=bins)
else:
raise ValueError("plot.kind must be function, line, bar, scatter, or histogram")
plt.title(title)
plt.tight_layout()
plt.savefig(output_path, bbox_inches="tight")
plt.close()
return {
"tool": "plot",
"output": {"kind": kind, "title": title, "path": str(output_path)},
"artifacts": [encode_png_artifact(output_path)],
}
def execute_search(args: dict[str, Any]) -> dict[str, Any]:
query = str(args.get("query") or "").strip()
if not query:
raise ValueError("search.query is required")
top_k = max(1, min(int(args.get("top_k") or 3), 8))
query_terms = {term for term in re.split(r"\W+", query.lower()) if term}
scored = []
for doc in LOCAL_SEARCH_DOCS:
haystack = f"{doc['title']} {doc['body']}".lower()
score = sum(2 if term in doc["title"].lower() else 1 for term in query_terms if term in haystack)
scored.append({**doc, "score": score})
ranked = sorted(scored, key=lambda item: item["score"], reverse=True)[:top_k]
return {
"tool": "search",
"output": {"query": query, "top_k": top_k, "results": ranked},
"artifacts": [],
}
def execute_tool_call(request: dict[str, Any]) -> dict[str, Any]:
call_text = str(request.get("call") or "").strip()
tool, args = parse_call(call_text)
if tool == "calculator":
result = execute_calculator(args)
elif tool == "plot":
result = execute_plot(args)
elif tool == "search":
result = execute_search(args)
else:
raise ValueError(f"Unsupported tool: {tool}")
return {"call": call_text, "parsed": {"tool": tool, "args": args}, **result}
def evaluate_checkpoint(request: dict[str, Any]) -> dict[str, Any]:
prompts = [
("calculator", "calculate 23 plus 17 times 4"),
("calculator", "what is 18 percent of 250"),
("plot", "plot y = x**2 from -10 to 10"),
("plot", "draw a line chart for values 3, 8, 5, 13"),
("search", "search for transformer self attention"),
("search", "find beginner notes about tool calling JSON schema"),
]
max_items = max(1, min(int(request.get("max_items") or len(prompts)), len(prompts)))
results = []
for expected_tool, user in prompts[:max_items]:
generated = generate_call({**request, "user": user, "temperature": request.get("temperature", 0.25)})
parsed = generated.get("parsed")
tool_ok = bool(parsed and parsed.get("tool") == expected_tool)
execution_ok = False
error = None
execution = None
if parsed:
try:
execution = execute_tool_call({"call": generated["call"]})
execution_ok = True
except Exception as exec_error:
error = str(exec_error)
else:
error = generated.get("parse_error")
results.append({
"user": user,
"expected_tool": expected_tool,
"call": generated["call"],
"tool_ok": tool_ok,
"valid_call": bool(parsed),
"execution_ok": execution_ok,
"error": error,
"execution_output": execution.get("output") if execution else None,
})
total = len(results) or 1
return {
"items": results,
"summary": {
"count": len(results),
"tool_accuracy": sum(1 for item in results if item["tool_ok"]) / total,
"valid_call_rate": sum(1 for item in results if item["valid_call"]) / total,
"execution_pass_rate": sum(1 for item in results if item["execution_ok"]) / total,
},
}
def status_payload() -> dict[str, Any]:
checkpoints = sorted(CHECKPOINT_DIR.glob("*.pt"), key=lambda item: item.stat().st_mtime, reverse=True)
torch_available = torch is not None
cuda_available = is_cuda_available()
device, device_reason = resolve_device()
cuda_device_count = torch.cuda.device_count() if torch_available else 0
return {
"ok": True,
"python": sys.version.split()[0],
"torch_available": torch_available,
"torch_version": getattr(torch, "__version__", None) if torch_available else None,
"torch_cuda_version": torch_cuda_version(),
"cuda_available": cuda_available,
"cuda_device": torch.cuda.get_device_name(0) if cuda_available else None,
"cuda_device_count": cuda_device_count,
"gpu_preferred": DEVICE_POLICY == "auto",
"device_policy": DEVICE_POLICY,
"device": device,
"device_reason": device_reason,
"cuda_unavailable_reason": None if cuda_available else cuda_unavailable_reason(),
"working_dir": str(RUN_DIR),
"active_job_id": ACTIVE_JOB_ID,
"checkpoints": [str(item) for item in checkpoints[:10]],
"tools": list(TOOLS),
}
class ToolCallTinyGPTHandler(BaseHTTPRequestHandler):
def log_message(self, format: str, *args: Any) -> None:
print(f"[toolcalltinygpt-runner] {self.address_string()} - {format % args}")
def do_OPTIONS(self) -> None:
respond(self, 200, {"ok": True})
def do_GET(self) -> None:
parsed = urlparse(self.path)
path = parsed.path.rstrip("/") or "/"
if path in {"/", "/status", "/health"}:
respond(self, 200, status_payload())
return
if path.startswith("/jobs/"):
job_id = path.split("/")[-1]
with LOCK:
job = JOBS.get(job_id)
if not job:
respond(self, 404, {"message": "Job not found."})
return
respond(self, 200, job)
return
respond(self, 404, {"message": "Unknown route."})
def do_POST(self) -> None:
global ACTIVE_JOB_ID
parsed = urlparse(self.path)
path = parsed.path.rstrip("/") or "/"
try:
if path == "/train":
body = read_body(self)
with LOCK:
if ACTIVE_JOB_ID and JOBS.get(ACTIVE_JOB_ID, {}).get("status") in {"queued", "running"}:
respond(self, 409, {"message": f"Training job already running: {ACTIVE_JOB_ID}"})
return
job_id = uuid.uuid4().hex[:12]
job = {"id": job_id, "status": "queued", "created_at": utc_now(), "logs": [], "metrics": []}
JOBS[job_id] = job
STOP_EVENTS[job_id] = threading.Event()
ACTIVE_JOB_ID = job_id
thread = threading.Thread(target=run_training, args=(job_id, body), daemon=True)
thread.start()
respond(self, 200, job)
return
if path.startswith("/jobs/") and path.endswith("/stop"):
parts = path.split("/")
job_id = parts[2] if len(parts) >= 3 else ""
if job_id in STOP_EVENTS:
STOP_EVENTS[job_id].set()
with LOCK:
job = JOBS.get(job_id)
if not job:
respond(self, 404, {"message": "Job not found."})
return
respond(self, 200, job)
return
if path == "/generate":
respond(self, 200, generate_call(read_body(self)))
return
if path == "/execute_call":
respond(self, 200, execute_tool_call(read_body(self)))
return
if path == "/evaluate":
respond(self, 200, evaluate_checkpoint(read_body(self)))
return
respond(self, 404, {"message": "Unknown route."})
except Exception as error:
respond(self, 500, {"message": str(error), "traceback": traceback.format_exc()})
def main() -> None:
global DEVICE_POLICY
parser = argparse.ArgumentParser(description="Local Tool-Calling TinyGPT training runner for AlgoLab.")
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", type=int, default=4888)
parser.add_argument(
"--device",
choices=["auto", "cuda", "gpu", "cpu"],
default="auto",
help="auto/cuda/gpu prioritizes CUDA GPU and falls back to CPU when CUDA is unavailable; cpu forces CPU.",
)
args = parser.parse_args()
DEVICE_POLICY = "auto" if args.device in {"auto", "cuda", "gpu"} else "cpu"
server = ThreadingHTTPServer((args.host, args.port), ToolCallTinyGPTHandler)
print(f"Tool-Calling TinyGPT local runner listening on http://{args.host}:{args.port}")
print(f"Working directory: {RUN_DIR}")
print(f"Device policy: {DEVICE_POLICY}")
print(f"Resolved device: {resolve_device()[0]} ({resolve_device()[1]})")
print("Press Ctrl+C to stop.")
try:
server.serve_forever()
except KeyboardInterrupt:
print("\nStopping Tool-Calling TinyGPT local runner.")
finally:
server.server_close()
if __name__ == "__main__":
main()
下一步:从单次 CALL 到 CALL -> OBSERVATION -> FINAL
本模块故意只做单次 CALL。等工具选择和 JSON 参数稳定后,下一阶段才适合训练多段轨迹:
这样课程主线会很清楚:No.6 理解 next-token,No.7 训练 task-to-code,No.8 训练 tool-calling,后续模块再进入 ReAct、多工具编排和反馈式 agent。
为什么 next-token 能长出代码和工具调用能力
专业地说,decoder-only GPT 不是在学习一个显式函数 f(x)=y,而是在估计条件分布 pθ(tᵢ | t<i)。训练把整段文本的概率分解为每个位置的 next-token 概率;cross entropy 下降,表示模型把“在这个上下文之后应该接什么 token”排得更靠前。
| 层次 | next-token 学到什么 | 在本模块对应什么 |
|---|---|---|
| 表示层 | Transformer self-attention 把前文 token 编码成上下文相关表示,因此输出分布依赖完整 prompt。 | <user>...</user><call> 让模型知道接下来应该进入工具调用空间。 |
| 目标层 | 代码、JSON、CALL 指令都只是 token 序列;只要训练集中稳定出现,next-token loss 就会奖励这些结构。 | plot ... 后面高概率接 CALL plot {...},search ... 后面高概率接 CALL search {...}。 |
| 条件层 | 同一个模型参数不必改,任务由文本前缀指定;这就是 GPT-3 论文强调的 text interaction / few-shot conditioning。 | 推理时不给完整答案,只给到 <call>,模型把学到的条件分布展开成 CALL。 |
| 行动层 | 语言模型本身没有执行工具;它生成一个符号动作,外部系统解析并执行。 | CALL calculator 交给 calculator,CALL plot 交给绘图器,CALL search 交给检索器。 |
| 边界 | next-token 只保证“分布上像训练数据”,不保证精确算术、精确复制数字或 JSON 永远合法。 | 所以评测不能只看 loss,还要看 tool accuracy、valid CALL rate、execution pass rate。 |
这也是 Toolformer / ReAct / MRKL 一类工作的共同思想:把外部能力暴露成文本动作或 API 调用,让语言模型负责选择动作和参数,确定性模块负责执行。TinyGPT 版本只是把这个思想缩小到三个本地工具。
参考文献:Attention Is All You Need、Language Models are Few-Shot Learners、Scaling Laws for Neural Language Models、Toolformer、ReAct、MRKL Systems。