mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
UOps.VECTORIZE cleanups [run_process_replay] (#5314)
* still render_cast * one extra line ok * these are all just vectorize * save space * behavior change can go in a different diff
This commit is contained in:
@@ -733,7 +733,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
# check that the float4 cast collapses
|
||||
store_vals = [u.src[-1] for u in k.uops if u.op is UOps.STORE]
|
||||
for val in store_vals:
|
||||
assert val.dtype == dtypes.float.vec(4) and val.op not in {UOps.VECTORIZE, UOps.CAST}
|
||||
assert val.dtype == dtypes.float.vec(4) and val.op is not UOps.VECTORIZE
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
def test_grouped_store_values(self):
|
||||
@@ -741,7 +741,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
out = x.flip((0,1)).contiguous()
|
||||
k = helper_linearizer_opt(out)[-1]
|
||||
store_val = [u.src[-1] for u in k.uops if u.op is UOps.STORE][0]
|
||||
assert store_val.dtype == dtypes.float.vec(4) and store_val.op not in {UOps.VECTORIZE, UOps.CAST}
|
||||
assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not UOps.VECTORIZE
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
@@ -759,7 +759,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
barrier = [u for u in k.uops if u.op is UOps.BARRIER][0]
|
||||
# check that the float4 cast collapses for all stores
|
||||
for store in local_stores+global_stores:
|
||||
assert store.src[2].dtype == dtypes.float.vec(2) and store.src[2].op not in {UOps.VECTORIZE, UOps.CAST}
|
||||
assert store.src[2].dtype == dtypes.float.vec(2) and store.src[2].op is not UOps.VECTORIZE
|
||||
# # check the children's vins
|
||||
# TODO: src ALU are not the same, should it?
|
||||
# assert barrier.src == tuple(local_stores)
|
||||
@@ -776,7 +776,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
# the float4 value stores directly in lds and we skip upcast
|
||||
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
|
||||
assert stores[0].src[-1].op not in {UOps.VECTORIZE, UOps.CAST}
|
||||
assert stores[0].src[-1].op is not UOps.VECTORIZE
|
||||
|
||||
# the global store doesn't change
|
||||
assert stores[1].src[2].dtype == dtypes.float
|
||||
|
||||
@@ -205,8 +205,7 @@ class Linearizer(Kernel):
|
||||
idx, valid = self.sts[i].expr_idxs(_idx)
|
||||
if isinstance(buf.dtype, ImageDType):
|
||||
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), \
|
||||
tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
|
||||
rendered_idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
|
||||
else:
|
||||
rendered_idx = idx.render(render_ops, self.loop_uops)
|
||||
if self.late_gate is not None: valid *= self.late_gate
|
||||
|
||||
@@ -98,8 +98,9 @@ def type_verify(uops):
|
||||
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
||||
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
|
||||
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == dtype.count
|
||||
if uop is UOps.VECTORIZE: assert dtype.count > 1 and len(src) == dtype.count
|
||||
if uop is UOps.VECTORIZE: assert dtype == src[0].dtype.vec(len(src)), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
|
||||
if uop is UOps.VECTORIZE:
|
||||
assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}"
|
||||
assert dtype == src[0].dtype.vec(len(src)), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
|
||||
if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
|
||||
if uop is UOps.STORE:
|
||||
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
|
||||
|
||||
@@ -236,7 +236,7 @@ class MetalRenderer(CStyleLanguage):
|
||||
UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"(bfloat)precise::sin({x})" if dtype == dtypes.bfloat16 else f"precise::sin({x})",}
|
||||
|
||||
def render_cast(self, x: str, var_dtype: DType, bitcast=False) -> str:
|
||||
def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str:
|
||||
return f"as_type<{self.render_dtype(var_dtype)}>({x})" if bitcast else super().render_cast(x, var_dtype)
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
||||
|
||||
@@ -100,20 +100,19 @@ class PythonProgram:
|
||||
del ul[i]
|
||||
i = loop_ends[i] + 1
|
||||
continue
|
||||
elif uop in (UOps.CAST, UOps.BITCAST, UOps.VECTORIZE):
|
||||
if dtype.count > 1: ul[i] = inp
|
||||
elif uop is UOps.VECTORIZE: ul[i] = inp
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
assert dtp[0].fmt and dtype.fmt
|
||||
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
|
||||
if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
|
||||
else:
|
||||
assert dtp[0].fmt and dtype.fmt
|
||||
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
|
||||
if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
|
||||
else:
|
||||
casted = [dtypes.as_const(x, dtype) for x in inp[0]]
|
||||
if dtypes.is_int(dtype):
|
||||
overflow_adjust = 2**(dtype.itemsize*8 - 1) if not dtypes.is_unsigned(dtype) else 0
|
||||
casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
|
||||
elif dtypes.is_float(dtype):
|
||||
casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
|
||||
ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
|
||||
casted = [dtypes.as_const(x, dtype) for x in inp[0]]
|
||||
if dtypes.is_int(dtype):
|
||||
overflow_adjust = 2**(dtype.itemsize*8 - 1) if not dtypes.is_unsigned(dtype) else 0
|
||||
casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
|
||||
elif dtypes.is_float(dtype):
|
||||
casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
|
||||
ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
|
||||
elif uop is UOps.LOAD:
|
||||
if isinstance(dtp[0], ImageDType):
|
||||
assert dtype.count == 4
|
||||
|
||||
Reference in New Issue
Block a user