new upcast works (#1066)

* new upcast works

* float4 try

* fix unaligned float4

* disallow unaligned access

* upcast dim

* maybe good now

* fix gpu half

* vstore_half4

* fix deep image bugs

* improve symbolic to fix issues

* fix symbolic

* cl test

* this maybe

* gcd of 1 is 1

* real fix for old python

* improve fuzzer
This commit is contained in:
George Hotz
2023-06-27 19:34:53 -07:00
committed by GitHub
parent 4d703be6d7
commit d16c16ec28
9 changed files with 143 additions and 101 deletions

View File

@@ -180,12 +180,12 @@ jobs:
run: |
GPU=1 IMAGE=1 python3 test/test_ops.py
FORWARD_ONLY=1 GPU=1 IMAGE=2 python3 test/test_ops.py
- name: Test openpilot model
- name: Test openpilot model compile and size
run: |
ALLOWED_KERNEL_COUNT=199 FLOAT16=1 VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
ALLOWED_KERNEL_COUNT=199 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
python3 -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
- name: Test openpilot model correctness (float32)
run: DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
testmetal:
name: Metal Tests

View File

@@ -21,6 +21,14 @@ def add_num(expr, rng=None):
if rng is None: rng = random.randint(-4,4)
return expr + rng, rng
def lt(expr, rng=None):
if rng is None: rng = random.randint(-4,4)
return expr < rng, rng
def ge(expr, rng=None):
if rng is None: rng = random.randint(-4,4)
return expr >= rng, rng
if __name__ == "__main__":
ops = [add_v, div, mul, add_num]
while 1:
@@ -29,6 +37,9 @@ if __name__ == "__main__":
u3 = Variable("v3", 0, 4)
v = [u1,u2,u3]
tape = [random.choice(ops) for _ in range(20)]
# 10% of the time, add a less than or greater than
if random.random() < 0.05: tape.append(lt)
elif random.random() < 0.05: tape.append(ge)
expr = Variable.num(0)
rngs = []
for t in tape:

View File

@@ -147,7 +147,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65), (45,65)], torch.minimum, Tensor.minimum)
helper_test_op([(), ()], torch.minimum, Tensor.minimum)
def test_add(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add)
helper_test_op([(45,68), (45,68)], lambda x,y: x+y, Tensor.add)
def test_add_number(self):
helper_test_op([(), ()], lambda x,y: x+y, Tensor.add)
def test_add3(self):
@@ -167,6 +167,7 @@ class TestOps(unittest.TestCase):
helper_test_op([()], lambda x: -x)
def test_mul(self):
helper_test_op([(64,64), (64,64)], lambda x,y: x*y, Tensor.mul)
def test_mul_number(self):
helper_test_op([(), ()], lambda x,y: x*y, Tensor.mul)
def test_mul_const(self):
helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2)

View File

@@ -87,6 +87,15 @@ class TestRealDoesntSimplify(unittest.TestCase):
View((4, 4, 3, 3), (36, 9, 3, 1), 0, None)])
assert self.st.real_strides() == (None, 18, -3, -1)
class TestRealStrides(unittest.TestCase):
def test_1(self):
self.st = ShapeTracker((16, 32, 4), views=[
View((2048,), (1,), 0, ((0, 512),)),
View((16, 32, 4), (128, 4, 1), 0, None)])
st = self.st.real_strides()
print(self.st, st)
assert st == (None, 4, 1)
class TestRealSimplifies(unittest.TestCase):
def tearDown(self):
st = self.st.real_strides()
@@ -105,7 +114,6 @@ class TestRealSimplifies(unittest.TestCase):
View((8, 3, 3, 11, 2, 28), (924, 308, 0, 28, 0, 1), 0, None),
View((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)])
class TestSimplifyingShapeTracker(unittest.TestCase):
def setUp(self):
self.st = CheckingShapeTracker((1, 10))

View File

