diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bcd2d531f3..9e6ab355ae 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -310,10 +310,6 @@ jobs: run: python test/external/fuzz_symbolic.py - name: Fuzz Test fast idiv run: python test/external/fuzz_fast_idiv.py - - name: Fuzz Test shapetracker - run: CNT=50 python test/external/fuzz_shapetracker.py - - name: Fuzz Test shapetracker math - run: CNT=200 python test/external/fuzz_shapetracker_math.py - name: Fuzz Test shape ops run: python test/external/fuzz_shape_ops.py diff --git a/setup.py b/setup.py index 39dd40da60..8d1bb7b789 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,6 @@ setup(name='tinygrad', 'tinygrad.runtime.support.am', 'tinygrad.runtime.support.nv', 'tinygrad.schedule', - 'tinygrad.shape', 'tinygrad.uop', 'tinygrad.viz', ], diff --git a/test/external/external_uop_gc.py b/test/external/external_uop_gc.py index 1155c068bf..4327b69a56 100644 --- a/test/external/external_uop_gc.py +++ b/test/external/external_uop_gc.py @@ -1,6 +1,5 @@ import gc from tinygrad import Tensor, UOp, Device, nn -from tinygrad.shape.shapetracker import views_to_valid_uop from tinygrad.engine.realize import method_cache, get_program from tinygrad.schedule.indexing import apply_movement_op from test.test_tiny import TestTiny @@ -69,7 +68,6 @@ if __name__ == "__main__": # these caches will keep uops alive method_cache.clear() - views_to_valid_uop.cache_clear() apply_movement_op.cache_clear() Tensor._device_seeds.clear() Tensor._device_rng_counters.clear() diff --git a/test/opt/test_gen_float4.py b/test/opt/test_gen_float4.py index 0b675eb469..357dccae6d 100644 --- a/test/opt/test_gen_float4.py +++ b/test/opt/test_gen_float4.py @@ -2,7 +2,6 @@ import unittest from tinygrad import Device, Tensor, dtypes from tinygrad.uop.ops import UOp, Ops from tinygrad.codegen.opt import Opt, OptOps -from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.engine.realize import get_program from tinygrad.helpers import AMX @@ -149,33 +148,5 @@ class TestFloat4(unittest.TestCase): assert TestFloat4.count_float4(uops) == (1, 1) - @unittest.skip("Ops.VIEW no longer exists") - def test_half4_load_unrolled(self): - # from llama 7B shard 4 gpus - ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(96000), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(96000), arg=0, src=()),)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( - UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.MUL, dtypes.half, arg=None, src=( - UOp(Ops.LOAD, dtypes.half, arg=None, src=( - UOp(Ops.VIEW, dtypes.half.ptr(9216), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(9216), arg=1, src=()),)),)), - UOp(Ops.LOAD, dtypes.half, arg=None, src=( - UOp(Ops.VIEW, dtypes.half.ptr(32768000), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(32768000), arg=2, src=()),)),)),)),)),)),)),)) - - # TODO: fix this, expected might change but should be positive - for expected, opts in [ - ((7, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=3), Opt(op=OptOps.UNROLL, axis=0, arg=4)]), - ((5, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4)]), - ((2, 0), [Opt(op=OptOps.UNROLL, axis=0, arg=4)]), - ]: - program = get_program(ast, Device[Device.DEFAULT].renderer, opts=opts) - - count = TestFloat4.count_half4(program.uops) - assert count == expected, f"{count=}, {expected=}" - if __name__ == '__main__': unittest.main() diff --git a/test/test_linearizer_dumb.py b/test/test_linearizer_dumb.py index ac56b04d31..ce6d5ec144 100644 --- a/test/test_linearizer_dumb.py +++ b/test/test_linearizer_dumb.py @@ -5,10 +5,8 @@ import unittest from tinygrad import Device, dtypes from tinygrad.uop.ops import UOp, Ops, AxisType, KernelInfo -from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.codegen.opt.search import Opt, OptOps from tinygrad.engine.realize import get_program -from tinygrad.renderer.ptx import PTXRenderer class TestLinearizerFailure(unittest.TestCase): @unittest.skipUnless(Device.DEFAULT == "METAL", "only tested on METAL") @@ -29,49 +27,5 @@ class TestLinearizerFailure(unittest.TestCase): ast = c12.sink(arg=KernelInfo(name='test', axis_types=(), dont_use_locals=False, applied_opts=(Opt(op=OptOps.GROUP, axis=1, arg=16),), opts_to_apply=None)) _ = get_program(ast, Device["METAL"].renderer) -class TestLinearizerDumb(unittest.TestCase): - @unittest.expectedFailure - @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need float4") - def test_unrolled_float4_align(self): - c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=0, src=()) - c1 = c0.view(ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),))) - c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(18), arg=1, src=()) - c3 = c2.view(ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),))) - c4 = c3.load() - c5 = UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()) - c6 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(18), arg=2, src=()) - c7 = c6.view(ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),))) - c8 = c7.load() - c9 = c1.store(c4.alu(Ops.CMPNE, UOp.const(dtypes.long, -1, src=c5)).alu(Ops.CMPNE, UOp.const(dtypes.bool, True, src=c5)).where(UOp.const(dtypes.float, 0.0, src=c5), c8).f(Ops.REDUCE_AXIS, arg=(Ops.ADD, (0, 1)))) - ast = c9.sink() - opts = [Opt(op=OptOps.UNROLL, axis=0, arg=0)] - prg = get_program(ast, Device[Device.DEFAULT].renderer, opts) - print(prg.src) - load_idxs = [x.src[1] for x in prg.uops if x.op is Ops.LOAD and x.src[0].arg == 2] - assert load_idxs[0] < load_idxs[1], f"first loaded idx {load_idxs[0].arg} then {load_idxs[1].arg}!" - - @unittest.expectedFailure - @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need float4") - @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "this is somehow correct in PTX") - def test_upcasted_stores_out_of_order(self): - c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9360), arg=0, src=()) - c1 = c0.view(ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),))) - c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(144), arg=1, src=()) - c3 = c2.view(ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),))) - c4 = c3.load() - c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1040), arg=2, src=()) - c6 = c5.view(ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))) - c7 = c6.load() - c8 = c1.store((c4*c7).f(Ops.REDUCE_AXIS, arg=(Ops.ADD, (6,)))) - ast = c8.sink() - opts = [Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=0)] - prg = get_program(ast, Device[Device.DEFAULT].renderer, opts) - print(prg.src) - store_idxs = [x.src[1] for x in prg.uops if x.op is Ops.STORE] - for i in range(len(store_idxs) - 1): - first_bounds = store_idxs[i].vmin+store_idxs[i].vmax - next_bounds = store_idxs[i+1].vmin+store_idxs[i+1].vmax - assert first_bounds < next_bounds, f"first stored (max) idx {first_bounds} then {next_bounds}!" - if __name__ == '__main__': unittest.main() diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index 991a9dcc93..96139465cb 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -1,6 +1,5 @@ import unittest from tinygrad import Tensor, Variable, GlobalCounters -from tinygrad.shape.shapetracker import View from tinygrad.uop.ops import sym_infer from tinygrad.dtype import dtypes from tinygrad.device import is_dtype_supported @@ -64,14 +63,6 @@ class TestSymbolicOps(unittest.TestCase): self.test_attention(imin=4, imax=5, use_symbolic=False) self.test_attention(imin=4, imax=5, use_symbolic=True) - # until this works, symbolic single kernel softmax won't - @unittest.expectedFailure - def test_attention_simple_view(self): - i = Variable("i", 2, 10) - v1 = View.create((2,4,1,i,i), ((i*4),i,0,0,1)) - v2 = View.create((2,4,1,i,i,i), (((i*i)*4),(i*i),0,0,i,1)) - self.assertIsNotNone(v1+v2) - def test_attention_training(self): with Tensor.train(): self.test_attention(dropout_p=0.0) diff --git a/test/unit/test_indexing.py b/test/unit/test_indexing.py index 771ff344b5..32bda7a415 100644 --- a/test/unit/test_indexing.py +++ b/test/unit/test_indexing.py @@ -5,8 +5,6 @@ import numpy as np from tinygrad import Tensor, dtypes, Device, TinyJit from tinygrad.device import is_dtype_supported -from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.shape.view import View from tinygrad.helpers import CI, all_same, prod random.seed(42) @@ -22,11 +20,13 @@ def consec(shape, start=1): # creates strided tensor with base set to reference tensor's base, equivalent to torch.set_() def set_(reference: Tensor, shape, strides, offset): raise NotImplementedError("need to implement without calling uop.view") + """ if reference.uop.base.realized is None: reference.realize() assert reference.uop.base.realized, "base has to be realized before setting it to strided's base" strided = Tensor(reference.uop.view(ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),)))) assert strided.uop.st.real_strides() == strides, "real_strides should equal strides for strided" return strided + """ def clone(original:Tensor): return original.clone() def copy_(src:Tensor, other:Tensor) -> Tensor: return src.clone() diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py deleted file mode 100644 index 04b62a777c..0000000000 --- a/test/unit/test_shapetracker.py +++ /dev/null @@ -1,774 +0,0 @@ -#!/usr/bin/env python -import unittest -import numpy as np -from tinygrad.dtype import dtypes, Invalid -from tinygrad.helpers import prod -from tinygrad.shape.shapetracker import ShapeTracker, View, views_to_valid_uop -from tinygrad import Variable -from tinygrad.uop.ops import UOp, Ops, graph_rewrite -from tinygrad.codegen.late.devectorizer import sym -from itertools import product - -def shapetracker_getitem(st:ShapeTracker, val:int): - valid_idx = views_to_valid_uop(st.reshape((st.size,)).views, (UOp.const(dtypes.int, val),)) - idx, valid = valid_idx.get_idx(), valid_idx.get_valid() - idx, valid = graph_rewrite(idx, sym), graph_rewrite(valid, sym) - assert idx.op is Ops.CONST and valid.op is Ops.CONST - return idx.arg, valid.arg - -class CheckingShapeTracker: - def __init__(self, shape): - self.st = ShapeTracker.from_shape(shape) - self.t = np.arange(prod(shape), dtype=np.int32).reshape(shape) - - @property - def shape(self): - return self.t.shape - - def simplify(self): - self.st = self.st.simplify() - return self - - def reshape(self, new_shape): - self.st = self.st.reshape(new_shape) - self.t = self.t.reshape(new_shape) - return self - - def permute(self, axis): - self.st = self.st.permute(axis) - self.t = np.transpose(self.t, axis) - return self - - def expand(self, new_shape): - self.st = self.st.expand(new_shape) - self.t = np.broadcast_to(self.t, new_shape) - return self - - def flip(self, arg): - self.st = self.st.flip(arg) - self.t = np.flip(self.t, tuple(i for i in range(len(arg)) if arg[i])) - return self - - def shrink(self, arg): - self.st = self.st.shrink(arg) - self.t = self.t[tuple([slice(x[0], x[1]) for x in arg])] - return self - - def pad(self, arg): - self.st = self.st.pad(arg) - self.t = np.pad(self.t, arg, constant_values=-1) - return self - - def __getitem__(self, val): - return self.t.flatten()[val] - - @property - def views(self): return self.st.views - - @property - def contiguous(self): return self.st.contiguous - - def assert_same(self): - x = [(v[0] if (v:=shapetracker_getitem(self.st, i))[1] and v[0] is not Invalid else -1) for i in range(prod(self.st.shape))] - y = [self[i] for i in range(prod(self.shape))] - assert self.st.shape == self.shape - assert x == y, f"mismatch shapetracker:{x} real:{y}" - -@unittest.skip("don't create shapetrackers with views") -class TestRealIssues(unittest.TestCase): - def test_reshape_doesnt_multiview(self): - self.st = ShapeTracker((View.create((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None),)) - self.st.reshape((128, 2, 256, 2, 2, 2, 2, 2, 256, 8, 2)) - assert len(self.st.views) == 1 - - def test_reshape_stable_diffusion(self): - # regression test for https://github.com/tinygrad/tinygrad/pull/2616 - st = ShapeTracker((View((2, 1920, 32, 32), (1310720, 1024, 32, 1), 0, ((0, 2), (0, 1280), (0, 32), (0, 32)), False),)) - st = st.reshape((2, 32, 240, 256)) - assert len(st.views) == 2 - - def test_reshape_trailing_invalid_ones(self): - st = ShapeTracker((View(shape=(1, 1, 5), strides=(0, 0, 1), offset=-5, mask=((1, 1), (0, 1), (0, 5)), contiguous=False),)) - st = st.reshape((5,)) - assert len(st.views) == 1 - assert st.views[0].mask == ((0,0),) - -class TestRealDoesntSimplify(unittest.TestCase): - def tearDown(self): - self.st = self.st.simplify() - assert len(self.st.views) != 1 - - def test_1(self): - self.st = ShapeTracker(( - View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None), - View.create((8, 6, 11), (66, 11, 1), 0, None))) - self.assertEqual(self.st.is_expanded(), (False, False, False)) - - def test_2(self): - self.st = ShapeTracker(( - View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None), - View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None))) - self.assertEqual(self.st.is_expanded(), (False, False, False, False)) - -class TestRealStrides(unittest.TestCase): - def test_1(self): - st = ShapeTracker(( - View.create((2048,), (1,), 0, ((0, 512),)), - View.create((16, 32, 4), (128, 4, 1), 0, None), - )) - self.assertEqual(st.is_expanded(), (False, False, False)) - - def test_2(self): - # test/test_ops.py::TestOps::test_simple_padding_conv1d - st = ShapeTracker(( - View.create((6, 2, 5, 14), (90, 45, 1, 5), 0, ((0, 6), (0, 2), (0, 5), (0, 9))), - View.create((6, 2, 78), (140, 70, 1), 0, ((0, 6), (0, 2), (0, 70))), - View.create((6, 2, 13, 6), (156, 78, 1, 13), 0, None), - )) - self.assertEqual(st.is_expanded(), (False, False, False, False)) - - def test_3(self): - # test/test_ops.py::TestOps::test_simple_cumsum - st = ShapeTracker(( - View.create((4, 256, 512), (256, 0, 1), 0, ((0, 4), (0, 256), (0, 256))), - View.create((4, 131327), (131072, 1), 0, ((0, 4), (0, 131072))), - View.create((4, 511, 257), (131327, 1, 511), 0, None), - )) - self.assertEqual(st.is_expanded(), (False, False, False)) - - def test_4(self): - # test/test_nn.py::TestNN::test_conv_transpose1d - st = ShapeTracker(( - View.create((4, 16, 56, 2), (896, 56, 1, 0), 0, ((0, 4), (0, 16), (0, 56), (0, 1))), - View.create((1, 4, 1, 16, 8, 121), (0, 1792, 0, 112, 0, 1), -5, ((0, 1), (0, 4), (0, 1), (0, 16), (0, 8), (5, 116))), - View.create((4, 64, 115, 16, 7), (15488, 0, 1, 968, 122), 0, None), - )) - self.assertEqual(st.is_expanded(), (False, True, False, False, False)) - - def test_5(self): - # test/test_ops.py::TestOps::test_conv2d - st = ShapeTracker(( - View.create((1, 3, 1, 12, 2, 8), (0, 132, 0, 12, 1, 2), 0, ((0, 1), (0, 3), (0, 1), (0, 11), (0, 2), (0, 6))), - View.create((1, 3, 22, 21), (0, 192, 16, 1), 0, ((0, 1), (0, 3), (0, 12), (0, 16))), - View.create((3, 11, 7, 2, 3), (462, 21, 1, 231, 7), 0, None), - )) - self.assertEqual(st.is_expanded(), (False, False, False, True, False)) - -class TestIndexExpressions2d(unittest.TestCase): - def setUp(self): - shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5 - offsets = [0, 1, 15, 28, 10000] - self.sts = [ShapeTracker.from_shape((prod(base_shape)+offset,)).shrink(((offset, offset+prod(base_shape)),)).\ - reshape(base_shape) for base_shape in shapes for offset in offsets] - self.offset = [offset for base_shape in shapes for offset in offsets] - self.shapes = [shape for shape in shapes for offset in offsets] - self.idxs_exprs = [] - - def tearDown(self): - for st, offset, shape, idxs_expr in zip(self.sts, self.offset, self.shapes, self.idxs_exprs): - numel = prod(shape) - self.check_bounds(idxs_expr(self.default_idxs(st.shape)), offset, numel) - idx0s = [(0,0), (0, min(1, st.shape[0]-1)), (0, st.shape[0]-1), (min(3, st.shape[0]-1), min(6, st.shape[0]-1)), (st.shape[0]-1, st.shape[0]-1)] - idx1s = [(0,0), (0, min(1, st.shape[1]-1)), (0, st.shape[1]-1), (min(3, st.shape[1]-1), min(6, st.shape[1]-1)), (st.shape[1]-1, st.shape[1]-1)] - idx2s = [(0,0), (0, min(1, st.shape[2]-1)), (0, st.shape[2]-1), (min(3, st.shape[2]-1), min(6, st.shape[2]-1)), - (st.shape[2]-1, st.shape[2]-1)] if len(st.shape) == 3 else [None for _ in idx0s] - for idx0, idx1, idx2 in product(idx0s, idx1s, idx2s): - idxs = [Variable(f"idx{i}", idx[0], idx[1]) for i, idx in enumerate((idx0, idx1, idx2)) if idx is not None] - self.check_bounds(idxs_expr(idxs), offset, numel) - - def default_idx(self, shape): - return Variable("idx", 0, prod(shape)-1) - - def default_idxs(self, shape): - return [Variable(f"idx{i}", 0, d-1) for i,d in enumerate(shape)] - - def check_bounds(self, expr, offset, numel): - assert expr.vmin >= offset - assert expr.vmax <= offset + numel - 1 - - def test_noop(self): - for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): - self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[1] + offset) - - def test_permute(self): - new_st = [] - for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): - st = st.permute((1, 0)) - self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0] + idxs[1]*base_shape[1] + offset) - new_st.append(st) - self.sts = new_st - - def test_reshape(self): - new_st = [] - for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): - st = st.reshape((base_shape[0], 1, base_shape[1])) - self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset) - new_st.append(st) - self.sts = new_st - - def test_reshape_expand(self): - new_st = [] - for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): - st = st.reshape((base_shape[0], 1, base_shape[1])) - st = st.expand((base_shape[0], base_shape[1], base_shape[1])) - self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset) - new_st.append(st) - self.sts = new_st - - def test_permute_reshape_1(self): # This tests multiple views - new_st = [] - for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): - st = st.permute((1, 0)) - st = st.reshape((base_shape[0]//5, 1, base_shape[1]*5)) - self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[0]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + \ - (idxs[0]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset) - new_st.append(st) - self.sts = new_st - - def test_permute_reshape_2(self): - new_st = [] - for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): - st = st.permute((1, 0)) - st = st.reshape((1, base_shape[0]//5, base_shape[1]*5)) - self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[1]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + \ - (idxs[1]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset) - new_st.append(st) - self.sts = new_st - - def test_reshaping_splitting(self): - self.st = CheckingShapeTracker((5,10,5,10)) - self.st.permute((1, 0, 3, 2)) - self.st.pad(((0,0), (0,5), (0,0), (0,5))) - self.st.reshape((10,2,5,10,2,5)) - assert len(self.st.views) == 1 - self.st.assert_same() - - def test_reshape_splitting_1(self): - self.st = CheckingShapeTracker((1,10,1)) - self.st.pad(((0,4),(0,0),(1,0))) - self.st.reshape((5,5,2,2)) - assert len(self.st.views) == 1 - self.st.assert_same() - - def test_reshape_combining_1(self): - self.st = CheckingShapeTracker((2,1,10)) - self.st.pad(((2,6), (0,0), (0,0))) - self.st.reshape((100,)) - assert len(self.st.views) == 1 - self.st.assert_same() - - def test_reshape_combining_2(self): - self.st = CheckingShapeTracker((1,1,5)) - self.st.pad(((3,6), (0,0), (0,5))) - self.st.reshape((100,)) - assert len(self.st.views) == 1 - self.st.assert_same() - - def test_reshape_combining_3(self): - self.st = CheckingShapeTracker((1,1,4)) - self.st.pad(((3,6), (0,0), (1,5))) - self.st.reshape((100,)) - assert len(self.st.views) == 1 - assert self.st.views[0].mask[0] == (31, 35) - self.st.assert_same() - - def test_reshape_combining_4(self): - # interestingly this one is quite slow - self.st = CheckingShapeTracker((1,1,5,5,1,1,5)) - self.st.pad(((2,1), (0,0), (0,2), (0,0), (2,1), (0,0), (0,2))) - self.st.reshape((28,5,28)) - assert len(self.st.views) == 1 - self.st.assert_same() - - def test_reshape_splitting_combining(self): - self.st = CheckingShapeTracker((1,5,5)) - self.st.pad(((0,4), (0,5), (0,0))) - self.st.reshape((10,25)) - assert len(self.st.views) == 1 - self.st.assert_same() - - def test_reshape_only_1s(self): - self.st = CheckingShapeTracker((1, 1, 1, 4, 1, 3, 5, 1)) - self.st.pad(((0,4), (0,0), (0,0), (1,1), (0,0), (0,0), (0,0), (0,0))) - self.st.reshape((5, 6, 3, 5)) - assert len(self.st.views) == 1 - self.st.assert_same() - self.st.reshape((1, 1, 5, 6, 3, 5, 1, 1)) - assert len(self.st.views) == 1 - self.st.assert_same() - self.st.reshape((1, 5, 6, 1, 3, 1, 5, 1)) - assert len(self.st.views) == 1 - self.st.assert_same() - - def test_zero_mask_1(self): - self.st = CheckingShapeTracker((1, 3, 2)) - self.st.pad(((0,0), (0,3), (0,0))) - self.st.shrink(((0,1), (3,6), (0,2))) - self.st.reshape((3,2)) - assert len(self.st.views) == 1 - self.st.assert_same() - self.st.reshape((1, 3, 1, 2, 1)) - assert len(self.st.views) == 1 - self.st.assert_same() - - def test_zero_mask_2(self): - self.st = CheckingShapeTracker((1, 3, 2)) - self.st.pad(((0,2), (0,3), (0,0))) - self.st.shrink(((2,3), (3,6), (0,2))) - self.st.reshape((3,2)) - assert len(self.st.views) == 1 - self.st.assert_same() - self.st.reshape((1, 3, 1, 2, 1)) - assert len(self.st.views) == 1 - self.st.assert_same() - - def test_expanded_reshaped(self): - self.st = CheckingShapeTracker((1, 3, 2, 1)) - self.st.expand((5, 3, 2, 2)) - self.st.pad(((0,0), (0,3), (0,0), (0, 0))) - self.st.reshape((5, 2, 3, 2, 2)) - assert len(self.st.views) == 1 - self.st.assert_same() - - def test_splitting_big(self): - self.st = CheckingShapeTracker((1, 5, 1, 15, 1)) - self.st.pad(((0,0), (0,5), (0,0), (0,15), (0,0))) - self.st.reshape((10, 1, 30)) - self.st.permute((2,1,0)) - self.st.reshape((2,3,5,2,5)) - assert len(self.st.views) == 1 - v = self.st.views[-1] - assert v.strides == (0, 5, 1, 0, 15) and v.mask == ((0, 1), (0, 3), (0, 5), (0, 1), (0, 5)) - self.st.assert_same() - - def test_combining_big(self): - self.st = CheckingShapeTracker((1,3,1,5,3,1)) - self.st.pad(((0,0),(2,2),(0,0),(0,0),(0,0),(0,0))) - self.st.reshape((1,1,1,105,1,1)) - assert len(self.st.views) == 1 - v = self.st.views[-1] - assert v.strides == (0, 0, 0, 1, 0, 0) and v.mask == ((0, 1), (0, 1), (0, 1), (30, 75), (0, 1), (0, 1)) and v.offset == -30 - self.st.assert_same() - - def test_pad_reshape(self): - self.st = CheckingShapeTracker((4,)) - self.st.pad(((2,2),)) - self.st.reshape((4,2)) - assert len(self.st.views) == 1 - self.st.assert_same() - -class TestSimplifyingShapeTracker(unittest.TestCase): - def setUp(self): - self.st = CheckingShapeTracker((1, 10)) - - def tearDown(self): - self.st.assert_same() - - # multiview simplify - def test_expand_contract_simple(self): - self.st = self.st.expand((10, 10)) - self.st = self.st.reshape((100,)) - print(self.st.views) - assert (len(self.st.views) == 2) - self.st = self.st.reshape((10, 10)) - print(self.st.views) - - self.st = self.st.simplify() - print(self.st.views) - assert (len(self.st.views) == 1) - - # multiview simplify - def test_expand_contract_different_shape(self): - self.st.expand((10, 10)) - self.st.reshape((100,)) - print(self.st.views) - assert (len(self.st.views) == 2) - self.st.reshape((2, 5, 2, 5)) - print(self.st.views) - - self.st = self.st.simplify() - print(self.st.views) - assert (len(self.st.views) == 1) - - # multiview simplify - def test_expand_contract_still_complex(self): - self.st.expand((10, 10)) - self.st.reshape((100,)) - print(self.st.views) - assert (len(self.st.views) == 2) - self.st.reshape((5, 20)) - - self.st = self.st.simplify() - print(self.st.views) - assert (len(self.st.views) == 2) - -# Tensor.zeros(2, 4).permute(1,0).reshape(2, 4) -# (d1*4 + d0%4), d1=x//4, d0=x%4 = ((x//4)*4) + (x%4)%4 - -class TestComplexShapeTracker(unittest.TestCase): - def test_add_1s(self): - self.st = CheckingShapeTracker((4, 4)) - self.st.permute((1,0)) - self.st.reshape((1,4,1,4,1)) - assert not self.st.contiguous - self.st.permute((0,3,2,1,4)) - assert self.st.contiguous - - def test_permute_1s_simple(self): - self.st = CheckingShapeTracker((1, 16, 9,9)) - self.st.permute((1,0,2,3)) - assert self.st.contiguous - self.st = CheckingShapeTracker((2, 16, 9,9)) - self.st.permute((1,0,2,3)) - assert not self.st.contiguous - - def test_remove_1s_simple(self): - self.st = CheckingShapeTracker((1, 16, 1, 1)) - self.st.reshape((16,)) - assert self.st.contiguous - - def test_remove_1s(self): - self.st = CheckingShapeTracker((1, 4, 1, 4, 1)) - self.st.permute((0,3,2,1,4)) - self.st.reshape((4,4)) - assert not self.st.contiguous - self.st.permute((1,0)) - assert self.st.contiguous - - def test_permute_reshape(self): - self.st = CheckingShapeTracker((4, 4)) - self.st.permute((1,0)) - self.st.reshape((2, 2, 2, 2)) - # TODO: should also be tested by test_super_complex - assert len(self.st.views) == 1 - - def test_factorize_split(self): - self.st = CheckingShapeTracker((4, 4)) - self.st.permute((1,0)) - self.st.reshape((2, 2, 2, 2)) - self.st.permute((2,3,0,1)) - assert self.st.contiguous - - def test_factorize_combine(self): - self.st = CheckingShapeTracker((4, 4, 4)) - self.st.permute((2, 0, 1)) - self.st.reshape((4, 16)) - self.st.permute((1, 0)) - assert self.st.contiguous - - def test_factorize_combine_add_ones(self): - self.st = CheckingShapeTracker((4, 4, 4)) - self.st.permute((2, 0, 1)) - self.st.reshape((4, 16, 1, 1)) - self.st.permute((1, 0, 2, 3)) - assert self.st.contiguous - - def test_fancy_factorize(self): - self.st = CheckingShapeTracker((32, 3, 3, 1)) - self.st.reshape((8, 4, 3, 3)) - assert len(self.st.views) == 1 - - def test_super_complex_2_fail(self): - self.st = CheckingShapeTracker((4, 4, 4)) - self.st.permute((2, 0, 1)) - self.st.reshape((16, 4)) - assert len(self.st.views) != 1 - - def test_work(self): - self.st = CheckingShapeTracker((64, 1024, 4)) - self.st.reshape((1, 64, 128, 32)) - self.st.permute((0, 3, 1, 2)) - self.st.reshape((1, 32, 1, 64, 128)) - self.st.permute((0, 3, 4, 1, 2)) - assert self.st.contiguous - - def test_work2(self): - self.st = CheckingShapeTracker((64, 1024, 4)) - self.st.reshape((1, 64, 128, 32)) - self.st.permute((0, 3, 1, 2)) - self.st.reshape((1, 1, 32, 64, 128)) - self.st.permute((0, 3, 4, 1, 2)) - self.st.reshape((64, 1024, 4)) - print(self.st.views) - assert self.st.contiguous - -class TestShapeTrackerEquality(unittest.TestCase): - def test_simple_equals(self): - self.assertEqual(ShapeTracker.from_shape((10,10)), ShapeTracker.from_shape((10,10))) - def test_other_equals(self): - st1 = ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True))) - st2 = ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True))) - self.assertEqual(st1, st2) - -class TestSingleShapeTracker(unittest.TestCase): - def setUp(self): - self.st = CheckingShapeTracker((7,4)) - - def tearDown(self): - self.st.assert_same() - - def test_reshape(self): - self.st.reshape((7,1,4)) - assert self.st.contiguous - - def test_permute(self): - self.st.permute((1,0)) - assert not self.st.contiguous - - def test_shrink(self): - self.st.shrink(((1,2), (0,4))) - assert not self.st.contiguous - - def test_double_permute(self): - self.st.permute((1,0)) - self.st.permute((1,0)) - assert self.st.contiguous - - def test_reshape_permute(self): - self.st.reshape((7,1,4)) - self.st.permute((0,1,2)) - assert self.st.contiguous - - def test_reshape_permute_yes(self): - self.st.reshape((7,1,4)) - self.st.permute((0,2,1)) - assert self.st.contiguous - - def test_reshape_permute_no(self): - self.st.reshape((4,7)) - self.st.permute((1,0)) - assert not self.st.contiguous - -class TestShapeTrackerFuzzFailures(unittest.TestCase): - def setUp(self): - self.st = CheckingShapeTracker((3,3,3)) - def tearDown(self): - self.st.assert_same() - def test_case_1(self): - self.st.shrink(((1, 2), (1, 3), (1, 3))) - self.st.reshape((1, 4)) - self.st.shrink(((0, 1), (1, 3))) - self.st = self.st.simplify() - def test_case_2(self): - self.st.flip( (True, False, True) ) - self.st.reshape( (3, 9) ) - self.st.shrink( ((1, 2), (1, 5)) ) - self.st.flip( (True, True) ) - def test_case_3(self): - self.st.shrink( ((0, 2), (0, 2), (0, 1)) ) - self.st.permute( (1, 0, 2) ) - self.st.reshape( (4,) ) - self.st.shrink( ((0, 3),) ) - self.st.flip( (True, False) ) - def test_case_4(self): - self.st.reshape( (3, 3, 3, 1) ) - self.st.pad( ((0, 0), (0, 0), (0, 0), (1, 1)) ) - self.st.shrink( ((0, 2), (1, 2), (0, 2), (0, 1)) ) - self.st.expand( (2, 1, 2, 3) ) - -class TestMaskedShapeTracker(unittest.TestCase): - def test_pad_1x1(self): - self.st = CheckingShapeTracker((1,1)) - self.st.pad(((1,1), (1,1))) - self.st.assert_same() - - def test_pad_2x2(self): - self.st = CheckingShapeTracker((2,2)) - self.st.pad(((1,1), (1,1))) - self.st.assert_same() - - def test_pad_reshape(self): - st1 = CheckingShapeTracker((1, 2)) - st1.pad(((1, 0), (0, 1))) - st1.reshape((3, 2)) - st1.assert_same() - - st2 = CheckingShapeTracker((1, 2)) - st2.pad(((1, 1), (0, 2))) - st2.reshape((4, 3)) - st2.assert_same() - - st3 = CheckingShapeTracker((1, 1, 1, 2)) - st3.pad(((0, 2), (1, 2), (2, 2), (0, 4))) - st3.reshape((4, 3, 6, 5)) - st3.assert_same() - -class TestShapeTracker(unittest.TestCase): - def setUp(self): - self.st = CheckingShapeTracker((7,4)) - self.apply = lambda fxn: [fxn(x) for x in [self.st]] - - def tearDown(self): - self.st.assert_same() - - def test_noop(self): - pass - - def test_simple_split(self): - self.test_permute() - self.apply(lambda x: x.reshape((prod(self.st.shape), ))) - - def test_simple_pad(self): - self.st.pad(((1,1), (1,1))) - - def test_pad_shrink(self): - self.st.pad(((1,1), (1,1))) - self.st.shrink(((0,4), (0,4))) - - def test_pad_one_sided(self): - self.st.pad(((0,1), (0,0))) - - def test_pad_reshape(self): - self.st.pad(((0,1), (0,0))) - self.st.reshape((8*4,)) - - def test_pad_pad(self): - self.st.pad(((1,1), (1,1))) - self.st.pad(((1,1), (1,1))) - - def test_pad_permute(self): - self.st.pad(((1,1), (2,2))) - self.st.permute((1,0)) - - def test_pad_expand(self): - self.st.reshape((7,4,1)) - self.st.pad(((1,1), (1,1), (0,0))) - self.st.expand((9,6,4)) - - def test_pad_expand_alt(self): - self.st.pad(((1,1), (1,1))) - self.st.reshape((9,6,1)) - self.st.expand((9,6,4)) - - def test_pad_flip(self): - self.st.pad(((1,4), (1,3))) - self.st.flip((True, False)) - - def test_pad_flip_int(self): - self.st.pad(((1,4), (1,3))) - self.st.flip((0, 1)) - - def test_reshape(self): - new_shape = self.st.shape[::-1] - self.apply(lambda x: x.reshape(new_shape)) - - def test_permute(self): - if len(self.st.shape) == 2: self.apply(lambda x: x.permute((1,0))) - elif len(self.st.shape) == 3: self.apply(lambda x: x.permute((2,0,1))) - - def test_reshape_with_1(self): - new_shape = (self.st.shape[0], 1, self.st.shape[1]) - self.apply(lambda x: x.reshape(new_shape)) - - def test_expand(self): - self.test_reshape_with_1() - new_shape = list(self.st.shape) - new_shape[1] = 2 - self.apply(lambda x: x.expand(tuple(new_shape))) - - def test_flip_0(self): - self.apply(lambda x: x.flip((True, False))) - - def test_flip_1(self): - self.apply(lambda x: x.flip((False, True))) - - def test_flip_01(self): - self.apply(lambda x: x.flip((True, True))) - - def test_slice_0(self): - self.apply(lambda x: x.shrink(((1, x.shape[0]), (0, x.shape[1])))) - - def test_slice_1(self): - self.apply(lambda x: x.shrink(((0, x.shape[0]), (1, x.shape[1])))) - - def test_slice_1c1(self): - self.apply(lambda x: x.shrink(((0, 1), (0, 1)))) - - def test_slice_1c2(self): - self.apply(lambda x: x.shrink(((1, 2), (1, 2)))) - - def test_double_permute(self): - self.apply(lambda x: x.permute((1, 0))) - self.apply(lambda x: x.permute((1, 0))) - - def test_slice_permute(self): - self.apply(lambda x: x.shrink(((0, 2), (2, 4)))) - self.apply(lambda x: x.permute((1, 0))) - - def test_slice_expand(self): - self.apply(lambda x: x.shrink(((0, 2), (3, 4)))) - self.apply(lambda x: x.expand((2, 10))) - - def test_double_flip(self): - self.apply(lambda x: x.flip((True, False))) - self.apply(lambda x: x.flip((True, False))) - - def test_flip(self): self.apply(lambda x: x.flip((True, False))) - def test_flip2(self): self.apply(lambda x: x.flip((False, True))) - def test_flip3(self): self.apply(lambda x: x.flip((True, True))) - - def test_reshape_then_permute(self): - self.test_reshape() - self.test_permute() - - def test_reshape_then_expand(self): - self.test_reshape() - self.test_expand() - - def test_permute_then_reshape(self): - self.test_permute() - self.test_reshape() - - def test_expand_then_reshape(self): - self.test_expand() - self.test_reshape() - - def test_combo(self): - self.test_permute() - self.test_reshape() - self.test_slice_1() - self.test_expand() - self.test_permute() - -class TestVariableShrink(unittest.TestCase): - def test_shrink(self): - st = ShapeTracker.from_shape((10,)) - st = st.shrink(((0, Variable("i", 1, 10)),)) - assert len(st.views) == 1 - - def test_shrink_bound(self): - st = ShapeTracker.from_shape((10,)) - st = st.shrink(((0, Variable("i", 1, 10).bind(3)),)) - assert len(st.views) == 1 - -class TestVariableMerge(unittest.TestCase): - def test_add_reshape(self): - vi = Variable("i", 1, 10) - st1 = ShapeTracker.from_shape((vi,)) - st2 = ShapeTracker.from_shape((1, vi,)) - st = st1+st2 - assert len(st.views) == 1 - - def test_add_stride_0(self): - st1 = ShapeTracker.from_shape((3,), (0,)) - st2 = ShapeTracker.from_shape((Variable("i", 1, 10).bind(3),), (0,)) - st = st1+st2 - assert len(st.views) == 1, f"multiview {st}" - - def test_add_reshape_bound(self): - vi = Variable("i", 1, 10).bind(3) - st1 = ShapeTracker.from_shape((vi,)) - st2 = ShapeTracker.from_shape((1, vi,)) - st = st1+st2 - assert len(st.views) == 1 - - def test_simplify(self): - vi = Variable("i", 1, 10).bind(3) - st1 = ShapeTracker.from_shape((vi,)) - st2 = ShapeTracker.from_shape((1, vi,)) - st = ShapeTracker((st1.views[0], st2.views[0])) - st = st.simplify() - assert len(st.views) == 1 - -if __name__ == '__main__': - unittest.main() diff --git a/test/unit/test_shapetracker_math.py b/test/unit/test_shapetracker_math.py deleted file mode 100644 index 38808c2d23..0000000000 --- a/test/unit/test_shapetracker_math.py +++ /dev/null @@ -1,108 +0,0 @@ -import unittest -from tinygrad.helpers import prod -from tinygrad.shape.view import View -from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad import Variable -from test.unit.test_shapetracker import shapetracker_getitem - -class MultiShapeTracker: - def __init__(self, sts:list[ShapeTracker]): self.sts = sts - @property - def shape(self): return self.sts[0].shape - def reshape(self, arg): self.sts = [x.reshape(arg) for x in self.sts] - def permute(self, arg): self.sts = [x.permute(arg) for x in self.sts] - def expand(self, arg): self.sts = [x.expand(arg) for x in self.sts] - def shrink(self, arg): self.sts = [x.shrink(arg) for x in self.sts] - def flip(self, arg): self.sts = [x.flip(arg) for x in self.sts] - def pad(self, arg): self.sts = [x.pad(arg) for x in self.sts] - -def st_equal(st1:ShapeTracker, st2:ShapeTracker) -> bool: - if st1.shape != st2.shape: return False - if st1 == st2: return True - for i in range(0, prod(st1.shape)): - st1_off, st1_v = shapetracker_getitem(st1, i) - st2_off, st2_v = shapetracker_getitem(st2, i) - if st1_v != st2_v or (st1_off != st2_off and st1_v): - print(f"ST MISMATCH @ {i}, {st1_v=} != {st2_v=}, {st1_off=} != {st2_off=}") - print(st1) - print(st2) - return False - return True - -class TestShapeTrackerBasics(unittest.TestCase): - def test_pad_shrink_removes_mask(self): - a = ShapeTracker.from_shape((10, 10)) - a = a.pad(((0,2), (0,2))) - a = a.shrink(((0,10), (0,10))) - assert len(a.views) == 1 and a.views[-1].mask is None - - def test_pad_shrink_leaves_mask(self): - a = ShapeTracker.from_shape((10, 10)) - a = a.pad(((0,2), (0,2))) - a = a.shrink(((0,10), (0,11))) - assert len(a.views) == 1 and a.views[-1].mask is not None - - def test_reshape_makes_same(self): - a = ShapeTracker.from_shape((2, 5)) - x = a.pad( ((2, 0), (0, 0)) ) - x = x.reshape( (2, 2, 5) ) - x1 = x.reshape( (4, 5) ) - x1 = x1.reshape( (2, 2, 5) ) - assert x == x1.simplify() - - def test_simplify_is_correct(self): - multiv = ShapeTracker(views=(View(shape=(15, 3), strides=(9, 1), offset=6, mask=None, contiguous=False), - View(shape=(4, 3), strides=(12, 4), offset=0, mask=None, contiguous=False))) - assert st_equal(multiv, multiv.simplify()) - -class TestShapeTrackerAdd(unittest.TestCase): - def test_simple_add_reshape(self): - a = ShapeTracker.from_shape((10, 10)) - a = a.reshape((100,)) - b = ShapeTracker.from_shape((100,)) - assert a+b == b - - @unittest.skip("no longer simplifies") - def test_simple_add_permute(self): - a = ShapeTracker.from_shape((10, 10)) - a = a.permute((1,0)) - b = ShapeTracker.from_shape((10, 10)) - b = b.permute((1,0)) - assert a+b == ShapeTracker.from_shape((10, 10)) - - def test_plus_real1(self): - st = MultiShapeTracker([ShapeTracker.from_shape((15, 9))]) - st.shrink( ((0, 15), (6, 9)) ) - backup = st.sts[0] - st.sts.append(ShapeTracker.from_shape(backup.shape)) - st.reshape( (45,) ) - st.flip( (True,) ) - st.reshape( (15, 3) ) - assert st_equal(backup + st.sts[1], st.sts[0]) - - def test_off_by_one(self): - st1 = ShapeTracker(views=(View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True), - View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True))) - st2 = ShapeTracker(views=(View(shape=(4,), strides=(1,), offset=0, mask=None, contiguous=True), - View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True))) - assert not (st_equal(st1, st2)) - -class TestShapeTrackerAddVariable(unittest.TestCase): - def test_merge_symbolic_views(self): - var_i = Variable('i', 1, 10) - var_j = Variable('i', 1, 10) - vm1 = View(shape=(var_i, var_j, 3), strides=(3, 0, 1), offset=0, mask=None, contiguous=False) - vm2 = View(shape=(var_i, var_j, 3), strides=(var_j*3, 3, 1), offset=0, mask=None, contiguous=True) - ShapeTracker((vm1,)) + ShapeTracker((vm2,)) - - def test_merge_symbolic_views_2(self): - var_i = Variable('i', 1, 10) - var_j = Variable('j', 1, 10) - vm1 = View(shape=(var_i, var_j), strides=(0, 0), offset=0, mask=None, contiguous=False) - vm2 = View(shape=(var_i, var_j), strides=(var_j, 1), offset=0, mask=None, contiguous=True) - ret = (ShapeTracker((vm1,)) + ShapeTracker((vm2,))).reshape((var_i, var_j, 1)) - ret_2 = ShapeTracker((vm1,)) + ShapeTracker((vm2,)).reshape((var_i, var_j, 1)) - assert ret == ret_2 - -if __name__ == '__main__': - unittest.main() diff --git a/test/unit/test_symbolic_shapetracker.py b/test/unit/test_symbolic_shapetracker.py index 4f0824b947..8d876c2a9f 100644 --- a/test/unit/test_symbolic_shapetracker.py +++ b/test/unit/test_symbolic_shapetracker.py @@ -1,5 +1,4 @@ import unittest -from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad import Variable from tinygrad.tensor import Tensor @@ -7,40 +6,6 @@ class TestSymbolic(unittest.TestCase): def assert_tuple_equal(self, x, y): for a,b in zip(x,y): self.assertFalse(a != b) - def test_symbolic_st(self): - x = Variable("x", 1, 100) - st = ShapeTracker.from_shape((x, 3)) - self.assert_tuple_equal(st.shape, (x, 3)) - self.assert_tuple_equal(st.is_expanded(), (False, False)) - - def test_is_expanded_0(self): - st = ShapeTracker(views=(View(shape=(2, (Variable('start_pos', 1, 8)+1), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (Variable('start_pos', 1, 8)+1)), strides=((Variable('start_pos', 1, 8)+1), 1), offset=0, mask=None, contiguous=True))) # noqa: E501 - self.assert_tuple_equal(st.is_expanded(), (False, False)) - - def test_is_expanded_1(self): - st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+2)), strides=(Variable('i', 1, 10), 1), offset=0, mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501 - self.assert_tuple_equal(st.is_expanded(), (False, False)) - - def test_is_expanded_2(self): - st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+Variable('j', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=0, mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501 - self.assert_tuple_equal(st.is_expanded(), (False, False)) - - def test_merge_view_recursion_err(self): - vm2 = View(shape=(Variable('j', 1, 10),), strides=(0,), offset=0, mask=None, contiguous=False) - vm1 = View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True) - self.assertEqual(vm2+vm1, None) - - def test_merge_view_recursion_err2(self): - vm2 = View(shape=(Variable('a', 1, 10).bind(4),), strides=(0,), offset=0, mask=None, contiguous=False) - # NOTE: vm1 is different from what create function would give, and this test vm2+vm1 halts - vm1 = View(shape=(Variable('a', 1, 10).bind(4),), strides=(1,), offset=0, mask=((0, Variable('a', 1, 10).bind(4)),), contiguous=False) - self.assertEqual(vm2+vm1, None) - - vm3 = View.create(shape=(Variable('a', 1, 10).bind(4),)) - self.assertEqual(vm3.shape, vm1.shape) - self.assertEqual(vm3.strides, vm1.strides) - self.assertEqual(vm2+vm3, vm2) - def test_cat_dim0_is_expanded(self): i = Variable("i", 1, 5).bind(3) j = Variable("j", 1, 5).bind(3) @@ -59,46 +24,6 @@ class TestSymbolic(unittest.TestCase): class TestSymbolicVarVals(unittest.TestCase): def assert_equal(self, x, y): self.assertFalse(x != y) - def test_var_vals_empty(self): - assert ShapeTracker.from_shape((3, 4, 5)).var_vals == {} - - def test_var_vals_shape(self): - x = Variable("x", 1, 100).bind(3) - assert ShapeTracker.from_shape((x, 3)).var_vals == {"x": 3} - - def test_var_vals_offset(self): - x = Variable("x", 1, 100).bind(3) - st = ShapeTracker.from_shape((4, 3)).shrink(((x, x+1), (0, 3))) - self.assert_equal(st.views[-1].offset, x * 3) - assert st.var_vals == {"x": 3} - - def test_var_vals_mask(self): - x = Variable("x", 1, 100).bind(3) - view = View.create(shape=(3,4), strides=(4,1), offset=0, mask=((0, x), (0, 4))) - st = ShapeTracker(views=(view,)) - assert st.var_vals == {"x": 3} - - def test_var_vals_complex(self): - x = Variable("x", 1, 100).bind(3) - y = Variable("y", 1, 100).bind(4) - z = Variable("z", 1, 100).bind(5) - st = ShapeTracker.from_shape((x, 5, y)).shrink(((0, x), (z, z+1), (0, 3))) - self.assert_equal(st.views[-1].offset, y * z) - assert st.var_vals == {"x": 3, "y": 4, "z": 5} - - def test_shrink_reshape(self): - x = Variable("x", 1, 100).bind(3) - st = ShapeTracker.from_shape((10, 10, 10)).shrink(((x, x+3), (3, 7), (2, 5))) - st = st.reshape((3*4*3,)) - assert st.var_vals == {"x": 3} - -class TestShapeTrackerUnbind(unittest.TestCase): - def test_view_unbind(self): - v = Variable("v", 1, 100) - bv = Variable("v", 1, 100).bind(3) - unbound_view, var_val = View.create(shape=(bv, 4)).unbind() - assert unbound_view == View.create(shape=(v, 4)) - assert var_val == {v: 3} def test_shrink_unbind(self): v = Variable("v", 1, 100) @@ -137,17 +62,6 @@ class TestSymbolicReshape(unittest.TestCase): ret = ret.reshape(1, vi*vj) assert ret.shape == (1, vi*vj) - def test_symbolic_mask(self): - # taken from gpt2 single kvcache - # these two caused problems in gpt2 if reshape merged views - view = View(shape=(1, (Variable('start_pos', 1, 128).bind(2)+1), 16, 64), strides=(0, 0, 64, 1), offset=1024, mask=((0, 1), (Variable('start_pos', 1, 128).bind(2), (Variable('start_pos', 1, 128).bind(2)+1)), (0, 16), (0, 64)), contiguous=False) # noqa: E501 - new_shape = (1, 1, (Variable('start_pos', 1, 128).bind(2)+1), 16, 64) - assert view.reshape(new_shape) is None - - view = View(shape=(2, 1, (Variable('start_pos', 1, 128)+1), 16, 64), strides=(0, 0, 1024, 64, 1), offset=131072, mask=((1, 2), (0, 1), (0, (Variable('start_pos', 1, 128)+1)), (0, 16), (0, 64)), contiguous=False) # noqa: E501 - new_shape = (2, (Variable('start_pos', 1, 128)+1), 16, 64) - assert view.reshape(new_shape) is None - class TestSymbolicExpand(unittest.TestCase): def test_expand_into_symbols(self): vi = Variable("i", 1, 5).bind(3) @@ -190,6 +104,5 @@ class TestSymbolicPad(unittest.TestCase): t = t[:9] assert t.tolist() == [0,0,0,0,1,1,1,1,1] - if __name__ == '__main__': unittest.main() diff --git a/test/unit/test_view.py b/test/unit/test_view.py deleted file mode 100644 index 440755ceba..0000000000 --- a/test/unit/test_view.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python -import unittest -from tinygrad.shape.view import View, merge_dims -# from tinygrad.shape.shapetracker import ShapeTracker - -class TestView(unittest.TestCase): - def test_canonicalize_empty_mask(self): - v = View.create(shape=(2,2,2), strides=(4,2,1), mask=((0,2),(0,2),(0,2))) - self.assertIsNone(v.mask) - v = View.create(shape=(4,3,2), strides=(1,4,10), mask=((0,4),(0,3),(0,2))) - self.assertIsNone(v.mask) - - def test_empty_mask_contiguous(self): - v1 = View.create(shape=(2,2,2), strides=(4,2,1), mask=None) - v2 = View.create(shape=(2,2,2), strides=(4,2,1), mask=((0,2),(0,2),(0,2))) - self.assertEqual(v1.contiguous, v2.contiguous) - v1 = View.create(shape=(1,1,1,4), strides=(0,0,0,1), offset=0, mask=None) - v2 = View.create(shape=(1,1,1,4), strides=(0,0,0,1), offset=0, mask=((0,1),(0,1),(0,1),(0,4))) - self.assertEqual(v1.contiguous, v2.contiguous) - v = View.create(shape=(2,3,4), mask=((0,2),(0,3),(0,4))) - self.assertTrue(v.contiguous) - - def test_reshape_all_invalid(self): - v = View.create((4,5), mask=((0,0), (0,0))).reshape((20,)) - self.assertIsNotNone(v) - self.assertEqual(v, View.create((20,), mask=((0,0),))) - - def test_add_0(self): - v1 = View.create((2,3,4)) - v2 = View.create((2,0,4)) - self.assertEqual(v2, v1+v2) - - def test_add_0_masked(self): - v1 = View.create((2,3,4), mask=((0, 0), (0, 0), (0, 0))) - v2 = View.create((2,0,4)) - self.assertEqual(v2, v1+v2) - -class TestMergeDims(unittest.TestCase): - def test_contiguous(self): - shape = (2, 3, 4) - strides = (12, 4, 1) #=strides_for_shape(shape) - m = merge_dims(shape, strides) - self.assertEqual(m, ((24, 1, 24),)) - - def test_0_in_strides(self): - shape = (2, 3, 4) - self.assertEqual(merge_dims(shape, (0, 4, 1)), ((2, 0, 0), (12, 1, 12))) - self.assertEqual(merge_dims(shape, (0, 0, 1)), ((6, 0, 0), (4, 1, 4))) - self.assertEqual(merge_dims(shape, (3, 1, 0)), ((6, 1, 6), (4, 0, 4))) - self.assertEqual(merge_dims(shape, (0, 0, 0)), ((24, 0, 0),)) - - def test_pad(self): - # print(ShapeTracker.from_shape((1, 2)).pad(((1, 0), (0, 1))).views[-1]) - self.assertEqual(merge_dims((2, 3), (0, 1), ((1, 2), (0, 2))), ((6, 1, 3),)) - - # print(f"{ShapeTracker.from_shape((1, 1, 2)).pad(((1, 0), (1, 0), (0, 1))).views[-1]}") - self.assertEqual(merge_dims((2, 2, 3), (0, 0, 1), ((1, 2), (1, 2), (0, 2))), ((12, 1, 3),)) - - # print(f"{ShapeTracker.from_shape((1, 1, 2, 2)).pad(((1, 0), (1, 0), (0, 1), (0, 1))).views[-1]}") - self.assertEqual(merge_dims((2, 2, 3, 3), (0, 0, 2, 1), ((1, 2), (1, 2), (0, 2), (0, 2))), ((12, 2, 3), (3, 1, 3))) - - # print(f"{ShapeTracker.from_shape((2, 1, 2)).pad(((0, 0), (1, 0), (0, 1))).views[-1]}") - self.assertEqual(merge_dims((2, 2, 3), (2, 0, 1), ((0, 2), (1, 2), (0, 2))), ((2, 2, 2), (6, 1, 3))) - - def test_different_1_pad(self): - # print(f"{ShapeTracker.from_shape((2, 2, 1)).pad(((0, 0), (0, 0), (0, 1))).views[-1]}") - self.assertEqual(merge_dims((2, 2, 2), (2, 1, 0), ((0, 2), (0, 2), (0, 1))), ((4, 1, 4), (2, 0, 2))) - - # print(f"{ShapeTracker.from_shape((2, 1, 1)).pad(((0, 0), (0, 1), (0, 1))).views[-1]}") - self.assertEqual(merge_dims((2, 2, 2), (1, 0, 0), ((0, 2), (0, 2), (0, 1))), ((2, 1, 2), (4, 0, 4))) - -if __name__ == '__main__': - unittest.main() diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 6d38681dcf..7b3936b20d 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -2,7 +2,7 @@ from __future__ import annotations import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass import urllib.request, subprocess, shutil, math, types, copyreg, inspect, importlib, decimal, itertools from dataclasses import dataclass, field -from typing import ClassVar, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic, Generator +from typing import ClassVar, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic, Generator, cast T = TypeVar("T") U = TypeVar("U") @@ -86,6 +86,16 @@ def word_wrap(x, wrap=80): return x[:i] + "\n" + word_wrap(x[i:], wrap) def pad_bytes(b:bytes, align:int) -> bytes: return b + b'\x00' * ((align - (len(b) % align)) % align) +@functools.cache +def canonicalize_strides(shape:tuple[T, ...], strides:tuple[T, ...]) -> tuple[T, ...]: + return tuple(cast(T, 0) if s == 1 else st for s, st in zip(shape, strides)) + +@functools.cache +def strides_for_shape(shape:tuple[T, ...]) -> tuple[T, ...]: + if not shape: return () + strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))[::-1] + return canonicalize_strides(shape, strides) + # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape def get_contraction(old_shape:tuple[T, ...], new_shape:tuple[T, ...]) -> list[list[int]]|None: # T is sint acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul)) diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 110da5ecd7..51f811b863 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -3,8 +3,7 @@ from collections import OrderedDict from typing import Any, Callable, BinaryIO, Iterable from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes -from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T -from tinygrad.shape.view import strides_for_shape +from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T, strides_for_shape class TensorIO(io.RawIOBase, BinaryIO): def __init__(self, t: Tensor): diff --git a/tinygrad/shape/__init__.py b/tinygrad/shape/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py deleted file mode 100644 index b4e0b584f0..0000000000 --- a/tinygrad/shape/shapetracker.py +++ /dev/null @@ -1,81 +0,0 @@ -# ShapeTracker allows movement operations to a buffer that don't require a copy to be made. -from __future__ import annotations -from dataclasses import dataclass -import functools -from typing import Callable -from tinygrad.helpers import merge_dicts, getenv -from tinygrad.shape.view import View, unravel -from tinygrad.uop.symbolic import sym -from tinygrad.uop.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context - -@functools.cache -def views_to_valid_uop(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=None) -> UOp: - idx = views[-1].to_valid_uop(_idxs) - for view in reversed(views[0:-1]): - idx = view.to_valid_uop([sint_to_uop(i) for i in unravel(view.shape, idx)]) - with Context(TRACK_MATCH_STATS=0): - return graph_rewrite(idx, sym, name="indexing sym @ 1") - -@functools.cache -def views_to_is_expanded(views: tuple[View, ...]) -> tuple[bool, ...]: - # NOTE: return if each dim is expanded - if len(views) == 1 and views[-1].mask is None: return tuple([bool(st==0) for st in views[-1].strides]) - idx = views_to_valid_uop(views).get_idx() - used_ranges = [x.arg[0] for x in idx.toposort() if x.op is Ops.RANGE] - return tuple([i not in used_ranges for i in range(len(views[-1].shape))]) - -@dataclass(frozen=True, order=True) -class ShapeTracker: - views: tuple[View, ...] - - def __add__(self, st:ShapeTracker) -> ShapeTracker: - ret = self - for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification - return ret - - @staticmethod - def from_shape(shape:tuple[sint, ...], strides:tuple[sint, ...]|None=None) -> ShapeTracker: return ShapeTracker((View.create(shape, strides),)) - - @property - def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous - - @property - def shape(self) -> tuple[sint, ...]: return self.views[-1].shape - - @property - def size(self) -> int: return self.views[-1].size() - - def vars(self) -> set[Variable]: return set().union(*[v.vars() for v in self.views]) - - @property - def var_vals(self) -> dict[str, int]: return merge_dicts([{(vu:=v.unbind())[0].expr:vu[1]} for v in self.vars()]) - - def unbind(self) -> tuple[ShapeTracker, dict[Variable, int]]: - unbound_views, var_vals = zip(*[v.unbind() for v in self.views]) - if all(len(x) == 0 for x in var_vals): return self, {} - return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals) - - def is_expanded(self) -> tuple[bool, ...]: - with Context(TRACK_MATCH_STATS=0): return views_to_is_expanded(self.views) - - def simplify(self) -> ShapeTracker: - if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: - return ShapeTracker(self.views[:-2] + (new_view,)).simplify() - return self - - # *** under this line are the movement ops *** - - def pad(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), )) - def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), )) - def expand(self, new_shape: tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), )) - def permute(self, axis: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), )) - def flip(self, mul: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].flip(mul), )) - - def reshape(self, new_shape: tuple[sint, ...]) -> ShapeTracker: - if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,)) - return ShapeTracker(self.views + (View.create(new_shape), )) - - def mop(self, op, arg): return mops[op](self, arg) - -mops: dict[Ops, Callable] = {Ops.RESHAPE: ShapeTracker.reshape, Ops.PERMUTE: ShapeTracker.permute, Ops.EXPAND: ShapeTracker.expand, - Ops.SHRINK: ShapeTracker.shrink, Ops.FLIP: ShapeTracker.flip, Ops.PAD: ShapeTracker.pad} diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py deleted file mode 100644 index 9b5e489ee9..0000000000 --- a/tinygrad/shape/view.py +++ /dev/null @@ -1,261 +0,0 @@ -from __future__ import annotations -import functools, operator, itertools -from dataclasses import dataclass -from typing import cast, Sequence -from tinygrad.dtype import dtypes -from tinygrad.uop.ops import resolve, UOp, Variable, sint, smax, smin, sint_to_uop, Ops, ssimplify -from tinygrad.helpers import prod, all_int, flatten - -@functools.cache -def canonicalize_strides(shape:tuple[sint, ...], strides:tuple[sint, ...]) -> tuple[sint, ...]: - return tuple(0 if s == 1 else st for s, st in zip(shape, strides)) - -@functools.cache -def strides_for_shape(shape:tuple[sint, ...]) -> tuple[sint, ...]: - if not shape: return () - strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))[::-1] - return canonicalize_strides(shape, strides) - -@functools.cache -def merge_dims(shape:tuple[int, ...], strides:tuple[int, ...], mask:tuple[tuple[int, int], ...]|None=None) -> tuple[tuple[int, int, int], ...]: - # merge contiguous sub-parts or zero strided dims - # any stride 0, masked from dim=1, or contiguous part is merged into next dim. - # stride != 0 to stride == 0 starts a new merging block - # ret = tuple[(merged_size, stride, merged size w/o zero stride), ...] - if not shape: return () - assert len(shape) == len(strides) and (mask is None or len(shape) == len(mask)) - ret = [(shape[0], strides[0], shape[0] if strides[0] != 0 else 0)] - # merge this dim to next dim if size is 1 - merging = (mask[0][1] - mask[0][0] == 1) if mask is not None else shape[0] == 1 - for i, (s, st) in enumerate(zip(shape[1:], strides[1:]), start=1): - # always merge 1 - if s == 1: continue - last_s, last_st, last_pre_expand_s = ret[-1] - # merge last dim with this dim if merging or strides matched - if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s)) - else: ret.append((s, st, s)) - # merge this dim to next dim if size is 1 - merging = (mask[i][1] - mask[i][0] == 1) if mask is not None else s == 1 - return tuple(ret) - -@functools.cache -def _reshape_mask(_mask:tuple[tuple[sint, sint], ...]|None, old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) \ - -> tuple[tuple[sint, sint], ...]|None: - """Returns the new mask if reshape is possible, and None if not possible.""" - if _mask is None: return tuple((0, s) for s in new_shape) - if not all_int(flatten(_mask)): return None - - new_mask: list[tuple[int, int]] = [] - # _mask is all int here - r_masks, r_shape, r_new_shape = reversed(cast(tuple[tuple[int, int], ...], _mask)), reversed(old_shape), reversed(new_shape) - curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1)) - - while len(new_mask) < len(new_shape): - (l, r), next_stride = mask, ssimplify(new_dim * curr_stride) - - # need to split mask - if old_dim == next_stride: # simply copy the mask and get next batch for merging - new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1)) - curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1)) - elif old_dim > next_stride: # mask can only be splitted if reshape doesn't cut across the mask. - if old_dim % next_stride != 0: return None - if (l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride: return None - new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1)) - curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension - else: - next_mask = next(r_masks, (0, 1)) - # combine if the mask can unfold continuously - if mask != (0, old_dim) and l != r and next_mask[1] - next_mask[0] != 1: return None - mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), ssimplify(old_dim * next(r_shape, 1)) - - return tuple(reversed(new_mask)) - -def unravel(shape:tuple[sint, ...], offset:sint) -> list[sint]: - # find the position of offset on each dimension based on shape - # similar to unravel_index in numpy/torch - acc, idxs = 1, [] - for d in reversed(shape): - idxs.append((offset//acc)%d) - acc *= d - return idxs[::-1] - -@dataclass(frozen=True) -class View: - shape:tuple[sint, ...] - strides:tuple[sint, ...] - offset:sint - mask:tuple[tuple[sint, sint], ...]|None - contiguous:bool - - def to_valid_uop(self, idxs:Sequence[UOp]|None=None) -> UOp: - """valid.where(idx, INVALID)""" - if idxs is None: idxs = [UOp.range(s, i) for i,s in enumerate(self.shape)] - iexpr = sint_to_uop(self.offset) - where = UOp.const(dtypes.bool, True) - for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else itertools.repeat(None)): - iexpr = iexpr + idx*sint_to_uop(st) - if m is not None: - if resolve(m[0] != 0): where &= (idx >= sint_to_uop(m[0])) - if resolve(m[1] != sh): where &= (idx < sint_to_uop(m[1])) - return where.where(iexpr, UOp.invalid()) - - @functools.cache # pylint: disable=method-cache-max-size-none - def size(self) -> int: - ret = prod([x.vmax if isinstance(x, UOp) else x for x in self.shape]) - assert isinstance(ret, int), f"{ret=} is not int" - return ret - - @staticmethod - @functools.cache - def create(shape:tuple[sint, ...], strides:tuple[sint, ...]|None=None, offset:sint=0, mask:tuple[tuple[sint, sint], ...]|None=None): - # TODO: resolve shouldn't be needed here - if not all(resolve(s >= 0) for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}") - strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape) - # canonicalize 0 in shape - if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True) - # canonicalize no-op mask - if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None - # if any dimension has size >1, but is masked such that only one index in the dimension is unmasked - # then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset - if mask and any(elim := [not resolve(b+1 < e) for b,e in mask]): - if any(not resolve(b < e) for b,e in mask): - strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape) - offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim)) - strides = tuple(0 if e else st for st,e in zip(strides, elim)) - # simplify as we go - if isinstance(offset, UOp): offset = cast(sint, offset.ssimplify()) - shape = tuple(cast(sint, x.ssimplify()) if isinstance(x, UOp) else x for x in shape) - # TODO: enabling stride simplification breaks symbolic jit - """ - strides = tuple(x.ssimplify() if isinstance(x, UOp) else x for x in strides) - if mask: mask = tuple((s.ssimplify() if isinstance(s, UOp) else s, e.ssimplify() if isinstance(e, UOp) else e) for s,e in mask) - """ - contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape) - return View(shape, strides, offset, mask, contiguous) - - @functools.cache # pylint: disable=method-cache-max-size-none - def vars(self) -> set[Variable]: - flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple() - return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, UOp)], set()) - - @functools.cache # pylint: disable=method-cache-max-size-none - def unbind(self) -> tuple[View, dict[Variable, int]]: - var_unboundvar_val = [(v, v.unbind()) for v in self.vars() if v.op is Ops.BIND] - unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val} - return self.substitute(unbound_vars), dict(x[1] for x in var_unboundvar_val) - - def substitute(self, dvars:dict[UOp, UOp]): - def _substitute(x:sint): return x if isinstance(x, int) else x.substitute(dvars) - new_shape = tuple(map(_substitute, self.shape)) - new_strides = tuple(map(_substitute, self.strides)) - new_offset = _substitute(self.offset) - new_mask = tuple((_substitute(x[0]), _substitute(x[1])) for x in self.mask) if self.mask is not None else None - return View.create(new_shape, new_strides, new_offset, new_mask) - - @functools.cache # pylint: disable=method-cache-max-size-none - def __add__(self, vm1:View) -> View|None: - vm2 = self - if vm2.contiguous or vm1.size() == 0: return vm1 - if vm1.contiguous and vm1.shape == vm2.shape: return vm2 - if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret - if vm1.mask: - if (new_vm1 := vm1.shrink(vm1.mask)) == vm1 or (merged := vm2 + new_vm1) is None: return None - return merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape))) - if not all_int(vm1.shape): - # if all strides are 0 and vm2 is unmasked, return vm1 - if all(x == 0 for x in vm2.strides+vm1.strides) and vm2.mask is None: return vm1 - return None - - # Project vm1's offset and strides on to vm2. - origin = [ssimplify(o) for o in unravel(vm2.shape, vm1.offset)] - terms: list[list[tuple[int, sint]]] = [[] for _ in vm2.shape] - strides: list[sint] = [0] * len(vm1.shape) - for d1, st in enumerate(vm1.strides): - if st == 0: continue - for d2, (o, s1) in enumerate(zip(origin, unravel(vm2.shape, vm1.offset + st))): - if not resolve((s1 := s1 - o)!=0): continue # if s1 can possibly be 0 - terms[d2].append((d1, s1)) - strides[d1] += ssimplify(s1 * vm2.strides[d2]) - return None - - def __unsafe_resize(self, arg: tuple[tuple[sint, sint], ...], mask=None) -> View: - offset = sum([s * x[0] for s, x in zip(self.strides,arg)]) - if self.mask: - # move the old mask - nmask = tuple([(smax(0, smin(mx-ax,ay-ax)), smax(0, smin(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)]) - # merge the masks if we have two - mask = tuple([(smax(mx1, mx2), smin(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask - return View.create(tuple([y-x for x,y in arg]), self.strides, self.offset+offset, mask) - - @functools.cache # pylint: disable=method-cache-max-size-none - def pad(self, arg: tuple[tuple[sint, sint], ...]) -> View: - assert len(arg) == len(self.shape), f"invalid pad {arg} for {self.shape}" - # NOTE: not checking for symbolic arg - for b,e in arg: assert not all_int([b,e]) or b>=0 and e>=0, f"invalid pad {arg} for {self.shape}" - if any(resolve(b!=0) or resolve(e!=0) for b, e in arg): - zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)]) - mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)]) - return self.__unsafe_resize(zvarg, mask=mask) - return self - - @functools.cache # pylint: disable=method-cache-max-size-none - def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> View: - assert len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}" - # NOTE: not checking for symbolic arg - for s,(b,e) in zip(self.shape,arg): assert not all_int([s,b,e]) or (0<=b<=e<=s), f"invalid shrink {arg} for {self.shape}" - return self.__unsafe_resize(arg) - - @functools.cache # pylint: disable=method-cache-max-size-none - def expand(self, new_shape: tuple[sint, ...]) -> View: - if len(new_shape) != len(self.shape): raise ValueError(f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}") - # NOTE: does not check multiple of symbolic shape - assert all(resolve(s == ns) or s == 1 for s,ns in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}" - if 0 in self.shape: return View.create(new_shape) - # TODO: resolve may not be needed, but it's hard because vars need to be canonicalized - mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if resolve(s != ns) and resolve(s == 1, False) else m) \ - for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None - return View.create(new_shape, self.strides, self.offset, mask) - - @functools.cache # pylint: disable=method-cache-max-size-none - def permute(self, axis: tuple[int, ...]) -> View: - assert sorted(axis) == list(range(len(self.shape))), f"invalid permutation {axis} of len {len(self.shape)}" - return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset, - tuple(self.mask[a] for a in axis) if self.mask is not None else None) - - @functools.cache # pylint: disable=method-cache-max-size-none - def flip(self, arg: tuple[bool, ...]) -> View: - offset = sum((s-1)*z for s,z,f in zip(self.shape, self.strides, arg) if f) - mask = tuple((s-my,s-mx) if f else (mx,my) for (mx,my),s,f in zip(self.mask, self.shape, arg)) if self.mask is not None else None - return View.create(self.shape, tuple(-z if f else z for z,f in zip(self.strides, arg)), self.offset+offset, mask) - - @functools.cache # pylint: disable=method-cache-max-size-none - def reshape(self, new_shape: tuple[sint, ...]) -> View|None: - if self.shape == new_shape: return self - - if not all(x >= 0 for x in new_shape): raise ValueError(f"shape can't contain negative numbers {new_shape}") - # check for the same size - if resolve(prod(self.shape) != prod(new_shape), True): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}") - - if 0 in self.shape: return View.create(new_shape) - if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None - - # after the asserts, it's okay to check contiguous - if self.contiguous: return View.create(new_shape) - - r_strides, r_new_shape = [], reversed(new_shape) - for merged_size, new_stride, real_size in reversed(merge_dims(self.shape, self.strides, self.mask)): - acc = 1 - # TODO: third resolve shouldn't be needed - while resolve(acc <= merged_size) and resolve(acc != merged_size) and resolve((new_dim := next(r_new_shape, 0)) > 0): - r_strides.append(new_stride * acc) - acc = acc * new_dim - if not resolve(acc < real_size): new_stride = 0 - if resolve(acc != merged_size): return None - new_strides = (0,) * (len(new_shape) - len(r_strides)) + tuple(r_strides[::-1]) - - if (new_mask:=_reshape_mask(self.mask, self.shape, new_shape)) is not None: - extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \ - (sum(m[0] * s for m,s in zip(new_mask, new_strides))) - return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask) - - return None