mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
remove arbitrary multiplication case (#7033)
adds the wrongly simplified kernel in test_linearizer_failures #7019
This commit is contained in:
@@ -1198,5 +1198,46 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=1, amt=2)]
|
||||
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
|
||||
|
||||
def test_failure_51(self):
|
||||
# regression test for #7019, training bert on tinybox red
|
||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(1024, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.ALU, dtypes.half, arg=UnaryOps.RECIP, src=(
|
||||
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.half, arg=TernaryOps.WHERE, src=(
|
||||
x6:=UOp(UOps.VALID, dtypes.bool, arg=None, src=(
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(UOps.CONST, dtypes.half, arg=1.0, src=()),
|
||||
x9:=UOp(UOps.CONST, dtypes.half, arg=0.0, src=()),)),
|
||||
UOp(UOps.ALU, dtypes.half, arg=UnaryOps.EXP2, src=(
|
||||
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.ALU, dtypes.half, arg=TernaryOps.WHERE, src=(
|
||||
x6,
|
||||
UOp(UOps.CONST, dtypes.half, arg=2.0, src=()),
|
||||
x9,)),
|
||||
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
|
||||
UOp(UOps.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1024), strides=(524288, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1024), strides=(0, 1024, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
|
||||
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=3, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
UOp(UOps.ALU, dtypes.half, arg=TernaryOps.WHERE, src=(
|
||||
x6,
|
||||
UOp(UOps.CONST, dtypes.half, arg=-1.4426950408889634, src=()),
|
||||
x9,)),)),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.TC, axis=0, amt=2)]
|
||||
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -365,9 +365,6 @@ class UOp(MathTrait):
|
||||
Lmin, Lmax = (s0.vmin, s0.vmax) if s1.vmin >= 0 else (s0.vmax, s0.vmin)
|
||||
Rmin, Rmax = (s1.vmin, s1.vmax) if s0.vmin >= 0 else (s1.vmax, s1.vmin)
|
||||
return Lmin*Rmin, Lmax*Rmax
|
||||
# arbitrary
|
||||
products = [s0.vmin * s1.vmin, s0.vmin * s1.vmax, s0.vmax * s1.vmin, s0.vmax * s1.vmax]
|
||||
return min(products), max(products)
|
||||
if self.arg is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1
|
||||
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
|
||||
if s1.arg > 0: return s0.vmin//s1.arg, s0.vmax//s1.arg
|
||||
|
||||
Reference in New Issue
Block a user