use membufs in ImageDType checks [run_process_replay] (#6136)

* use membufs in ImageDType checks

* set by key [run_process_replay]
This commit is contained in:
qazal
2024-08-17 21:17:46 +08:00
committed by GitHub
parent 41ac8bdd63
commit d1d41130cd
2 changed files with 16 additions and 6 deletions

View File

@@ -26,7 +26,17 @@ class TestTimeLinearizer(unittest.TestCase):
def test_bufs_from_lin(self):
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is UOps.SINK][0]
rawbufs = bufs_from_lin(lin:=Kernel(si.ast))
assert len(rawbufs) == len(lin.membufs)
assert len(rawbufs) == len(lin.membufs) == 2
assert all(r is not None for r in rawbufs)
assert all(isinstance(r, Buffer) for r in rawbufs)
assert all(r.size > 0 for r in rawbufs)
def test_bufs_from_lin_alt(self):
a = Tensor.randn(4, 4)
b = a+a[0]
si = [si for si in b.schedule() if si.ast.op is UOps.SINK][0]
rawbufs = bufs_from_lin(k:=Kernel(si.ast))
assert len(rawbufs) == len(k.membufs) == 2
assert all(r is not None for r in rawbufs)
assert all(isinstance(r, Buffer) for r in rawbufs)
assert all(r.size > 0 for r in rawbufs)

View File

@@ -128,7 +128,7 @@ class Kernel:
return ret
@property
def membufs(self) -> List[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {UOps.LOAD, UOps.STORE}])
def membufs(self) -> List[UOp]: return list({x.src[0].key:x.src[0] for x in self.bufs if x.op in {UOps.LOAD, UOps.STORE}}.values())
# TODO: these need more tests or it might silently be no-op
def float4_axis(self, i:int): return [x-self.first_upcast for x in self.sts[i].unit_stride_axes() if x >= self.first_upcast and self.sts[i].shape[x]%4 == 0] # noqa: E501
@@ -242,8 +242,8 @@ class Kernel:
shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
# if it's an image, insert fake strides such that this fusion doesn't happen across image axes
if isinstance(self.bufs[0].src[0].dtype, ImageDType):
base_shape = self.bufs[0].src[0].dtype.shape
if isinstance(self.membufs[0].dtype, ImageDType):
base_shape = self.membufs[0].dtype.shape
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
special_strides: Tuple[sint, ...] = tuple()
for i,g in enumerate(shape_idx_groups):
@@ -407,7 +407,7 @@ class Kernel:
else: amt = -1
if self.reduceop and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
acc_sz = dt.base.itemsize if isinstance((dt:=cast(DType, self.reduceop.dtype)), ImageDType) else dt.itemsize
acc_sz = cast(DType, self.reduceop.dtype).itemsize
upcast_sz = prod([a for a,b in zip(self.full_shape[self.first_upcast:], self.sts[0].shape[self.first_upcast:]) if a == b])
local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
smem_sz = amt*acc_sz*upcast_sz*local_sz
@@ -529,7 +529,7 @@ class Kernel:
# upcast float4 images
for buf_index,buf in enumerate(self.bufs):
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
if buf.dtype.__class__ is ImageDType:
if buf.src[0].dtype.__class__ is ImageDType:
#assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
if unit_stride_axes_mul_4[0] < self.first_reduce: