mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
set OPENPILOT_HACKS=1 to enable replace assign (#15123)
This commit is contained in:
committed by
GitHub
parent
df23057984
commit
592f9bf6c6
@@ -31,7 +31,7 @@ def compile(onnx_file):
|
||||
for i in range(3):
|
||||
GlobalCounters.reset()
|
||||
print(f"run {i}")
|
||||
with Context(DEBUG=max(DEBUG.value, 2 if i == 2 else 1)):
|
||||
with Context(DEBUG=max(DEBUG.value, 2 if i == 2 else 1), OPENPILOT_HACKS=1):
|
||||
ret = run_onnx_jit(**inputs).numpy()
|
||||
# copy i == 1 so use of JITBEAM is okay
|
||||
if i == 1: test_val = np.copy(ret)
|
||||
|
||||
@@ -797,7 +797,7 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
|
||||
def test_image_dot_f16_fusion(self):
|
||||
with Context(FLOAT16=1):
|
||||
with Context(FLOAT16=1, OPENPILOT_HACKS=1):
|
||||
def cnt():
|
||||
x, y, z = Tensor.empty((64, 64), dtype='float'), Tensor.empty((64, 64), dtype='float'), Tensor.empty((64, 64), dtype='float')
|
||||
a = (x @ y).relu()
|
||||
@@ -814,18 +814,19 @@ class TestSchedule(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
|
||||
@unittest.expectedFailure
|
||||
def test_image_conv_fusion(self):
|
||||
def cnt():
|
||||
x, y, z = Tensor.empty((1, 4, 3, 3)), Tensor.empty((4, 1, 3, 3)), Tensor.empty((4, 1, 7, 7))
|
||||
a = x.conv2d(y, Tensor.empty(4), groups=4, padding=1)
|
||||
b = a.conv2d(z, groups=4, padding=3)
|
||||
sched = (a + b).schedule()
|
||||
for si in sched: si.lower()
|
||||
return len([si for si in sched if isinstance(si.prg, CompiledRunner)])
|
||||
with Context(OPENPILOT_HACKS=1):
|
||||
def cnt():
|
||||
x, y, z = Tensor.empty((1, 4, 3, 3)), Tensor.empty((4, 1, 3, 3)), Tensor.empty((4, 1, 7, 7))
|
||||
a = x.conv2d(y, Tensor.empty(4), groups=4, padding=1)
|
||||
b = a.conv2d(z, groups=4, padding=3)
|
||||
sched = (a + b).schedule()
|
||||
for si in sched: si.lower()
|
||||
return len([si for si in sched if isinstance(si.prg, CompiledRunner)])
|
||||
|
||||
with Context(IMAGE=1): cnt1 = cnt()
|
||||
with Context(IMAGE=2): cnt2 = cnt()
|
||||
with Context(IMAGE=1): cnt1 = cnt()
|
||||
with Context(IMAGE=2): cnt2 = cnt()
|
||||
|
||||
self.assertEqual(cnt1, cnt2)
|
||||
self.assertEqual(cnt1, cnt2)
|
||||
|
||||
def _test_fusion(self, shapes, f, cnt):
|
||||
with Context(DEBUG=0, TRACK_MATCH_STATS=0): args = [Tensor.randn(s).realize() for s in shapes]
|
||||
|
||||
@@ -580,21 +580,22 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
# this is the failing case in openpilot...it's very simple like this
|
||||
def test_image_conv_fusion(self):
|
||||
w1 = Tensor.empty(16, 16, 1, 1)
|
||||
b1 = Tensor.empty(16)
|
||||
w2 = Tensor.empty(16, 16, 1, 1)
|
||||
b2 = Tensor.empty(16)
|
||||
w3 = Tensor.empty(16, 16, 1, 1)
|
||||
b3 = Tensor.empty(16)
|
||||
with Context(OPENPILOT_HACKS=1):
|
||||
w1 = Tensor.empty(16, 16, 1, 1)
|
||||
b1 = Tensor.empty(16)
|
||||
w2 = Tensor.empty(16, 16, 1, 1)
|
||||
b2 = Tensor.empty(16)
|
||||
w3 = Tensor.empty(16, 16, 1, 1)
|
||||
b3 = Tensor.empty(16)
|
||||
|
||||
x = Tensor.empty(1, 16, 32, 32)
|
||||
x = base = x.image_conv2d(w1, b1)
|
||||
x = x.image_conv2d(w2, b2) + base
|
||||
x = x.image_conv2d(w3, b3)
|
||||
x = Tensor.empty(1, 16, 32, 32)
|
||||
x = base = x.image_conv2d(w1, b1)
|
||||
x = x.image_conv2d(w2, b2) + base
|
||||
x = x.image_conv2d(w3, b3)
|
||||
|
||||
# NOOP, 3 convs, contiguous
|
||||
#check_schedule(x, 5)
|
||||
check_schedule(x, 7)
|
||||
# NOOP, 3 convs, contiguous
|
||||
#check_schedule(x, 5)
|
||||
check_schedule(x, 7)
|
||||
|
||||
def test_image_conv_fusion_minimal(self):
|
||||
b1 = Tensor.empty(16)
|
||||
|
||||
@@ -174,7 +174,7 @@ class ContextVar(Generic[T]):
|
||||
return [getattr(obj, x) if obj else x for x in self.value.split(',') if x]
|
||||
|
||||
DEBUG, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
|
||||
IMAGE, FLOAT16 = ContextVar("IMAGE", 0), ContextVar("FLOAT16", 0)
|
||||
IMAGE, FLOAT16, OPENPILOT_HACKS = ContextVar("IMAGE", 0), ContextVar("FLOAT16", 0), ContextVar("OPENPILOT_HACKS", 0)
|
||||
JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVar("JIT_BATCH_SIZE", 32)
|
||||
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
|
||||
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _
|
||||
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, should_resolve_call, identity_element
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
|
||||
from tinygrad.helpers import PCONTIG, FLOAT16, argsort, partition, get_single_element
|
||||
from tinygrad.helpers import PCONTIG, FLOAT16, OPENPILOT_HACKS, argsort, partition, get_single_element
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
|
||||
from tinygrad.codegen.opt import Opt
|
||||
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
|
||||
@@ -31,8 +31,8 @@ def found_assign(ctx:dict[UOp, UOp], assign:UOp, src:UOp):
|
||||
x = x.src[0]
|
||||
ctx[x] = assign
|
||||
|
||||
pm_openpilot_hacks = PatternMatcher([
|
||||
# *** ASSIGN replacement hack for openpilot ***
|
||||
# *** fold moved ASSIGNs (hack for openpilot) ***
|
||||
pm_fold_moved_assign = PatternMatcher([
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat((*GroupOp.Movement, Ops.CAST), name="src")), name="assign"), found_assign),
|
||||
# replace ALU sources with assign versions found above
|
||||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||
@@ -535,8 +535,8 @@ split_kernels = PatternMatcher([
|
||||
|
||||
@profile_matches
|
||||
def get_kernel_graph(sink:UOp) -> UOp:
|
||||
tsink = graph_rewrite(sink, pm_openpilot_hacks, ctx={}, name="openpilot hacks")
|
||||
tsink = graph_rewrite(tsink, multi_pm, name="multi_pm")
|
||||
tsink = graph_rewrite(sink, multi_pm, name="multi_pm")
|
||||
if OPENPILOT_HACKS: tsink = graph_rewrite(tsink, pm_fold_moved_assign, ctx={}, name="fold moved assigns")
|
||||
tsink = graph_rewrite(tsink, pm_syntactic_sugar+pm_mops+earliest_rewrites, bottom_up=True, name="earliest rewrites")
|
||||
|
||||
# convert movement ops to ranges
|
||||
|
||||
Reference in New Issue
Block a user