diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index d0304d7ab0..78ad3c2b49 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -770,6 +770,7 @@ class TestIdxUpcast(unittest.TestCase): # Offset brings final value within int32, but calculation has to be done on int64 # ((gidx0+((gidx1*129)+(gidx2*528384)))+-1073741824) where gidx.max = 128, gidx1.max = 4095, gidx2.max = 4095. # Intermediate sum is 2147487743, bigger than 2**31 (2147483647) + @unittest.skip("Cast back is an optimization to be implemented") def test_overflow_neg_offset_upper_bound(self): dim1, dim2, dim3, offset = 2**12, 2**12, 2**7+1, -2**30 store, _ = self._assert((dim1, dim2, dim3), dtypes.int, offset=offset) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index a3c1c3a2d9..86307cf626 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -6,7 +6,6 @@ from tinygrad.dtype import dtypes, PtrDType from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop from tinygrad.renderer import Renderer from tinygrad.helpers import all_int, prod, partition, flatten -from tinygrad.codegen.uopgraph import sym # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None: @@ -105,26 +104,9 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp): ctx.acc_num += 1 return acc.assign(acc.alu(alu_op, ret)) -def overflow(u, dtype): return u.vmax > dtypes.max(dtype) or u.vmin < dtypes.min(dtype) - -# If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation, -# or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace` -def upcast(u: UOp): - srcs = [upcast(_src) for _src in u.src] - if u.dtype.scalar() is dtypes.int: - dtype = dtypes.int64.vec(u.dtype.count) if u.dtype.count > 1 else dtypes.int64 - ret = u.replace(dtype=dtype, src=tuple([_src.cast(dtype) for _src in srcs])) - if overflow(u, u.dtype): - return ret - # Check the original src, new srcs has Ops.CAST whose vmin, vmax changes the real bounds - if any((overflow(src, u.dtype) for src in u.src)): - # Optionally cast down - return ret.cast(u.dtype) - return u.replace(src=tuple(srcs)) def lower_load_store(ctx: IndexContext, x: UOp): idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs) - idx, valid = upcast(graph_rewrite(idx, sym, {})), upcast(graph_rewrite(valid, sym, {})) buf = x.src[0] if x.op is Ops.LOAD: barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else () @@ -137,19 +119,14 @@ def lower_load_store(ctx: IndexContext, x: UOp): # NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs]) - idx = upcast(graph_rewrite(idx, sym, {})) if (not cast(PtrDType, x.src[0].dtype).local) or store_back: for oidx, ridx in zip(ctx.idxs, ctx.ridxs): if oidx is not ridx: valid = valid * oidx.eq(0) return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid), x.src[2])) -def valid_view_to_uop(ctx: IndexContext, x:UOp): - _, valid = x.st_arg.to_indexed_uops(ctx.idxs) - return upcast(graph_rewrite(valid, sym, {})) - pm_lowerer = PatternMatcher([ (UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis), - (UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), valid_view_to_uop), + (UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx, x: x.st_arg.to_indexed_uops(ctx.idxs)[1]), # rewrite LOAD/STORE VIEW to LOAD/STORE with indexed (UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store), (UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)), diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index d86688fa0f..3777c2f3cb 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -8,6 +8,18 @@ from tinygrad.shape.view import View, strides_for_shape from tinygrad.dtype import dtypes from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid +def overflow(u): return u.vmax > dtypes.max(dtypes.int32) or u.vmin < dtypes.min(dtypes.int32) + +# If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation, +# or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace` +def upcast(u: UOp): + srcs = [upcast(_src) for _src in u.src] + if u.dtype.scalar() is dtypes.int: + if overflow(u) or any((overflow(src) for src in u.src)): # Check original src to exclude Ops.CAST which may obscure vmin and vmax + dtype = dtypes.int64.vec(u.dtype.count) if u.dtype.count > 1 else dtypes.int64 + return u.replace(dtype=dtype, src=tuple([_src.cast(dtype) for _src in srcs])) + return u.replace(src=tuple(srcs)) + @functools.lru_cache(None) def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...]]=None) -> tuple[UOp, UOp]: idx, valid = views[-1].to_indexed_uops(_idxs) @@ -18,7 +30,7 @@ def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ... idxs.append((idx//acc)%d) acc *= d idx, valid = view.to_indexed_uops(idxs[::-1], valid) - return idx, valid + return upcast(idx), upcast(valid) @functools.lru_cache(None) def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[Optional[sint], ...]: