remove arbitrary multiplication case (#7033)

adds the wrongly simplified kernel in test_linearizer_failures
#7019
This commit is contained in:
chenyu
2024-10-13 15:06:05 -04:00
committed by GitHub
parent 13575f080a
commit 1a27417262
2 changed files with 41 additions and 3 deletions

View File

@@ -1198,5 +1198,46 @@ class TestLinearizerFailures(unittest.TestCase):
opts = [Opt(op=OptOps.UPCAST, axis=1, amt=2)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
def test_failure_51(self):
# regression test for #7019, training bert on tinybox red
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(1024, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.ALU, dtypes.half, arg=UnaryOps.RECIP, src=(
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.half, arg=TernaryOps.WHERE, src=(
x6:=UOp(UOps.VALID, dtypes.bool, arg=None, src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(UOps.CONST, dtypes.half, arg=1.0, src=()),
x9:=UOp(UOps.CONST, dtypes.half, arg=0.0, src=()),)),
UOp(UOps.ALU, dtypes.half, arg=UnaryOps.EXP2, src=(
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(UOps.ALU, dtypes.half, arg=TernaryOps.WHERE, src=(
x6,
UOp(UOps.CONST, dtypes.half, arg=2.0, src=()),
x9,)),
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.ADD, src=(
UOp(UOps.CAST, dtypes.half, arg=None, src=(
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(UOps.CAST, dtypes.float, arg=None, src=(
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1024), strides=(524288, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1024), strides=(0, 1024, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=3, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
UOp(UOps.ALU, dtypes.half, arg=TernaryOps.WHERE, src=(
x6,
UOp(UOps.CONST, dtypes.half, arg=-1.4426950408889634, src=()),
x9,)),)),)),)),)),)),))
opts = [Opt(op=OptOps.TC, axis=0, amt=2)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
if __name__ == '__main__':
unittest.main()

View File

@@ -365,9 +365,6 @@ class UOp(MathTrait):
Lmin, Lmax = (s0.vmin, s0.vmax) if s1.vmin >= 0 else (s0.vmax, s0.vmin)
Rmin, Rmax = (s1.vmin, s1.vmax) if s0.vmin >= 0 else (s1.vmax, s1.vmin)
return Lmin*Rmin, Lmax*Rmax
# arbitrary
products = [s0.vmin * s1.vmin, s0.vmin * s1.vmax, s0.vmax * s1.vmin, s0.vmax * s1.vmax]
return min(products), max(products)
if self.arg is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
if s1.arg > 0: return s0.vmin//s1.arg, s0.vmax//s1.arg