算法可视化与交互学习平台

Tool-Calling TinyGPT:训练一个会调用工具的小模型Tool-Calling TinyGPT: Train a Local Tiny Model That Emits Tool Calls

在 TinyCodeGPT 的本地训练闭环之上进入第二阶段:不再让小模型直接写完整代码,而是训练它把用户请求翻译成 CALL calculator / CALL plot / CALL search 三类结构化工具指令,并在本地执行、验证、评测工具选择准确率、JSON 合法率和执行通过率。

LLMIntermediateFree
KernelGPU
8

从生成代码,推进到生成工具调用

No.7 训练的是 **task -> code**:看到任务,续写一段 Python 代码。

这个模块进入第二阶段:训练一个更像 agent 内核的小模型,让它先生成结构化工具调用:

用户请求 user -> CALL calculator {...} -> CALL plot {...} -> CALL search {...}

这一步的关键不是让小模型“什么都知道”,而是让它学会三件事:

1. 什么时候该调用哪个工具。 2. 如何把自然语言参数变成稳定 JSON。 3. 如何让本地 runner 执行 CALL,并用执行结果评测 checkpoint。

完整闭环仍然是 No.7 的工程形状:JSONL 数据集、字符级 tokenizer、TinyGPT 本地训练、checkpoint、生成、执行、错误分析和续训。

协议

第一步:先把 CALL 协议钉死

Tool-calling 训练最怕输出格式漂移,所以第一版只允许三种单行指令。模型不是自由发挥,而是在固定语法里补全工具名和 JSON 参数。

工具模型应该生成本地 runner 执行什么
calculatorCALL calculator {"expression":"23 + 17 * 4"}用 AST 白名单解析算术表达式,返回数值结果。
plotCALL plot {"kind":"function","expression":"x**2","x_min":-10,"x_max":10}用 Matplotlib Agg 在本地生成 PNG 图表 artifact。
searchCALL search {"query":"transformer self attention","top_k":3}在本地课程小语料中检索,返回排序结果。
固定格式带来的好处: 自然语言不稳定 -> CALL 格式稳定 工具选择可评测 -> JSON 合法可评测 -> 执行结果可评测
公式

监督学习目标:让 CALL 序列概率最大

数据集由用户请求 u_i、工具名 a_i 和目标 CALL 文本 g_i 组成。训练时仍然是 next-token 预测,只是被预测的目标不再是 Python 代码,而是一条结构化工具指令。

用户自然语言请求
natural-language user request
目标工具名:calculator、plot 或 search
target tool name
完整 CALL 指令文本
complete CALL instruction

为什么不直接训练最终答案

calculator 的数值、plot 的图片和 search 的检索结果都可以由确定性工具完成。小模型只需要负责“路由和参数化”,这样更容易训练,也更容易定位失败原因。

user -> tool + args -> deterministic execution
JSONL

JSONL 数据格式:user 到 call

jsonl
{"tool":"calculator","user":"calculate 23 plus 17 times 4","call":"CALL calculator {\"expression\":\"23 + 17 * 4\"}"}
{"tool":"plot","user":"plot y = x**2 from -10 to 10","call":"CALL plot {\"kind\":\"function\",\"expression\":\"x**2\",\"x_min\":-10,\"x_max\":10,\"points\":140}"}
{"tool":"search","user":"search for transformer self attention","call":"CALL search {\"query\":\"transformer self attention\",\"top_k\":3}"}
样本

一条样本如何进入训练窗口

用户请求被归一化
user = 'plot y = x**2 from -10 to 10'
Initial Variables
user
plot y = x**2 from -10 to 10
tool
plot
Step 1 Variables
user
plot y = x**2 from -10 to 10
Step 1 / 4
Loss

训练仍是 next-token,只是目标变成 CALL

No.8 没有改变 TinyGPT 的训练数学。改变的是样本语义:模型在 <call> 后续写的是工具指令,因此 loss 下降意味着 CALL 格式、工具名和 JSON 参数越来越像训练分布。

目标 CALL 指令
target CALL instruction
batch size 与 context length
batch size and context length
字符级 next-token 交叉熵
character-level next-token cross entropy

低 loss 不等于工具调用正确

loss 只衡量字符概率。真正质量要看工具名是否正确、JSON 是否能解析、参数是否能执行、执行结果是否符合任务。

loss -> format likelihood execute_call -> real usability
本地

本地 runner 架构:训练和工具执行都在本机

第 8 模块继续采用本地 runner。网页只发 HTTP 请求和展示状态;PyTorch 训练、checkpoint、工具执行和图表 artifact 都留在本机。runner 的默认 --device auto 策略是优先使用 CUDA GPU;如果当前 PyTorch 或驱动环境没有 CUDA,则自动回退 CPU,并在状态卡里显示原因。

AlgoLab Web UI -> http://127.0.0.1:4888 -> toolcalltinygpt_local_runner.py --device auto -> prefer CUDA GPU, fallback CPU -> PyTorch TinyGPT training -> checkpoint -> generate CALL -> execute calculator / plot / search
接口作用
POST /train解析 JSONL,训练本地 Tool-Calling TinyGPT。
POST /generate从 user prompt 生成一条 CALL。
POST /execute_call解析 CALL 并执行本地工具。
POST /evaluate跑少量内置 prompt,返回 tool accuracy、valid CALL rate 和 execution pass rate。
训练

本地训练控制台:Tool-Calling TinyGPT

本面板只负责控制和可视化;真正训练发生在本地 scripts/toolcalltinygpt_local_runner.py 里。

# Windows / PowerShell
& 'C:\Users\richi\TI_richiebao\LLM\.venv\Scripts\python.exe' scripts\toolcalltinygpt_local_runner.py --host 127.0.0.1 --port 4888 --device auto

# 推荐使用与 No.7 TinyCodeGPT 相同的 GPU Python 环境。
# --device auto = 优先 CUDA GPU;CUDA 不可用时自动回退 CPU

# Runner API
GET  /status
POST /train
POST /generate
POST /execute_call
POST /evaluate
Runner 状态
Connected
no
Device
-
Policy
auto
PyTorch
-
CUDA Build
-
CUDA Runtime
no
CUDA Devices
-
Checkpoints
-
Active Job
-
1. 构建 tool-call JSONL 数据集

每条样本都把用户请求监督到一条完整 CALL 指令,而不是最终答案。

Valid Lines
3
Invalid Lines
0
calculator
1
plot
1
search
1
Chars
421
样本格式追踪
<user>
calculate 23 plus 17 times 4
</user>
<call>
CALL calculator {"expression":"23 + 17 * 4"}
</call>
x/y 错位预览
0
<
u
1
u
s
2
s
e
3
e
r
4
r
>
5
>
\n
6
\n
c
7
c
a
8
a
l
9
l
c
10
c
u
11
u
l
12
l
a
13
a
t
14
t
e
15
e
space
16
space
2
17
2
3
18
3
space
19
space
p
20
p
l
21
l
u
22
u
s
23
s
space
24
space
1
25
1
7
26
7
space
27
space
t
2. 本地训练 Tool-Calling TinyGPT

