mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
new upcast works (#1066)
* new upcast works * float4 try * fix unaligned float4 * disallow unaligned access * upcast dim * maybe good now * fix gpu half * vstore_half4 * fix deep image bugs * improve symbolic to fix issues * fix symbolic * cl test * this maybe * gcd of 1 is 1 * real fix for old python * improve fuzzer
This commit is contained in:
8
.github/workflows/test.yml
vendored
8
.github/workflows/test.yml
vendored
@@ -180,12 +180,12 @@ jobs:
|
||||
run: |
|
||||
GPU=1 IMAGE=1 python3 test/test_ops.py
|
||||
FORWARD_ONLY=1 GPU=1 IMAGE=2 python3 test/test_ops.py
|
||||
- name: Test openpilot model
|
||||
- name: Test openpilot model compile and size
|
||||
run: |
|
||||
ALLOWED_KERNEL_COUNT=199 FLOAT16=1 VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
|
||||
ALLOWED_KERNEL_COUNT=199 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
|
||||
python3 -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
|
||||
DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
|
||||
VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
|
||||
- name: Test openpilot model correctness (float32)
|
||||
run: DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
|
||||
|
||||
testmetal:
|
||||
name: Metal Tests
|
||||
|
||||
11
test/external/fuzz_symbolic.py
vendored
11
test/external/fuzz_symbolic.py
vendored
@@ -21,6 +21,14 @@ def add_num(expr, rng=None):
|
||||
if rng is None: rng = random.randint(-4,4)
|
||||
return expr + rng, rng
|
||||
|
||||
def lt(expr, rng=None):
|
||||
if rng is None: rng = random.randint(-4,4)
|
||||
return expr < rng, rng
|
||||
|
||||
def ge(expr, rng=None):
|
||||
if rng is None: rng = random.randint(-4,4)
|
||||
return expr >= rng, rng
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops = [add_v, div, mul, add_num]
|
||||
while 1:
|
||||
@@ -29,6 +37,9 @@ if __name__ == "__main__":
|
||||
u3 = Variable("v3", 0, 4)
|
||||
v = [u1,u2,u3]
|
||||
tape = [random.choice(ops) for _ in range(20)]
|
||||
# 10% of the time, add a less than or greater than
|
||||
if random.random() < 0.05: tape.append(lt)
|
||||
elif random.random() < 0.05: tape.append(ge)
|
||||
expr = Variable.num(0)
|
||||
rngs = []
|
||||
for t in tape:
|
||||
|
||||
@@ -147,7 +147,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65), (45,65)], torch.minimum, Tensor.minimum)
|
||||
helper_test_op([(), ()], torch.minimum, Tensor.minimum)
|
||||
def test_add(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add)
|
||||
helper_test_op([(45,68), (45,68)], lambda x,y: x+y, Tensor.add)
|
||||
def test_add_number(self):
|
||||
helper_test_op([(), ()], lambda x,y: x+y, Tensor.add)
|
||||
def test_add3(self):
|
||||
@@ -167,6 +167,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([()], lambda x: -x)
|
||||
def test_mul(self):
|
||||
helper_test_op([(64,64), (64,64)], lambda x,y: x*y, Tensor.mul)
|
||||
def test_mul_number(self):
|
||||
helper_test_op([(), ()], lambda x,y: x*y, Tensor.mul)
|
||||
def test_mul_const(self):
|
||||
helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2)
|
||||
|
||||
@@ -87,6 +87,15 @@ class TestRealDoesntSimplify(unittest.TestCase):
|
||||
View((4, 4, 3, 3), (36, 9, 3, 1), 0, None)])
|
||||
assert self.st.real_strides() == (None, 18, -3, -1)
|
||||
|
||||
class TestRealStrides(unittest.TestCase):
|
||||
def test_1(self):
|
||||
self.st = ShapeTracker((16, 32, 4), views=[
|
||||
View((2048,), (1,), 0, ((0, 512),)),
|
||||
View((16, 32, 4), (128, 4, 1), 0, None)])
|
||||
st = self.st.real_strides()
|
||||
print(self.st, st)
|
||||
assert st == (None, 4, 1)
|
||||
|
||||
class TestRealSimplifies(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
st = self.st.real_strides()
|
||||
@@ -105,7 +114,6 @@ class TestRealSimplifies(unittest.TestCase):
|
||||
View((8, 3, 3, 11, 2, 28), (924, 308, 0, 28, 0, 1), 0, None),
|
||||
View((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)])
|
||||
|
||||
|
||||
class TestSimplifyingShapeTracker(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.st = CheckingShapeTracker((1, 10))
|
||||
|
||||
@@ -26,7 +26,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
|
||||
def test_ge_divides(self):
|
||||
expr = (Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512
|
||||
self.helper_test_variable(expr, 0, 1, "(((idx*4)+FLOAT4_INDEX)<512)")
|
||||
self.helper_test_variable(expr, 0, 1, "((idx*4)<512)")
|
||||
self.helper_test_variable(expr//4, 0, 1, "(idx<128)")
|
||||
|
||||
def test_ge_divides_and(self):
|
||||
@@ -37,6 +37,10 @@ class TestSymbolic(unittest.TestCase):
|
||||
(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7)) < 512])
|
||||
self.helper_test_variable(expr//4, 0, 1, "((((FLOAT8_INDEX//4)+idx2)<128) and (idx1<128))")
|
||||
|
||||
def test_lt_factors(self):
|
||||
expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256)) < 512])
|
||||
self.helper_test_variable(expr, 0, 1, "(((idx1*4)+FLOAT4_INDEX)<512)")
|
||||
|
||||
def test_div_becomes_num(self):
|
||||
assert isinstance(Variable("a", 2, 3)//2, NumNode)
|
||||
|
||||
|
||||
@@ -125,19 +125,19 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
|
||||
else: val = f"{x}" + ("f" if not dtypes.is_int(bufs[args.i].dtype) else "")
|
||||
elif isinstance(bufs[args.i].dtype, ImageDType):
|
||||
assert newvar.dtype == dtypes._float4, "image must be float4"
|
||||
assert newvar.dtype == dtypes._float4, f"image must be float4 {newvar}"
|
||||
prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
|
||||
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args.idx, args.valid)
|
||||
val = f"read_imagef({bufnames[args.i]}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}))"
|
||||
else:
|
||||
if lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
|
||||
if newvar.dtype == dtypes._float4:
|
||||
val = f"vload_half4({(args.idx//4).render(render_cl)}, {bufnames[args.i]})"
|
||||
val = f"vload_half4(0, {bufnames[args.i]}+{(args.idx).render(render_cl)})"
|
||||
else:
|
||||
val = f"vload_half({args.idx.render(render_cl)}, {bufnames[args.i]})"
|
||||
else:
|
||||
if newvar.dtype == dtypes._float4:
|
||||
val = f"({newvar.dtype.name})((({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}])"
|
||||
val = f"({newvar.dtype.name})(*(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*)({bufnames[args.i]}+{args.idx.render(render_cl)})))"
|
||||
else:
|
||||
val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]"
|
||||
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
|
||||
@@ -159,8 +159,10 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
if isinstance(bufs[args[0]].dtype, ImageDType):
|
||||
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args[1], args[2])
|
||||
kk(f"write_imagef({bufnames[args.i]}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {vin[0].render()});")
|
||||
elif lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
|
||||
kk(f"vstore_half4({vin[0].render()}, {args.idx.render(render_cl)}, {bufnames[args.i]});")
|
||||
else:
|
||||
kk(f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}] = {vin[0].render()};")
|
||||
kk(f"*(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*)({bufnames[args.i]}+{args.idx.render(render_cl)})) = ({bufs[args.i].dtype.name}4){vin[0].render()};")
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];")
|
||||
else:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, TypeVar, Dict
|
||||
from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, TypeVar, Dict, Iterator, Union, Sequence
|
||||
import itertools, math
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
@@ -9,7 +9,8 @@ from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps
|
||||
from tinygrad.runtime.lib import RawConst
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.shape.symbolic import Variable, NumNode
|
||||
VariableOrNum = Union[Variable, NumNode]
|
||||
|
||||
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto(); BARRIER = auto(); \
|
||||
SPECIAL = auto(); DEFINE_REGISTER = auto(); LABEL = auto(); COND_BRANCH = auto() # noqa: E702
|
||||
@@ -74,6 +75,10 @@ def get_grouped_maybe_float4(*values:List[Token], grouping_allowed=True):
|
||||
return zip(new_idxs, new_values)
|
||||
return zip([[i] for i in range(len(values[0]))], zip(*values))
|
||||
|
||||
def expand_idxs(idxs:Sequence[VariableOrNum]) -> Iterator[Tuple[VariableOrNum, ...]]:
|
||||
for x in itertools.product(*[[idx] if not isinstance(idx, Variable) or idx.expr is not None else [Variable.num(j) for j in range(idx.min, idx.max+1)] for idx in idxs[::-1]]):
|
||||
yield x[::-1]
|
||||
|
||||
class MemOp(NamedTuple):
|
||||
i: int
|
||||
idx: Variable
|
||||
@@ -157,67 +162,55 @@ class Linearizer:
|
||||
acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))]
|
||||
return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])]
|
||||
|
||||
def _group_float4(self, i, store_offset):
|
||||
store_offset_float4 = {}
|
||||
float4_axis = (self.upcasted-1) - self.float4_axis(i)[0]
|
||||
for uidxs, var in store_offset.items():
|
||||
if uidxs[float4_axis]%4 == 0:
|
||||
store_offset_float4[uidxs] = [var]
|
||||
else:
|
||||
uidxs2 = list(uidxs)
|
||||
uidxs2[float4_axis] -= uidxs2[float4_axis]%4
|
||||
store_offset_float4[tuple(uidxs2)].append(var)
|
||||
return store_offset_float4
|
||||
def get_upcast_dim(self, i, amt=4):
|
||||
should_upcast = self.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType))
|
||||
return [x for x in self.sts[i].unit_stride_axes() if should_upcast and x >= self.shape_len-self.upcasted and self.sts[i].shape[x] == amt]
|
||||
|
||||
def global_load(self, i, idxs:List[Variable], const=None) -> List[Token]:
|
||||
load_offset: Dict[Tuple[int, ...], Any] = {uidxs:(dtypes.float,uidxs)+self.sts[i].expr_idxs(idxs+[Variable.num(x) for x in uidxs[::-1]]) for uidxs in self.shape_offsets(i)}
|
||||
|
||||
# float4 grouping (optional)
|
||||
should_upcast = self.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType)) and len(self.float4_axis(i)) == 1
|
||||
if should_upcast:
|
||||
load_offset_new = {}
|
||||
for k,out_tokens in self._group_float4(i, load_offset).items():
|
||||
idxs = [x[2]-out_tokens[0][2] for x in out_tokens]
|
||||
valids_okay = all_same([x[3] for x in out_tokens]) or (all_same([x[3]//4 for x in out_tokens]) and (out_tokens[0][3]//4)*4 == out_tokens[0][3])
|
||||
if any([idx.min != idx.max or idx.min != val for idx,val in zip(idxs, range(4))]) or (out_tokens[0][2]//4)*4 != out_tokens[0][2] or not valids_okay:
|
||||
# idxs not in order, valids don't match, or idx doesn't evenly divide 4. use normal float
|
||||
for x in out_tokens: load_offset_new[x[1]] = x
|
||||
else:
|
||||
load_offset_new[k] = (dtypes._float4, [x[1] for x in out_tokens], out_tokens[0][2], out_tokens[0][3])
|
||||
load_offset = load_offset_new
|
||||
|
||||
# do loads
|
||||
def global_load(self, i, idxs:Sequence[VariableOrNum], const=None) -> List[Token]:
|
||||
upcast_dim = self.get_upcast_dim(i)
|
||||
cache: Dict[str, Token] = {}
|
||||
loaded = {}
|
||||
for uidxs, (localtype, uidx_list, idx, valid) in load_offset.items():
|
||||
ret = []
|
||||
for _idx in expand_idxs(idxs):
|
||||
if len(upcast_dim) == 1:
|
||||
idx, valid = self.sts[i].expr_idxs((_idx[:upcast_dim[0]] + (Variable.num(0),) + _idx[upcast_dim[0]+1:]))
|
||||
localtype = dtypes._float4
|
||||
# disallow unaligned access, fall back to float
|
||||
if idx.render() != ((idx//4)*4).render():
|
||||
idx, valid = self.sts[i].expr_idxs(_idx)
|
||||
localtype = dtypes.float
|
||||
else:
|
||||
idx, valid = self.sts[i].expr_idxs(_idx)
|
||||
localtype = dtypes.float
|
||||
key = f"{localtype}{idx.render()}{valid.render()}"
|
||||
if key not in cache:
|
||||
cache[key] = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{len(cache)}", localtype), [], MemOp(i, idx, valid)) if const is None else self.uop(UOps.CONST, Token(f"acc{mnum(i)}_{len(cache)}", localtype), [], const)
|
||||
if localtype == dtypes._float4:
|
||||
for j,uidx in enumerate(uidx_list):
|
||||
loaded[uidx] = Token(cache[key].name, dtypes._float4, j)
|
||||
else:
|
||||
loaded[uidxs] = cache[key]
|
||||
return [loaded[uidxs] for uidxs in self.shape_offsets(i)]
|
||||
cache[key] = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{len(cache)}", localtype), [], MemOp(i, idx, valid)) if const is None else \
|
||||
self.uop(UOps.CONST, Token(f"acc{mnum(i)}_{len(cache)}", localtype), [], const)
|
||||
ret.append(Token(cache[key].name, cache[key].dtype, _idx[upcast_dim[0]].b) if localtype == dtypes._float4 else cache[key])
|
||||
return ret
|
||||
|
||||
def global_store(self, i, idxs:List[Variable], store:List[Token], ssa) -> None:
|
||||
store_offset: Dict[Tuple[int, ...], Token] = dict(zip(self.shape_offsets(i), store))
|
||||
def global_store(self, i, idxs:List[VariableOrNum], store:List[Token], ssa) -> None:
|
||||
store_offset = dict(zip(expand_idxs(idxs), store))
|
||||
|
||||
# float4 grouping (optional)
|
||||
# TODO: why does this not work for float16?
|
||||
should_upcast = self.supports_float4 and (self.bufs[i].dtype == dtypes.float32 or isinstance(self.bufs[i].dtype, ImageDType)) and len(self.float4_axis(i)) == 1
|
||||
if should_upcast:
|
||||
# float4 grouping
|
||||
upcast_dim = self.get_upcast_dim(i)
|
||||
if len(upcast_dim) == 1:
|
||||
grouped_store_offset = defaultdict(list)
|
||||
for k in store_offset:
|
||||
_idx = k[:upcast_dim[0]] + (Variable.num(0),) + k[upcast_dim[0]+1:]
|
||||
grouped_store_offset[_idx].append(store_offset[k])
|
||||
store_offset_new = {}
|
||||
for k,out_tokens in self._group_float4(i, store_offset).items():
|
||||
for k,out_tokens in grouped_store_offset.items():
|
||||
idx, valid = self.sts[i].expr_idxs(k)
|
||||
assert idx.render() == ((idx//4)*4).render(), "float4 stores are always aligned"
|
||||
assert valid.min == 1, "stores are always valid"
|
||||
if all_same([x.name for x in out_tokens]) and tuple(range(4)) == tuple(x.offset for x in out_tokens):
|
||||
store_offset_new[k] = Token(out_tokens[0].name, dtypes._float4)
|
||||
else:
|
||||
store_offset_new[k] = self.uop(UOps.CAST, ssa("alu", dtypes._float4), out_tokens)
|
||||
store_offset = store_offset_new
|
||||
|
||||
# do stores
|
||||
for uidxs, var in store_offset.items():
|
||||
self.uop(UOps.STORE, None, [var], MemOp(i, *self.sts[i].expr_idxs(idxs+[Variable.num(x) for x in uidxs[::-1]])))
|
||||
for idx, var in store_offset.items():
|
||||
self.uop(UOps.STORE, None, [var], MemOp(i, *self.sts[i].expr_idxs(idx)))
|
||||
|
||||
def linearize(self):
|
||||
# uops
|
||||
@@ -254,7 +247,10 @@ class Linearizer:
|
||||
# local loop
|
||||
local_idxs = [Variable(f"lidx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce-self.local_dims, self.first_reduce+len(self.group_for_reduce))]
|
||||
self.uop(UOps.LOOP, None, [], (local_idxs, "local"))
|
||||
gl_idxs = global_idxs + local_idxs
|
||||
|
||||
# upcast indexes
|
||||
full_upcast_idxs = [Variable(None, 0, s-1) for s in self.full_shape[self.shape_len-self.upcasted:]]
|
||||
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
|
||||
|
||||
# reduce op
|
||||
fake_reduce_idxs = []
|
||||
@@ -264,13 +260,13 @@ class Linearizer:
|
||||
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
||||
|
||||
# define accumulator
|
||||
acc = self.global_load(0, gl_idxs+fake_reduce_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
|
||||
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
|
||||
|
||||
# reduce loop
|
||||
self.uop(UOps.LOOP, None, [], (reduce_idxs, "reduce"))
|
||||
|
||||
# load earlybufs
|
||||
loaded_buffers.update({b:self.global_load(i, gl_idxs+reduce_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs and i != 0})
|
||||
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs and i != 0})
|
||||
|
||||
# run early AST (with reduce)
|
||||
self.ast_parse(self.reduceop, [acc[off] for off in self.acc_offsets(self.full_buf_index)], loaded_buffers, ssa, do_reduce=True)
|
||||
@@ -281,7 +277,7 @@ class Linearizer:
|
||||
# end the local loop, do the local reduce
|
||||
if self.group_for_reduce:
|
||||
fake_global_idxs = [x*0 for x in global_idxs]
|
||||
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs, acc, ssa) # store accumulators
|
||||
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc, ssa) # store accumulators
|
||||
self.uop(UOps.BARRIER, None, [], ())
|
||||
self.uop(UOps.ENDLOOP, None, [], (local_idxs, "local"))
|
||||
|
||||
@@ -294,18 +290,20 @@ class Linearizer:
|
||||
self.upcast()
|
||||
self.group_for_reduce.pop()
|
||||
local_idxs = local_idxs[:-1]
|
||||
# regenerate upcast_idxs
|
||||
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
|
||||
|
||||
# NOTE: this structure is the same as the reduce op above
|
||||
|
||||
# define late accumulator
|
||||
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
|
||||
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
|
||||
|
||||
# late reduce loop
|
||||
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
|
||||
self.uop(UOps.LOOP, None, [], (end_local_idxs, "late_reduce"))
|
||||
|
||||
# load localbufs
|
||||
loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs+fake_reduce_idxs)
|
||||
loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs+fake_reduce_idxs+upcast_idxs)
|
||||
|
||||
# there's no AST here (and there's no shape for the reduce LazyOp)
|
||||
self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, ssa, do_reduce=True) # type: ignore
|
||||
@@ -314,13 +312,13 @@ class Linearizer:
|
||||
self.uop(UOps.ENDLOOP, None, [], (end_local_idxs, "late_reduce"))
|
||||
|
||||
# load latebufs
|
||||
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})
|
||||
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})
|
||||
|
||||
# run late AST
|
||||
val = self.ast_parse(self.ast, acc, loaded_buffers, ssa)
|
||||
|
||||
# store
|
||||
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs, val, ssa)
|
||||
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val, ssa)
|
||||
|
||||
if not self.group_for_reduce:
|
||||
# end the global+local loop
|
||||
@@ -364,6 +362,9 @@ class Linearizer:
|
||||
@property
|
||||
def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True)
|
||||
|
||||
@property
|
||||
def output_shape(self) -> Tuple[int, ...]: return self.sts[0].shape
|
||||
|
||||
@property
|
||||
def full_shape(self) -> Tuple[int, ...]: return self.sts[self.full_buf_index].shape
|
||||
|
||||
@@ -443,7 +444,7 @@ class Linearizer:
|
||||
|
||||
def simplify_merge_adjacent(self):
|
||||
if self.shape_len == 0: return
|
||||
shapes, strides = [x.shape for x in self.sts], [x.views[-1].strides for x in self.sts]
|
||||
shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
|
||||
|
||||
# merge dimensions if we can, multi get_shape_strides
|
||||
# TODO: does this always preserve the reduce dimension, NO
|
||||
@@ -453,7 +454,7 @@ class Linearizer:
|
||||
can_merge = []
|
||||
for j in range(len(shapes)):
|
||||
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
|
||||
can_merge.append((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*strides[j][i]) or (strides[j][i] == 0 and rets[j][-1][1] == 0))
|
||||
can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0)))
|
||||
# more can merge than this
|
||||
mergeable = all(can_merge) and i != self.first_reduce
|
||||
for j in range(len(shapes)):
|
||||
@@ -467,7 +468,7 @@ class Linearizer:
|
||||
|
||||
def required_optimizations(self, early_only=False):
|
||||
for buf_index,buf in enumerate(self.bufs):
|
||||
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes() if self.sts[buf_index].shape[i]%4 == 0]
|
||||
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
|
||||
if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType:
|
||||
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
|
||||
if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
|
||||
|
||||
@@ -4,7 +4,7 @@ from enum import Enum, auto
|
||||
import functools
|
||||
from typing import Dict, Tuple, Union, List, Optional, Callable, cast
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode
|
||||
|
||||
# these ops live here
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702
|
||||
@@ -65,7 +65,7 @@ class View:
|
||||
# generate an expression if you have a single idx variable
|
||||
def expr_node(self, idx=None) -> Node:
|
||||
if idx is None: idx = Variable('idx', 0, prod(self.shape))
|
||||
ret = [Variable.num(self.offset)]
|
||||
ret: List[Node] = [Variable.num(self.offset)]
|
||||
acc = 1
|
||||
for d,s in reversed(self.shape_strides):
|
||||
ret.append(((idx//acc)%d)*s)
|
||||
@@ -157,26 +157,23 @@ class ShapeTracker:
|
||||
assert real_offset.__class__ is NumNode, f"how is the offset not a number? {real_offset} {mask}"
|
||||
return real_offset.b
|
||||
|
||||
def real_strides(self) -> Tuple[Optional[int], ...]:
|
||||
if len(self.views) == 1: return self.views[-1].strides
|
||||
ret: List[Optional[int]] = []
|
||||
acc, real_offset = 1, self.real_offset()
|
||||
for s in reversed(self.shape):
|
||||
if s == 1: # fast path, all shape 1 have stride 0
|
||||
ret.append(0)
|
||||
continue
|
||||
var = Variable('idx', 0, s-1)
|
||||
this_dim, _ = self.expr_node(var*acc)
|
||||
this_dim -= real_offset
|
||||
acc *= s
|
||||
# TODO: sometimes a mod here is okay if you are say, reading a float4, since you only care %4
|
||||
# if test.__class__ is ModNode and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay
|
||||
if this_dim.__class__ is MulNode and cast(MulNode, this_dim).a.__class__ is Variable: ret.append(this_dim.b)
|
||||
elif this_dim.__class__ is NumNode and this_dim.b == 0: ret.append(0)
|
||||
elif this_dim.__class__ is Variable: ret.append(1)
|
||||
else: ret.append(None)
|
||||
return tuple(ret[::-1])
|
||||
def unit_stride_axes(self) -> List[int]: return [i for i,st in enumerate(self.real_strides()) if st == 1]
|
||||
# NOTE: if a stride is not always valid, it will be None
|
||||
def real_strides(self, ignore_valid=False) -> Tuple[Optional[int], ...]:
|
||||
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
|
||||
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
|
||||
idx, valid = self.expr_idxs(idxs)
|
||||
ret: List[Optional[int]] = [None for _ in self.views[-1].shape]
|
||||
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
|
||||
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable):
|
||||
ret[idxs.index(this_dim.a)] = this_dim.b
|
||||
elif isinstance(this_dim, Variable):
|
||||
ret[idxs.index(this_dim)] = 1
|
||||
render_idx, render_valid = idx.render(), valid.render()
|
||||
for i in range(len(self.shape)):
|
||||
if f'idx{i}' in render_valid and not ignore_valid: ret[i] = None
|
||||
elif f'idx{i}' not in render_idx: ret[i] = 0
|
||||
return tuple(ret)
|
||||
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
||||
|
||||
def _expr_idx(self, idx, valid):
|
||||
for v in reversed(self.views[0:-1]):
|
||||
|
||||
@@ -2,7 +2,8 @@ from __future__ import annotations
|
||||
from abc import abstractmethod
|
||||
import functools
|
||||
from math import gcd
|
||||
from typing import List, Dict, Callable, Tuple, Type, Union
|
||||
from tinygrad.helpers import partition
|
||||
from typing import List, Dict, Callable, Tuple, Type, Union, Optional
|
||||
|
||||
# NOTE: Python has different behavior for negative mod and floor div than c
|
||||
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
|
||||
@@ -25,8 +26,22 @@ class Node:
|
||||
def __neg__(self): return self*-1
|
||||
def __add__(self, b:Union[Node, int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
|
||||
def __sub__(self, b:Union[Node, int]): return self+-b
|
||||
def __ge__(self, b:int): return create_node(LtNode(-self, -b+1))
|
||||
def __lt__(self, b:int): return create_node(LtNode(self, b))
|
||||
def __ge__(self, b:int): return (-self) < (-b+1)
|
||||
def __lt__(self, b:int):
|
||||
lhs = self
|
||||
if isinstance(lhs, SumNode):
|
||||
muls, others = partition(lhs.nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
|
||||
if len(muls):
|
||||
# NOTE: gcd in python 3.8 takes exactly 2 args
|
||||
mul_gcd = muls[0].b
|
||||
for x in muls[1:]: mul_gcd = gcd(mul_gcd, x.b)
|
||||
if b%mul_gcd == 0:
|
||||
all_others = Variable.sum(others)
|
||||
#print(mul_gcd, muls, all_others)
|
||||
if all_others.min >= 0 and all_others.max < mul_gcd:
|
||||
# TODO: should we divide both by mul_gcd here?
|
||||
lhs = Variable.sum(muls)
|
||||
return create_node(LtNode(lhs, b))
|
||||
def __mul__(self, b:int):
|
||||
if b == 0: return NumNode(0)
|
||||
elif b == 1: return self
|
||||
@@ -54,7 +69,7 @@ class Node:
|
||||
return create_node(ModNode(self, b))
|
||||
|
||||
@staticmethod
|
||||
def num(num:int) -> Node: return NumNode(num)
|
||||
def num(num:int) -> NumNode: return NumNode(num)
|
||||
|
||||
@staticmethod
|
||||
def factorize(nodes:List[Node]):
|
||||
@@ -100,12 +115,12 @@ class Node:
|
||||
# 4 basic node types
|
||||
|
||||
class Variable(Node):
|
||||
def __new__(cls, expr:str, nmin:int, nmax:int):
|
||||
def __new__(cls, expr:Optional[str], nmin:int, nmax:int):
|
||||
assert nmin >= 0 and nmin <= nmax
|
||||
if nmin == nmax: return NumNode(nmin)
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, expr:str, nmin:int, nmax:int):
|
||||
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
|
||||
self.expr, self.min, self.max = expr, nmin, nmax
|
||||
|
||||
class NumNode(Node):
|
||||
@@ -128,6 +143,7 @@ class LtNode(OpNode):
|
||||
def __mul__(self, b: int): return (self.a*b) < (self.b*b)
|
||||
def __floordiv__(self, b: int, _=False): return (self.a//b) < (self.b//b)
|
||||
def get_bounds(self) -> Tuple[int, int]: return int(self.a.max < self.b), int(self.a.min < self.b)
|
||||
|
||||
class MulNode(OpNode):
|
||||
def __mul__(self, b: int): return self.a*(self.b*b) # two muls in one mul
|
||||
def __floordiv__(self, b: int, factoring_allowed=False): # NOTE: mod negative isn't handled right
|
||||
@@ -139,11 +155,13 @@ class MulNode(OpNode):
|
||||
return Node.__mod__(a, b)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
|
||||
|
||||
class DivNode(OpNode):
|
||||
def __floordiv__(self, b: int, _=False): return self.a//(self.b*b) # two divs is one div
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
assert self.a.min >= 0
|
||||
return self.a.min//self.b, self.a.max//self.b
|
||||
|
||||
class ModNode(OpNode):
|
||||
def __floordiv__(self, b: int, factoring_allowed=True):
|
||||
if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod
|
||||
@@ -184,7 +202,7 @@ class SumNode(RedNode):
|
||||
return Node.__floordiv__(self, b, factoring_allowed)
|
||||
|
||||
def __mod__(self, b: int):
|
||||
new_nodes = []
|
||||
new_nodes: List[Node] = []
|
||||
for x in self.nodes:
|
||||
if x.__class__ is NumNode: new_nodes.append(Variable.num(x.b%b))
|
||||
elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b))
|
||||
|
||||
Reference in New Issue
Block a user