mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
move functions to view and update docs [pr] (#10904)
* move functions to view and update docs [pr] * move quantize
This commit is contained in:
47
docs/developer/layout.md
Normal file
47
docs/developer/layout.md
Normal file
@@ -0,0 +1,47 @@
|
||||
# tinygrad directory layout
|
||||
|
||||
Listed in order of how they are processed
|
||||
|
||||
---
|
||||
|
||||
## tinygrad/kernelize
|
||||
|
||||
Group UOps into kernels.
|
||||
|
||||
::: tinygrad.kernelize.kernelize.get_kernelize_map
|
||||
options:
|
||||
members: false
|
||||
show_labels: false
|
||||
show_source: false
|
||||
|
||||
---
|
||||
|
||||
## tinygrad/opt
|
||||
|
||||
Transforms the ast into an optimized ast. This is where BEAM search and heuristics live.
|
||||
|
||||
When finished, this will just have a function that takes in the ast and returns the optimized ast.
|
||||
|
||||
---
|
||||
|
||||
## tinygrad/codegen
|
||||
|
||||
Transform the optimized ast into a linearized list of UOps.
|
||||
|
||||
::: tinygrad.codegen.full_rewrite
|
||||
options:
|
||||
members: false
|
||||
show_labels: false
|
||||
show_source: false
|
||||
|
||||
---
|
||||
|
||||
## tinygrad/renderer
|
||||
|
||||
Transform the linearized list of UOps into a program.
|
||||
|
||||
---
|
||||
|
||||
## tinygrad/engine
|
||||
|
||||
Abstracted high level interface to the runtimes.
|
||||
@@ -22,6 +22,7 @@ nav:
|
||||
- Runtime: runtime.md
|
||||
- Developer:
|
||||
- Intro: developer/developer.md
|
||||
- Layout: developer/layout.md
|
||||
- Speed: developer/speed.md
|
||||
- UOp: developer/uop.md
|
||||
- Grouper:
|
||||
|
||||
@@ -3,7 +3,7 @@ from tinygrad import Variable
|
||||
from tinygrad.helpers import Context, ContextVar, argfix
|
||||
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, polyN, time_to_str, cdiv, cmod, getbits
|
||||
from tinygrad.tensor import get_shape
|
||||
from tinygrad.codegen.lowerer import get_contraction, get_contraction_with_reduce
|
||||
from tinygrad.shape.view import get_contraction, get_contraction_with_reduce
|
||||
import numpy as np
|
||||
|
||||
VARIABLE = ContextVar("VARIABLE", 0)
|
||||
|
||||
@@ -7,7 +7,8 @@ from tinygrad.uop.spec import type_verify
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.lowerer import pm_quant, pm_lowerer, get_index
|
||||
from tinygrad.codegen.lowerer import pm_lowerer, get_index
|
||||
from tinygrad.codegen.quantize import pm_quant
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing
|
||||
from tinygrad.codegen.expander import migrate_indexing, expander
|
||||
from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexing, devectorize, \
|
||||
@@ -75,6 +76,17 @@ def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, linearizer:bool=Fals
|
||||
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), linearizer))
|
||||
|
||||
def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]:
|
||||
"""
|
||||
Function to transform the Kernel UOp graph into a linearized program.
|
||||
|
||||
Args:
|
||||
sink: The Ops.SINK rooting the Kernel graph.
|
||||
opts: The Renderer (can change how things are processed, fix this).
|
||||
|
||||
Returns:
|
||||
Linear program in UOps.
|
||||
"""
|
||||
|
||||
lst = list(full_rewrite_to_sink(sink, opts, linearizer=True).arg.lst)
|
||||
if __debug__: type_verify(lst)
|
||||
return lst
|
||||
|
||||
@@ -1,37 +1,12 @@
|
||||
# the job of the lowerer is to do indexing
|
||||
import itertools, operator, math
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint, sint_to_uop
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
|
||||
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
|
||||
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
|
||||
except ValueError: return None
|
||||
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
|
||||
|
||||
def get_contraction_with_reduce(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...], reduce_axis:tuple[int, ...]) -> list[list[int]]|None:
|
||||
if (contraction:=get_contraction(old_shape, new_shape)) is None: return None
|
||||
# contraction returns the 1s as right justified as possible
|
||||
# normally this contraction is good, but sometimes the reduce dim is empty. borrow from the next one, leaving one
|
||||
# this ensures there's always ones available in the reduce dimension. this is also a valid contraction
|
||||
for i in range(len(contraction)):
|
||||
if i in reduce_axis and len(contraction[i]) == 0:
|
||||
take_from = i+1
|
||||
while take_from < len(contraction) and len(contraction[take_from]) == 0:
|
||||
assert new_shape[take_from] == 1
|
||||
take_from += 1
|
||||
if take_from == len(contraction) or new_shape[take_from] != 1: return None # nothing to take
|
||||
for j in range(take_from, i, -1):
|
||||
assert len(contraction[j]) > 0
|
||||
contraction[j-1] = contraction[j][:-1]
|
||||
contraction[j] = contraction[j][-1:]
|
||||
return contraction
|
||||
from tinygrad.shape.view import get_contraction
|
||||
|
||||
# ***** indexing *****
|
||||
def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
|
||||
@@ -166,66 +141,3 @@ pm_lowerer = PatternMatcher([
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat.var("buf").view(),), allow_any_len=True, name="x"), lower_load_store),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
|
||||
])
|
||||
|
||||
# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints ****
|
||||
|
||||
FP = (1 << 15)
|
||||
pm_quant = symbolic+PatternMatcher([
|
||||
# cast after add/mul
|
||||
(UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32),
|
||||
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))+y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
|
||||
(UPat.var("x").cast(dtypes.float32) * UPat.var("y").cast(dtypes.float32),
|
||||
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))*y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
|
||||
|
||||
# masked MUL after masked ADD
|
||||
((UPat.var("x") + UPat.var("v").where(UPat.var('cadd'), UPat(Ops.CONST, arg=0))) * UPat.var("v").where(UPat.var('cmul'), UPat(Ops.CONST, arg=0)),
|
||||
lambda x,v,cadd,cmul: x*v.where(cmul, 0)+v.where(cadd*cmul, 0)),
|
||||
|
||||
# MUL after reduce
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("x") * UPat.cvar("c"),), name="r"), lambda x,c,r: r.replace(src=(x,))*c.arg),
|
||||
# CAST after reduce (doesn't work if it's a size change)
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="r"),
|
||||
lambda x,r: r.replace(dtype=x.dtype, src=(x,)).cast(r.dtype) if dtypes.is_float(r.dtype) else None),
|
||||
|
||||
# x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats)
|
||||
(UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats),
|
||||
lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None),
|
||||
# mul 0 * c1 is 0
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
|
||||
UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
|
||||
# mul (with plus) 0 * c1 is 0
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
|
||||
(UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int) + \
|
||||
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
|
||||
lambda ld,v,c1: ld*c1),
|
||||
|
||||
# const push through add
|
||||
((UPat.var("x")*UPat.cvar("c1") + UPat.var("y")*UPat.cvar("c2")) * UPat.cvar("c3"), lambda x,y,c1,c2,c3: (x*c1*c3) + (y*c2*c3)),
|
||||
|
||||
# fixed point mult, replace (x.float()*c1+c2).int() with an int expression
|
||||
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("cc")).cast(dtypes.int),
|
||||
lambda x,c1,cc: ((x*(c1*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
|
||||
# fixed point mult, replace (x.float()*c1 + y.float()*c2)*cc.int() with an int expression
|
||||
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")+UPat.var("cc")).cast(dtypes.int),
|
||||
lambda x,c1,y,c2,cc: ((x*(c1*FP).cast(x.dtype) + y.cast(x.dtype)*(c2*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
|
||||
|
||||
# where move
|
||||
(UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul:
|
||||
(yes*mul*valid.where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))) if yes.op is not Ops.CONST or yes.arg != 1 else None),
|
||||
((UPat.var("x")*UPat.cvar("c"))*(UPat.var().where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)).named("v")), lambda x,c,v: (x*v)*c),
|
||||
(UPat.var("x").cast().named('c') * UPat.var('valid').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)), lambda x,c,valid:
|
||||
(x*valid.where(UOp.const(x.dtype, 1), UOp.const(x.dtype, 0))).cast(c.dtype)),
|
||||
((UPat.var('x') * UPat.var('v1').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)) *
|
||||
UPat.var('v2').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))).named("mul"), lambda x, mul, v1, v2:
|
||||
x * (v1&v2).where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))),
|
||||
|
||||
# where on two adds
|
||||
(UPat.var("x") + UPat.var("v").where(UPat.var("a0"), UPat.var("a1")) + UPat.var("v").where(UPat.var("b0"), UPat.var("b1")),
|
||||
lambda x,v,a0,a1,b0,b1: x + v.where(a0+b0, a1+b1)),
|
||||
|
||||
# split REDUCE into multiple reduces (who remembers FOIL?)
|
||||
(UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * UPat(Ops.CAST, name="v2"),), name="r"),
|
||||
lambda v1,v2,c1,r: r.replace(src=(v1*v2,)) + r.replace(src=(c1*v2,))),
|
||||
(UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")),), name="r"),
|
||||
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))),
|
||||
])
|
||||
|
||||
67
tinygrad/codegen/quantize.py
Normal file
67
tinygrad/codegen/quantize.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from tinygrad.dtype import dtypes, least_upper_dtype
|
||||
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
|
||||
# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints ****
|
||||
# this is badly tested and low quality. remove it?
|
||||
|
||||
FP = (1 << 15)
|
||||
pm_quant = symbolic+PatternMatcher([
|
||||
# cast after add/mul
|
||||
(UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32),
|
||||
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))+y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
|
||||
(UPat.var("x").cast(dtypes.float32) * UPat.var("y").cast(dtypes.float32),
|
||||
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))*y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
|
||||
|
||||
# masked MUL after masked ADD
|
||||
((UPat.var("x") + UPat.var("v").where(UPat.var('cadd'), UPat(Ops.CONST, arg=0))) * UPat.var("v").where(UPat.var('cmul'), UPat(Ops.CONST, arg=0)),
|
||||
lambda x,v,cadd,cmul: x*v.where(cmul, 0)+v.where(cadd*cmul, 0)),
|
||||
|
||||
# MUL after reduce
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("x") * UPat.cvar("c"),), name="r"), lambda x,c,r: r.replace(src=(x,))*c.arg),
|
||||
# CAST after reduce (doesn't work if it's a size change)
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="r"),
|
||||
lambda x,r: r.replace(dtype=x.dtype, src=(x,)).cast(r.dtype) if dtypes.is_float(r.dtype) else None),
|
||||
|
||||
# x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats)
|
||||
(UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats),
|
||||
lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None),
|
||||
# mul 0 * c1 is 0
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
|
||||
UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
|
||||
# mul (with plus) 0 * c1 is 0
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
|
||||
(UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int) + \
|
||||
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
|
||||
lambda ld,v,c1: ld*c1),
|
||||
|
||||
# const push through add
|
||||
((UPat.var("x")*UPat.cvar("c1") + UPat.var("y")*UPat.cvar("c2")) * UPat.cvar("c3"), lambda x,y,c1,c2,c3: (x*c1*c3) + (y*c2*c3)),
|
||||
|
||||
# fixed point mult, replace (x.float()*c1+c2).int() with an int expression
|
||||
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("cc")).cast(dtypes.int),
|
||||
lambda x,c1,cc: ((x*(c1*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
|
||||
# fixed point mult, replace (x.float()*c1 + y.float()*c2)*cc.int() with an int expression
|
||||
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")+UPat.var("cc")).cast(dtypes.int),
|
||||
lambda x,c1,y,c2,cc: ((x*(c1*FP).cast(x.dtype) + y.cast(x.dtype)*(c2*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
|
||||
|
||||
# where move
|
||||
(UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul:
|
||||
(yes*mul*valid.where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))) if yes.op is not Ops.CONST or yes.arg != 1 else None),
|
||||
((UPat.var("x")*UPat.cvar("c"))*(UPat.var().where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)).named("v")), lambda x,c,v: (x*v)*c),
|
||||
(UPat.var("x").cast().named('c') * UPat.var('valid').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)), lambda x,c,valid:
|
||||
(x*valid.where(UOp.const(x.dtype, 1), UOp.const(x.dtype, 0))).cast(c.dtype)),
|
||||
((UPat.var('x') * UPat.var('v1').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)) *
|
||||
UPat.var('v2').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))).named("mul"), lambda x, mul, v1, v2:
|
||||
x * (v1&v2).where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))),
|
||||
|
||||
# where on two adds
|
||||
(UPat.var("x") + UPat.var("v").where(UPat.var("a0"), UPat.var("a1")) + UPat.var("v").where(UPat.var("b0"), UPat.var("b1")),
|
||||
lambda x,v,a0,a1,b0,b1: x + v.where(a0+b0, a1+b1)),
|
||||
|
||||
# split REDUCE into multiple reduces (who remembers FOIL?)
|
||||
(UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * UPat(Ops.CAST, name="v2"),), name="r"),
|
||||
lambda v1,v2,c1,r: r.replace(src=(v1*v2,)) + r.replace(src=(c1*v2,))),
|
||||
(UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")),), name="r"),
|
||||
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))),
|
||||
])
|
||||
@@ -2,13 +2,12 @@ from dataclasses import dataclass
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve, sint
|
||||
from tinygrad.uop.ops import track_rewrites, _substitute
|
||||
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
||||
from tinygrad.codegen.lowerer import get_contraction_with_reduce
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.kernelize.multi import multi_pm
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
|
||||
from tinygrad.kernelize.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
||||
|
||||
# creation can recurse a lot
|
||||
@@ -422,6 +421,16 @@ remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(
|
||||
|
||||
@track_rewrites(name=lambda big_sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[big_sink].toposort() if u.op is Ops.KERNEL]))}")
|
||||
def get_kernelize_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
"""
|
||||
Function to transform the Tensor UOp graph into a version with Ops.KERNEL
|
||||
|
||||
Args:
|
||||
big_sink: The Ops.SINK rooting the Tensor graph.
|
||||
|
||||
Returns:
|
||||
Map transforming each UOp in the big_sink to the Ops.KERNEL graph.
|
||||
"""
|
||||
|
||||
# multi + merge_views + simplify
|
||||
tensor_map = graph_rewrite_map(big_sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")
|
||||
|
||||
|
||||
@@ -13,8 +13,7 @@ from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, unwrap
|
||||
from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, AMX
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.codegen.lowerer import get_contraction
|
||||
from tinygrad.shape.view import strides_for_shape, get_contraction
|
||||
from tinygrad.kernelize.kernelize import view_left
|
||||
from tinygrad.codegen import full_rewrite
|
||||
|
||||
|
||||
@@ -6,6 +6,31 @@ from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop, Ops, ssimplify
|
||||
from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
|
||||
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
|
||||
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
|
||||
except ValueError: return None
|
||||
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
|
||||
|
||||
def get_contraction_with_reduce(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...], reduce_axis:tuple[int, ...]) -> list[list[int]]|None:
|
||||
if (contraction:=get_contraction(old_shape, new_shape)) is None: return None
|
||||
# contraction returns the 1s as right justified as possible
|
||||
# normally this contraction is good, but sometimes the reduce dim is empty. borrow from the next one, leaving one
|
||||
# this ensures there's always ones available in the reduce dimension. this is also a valid contraction
|
||||
for i in range(len(contraction)):
|
||||
if i in reduce_axis and len(contraction[i]) == 0:
|
||||
take_from = i+1
|
||||
while take_from < len(contraction) and len(contraction[take_from]) == 0:
|
||||
assert new_shape[take_from] == 1
|
||||
take_from += 1
|
||||
if take_from == len(contraction) or new_shape[take_from] != 1: return None # nothing to take
|
||||
for j in range(take_from, i, -1):
|
||||
assert len(contraction[j]) > 0
|
||||
contraction[j-1] = contraction[j][:-1]
|
||||
contraction[j] = contraction[j][-1:]
|
||||
return contraction
|
||||
|
||||
@functools.cache
|
||||
def canonicalize_strides(shape:tuple[sint, ...], strides:tuple[sint, ...]) -> tuple[sint, ...]:
|
||||
return tuple(0 if s == 1 else st for s, st in zip(shape, strides))
|
||||
|
||||
Reference in New Issue
Block a user