训练目标仍然是 next-token cross entropy,但上下文从 task-code 改成 user-call。

Status
-
Dataset
-
Params
-
Vocab
-
Steps/Epoch
-
Max Steps
-
Train Loss
-
Val Loss
-
等待训练日志...
3. 生成并执行 CALL 指令

默认使用当前训练好的 checkpoint 生成 CALL;CALL 会先被解析,再由本地工具执行。

Selected Path
-
Current Model
-
Parsed Tool
-
Truncated
-
Loaded Model
-
生成后这里会显示 JSON args。
命令

启动本地 runner 的命令

powershell
& 'C:\Users\richi\TI_richiebao\LLM\.venv\Scripts\python.exe' scripts\toolcalltinygpt_local_runner.py --host 127.0.0.1 --port 4888 --device auto
安全

三个工具为什么要做成受限执行

工具调用的意义是把不稳定的语言输出交给稳定的程序执行。但这也要求工具边界清晰,尤其不能把模型输出当作任意代码运行。

工具限制策略失败信号
calculator只允许 AST 白名单:数字、变量 pi/e、基础运算和少量 math 函数。表达式无法解析、指数过大、结果非有限数。
plot只允许 function / line / bar / scatter / histogram,并限制数组长度和点数。kind 不合法、x/y 长度不一致、表达式不安全。
search第一版只检索本地课程小语料,不接互联网,结果可复现。query 为空、top_k 超范围、检索词过宽。
深读

源码深读:runner 的关键函数地图

函数负责什么为什么重要
parse_jsonl_dataset校验 {tool,user,call},并确认 CALL 工具名匹配。脏数据会直接破坏格式学习。
format_sample把样本写成 <user>...<call>...生成时必须复用同一前缀。
run_training沿用 No.7 的 PyTorch 训练循环。训练预算仍由 token 数、batch、context、epoch 决定。
generate_call加载 checkpoint 并采样 CALL。temperature 太高会让 JSON 漂移。
execute_tool_call解析 CALL 并分派到三个本地工具。把“像 CALL”变成“能跑通”。
evaluate_checkpoint跑固定 prompt 小评测。把 loss、格式和执行结果放在同一张成绩单里。
一次失败样例的定位顺序: 1. 是否生成了 ,有没有截断? 2. CALL tool 是否在三类白名单里? 3. JSON 是否能解析? 4. 参数是否能被工具执行? 5. 是该补数据,还是调低 temperature?
源码

本地 runner 源码:toolcalltinygpt_local_runner.py

python
from __future__ import annotations

import argparse
import ast
import base64
import json
import math
import operator
import random
import re
import sys
import threading
import time
import traceback
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from typing import Any
from urllib.parse import urlparse

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
except Exception:  # pragma: no cover - status endpoint reports this clearly.
    torch = None

    class _NNStub:
        Module = object

    nn = _NNStub()
    F = None


def torch_no_grad():
    if torch is not None:
        return torch.no_grad()

    def decorator(function):
        return function

    return decorator


ROOT_DIR = Path(__file__).resolve().parents[1]
RUN_DIR = ROOT_DIR / ".tmp" / "toolcalltinygpt"
CHECKPOINT_DIR = RUN_DIR / "checkpoints"
RUN_OUTPUT_DIR = RUN_DIR / "tool_outputs"
DATASET_DIR = RUN_DIR / "datasets"
RUN_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
RUN_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
DATASET_DIR.mkdir(parents=True, exist_ok=True)

JOBS: dict[str, dict[str, Any]] = {}
STOP_EVENTS: dict[str, threading.Event] = {}
ACTIVE_JOB_ID: str | None = None
LOCK = threading.Lock()

TOOLS = ("calculator", "plot", "search")
CALL_RE = re.compile(r"^\s*CALL\s+([a-zA-Z_][a-zA-Z0-9_]*)\s+(\{.*\})\s*$", re.DOTALL)


@dataclass
class ModelConfig:
    n_layer: int
    n_head: int
    n_embd: int
    block_size: int


PRESETS = {
    "tiny": ModelConfig(n_layer=2, n_head=2, n_embd=96, block_size=160),
    "small": ModelConfig(n_layer=4, n_head=4, n_embd=192, block_size=256),
    "medium": ModelConfig(n_layer=6, n_head=6, n_embd=384, block_size=384),
}

DEVICE_POLICY = "auto"


def torch_cuda_version() -> str | None:
    if torch is None:
        return None
    return getattr(getattr(torch, "version", None), "cuda", None)


def is_cuda_available() -> bool:
    return bool(torch is not None and torch.cuda.is_available())


def cuda_unavailable_reason() -> str:
    if torch is None:
        return "PyTorch is not installed."
    version = str(getattr(torch, "__version__", ""))
    if version.endswith("+cpu") or "+cpu" in version:
        return "The current PyTorch wheel is CPU-only. Install a CUDA-enabled PyTorch wheel to use GPU."
    if torch_cuda_version() is None:
        return "The current PyTorch build does not report CUDA support."
    return "CUDA is not available to PyTorch in this environment."


def resolve_device(requested_policy: str | None = None) -> tuple[str, str]:
    policy = str(requested_policy or DEVICE_POLICY or "auto").strip().lower()
    if policy in {"gpu", "cuda"}:
        policy = "auto"
    if policy not in {"auto", "cpu"}:
        policy = "auto"
    if policy == "cpu":
        return "cpu", "CPU was selected explicitly."
    if is_cuda_available():
        name = torch.cuda.get_device_name(0)
        return "cuda", f"GPU priority: using CUDA device {name}."
    return "cpu", f"GPU priority: CUDA unavailable, falling back to CPU. {cuda_unavailable_reason()}"


LOCAL_SEARCH_DOCS = [
    {
        "title": "Gradient descent and learning rate",
        "body": "Gradient descent updates parameters in the negative gradient direction. The learning rate controls step size.",
    },
    {
        "title": "Transformer self attention",
        "body": "Self attention lets each token mix information from previous tokens through query, key, and value vectors.",
    },
    {
        "title": "TinyGPT next token training",
        "body": "A TinyGPT model learns by predicting the next token at every position with cross entropy loss.",
    },
    {
        "title": "Tool calling JSON schema",
        "body": "Tool calling separates language understanding from deterministic tool execution using structured JSON arguments.",
    },
    {
        "title": "Safe calculator expression parser",
        "body": "A calculator tool should parse expressions with a whitelist instead of using eval or shell execution.",
    },
    {
        "title": "Matplotlib plotting tool",
        "body": "A plotting tool can turn structured chart arguments into local PNG artifacts using a non-interactive backend.",
    },
    {
        "title": "VAE latent space",
        "body": "A variational autoencoder maps data into a probabilistic latent space and samples through reparameterization.",
    },
    {
        "title": "Diffusion denoising",
        "body": "Diffusion models learn to reverse a noise process step by step until an image-like sample appears.",
    },
    {
        "title": "MLP nonlinear feature space",
        "body": "An MLP stacks affine layers and nonlinear activations to fit relationships beyond a linear boundary.",
    },
    {
        "title": "Local checkpoint evaluation",
        "body": "A local model checkpoint should be evaluated by format validity, tool selection accuracy, and execution pass rate.",
    },
]