@@ -26,7 +26,7 @@ class TestSymbolic(unittest.TestCase):
def test_ge_divides(self):
expr = (Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512
self.helper_test_variable(expr, 0, 1, "(((idx*4)+FLOAT4_INDEX)<512)")
self.helper_test_variable(expr, 0, 1, "((idx*4)<512)")
self.helper_test_variable(expr//4, 0, 1, "(idx<128)")
def test_ge_divides_and(self):
@@ -37,6 +37,10 @@ class TestSymbolic(unittest.TestCase):
(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7)) < 512])
self.helper_test_variable(expr//4, 0, 1, "((((FLOAT8_INDEX//4)+idx2)<128) and (idx1<128))")
def test_lt_factors(self):
expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256)) < 512])
self.helper_test_variable(expr, 0, 1, "(((idx1*4)+FLOAT4_INDEX)<512)")
def test_div_becomes_num(self):
assert isinstance(Variable("a", 2, 3)//2, NumNode)

View File

@@ -125,19 +125,19 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
else: val = f"{x}" + ("f" if not dtypes.is_int(bufs[args.i].dtype) else "")
elif isinstance(bufs[args.i].dtype, ImageDType):
assert newvar.dtype == dtypes._float4, "image must be float4"
assert newvar.dtype == dtypes._float4, f"image must be float4 {newvar}"
prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args.idx, args.valid)
val = f"read_imagef({bufnames[args.i]}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}))"
else:
if lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
if newvar.dtype == dtypes._float4:
val = f"vload_half4({(args.idx//4).render(render_cl)}, {bufnames[args.i]})"
val = f"vload_half4(0, {bufnames[args.i]}+{(args.idx).render(render_cl)})"
else:
val = f"vload_half({args.idx.render(render_cl)}, {bufnames[args.i]})"
else:
if newvar.dtype == dtypes._float4:
val = f"({newvar.dtype.name})((({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}])"
val = f"({newvar.dtype.name})(*(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*)({bufnames[args.i]}+{args.idx.render(render_cl)})))"
else:
val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]"
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
@@ -159,8 +159,10 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
if isinstance(bufs[args[0]].dtype, ImageDType):
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args[1], args[2])
kk(f"write_imagef({bufnames[args.i]}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {vin[0].render()});")
elif lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
kk(f"vstore_half4({vin[0].render()}, {args.idx.render(render_cl)}, {bufnames[args.i]});")
else:
kk(f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}] = {vin[0].render()};")
kk(f"*(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*)({bufnames[args.i]}+{args.idx.render(render_cl)})) = ({bufs[args.i].dtype.name}4){vin[0].render()};")
elif uop == UOps.DEFINE_LOCAL:
kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];")
else:

View File

