mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
Add UOps.VECTORIZE [run_process_replay] (#5289)
* Add UOps.VECTORIZE to core * Update vectorized cast tests * Addresses code review comments - Removes VECTORIZE from LLVMRenderer - Add line breaks to unduly long lines - Add noop CAST rule back - Update asserts and add render_vectorize in CSytleLanguage renderer * Add missing const folding rule for VECTORIZE Also adds corresponding test * Fixes test_const_vectorize_fold and add assert - Use sane types with VECTORIZE in test_const_vectorize_fold - Add assert that sanity checks the types for VECTORIZE * Rename test_cast_vectorized_fold Renames test_cast_vectorized_fold to test_noop_vectorize_fold because the test targets a very specific rule and there are other tests for VECTORIZE. * Revert unrelated changes --------- Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com> Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
@@ -733,7 +733,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
# check that the float4 cast collapses
|
||||
store_vals = [u.src[-1] for u in k.uops if u.op is UOps.STORE]
|
||||
for val in store_vals:
|
||||
assert val.dtype == dtypes.float.vec(4) and val.op is not UOps.CAST
|
||||
assert val.dtype == dtypes.float.vec(4) and val.op not in {UOps.VECTORIZE, UOps.CAST}
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
def test_grouped_store_values(self):
|
||||
@@ -741,7 +741,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
out = x.flip((0,1)).contiguous()
|
||||
k = helper_linearizer_opt(out)[-1]
|
||||
store_val = [u.src[-1] for u in k.uops if u.op is UOps.STORE][0]
|
||||
assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not UOps.CAST
|
||||
assert store_val.dtype == dtypes.float.vec(4) and store_val.op not in {UOps.VECTORIZE, UOps.CAST}
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
@@ -759,7 +759,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
barrier = [u for u in k.uops if u.op is UOps.BARRIER][0]
|
||||
# check that the float4 cast collapses for all stores
|
||||
for store in local_stores+global_stores:
|
||||
assert store.src[2].dtype == dtypes.float.vec(2) and store.src[2].op is not UOps.CAST
|
||||
assert store.src[2].dtype == dtypes.float.vec(2) and store.src[2].op not in {UOps.VECTORIZE, UOps.CAST}
|
||||
# # check the children's vins
|
||||
# TODO: src ALU are not the same, should it?
|
||||
# assert barrier.src == tuple(local_stores)
|
||||
@@ -776,7 +776,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
# the float4 value stores directly in lds and we skip upcast
|
||||
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
|
||||
assert stores[0].src[-1].op is not UOps.CAST
|
||||
assert stores[0].src[-1].op not in {UOps.VECTORIZE, UOps.CAST}
|
||||
|
||||
# the global store doesn't change
|
||||
assert stores[1].src[2].dtype == dtypes.float
|
||||
@@ -791,7 +791,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
]
|
||||
k = helper_linearizer_ast(ast, [Tensor.randn(240*40).realize()], opts=[opt])[-1]
|
||||
out = [u for u in k.uops if u.op is UOps.STORE][0]
|
||||
assert out.src[-1].op is UOps.CAST and out.src[-1].dtype == dtypes.float.vec(4)
|
||||
assert out.src[-1].op is UOps.VECTORIZE and out.src[-1].dtype == dtypes.float.vec(4)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
@@ -803,7 +803,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)]
|
||||
k = helper_linearizer_ast(ast, [Tensor.randn(8*32).realize()], opts=[opt])[-1]
|
||||
out = [u for u in k.uops if u.op is UOps.STORE][0]
|
||||
assert out.src[-1].op is UOps.CAST and out.src[-1].dtype == dtypes.float.vec(2)
|
||||
assert out.src[-1].op is UOps.VECTORIZE and out.src[-1].dtype == dtypes.float.vec(2)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4")
|
||||
class TestFloat4(unittest.TestCase):
|
||||
|
||||
@@ -119,16 +119,25 @@ class TestUOpGraph(TestUOps):
|
||||
self.assertEqual(out.op, UOps.CONST)
|
||||
self.assertEqual(out.arg, 0)
|
||||
|
||||
def test_cast_vectorized_fold(self):
|
||||
def test_const_vectorize_fold(self):
|
||||
c0 = UOp(UOps.CONST, dtypes.half, arg=0.0)
|
||||
out = UOp(UOps.VECTORIZE, dtypes.half.vec(2), (c0, c0))
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.op, UOps.CONST)
|
||||
self.assertEqual(out.arg, 0.0)
|
||||
|
||||
def test_noop_vectorize_fold(self):
|
||||
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=(0, True))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(UOps.LOAD, dtypes.float.vec(2), (d0, idx))
|
||||
cast = UOp(UOps.CAST, dtypes.float.vec(2), (ld,))
|
||||
x = UOp(UOps.GEP, dtypes.float, (cast, ), arg=0)
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(2), (ld,))
|
||||
x = UOp(UOps.GEP, dtypes.float, (vec, ), arg=0)
|
||||
alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
|
||||
out = UOp(UOps.STORE, None, (d0, idx, alu))
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 0)
|
||||
self.assertEqual(len([x for x in g.uops if x.op is UOps.VECTORIZE]), 0)
|
||||
|
||||
def test_cast_alu_fold(self):
|
||||
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=(0, True))
|
||||
|
||||
@@ -156,7 +156,7 @@ class Linearizer(Kernel):
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
|
||||
rendered_idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
|
||||
valid_tuple = (valid_uop, UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
|
||||
self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4), (buf_uop, rendered_idx) + valid_tuple + barrier)
|
||||
if localtype == localtype.scalar():
|
||||
@@ -197,7 +197,7 @@ class Linearizer(Kernel):
|
||||
amt = len(grouped)
|
||||
idx, valid = self.sts[i].expr_idxs(k)
|
||||
assert idx == ((idx//amt)*amt), "float4 stores are always aligned"
|
||||
store_offset_new[k] = UOp(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
|
||||
store_offset_new[k] = UOp(UOps.VECTORIZE, buf.dtype.vec(amt), tuple(grouped))
|
||||
store_offset = store_offset_new
|
||||
|
||||
stores = []
|
||||
@@ -205,7 +205,7 @@ class Linearizer(Kernel):
|
||||
idx, valid = self.sts[i].expr_idxs(_idx)
|
||||
if isinstance(buf.dtype, ImageDType):
|
||||
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), \
|
||||
rendered_idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), \
|
||||
tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
|
||||
else:
|
||||
rendered_idx = idx.render(render_ops, self.loop_uops)
|
||||
@@ -287,13 +287,13 @@ class Linearizer(Kernel):
|
||||
next_ *= 1 if stride == 0 else sz
|
||||
return strides
|
||||
upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
|
||||
# cast initial accs
|
||||
wmmas = [UOp(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]]))
|
||||
# vectorize initial accs
|
||||
wmmas = [UOp(UOps.VECTORIZE, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]]))
|
||||
for x in range(0, len(accs[reduceop]), wmma_sz[2])]
|
||||
for it in [x[::-1] for x in itertools.product(*list([range(sz) for _,sz in upcasts[0]][::-1]))]:
|
||||
offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(it, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
|
||||
ops = (UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
|
||||
UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
|
||||
ops = (UOp(UOps.VECTORIZE, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
|
||||
UOp(UOps.VECTORIZE, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
|
||||
wmmas[(wmma_idx:=offs[2]//wmma_sz[2])])
|
||||
# TODO: don't need to DEFINE_ACC, pass to WMMA in op3, or PHI accs that are not valid
|
||||
wmmas[wmma_idx] = UOp(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), dev))
|
||||
|
||||
@@ -17,7 +17,7 @@ class UOps(Enum):
|
||||
CONST = auto(); SPECIAL = auto() # noqa: E702
|
||||
NOOP = auto(); UNMUL = auto(); GEP = auto() # noqa: E702
|
||||
# math ops
|
||||
CAST = auto(); BITCAST = auto() # noqa: E702
|
||||
CAST = auto(); BITCAST = auto(); VECTORIZE = auto() # noqa: E702
|
||||
ALU = auto(); WMMA = auto() # noqa: E702
|
||||
# memory/assignment ops
|
||||
LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702
|
||||
@@ -96,8 +96,10 @@ def type_verify(uops):
|
||||
assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
|
||||
arg = src[0].arg
|
||||
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
||||
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
|
||||
if uop is UOps.CAST and dtype is not None and dtype.count > 1: assert len(src) == dtype.count
|
||||
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
|
||||
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == dtype.count
|
||||
if uop is UOps.VECTORIZE: assert dtype.count > 1 and len(src) == dtype.count
|
||||
if uop is UOps.VECTORIZE: assert dtype == src[0].dtype.vec(len(src)), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
|
||||
if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
|
||||
if uop is UOps.STORE:
|
||||
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
|
||||
@@ -231,6 +233,7 @@ constant_folder = PatternMatcher([
|
||||
# const rules
|
||||
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
(UPat(UOps.VECTORIZE, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST),)), UPat(name="x"))), lambda x: x),
|
||||
@@ -300,18 +303,23 @@ constant_folder = PatternMatcher([
|
||||
# TODO: can do the invert of this (flip alt/load) when we fix double ops
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
|
||||
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
|
||||
# store float4/float2 directly (remove CAST/GEP)
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store),
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(2)))), UOp.store),
|
||||
# CAST-PHI-GEP -> PHI-CAST
|
||||
(UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
|
||||
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
|
||||
(UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
|
||||
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))),
|
||||
# store float4/float2 directly (remove VECTORIZE/GEP)
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.VECTORIZE, src=tuple(
|
||||
UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store),
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.VECTORIZE, src=tuple(
|
||||
UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(2)))), UOp.store),
|
||||
# VECTORIZE-PHI-GEP -> PHI-VECTORIZE
|
||||
(UPat(UOps.VECTORIZE, name="root", src=tuple(
|
||||
UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
|
||||
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3))))),
|
||||
(UPat(UOps.VECTORIZE, name="root", src=tuple(
|
||||
UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
|
||||
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1))))),
|
||||
# NEG/CMPLT -> CMPLT
|
||||
(UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)),
|
||||
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
||||
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
||||
(UPat(UOps.VECTORIZE, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
||||
# fold gated LOAD/STORE
|
||||
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
|
||||
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var"), UOp.var("barrier")),
|
||||
|
||||
@@ -197,10 +197,12 @@ class PTXRenderer(Renderer):
|
||||
else:
|
||||
kk(f"mov.b{self.types[dtype][1:]} {r[src[0]]}, {r[src[1]]};")
|
||||
r[u] = r[src[0]]
|
||||
elif uop in {UOps.VECTORIZE}:
|
||||
assert src[0].dtype is not None and dtype.count > 1
|
||||
r[u] = [r[x] for x in src] # type: ignore
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
assert src[0].dtype is not None
|
||||
if dtype.count>1: r[u] = [r[x] for x in src] # type: ignore
|
||||
else: _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
||||
assert src[0].dtype is not None and dtype.count == 1
|
||||
_cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
# TODO: we should sum these, and fetch 0xC000 from somewhere
|
||||
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
|
||||
|
||||
@@ -33,9 +33,12 @@ class CStyleLanguage(Renderer):
|
||||
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
|
||||
|
||||
# returns a str expression of the casted xs with the given type
|
||||
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
|
||||
if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x[0]}))"
|
||||
if len(x) == 1: return f"({self.render_dtype(var_dtype)})({x[0]})"
|
||||
def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str:
|
||||
if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x}))"
|
||||
return f"({self.render_dtype(var_dtype)})({x})"
|
||||
|
||||
# returns a str expression of the vectorized xs with the given type
|
||||
def render_vectorize(self, x:List[str], var_dtype:DType) -> str:
|
||||
assert len(x) == var_dtype.count, f"cast is wrong size {len(x)} != {var_dtype.count}"
|
||||
assert self.float4 is not None, "vectorized cast is not supported on this platform"
|
||||
return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
|
||||
@@ -47,7 +50,8 @@ class CStyleLanguage(Renderer):
|
||||
elif dtype.scalar() == dtypes.bool: val = "1" if x else "0"
|
||||
elif dtype.scalar() == dtypes.float: val = f"{x}f"
|
||||
else: val = str(x)
|
||||
return (self.render_cast([val] * dtype.count, dtype) if dtype.count > 1 or dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
|
||||
if dtype.count > 1: return self.render_vectorize([val] * dtype.count, dtype)
|
||||
return (self.render_cast(val, dtype) if dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
|
||||
|
||||
# returns a str expression of the loaded value with the output type
|
||||
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
|
||||
@@ -144,14 +148,14 @@ class CStyleLanguage(Renderer):
|
||||
elif uop is UOps.PHI:
|
||||
kk(f"{r[src[0]]} = {r[src[1]]};")
|
||||
r[u] = r[src[0]]
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
elif uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}:
|
||||
assert len(src) == 1 or (uop is UOps.VECTORIZE and len(src) > 1), "Invalid source length for operation"
|
||||
if uop is UOps.BITCAST:
|
||||
assert len(src) == 1
|
||||
precast = ssa('precast')
|
||||
kk(f"{self.render_dtype(cast(DType, src[0].dtype))} {precast} = {r[src[0]]};")
|
||||
val = self.render_cast([precast], dtype, bitcast=True)
|
||||
else:
|
||||
val = self.render_cast([r[x] for x in src], dtype, bitcast=False)
|
||||
val = self.render_cast(precast, dtype, bitcast=True)
|
||||
elif uop is UOps.CAST: val = self.render_cast(r[src[0]], dtype, bitcast=False)
|
||||
else: val = self.render_vectorize([r[x] for x in src], dtype)
|
||||
if child_count[u] <= 1: r[u] = val
|
||||
else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
@@ -201,7 +205,7 @@ class OpenCLRenderer(CStyleLanguage):
|
||||
uses_vload = True
|
||||
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" }
|
||||
def render_cast(self, x, var_dtype, bitcast=False) -> str:
|
||||
return f"as_{self.render_dtype(var_dtype)}({x[0]})" if bitcast else super().render_cast(x, var_dtype)
|
||||
return f"as_{self.render_dtype(var_dtype)}({x})" if bitcast else super().render_cast(x, var_dtype)
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"]
|
||||
@@ -232,8 +236,8 @@ class MetalRenderer(CStyleLanguage):
|
||||
UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"(bfloat)precise::sin({x})" if dtype == dtypes.bfloat16 else f"precise::sin({x})",}
|
||||
|
||||
def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str:
|
||||
return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
|
||||
def render_cast(self, x: str, var_dtype: DType, bitcast=False) -> str:
|
||||
return f"as_type<{self.render_dtype(var_dtype)}>({x})" if bitcast else super().render_cast(x, var_dtype)
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
||||
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])
|
||||
|
||||
@@ -100,7 +100,7 @@ class PythonProgram:
|
||||
del ul[i]
|
||||
i = loop_ends[i] + 1
|
||||
continue
|
||||
elif uop in (UOps.CAST, UOps.BITCAST):
|
||||
elif uop in (UOps.CAST, UOps.BITCAST, UOps.VECTORIZE):
|
||||
if dtype.count > 1: ul[i] = inp
|
||||
else:
|
||||
assert dtp[0].fmt and dtype.fmt
|
||||
|
||||
Reference in New Issue
Block a user