def utc_now() -> str:
    return datetime.now(timezone.utc).isoformat()


def jsonable(value: Any) -> Any:
    if isinstance(value, Path):
        return str(value)
    return value


def respond(handler: BaseHTTPRequestHandler, status: int, payload: dict[str, Any]) -> None:
    raw = json.dumps(payload, ensure_ascii=False, default=jsonable).encode("utf-8")
    handler.send_response(status)
    handler.send_header("Content-Type", "application/json; charset=utf-8")
    handler.send_header("Content-Length", str(len(raw)))
    handler.send_header("Access-Control-Allow-Origin", "*")
    handler.send_header("Access-Control-Allow-Headers", "Content-Type")
    handler.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
    handler.end_headers()
    handler.wfile.write(raw)


def read_body(handler: BaseHTTPRequestHandler) -> dict[str, Any]:
    length = int(handler.headers.get("Content-Length") or "0")
    if length <= 0:
        return {}
    raw = handler.rfile.read(length).decode("utf-8")
    return json.loads(raw or "{}")


def update_job(job_id: str, **patch: Any) -> None:
    with LOCK:
        job = JOBS[job_id]
        job.update(patch)


def append_log(job_id: str, line: str) -> None:
    with LOCK:
        job = JOBS[job_id]
        logs = job.setdefault("logs", [])
        logs.append(line)
        if len(logs) > 240:
            del logs[: len(logs) - 240]


def append_metric(job_id: str, metric: dict[str, Any]) -> None:
    with LOCK:
        job = JOBS[job_id]
        metrics = job.setdefault("metrics", [])
        metrics.append(metric)
        if len(metrics) > 240:
            del metrics[: len(metrics) - 240]


def update_training_progress(
    job_id: str,
    step: int,
    max_steps: int,
    steps_per_epoch: int,
    target_epochs: float,
) -> None:
    epoch = step / max(steps_per_epoch, 1)
    effective_target_epochs = target_epochs if target_epochs > 0 else max_steps / max(steps_per_epoch, 1)
    update_job(
        job_id,
        progress={
            "step": step,
            "max_steps": max_steps,
            "epoch": round(epoch, 4),
            "target_epochs": round(effective_target_epochs, 4),
            "percent": round(step / max(max_steps, 1), 6),
        },
    )


def clamp_int(value: Any, fallback: int, lower: int, upper: int) -> int:
    try:
        parsed = int(value)
    except (TypeError, ValueError):
        parsed = fallback
    return max(lower, min(parsed, upper))


def clamp_float(value: Any, fallback: float, lower: float, upper: float) -> float:
    try:
        parsed = float(value)
    except (TypeError, ValueError):
        parsed = fallback
    return max(lower, min(parsed, upper))


def resolve_model_config(request: dict[str, Any]) -> tuple[str, ModelConfig]:
    preset = str(request.get("preset") or "small")
    base_config = PRESETS.get(preset, PRESETS["small"])
    raw_config = request.get("model_config") if preset == "custom" else None
    if not isinstance(raw_config, dict):
        return preset if preset in PRESETS else "small", base_config
    config = ModelConfig(
        n_layer=clamp_int(raw_config.get("n_layer"), base_config.n_layer, 1, 12),
        n_head=clamp_int(raw_config.get("n_head"), base_config.n_head, 1, 12),
        n_embd=clamp_int(raw_config.get("n_embd"), base_config.n_embd, 32, 768),
        block_size=clamp_int(raw_config.get("block_size"), base_config.block_size, 64, 1024),
    )
    if config.n_embd % config.n_head != 0:
        raise ValueError("n_embd must be divisible by n_head for custom model config.")
    return "custom", config


def make_call(tool: str, args: dict[str, Any]) -> str:
    return f"CALL {tool} {json.dumps(args, ensure_ascii=False, separators=(',', ':'))}"


def add_sample(samples: list[dict[str, str]], tool: str, user: str, call: str) -> None:
    samples.append({"tool": tool, "user": user.strip(), "call": call.strip()})


def generate_dataset(samples_per_tool: int, seed: int = 11) -> list[dict[str, str]]:
    rng = random.Random(seed)
    samples: list[dict[str, str]] = []
    search_topics = [
        "gradient descent learning rate",
        "transformer self attention",
        "TinyGPT next token training",
        "tool calling JSON schema",
        "safe calculator expression parser",
        "matplotlib plotting tool",
        "local checkpoint evaluation",
    ]

    for _ in range(samples_per_tool):
        a, b, c = rng.randint(2, 90), rng.randint(2, 90), rng.randint(2, 20)
        add_sample(samples, "calculator", f"calculate {a} plus {b} times {c}", make_call("calculator", {"expression": f"{a} + {b} * {c}"}))
        add_sample(samples, "calculator", f"what is {a} percent of {b * c}", make_call("calculator", {"expression": f"{b * c} * {a} / 100"}))
        add_sample(samples, "calculator", f"compute area of circle with radius {c}", make_call("calculator", {"expression": f"pi * {c} ** 2"}))

        values = [rng.randint(1, 50) for _ in range(rng.randint(4, 8))]
        labels = [f"item_{index + 1}" for index in range(len(values))]
        expression = rng.choice(["x**2", "sin(x)", "cos(x)", "exp(-x**2)"])
        add_sample(samples, "plot", f"plot y = {expression} from -10 to 10", make_call("plot", {"kind": "function", "expression": expression, "x_min": -10, "x_max": 10, "points": 140}))
        add_sample(samples, "plot", f"draw a line chart for values {values}", make_call("plot", {"kind": "line", "x": list(range(1, len(values) + 1)), "y": values, "title": "Line chart"}))
        add_sample(samples, "plot", f"make a bar chart for values {values[:5]}", make_call("plot", {"kind": "bar", "labels": labels[:5], "values": values[:5], "title": "Bar chart"}))

        topic = rng.choice(search_topics)
        add_sample(samples, "search", f"search for {topic}", make_call("search", {"query": topic, "top_k": rng.randint(2, 5)}))
        add_sample(samples, "search", f"find beginner notes about {topic}", make_call("search", {"query": topic, "top_k": rng.randint(2, 5)}))

    rng.shuffle(samples)
    return samples


def strip_jsonl_fences(raw: str) -> str:
    text = raw.strip()
    if text.startswith("```"):
        lines = text.splitlines()
        if lines and lines[0].lstrip().startswith("```"):
            lines = lines[1:]
        if lines and lines[-1].strip() == "```":
            lines = lines[:-1]
        text = "\n".join(lines).strip()
    return text


