Support weird loads in Image (#2498)

* image support weird loads

* umm, that was always wrong

* openpilot compile fails with a weird error

* image test passes

* we have valids now

* clean that up

* no more required opts

* add fastvits test, fix bug

* minor cleanups
This commit is contained in:
George Hotz
2023-11-29 08:30:46 -08:00
committed by GitHub
parent e333672675
commit 889acefe85
11 changed files with 79 additions and 105 deletions

View File

@@ -182,7 +182,7 @@ jobs:
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot model compile and size
run: |
DEBUG=2 ALLOWED_KERNEL_COUNT=207 VALIDTEST=1 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile2.py
DEBUG=2 ALLOWED_KERNEL_COUNT=207 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile2.py
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot model correctness (float32)
@@ -190,6 +190,9 @@ jobs:
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot alt model correctness (float32)
run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile2.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot fastvits model correctness (float32)
run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile2.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
- if: ${{ matrix.task == 'openpilot' }}
name: Test tensor core ops
run: GPU=1 TC=2 python -m pytest -n=auto test/test_ops.py

View File

@@ -14,10 +14,9 @@ from typing import Tuple, List
from extra.onnx import get_run_onnx
from tinygrad.graph import print_tree, log_schedule_item
from tinygrad import Tensor, Device
from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, fetch, getenv, ImageDType, GRAPH
from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, fetch, getenv, ImageDType, GRAPH, DEBUG
from tinygrad.realize import run_schedule
from tinygrad.ops import LoadOps, ScheduleItem
from tinygrad.features.image import fix_schedule_for_images
Device.DEFAULT = "GPU"
def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
@@ -67,10 +66,6 @@ def schedule_to_thneed(schedule, output_fn):
setattr(prg.clprg, 'op_estimate', prg.op_estimate)
setattr(prg.clprg, 'prg', prg.prg)
if getenv("VALIDTEST") == 1:
src = re.search(r"=.*\?.*?read_image", prg.prg)
if src is not None: raise Exception("Openpilot has valid checks!")
global_size = prg.global_size + [1]*(3-len(prg.global_size))
local_size = prg.local_size + [1]*(3-len(prg.local_size))
cl_cache.append((prg.clprg, [[int(g*l) for g,l in zip(global_size, local_size)], local_size, *[x.realized._buf for x in args]]))
@@ -146,8 +141,7 @@ if __name__ == "__main__":
run_schedule(schedule_independent, disable_logging=True)
run_schedule(schedule_input)
with Context(DEBUG=2, BEAM=getenv("LATEBEAM")):
schedule = fix_schedule_for_images(schedule)
with Context(DEBUG=max(DEBUG.value, 2), BEAM=getenv("LATEBEAM")):
image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule)
print(f"**** running real kernels {image_count}/{len(schedule)} images ****")

29
test/test_image_dtype.py Normal file
View File

@@ -0,0 +1,29 @@
import unittest
import numpy as np
from tinygrad import Device, dtypes, Tensor
from tinygrad.helpers import ImageDType
@unittest.skipIf(Device.DEFAULT != "GPU", "only images on GPU")
class TestImageDType(unittest.TestCase):
def test_shrink_load_float(self):
it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).realize()
imgv = it.numpy()
np.testing.assert_equal(imgv[0:2], it[0:2].numpy())
def test_mul_stays_image(self):
it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).realize()
out = (it*2).realize()
assert isinstance(out.lazydata.realized.dtype, ImageDType)
def test_shrink_max(self):
it = Tensor.randn(8).cast(dtypes.imagef((1,2,4))).realize()
imgv = it.numpy()
np.testing.assert_equal(np.maximum(imgv[0:3], 0), it[0:3].relu().numpy())
def test_shrink_to_float(self):
it = Tensor.randn(4, 4).cast(dtypes.imagef((1,4,4))).realize()
imgv = it.numpy()
np.testing.assert_equal(np.maximum(imgv[:, 0], 0), it[:, 0].relu().realize())
if __name__ == '__main__':
unittest.main()

View File

@@ -468,21 +468,7 @@ class Kernel:
assert padded, "nothing was padded"
return self.simplify_ones()
def required_optimizations(self, early_only=False):
for buf_index,buf in enumerate(self.bufs):
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType:
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
if unit_stride_axes_mul_4[0] < self.first_reduce:
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
else:
self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
def hand_coded_optimizations(self):
# if there's images in the earlybufs, we have to make an axis the 4 loading one
self.required_optimizations(early_only=True)
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
@@ -522,8 +508,16 @@ class Kernel:
if self.sts[0].shape[axes[0]]%4 == 0:
self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4))
# now do everything required
self.required_optimizations()
# upcast float4 images
for buf_index,buf in enumerate(self.bufs):
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
if buf.dtype.__class__ is ImageDType:
#assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
if unit_stride_axes_mul_4[0] < self.first_reduce:
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
else:
self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
# no more opt if we are grouping
if self.group_for_reduce: return

