move functions to view and update docs [pr] (#10904)

* move functions to view and update docs [pr]

* move quantize
This commit is contained in:
George Hotz
2025-06-20 16:47:58 -07:00
committed by GitHub
parent b41e0563a3
commit 1ce63f8d04
9 changed files with 169 additions and 97 deletions

47
docs/developer/layout.md Normal file
View File

@@ -0,0 +1,47 @@
# tinygrad directory layout
Listed in order of how they are processed
---
## tinygrad/kernelize
Group UOps into kernels.
::: tinygrad.kernelize.kernelize.get_kernelize_map
options:
members: false
show_labels: false
show_source: false
---
## tinygrad/opt
Transforms the ast into an optimized ast. This is where BEAM search and heuristics live.
When finished, this will just have a function that takes in the ast and returns the optimized ast.
---
## tinygrad/codegen
Transform the optimized ast into a linearized list of UOps.
::: tinygrad.codegen.full_rewrite
options:
members: false
show_labels: false
show_source: false
---
## tinygrad/renderer
Transform the linearized list of UOps into a program.
---
## tinygrad/engine
Abstracted high level interface to the runtimes.

View File

@@ -22,6 +22,7 @@ nav:
- Runtime: runtime.md
- Developer:
- Intro: developer/developer.md
- Layout: developer/layout.md
- Speed: developer/speed.md
- UOp: developer/uop.md
- Grouper:

View File

@@ -3,7 +3,7 @@ from tinygrad import Variable
from tinygrad.helpers import Context, ContextVar, argfix
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, polyN, time_to_str, cdiv, cmod, getbits
from tinygrad.tensor import get_shape
from tinygrad.codegen.lowerer import get_contraction, get_contraction_with_reduce
from tinygrad.shape.view import get_contraction, get_contraction_with_reduce
import numpy as np
VARIABLE = ContextVar("VARIABLE", 0)

View File

@@ -7,7 +7,8 @@ from tinygrad.uop.spec import type_verify
from tinygrad.renderer import Renderer
# import all pattern matchers here
from tinygrad.codegen.lowerer import pm_quant, pm_lowerer, get_index
from tinygrad.codegen.lowerer import pm_lowerer, get_index
from tinygrad.codegen.quantize import pm_quant
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing
from tinygrad.codegen.expander import migrate_indexing, expander
from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexing, devectorize, \
@@ -75,6 +76,17 @@ def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, linearizer:bool=Fals
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), linearizer))
def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]:
"""
Function to transform the Kernel UOp graph into a linearized program.
Args:
sink: The Ops.SINK rooting the Kernel graph.
opts: The Renderer (can change how things are processed, fix this).
Returns:
Linear program in UOps.
"""
lst = list(full_rewrite_to_sink(sink, opts, linearizer=True).arg.lst)
if __debug__: type_verify(lst)
return lst

View File

