mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
all consts have shapes (#14959)
* all consts have shapes * vconst has shape too * use normal schedule * cast ptrdtype * image * bitcast issue + hack
This commit is contained in:
@@ -265,8 +265,6 @@ class TestCustomKernel(unittest.TestCase):
|
||||
Expected schedule order: [A2, B2, E, custom_addmul, final_sum]
|
||||
The custom_addmul kernel should be at index 3.
|
||||
"""
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.schedule.rangeify import get_rangeify
|
||||
|
||||
A, B = Tensor.empty(4, 4), Tensor.empty(4, 4)
|
||||
A2 = (A + 1).contiguous() # kernel 0: depends on A
|
||||
@@ -275,10 +273,7 @@ class TestCustomKernel(unittest.TestCase):
|
||||
C, D, _, _ = Tensor.custom_kernel(C, D, A2, B2, fxn=custom_elementwise_addmul_kernel) # depends on A2 AND B2
|
||||
E = (A2 * 3).contiguous() # kernel 2: depends only on A2
|
||||
result = (C + D + E).sum() # kernel 3: custom_addmul, then kernel 4: sum
|
||||
|
||||
big_sink = result.uop.sink()
|
||||
sched_sink = get_rangeify(big_sink)
|
||||
schedule, _ = create_schedule(sched_sink)
|
||||
schedule = result.schedule()
|
||||
|
||||
# Find the custom_addmul kernel position
|
||||
custom_idx = next((i for i, item in enumerate(schedule)
|
||||
|
||||
@@ -207,10 +207,15 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
|
||||
Ops.VECTORIZE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS:
|
||||
return None
|
||||
|
||||
case Ops.CAST:
|
||||
# when PTX cases from ptr to non ptr, remove the shape
|
||||
if isinstance(self.src[0].dtype, PtrDType) and not isinstance(self.src[0].dtype, ImageDType) and not isinstance(self.dtype, PtrDType):
|
||||
return None
|
||||
|
||||
case Ops.INDEX:
|
||||
# non pointer index doesn't have a shape
|
||||
if not isinstance(self.dtype, PtrDType): return None
|
||||
@@ -220,7 +225,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return self.src[0].shape[len(self.src[1:]):]
|
||||
|
||||
# some ops init the shape
|
||||
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND: return () if self._device is not None else None
|
||||
case Ops.CONST | Ops.VCONST | Ops.DEFINE_VAR | Ops.BIND: return ()
|
||||
case Ops.BUFFER: return (self.arg,)
|
||||
case Ops.BUFFER_VIEW: return (self.arg[0],)
|
||||
case Ops.ENCDEC: return self.arg[0]
|
||||
@@ -240,7 +245,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
case Ops.BITCAST:
|
||||
ps = self.src[0]._shape
|
||||
if ps is None: return None
|
||||
if (output_sz:=self.dtype.itemsize) != (input_sz:=self.src[0].dtype.itemsize): return ps[:-1]+(ssimplify((ps[-1]*input_sz) // output_sz),)
|
||||
if (output_sz:=self.dtype.itemsize) != (input_sz:=self.src[0].dtype.itemsize):
|
||||
return ps[:-1]+(ssimplify((ps[-1]*input_sz) // output_sz),) if len(ps) > 0 else ps
|
||||
return ps
|
||||
|
||||
# TODO: disallow reshape from nothing. tested by TestOpenClip.test_multigpu_clip_score
|
||||
|
||||
Reference in New Issue
Block a user