In this tutorial, we build an advanced tree-of-thoughts (TOT) multi-branch reasoning agent from scratch. Instead of relying on linear thought-chain logic, we design a system that generates multiple logic branches, scores each branch using a heuristic evaluation function, prunes weak candidates, and continues expanding only the strongest paths. We combine an instruction-tuned Transformer model with a custom tree structure and implement beam-search style selection with depth-limited search. By grounding the system in a 24-game domain, we create a clear, objective benchmark for reasoning where we can observe branch expansion, pruning, scoring, and goal detection in action.
!pip -q install -U transformers accelerate sentencepiece
import re
import math
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Dict, Any
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
MODEL_NAME = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print("Device:", device)
print("Model loaded:", MODEL_NAME)
@dataclass
class Node:
depth: int
numbers: List(float)
exprs: List(str)
thought: str = ""
score: float = -1e9
is_goal: bool = False
parent: Optional("Node") = None
meta: Dict(str, Any) = field(default_factory=dict)
def pretty_state(nums: List(float), exprs: List(str)) -> str:
pairs = (f"{e}={n:g}" for e, n in zip(exprs, nums))
return " | ".join(pairs)
We install the required libraries and load the FLAN-T5 model using the correct Seq2Seq architecture. We define our basic node data structure that represents each argument state in the tree-of-thoughts search. We also initialize the device configuration and helper utilities that enable us to clearly print and inspect the logic state.
OPS = ("+", "-", "*", "/")
def safe_apply(a: float, b: float, op: str) -> Optional(float):
if op == "+": return a + b
if op == "-": return a - b
if op == "*": return a * b
if op == "/":
if abs(b) < 1e-12:
return None
return a / b
return None
def combine_expr(ea: str, eb: str, op: str) -> str:
return f"({ea} {op} {eb})"
def is_24(x: float, tol: float = 1e-6) -> bool:
return abs(x - 24.0) <= tol
def one_step_closeness(nums: List(float)) -> float:
if len(nums) == 1:
return abs(nums(0) - 24.0)
best = float("inf")
n = len(nums)
for i in range(n):
for j in range(n):
if i == j:
continue
a, b = nums(i), nums(j)
for op in OPS:
r = safe_apply(a, b, op)
if r is None:
continue
best = min(best, abs(r - 24.0))
return best if best != float("inf") else 1e9
def heuristic_score(node: Node) -> float:
nums = node.numbers
base = -one_step_closeness(nums)
depth_penalty = 0.05 * node.depth
exact_bonus = 2.0 if any(is_24(x) for x in nums) else 0.0
return base - depth_penalty + exact_bonus
We apply mathematical logic to the 24-game domain. We define safe operator execution, expression construction, goal checking, and a heuristic scoring function that estimates how close the state is to the goal 24 . We design the heuristic to guide the search intelligently while penalizing deep branches.
PROPOSER_PROMPT = """You are helping solve the 24 game.
We have current items, each item has an expression and its numeric value.
Pick TWO items and combine them with one operation from + - * / to create a new item.
Return between {k} and {k2} suggestions as lines using EXACT format:
i,j,op
Where i and j are 0-based indices into the list. Use i != j. Prefer moves that help reach 24.
Current items:
{items}
"""
def llm_generate_suggestions(items: str, k_min: int, k_max: int, max_new_tokens: int = 160) -> str:
prompt = PROPOSER_PROMPT.format(k=k_min, k2=k_max, items=items)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.8,
top_p=0.92,
num_return_sequences=1,
)
txt = tokenizer.decode(out(0), skip_special_tokens=True)
return txt.strip()
def parse_moves(text: str, n_items: int) -> List(Tuple(int, int, str)):
moves = ()
for line in text.splitlines():
line = line.strip()
m = re.match(r"^s*(d+)s*,s*(d+)s*,s*((+-*/))s*$", line)
if not m:
continue
i, j, op = int(m.group(1)), int(m.group(2)), m.group(3)
if 0 <= i < n_items and 0 <= j < n_items and i != j:
moves.append((i, j, op))
seen = set()
uniq = ()
for mv in moves:
if mv not in seen:
uniq.append(mv)
seen.add(mv)
return uniq
def fallback_moves(nums: List(float), limit: int = 24) -> List(Tuple(int, int, str)):
scored = ()
n = len(nums)
for i in range(n):
for j in range(n):
if i == j:
continue
for op in OPS:
r = safe_apply(nums(i), nums(j), op)
if r is None:
continue
scored.append((abs(r - 24.0), i, j, op))
scored.sort(key=lambda x: x(0))
out = ((i, j, op) for _, i, j, op in scored(:limit))
seen, uniq = set(), ()
for mv in out:
if mv not in seen:
uniq.append(mv)
seen.add(mv)
return uniq
We build an LLM proposer that generates multiple logic branches. We carefully format the prompt so that the model can return structured combinatorial operations and parse those outputs into executable moves. We also implement a deterministic fallback strategy to ensure that the detection remains robust even when the model output is noisy.
def apply_move(node: Node, i: int, j: int, op: str) -> Optional(Node):
nums = node.numbers(:)
exprs = node.exprs(:)
a, b = nums(i), nums(j)
r = safe_apply(a, b, op)
if r is None:
return None
ea, eb = exprs(i), exprs(j)
new_expr = combine_expr(ea, eb, op)
for idx in sorted((i, j), reverse=True):
nums.pop(idx)
exprs.pop(idx)
nums.append(r)
exprs.append(new_expr)
child = Node(
depth=node.depth + 1,
numbers=nums,
exprs=exprs,
parent=node,
thought=f"Combine item {i} and {j} with '{op}' -> {new_expr} = {r:g}",
)
child.is_goal = (len(nums) == 1 and is_24(nums(0)))
child.score = heuristic_score(child)
return child
def expand(node: Node, branch_factor: int, proposer_kmin: int = 8, proposer_kmax: int = 14) -> List(Node):
items_str = "n".join((f"{idx}: {node.exprs(idx)} = {node.numbers(idx):g}" for idx in range(len(node.numbers))))
raw = llm_generate_suggestions(items_str, proposer_kmin, proposer_kmax)
moves = parse_moves(raw, len(node.numbers))
if not moves:
moves = fallback_moves(node.numbers, limit=30)
moves = moves(: max(branch_factor * 2, branch_factor))
children = ()
for (i, j, op) in moves:
ch = apply_move(node, i, j, op)
if ch is not None:
children.append(ch)
children.sort(key=lambda x: x.score, reverse=True)
return children(:branch_factor)
We implement the branch expansion mechanism of the Tree-of-Thoughts algorithm. We apply the proposed steps to create new child nodes and calculate their heuristic scores. We then prune weak branches locally, keeping only the strongest candidates for further exploration.
def reconstruct_solution(goal: Node) -> List(str):
path = ()
cur = goal
while cur is not None:
if cur.thought:
path.append(cur.thought)
cur = cur.parent
return list(reversed(path))
def tot_solve_24(
start_nums: List(int),
beam_width: int = 10,
branch_factor: int = 8,
max_depth: int = 3,
prune_threshold: float = -10.0,
verbose: bool = True
) -> Dict(str, Any):
root = Node(
depth=0,
numbers=(float(x) for x in start_nums),
exprs=(str(x) for x in start_nums),
)
root.score = heuristic_score(root)
beam = (root)
best_seen = root
if verbose:
print("n=== ToT Search Start ===")
print("Start:", pretty_state(root.numbers, root.exprs))
print("Root score:", root.score)
for d in range(max_depth):
candidates: List(Node) = ()
if verbose:
print(f"n--- Depth {d} -> {d+1} expansion ---")
print("Beam states:")
for bidx, b in enumerate(beam(: min(len(beam), 6))):
print(f" ({bidx}) score={b.score:.3f} | {pretty_state(b.numbers, b.exprs)}")
for b in beam:
kids = expand(b, branch_factor=branch_factor)
candidates.extend(kids)
if not candidates:
break
candidates = (c for c in candidates if c.score >= prune_threshold)
goals = (c for c in candidates if c.is_goal)
if goals:
goals.sort(key=lambda x: x.score, reverse=True)
sol = goals(0)
steps = reconstruct_solution(sol)
return {
"solved": True,
"start": start_nums,
"expression": sol.exprs(0),
"value": sol.numbers(0),
"steps": steps,
"final_score": sol.score
}
candidates.sort(key=lambda x: x.score, reverse=True)
beam = candidates(:beam_width)
if beam and beam(0).score > best_seen.score:
best_seen = beam(0)
if verbose:
print("Top candidates after pruning/beam:")
for cidx, c in enumerate(beam(: min(len(beam), 6))):
print(f" ({cidx}) score={c.score:.3f} | {pretty_state(c.numbers, c.exprs)}")
best_expr = best_seen.exprs(0) if len(best_seen.exprs) == 1 else " ; ".join(best_seen.exprs)
best_val = best_seen.numbers(0) if len(best_seen.numbers) == 1 else None
return {
"solved": False,
"start": start_nums,
"best_state": pretty_state(best_seen.numbers, best_seen.exprs),
"best_expression": best_expr,
"best_value": best_val,
"final_score": best_seen.score,
"note": "Not solved within depth/beam limits; increase beam_width/branch_factor or adjust pruning."
}
tests = (
(4, 1, 8, 7),
(3, 3, 8, 8),
(6, 6, 6, 6),
(9, 9, 4, 4),
)
for nums in tests:
result = tot_solve_24(
nums,
beam_width=12,
branch_factor=10,
max_depth=3,
prune_threshold=-12.0,
verbose=True
)
print("n=== RESULT ===")
for k, v in result.items():
if k == "steps":
print("steps:")
for s in v:
print(" -", s)
else:
print(f"{k}: {v}")
print("n" + "="*80 + "n")
print("""
To adapt this ToT agent beyond the 24 game:
1) Define a STATE representation (like numbers/exprs here).
2) Define a PROPOSER that generates candidate next steps (LLM tool or rule-based).
3) Define a HEURISTIC / SCORER:
- for checkable tasks, use objective scoring
- for open-ended tasks, use an LLM-critic scoring rubric
4) Run the same ToT loop:
expand -> score -> prune -> keep top beam -> repeat until goal or depth limit.
""")
We implement a full tree-of-thoughts search loop using beam search and depth limits. We expand, score, prune, and select the top branches at each depth until we reach a solution or the search budget is exhausted. Finally, we reconstruct the reasoning path and demonstrate how the agent solves several 24-game examples step by step.
Finally, we built a complete multi-branch reasoning agent that demonstrates how Tree-of-Thoughts transforms LLM reasoning from a path to a structured search process. We have implemented branch generation, heuristic scoring, pruning, beam selection, and depth control in a modular architecture that can be easily adapted to other logic problems. Through this tutorial, we saw how combining language models with search algorithms significantly improves structured decision making. We now have a reusable ToT framework that we can extend to mathematical reasoning, planning tasks, symbolic search, or even LLM-critic-based evaluation systems.
check it out full code here. Also, feel free to follow us Twitter And don’t forget to join us 120k+ ml subreddit and subscribe our newsletter. wait! Are you on Telegram? Now you can also connect with us on Telegram.