@@ -1,4 +1,4 @@
from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, TypeVar, Dict
from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, TypeVar, Dict, Iterator, Union, Sequence
import itertools, math
from collections import defaultdict
from enum import Enum, auto
@@ -9,7 +9,8 @@ from tinygrad.lazy import LazyBuffer
from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps
from tinygrad.runtime.lib import RawConst
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
from tinygrad.shape.symbolic import Variable
from tinygrad.shape.symbolic import Variable, NumNode
VariableOrNum = Union[Variable, NumNode]
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto(); BARRIER = auto(); \
SPECIAL = auto(); DEFINE_REGISTER = auto(); LABEL = auto(); COND_BRANCH = auto() # noqa: E702
@@ -74,6 +75,10 @@ def get_grouped_maybe_float4(*values:List[Token], grouping_allowed=True):
return zip(new_idxs, new_values)
return zip([[i] for i in range(len(values[0]))], zip(*values))
def expand_idxs(idxs:Sequence[VariableOrNum]) -> Iterator[Tuple[VariableOrNum, ...]]:
for x in itertools.product(*[[idx] if not isinstance(idx, Variable) or idx.expr is not None else [Variable.num(j) for j in range(idx.min, idx.max+1)] for idx in idxs[::-1]]):
yield x[::-1]
class MemOp(NamedTuple):
i: int
idx: Variable
@@ -157,67 +162,55 @@ class Linearizer:
acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))]
return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])]
def _group_float4(self, i, store_offset):
store_offset_float4 = {}
float4_axis = (self.upcasted-1) - self.float4_axis(i)[0]
for uidxs, var in store_offset.items():
if uidxs[float4_axis]%4 == 0:
store_offset_float4[uidxs] = [var]
else:
uidxs2 = list(uidxs)
uidxs2[float4_axis] -= uidxs2[float4_axis]%4
store_offset_float4[tuple(uidxs2)].append(var)
return store_offset_float4
def get_upcast_dim(self, i, amt=4):
should_upcast = self.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType))
return [x for x in self.sts[i].unit_stride_axes() if should_upcast and x >= self.shape_len-self.upcasted and self.sts[i].shape[x] == amt]
def global_load(self, i, idxs:List[Variable], const=None) -> List[Token]:
load_offset: Dict[Tuple[int, ...], Any] = {uidxs:(dtypes.float,uidxs)+self.sts[i].expr_idxs(idxs+[Variable.num(x) for x in uidxs[::-1]]) for uidxs in self.shape_offsets(i)}
# float4 grouping (optional)
should_upcast = self.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType)) and len(self.float4_axis(i)) == 1
if should_upcast:
load_offset_new = {}
for k,out_tokens in self._group_float4(i, load_offset).items():
idxs = [x[2]-out_tokens[0][2] for x in out_tokens]
valids_okay = all_same([x[3] for x in out_tokens]) or (all_same([x[3]//4 for x in out_tokens]) and (out_tokens[0][3]//4)*4 == out_tokens[0][3])
if any([idx.min != idx.max or idx.min != val for idx,val in zip(idxs, range(4))]) or (out_tokens[0][2]//4)*4 != out_tokens[0][2] or not valids_okay:
# idxs not in order, valids don't match, or idx doesn't evenly divide 4. use normal float
for x in out_tokens: load_offset_new[x[1]] = x
else:
load_offset_new[k] = (dtypes._float4, [x[1] for x in out_tokens], out_tokens[0][2], out_tokens[0][3])
load_offset = load_offset_new
# do loads
def global_load(self, i, idxs:Sequence[VariableOrNum], const=None) -> List[Token]:
upcast_dim = self.get_upcast_dim(i)
cache: Dict[str, Token] = {}
loaded = {}
for uidxs, (localtype, uidx_list, idx, valid) in load_offset.items():
ret = []
for _idx in expand_idxs(idxs):
if len(upcast_dim) == 1:
idx, valid = self.sts[i].expr_idxs((_idx[:upcast_dim[0]] + (Variable.num(0),) + _idx[upcast_dim[0]+1:]))
localtype = dtypes._float4
# disallow unaligned access, fall back to float
if idx.render() != ((idx//4)*4).render():
idx, valid = self.sts[i].expr_idxs(_idx)
localtype = dtypes.float
else:
idx, valid = self.sts[i].expr_idxs(_idx)
localtype = dtypes.float
key = f"{localtype}{idx.render()}{valid.render()}"
if key not in cache:
cache[key] = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{len(cache)}", localtype), [], MemOp(i, idx, valid)) if const is None else self.uop(UOps.CONST, Token(f"acc{mnum(i)}_{len(cache)}", localtype), [], const)
if localtype == dtypes._float4:
for j,uidx in enumerate(uidx_list):
loaded[uidx] = Token(cache[key].name, dtypes._float4, j)
else:
loaded[uidxs] = cache[key]
return [loaded[uidxs] for uidxs in self.shape_offsets(i)]
cache[key] = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{len(cache)}", localtype), [], MemOp(i, idx, valid)) if const is None else \
self.uop(UOps.CONST, Token(f"acc{mnum(i)}_{len(cache)}", localtype), [], const)
ret.append(Token(cache[key].name, cache[key].dtype, _idx[upcast_dim[0]].b) if localtype == dtypes._float4 else cache[key])
return ret
def global_store(self, i, idxs:List[Variable], store:List[Token], ssa) -> None:
store_offset: Dict[Tuple[int, ...], Token] = dict(zip(self.shape_offsets(i), store))
def global_store(self, i, idxs:List[VariableOrNum], store:List[Token], ssa) -> None:
store_offset = dict(zip(expand_idxs(idxs), store))
# float4 grouping (optional)
# TODO: why does this not work for float16?
should_upcast = self.supports_float4 and (self.bufs[i].dtype == dtypes.float32 or isinstance(self.bufs[i].dtype, ImageDType)) and len(self.float4_axis(i)) == 1
if should_upcast:
# float4 grouping
upcast_dim = self.get_upcast_dim(i)
if len(upcast_dim) == 1:
grouped_store_offset = defaultdict(list)
for k in store_offset:
_idx = k[:upcast_dim[0]] + (Variable.num(0),) + k[upcast_dim[0]+1:]
grouped_store_offset[_idx].append(store_offset[k])
store_offset_new = {}
for k,out_tokens in self._group_float4(i, store_offset).items():
for k,out_tokens in grouped_store_offset.items():
idx, valid = self.sts[i].expr_idxs(k)
assert idx.render() == ((idx//4)*4).render(), "float4 stores are always aligned"
assert valid.min == 1, "stores are always valid"
if all_same([x.name for x in out_tokens]) and tuple(range(4)) == tuple(x.offset for x in out_tokens):
store_offset_new[k] = Token(out_tokens[0].name, dtypes._float4)
else:
store_offset_new[k] = self.uop(UOps.CAST, ssa("alu", dtypes._float4), out_tokens)
store_offset = store_offset_new
# do stores
for uidxs, var in store_offset.items():
self.uop(UOps.STORE, None, [var], MemOp(i, *self.sts[i].expr_idxs(idxs+[Variable.num(x) for x in uidxs[::-1]])))
for idx, var in store_offset.items():
self.uop(UOps.STORE, None, [var], MemOp(i, *self.sts[i].expr_idxs(idx)))
def linearize(self):
# uops
@@ -254,7 +247,10 @@ class Linearizer:
# local loop
local_idxs = [Variable(f"lidx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce-self.local_dims, self.first_reduce+len(self.group_for_reduce))]
self.uop(UOps.LOOP, None, [], (local_idxs, "local"))
gl_idxs = global_idxs + local_idxs
# upcast indexes
full_upcast_idxs = [Variable(None, 0, s-1) for s in self.full_shape[self.shape_len-self.upcasted:]]
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
# reduce op
fake_reduce_idxs = []
@@ -264,13 +260,13 @@ class Linearizer:
fake_reduce_idxs = [x*0 for x in reduce_idxs]
# define accumulator
acc = self.global_load(0, gl_idxs+fake_reduce_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
# reduce loop
self.uop(UOps.LOOP, None, [], (reduce_idxs, "reduce"))
# load earlybufs
loaded_buffers.update({b:self.global_load(i, gl_idxs+reduce_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs and i != 0})
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs and i != 0})
# run early AST (with reduce)
self.ast_parse(self.reduceop, [acc[off] for off in self.acc_offsets(self.full_buf_index)], loaded_buffers, ssa, do_reduce=True)
@@ -281,7 +277,7 @@ class Linearizer:
# end the local loop, do the local reduce
if self.group_for_reduce:
fake_global_idxs = [x*0 for x in global_idxs]
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs, acc, ssa) # store accumulators
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc, ssa) # store accumulators
self.uop(UOps.BARRIER, None, [], ())
self.uop(UOps.ENDLOOP, None, [], (local_idxs, "local"))
@@ -294,18 +290,20 @@ class Linearizer:
self.upcast()
self.group_for_reduce.pop()
local_idxs = local_idxs[:-1]
# regenerate upcast_idxs
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
# NOTE: this structure is the same as the reduce op above
# define late accumulator
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
# late reduce loop
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
self.uop(UOps.LOOP, None, [], (end_local_idxs, "late_reduce"))
# load localbufs
loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs+fake_reduce_idxs)
loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs+fake_reduce_idxs+upcast_idxs)
# there's no AST here (and there's no shape for the reduce LazyOp)
self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, ssa, do_reduce=True) # type: ignore
@@ -314,13 +312,13 @@ class Linearizer:
self.uop(UOps.ENDLOOP, None, [], (end_local_idxs, "late_reduce"))
# load latebufs
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})
# run late AST
val = self.ast_parse(self.ast, acc, loaded_buffers, ssa)
# store
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs, val, ssa)
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val, ssa)
if not self.group_for_reduce:
# end the global+local loop
@@ -364,6 +362,9 @@ class Linearizer:
@property
def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True)
@property
def output_shape(self) -> Tuple[int, ...]: return self.sts[0].shape
@property
def full_shape(self) -> Tuple[int, ...]: return self.sts[self.full_buf_index].shape
@@ -443,7 +444,7 @@ class Linearizer:
def simplify_merge_adjacent(self):
if self.shape_len == 0: return
shapes, strides = [x.shape for x in self.sts], [x.views[-1].strides for x in self.sts]
shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
# merge dimensions if we can, multi get_shape_strides
# TODO: does this always preserve the reduce dimension, NO
@@ -453,7 +454,7 @@ class Linearizer:
can_merge = []
for j in range(len(shapes)):
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
can_merge.append((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*strides[j][i]) or (strides[j][i] == 0 and rets[j][-1][1] == 0))
can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0)))
# more can merge than this
mergeable = all(can_merge) and i != self.first_reduce
for j in range(len(shapes)):
@@ -467,7 +468,7 @@ class Linearizer:
def required_optimizations(self, early_only=False):
for buf_index,buf in enumerate(self.bufs):
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes() if self.sts[buf_index].shape[i]%4 == 0]
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType:
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:

