mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
dumb linearizer example that max is not simplified (#5644)
* dumb linearizer example that max is not simplified this might just get fix once basic mod simplification is done * need local
This commit is contained in:
@@ -32,5 +32,28 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
prg.uops.print()
|
||||
print(prg.src)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
|
||||
def test_max_simplify_and_cancel(self):
|
||||
ast = LazyOp(MetaOps.KERNEL, arg=None, src=(
|
||||
LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),))), src=(
|
||||
LazyOp(BinaryOps.MUL, arg=None, src=(
|
||||
LazyOp(UnaryOps.CAST, arg=dtypes.int, src=(
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),))), src=()),
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),
|
||||
LazyOp(BinaryOps.ADD, arg=None, src=(
|
||||
LazyOp(ReduceOps.SUM, arg=(1,), src=(
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False)))), src=()),)),
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=1000, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),)),))
|
||||
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
k.required_optimizations()
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
prg = k.to_program()
|
||||
prg.uops.print()
|
||||
print(prg.src)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -251,6 +251,11 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_div_into_mod(self):
|
||||
self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)")
|
||||
|
||||
def test_div_neg_cancel(self):
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 0, 25, "((1+idx)//4)")
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "(idx//4)")
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, -1, 24, "(((3+idx)//4)+-1)")
|
||||
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
# TODO: why are the negative tests broken? (even if we did support negative variables)
|
||||
|
||||
@@ -330,6 +330,12 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_div_into_mod(self):
|
||||
self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_div_neg_cancel(self):
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 0, 25, "((1+idx)//4)")
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "(idx//4)")
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, -1, 24, "(((3+idx)//4)+-1)")
|
||||
|
||||
@unittest.skip("not supported on uops yet")
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
|
||||
Reference in New Issue
Block a user