From bc95b7e42264a391da5e689145e01f52ea1ad668 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 14 Oct 2024 15:11:23 +0300 Subject: [PATCH] actually use UOps.CONTIGUOUS (#7049) --- tinygrad/codegen/uopgraph.py | 4 ++-- tinygrad/engine/schedule.py | 6 ++---- tinygrad/ops.py | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index ef9977c5c6..3515580655 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -274,8 +274,8 @@ def no_vectorized_wmma(wmma:UOp): sym = symbolic_flat+PatternMatcher([ # self ASSIGN is just self (UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x), - # ASSIGN to global is just self - (UPat(UOps.ASSIGN, src=(UPat(UOps.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x), + # ASSIGN or CONTIGUOUS to global is just self + (UPat((UOps.ASSIGN, UOps.CONTIGUOUS), src=(UPat(UOps.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x), # VECTORIZE/GEP: the expander rule allows tuple GEP creation, this is just for removal (UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x if x.dtype == vec.dtype and tuple(y.arg[0] for y in vec.src) == tuple(range(len(vec.src))) else None), diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 0f3925d0a8..2974f60a9f 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -115,7 +115,7 @@ reduceop_fusor = PatternMatcher([ # push a SWIZZLE down to STORE, through a reduce (ONLY reshapes) (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce), # push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes) - (UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.STORE), name="root"), push_swizzle_down_through_elementwise), + (UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.CONTIGUOUS, UOps.STORE), name="root"), push_swizzle_down_through_elementwise), (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) @@ -164,9 +164,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. src_st = ShapeTracker.from_shape(buf.srcs[0].shape) if buf.op in ReduceOps else st src: List[UOp] = [_recursive_uop(x, src_st, outputs, var_vals, inputs, buf_uops, assign_targets, cache) for x in buf.srcs] if buf.op in ReduceOps: ret = UOp(UOps.REDUCE_AXIS, dtype, tuple(src), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg)).view(st) - elif buf.op is MetaOps.CONTIGUOUS: - assert buf in outputs, f"{buf.op} must be writable" - ret = src[0] + elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, (buf_uops[buf.buffer], src[0])) elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1])) elif buf.op is UnaryOps.CAST: ret = src[0].cast(dtype) elif buf.op is UnaryOps.BITCAST: ret = src[0].bitcast(dtype) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 0f90e57fc7..6af00ddecf 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -722,7 +722,7 @@ spec = PatternMatcher([ (UPat(UOps.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), (UPat(UOps.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)), - (UPat(UOps.ASSIGN, src=(UPat((UOps.DEFINE_ACC, UOps.DEFINE_GLOBAL)), UPat())), lambda: True), + (UPat((UOps.ASSIGN, UOps.CONTIGUOUS), src=(UPat((UOps.DEFINE_ACC, UOps.DEFINE_GLOBAL)), UPat())), lambda: True), (UPat(UOps.ENDRANGE, dtype=dtypes.void, src=(UPat(UOps.RANGE),)), lambda: True), # all WMMA has 3 args,