Add UOps.VECTORIZE [run_process_replay] (#5289)

* Add UOps.VECTORIZE to core

* Update vectorized cast tests

* Addresses code review comments

- Removes VECTORIZE from LLVMRenderer
- Add line breaks to unduly long lines
- Add noop CAST rule back
- Update asserts and add render_vectorize in
  CSytleLanguage renderer

* Add missing const folding rule for VECTORIZE

Also adds corresponding test

* Fixes test_const_vectorize_fold and add assert

- Use sane types with VECTORIZE in test_const_vectorize_fold
- Add assert that sanity checks the types for VECTORIZE

* Rename test_cast_vectorized_fold

Renames test_cast_vectorized_fold to test_noop_vectorize_fold
because the test targets a very specific rule and there are
other tests for VECTORIZE.

* Revert unrelated changes

---------

Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
greg-niemeyer
2024-07-06 23:59:57 -07:00
committed by GitHub
parent 2a7282c1e1
commit 77b2ce9fc9
7 changed files with 67 additions and 44 deletions

View File

@@ -733,7 +733,7 @@ class TestLinearizer(unittest.TestCase):
# check that the float4 cast collapses
store_vals = [u.src[-1] for u in k.uops if u.op is UOps.STORE]
for val in store_vals:
assert val.dtype == dtypes.float.vec(4) and val.op is not UOps.CAST
assert val.dtype == dtypes.float.vec(4) and val.op not in {UOps.VECTORIZE, UOps.CAST}
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_grouped_store_values(self):
@@ -741,7 +741,7 @@ class TestLinearizer(unittest.TestCase):
out = x.flip((0,1)).contiguous()
k = helper_linearizer_opt(out)[-1]
store_val = [u.src[-1] for u in k.uops if u.op is UOps.STORE][0]
assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not UOps.CAST
assert store_val.dtype == dtypes.float.vec(4) and store_val.op not in {UOps.VECTORIZE, UOps.CAST}
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@@ -759,7 +759,7 @@ class TestLinearizer(unittest.TestCase):
barrier = [u for u in k.uops if u.op is UOps.BARRIER][0]
# check that the float4 cast collapses for all stores
for store in local_stores+global_stores:
assert store.src[2].dtype == dtypes.float.vec(2) and store.src[2].op is not UOps.CAST
assert store.src[2].dtype == dtypes.float.vec(2) and store.src[2].op not in {UOps.VECTORIZE, UOps.CAST}
# # check the children's vins
# TODO: src ALU are not the same, should it?
# assert barrier.src == tuple(local_stores)
@@ -776,7 +776,7 @@ class TestLinearizer(unittest.TestCase):
# the float4 value stores directly in lds and we skip upcast
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
assert stores[0].src[-1].op is not UOps.CAST
assert stores[0].src[-1].op not in {UOps.VECTORIZE, UOps.CAST}
# the global store doesn't change
assert stores[1].src[2].dtype == dtypes.float
@@ -791,7 +791,7 @@ class TestLinearizer(unittest.TestCase):
]
k = helper_linearizer_ast(ast, [Tensor.randn(240*40).realize()], opts=[opt])[-1]
out = [u for u in k.uops if u.op is UOps.STORE][0]
assert out.src[-1].op is UOps.CAST and out.src[-1].dtype == dtypes.float.vec(4)
assert out.src[-1].op is UOps.VECTORIZE and out.src[-1].dtype == dtypes.float.vec(4)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
@@ -803,7 +803,7 @@ class TestLinearizer(unittest.TestCase):
Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)]
k = helper_linearizer_ast(ast, [Tensor.randn(8*32).realize()], opts=[opt])[-1]
out = [u for u in k.uops if u.op is UOps.STORE][0]
assert out.src[-1].op is UOps.CAST and out.src[-1].dtype == dtypes.float.vec(2)
assert out.src[-1].op is UOps.VECTORIZE and out.src[-1].dtype == dtypes.float.vec(2)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4")
class TestFloat4(unittest.TestCase):

View File