View File

@@ -4,7 +4,7 @@ from enum import Enum, auto
import functools
from typing import Dict, Tuple, Union, List, Optional, Callable, cast
from tinygrad.helpers import prod, DEBUG
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode
# these ops live here
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702
@@ -65,7 +65,7 @@ class View:
# generate an expression if you have a single idx variable
def expr_node(self, idx=None) -> Node:
if idx is None: idx = Variable('idx', 0, prod(self.shape))
ret = [Variable.num(self.offset)]
ret: List[Node] = [Variable.num(self.offset)]
acc = 1
for d,s in reversed(self.shape_strides):
ret.append(((idx//acc)%d)*s)
@@ -157,26 +157,23 @@ class ShapeTracker:
assert real_offset.__class__ is NumNode, f"how is the offset not a number? {real_offset} {mask}"
return real_offset.b
def real_strides(self) -> Tuple[Optional[int], ...]:
if len(self.views) == 1: return self.views[-1].strides
ret: List[Optional[int]] = []
acc, real_offset = 1, self.real_offset()
for s in reversed(self.shape):
if s == 1: # fast path, all shape 1 have stride 0
ret.append(0)
continue
var = Variable('idx', 0, s-1)
this_dim, _ = self.expr_node(var*acc)
this_dim -= real_offset
acc *= s
# TODO: sometimes a mod here is okay if you are say, reading a float4, since you only care %4
# if test.__class__ is ModNode and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay
if this_dim.__class__ is MulNode and cast(MulNode, this_dim).a.__class__ is Variable: ret.append(this_dim.b)
elif this_dim.__class__ is NumNode and this_dim.b == 0: ret.append(0)
elif this_dim.__class__ is Variable: ret.append(1)
else: ret.append(None)
return tuple(ret[::-1])
def unit_stride_axes(self) -> List[int]: return [i for i,st in enumerate(self.real_strides()) if st == 1]
# NOTE: if a stride is not always valid, it will be None
def real_strides(self, ignore_valid=False) -> Tuple[Optional[int], ...]:
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
idx, valid = self.expr_idxs(idxs)
ret: List[Optional[int]] = [None for _ in self.views[-1].shape]
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable):
ret[idxs.index(this_dim.a)] = this_dim.b
elif isinstance(this_dim, Variable):
ret[idxs.index(this_dim)] = 1
render_idx, render_valid = idx.render(), valid.render()
for i in range(len(self.shape)):
if f'idx{i}' in render_valid and not ignore_valid: ret[i] = None
elif f'idx{i}' not in render_idx: ret[i] = 0
return tuple(ret)
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
def _expr_idx(self, idx, valid):
for v in reversed(self.views[0:-1]):

View File

@@ -2,7 +2,8 @@ from __future__ import annotations
from abc import abstractmethod
import functools
from math import gcd
from typing import List, Dict, Callable, Tuple, Type, Union
from tinygrad.helpers import partition
from typing import List, Dict, Callable, Tuple, Type, Union, Optional
# NOTE: Python has different behavior for negative mod and floor div than c
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
@@ -25,8 +26,22 @@ class Node:
def __neg__(self): return self*-1
def __add__(self, b:Union[Node, int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
def __sub__(self, b:Union[Node, int]): return self+-b
def __ge__(self, b:int): return create_node(LtNode(-self, -b+1))
def __lt__(self, b:int): return create_node(LtNode(self, b))
def __ge__(self, b:int): return (-self) < (-b+1)
def __lt__(self, b:int):
lhs = self
if isinstance(lhs, SumNode):
muls, others = partition(lhs.nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
if len(muls):
# NOTE: gcd in python 3.8 takes exactly 2 args
mul_gcd = muls[0].b
for x in muls[1:]: mul_gcd = gcd(mul_gcd, x.b)
if b%mul_gcd == 0:
all_others = Variable.sum(others)
#print(mul_gcd, muls, all_others)
if all_others.min >= 0 and all_others.max < mul_gcd:
# TODO: should we divide both by mul_gcd here?
lhs = Variable.sum(muls)
return create_node(LtNode(lhs, b))
def __mul__(self, b:int):
if b == 0: return NumNode(0)
elif b == 1: return self
@@ -54,7 +69,7 @@ class Node:
return create_node(ModNode(self, b))
@staticmethod
def num(num:int) -> Node: return NumNode(num)
def num(num:int) -> NumNode: return NumNode(num)
@staticmethod
def factorize(nodes:List[Node]):
@@ -100,12 +115,12 @@ class Node:
# 4 basic node types
class Variable(Node):
def __new__(cls, expr:str, nmin:int, nmax:int):
def __new__(cls, expr:Optional[str], nmin:int, nmax:int):
assert nmin >= 0 and nmin <= nmax
if nmin == nmax: return NumNode(nmin)
return super().__new__(cls)
def __init__(self, expr:str, nmin:int, nmax:int):
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
self.expr, self.min, self.max = expr, nmin, nmax
class NumNode(Node):
@@ -128,6 +143,7 @@ class LtNode(OpNode):
def __mul__(self, b: int): return (self.a*b) < (self.b*b)
def __floordiv__(self, b: int, _=False): return (self.a//b) < (self.b//b)
def get_bounds(self) -> Tuple[int, int]: return int(self.a.max < self.b), int(self.a.min < self.b)
class MulNode(OpNode):
def __mul__(self, b: int): return self.a*(self.b*b) # two muls in one mul
def __floordiv__(self, b: int, factoring_allowed=False): # NOTE: mod negative isn't handled right
@@ -139,11 +155,13 @@ class MulNode(OpNode):
return Node.__mod__(a, b)
def get_bounds(self) -> Tuple[int, int]:
return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
class DivNode(OpNode):
def __floordiv__(self, b: int, _=False): return self.a//(self.b*b) # two divs is one div
def get_bounds(self) -> Tuple[int, int]:
assert self.a.min >= 0
return self.a.min//self.b, self.a.max//self.b
class ModNode(OpNode):
def __floordiv__(self, b: int, factoring_allowed=True):
if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod
@@ -184,7 +202,7 @@ class SumNode(RedNode):
return Node.__floordiv__(self, b, factoring_allowed)
def __mod__(self, b: int):
new_nodes = []
new_nodes: List[Node] = []
for x in self.nodes:
if x.__class__ is NumNode: new_nodes.append(Variable.num(x.b%b))
elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b))