mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Support weird loads in Image (#2498)
* image support weird loads * umm, that was always wrong * openpilot compile fails with a weird error * image test passes * we have valids now * clean that up * no more required opts * add fastvits test, fix bug * minor cleanups
This commit is contained in:
@@ -14,10 +14,9 @@ from typing import Tuple, List
|
||||
from extra.onnx import get_run_onnx
|
||||
from tinygrad.graph import print_tree, log_schedule_item
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, fetch, getenv, ImageDType, GRAPH
|
||||
from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, fetch, getenv, ImageDType, GRAPH, DEBUG
|
||||
from tinygrad.realize import run_schedule
|
||||
from tinygrad.ops import LoadOps, ScheduleItem
|
||||
from tinygrad.features.image import fix_schedule_for_images
|
||||
Device.DEFAULT = "GPU"
|
||||
|
||||
def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
|
||||
@@ -67,10 +66,6 @@ def schedule_to_thneed(schedule, output_fn):
|
||||
setattr(prg.clprg, 'op_estimate', prg.op_estimate)
|
||||
setattr(prg.clprg, 'prg', prg.prg)
|
||||
|
||||
if getenv("VALIDTEST") == 1:
|
||||
src = re.search(r"=.*\?.*?read_image", prg.prg)
|
||||
if src is not None: raise Exception("Openpilot has valid checks!")
|
||||
|
||||
global_size = prg.global_size + [1]*(3-len(prg.global_size))
|
||||
local_size = prg.local_size + [1]*(3-len(prg.local_size))
|
||||
cl_cache.append((prg.clprg, [[int(g*l) for g,l in zip(global_size, local_size)], local_size, *[x.realized._buf for x in args]]))
|
||||
@@ -146,8 +141,7 @@ if __name__ == "__main__":
|
||||
|
||||
run_schedule(schedule_independent, disable_logging=True)
|
||||
run_schedule(schedule_input)
|
||||
with Context(DEBUG=2, BEAM=getenv("LATEBEAM")):
|
||||
schedule = fix_schedule_for_images(schedule)
|
||||
with Context(DEBUG=max(DEBUG.value, 2), BEAM=getenv("LATEBEAM")):
|
||||
image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule)
|
||||
print(f"**** running real kernels {image_count}/{len(schedule)} images ****")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user