@@ -119,16 +119,25 @@ class TestUOpGraph(TestUOps):
self.assertEqual(out.op, UOps.CONST)
self.assertEqual(out.arg, 0)
def test_cast_vectorized_fold(self):
def test_const_vectorize_fold(self):
c0 = UOp(UOps.CONST, dtypes.half, arg=0.0)
out = UOp(UOps.VECTORIZE, dtypes.half.vec(2), (c0, c0))
g = UOpGraph([out])
self.assertEqual(len(g.uops), 1)
out = g.uops[-1]
self.assertEqual(out.op, UOps.CONST)
self.assertEqual(out.arg, 0.0)
def test_noop_vectorize_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=(0, True))
idx = UOp.const(dtypes.int, 0)
ld = UOp(UOps.LOAD, dtypes.float.vec(2), (d0, idx))
cast = UOp(UOps.CAST, dtypes.float.vec(2), (ld,))
x = UOp(UOps.GEP, dtypes.float, (cast, ), arg=0)
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(2), (ld,))
x = UOp(UOps.GEP, dtypes.float, (vec, ), arg=0)
alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
out = UOp(UOps.STORE, None, (d0, idx, alu))
g = UOpGraph([out])
self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 0)
self.assertEqual(len([x for x in g.uops if x.op is UOps.VECTORIZE]), 0)
def test_cast_alu_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=(0, True))

View File

@@ -156,7 +156,7 @@ class Linearizer(Kernel):
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
rendered_idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
valid_tuple = (valid_uop, UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4), (buf_uop, rendered_idx) + valid_tuple + barrier)
if localtype == localtype.scalar():
@@ -197,7 +197,7 @@ class Linearizer(Kernel):
amt = len(grouped)
idx, valid = self.sts[i].expr_idxs(k)
assert idx == ((idx//amt)*amt), "float4 stores are always aligned"
store_offset_new[k] = UOp(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
store_offset_new[k] = UOp(UOps.VECTORIZE, buf.dtype.vec(amt), tuple(grouped))
store_offset = store_offset_new
stores = []
@@ -205,7 +205,7 @@ class Linearizer(Kernel):
idx, valid = self.sts[i].expr_idxs(_idx)
if isinstance(buf.dtype, ImageDType):
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), \
rendered_idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), \
tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
else:
rendered_idx = idx.render(render_ops, self.loop_uops)
@@ -287,13 +287,13 @@ class Linearizer(Kernel):
next_ *= 1 if stride == 0 else sz
return strides
upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
# cast initial accs
wmmas = [UOp(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]]))
# vectorize initial accs
wmmas = [UOp(UOps.VECTORIZE, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]]))
for x in range(0, len(accs[reduceop]), wmma_sz[2])]
for it in [x[::-1] for x in itertools.product(*list([range(sz) for _,sz in upcasts[0]][::-1]))]:
offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(it, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
ops = (UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
ops = (UOp(UOps.VECTORIZE, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
UOp(UOps.VECTORIZE, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
wmmas[(wmma_idx:=offs[2]//wmma_sz[2])])
# TODO: don't need to DEFINE_ACC, pass to WMMA in op3, or PHI accs that are not valid
wmmas[wmma_idx] = UOp(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), dev))

View File

@@ -17,7 +17,7 @@ class UOps(Enum):
CONST = auto(); SPECIAL = auto() # noqa: E702
NOOP = auto(); UNMUL = auto(); GEP = auto() # noqa: E702
# math ops
CAST = auto(); BITCAST = auto() # noqa: E702
CAST = auto(); BITCAST = auto(); VECTORIZE = auto() # noqa: E702
ALU = auto(); WMMA = auto() # noqa: E702
# memory/assignment ops
LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702
@@ -96,8 +96,10 @@ def type_verify(uops):
assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
arg = src[0].arg
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
if uop is UOps.CAST and dtype is not None and dtype.count > 1: assert len(src) == dtype.count
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == dtype.count
if uop is UOps.VECTORIZE: assert dtype.count > 1 and len(src) == dtype.count
if uop is UOps.VECTORIZE: assert dtype == src[0].dtype.vec(len(src)), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
if uop is UOps.STORE:
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
@@ -231,6 +233,7 @@ constant_folder = PatternMatcher([
# const rules
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
(UPat(UOps.VECTORIZE, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST),)), UPat(name="x"))), lambda x: x),
@@ -300,18 +303,23 @@ constant_folder = PatternMatcher([
# TODO: can do the invert of this (flip alt/load) when we fix double ops
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
# store float4/float2 directly (remove CAST/GEP)
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store),
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(2)))), UOp.store),
# CAST-PHI-GEP -> PHI-CAST
(UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
(UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))),
# store float4/float2 directly (remove VECTORIZE/GEP)
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.VECTORIZE, src=tuple(
UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store),
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.VECTORIZE, src=tuple(
UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(2)))), UOp.store),
# VECTORIZE-PHI-GEP -> PHI-VECTORIZE
(UPat(UOps.VECTORIZE, name="root", src=tuple(
UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3))))),
(UPat(UOps.VECTORIZE, name="root", src=tuple(
UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1))))),
# NEG/CMPLT -> CMPLT
(UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)),
# cast NOOP (NOTE: it's str to deal with PtrDType)
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
(UPat(UOps.VECTORIZE, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
# fold gated LOAD/STORE
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var"), UOp.var("barrier")),

View File

@@ -197,10 +197,12 @@ class PTXRenderer(Renderer):
else:
kk(f"mov.b{self.types[dtype][1:]} {r[src[0]]}, {r[src[1]]};")
r[u] = r[src[0]]
elif uop in {UOps.VECTORIZE}:
assert src[0].dtype is not None and dtype.count > 1
r[u] = [r[x] for x in src] # type: ignore
elif uop in {UOps.CAST, UOps.BITCAST}:
assert src[0].dtype is not None
if dtype.count>1: r[u] = [r[x] for x in src] # type: ignore
else: _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
assert src[0].dtype is not None and dtype.count == 1
_cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
elif uop is UOps.DEFINE_LOCAL:
# TODO: we should sum these, and fetch 0xC000 from somewhere
assert args[1]*dtype.itemsize <= 0xC000, "too large local"

View File

@@ -33,9 +33,12 @@ class CStyleLanguage(Renderer):
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
# returns a str expression of the casted xs with the given type
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x[0]}))"
if len(x) == 1: return f"({self.render_dtype(var_dtype)})({x[0]})"
def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str:
if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x}))"
return f"({self.render_dtype(var_dtype)})({x})"
# returns a str expression of the vectorized xs with the given type
def render_vectorize(self, x:List[str], var_dtype:DType) -> str:
assert len(x) == var_dtype.count, f"cast is wrong size {len(x)} != {var_dtype.count}"
assert self.float4 is not None, "vectorized cast is not supported on this platform"
return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
@@ -47,7 +50,8 @@ class CStyleLanguage(Renderer):
elif dtype.scalar() == dtypes.bool: val = "1" if x else "0"
elif dtype.scalar() == dtypes.float: val = f"{x}f"
else: val = str(x)
return (self.render_cast([val] * dtype.count, dtype) if dtype.count > 1 or dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
if dtype.count > 1: return self.render_vectorize([val] * dtype.count, dtype)
return (self.render_cast(val, dtype) if dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
# returns a str expression of the loaded value with the output type
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
@@ -144,14 +148,14 @@ class CStyleLanguage(Renderer):
elif uop is UOps.PHI:
kk(f"{r[src[0]]} = {r[src[1]]};")
r[u] = r[src[0]]
elif uop in {UOps.CAST, UOps.BITCAST}:
elif uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}:
assert len(src) == 1 or (uop is UOps.VECTORIZE and len(src) > 1), "Invalid source length for operation"
if uop is UOps.BITCAST:
assert len(src) == 1
precast = ssa('precast')
kk(f"{self.render_dtype(cast(DType, src[0].dtype))} {precast} = {r[src[0]]};")
val = self.render_cast([precast], dtype, bitcast=True)
else:
val = self.render_cast([r[x] for x in src], dtype, bitcast=False)
val = self.render_cast(precast, dtype, bitcast=True)
elif uop is UOps.CAST: val = self.render_cast(r[src[0]], dtype, bitcast=False)
else: val = self.render_vectorize([r[x] for x in src], dtype)
if child_count[u] <= 1: r[u] = val
else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
elif uop is UOps.DEFINE_LOCAL:
@@ -201,7 +205,7 @@ class OpenCLRenderer(CStyleLanguage):
uses_vload = True
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" }
def render_cast(self, x, var_dtype, bitcast=False) -> str:
return f"as_{self.render_dtype(var_dtype)}({x[0]})" if bitcast else super().render_cast(x, var_dtype)
return f"as_{self.render_dtype(var_dtype)}({x})" if bitcast else super().render_cast(x, var_dtype)
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"]
@@ -232,8 +236,8 @@ class MetalRenderer(CStyleLanguage):
UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
UnaryOps.SIN: lambda x,dtype: f"(bfloat)precise::sin({x})" if dtype == dtypes.bfloat16 else f"precise::sin({x})",}
def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str:
return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
def render_cast(self, x: str, var_dtype: DType, bitcast=False) -> str:
return f"as_type<{self.render_dtype(var_dtype)}>({x})" if bitcast else super().render_cast(x, var_dtype)
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])

View File

@@ -100,7 +100,7 @@ class PythonProgram:
del ul[i]
i = loop_ends[i] + 1
continue
elif uop in (UOps.CAST, UOps.BITCAST):
elif uop in (UOps.CAST, UOps.BITCAST, UOps.VECTORIZE):
if dtype.count > 1: ul[i] = inp
else:
assert dtp[0].fmt and dtype.fmt