def parse_call(call_text: str) -> tuple[str, dict[str, Any]]:
    match = CALL_RE.match(call_text.strip())
    if not match:
        raise ValueError("CALL must match: CALL tool {json_args}")
    tool = match.group(1)
    if tool not in TOOLS:
        raise ValueError(f"Unsupported tool: {tool}")
    args = json.loads(match.group(2))
    if not isinstance(args, dict):
        raise ValueError("CALL args must be a JSON object")
    return tool, args


def normalize_dataset_sample(value: Any, line_number: int) -> dict[str, str]:
    if not isinstance(value, dict):
        raise ValueError(f"line {line_number}: expected a JSON object")
    tool = str(value.get("tool") or "").strip()
    user = str(value.get("user") or "").strip()
    call_text = str(value.get("call") or "").strip()
    if tool not in TOOLS:
        raise ValueError(f"line {line_number}: tool must be one of {', '.join(TOOLS)}")
    if not user:
        raise ValueError(f"line {line_number}: user is required")
    parsed_tool, _args = parse_call(call_text)
    if parsed_tool != tool:
        raise ValueError(f"line {line_number}: tool and CALL tool do not match")
    return {"tool": tool, "user": user, "call": call_text}


def parse_jsonl_dataset(raw: str) -> list[dict[str, str]]:
    text = strip_jsonl_fences(raw)
    if not text:
        raise ValueError("dataset_jsonl is empty")
    if text.startswith("["):
        parsed = json.loads(text)
        if not isinstance(parsed, list):
            raise ValueError("dataset JSON array is invalid")
        samples = [normalize_dataset_sample(item, index + 1) for index, item in enumerate(parsed)]
    else:
        samples = []
        for index, line in enumerate(text.splitlines(), start=1):
            stripped = line.strip()
            if not stripped:
                continue
            samples.append(normalize_dataset_sample(json.loads(stripped), index))
    if not samples:
        raise ValueError("dataset_jsonl does not contain any samples")
    return samples


def save_jsonl_dataset(job_id: str, samples: list[dict[str, str]]) -> Path:
    dataset_path = DATASET_DIR / f"{job_id}.jsonl"
    payload = "\n".join(json.dumps(sample, ensure_ascii=False) for sample in samples) + "\n"
    dataset_path.write_text(payload, encoding="utf-8")
    return dataset_path


def load_training_samples(
    job_id: str,
    request: dict[str, Any],
    samples_per_tool: int,
    seed: int,
) -> tuple[list[dict[str, str]], str, Path | None]:
    raw_jsonl = str(request.get("dataset_jsonl") or "").strip()
    if raw_jsonl:
        samples = parse_jsonl_dataset(raw_jsonl)
        dataset_path = save_jsonl_dataset(job_id, samples)
        return samples, "tool_call_jsonl", dataset_path
    return generate_dataset(samples_per_tool, seed=seed), "template_synthetic", None


def format_sample(sample: dict[str, str]) -> str:
    return f"<user>\n{sample['user']}\n</user>\n<call>\n{sample['call']}\n</call>\n"


def build_vocab(text: str) -> tuple[dict[str, int], dict[int, str]]:
    chars = sorted(set(text))
    stoi = {ch: i for i, ch in enumerate(chars)}
    itos = {i: ch for ch, i in stoi.items()}
    return stoi, itos


def normalize_itos(raw_itos: dict[Any, str]) -> dict[int, str]:
    if not raw_itos:
        return {}
    return {int(key): value for key, value in raw_itos.items()} if isinstance(next(iter(raw_itos.keys())), str) else raw_itos


def count_parameters(model: nn.Module) -> int:
    return sum(parameter.numel() for parameter in model.parameters())


class CausalSelfAttention(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        if config.n_embd % config.n_head != 0:
            raise ValueError("n_embd must be divisible by n_head")
        self.n_head = config.n_head
        self.head_dim = config.n_embd // config.n_head
        self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        batch, tokens, channels = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.split(channels, dim=2)
        q = q.view(batch, tokens, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(batch, tokens, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(batch, tokens, self.n_head, self.head_dim).transpose(1, 2)
        attention = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attention = attention.masked_fill(self.mask[:, :, :tokens, :tokens] == 0, float("-inf"))
        attention = F.softmax(attention, dim=-1)
        y = attention @ v
        y = y.transpose(1, 2).contiguous().view(batch, tokens, channels)
        return self.proj(y)


class TransformerBlock(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.ffn = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x


class ToolCallTinyGPT(nn.Module):
    def __init__(self, vocab_size: int, config: ModelConfig):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(vocab_size, config.n_embd)
        self.pos_emb = nn.Embedding(config.block_size, config.n_embd)
        self.blocks = nn.Sequential(*[TransformerBlock(config) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.head = nn.Linear(config.n_embd, vocab_size, bias=False)

    def forward(self, idx, targets=None):
        _batch, tokens = idx.shape
        positions = torch.arange(tokens, device=idx.device)
        x = self.token_emb(idx) + self.pos_emb(positions)
        x = self.blocks(x)
        logits = self.head(self.ln_f(x))
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
        return logits, loss

    @torch_no_grad()
    def generate(
        self,
        idx,
        max_new_tokens: int,
        temperature: float = 0.45,
        top_k: int = 30,
        stop_sequence: list[int] | None = None,
    ):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.config.block_size :]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / max(temperature, 1e-4)
            if top_k:
                values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < values[:, [-1]]] = -float("inf")
            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_id), dim=1)
            if stop_sequence and idx.size(1) >= len(stop_sequence):
                tail = idx[0, -len(stop_sequence) :].detach().cpu().tolist()
                if tail == stop_sequence:
                    break
        return idx


def make_batch(data: Any, block_size: int, batch_size: int, device: str):
    starts = torch.randint(0, len(data) - block_size - 1, (batch_size,))
    x = torch.stack([data[i : i + block_size] for i in starts]).to(device)
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in starts]).to(device)
    return x, y


def evaluate_loss(model: nn.Module, data: Any, block_size: int, batch_size: int, device: str) -> float:
    model.eval()
    losses = []
    with torch.no_grad():
        for _ in range(6):
            x, y = make_batch(data, block_size, batch_size, device)
            _logits, loss = model(x, y)
            losses.append(float(loss.item()))
    model.train()
    return sum(losses) / len(losses)


def latest_checkpoint() -> Path | None:
    checkpoints = sorted(CHECKPOINT_DIR.glob("*.pt"), key=lambda item: item.stat().st_mtime, reverse=True)
    return checkpoints[0] if checkpoints else None


def load_resume_checkpoint(path_value: str | None) -> tuple[Path, dict[str, Any]]:
    checkpoint_path = Path(path_value) if path_value else latest_checkpoint()
    if checkpoint_path is None or not checkpoint_path.exists():
        raise RuntimeError("No checkpoint found to continue training from.")
    return checkpoint_path, torch.load(checkpoint_path, map_location="cpu")