View File

@@ -92,20 +92,28 @@ class Linearizer(Kernel):
if valid.min == 0 and valid.max == 1:
valid_rendered = valid.render(self.render_ops, self)
self.load_cache[key] = self.uop(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)), TernaryOps.WHERE)
elif isinstance(buf.dtype, ImageDType):
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), (image_idx[0].render(self.render_ops, self), image_idx[1].render(self.render_ops, self)))
valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, dtypes.float32.vec(4))) if valid.min == 0 else tuple()
self.load_cache[key] = self.uop(UOps.LOAD, dtypes.float32.vec(4), (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
idx_small = idx%4
res = idx_small.render(self.render_ops, self)
if localtype == localtype.scalar():
out = self.uop(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
for ix in range(idx_small.max, idx_small.min, -1):
rvv = self.uop(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
sel = self.uop(UOps.ALU, res.dtype, (res, self.const(ix)), BinaryOps.CMPLT)
out = self.uop(UOps.ALU, localtype, (sel, rvv, out), TernaryOps.WHERE)
self.load_cache[key] = out
else:
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
if isinstance(buf.dtype, ImageDType):
idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), (idx[0].render(self.render_ops, self), idx[1].render(self.render_ops, self)))
else:
rendered_idx = idx.render(self.render_ops, self)
if valid.min == 0:
valid_rendered = valid.render(self.render_ops, self)
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)) + ((barrier,) if barrier else ()))
else:
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx) + ((barrier,) if barrier else ()))
rendered_idx = idx.render(self.render_ops, self)
valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, localtype)) if valid.min == 0 else tuple()
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
ret.append(self.uop(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
return ret
@@ -383,7 +391,7 @@ class Linearizer(Kernel):
for u in self.uops:
if not loop_stack[-1]: loop_stack[-1].append(u)
elif u.uop == UOps.LOOP: loop_stack.append([u])
elif u.uop not in [UOps.CONST, UOps.ALU]: loop_stack[-1].append(u)
elif u.uop not in [UOps.CONST, UOps.ALU, UOps.CAST]: loop_stack[-1].append(u)
else:
parents = get_recursive_parents(u)
for i in reversed(range(len(loop_stack))):

View File

@@ -186,7 +186,6 @@ def _get_optimized_linearizer(linearizer_opts:LinearizerOptions, ast:LazyOp) ->
if BEAM >= 1:
lins = [(("tc" if used_tensor_cores else "hc"), k)]
kb = Linearizer(ast, linearizer_opts)
kb.required_optimizations()
from tinygrad.features.search import beam_search, time_linearizer, bufs_from_lin
# TODO: this shouldn't use Device.DEFAULT, it should get the device from the LinearizerOptions
test_rawbuffers = bufs_from_lin(kb) # allocate scratch buffers for optimization
@@ -197,6 +196,4 @@ def _get_optimized_linearizer(linearizer_opts:LinearizerOptions, ast:LazyOp) ->
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
k = timed[0][1]
else:
k.required_optimizations()
return k

View File

@@ -1,5 +1,5 @@
from typing import List, Tuple, Dict, Any
from tinygrad.helpers import ImageDType, prod, IMAGE, getenv, dtypes, DEBUG, flatten
from typing import Tuple, Dict, Any
from tinygrad.helpers import prod, IMAGE, getenv, dtypes, DEBUG
# *** image Tensor function replacements ***
@@ -95,60 +95,6 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
# *** schedules with images need to be fixed to be valid ***
import dataclasses
from tinygrad.ops import ScheduleItem, BufferOps, LazyOp, UnaryOps, LoadOps, MemBuffer, get_lazyop_info
def fix_schedule_for_images(schedule:List[ScheduleItem]):
# this is the fundamental fix, find unwritable or unreadable images and convert them to normal float32 (TODO: should it be float16?)
replace_inputs = {}
for i, si in enumerate(schedule):
if isinstance(si.out.dtype, ImageDType) and (prod(si.out.shape) != prod(si.out.dtype.shape) or not any(si.out.shape[x]%4 == 0 for x in si.out.st.unit_stride_axes())):
if DEBUG >= 1: print(f"{i:3d}: rewrite output, output shape {prod(si.out.shape)}, image dtype {si.out.dtype} prod {prod(si.out.dtype.shape)}")
si.out.dtype = dtypes.float32
for b in si.ast.get_lazyops():
if b.op != BufferOps.LOAD: continue
if isinstance(si.inputs[b.arg.idx-1].dtype, ImageDType) and not any(b.arg.st.shape[x]%4 == 0 for x in b.arg.st.unit_stride_axes()):
if DEBUG >= 1: print(f"{i:3d}: rewrite input, image dtype {si.inputs[b.arg.idx-1].dtype}, {b.arg.st.views}")
if si.inputs[b.arg.idx-1].realized:
# have to copy it
replace_inputs[si.inputs[b.arg.idx-1]] = si.inputs[b.arg.idx-1].cast(dtypes.float32)
else:
# change it before it's created
si.inputs[b.arg.idx-1].dtype = dtypes.float32
# now fix up the schedule to reflect the new dtypes
fixed_schedule:List[ScheduleItem] = []
for i,si in enumerate(schedule):
ast = si.ast
inputs = si.inputs
# replace inputs with casted versions
if any(x in replace_inputs for x in inputs):
fixed_schedule += flatten([replace_inputs[x].schedule() for x in inputs if x in replace_inputs])
inputs = tuple(replace_inputs.get(x, x) for x in inputs)
# fix input dtypes to match what they actually are
replacements = {}
for b in si.ast.get_lazyops():
if b.op != BufferOps.LOAD: continue
if b.arg.dtype != inputs[b.arg.idx-1].dtype:
replacements[b] = LazyOp(BufferOps.LOAD, (), MemBuffer(b.arg.idx, inputs[b.arg.idx-1].dtype, b.arg.st))
if replacements: ast = ast.map_buffers(replacements)
# fix the ops to create the output dtype
if ast.op not in LoadOps:
info = get_lazyop_info(ast)
if info.dtype != si.out.dtype:
if DEBUG >= 3: print(f"{i:3d}: info.dtype {info.dtype} != {si.out.dtype} -> {si.out.dtype}")
ast_cast = LazyOp(UnaryOps.CAST, (ast.src[0],), (si.out.dtype, False))
ast = LazyOp(BufferOps.STORE, (ast_cast,), MemBuffer(0, si.out.dtype, ast.arg.st))
# put this in the fixed schedule
fixed_schedule.append(dataclasses.replace(si, ast=ast, inputs=inputs))
return fixed_schedule
# *** images have weird indexing requirements ***
from tinygrad.shape.symbolic import Node, AndNode, Variable, NumNode, SumNode, LtNode
@@ -178,6 +124,7 @@ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tup
fakes = {}
for cnt, (key_node, (mnn, mxn, multip)) in enumerate(val_dict.items()):
if mnn > mxn: return (idx, idy), valid # TODO: why is this happening?
fake_var = Variable("fake_" + str(cnt), mnn, mxn)
fakes[fake_var] = key_node
idxy += multip*(fake_var - key_node)

View File

@@ -4,7 +4,7 @@ from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapp
from weakref import ref, WeakSet, WeakValueDictionary
import numpy as np
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, dedup, merge_dicts, all_int, ImageDType
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, dedup, merge_dicts, all_int, ImageDType, DEBUG
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps, get_lazyop_info
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import Variable, sint
@@ -172,9 +172,15 @@ class LazyBuffer:
if op.op not in LoadOps:
info = get_lazyop_info(op)
assert info.dtype == self.dtype or isinstance(self.dtype, ImageDType), f"dtype mismatch {info.dtype=} != {self.dtype=}"
if isinstance(self.dtype, ImageDType) and (prod(self.shape) != prod(self.dtype.shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
if DEBUG >= 3: print(f"forcing image {self.dtype} to float32")
self.dtype = dtypes.float32 # NOTE; this is what makes the dtype above not match
op = LazyOp(UnaryOps.CAST, (op, ), (dtypes.float32, False))
# TODO: why doesn't this match?
#assert info.shape == self.shape, f"shape mismatch {info.shape=} != {self.shape=}"
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, info.dtype, ShapeTracker.from_shape(info.shape)))
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, self.dtype, ShapeTracker.from_shape(info.shape)))
return ret + [ScheduleItem(op, self, tuple(base_bufs), {k:var_vals[k] for k in vars_from_ast(op)})]

