diff --git a/test/test_search.py b/test/test_search.py index cb946fd51d..5ca03c43d3 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -26,7 +26,17 @@ class TestTimeLinearizer(unittest.TestCase): def test_bufs_from_lin(self): si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is UOps.SINK][0] rawbufs = bufs_from_lin(lin:=Kernel(si.ast)) - assert len(rawbufs) == len(lin.membufs) + assert len(rawbufs) == len(lin.membufs) == 2 + assert all(r is not None for r in rawbufs) + assert all(isinstance(r, Buffer) for r in rawbufs) + assert all(r.size > 0 for r in rawbufs) + + def test_bufs_from_lin_alt(self): + a = Tensor.randn(4, 4) + b = a+a[0] + si = [si for si in b.schedule() if si.ast.op is UOps.SINK][0] + rawbufs = bufs_from_lin(k:=Kernel(si.ast)) + assert len(rawbufs) == len(k.membufs) == 2 assert all(r is not None for r in rawbufs) assert all(isinstance(r, Buffer) for r in rawbufs) assert all(r.size > 0 for r in rawbufs) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index f44f5ffad9..cca0a3ef74 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -128,7 +128,7 @@ class Kernel: return ret @property - def membufs(self) -> List[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {UOps.LOAD, UOps.STORE}]) + def membufs(self) -> List[UOp]: return list({x.src[0].key:x.src[0] for x in self.bufs if x.op in {UOps.LOAD, UOps.STORE}}.values()) # TODO: these need more tests or it might silently be no-op def float4_axis(self, i:int): return [x-self.first_upcast for x in self.sts[i].unit_stride_axes() if x >= self.first_upcast and self.sts[i].shape[x]%4 == 0] # noqa: E501 @@ -242,8 +242,8 @@ class Kernel: shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts] # if it's an image, insert fake strides such that this fusion doesn't happen across image axes - if isinstance(self.bufs[0].src[0].dtype, ImageDType): - base_shape = self.bufs[0].src[0].dtype.shape + if isinstance(self.membufs[0].dtype, ImageDType): + base_shape = self.membufs[0].dtype.shape if shape_idx_groups := get_contraction(self.output_shape, base_shape): special_strides: Tuple[sint, ...] = tuple() for i,g in enumerate(shape_idx_groups): @@ -407,7 +407,7 @@ class Kernel: else: amt = -1 if self.reduceop and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})): - acc_sz = dt.base.itemsize if isinstance((dt:=cast(DType, self.reduceop.dtype)), ImageDType) else dt.itemsize + acc_sz = cast(DType, self.reduceop.dtype).itemsize upcast_sz = prod([a for a,b in zip(self.full_shape[self.first_upcast:], self.sts[0].shape[self.first_upcast:]) if a == b]) local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces]) smem_sz = amt*acc_sz*upcast_sz*local_sz @@ -529,7 +529,7 @@ class Kernel: # upcast float4 images for buf_index,buf in enumerate(self.bufs): unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0] - if buf.dtype.__class__ is ImageDType: + if buf.src[0].dtype.__class__ is ImageDType: #assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}" if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501 if unit_stride_axes_mul_4[0] < self.first_reduce: