diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0b4d9a4e2b..38d452c636 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -180,12 +180,12 @@ jobs: run: | GPU=1 IMAGE=1 python3 test/test_ops.py FORWARD_ONLY=1 GPU=1 IMAGE=2 python3 test/test_ops.py - - name: Test openpilot model + - name: Test openpilot model compile and size run: | - ALLOWED_KERNEL_COUNT=199 FLOAT16=1 VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py + ALLOWED_KERNEL_COUNT=199 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py python3 -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000' - DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py - VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py + - name: Test openpilot model correctness (float32) + run: DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py testmetal: name: Metal Tests diff --git a/test/external/fuzz_symbolic.py b/test/external/fuzz_symbolic.py index e5119acac6..7b565c467f 100644 --- a/test/external/fuzz_symbolic.py +++ b/test/external/fuzz_symbolic.py @@ -21,6 +21,14 @@ def add_num(expr, rng=None): if rng is None: rng = random.randint(-4,4) return expr + rng, rng +def lt(expr, rng=None): + if rng is None: rng = random.randint(-4,4) + return expr < rng, rng + +def ge(expr, rng=None): + if rng is None: rng = random.randint(-4,4) + return expr >= rng, rng + if __name__ == "__main__": ops = [add_v, div, mul, add_num] while 1: @@ -29,6 +37,9 @@ if __name__ == "__main__": u3 = Variable("v3", 0, 4) v = [u1,u2,u3] tape = [random.choice(ops) for _ in range(20)] + # 10% of the time, add a less than or greater than + if random.random() < 0.05: tape.append(lt) + elif random.random() < 0.05: tape.append(ge) expr = Variable.num(0) rngs = [] for t in tape: diff --git a/test/test_ops.py b/test/test_ops.py index 8ecda6248a..3f5c433a99 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -147,7 +147,7 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65), (45,65)], torch.minimum, Tensor.minimum) helper_test_op([(), ()], torch.minimum, Tensor.minimum) def test_add(self): - helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add) + helper_test_op([(45,68), (45,68)], lambda x,y: x+y, Tensor.add) def test_add_number(self): helper_test_op([(), ()], lambda x,y: x+y, Tensor.add) def test_add3(self): @@ -167,6 +167,7 @@ class TestOps(unittest.TestCase): helper_test_op([()], lambda x: -x) def test_mul(self): helper_test_op([(64,64), (64,64)], lambda x,y: x*y, Tensor.mul) + def test_mul_number(self): helper_test_op([(), ()], lambda x,y: x*y, Tensor.mul) def test_mul_const(self): helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2) diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index 8bc0991566..b8c172783a 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -87,6 +87,15 @@ class TestRealDoesntSimplify(unittest.TestCase): View((4, 4, 3, 3), (36, 9, 3, 1), 0, None)]) assert self.st.real_strides() == (None, 18, -3, -1) +class TestRealStrides(unittest.TestCase): + def test_1(self): + self.st = ShapeTracker((16, 32, 4), views=[ + View((2048,), (1,), 0, ((0, 512),)), + View((16, 32, 4), (128, 4, 1), 0, None)]) + st = self.st.real_strides() + print(self.st, st) + assert st == (None, 4, 1) + class TestRealSimplifies(unittest.TestCase): def tearDown(self): st = self.st.real_strides() @@ -105,7 +114,6 @@ class TestRealSimplifies(unittest.TestCase): View((8, 3, 3, 11, 2, 28), (924, 308, 0, 28, 0, 1), 0, None), View((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)]) - class TestSimplifyingShapeTracker(unittest.TestCase): def setUp(self): self.st = CheckingShapeTracker((1, 10)) diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index eed3fc2c0c..bd60f130e3 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -26,7 +26,7 @@ class TestSymbolic(unittest.TestCase): def test_ge_divides(self): expr = (Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512 - self.helper_test_variable(expr, 0, 1, "(((idx*4)+FLOAT4_INDEX)<512)") + self.helper_test_variable(expr, 0, 1, "((idx*4)<512)") self.helper_test_variable(expr//4, 0, 1, "(idx<128)") def test_ge_divides_and(self): @@ -37,6 +37,10 @@ class TestSymbolic(unittest.TestCase): (Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7)) < 512]) self.helper_test_variable(expr//4, 0, 1, "((((FLOAT8_INDEX//4)+idx2)<128) and (idx1<128))") + def test_lt_factors(self): + expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256)) < 512]) + self.helper_test_variable(expr, 0, 1, "(((idx1*4)+FLOAT4_INDEX)<512)") + def test_div_becomes_num(self): assert isinstance(Variable("a", 2, 3)//2, NumNode) diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py index 963c260239..d2b5dc5b79 100644 --- a/tinygrad/codegen/cstyle.py +++ b/tinygrad/codegen/cstyle.py @@ -125,19 +125,19 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY" else: val = f"{x}" + ("f" if not dtypes.is_int(bufs[args.i].dtype) else "") elif isinstance(bufs[args.i].dtype, ImageDType): - assert newvar.dtype == dtypes._float4, "image must be float4" + assert newvar.dtype == dtypes._float4, f"image must be float4 {newvar}" prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n") idx, idy = to_image_idx(bufs[args.i].dtype.shape, args.idx, args.valid) val = f"read_imagef({bufnames[args.i]}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}))" else: if lang.uses_vload and bufs[args.i].dtype == dtypes.float16: if newvar.dtype == dtypes._float4: - val = f"vload_half4({(args.idx//4).render(render_cl)}, {bufnames[args.i]})" + val = f"vload_half4(0, {bufnames[args.i]}+{(args.idx).render(render_cl)})" else: val = f"vload_half({args.idx.render(render_cl)}, {bufnames[args.i]})" else: if newvar.dtype == dtypes._float4: - val = f"({newvar.dtype.name})((({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}])" + val = f"({newvar.dtype.name})(*(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*)({bufnames[args.i]}+{args.idx.render(render_cl)})))" else: val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]" # NOTE: if min and max are both 0, it should be a CONST in the Linearizer @@ -159,8 +159,10 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan if isinstance(bufs[args[0]].dtype, ImageDType): idx, idy = to_image_idx(bufs[args.i].dtype.shape, args[1], args[2]) kk(f"write_imagef({bufnames[args.i]}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {vin[0].render()});") + elif lang.uses_vload and bufs[args.i].dtype == dtypes.float16: + kk(f"vstore_half4({vin[0].render()}, {args.idx.render(render_cl)}, {bufnames[args.i]});") else: - kk(f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}] = {vin[0].render()};") + kk(f"*(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*)({bufnames[args.i]}+{args.idx.render(render_cl)})) = ({bufs[args.i].dtype.name}4){vin[0].render()};") elif uop == UOps.DEFINE_LOCAL: kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];") else: diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 8a97b9f18c..4145e88589 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, TypeVar, Dict +from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, TypeVar, Dict, Iterator, Union, Sequence import itertools, math from collections import defaultdict from enum import Enum, auto @@ -9,7 +9,8 @@ from tinygrad.lazy import LazyBuffer from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps from tinygrad.runtime.lib import RawConst from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape -from tinygrad.shape.symbolic import Variable +from tinygrad.shape.symbolic import Variable, NumNode +VariableOrNum = Union[Variable, NumNode] class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto(); BARRIER = auto(); \ SPECIAL = auto(); DEFINE_REGISTER = auto(); LABEL = auto(); COND_BRANCH = auto() # noqa: E702 @@ -74,6 +75,10 @@ def get_grouped_maybe_float4(*values:List[Token], grouping_allowed=True): return zip(new_idxs, new_values) return zip([[i] for i in range(len(values[0]))], zip(*values)) +def expand_idxs(idxs:Sequence[VariableOrNum]) -> Iterator[Tuple[VariableOrNum, ...]]: + for x in itertools.product(*[[idx] if not isinstance(idx, Variable) or idx.expr is not None else [Variable.num(j) for j in range(idx.min, idx.max+1)] for idx in idxs[::-1]]): + yield x[::-1] + class MemOp(NamedTuple): i: int idx: Variable @@ -157,67 +162,55 @@ class Linearizer: acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))] return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])] - def _group_float4(self, i, store_offset): - store_offset_float4 = {} - float4_axis = (self.upcasted-1) - self.float4_axis(i)[0] - for uidxs, var in store_offset.items(): - if uidxs[float4_axis]%4 == 0: - store_offset_float4[uidxs] = [var] - else: - uidxs2 = list(uidxs) - uidxs2[float4_axis] -= uidxs2[float4_axis]%4 - store_offset_float4[tuple(uidxs2)].append(var) - return store_offset_float4 + def get_upcast_dim(self, i, amt=4): + should_upcast = self.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType)) + return [x for x in self.sts[i].unit_stride_axes() if should_upcast and x >= self.shape_len-self.upcasted and self.sts[i].shape[x] == amt] - def global_load(self, i, idxs:List[Variable], const=None) -> List[Token]: - load_offset: Dict[Tuple[int, ...], Any] = {uidxs:(dtypes.float,uidxs)+self.sts[i].expr_idxs(idxs+[Variable.num(x) for x in uidxs[::-1]]) for uidxs in self.shape_offsets(i)} - - # float4 grouping (optional) - should_upcast = self.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType)) and len(self.float4_axis(i)) == 1 - if should_upcast: - load_offset_new = {} - for k,out_tokens in self._group_float4(i, load_offset).items(): - idxs = [x[2]-out_tokens[0][2] for x in out_tokens] - valids_okay = all_same([x[3] for x in out_tokens]) or (all_same([x[3]//4 for x in out_tokens]) and (out_tokens[0][3]//4)*4 == out_tokens[0][3]) - if any([idx.min != idx.max or idx.min != val for idx,val in zip(idxs, range(4))]) or (out_tokens[0][2]//4)*4 != out_tokens[0][2] or not valids_okay: - # idxs not in order, valids don't match, or idx doesn't evenly divide 4. use normal float - for x in out_tokens: load_offset_new[x[1]] = x - else: - load_offset_new[k] = (dtypes._float4, [x[1] for x in out_tokens], out_tokens[0][2], out_tokens[0][3]) - load_offset = load_offset_new - - # do loads + def global_load(self, i, idxs:Sequence[VariableOrNum], const=None) -> List[Token]: + upcast_dim = self.get_upcast_dim(i) cache: Dict[str, Token] = {} - loaded = {} - for uidxs, (localtype, uidx_list, idx, valid) in load_offset.items(): + ret = [] + for _idx in expand_idxs(idxs): + if len(upcast_dim) == 1: + idx, valid = self.sts[i].expr_idxs((_idx[:upcast_dim[0]] + (Variable.num(0),) + _idx[upcast_dim[0]+1:])) + localtype = dtypes._float4 + # disallow unaligned access, fall back to float + if idx.render() != ((idx//4)*4).render(): + idx, valid = self.sts[i].expr_idxs(_idx) + localtype = dtypes.float + else: + idx, valid = self.sts[i].expr_idxs(_idx) + localtype = dtypes.float key = f"{localtype}{idx.render()}{valid.render()}" if key not in cache: - cache[key] = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{len(cache)}", localtype), [], MemOp(i, idx, valid)) if const is None else self.uop(UOps.CONST, Token(f"acc{mnum(i)}_{len(cache)}", localtype), [], const) - if localtype == dtypes._float4: - for j,uidx in enumerate(uidx_list): - loaded[uidx] = Token(cache[key].name, dtypes._float4, j) - else: - loaded[uidxs] = cache[key] - return [loaded[uidxs] for uidxs in self.shape_offsets(i)] + cache[key] = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{len(cache)}", localtype), [], MemOp(i, idx, valid)) if const is None else \ + self.uop(UOps.CONST, Token(f"acc{mnum(i)}_{len(cache)}", localtype), [], const) + ret.append(Token(cache[key].name, cache[key].dtype, _idx[upcast_dim[0]].b) if localtype == dtypes._float4 else cache[key]) + return ret - def global_store(self, i, idxs:List[Variable], store:List[Token], ssa) -> None: - store_offset: Dict[Tuple[int, ...], Token] = dict(zip(self.shape_offsets(i), store)) + def global_store(self, i, idxs:List[VariableOrNum], store:List[Token], ssa) -> None: + store_offset = dict(zip(expand_idxs(idxs), store)) - # float4 grouping (optional) - # TODO: why does this not work for float16? - should_upcast = self.supports_float4 and (self.bufs[i].dtype == dtypes.float32 or isinstance(self.bufs[i].dtype, ImageDType)) and len(self.float4_axis(i)) == 1 - if should_upcast: + # float4 grouping + upcast_dim = self.get_upcast_dim(i) + if len(upcast_dim) == 1: + grouped_store_offset = defaultdict(list) + for k in store_offset: + _idx = k[:upcast_dim[0]] + (Variable.num(0),) + k[upcast_dim[0]+1:] + grouped_store_offset[_idx].append(store_offset[k]) store_offset_new = {} - for k,out_tokens in self._group_float4(i, store_offset).items(): + for k,out_tokens in grouped_store_offset.items(): + idx, valid = self.sts[i].expr_idxs(k) + assert idx.render() == ((idx//4)*4).render(), "float4 stores are always aligned" + assert valid.min == 1, "stores are always valid" if all_same([x.name for x in out_tokens]) and tuple(range(4)) == tuple(x.offset for x in out_tokens): store_offset_new[k] = Token(out_tokens[0].name, dtypes._float4) else: store_offset_new[k] = self.uop(UOps.CAST, ssa("alu", dtypes._float4), out_tokens) store_offset = store_offset_new - # do stores - for uidxs, var in store_offset.items(): - self.uop(UOps.STORE, None, [var], MemOp(i, *self.sts[i].expr_idxs(idxs+[Variable.num(x) for x in uidxs[::-1]]))) + for idx, var in store_offset.items(): + self.uop(UOps.STORE, None, [var], MemOp(i, *self.sts[i].expr_idxs(idx))) def linearize(self): # uops @@ -254,7 +247,10 @@ class Linearizer: # local loop local_idxs = [Variable(f"lidx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce-self.local_dims, self.first_reduce+len(self.group_for_reduce))] self.uop(UOps.LOOP, None, [], (local_idxs, "local")) - gl_idxs = global_idxs + local_idxs + + # upcast indexes + full_upcast_idxs = [Variable(None, 0, s-1) for s in self.full_shape[self.shape_len-self.upcasted:]] + upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]] # reduce op fake_reduce_idxs = [] @@ -264,13 +260,13 @@ class Linearizer: fake_reduce_idxs = [x*0 for x in reduce_idxs] # define accumulator - acc = self.global_load(0, gl_idxs+fake_reduce_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) + acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) # reduce loop self.uop(UOps.LOOP, None, [], (reduce_idxs, "reduce")) # load earlybufs - loaded_buffers.update({b:self.global_load(i, gl_idxs+reduce_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs and i != 0}) + loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs and i != 0}) # run early AST (with reduce) self.ast_parse(self.reduceop, [acc[off] for off in self.acc_offsets(self.full_buf_index)], loaded_buffers, ssa, do_reduce=True) @@ -281,7 +277,7 @@ class Linearizer: # end the local loop, do the local reduce if self.group_for_reduce: fake_global_idxs = [x*0 for x in global_idxs] - self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs, acc, ssa) # store accumulators + self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc, ssa) # store accumulators self.uop(UOps.BARRIER, None, [], ()) self.uop(UOps.ENDLOOP, None, [], (local_idxs, "local")) @@ -294,18 +290,20 @@ class Linearizer: self.upcast() self.group_for_reduce.pop() local_idxs = local_idxs[:-1] + # regenerate upcast_idxs + upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]] # NOTE: this structure is the same as the reduce op above # define late accumulator - acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) + acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) # late reduce loop end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] self.uop(UOps.LOOP, None, [], (end_local_idxs, "late_reduce")) # load localbufs - loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs+fake_reduce_idxs) + loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs+fake_reduce_idxs+upcast_idxs) # there's no AST here (and there's no shape for the reduce LazyOp) self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, ssa, do_reduce=True) # type: ignore @@ -314,13 +312,13 @@ class Linearizer: self.uop(UOps.ENDLOOP, None, [], (end_local_idxs, "late_reduce")) # load latebufs - loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer}) + loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer}) # run late AST val = self.ast_parse(self.ast, acc, loaded_buffers, ssa) # store - self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs, val, ssa) + self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val, ssa) if not self.group_for_reduce: # end the global+local loop @@ -364,6 +362,9 @@ class Linearizer: @property def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True) + @property + def output_shape(self) -> Tuple[int, ...]: return self.sts[0].shape + @property def full_shape(self) -> Tuple[int, ...]: return self.sts[self.full_buf_index].shape @@ -443,7 +444,7 @@ class Linearizer: def simplify_merge_adjacent(self): if self.shape_len == 0: return - shapes, strides = [x.shape for x in self.sts], [x.views[-1].strides for x in self.sts] + shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts] # merge dimensions if we can, multi get_shape_strides # TODO: does this always preserve the reduce dimension, NO @@ -453,7 +454,7 @@ class Linearizer: can_merge = [] for j in range(len(shapes)): # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case - can_merge.append((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*strides[j][i]) or (strides[j][i] == 0 and rets[j][-1][1] == 0)) + can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0))) # more can merge than this mergeable = all(can_merge) and i != self.first_reduce for j in range(len(shapes)): @@ -467,7 +468,7 @@ class Linearizer: def required_optimizations(self, early_only=False): for buf_index,buf in enumerate(self.bufs): - unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes() if self.sts[buf_index].shape[i]%4 == 0] + unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0] if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType: assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}" if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 50ee975624..b02c150e2a 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -4,7 +4,7 @@ from enum import Enum, auto import functools from typing import Dict, Tuple, Union, List, Optional, Callable, cast from tinygrad.helpers import prod, DEBUG -from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node +from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode # these ops live here class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702 @@ -65,7 +65,7 @@ class View: # generate an expression if you have a single idx variable def expr_node(self, idx=None) -> Node: if idx is None: idx = Variable('idx', 0, prod(self.shape)) - ret = [Variable.num(self.offset)] + ret: List[Node] = [Variable.num(self.offset)] acc = 1 for d,s in reversed(self.shape_strides): ret.append(((idx//acc)%d)*s) @@ -157,26 +157,23 @@ class ShapeTracker: assert real_offset.__class__ is NumNode, f"how is the offset not a number? {real_offset} {mask}" return real_offset.b - def real_strides(self) -> Tuple[Optional[int], ...]: - if len(self.views) == 1: return self.views[-1].strides - ret: List[Optional[int]] = [] - acc, real_offset = 1, self.real_offset() - for s in reversed(self.shape): - if s == 1: # fast path, all shape 1 have stride 0 - ret.append(0) - continue - var = Variable('idx', 0, s-1) - this_dim, _ = self.expr_node(var*acc) - this_dim -= real_offset - acc *= s - # TODO: sometimes a mod here is okay if you are say, reading a float4, since you only care %4 - # if test.__class__ is ModNode and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay - if this_dim.__class__ is MulNode and cast(MulNode, this_dim).a.__class__ is Variable: ret.append(this_dim.b) - elif this_dim.__class__ is NumNode and this_dim.b == 0: ret.append(0) - elif this_dim.__class__ is Variable: ret.append(1) - else: ret.append(None) - return tuple(ret[::-1]) - def unit_stride_axes(self) -> List[int]: return [i for i,st in enumerate(self.real_strides()) if st == 1] + # NOTE: if a stride is not always valid, it will be None + def real_strides(self, ignore_valid=False) -> Tuple[Optional[int], ...]: + if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides + idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] + idx, valid = self.expr_idxs(idxs) + ret: List[Optional[int]] = [None for _ in self.views[-1].shape] + for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]): + if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable): + ret[idxs.index(this_dim.a)] = this_dim.b + elif isinstance(this_dim, Variable): + ret[idxs.index(this_dim)] = 1 + render_idx, render_valid = idx.render(), valid.render() + for i in range(len(self.shape)): + if f'idx{i}' in render_valid and not ignore_valid: ret[i] = None + elif f'idx{i}' not in render_idx: ret[i] = 0 + return tuple(ret) + def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1] def _expr_idx(self, idx, valid): for v in reversed(self.views[0:-1]): diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 969eb482c5..69b70278d9 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -2,7 +2,8 @@ from __future__ import annotations from abc import abstractmethod import functools from math import gcd -from typing import List, Dict, Callable, Tuple, Type, Union +from tinygrad.helpers import partition +from typing import List, Dict, Callable, Tuple, Type, Union, Optional # NOTE: Python has different behavior for negative mod and floor div than c # symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod @@ -25,8 +26,22 @@ class Node: def __neg__(self): return self*-1 def __add__(self, b:Union[Node, int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)]) def __sub__(self, b:Union[Node, int]): return self+-b - def __ge__(self, b:int): return create_node(LtNode(-self, -b+1)) - def __lt__(self, b:int): return create_node(LtNode(self, b)) + def __ge__(self, b:int): return (-self) < (-b+1) + def __lt__(self, b:int): + lhs = self + if isinstance(lhs, SumNode): + muls, others = partition(lhs.nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b) + if len(muls): + # NOTE: gcd in python 3.8 takes exactly 2 args + mul_gcd = muls[0].b + for x in muls[1:]: mul_gcd = gcd(mul_gcd, x.b) + if b%mul_gcd == 0: + all_others = Variable.sum(others) + #print(mul_gcd, muls, all_others) + if all_others.min >= 0 and all_others.max < mul_gcd: + # TODO: should we divide both by mul_gcd here? + lhs = Variable.sum(muls) + return create_node(LtNode(lhs, b)) def __mul__(self, b:int): if b == 0: return NumNode(0) elif b == 1: return self @@ -54,7 +69,7 @@ class Node: return create_node(ModNode(self, b)) @staticmethod - def num(num:int) -> Node: return NumNode(num) + def num(num:int) -> NumNode: return NumNode(num) @staticmethod def factorize(nodes:List[Node]): @@ -100,12 +115,12 @@ class Node: # 4 basic node types class Variable(Node): - def __new__(cls, expr:str, nmin:int, nmax:int): + def __new__(cls, expr:Optional[str], nmin:int, nmax:int): assert nmin >= 0 and nmin <= nmax if nmin == nmax: return NumNode(nmin) return super().__new__(cls) - def __init__(self, expr:str, nmin:int, nmax:int): + def __init__(self, expr:Optional[str], nmin:int, nmax:int): self.expr, self.min, self.max = expr, nmin, nmax class NumNode(Node): @@ -128,6 +143,7 @@ class LtNode(OpNode): def __mul__(self, b: int): return (self.a*b) < (self.b*b) def __floordiv__(self, b: int, _=False): return (self.a//b) < (self.b//b) def get_bounds(self) -> Tuple[int, int]: return int(self.a.max < self.b), int(self.a.min < self.b) + class MulNode(OpNode): def __mul__(self, b: int): return self.a*(self.b*b) # two muls in one mul def __floordiv__(self, b: int, factoring_allowed=False): # NOTE: mod negative isn't handled right @@ -139,11 +155,13 @@ class MulNode(OpNode): return Node.__mod__(a, b) def get_bounds(self) -> Tuple[int, int]: return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b) + class DivNode(OpNode): def __floordiv__(self, b: int, _=False): return self.a//(self.b*b) # two divs is one div def get_bounds(self) -> Tuple[int, int]: assert self.a.min >= 0 return self.a.min//self.b, self.a.max//self.b + class ModNode(OpNode): def __floordiv__(self, b: int, factoring_allowed=True): if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod @@ -184,7 +202,7 @@ class SumNode(RedNode): return Node.__floordiv__(self, b, factoring_allowed) def __mod__(self, b: int): - new_nodes = [] + new_nodes: List[Node] = [] for x in self.nodes: if x.__class__ is NumNode: new_nodes.append(Variable.num(x.b%b)) elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b))