diff --git a/test/test_custom_kernel.py b/test/test_custom_kernel.py new file mode 100644 index 0000000000..26a57325d0 --- /dev/null +++ b/test/test_custom_kernel.py @@ -0,0 +1,65 @@ +import unittest +from typing import Callable +from tinygrad import Tensor, UOp +from tinygrad.uop.ops import KernelInfo + +def custom_arange_kernel(C:UOp): + i = UOp.range(C.size, 0) + return C[i].store(i.cast(C.dtype.base)).end(i).sink(arg=KernelInfo(name=f"custom_arange_{C.size}")) + +def custom_add_one_kernel(B:UOp, A:UOp): + assert B.size == A.size + i = UOp.range(A.size, 0) + return B[i].store(A[i] + 1).end(i).sink(arg=KernelInfo(name=f"add_one_{A.size}")) + +def custom_elementwise_add_kernel(C:UOp, A:UOp, B:UOp): + i = UOp.range(C.size, 0) + return C[i].store(A[i]+B[i]).end(i).sink(arg=KernelInfo(name=f"custom_add_kernel_{C.size}")).simplify() + +def custom_elementwise_addmul_kernel(C:UOp, D:UOp, A:UOp, B:UOp): + assert C.size == D.size + i = UOp.range(C.size, 0) + store_c = C[i].store(A[i]+B[i]) + store_d = D[i].store(A[i]*B[i]) + return UOp.group(store_c, store_d).end(i).sink(arg=KernelInfo(name=f"custom_addmul_kernel_{C.size}")).simplify() + +def _kernel(tensors:list[Tensor], fxn:Callable) -> list[Tensor]: return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in tensors], fxn=fxn)] + +class TestCustomKernel(unittest.TestCase): + def test_simple(self): + a = Tensor.ones(16, 16).contiguous() + b = Tensor.ones(16, 16).contiguous() + c = Tensor.empty(16, 16) + + c = _kernel([c,a,b], fxn=custom_elementwise_add_kernel)[0] + + out = c.flatten().tolist() + assert all(x == 2 for x in out), "all 2" + + def test_multioutput(self): + a = Tensor.full((16, 16), 3.).contiguous() + b = Tensor.full((16, 16), 3.).contiguous() + c = Tensor.empty(16, 16) + d = Tensor.empty(16, 16) + + c,d = _kernel([c,d,a,b], custom_elementwise_addmul_kernel)[:2] + Tensor.realize(c,d) + + assert all(x == 6 for x in c.flatten().tolist()), "all 6" + assert all(x == 9 for x in d.flatten().tolist()), "all 9" + + def test_arange(self): + ref = Tensor.arange(100) + tst = Tensor.empty_like(ref) + tst = _kernel([tst], custom_arange_kernel)[0] + self.assertTrue((ref == tst).all().item()) + + def test_noncontig(self): + a = Tensor.ones(16, 16).contiguous() + tst = Tensor.empty_like(a) + b = a+1 + b_p1 = _kernel([tst, b], custom_add_one_kernel)[0] + self.assertTrue((b_p1 == 3).all().item()) + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_schedule.py b/test/test_schedule.py index 856a801062..7238bea5d1 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -711,7 +711,7 @@ class TestSchedule(unittest.TestCase): self.assertEqual(b.buffer.numpy(), [12]) # unlike schedule, kernelize can be called multiple times on a Tensor - def test_double_kerenlize(self): + def test_double_kernelize(self): a = Tensor.empty(10) b = Tensor.empty(10) c = (a+b) @@ -2267,7 +2267,7 @@ class TestContiguous(unittest.TestCase): def test_double_contiguous_realizes_once(self): a = Tensor.empty(4, 1) b = a.expand((4, 4)).contiguous().contiguous() - check_schedule(b, 2) # TODO: should be 1? + check_schedule(b, 1) def test_view_does_not_realize(self): a = Tensor.empty(4) diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 5b7ca601a2..a896c2b5d8 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -51,7 +51,7 @@ class IndexingContext: return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.index, 0) def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp): - if x.op in {Ops.BUFFERIZE, Ops.INDEX, Ops.KERNEL}: return None + if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None if x.op is Ops.AFTER and x.src[1].op is Ops.KERNEL: return None new_srcs = [] for s in x.src: @@ -155,6 +155,10 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: ending_ranges: dict[UOp, list[UOp]] = {} for x in tsink_reverse_toposort: if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue + + # no ranges on kernels, they are internal + if x.op is Ops.KERNEL: continue + if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this? ending_ranges[x] = sum([ending_ranges.get(u, []) for u in consumer_map[x]], []) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index b85e4a143a..f8905c6166 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -2,9 +2,9 @@ from dataclasses import dataclass, field import itertools from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo -from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate +from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags from tinygrad.uop.symbolic import symbolic_flat -from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata, DEBUG_RANGEIFY +from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY from tinygrad.helpers import PCONTIG, partition, get_single_element, unwrap from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify from tinygrad.codegen.opt import Opt @@ -355,6 +355,9 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([ # move RESHAPEs through MSELECT/MSTACK (UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"), lambda m: m.replace(src=tuple([x.src[0].base for x in m.src]), tag=None).reshape(m.shape).rtag(m.tag)), + + # remove any RESHAPEs on KERNEL + (UPat(Ops.KERNEL, name="k"), lambda k: k.replace(src=tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src))), ]) pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([ @@ -458,19 +461,13 @@ def remove_metadata_tags(ctx:LocalAddBufferContext, x:UOp): return x.replace(tag=None) pm_remove_tags = PatternMatcher([ - # remove all the tags (UPat(GroupOp.All, name="x"), remove_metadata_tags), ]) pm_add_range_tags = PatternMatcher([ - (UPat(Ops.RANGE, name="x"), lambda x: x.rtag(())) + (UPat(Ops.RANGE, name="x"), lambda x: x.rtag(())), ]) -@dataclass(frozen=True) -class Kernel: - ast: UOp - metadata: tuple[Metadata, ...] = () - def split_store(ctx:list[UOp], x:UOp) -> UOp|None: if len(x.ranges): return None @@ -507,7 +504,7 @@ def tag_uop(ctx:list[UOp], x:UOp): return x.replace(tag=(len(ctx)-1,)) add_tags = PatternMatcher([ # don't tag BUFFERs, they are global - (UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND, + (UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop), (UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)), ]) @@ -534,7 +531,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: uop_list: list[UOp] = [] tsink = graph_rewrite(sink, add_tags, ctx=uop_list, bottom_up=True, name="number the uops") - tsink = graph_rewrite(tsink, earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites") + tsink = graph_rewrite(tsink, pm_mops+earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites") # convert movement ops to ranges tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY) @@ -546,7 +543,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: # rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph # MSTACK stacks multiple BUFFERIZEs in one tagged tensor # if it's not tagged by here, it's out - tsink = UOp.sink(*[x for x in tsink.backward_slice if x.base.op in {Ops.BUFFERIZE, Ops.MSTACK, Ops.CONST, Ops.BUFFER} and \ + tsink = UOp.sink(*[x for x in tsink.backward_slice if x.base.op in {Ops.BUFFERIZE, Ops.MSTACK, Ops.CONST, Ops.BUFFER, Ops.AFTER} and \ x.tag is not None and len(x.tag)]) if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify") @@ -571,10 +568,14 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph") + # TODO: we can probably get this earlier + sink_tags = [s.tag for s in tsink.src] + tsink = graph_rewrite(tsink, _remove_all_tags, name="remove all tags") + becomes_map: dict[UOp, UOp] = {} - for s in tsink.src: - assert s.tag is not None - for a in s.tag: + for tag, s in zip(sink_tags, tsink.src): + assert tag is not None + for a in tag: if a is None: continue - becomes_map[uop_list[int(a)]] = s.replace(tag=None) + becomes_map[uop_list[int(a)]] = s return becomes_map diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 9aec0a9912..dbbe1655b1 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -414,6 +414,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return self.op is Ops.BUFFER def contiguous(self, *args, **kwargs): + if self.op is Ops.CONTIGUOUS: return self if self.is_contiguous(): return self return UOp(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs) def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD) @@ -773,6 +774,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]=()) -> UOp: return self.src[0].after(self.store(UOp.const(self.dtype, val) if not isinstance(val, UOp) else val).end(*argfix(end))) + def custom_kernel(*srcs:UOp, fxn:Callable) -> list[UOp]: + placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(srcs)] + base_srcs = tuple(x.contiguous().base for x in srcs) + kernel = UOp(Ops.KERNEL, src=base_srcs, arg=Kernel(fxn(*placeholders))) + return [s.after(kernel) for s in base_srcs] + @dataclass(frozen=True) class KernelInfo: name: str = "test" # name of the kernel @@ -1248,6 +1255,7 @@ pm_lower_index_dtype = PatternMatcher([ def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0] _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) +_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) def do_unbind(ctx:dict[Variable, int], x:UOp): v,i = x.unbind() diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index b1b4b62013..177e0b5e37 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -1,6 +1,6 @@ import math from typing import cast, Any -from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType, KernelInfo, pyrender +from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType, KernelInfo, pyrender, Kernel from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata from tinygrad.uop.validate import validate_index @@ -57,7 +57,7 @@ movement_ops = PatternMatcher([ (UPat(Ops.AFTER, src=(UPat(GroupOp.Movement),), allow_any_len=True), lambda: True), ]) -tensor_spec = PatternMatcher([ +_tensor_spec = PatternMatcher([ # buffer spec (UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True), (UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d: @@ -69,7 +69,7 @@ tensor_spec = PatternMatcher([ (UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.MSTACK, src=UPat(Ops.BUFFER)),)), lambda: True), # KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER - (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True), + (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND, Ops.CONTIGUOUS))), lambda: True), # ASSIGN has a target and a value. It can also optionally depend on other assigns (UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])), @@ -111,6 +111,11 @@ tensor_spec = PatternMatcher([ (UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True), ])+movement_ops+shared_spec +tensor_spec = PatternMatcher([ + # no tags allowed in tensor graph + (UPat(GroupOp.All, name="x"), lambda x: None if x.tag is None else False), +])+_tensor_spec + # ***** UOp spec in codegen shared between kernel and program ***** shared_codegen_spec = PatternMatcher([ @@ -246,7 +251,7 @@ full_spec = PatternMatcher([ (UPat(Ops.DEFINE_VAR, dtype=dtypes.floats), lambda: True), # allow any AFTER (UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True), -])+tensor_spec+kernel_spec+program_spec+shared_spec +])+_tensor_spec+kernel_spec+program_spec+shared_spec # ***** uop helpers ***** @@ -262,7 +267,7 @@ def type_verify(ast:UOp|list[UOp], check_spec:PatternMatcher): # late imports to avoid circular import from tinygrad.codegen.opt import Opt, OptOps -from tinygrad.schedule.rangeify import BufferizeOpts, Kernel +from tinygrad.schedule.rangeify import BufferizeOpts glbls:dict[str, Any] = {"inf": math.inf, "nan": math.nan, "KernelInfo": KernelInfo, "Kernel": Kernel, "Metadata": Metadata, "UOp": UOp, "dtypes": dtypes, "Ops": Ops, "AxisType": AxisType, "Invalid": Invalid, "Opt": Opt, "OptOps": OptOps, "BufferizeOpts": BufferizeOpts, "AddrSpace": AddrSpace}