all consts have shapes (#14959)

* all consts have shapes

* vconst has shape too

* use normal schedule

* cast ptrdtype

* image

* bitcast issue + hack
This commit is contained in:
George Hotz
2026-02-23 10:26:50 +08:00
committed by GitHub
parent 1538960002
commit 677145b393
2 changed files with 10 additions and 9 deletions

View File

@@ -265,8 +265,6 @@ class TestCustomKernel(unittest.TestCase):
Expected schedule order: [A2, B2, E, custom_addmul, final_sum]
The custom_addmul kernel should be at index 3.
"""
from tinygrad.engine.schedule import create_schedule
from tinygrad.schedule.rangeify import get_rangeify
A, B = Tensor.empty(4, 4), Tensor.empty(4, 4)
A2 = (A + 1).contiguous() # kernel 0: depends on A
@@ -275,10 +273,7 @@ class TestCustomKernel(unittest.TestCase):
C, D, _, _ = Tensor.custom_kernel(C, D, A2, B2, fxn=custom_elementwise_addmul_kernel) # depends on A2 AND B2
E = (A2 * 3).contiguous() # kernel 2: depends only on A2
result = (C + D + E).sum() # kernel 3: custom_addmul, then kernel 4: sum
big_sink = result.uop.sink()
sched_sink = get_rangeify(big_sink)
schedule, _ = create_schedule(sched_sink)
schedule = result.schedule()
# Find the custom_addmul kernel position
custom_idx = next((i for i, item in enumerate(schedule)

View File

@@ -207,10 +207,15 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
Ops.VECTORIZE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS:
return None
case Ops.CAST:
# when PTX cases from ptr to non ptr, remove the shape
if isinstance(self.src[0].dtype, PtrDType) and not isinstance(self.src[0].dtype, ImageDType) and not isinstance(self.dtype, PtrDType):
return None
case Ops.INDEX:
# non pointer index doesn't have a shape
if not isinstance(self.dtype, PtrDType): return None
@@ -220,7 +225,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return self.src[0].shape[len(self.src[1:]):]
# some ops init the shape
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND: return () if self._device is not None else None
case Ops.CONST | Ops.VCONST | Ops.DEFINE_VAR | Ops.BIND: return ()
case Ops.BUFFER: return (self.arg,)
case Ops.BUFFER_VIEW: return (self.arg[0],)
case Ops.ENCDEC: return self.arg[0]
@@ -240,7 +245,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.BITCAST:
ps = self.src[0]._shape
if ps is None: return None
if (output_sz:=self.dtype.itemsize) != (input_sz:=self.src[0].dtype.itemsize): return ps[:-1]+(ssimplify((ps[-1]*input_sz) // output_sz),)
if (output_sz:=self.dtype.itemsize) != (input_sz:=self.src[0].dtype.itemsize):
return ps[:-1]+(ssimplify((ps[-1]*input_sz) // output_sz),) if len(ps) > 0 else ps
return ps
# TODO: disallow reshape from nothing. tested by TestOpenClip.test_multigpu_clip_score