From 677145b39369607b1091708bf6829eaef0498954 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:26:50 +0800 Subject: [PATCH] all consts have shapes (#14959) * all consts have shapes * vconst has shape too * use normal schedule * cast ptrdtype * image * bitcast issue + hack --- test/backend/test_custom_kernel.py | 7 +------ tinygrad/uop/ops.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/test/backend/test_custom_kernel.py b/test/backend/test_custom_kernel.py index f59fc35266..ee4d7bc8a4 100644 --- a/test/backend/test_custom_kernel.py +++ b/test/backend/test_custom_kernel.py @@ -265,8 +265,6 @@ class TestCustomKernel(unittest.TestCase): Expected schedule order: [A2, B2, E, custom_addmul, final_sum] The custom_addmul kernel should be at index 3. """ - from tinygrad.engine.schedule import create_schedule - from tinygrad.schedule.rangeify import get_rangeify A, B = Tensor.empty(4, 4), Tensor.empty(4, 4) A2 = (A + 1).contiguous() # kernel 0: depends on A @@ -275,10 +273,7 @@ class TestCustomKernel(unittest.TestCase): C, D, _, _ = Tensor.custom_kernel(C, D, A2, B2, fxn=custom_elementwise_addmul_kernel) # depends on A2 AND B2 E = (A2 * 3).contiguous() # kernel 2: depends only on A2 result = (C + D + E).sum() # kernel 3: custom_addmul, then kernel 4: sum - - big_sink = result.uop.sink() - sched_sink = get_rangeify(big_sink) - schedule, _ = create_schedule(sched_sink) + schedule = result.schedule() # Find the custom_addmul kernel position custom_idx = next((i for i, item in enumerate(schedule) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index a981a6ea8f..fa835be5c7 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -207,10 +207,15 @@ class UOp(OpMixin, metaclass=UOpMetaClass): match self.op: # late ops don't have shape case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ - Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \ + Ops.VECTORIZE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \ Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS: return None + case Ops.CAST: + # when PTX cases from ptr to non ptr, remove the shape + if isinstance(self.src[0].dtype, PtrDType) and not isinstance(self.src[0].dtype, ImageDType) and not isinstance(self.dtype, PtrDType): + return None + case Ops.INDEX: # non pointer index doesn't have a shape if not isinstance(self.dtype, PtrDType): return None @@ -220,7 +225,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return self.src[0].shape[len(self.src[1:]):] # some ops init the shape - case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND: return () if self._device is not None else None + case Ops.CONST | Ops.VCONST | Ops.DEFINE_VAR | Ops.BIND: return () case Ops.BUFFER: return (self.arg,) case Ops.BUFFER_VIEW: return (self.arg[0],) case Ops.ENCDEC: return self.arg[0] @@ -240,7 +245,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass): case Ops.BITCAST: ps = self.src[0]._shape if ps is None: return None - if (output_sz:=self.dtype.itemsize) != (input_sz:=self.src[0].dtype.itemsize): return ps[:-1]+(ssimplify((ps[-1]*input_sz) // output_sz),) + if (output_sz:=self.dtype.itemsize) != (input_sz:=self.src[0].dtype.itemsize): + return ps[:-1]+(ssimplify((ps[-1]*input_sz) // output_sz),) if len(ps) > 0 else ps return ps # TODO: disallow reshape from nothing. tested by TestOpenClip.test_multigpu_clip_score