Files
tinygrad/CLAUDE.md

8.9 KiB

Claude Code Guide for tinygrad

Architecture Overview

tinygrad compiles tensor operations into optimized kernels. The pipeline:

  1. Tensor (tensor.py) - User-facing API, creates UOp graph
  2. UOp (uop/ops.py) - Unified IR for all operations (both tensor and kernel level)
  3. Schedule (engine/schedule.py, schedule/) - Converts tensor UOps to kernel UOps
  4. Codegen (codegen/) - Converts kernel UOps to device code
  5. Runtime (runtime/) - Device-specific execution

Key Concepts

UOp (Universal Operation)

Everything is a UOp - tensors, operations, buffers, kernels. Key properties:

  • op: The operation type (Ops enum)
  • dtype: Data type
  • src: Tuple of source UOps
  • arg: Operation-specific argument
  • tag: Optional tag for graph transformations

UOps are immutable and cached - creating the same UOp twice returns the same object (ucache).

PatternMatcher

Used extensively for graph transformations:

pm = PatternMatcher([
  (UPat(Ops.ADD, src=(UPat.cvar("x"), UPat.cvar("x"))), lambda x: x * 2),
])
result = graph_rewrite(uop, pm)

Schedule Cache

Schedules are cached by graph structure. BIND nodes (variables with bound values) are unbound before cache key computation so different values hit the same cache.

Testing

# Run specific test
python -m pytest test/unit/test_schedule_cache.py -xvs

# Run with timeout
python -m pytest test/test_symbolic_ops.py -x --timeout=60

# Debug with print
DEBUG=2 python -m pytest test/test_schedule.py::test_name -xvs

# Visualize UOp graphs
VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"

Common Environment Variables

  • DEBUG=1-7 - Increasing verbosity (7 shows assembly output)
  • VIZ=1 - Enable graph visualization
  • SPEC=1 - Enable UOp spec verification
  • NOOPT=1 - Disable optimizations
  • DEVICE=CPU/CUDA/AMD/METAL - Set default device

Debugging Tips

  1. Print UOp graphs: print(tensor.uop) or print(tensor.uop.sink())
  2. Check schedule: tensor.schedule() returns list of ExecItems
  3. Trace graph rewrites: Use VIZ=1 or add print in PatternMatcher callbacks
  4. Find UOps by type: [u for u in uop.toposort() if u.op is Ops.SOMETHING]

Workflow Rules

  • NEVER commit without explicit user approval - always show the diff and wait for approval
  • NEVER amend commits - always create a new commit instead
  • Run pre-commit run --all-files before committing to catch linting/type errors
  • Run tests before proposing commits
  • Test with SPEC=2 when modifying UOp-related code

Auto-generated Files (DO NOT EDIT)

The following files are auto-generated and should never be edited manually:

  • extra/assembly/amd/autogen/{arch}/__init__.py - Generated by python -m extra.assembly.amd.dsl --arch {arch}
  • extra/assembly/amd/autogen/{arch}/gen_pcode.py - Generated by python -m extra.assembly.amd.pcode --arch {arch}

Where {arch} is one of: rdna3, rdna4, cdna

To add missing instruction implementations, add them to extra/assembly/amd/emu.py instead.

Style Notes

  • 2-space indentation, 150 char line limit
  • PatternMatchers should be defined at module level (slow to construct)
  • Prefer graph_rewrite over manual graph traversal
  • UOp methods like .replace() preserve tags unless explicitly changed
  • Use .rtag(value) to add tags to UOps

Lessons Learned

UOp ucache Behavior

UOps are cached by their contents - creating a UOp with identical (op, dtype, src, arg) returns the same object. This means:

  • uop.replace(tag=None) on a tagged UOp returns the original untagged UOp if it exists in cache
  • Two UOps with same structure are identical (is comparison works)

Spec Validation

When adding new UOp patterns, update tinygrad/uop/spec.py. Test with:

SPEC=2 python3 test/unit/test_something.py

Spec issues appear as RuntimeError: SPEC ISSUE None: UOp(...).

Schedule Cache Key Normalization

The schedule cache strips values from BIND nodes so different bound values (e.g., KV cache positions) hit the same cache entry:

  • pm_pre_sched_cache: BIND(DEFINE_VAR, CONST) → BIND(DEFINE_VAR) for cache key
  • pm_post_sched_cache: restores original BIND from context
  • When accessing bind.src[1], check len(bind.src) > 1 first (might be stripped)
  • Extract var_vals from input_buffers dict after graph_rewrite (avoids extra toposort)