def run_training(job_id: str, request: dict[str, Any]) -> None:
    global ACTIVE_JOB_ID
    if torch is None:
        update_job(job_id, status="failed", error="PyTorch is not installed in this Python environment.", finished_at=utc_now())
        return

    stop_event = STOP_EVENTS[job_id]
    try:
        continue_from_checkpoint = bool(request.get("continue_from_checkpoint"))
        resume_checkpoint_path = str(request.get("checkpoint_path") or "").strip() or None
        preset, config = resolve_model_config(request)
        samples_per_tool = max(10, min(int(request.get("samples_per_tool") or 120), 1000))
        fallback_max_steps = max(20, min(int(request.get("max_steps") or 50_000), 200_000))
        target_epochs = clamp_float(request.get("target_epochs"), 6.0, 0.0, 100.0)
        batch_size = max(4, min(int(request.get("batch_size") or 48), 128))
        learning_rate = float(request.get("learning_rate") or 8e-4)
        seed = int(request.get("seed") or 11)

        random.seed(seed)
        torch.manual_seed(seed)
        device, device_reason = resolve_device(request.get("device"))
        update_job(job_id, status="running", started_at=utc_now(), device=device, device_policy=DEVICE_POLICY, device_reason=device_reason)
        append_log(job_id, f"device_policy = {DEVICE_POLICY}")
        append_log(job_id, f"device = {device}")
        append_log(job_id, device_reason)
        append_log(job_id, f"preset = {preset}, config = {config}")

        samples, dataset_source, dataset_path = load_training_samples(job_id, request, samples_per_tool, seed)
        text = "\n".join(format_sample(sample) for sample in samples)
        resume_checkpoint: dict[str, Any] | None = None
        resumed_from_checkpoint: Path | None = None
        if continue_from_checkpoint:
            resumed_from_checkpoint, resume_checkpoint = load_resume_checkpoint(resume_checkpoint_path)
            config = ModelConfig(**resume_checkpoint["config"])
            preset = str(resume_checkpoint.get("metadata", {}).get("preset") or "continued")
            stoi = resume_checkpoint["stoi"]
            itos = normalize_itos(resume_checkpoint["itos"])
            missing_chars = sorted(set(text) - set(stoi))
            if missing_chars:
                preview = "".join(missing_chars[:20])
                raise RuntimeError(
                    "Cannot continue from this checkpoint because the new dataset contains characters "
                    f"that are not in the checkpoint vocabulary: {preview!r}"
                )
            append_log(job_id, f"continue_from_checkpoint = {resumed_from_checkpoint}")
        else:
            stoi, itos = build_vocab(text)

        ids = torch.tensor([stoi[ch] for ch in text], dtype=torch.long)
        if len(ids) <= config.block_size + batch_size + 2:
            raise RuntimeError("Dataset is too small for the selected preset and batch size.")
        split = max(config.block_size + batch_size + 2, int(len(ids) * 0.92))
        split = min(split, len(ids) - config.block_size - batch_size - 2)
        train_ids = ids[:split]
        val_ids = ids[split:]
        if len(val_ids) <= config.block_size + batch_size + 2:
            val_ids = train_ids

        steps_per_epoch = max(1, math.ceil(len(train_ids) / (batch_size * config.block_size)))
        epoch_steps = math.ceil(steps_per_epoch * target_epochs) if target_epochs > 0 else fallback_max_steps
        max_steps = max(20, min(epoch_steps, 200_000))
        model = ToolCallTinyGPT(len(stoi), config).to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1)
        if resume_checkpoint is not None:
            model.load_state_dict(resume_checkpoint["model_state_dict"])
            if resume_checkpoint.get("optimizer_state_dict"):
                optimizer.load_state_dict(resume_checkpoint["optimizer_state_dict"])
                for state in optimizer.state.values():
                    for key, value in state.items():
                        if hasattr(value, "to"):
                            state[key] = value.to(device)

        parameter_count = count_parameters(model)
        tool_counts = {tool: sum(1 for sample in samples if sample["tool"] == tool) for tool in TOOLS}
        update_job(
            job_id,
            dataset_size=len(samples),
            tool_counts=tool_counts,
            dataset_source=dataset_source,
            dataset_path=str(dataset_path) if dataset_path else None,
            vocab_size=len(stoi),
            parameter_count=parameter_count,
            preset=preset,
            model_config=config.__dict__,
            max_steps=max_steps,
            fallback_max_steps=fallback_max_steps,
            target_epochs=target_epochs,
            steps_per_epoch=steps_per_epoch,
            resumed_from_checkpoint=str(resumed_from_checkpoint) if resumed_from_checkpoint else None,
        )
        update_training_progress(job_id, 0, max_steps, steps_per_epoch, target_epochs)
        append_log(job_id, f"dataset_size = {len(samples)} samples")
        append_log(job_id, f"tool_counts = {tool_counts}")
        append_log(job_id, f"vocab_size = {len(stoi)} characters")
        append_log(job_id, f"parameters = {parameter_count:,}")
        append_log(job_id, f"target_epochs = {target_epochs:g}, steps_per_epoch ≈ {steps_per_epoch}, planned_steps = {max_steps}")

        log_every = max(10, max_steps // 18)
        progress_every = 1 if max_steps <= 10_000 else max(1, max_steps // 2_000)
        start_time = time.time()
        for step in range(1, max_steps + 1):
            if stop_event.is_set():
                update_job(job_id, status="stopped", finished_at=utc_now())
                append_log(job_id, "training stopped by user")
                return
            x, y = make_batch(train_ids, config.block_size, batch_size, device)
            _logits, loss = model(x, y)
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            if step == 1 or step % progress_every == 0 or step == max_steps:
                update_training_progress(job_id, step, max_steps, steps_per_epoch, target_epochs)

            if step == 1 or step % log_every == 0 or step == max_steps:
                val_loss = evaluate_loss(model, val_ids, config.block_size, min(batch_size, 32), device)
                elapsed = max(time.time() - start_time, 1e-6)
                tokens_per_second = int(step * batch_size * config.block_size / elapsed)
                metric = {
                    "step": step,
                    "train_loss": round(float(loss.item()), 6),
                    "val_loss": round(float(val_loss), 6),
                    "tokens_per_second": tokens_per_second,
                    "epoch": round(step / steps_per_epoch, 4),
                }
                append_metric(job_id, metric)
                append_log(job_id, f"step {step:5d}/{max_steps} | epoch {metric['epoch']:.2f}/{target_epochs:g} | train_loss {metric['train_loss']:.4f} | val_loss {metric['val_loss']:.4f} | {tokens_per_second:,} tok/s")

        checkpoint_path = CHECKPOINT_DIR / f"{job_id}.pt"
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "config": config.__dict__,
                "stoi": stoi,
                "itos": itos,
                "samples": samples[:120],
                "metadata": {
                    "preset": preset,
                    "model_config": config.__dict__,
                    "samples_per_tool": samples_per_tool,
                    "max_steps": max_steps,
                    "fallback_max_steps": fallback_max_steps,
                    "target_epochs": target_epochs,
                    "steps_per_epoch": steps_per_epoch,
                    "batch_size": batch_size,
                    "learning_rate": learning_rate,
                    "device": device,
                    "device_policy": DEVICE_POLICY,
                    "device_reason": device_reason,
                    "dataset_source": dataset_source,
                    "dataset_path": str(dataset_path) if dataset_path else None,
                    "dataset_size": len(samples),
                    "tool_counts": tool_counts,
                    "resumed_from_checkpoint": str(resumed_from_checkpoint) if resumed_from_checkpoint else None,
                    "created_at": utc_now(),
                },
            },
            checkpoint_path,
        )
        update_job(job_id, status="completed", checkpoint_path=str(checkpoint_path), finished_at=utc_now())
        append_log(job_id, f"checkpoint saved: {checkpoint_path}")
    except Exception as error:
        update_job(job_id, status="failed", error=str(error), finished_at=utc_now())
        append_log(job_id, traceback.format_exc())
    finally:
        with LOCK:
            if ACTIVE_JOB_ID == job_id:
                ACTIVE_JOB_ID = None


