mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
@@ -574,7 +574,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
def test_in_out_bounds_access_with_mask(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 42),), "gidx0")
|
||||
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, (5<gidx0)&(gidx0<16)),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<16),))
|
||||
to_uops_list([ld0, ld1])
|
||||
@@ -598,7 +598,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 42),), "gidx0")
|
||||
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<8),)).cast(dtypes.index)
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<32)),))
|
||||
to_uops_list([ld1])
|
||||
@@ -834,7 +834,7 @@ class TestIFUOps(unittest.TestCase):
|
||||
valid = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), "gidx0")<1
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0")
|
||||
gate = valid&(lidx.ne(2))
|
||||
st = UOp(Ops.STORE, dtypes.void, (sbuf, lidx, UOp.const(dtypes.float, 42)))
|
||||
st = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx), UOp.const(dtypes.float, 42)))
|
||||
barrier = UOp(Ops.BARRIER, dtypes.void, (st,))
|
||||
lbufs = [UOp(Ops.LOAD, dtypes.float, (sbuf.index(UOp.const(dtypes.int, i)), barrier)) for i in range(4)]
|
||||
stores = [UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, i), gate), lbufs[i])) for i in range(4)]
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import unittest, math
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.helpers import all_same
|
||||
from tinygrad.helpers import all_same, Context
|
||||
from tinygrad.uop.ops import GroupOp, UOp, Ops, exec_alu, PatternMatcher, TrackedPatternMatcher, UPat
|
||||
from tinygrad.codegen import full_rewrite_to_sink
|
||||
from hypothesis import given, strategies as strat
|
||||
|
||||
# Helper function to apply the graph rewrite
|
||||
@Context(SPEC=0)
|
||||
def apply_rewrite(expr):
|
||||
return full_rewrite_to_sink(expr.sink()).src[0]
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class TestHelpers(unittest.TestCase):
|
||||
|
||||
class TestValidIdxSimplification(unittest.TestCase):
|
||||
def check(self, load, sidx, svalid):
|
||||
with Context(NOOPT=1):
|
||||
with Context(NOOPT=1, SPEC=0):
|
||||
load = full_rewrite_to_sink(load.sink()).src[0]
|
||||
idx, valid = load.src[0].src[1], load.src[0].src[2]
|
||||
check_uop_against_string(self, idx, sidx)
|
||||
@@ -213,7 +213,7 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
|
||||
class TestImageSimplification(unittest.TestCase):
|
||||
def check(self, load, svalid, sidx0, sidx1):
|
||||
with Context(NOOPT=1):
|
||||
with Context(NOOPT=1, SPEC=0):
|
||||
load = full_rewrite_to_sink(load.sink()).src[0]
|
||||
idx = load.src[0].src[1]
|
||||
self.assertEqual(idx.op, Ops.VECTORIZE)
|
||||
@@ -283,7 +283,8 @@ class TestImageSimplification(unittest.TestCase):
|
||||
|
||||
# empty -> invalid
|
||||
load = get_load_image_uop(shape, (gidx0<8) & (gidx0<8).ne(True), idx)
|
||||
load = full_rewrite_to_sink(load.sink()).src[0]
|
||||
with Context(NOOPT=1, SPEC=0):
|
||||
load = full_rewrite_to_sink(load.sink()).src[0]
|
||||
self.assertEqual(load.op, Ops.VECTORIZE)
|
||||
self.assertEqual(load.dtype.count, 4)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL, SPEC
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype
|
||||
from tinygrad.uop.spec import type_verify, program_spec
|
||||
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# import all pattern matchers here
|
||||
@@ -19,6 +19,8 @@ from tinygrad.codegen.late.control_flow import CFGContext, pm_split_ends, pm_add
|
||||
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
if ren is None: ren = Renderer()
|
||||
|
||||
if SPEC: type_verify(list(sink.toposort()), kernel_spec)
|
||||
|
||||
# first we optimize
|
||||
if optimize:
|
||||
if QUANTIZE and ren.device in {"CPU", "DSP"}: sink = graph_rewrite(sink, pm_quant, name="quantize")
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from typing import cast
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid
|
||||
from tinygrad.helpers import DEBUG, Context
|
||||
from tinygrad.helpers import DEBUG, Context, prod
|
||||
from tinygrad.uop.validate import validate_index
|
||||
|
||||
# four specs:
|
||||
# shared_spec -- usable anywhere
|
||||
# tensor_spec -- usable in tensor graph
|
||||
# kernel_spec -- usable in kernel passed into codegen
|
||||
# program_spec -- usable in linearized program
|
||||
# full_spec -- all uops ever created
|
||||
|
||||
@@ -15,6 +16,9 @@ from tinygrad.uop.validate import validate_index
|
||||
shared_spec = PatternMatcher([
|
||||
(UPat(Ops.SINK, dtypes.void), lambda: True), # NOTE: for testing, we let sinks be anything
|
||||
|
||||
# SENTINEL should never be anywhere
|
||||
(UPat(Ops.SENTINEL), lambda: False),
|
||||
|
||||
# CONST/DEFINE_VAR are everywhere
|
||||
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
||||
@@ -146,15 +150,33 @@ program_spec = PatternMatcher([
|
||||
(UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
|
||||
])+shared_spec
|
||||
|
||||
# ***** UOp spec in kernel graph *****
|
||||
|
||||
kernel_spec = PatternMatcher([
|
||||
# index is allowed here
|
||||
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
|
||||
|
||||
# UNROLL/CONTRACT is used here for WMMA
|
||||
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||
|
||||
# END can end multiple axes here
|
||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), allow_any_len=True, dtype=dtypes.void), lambda: True),
|
||||
|
||||
# bufferize (must be on ranges)
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.op in {Ops.RANGE, Ops.CONST} for y in x.src[1:])),
|
||||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
|
||||
|
||||
# intermediate index
|
||||
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
|
||||
])+program_spec+shared_spec
|
||||
|
||||
# *** this spec should match all UOps ever created ***
|
||||
|
||||
full_spec = PatternMatcher([
|
||||
# any END
|
||||
(UPat(Ops.END), lambda: True),
|
||||
|
||||
# SENTINEL should never be in the graph
|
||||
(UPat(Ops.SENTINEL), lambda: False),
|
||||
|
||||
# Invalid must have type Index
|
||||
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
|
||||
# where on index in rhs position is fine
|
||||
@@ -165,19 +187,12 @@ full_spec = PatternMatcher([
|
||||
|
||||
# rangeify: buffer view with index or load is okay
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.INDEX, Ops.LOAD)),)), lambda: True),
|
||||
# bufferize (must be on ranges)
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.op in {Ops.RANGE, Ops.CONST} for y in x.src[1:])),
|
||||
# intermediate index
|
||||
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
|
||||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
|
||||
# copy on index
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.INDEX), UPat())), lambda: True),
|
||||
# assign on index. the third op is the shape
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat(), UPat())), lambda: True),
|
||||
|
||||
# expander: unroll/contract/gep/ptrcat/cat
|
||||
#(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
#(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||
(UPat((Ops.UNROLL, Ops.CONTRACT), src=(UPat(),)), lambda: True),
|
||||
# GEP multi is supported here
|
||||
(UPat(Ops.GEP, name="gep"), lambda gep: gep.dtype is dtypes.void or gep.dtype.vcount == len(gep.arg)),
|
||||
@@ -211,7 +226,7 @@ full_spec = PatternMatcher([
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.STORE),)), lambda: True),
|
||||
# allow any AFTER
|
||||
(UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True),
|
||||
])+tensor_spec+program_spec
|
||||
])+tensor_spec+kernel_spec+program_spec+shared_spec
|
||||
|
||||
# ***** uop helpers *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user