Avoiding Extra Work

  • Use ctx dict from graph_rewrite to collect info during traversal instead of separate toposort
  • Only extract var_vals when schedule is non-empty (no kernels = no vars needed)
  • PatternMatchers are slow to construct - define at module level, not in functions

Readability Over Speed

Don't add complexity for marginal performance gains. Simpler code that's slightly slower is often better:

# BAD: "optimized" with extra complexity
if has_afters:  # skip toposort if no AFTERs
  after_map = [(u, u.buf_uop) for u in big_sink.toposort() if u.op is Ops.AFTER]

# GOOD: simple, always works
after_map = [(u, u.buf_uop) for u in big_sink.toposort() if u.op is Ops.AFTER]

The conditional check adds complexity, potential bugs, and often negligible speedup. Only optimize when profiling shows a real bottleneck.

Testing LLM Changes

# Quick smoke test
echo "Hello" | DEBUG=1 python tinygrad/apps/llm.py --model "llama3.2:1b"

# Check cache hits (should see "cache hit" after warmup)
echo "Hello world" | DEBUG=1 python tinygrad/apps/llm.py --model "llama3.2:1b" 2>&1 | grep cache

# Test with beam search
echo "Hello" | BEAM=2 python tinygrad/apps/llm.py --model "llama3.2:1b"

Common Patterns

Graph Transformation

def my_transform(ctx, x):
  # Return new UOp or None to skip
  return x.replace(arg=new_arg)

pm = PatternMatcher([
  (UPat(Ops.SOMETHING, name="x"), my_transform),
])
result = graph_rewrite(input_uop, pm, ctx={})

Finding Variables

# Get all variables in a UOp graph
variables = uop.variables()

# Get bound variable values
var, val = bind_uop.unbind()

Shape Handling

# Shapes can be symbolic (contain UOps)
shape = tensor.shape  # tuple[sint, ...] where sint = int | UOp

Performance Optimization

When optimizing tinygrad internals:

  1. Measure wall time, not just call counts - Reducing graph_rewrite calls doesn't always improve wall time. The overhead of conditional checks can exceed the cost of the operation being skipped.

  2. Profile each optimization individually - Run benchmarks with and without each change to measure actual impact. Use test/external/external_benchmark_schedule.py for schedule/rewrite timing.

  3. Early exits in hot paths are effective - Simple checks like if self.op is Ops.CONST: return self in simplify() can eliminate many unnecessary graph_rewrite calls.

  4. graph_rewrite is expensive - Each call has overhead even for small graphs. Avoid calling it when the result is trivially known (e.g., simplifying a CONST returns itself).

  5. Beware iterator overhead - Checks like all(x.op is Ops.CONST for x in self.src) can be slower than just running the operation, especially for small sequences.

  6. Verify cache hit rates before adding/keeping caches - Measure actual hit rates with real workloads. A cache with 0% hit rate is pure overhead (e.g., pm_cache was removed because the algorithm guarantees each UOp is only passed to pm_rewrite once).

  7. Use TRACK_MATCH_STATS=2 to profile pattern matching - This shows match rates and time per pattern. Look for patterns with 0% match rate that still cost significant time - these are pure overhead for that workload.

  8. Cached properties beat manual traversal - backward_slice uses @functools.cached_property. A DFS with early-exit sounds faster but is actually slower because it doesn't benefit from caching. The cache hit benefit often outweighs algorithmic improvements.

  9. Avoid creating intermediate objects in hot paths - For example, any(x.op in ops for x in self.backward_slice) is faster than any(x.op in ops for x in {self:None, **self.backward_slice}) because it avoids dict creation.

Pattern Matching Profiling

Use TRACK_MATCH_STATS=2 to identify expensive patterns:

TRACK_MATCH_STATS=2 PYTHONPATH="." python3 test/external/external_benchmark_schedule.py

Output format: matches / attempts -- match_time / total_time ms -- location

Key patterns to watch (from ResNet50 benchmark):

  • split_load_store: ~146ms, 31% match rate - does real work
  • simplify_valid: ~75ms, 0% match rate in this workload - checks AND ops for INDEX in backward slice
  • vmin==vmax folding: ~55ms, 0.33% match rate - checks 52K ops but rarely matches

Patterns with 0% match rate are workload-specific overhead. They may be useful in other workloads, so don't remove them without understanding their purpose.

AMD Performance Counter Profiling

Set VIZ to -2 to save performance counters traces for the AMD backend.

Use the CLI in ./extra/sqtt/roc.py to explore the trace.