View File

@@ -4,13 +4,9 @@ from tinygrad.ops import ScheduleItem, LazyOp, LoadOps, BufferOps
from tinygrad.device import Device
from tinygrad.graph import log_schedule_item, print_tree
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import DEBUG, prod, all_int, IMAGE, getenv
from tinygrad.features.image import fix_schedule_for_images
from tinygrad.helpers import DEBUG, prod, all_int, getenv
def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
# HACK: images can be not usable due to shape
if IMAGE >= 2: schedule = fix_schedule_for_images(schedule)
# NOTE: if you for loop the schedule it's slow because nothing frees
while len(schedule):
si = schedule.pop(0)

View File

@@ -52,7 +52,7 @@ class CStyleLanguage(NamedTuple):
def render_const(self, x:Union[float,int], var_dtype) -> str:
if math.isnan(x): val = "NAN"
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
else: val = f"{x}f" if dtypes.is_float(var_dtype) and isinstance(x, float) else f"{int(x)}"
else: val = f"{float(x)}f" if dtypes.is_float(var_dtype) else f"{int(x)}"
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
# returns a str expression of the loaded value with the output type

View File

@@ -132,7 +132,7 @@ class Node:
class Variable(Node):
def __new__(cls, expr:Optional[str], nmin:int, nmax:int):
assert nmin >= 0 and nmin <= nmax
assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}"
if nmin == nmax: return NumNode(nmin)
return super().__new__(cls)