mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
simpler can_pad (#10364)
* simpler can_pad [pr] * 3 kernels * tests * less kernels
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -436,7 +436,7 @@ jobs:
|
||||
llvm: 'true'
|
||||
- name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
PYTHONPATH="." ALLOWED_KERNEL_COUNT=209 ALLOWED_READ_IMAGE=2138 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
PYTHONPATH="." ALLOWED_KERNEL_COUNT=209 ALLOWED_READ_IMAGE=2137 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
- name: Test openpilot alt model correctness (float32)
|
||||
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
|
||||
- name: Test openpilot fastvits model correctness (float32)
|
||||
|
||||
@@ -93,12 +93,11 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(z, 2))
|
||||
self.assertEqual(z.item(), 32)
|
||||
|
||||
# TODO: same issue in precompute_freqs_cis
|
||||
def test_push_pads_contiguous(self):
|
||||
x = Tensor.full((4,1), 2.).contiguous()
|
||||
y = Tensor.full((4,4), 4.).contiguous()
|
||||
z = (x.reciprocal().expand(4,4)*y).pad((None, (0,1),)).sum()
|
||||
run_schedule(check_schedule(z, 3, [x,y]))
|
||||
run_schedule(check_schedule(z, 2, [x,y]))
|
||||
self.assertEqual(z.item(), 32)
|
||||
|
||||
def test_rand(self):
|
||||
@@ -1860,10 +1859,10 @@ class TestIndexing(unittest.TestCase):
|
||||
args = {"dim":32 if CI else 128, "end":2048 if CI else 8192, "theta":10000}
|
||||
fused = precompute_freqs_cis(**args)
|
||||
with Context(FUSE_ARANGE=1):
|
||||
run_schedule(check_schedule(fused, 5)) # TODO: this is too many
|
||||
run_schedule(check_schedule(fused, 3))
|
||||
if getenv("CHECK", 1):
|
||||
ref = precompute_freqs_cis(**args)
|
||||
run_schedule(check_schedule(ref, 5))
|
||||
run_schedule(check_schedule(ref, 3))
|
||||
np.testing.assert_equal(fused.numpy(), ref.numpy())
|
||||
|
||||
def test_fuse_assign_contiguous(self):
|
||||
|
||||
@@ -413,7 +413,7 @@ class Kernel:
|
||||
check(not self.vars, "does not work with symbolic shape")
|
||||
check(axis < self.first_upcast, "cannot pad upcasted")
|
||||
# ok to pad SUM if all parent ALU ops have f(0) = 0
|
||||
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}, cache={}), f"cannot pad {r}")
|
||||
if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}")
|
||||
padded = False
|
||||
for i,st in enumerate(self.sts):
|
||||
if (s:=st.shape[axis]) == 1: continue # reduced
|
||||
|
||||
@@ -131,7 +131,7 @@ def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
|
||||
def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None:
|
||||
st = unwrap(view.st)
|
||||
# always realize unsafe pad ops before masked view
|
||||
if any(v.mask is not None for v in st.views) and not can_pad(tr, ctx, cache=dict()): return realize(ctx, tr)
|
||||
if any(v.mask is not None for v in st.views) and not can_pad(tr, ctx): return realize(ctx, tr)
|
||||
# fold simple pads
|
||||
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(tr.shape) and resolve(prod(tr.shape) >= prod([y-x for x,y in m])): return
|
||||
# realize before expand
|
||||
|
||||
@@ -188,11 +188,8 @@ class GroupOp:
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
||||
def can_pad(u:UOp, edges:dict[UOp, None], cache:dict[UOp, None]) -> bool:
|
||||
if u.op in GroupOp.UnsafePad: return False
|
||||
if u in edges or u in cache: return True
|
||||
cache[u] = None
|
||||
return all(can_pad(x.base, edges, cache) for x in u.src)
|
||||
def can_pad(root:UOp, edges:dict[UOp, None]) -> bool:
|
||||
return all(u.op not in GroupOp.UnsafePad for u in root.toposort(gate=lambda x:x not in edges))
|
||||
|
||||
# With True as the default, this matches the old symbolic behavior
|
||||
def resolve(x:UOp|bool, default:bool=True):
|
||||
|
||||
Reference in New Issue
Block a user