def load_checkpoint(path_value: str | None):
    if torch is None:
        raise RuntimeError("PyTorch is not installed in this Python environment.")
    checkpoint_path = Path(path_value) if path_value else latest_checkpoint()
    if checkpoint_path is None or not checkpoint_path.exists():
        raise RuntimeError("No Tool-Calling TinyGPT checkpoint found. Train a model first.")
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    config = ModelConfig(**checkpoint["config"])
    stoi = checkpoint["stoi"]
    itos = normalize_itos(checkpoint["itos"])
    model = ToolCallTinyGPT(len(stoi), config)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    return checkpoint_path, model, config, stoi, itos


def extract_call(text: str) -> str:
    marker = "<call>\n"
    start = text.find(marker)
    if start >= 0:
        text = text[start + len(marker) :]
    end = text.find("</call>")
    if end >= 0:
        text = text[:end]
    return text.strip()


def generate_call(request: dict[str, Any]) -> dict[str, Any]:
    checkpoint_path, model, config, stoi, itos = load_checkpoint(request.get("checkpoint_path"))
    device, device_reason = resolve_device(request.get("device"))
    model.to(device)
    user = str(request.get("user") or "").strip()
    if not user:
        raise RuntimeError("User prompt cannot be empty.")
    prompt = f"<user>\n{user}\n</user>\n<call>\n"
    encoded = [stoi[ch] for ch in prompt if ch in stoi]
    if not encoded:
        raise RuntimeError("Prompt contains no known characters from the tokenizer vocabulary.")
    idx = torch.tensor([encoded], dtype=torch.long, device=device)
    max_new_tokens = max(40, min(int(request.get("max_new_tokens") or 240), 800))
    temperature = max(0.1, min(float(request.get("temperature") or 0.4), 1.2))
    top_k = max(0, min(int(request.get("top_k") or 30), 100))
    stop_sequence = [stoi[ch] for ch in "\n</call>" if ch in stoi]
    with torch.no_grad():
        generated = model.generate(
            idx,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k,
            stop_sequence=stop_sequence if stop_sequence else None,
        )
    text = "".join(itos[int(token)] for token in generated[0].detach().cpu().tolist())
    call_text = extract_call(text)
    parsed: dict[str, Any] | None = None
    parse_error = None
    try:
        tool, args = parse_call(call_text)
        parsed = {"tool": tool, "args": args}
    except Exception as error:
        parse_error = str(error)
    return {
        "user": user,
        "call": call_text,
        "parsed": parsed,
        "parse_error": parse_error,
        "truncated": "</call>" not in text,
        "checkpoint_path": str(checkpoint_path),
        "device": device,
        "device_reason": device_reason,
    }


ALLOWED_BINOPS = {
    ast.Add: operator.add,
    ast.Sub: operator.sub,
    ast.Mult: operator.mul,
    ast.Div: operator.truediv,
    ast.FloorDiv: operator.floordiv,
    ast.Mod: operator.mod,
    ast.Pow: operator.pow,
}
ALLOWED_UNARYOPS = {ast.UAdd: operator.pos, ast.USub: operator.neg}
ALLOWED_NAMES = {
    "pi": math.pi,
    "e": math.e,
}
ALLOWED_FUNCS = {
    "sqrt": math.sqrt,
    "sin": math.sin,
    "cos": math.cos,
    "tan": math.tan,
    "log": math.log,
    "log10": math.log10,
    "exp": math.exp,
    "abs": abs,
    "round": round,
}


def safe_eval_expression(expression: str, variable_values: dict[str, float] | None = None) -> float:
    tree = ast.parse(expression, mode="eval")
    variables = {**ALLOWED_NAMES, **(variable_values or {})}

    def visit(node: ast.AST):
        if isinstance(node, ast.Expression):
            return visit(node.body)
        if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
            return float(node.value)
        if isinstance(node, ast.Name) and node.id in variables:
            return variables[node.id]
        if isinstance(node, ast.BinOp) and type(node.op) in ALLOWED_BINOPS:
            left = visit(node.left)
            right = visit(node.right)
            if isinstance(node.op, ast.Pow) and abs(right) > 8:
                raise ValueError("Power exponent is too large.")
            return ALLOWED_BINOPS[type(node.op)](left, right)
        if isinstance(node, ast.UnaryOp) and type(node.op) in ALLOWED_UNARYOPS:
            return ALLOWED_UNARYOPS[type(node.op)](visit(node.operand))
        if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id in ALLOWED_FUNCS:
            args = [visit(arg) for arg in node.args]
            if len(args) > 3:
                raise ValueError("Too many function arguments.")
            return ALLOWED_FUNCS[node.func.id](*args)
        raise ValueError(f"Unsupported expression node: {type(node).__name__}")

    result = visit(tree)
    if not isinstance(result, (int, float)) or not math.isfinite(result):
        raise ValueError("Expression result is not finite.")
    return float(result)


def execute_calculator(args: dict[str, Any]) -> dict[str, Any]:
    expression = str(args.get("expression") or "").strip()
    if not expression:
        raise ValueError("calculator.expression is required")
    if len(expression) > 240:
        raise ValueError("calculator.expression is too long")
    value = safe_eval_expression(expression)
    return {
        "tool": "calculator",
        "output": {"expression": expression, "value": value, "rounded": round(value, 8)},
        "artifacts": [],
    }


def sanitize_float_list(value: Any, name: str, max_items: int = 200) -> list[float]:
    if not isinstance(value, list):
        raise ValueError(f"{name} must be a list")
    if not 1 <= len(value) <= max_items:
        raise ValueError(f"{name} must contain 1-{max_items} items")
    parsed = [float(item) for item in value]
    if not all(math.isfinite(item) for item in parsed):
        raise ValueError(f"{name} contains non-finite values")
    return parsed


