diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index d8acb420d4..4d4c8b5e62 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -1198,5 +1198,46 @@ class TestLinearizerFailures(unittest.TestCase): opts = [Opt(op=OptOps.UPCAST, axis=1, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) + def test_failure_51(self): + # regression test for #7019, training bert on tinybox red + ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( + UOp(UOps.STORE, dtypes.void, arg=None, src=( + UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0, src=()), + UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(1024, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(UOps.ALU, dtypes.half, arg=UnaryOps.RECIP, src=( + UOp(UOps.ALU, dtypes.half, arg=BinaryOps.ADD, src=( + UOp(UOps.ALU, dtypes.half, arg=TernaryOps.WHERE, src=( + x6:=UOp(UOps.VALID, dtypes.bool, arg=None, src=( + UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(UOps.CONST, dtypes.half, arg=1.0, src=()), + x9:=UOp(UOps.CONST, dtypes.half, arg=0.0, src=()),)), + UOp(UOps.ALU, dtypes.half, arg=UnaryOps.EXP2, src=( + UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(UOps.ALU, dtypes.half, arg=TernaryOps.WHERE, src=( + x6, + UOp(UOps.CONST, dtypes.half, arg=2.0, src=()), + x9,)), + UOp(UOps.ALU, dtypes.half, arg=BinaryOps.ADD, src=( + UOp(UOps.CAST, dtypes.half, arg=None, src=( + UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(UOps.CAST, dtypes.float, arg=None, src=( + UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(UOps.LOAD, dtypes.half, arg=None, src=( + UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()), + UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1024), strides=(524288, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(UOps.LOAD, dtypes.half, arg=None, src=( + UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()), + UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1024), strides=(0, 1024, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), + UOp(UOps.LOAD, dtypes.half, arg=None, src=( + UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=3, src=()), + UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), + UOp(UOps.ALU, dtypes.half, arg=TernaryOps.WHERE, src=( + x6, + UOp(UOps.CONST, dtypes.half, arg=-1.4426950408889634, src=()), + x9,)),)),)),)),)),)),)) + opts = [Opt(op=OptOps.TC, axis=0, amt=2)] + helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 729601c5c1..6d6d51e67a 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -365,9 +365,6 @@ class UOp(MathTrait): Lmin, Lmax = (s0.vmin, s0.vmax) if s1.vmin >= 0 else (s0.vmax, s0.vmin) Rmin, Rmax = (s1.vmin, s1.vmax) if s0.vmin >= 0 else (s1.vmax, s1.vmin) return Lmin*Rmin, Lmax*Rmax - # arbitrary - products = [s0.vmin * s1.vmin, s0.vmin * s1.vmax, s0.vmax * s1.vmin, s0.vmax * s1.vmax] - return min(products), max(products) if self.arg is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1 if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST: if s1.arg > 0: return s0.vmin//s1.arg, s0.vmax//s1.arg