@@ -1,37 +1,12 @@
# the job of the lowerer is to do indexing
import itertools, operator, math
import math
from dataclasses import dataclass
from typing import cast
from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint, sint_to_uop
from tinygrad.renderer import Renderer
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap
from tinygrad.uop.symbolic import symbolic
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
except ValueError: return None
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
def get_contraction_with_reduce(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...], reduce_axis:tuple[int, ...]) -> list[list[int]]|None:
if (contraction:=get_contraction(old_shape, new_shape)) is None: return None
# contraction returns the 1s as right justified as possible
# normally this contraction is good, but sometimes the reduce dim is empty. borrow from the next one, leaving one
# this ensures there's always ones available in the reduce dimension. this is also a valid contraction
for i in range(len(contraction)):
if i in reduce_axis and len(contraction[i]) == 0:
take_from = i+1
while take_from < len(contraction) and len(contraction[take_from]) == 0:
assert new_shape[take_from] == 1
take_from += 1
if take_from == len(contraction) or new_shape[take_from] != 1: return None # nothing to take
for j in range(take_from, i, -1):
assert len(contraction[j]) > 0
contraction[j-1] = contraction[j][:-1]
contraction[j] = contraction[j][-1:]
return contraction
from tinygrad.shape.view import get_contraction
# ***** indexing *****
def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
@@ -166,66 +141,3 @@ pm_lowerer = PatternMatcher([
(UPat((Ops.LOAD, Ops.STORE), src=(UPat.var("buf").view(),), allow_any_len=True, name="x"), lower_load_store),
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
])
# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints ****
FP = (1 << 15)
pm_quant = symbolic+PatternMatcher([
# cast after add/mul
(UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32),
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))+y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
(UPat.var("x").cast(dtypes.float32) * UPat.var("y").cast(dtypes.float32),
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))*y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
# masked MUL after masked ADD
((UPat.var("x") + UPat.var("v").where(UPat.var('cadd'), UPat(Ops.CONST, arg=0))) * UPat.var("v").where(UPat.var('cmul'), UPat(Ops.CONST, arg=0)),
lambda x,v,cadd,cmul: x*v.where(cmul, 0)+v.where(cadd*cmul, 0)),
# MUL after reduce
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("x") * UPat.cvar("c"),), name="r"), lambda x,c,r: r.replace(src=(x,))*c.arg),
# CAST after reduce (doesn't work if it's a size change)
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="r"),
lambda x,r: r.replace(dtype=x.dtype, src=(x,)).cast(r.dtype) if dtypes.is_float(r.dtype) else None),
# x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats)
(UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats),
lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None),
# mul 0 * c1 is 0
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
# mul (with plus) 0 * c1 is 0
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
(UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int) + \
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
lambda ld,v,c1: ld*c1),
# const push through add
((UPat.var("x")*UPat.cvar("c1") + UPat.var("y")*UPat.cvar("c2")) * UPat.cvar("c3"), lambda x,y,c1,c2,c3: (x*c1*c3) + (y*c2*c3)),
# fixed point mult, replace (x.float()*c1+c2).int() with an int expression
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("cc")).cast(dtypes.int),
lambda x,c1,cc: ((x*(c1*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
# fixed point mult, replace (x.float()*c1 + y.float()*c2)*cc.int() with an int expression
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")+UPat.var("cc")).cast(dtypes.int),
lambda x,c1,y,c2,cc: ((x*(c1*FP).cast(x.dtype) + y.cast(x.dtype)*(c2*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
# where move
(UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul:
(yes*mul*valid.where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))) if yes.op is not Ops.CONST or yes.arg != 1 else None),
((UPat.var("x")*UPat.cvar("c"))*(UPat.var().where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)).named("v")), lambda x,c,v: (x*v)*c),
(UPat.var("x").cast().named('c') * UPat.var('valid').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)), lambda x,c,valid:
(x*valid.where(UOp.const(x.dtype, 1), UOp.const(x.dtype, 0))).cast(c.dtype)),
((UPat.var('x') * UPat.var('v1').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)) *
UPat.var('v2').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))).named("mul"), lambda x, mul, v1, v2:
x * (v1&v2).where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))),
# where on two adds
(UPat.var("x") + UPat.var("v").where(UPat.var("a0"), UPat.var("a1")) + UPat.var("v").where(UPat.var("b0"), UPat.var("b1")),
lambda x,v,a0,a1,b0,b1: x + v.where(a0+b0, a1+b1)),
# split REDUCE into multiple reduces (who remembers FOIL?)
(UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * UPat(Ops.CAST, name="v2"),), name="r"),
lambda v1,v2,c1,r: r.replace(src=(v1*v2,)) + r.replace(src=(c1*v2,))),
(UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")),), name="r"),
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))),
])

View File

@@ -0,0 +1,67 @@
from tinygrad.dtype import dtypes, least_upper_dtype
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
from tinygrad.uop.symbolic import symbolic
# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints ****
# this is badly tested and low quality. remove it?
FP = (1 << 15)
pm_quant = symbolic+PatternMatcher([
# cast after add/mul
(UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32),
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))+y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
(UPat.var("x").cast(dtypes.float32) * UPat.var("y").cast(dtypes.float32),
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))*y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
# masked MUL after masked ADD
((UPat.var("x") + UPat.var("v").where(UPat.var('cadd'), UPat(Ops.CONST, arg=0))) * UPat.var("v").where(UPat.var('cmul'), UPat(Ops.CONST, arg=0)),
lambda x,v,cadd,cmul: x*v.where(cmul, 0)+v.where(cadd*cmul, 0)),
# MUL after reduce
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("x") * UPat.cvar("c"),), name="r"), lambda x,c,r: r.replace(src=(x,))*c.arg),
# CAST after reduce (doesn't work if it's a size change)
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="r"),
lambda x,r: r.replace(dtype=x.dtype, src=(x,)).cast(r.dtype) if dtypes.is_float(r.dtype) else None),
# x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats)
(UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats),
lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None),
# mul 0 * c1 is 0
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
# mul (with plus) 0 * c1 is 0
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
(UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int) + \
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
lambda ld,v,c1: ld*c1),
# const push through add
((UPat.var("x")*UPat.cvar("c1") + UPat.var("y")*UPat.cvar("c2")) * UPat.cvar("c3"), lambda x,y,c1,c2,c3: (x*c1*c3) + (y*c2*c3)),
# fixed point mult, replace (x.float()*c1+c2).int() with an int expression
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("cc")).cast(dtypes.int),
lambda x,c1,cc: ((x*(c1*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
# fixed point mult, replace (x.float()*c1 + y.float()*c2)*cc.int() with an int expression
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")+UPat.var("cc")).cast(dtypes.int),
lambda x,c1,y,c2,cc: ((x*(c1*FP).cast(x.dtype) + y.cast(x.dtype)*(c2*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
# where move
(UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul:
(yes*mul*valid.where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))) if yes.op is not Ops.CONST or yes.arg != 1 else None),
((UPat.var("x")*UPat.cvar("c"))*(UPat.var().where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)).named("v")), lambda x,c,v: (x*v)*c),
(UPat.var("x").cast().named('c') * UPat.var('valid').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)), lambda x,c,valid:
(x*valid.where(UOp.const(x.dtype, 1), UOp.const(x.dtype, 0))).cast(c.dtype)),
((UPat.var('x') * UPat.var('v1').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)) *
UPat.var('v2').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))).named("mul"), lambda x, mul, v1, v2:
x * (v1&v2).where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))),
# where on two adds
(UPat.var("x") + UPat.var("v").where(UPat.var("a0"), UPat.var("a1")) + UPat.var("v").where(UPat.var("b0"), UPat.var("b1")),
lambda x,v,a0,a1,b0,b1: x + v.where(a0+b0, a1+b1)),
# split REDUCE into multiple reduces (who remembers FOIL?)
(UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * UPat(Ops.CAST, name="v2"),), name="r"),
lambda v1,v2,c1,r: r.replace(src=(v1*v2,)) + r.replace(src=(c1*v2,))),
(UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")),), name="r"),
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))),
])

View File

@@ -2,13 +2,12 @@ from dataclasses import dataclass
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve, sint
from tinygrad.uop.ops import track_rewrites, _substitute
from tinygrad.uop.spec import type_verify, tensor_uop_spec
from tinygrad.codegen.lowerer import get_contraction_with_reduce
from tinygrad.uop.symbolic import symbolic_simple
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
from tinygrad.dtype import ImageDType
from tinygrad.kernelize.multi import multi_pm
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
from tinygrad.kernelize.grouper import group_realizes, ALWAYS_CONTIGUOUS
# creation can recurse a lot
@@ -422,6 +421,16 @@ remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(
@track_rewrites(name=lambda big_sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[big_sink].toposort() if u.op is Ops.KERNEL]))}")
def get_kernelize_map(big_sink:UOp) -> dict[UOp, UOp]:
"""
Function to transform the Tensor UOp graph into a version with Ops.KERNEL
Args:
big_sink: The Ops.SINK rooting the Tensor graph.
Returns:
Map transforming each UOp in the big_sink to the Ops.KERNEL graph.
"""
# multi + merge_views + simplify
tensor_map = graph_rewrite_map(big_sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")

View File

@@ -13,8 +13,7 @@ from tinygrad.dtype import ImageDType
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, unwrap
from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, AMX
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import strides_for_shape
from tinygrad.codegen.lowerer import get_contraction
from tinygrad.shape.view import strides_for_shape, get_contraction
from tinygrad.kernelize.kernelize import view_left
from tinygrad.codegen import full_rewrite

View File

@@ -6,6 +6,31 @@ from tinygrad.dtype import dtypes
from tinygrad.uop.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop, Ops, ssimplify
from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
except ValueError: return None
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
def get_contraction_with_reduce(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...], reduce_axis:tuple[int, ...]) -> list[list[int]]|None:
if (contraction:=get_contraction(old_shape, new_shape)) is None: return None
# contraction returns the 1s as right justified as possible
# normally this contraction is good, but sometimes the reduce dim is empty. borrow from the next one, leaving one
# this ensures there's always ones available in the reduce dimension. this is also a valid contraction
for i in range(len(contraction)):
if i in reduce_axis and len(contraction[i]) == 0:
take_from = i+1
while take_from < len(contraction) and len(contraction[take_from]) == 0:
assert new_shape[take_from] == 1
take_from += 1
if take_from == len(contraction) or new_shape[take_from] != 1: return None # nothing to take
for j in range(take_from, i, -1):
assert len(contraction[j]) > 0
contraction[j-1] = contraction[j][:-1]
contraction[j] = contraction[j][-1:]
return contraction
@functools.cache
def canonicalize_strides(shape:tuple[sint, ...], strides:tuple[sint, ...]) -> tuple[sint, ...]:
return tuple(0 if s == 1 else st for s, st in zip(shape, strides))