def encode_png_artifact(path: Path) -> dict[str, str]:
    payload = base64.b64encode(path.read_bytes()).decode("ascii")
    return {
        "filename": path.name,
        "path": str(path),
        "mime_type": "image/png",
        "data_url": f"data:image/png;base64,{payload}",
    }


def execute_plot(args: dict[str, Any]) -> dict[str, Any]:
    import matplotlib

    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    kind = str(args.get("kind") or "").strip().lower()
    title = str(args.get("title") or kind.title())[:80]
    run_id = uuid.uuid4().hex[:10]
    output_path = RUN_OUTPUT_DIR / f"{run_id}_plot.png"
    plt.figure(figsize=(6, 4))

    if kind == "function":
        expression = str(args.get("expression") or "").strip()
        x_min = float(args.get("x_min", -10))
        x_max = float(args.get("x_max", 10))
        points = max(20, min(int(args.get("points") or 120), 300))
        if not expression:
            raise ValueError("plot.expression is required for function plots")
        if x_min >= x_max:
            raise ValueError("x_min must be smaller than x_max")
        xs = [x_min + (x_max - x_min) * index / (points - 1) for index in range(points)]
        ys = [safe_eval_expression(expression, {"x": x}) for x in xs]
        plt.plot(xs, ys)
    elif kind == "line":
        y = sanitize_float_list(args.get("y"), "plot.y")
        x = sanitize_float_list(args.get("x"), "plot.x") if isinstance(args.get("x"), list) else list(range(len(y)))
        if len(x) != len(y):
            raise ValueError("plot.x and plot.y must have the same length")
        plt.plot(x, y, marker="o")
    elif kind == "bar":
        values = sanitize_float_list(args.get("values"), "plot.values", max_items=40)
        labels_raw = args.get("labels")
        labels = [str(item)[:24] for item in labels_raw] if isinstance(labels_raw, list) else [str(index + 1) for index in range(len(values))]
        if len(labels) != len(values):
            raise ValueError("plot.labels and plot.values must have the same length")
        plt.bar(labels, values)
        plt.xticks(rotation=20)
    elif kind == "scatter":
        x = sanitize_float_list(args.get("x"), "plot.x")
        y = sanitize_float_list(args.get("y"), "plot.y")
        if len(x) != len(y):
            raise ValueError("plot.x and plot.y must have the same length")
        plt.scatter(x, y)
    elif kind == "histogram":
        values = sanitize_float_list(args.get("values"), "plot.values")
        bins = max(2, min(int(args.get("bins") or 8), 40))
        plt.hist(values, bins=bins)
    else:
        raise ValueError("plot.kind must be function, line, bar, scatter, or histogram")

    plt.title(title)
    plt.tight_layout()
    plt.savefig(output_path, bbox_inches="tight")
    plt.close()
    return {
        "tool": "plot",
        "output": {"kind": kind, "title": title, "path": str(output_path)},
        "artifacts": [encode_png_artifact(output_path)],
    }


def execute_search(args: dict[str, Any]) -> dict[str, Any]:
    query = str(args.get("query") or "").strip()
    if not query:
        raise ValueError("search.query is required")
    top_k = max(1, min(int(args.get("top_k") or 3), 8))
    query_terms = {term for term in re.split(r"\W+", query.lower()) if term}
    scored = []
    for doc in LOCAL_SEARCH_DOCS:
        haystack = f"{doc['title']} {doc['body']}".lower()
        score = sum(2 if term in doc["title"].lower() else 1 for term in query_terms if term in haystack)
        scored.append({**doc, "score": score})
    ranked = sorted(scored, key=lambda item: item["score"], reverse=True)[:top_k]
    return {
        "tool": "search",
        "output": {"query": query, "top_k": top_k, "results": ranked},
        "artifacts": [],
    }


def execute_tool_call(request: dict[str, Any]) -> dict[str, Any]:
    call_text = str(request.get("call") or "").strip()
    tool, args = parse_call(call_text)
    if tool == "calculator":
        result = execute_calculator(args)
    elif tool == "plot":
        result = execute_plot(args)
    elif tool == "search":
        result = execute_search(args)
    else:
        raise ValueError(f"Unsupported tool: {tool}")
    return {"call": call_text, "parsed": {"tool": tool, "args": args}, **result}


def evaluate_checkpoint(request: dict[str, Any]) -> dict[str, Any]:
    prompts = [
        ("calculator", "calculate 23 plus 17 times 4"),
        ("calculator", "what is 18 percent of 250"),
        ("plot", "plot y = x**2 from -10 to 10"),
        ("plot", "draw a line chart for values 3, 8, 5, 13"),
        ("search", "search for transformer self attention"),
        ("search", "find beginner notes about tool calling JSON schema"),
    ]
    max_items = max(1, min(int(request.get("max_items") or len(prompts)), len(prompts)))
    results = []
    for expected_tool, user in prompts[:max_items]:
        generated = generate_call({**request, "user": user, "temperature": request.get("temperature", 0.25)})
        parsed = generated.get("parsed")
        tool_ok = bool(parsed and parsed.get("tool") == expected_tool)
        execution_ok = False
        error = None
        execution = None
        if parsed:
            try:
                execution = execute_tool_call({"call": generated["call"]})
                execution_ok = True
            except Exception as exec_error:
                error = str(exec_error)
        else:
            error = generated.get("parse_error")
        results.append({
            "user": user,
            "expected_tool": expected_tool,
            "call": generated["call"],
            "tool_ok": tool_ok,
            "valid_call": bool(parsed),
            "execution_ok": execution_ok,
            "error": error,
            "execution_output": execution.get("output") if execution else None,
        })
    total = len(results) or 1
    return {
        "items": results,
        "summary": {
            "count": len(results),
            "tool_accuracy": sum(1 for item in results if item["tool_ok"]) / total,
            "valid_call_rate": sum(1 for item in results if item["valid_call"]) / total,
            "execution_pass_rate": sum(1 for item in results if item["execution_ok"]) / total,
        },
    }


def status_payload() -> dict[str, Any]:
    checkpoints = sorted(CHECKPOINT_DIR.glob("*.pt"), key=lambda item: item.stat().st_mtime, reverse=True)
    torch_available = torch is not None
    cuda_available = is_cuda_available()
    device, device_reason = resolve_device()
    cuda_device_count = torch.cuda.device_count() if torch_available else 0
    return {
        "ok": True,
        "python": sys.version.split()[0],
        "torch_available": torch_available,
        "torch_version": getattr(torch, "__version__", None) if torch_available else None,
        "torch_cuda_version": torch_cuda_version(),
        "cuda_available": cuda_available,
        "cuda_device": torch.cuda.get_device_name(0) if cuda_available else None,
        "cuda_device_count": cuda_device_count,
        "gpu_preferred": DEVICE_POLICY == "auto",
        "device_policy": DEVICE_POLICY,
        "device": device,
        "device_reason": device_reason,
        "cuda_unavailable_reason": None if cuda_available else cuda_unavailable_reason(),
        "working_dir": str(RUN_DIR),
        "active_job_id": ACTIVE_JOB_ID,
        "checkpoints": [str(item) for item in checkpoints[:10]],
        "tools": list(TOOLS),
    }


class ToolCallTinyGPTHandler(BaseHTTPRequestHandler):
    def log_message(self, format: str, *args: Any) -> None:
        print(f"[toolcalltinygpt-runner] {self.address_string()} - {format % args}")

    def do_OPTIONS(self) -> None:
        respond(self, 200, {"ok": True})

    def do_GET(self) -> None:
        parsed = urlparse(self.path)
        path = parsed.path.rstrip("/") or "/"
        if path in {"/", "/status", "/health"}:
            respond(self, 200, status_payload())
            return
        if path.startswith("/jobs/"):
            job_id = path.split("/")[-1]
            with LOCK:
                job = JOBS.get(job_id)
            if not job:
                respond(self, 404, {"message": "Job not found."})
                return
            respond(self, 200, job)
            return
        respond(self, 404, {"message": "Unknown route."})

    def do_POST(self) -> None:
        global ACTIVE_JOB_ID
        parsed = urlparse(self.path)
        path = parsed.path.rstrip("/") or "/"
        try:
            if path == "/train":
                body = read_body(self)
                with LOCK:
                    if ACTIVE_JOB_ID and JOBS.get(ACTIVE_JOB_ID, {}).get("status") in {"queued", "running"}:
                        respond(self, 409, {"message": f"Training job already running: {ACTIVE_JOB_ID}"})
                        return
                    job_id = uuid.uuid4().hex[:12]
                    job = {"id": job_id, "status": "queued", "created_at": utc_now(), "logs": [], "metrics": []}
                    JOBS[job_id] = job
                    STOP_EVENTS[job_id] = threading.Event()
                    ACTIVE_JOB_ID = job_id
                thread = threading.Thread(target=run_training, args=(job_id, body), daemon=True)
                thread.start()
                respond(self, 200, job)
                return

            if path.startswith("/jobs/") and path.endswith("/stop"):
                parts = path.split("/")
                job_id = parts[2] if len(parts) >= 3 else ""
                if job_id in STOP_EVENTS:
                    STOP_EVENTS[job_id].set()
                with LOCK:
                    job = JOBS.get(job_id)
                if not job:
                    respond(self, 404, {"message": "Job not found."})
                    return
                respond(self, 200, job)
                return

            if path == "/generate":
                respond(self, 200, generate_call(read_body(self)))
                return

            if path == "/execute_call":
                respond(self, 200, execute_tool_call(read_body(self)))
                return

            if path == "/evaluate":
                respond(self, 200, evaluate_checkpoint(read_body(self)))
                return

            respond(self, 404, {"message": "Unknown route."})
        except Exception as error:
            respond(self, 500, {"message": str(error), "traceback": traceback.format_exc()})


def main() -> None:
    global DEVICE_POLICY
    parser = argparse.ArgumentParser(description="Local Tool-Calling TinyGPT training runner for AlgoLab.")
    parser.add_argument("--host", default="127.0.0.1")
    parser.add_argument("--port", type=int, default=4888)
    parser.add_argument(
        "--device",
        choices=["auto", "cuda", "gpu", "cpu"],
        default="auto",
        help="auto/cuda/gpu prioritizes CUDA GPU and falls back to CPU when CUDA is unavailable; cpu forces CPU.",
    )
    args = parser.parse_args()
    DEVICE_POLICY = "auto" if args.device in {"auto", "cuda", "gpu"} else "cpu"
    server = ThreadingHTTPServer((args.host, args.port), ToolCallTinyGPTHandler)
    print(f"Tool-Calling TinyGPT local runner listening on http://{args.host}:{args.port}")
    print(f"Working directory: {RUN_DIR}")
    print(f"Device policy: {DEVICE_POLICY}")
    print(f"Resolved device: {resolve_device()[0]} ({resolve_device()[1]})")
    print("Press Ctrl+C to stop.")
    try:
        server.serve_forever()
    except KeyboardInterrupt:
        print("\nStopping Tool-Calling TinyGPT local runner.")
    finally:
        server.server_close()


if __name__ == "__main__":
    main()
下一步

下一步:从单次 CALL 到 CALL -> OBSERVATION -> FINAL

本模块故意只做单次 CALL。等工具选择和 JSON 参数稳定后,下一阶段才适合训练多段轨迹:

USER: compare x^2 and sin(x) CALL plot {...} OBSERVATION: plot artifact path + summary FINAL: explain what the plot shows

这样课程主线会很清楚:No.6 理解 next-token,No.7 训练 task-to-code,No.8 训练 tool-calling,后续模块再进入 ReAct、多工具编排和反馈式 agent。

原理

为什么 next-token 能长出代码和工具调用能力

专业地说,decoder-only GPT 不是在学习一个显式函数 f(x)=y,而是在估计条件分布 pθ(tᵢ | t<i)。训练把整段文本的概率分解为每个位置的 next-token 概率;cross entropy 下降,表示模型把“在这个上下文之后应该接什么 token”排得更靠前。

训练串: <user>u</user> <call>g</call> 目标:最大化 pθ(t1:T) = ∏ pθ(tᵢ | t<i) 推理:给 prefix = <user>u</user><call> 循环预测 next token -> g = CALL tool {...}
层次next-token 学到什么在本模块对应什么
表示层Transformer self-attention 把前文 token 编码成上下文相关表示,因此输出分布依赖完整 prompt。<user>...</user><call> 让模型知道接下来应该进入工具调用空间。
目标层代码、JSON、CALL 指令都只是 token 序列;只要训练集中稳定出现,next-token loss 就会奖励这些结构。plot ... 后面高概率接 CALL plot {...}search ... 后面高概率接 CALL search {...}
条件层同一个模型参数不必改,任务由文本前缀指定;这就是 GPT-3 论文强调的 text interaction / few-shot conditioning。推理时不给完整答案,只给到 <call>,模型把学到的条件分布展开成 CALL。
行动层语言模型本身没有执行工具;它生成一个符号动作,外部系统解析并执行。CALL calculator 交给 calculator,CALL plot 交给绘图器,CALL search 交给检索器。
边界next-token 只保证“分布上像训练数据”,不保证精确算术、精确复制数字或 JSON 永远合法。所以评测不能只看 loss,还要看 tool accuracy、valid CALL rate、execution pass rate。

这也是 Toolformer / ReAct / MRKL 一类工作的共同思想:把外部能力暴露成文本动作或 API 调用,让语言模型负责选择动作和参数,确定性模块负责执行。TinyGPT 版本只是把这个思想缩小到三个本地工具。

参考文献:Attention Is All You NeedLanguage Models are Few-Shot LearnersScaling Laws for Neural Language ModelsToolformerReActMRKL Systems

问问 LLM:如何把 TinyGPT 变成工具调用模型