index in cstyle (#7328)

* index only in cstyle

* fix prefix dtypes

* fix tests

* global indexing

* Revert "global indexing"

This reverts commit 4d507e8abb.

* fix image

* fix image

* ptx tests

* fix CUDA dtype rendering
This commit is contained in:
George Hotz
2024-10-29 12:06:26 +07:00
committed by GitHub
parent f55c3dcff8
commit 4cb236a495
6 changed files with 74 additions and 53 deletions

View File

@@ -550,7 +550,7 @@ class TestTypePromotion(unittest.TestCase):
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_promo_resulted_higher_than_inputs(self, dtype1, dtype2):
result = least_upper_dtype(dtype1, dtype2)
assert result >= dtype1 and result >= dtype2
assert not (result < dtype1) and not (result < dtype2)
def test_dtype_promo(self):
assert least_upper_dtype(dtypes.bool, dtypes.int8) == dtypes.int8

View File

@@ -942,10 +942,7 @@ class TestLinearizer(unittest.TestCase):
sink = UOp(UOps.SINK, src=(store,))
lin = Kernel(sink)
lin.linearize()
assert len(lin.uops) <= 7, "too many uops"
a_bufs = [u.op for u in lin.uops[-1].src[2].src]
assert a_bufs == [UOps.LOAD, UOps.CONST]
assert len(lin.uops) <= 9, "too many uops"
def test_upcast_cse(self):
# when upcasting, within a subtree, there may be common expressions.
@@ -989,10 +986,10 @@ class TestLinearizer(unittest.TestCase):
# the first store is to lds and can be upcasted
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
assert stores[0].src[0].op is UOps.DEFINE_LOCAL
assert any(x.op is UOps.DEFINE_LOCAL for x in stores[0].sparents)
# the second store is to gds with no upcasts
assert stores[1].src[2].dtype == dtypes.float
assert stores[1].src[0].op is UOps.DEFINE_GLOBAL
assert stores[1].src[-1].dtype == dtypes.float
assert any(x.op is UOps.DEFINE_GLOBAL for x in stores[1].sparents)
def test_zero_fold(self):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
@@ -1340,7 +1337,7 @@ class TestLinearizer(unittest.TestCase):
barrier = [u for u in k.uops if u.op is UOps.BARRIER][0]
# check that the float4 cast collapses for all stores
for store in local_stores+global_stores:
assert store.src[2].dtype.count > 1 # and store.src[2].op is not UOps.VECTORIZE
assert store.src[-1].dtype.count > 1 # and store.src[2].op is not UOps.VECTORIZE
# # check the children's vins
# TODO: src ALU are not the same, should it?
# assert barrier.src == tuple(local_stores)
@@ -1360,7 +1357,7 @@ class TestLinearizer(unittest.TestCase):
#assert stores[0].src[-1].op is not UOps.VECTORIZE
# the global store doesn't change
assert stores[1].src[2].dtype == dtypes.float
assert stores[1].src[-1].dtype == dtypes.float
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
@@ -1404,11 +1401,11 @@ class TestFloat4(unittest.TestCase):
@staticmethod
def count_float4(k, n=4):
return (len([uop for uop in k.uops if uop.op is UOps.LOAD and uop.dtype == dtypes.float.vec(n)]),
len([uop for uop in k.uops if uop.op is UOps.STORE and len(uop.src) == 3 and uop.src[2].dtype == dtypes.float.vec(n)]))
len([uop for uop in k.uops if uop.op is UOps.STORE and uop.src[-1].dtype == dtypes.float.vec(n)]))
@staticmethod
def count_half4(k):
return (len([uop for uop in k.uops if uop.op is UOps.LOAD and uop.dtype == dtypes.half.vec(4)]),
len([uop for uop in k.uops if uop.op is UOps.STORE and len(uop.src) == 3 and uop.src[2].dtype == dtypes.half.vec(4)]))
len([uop for uop in k.uops if uop.op is UOps.STORE and uop.src[-1].dtype == dtypes.half.vec(4)]))
# TODO: express opts below as auto opts

View File

@@ -19,7 +19,7 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
uops: List[UOp] = []
def _recursive_add(uop:UOp) -> List[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop]
uops = dedup(flatten(_recursive_add(st) for st in stores))
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[2].dtype), \
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is UOps.STORE]
inbufs = [cast(LazyBuffer,x.lazydata).base.buffer for x in inputs]
src = Device[Device.DEFAULT].renderer.render("test", uops)
@@ -42,7 +42,7 @@ class TestCStyleFailures(unittest.TestCase):
ret = _test_uop_result([Tensor([1])], uops)[0]
self.assertEqual(ret[0], 1)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local and Device.DEFAULT == "PTX", "need local")
class TestPTXFailures(unittest.TestCase):
def test_gated_store_with_alu(self):
a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)

View File

@@ -6,7 +6,7 @@ from tinygrad.helpers import getenv
ConstType = Union[float, int, bool]
@dataclass(frozen=True, order=True)
@dataclass(frozen=True)
class DType:
priority: int # this determines when things get upcasted
itemsize: int
@@ -14,6 +14,7 @@ class DType:
fmt: Optional[str]
count: int
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count > 1 else "")
def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
def vec(self, sz:int):
assert self.count == 1, f"can't vectorize {self} with size {sz}"
if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar
@@ -22,7 +23,7 @@ class DType:
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local)
def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
@dataclass(frozen=True, repr=False)
@dataclass(frozen=True)
class ImageDType(DType):
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
base: DType
@@ -32,7 +33,7 @@ class ImageDType(DType):
def ptr(self, local=False) -> Union[PtrDType, ImageDType]: return self
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
@dataclass(frozen=True, repr=False)
@dataclass(frozen=True)
class PtrDType(DType):
base: DType
local: bool

View File

@@ -142,6 +142,9 @@ class UOps(FastEnum):
ASSIGN = auto()
BIND = auto()
# late INDEX
INDEX = auto()
# control flow ops
BARRIER = auto()
IF = auto()
@@ -274,6 +277,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}"
return ret
def sink(self, *srcs:UOp): return UOp(UOps.SINK, dtypes.void, (self,)+srcs)
def index(self, idx:UOp): return UOp(UOps.INDEX, self.dtype, (self,idx))
def view(self, st:ShapeTracker): return UOp(UOps.VIEW, self.dtype, (self,), st)
def const_like(self, b:ConstType|Variable|Tuple[ConstType, ...]): return UOp.const(self.dtype, b)
def broadcast(self, count:int):
@@ -457,12 +461,10 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
dont_count: Set[UOp] = set()
if ignore_indexing:
for u in uops:
if u.op is UOps.LOAD:
dont_count = dont_count.union(u.src[1].sparents)
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
elif u.op is UOps.STORE:
dont_count = dont_count.union(u.src[1].sparents)
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
if u.op in {UOps.LOAD, UOps.STORE}:
offset = 0 if u.src[0].op not in {UOps.INDEX, UOps.CAST} else -1
dont_count = dont_count.union(u.src[offset+1].sparents)
if len(u.src) > offset+3: dont_count = dont_count.union(u.src[offset+3].sparents)
elif u.op is UOps.IF:
dont_count = dont_count.union(u.src[0].sparents)
for u in uops:
@@ -476,7 +478,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
elif u.op is UOps.LOAD:
mem += u.dtype.itemsize * mults
elif u.op is UOps.STORE:
mem += u.src[2].dtype.itemsize * mults
mem += u.src[2 if u.src[0].op not in {UOps.INDEX, UOps.CAST} else 1].dtype.itemsize * mults
elif u.op is UOps.ALU and u not in dont_count:
flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count
elif u.op is UOps.WMMA and u not in dont_count:
@@ -535,6 +537,7 @@ class UPat(MathTrait):
def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return UPat(UOps.CONST, dtype=dtype, arg=b)
# copied from UOp
def index(self, idx:UPat): return UPat(UOps.INDEX, self.dtype, (self,idx))
def cast(self, dtype=None): return UPat(UOps.CAST, dtype, (self,))
def bitcast(self, dtype=None): return UPat(UOps.BITCAST, dtype, (self,))
def gep(self, i:int): return UPat(UOps.GEP, None, (self,), (i,))
@@ -733,6 +736,9 @@ spec = PatternMatcher([
(UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.VIEW))), lambda: True),
(UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.VIEW), UPat(UOps.STORE))), lambda: True),
# early STORE has a <buf, shapetracker, val>
(UPat(UOps.STORE, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.VIEW), UPat())), lambda: True),
# LOAD takes a <buf, idx, alt?, gate?, barrier?>
(UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat())), lambda: True),
(UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat((UOps.IF, UOps.BARRIER)))), lambda: True),
@@ -744,6 +750,21 @@ spec = PatternMatcher([
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat(), UPat(UOps.IF))), lambda: True),
# **** new style load/store ****
# INDEX is used in new style load/store
(UPat(UOps.INDEX, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat())), lambda: True),
# LOAD takes a <bufidx, alt?, gate?, barrier?>
(UPat(UOps.LOAD, src=(UPat((UOps.INDEX, UOps.CAST)),)), lambda: True),
(UPat(UOps.LOAD, src=(UPat((UOps.INDEX, UOps.CAST)), UPat((UOps.IF, UOps.BARRIER)))), lambda: True),
(UPat(UOps.LOAD, src=(UPat((UOps.INDEX, UOps.CAST)), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
# STORE takes a <bufidx, val, gate?>
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.INDEX, UOps.CAST)), UPat())), lambda: True),
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.INDEX, UOps.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.INDEX, UOps.CAST)), UPat(), UPat(UOps.IF))), lambda: True),
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
(UPat(UOps.ALU, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y")), arg=TernaryOps.WHERE),
lambda w,x,y: w.dtype == x.dtype == y.dtype),
@@ -771,8 +792,8 @@ spec = PatternMatcher([
(UPat(UOps.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in REDUCE_ALU.values()),
(UPat(UOps.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
(UPat(UOps.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
(UPat((UOps.BITCAST, UOps.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None and x.dtype.count == 1),
(UPat(UOps.BARRIER, dtypes.void, src=UPat(UOps.STORE, src=(UPat(UOps.DEFINE_LOCAL),), allow_any_len=True)), lambda: True),
(UPat((UOps.BITCAST, UOps.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
(UPat(UOps.BARRIER, dtypes.void, src=UPat(UOps.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
# NOTE: for testing, we let sinks be anything
#(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),

View File

@@ -7,12 +7,6 @@ from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
from tinygrad.renderer import Renderer, TensorCore
def _render_index(r:CStyleLanguage, buf:UOp, idx:UOp, dtype:DType) -> str:
sidx = strip_parens(r[idx]) if idx.arg == BinaryOps.ADD else r[idx]
if dtype.count > 1 and isinstance(buf.dtype, PtrDType):
return f"(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(dtype)}*)({r[buf]}+{sidx}))"
return f"({r[buf]}+{sidx})"
base_rewrite = PatternMatcher([
(UPat(UOps.DEFINE_ACC, name="x"), lambda r,x: r[x.src[0]]),
(UPat(UOps.ASSIGN, name="x"), lambda r,x: f"{r[x.src[0]]} = {r[x.src[1]]};"),
@@ -40,13 +34,12 @@ base_rewrite = PatternMatcher([
(UPat(UOps.CONST, dtype=dtypes.uint32, name="x"), lambda r,x: f"{x.arg}u"),
(UPat(UOps.CONST, dtype=dtypes.bool, name="x"), lambda r,x: "1" if x.arg else "0"),
(UPat(UOps.CONST, name="x"), lambda r,x: str(x.arg)),
# load/store
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var"), UPat.var("gate")), name="load"),
lambda r,buf,idx,load,var,gate: f"({r[gate]}?*{_render_index(r, buf, idx, load.dtype)}:{r[var]})"),
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True, name="load"),
lambda r,buf,idx,load: f"*{_render_index(r, buf, idx, load.dtype)}"),
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var")), allow_any_len=True),
lambda r,buf,idx,var: f"*{_render_index(r, buf, idx, var.dtype)} = {r[var]};"),
# new load/store
(UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
lambda r,buf,idx: f"({r[buf]}+{strip_parens(r[idx]) if idx.arg == BinaryOps.ADD else r[idx]})"),
(UPat(UOps.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda r,bidx,var,gate: f"({r[gate]}?*{r[bidx]}:{r[var]})"),
(UPat(UOps.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda r,bidx: f"*{r[bidx]}"),
(UPat(UOps.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda r,bidx,var: f"*{r[bidx]} = {r[var]};"),
# alu/gep
(UPat(UOps.ALU, name="x"), lambda r,x: r.code_for_op[x.arg](
*([strip_parens(r[v]) if v.arg == x.arg and x.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else r[v] for v in x.src]), x.dtype)),
@@ -54,6 +47,12 @@ base_rewrite = PatternMatcher([
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if r.device in {"CUDA", "NV"} else 4) or r.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
])
def idx_load_store(x:UOp):
idx = x.src[0].index(x.src[1])
v = x.dtype.count if x.op is UOps.LOAD else x.src[2].dtype.count
if v > 1 and not isinstance(x.src[0].dtype, ImageDType): idx = idx.cast(idx.dtype.base.vec(v).ptr(idx.dtype.local))
return UOp(x.op, x.dtype, (idx,)+x.src[2:], x.arg)
extra_pm = PatternMatcher([
# consts are rendered to larger type and casted
(UPat(UOps.CONST, (dtypes.bfloat16, dtypes.half), name="c"), lambda c: UOp.const(dtypes.float, c.arg).cast(c.dtype)),
@@ -62,9 +61,11 @@ extra_pm = PatternMatcher([
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
(UPat(UOps.BITCAST, name="x"),
lambda x: UOp(UOps.BITCAST, x.dtype, (UOp(UOps.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not UOps.NOOP else None),
# use indexing for LOAD/STORE
(UPat((UOps.LOAD, UOps.STORE), src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store),
# gate any stores that aren't gated with ifs
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
lambda store: UOp(UOps.STORE, src=store.src[:3]+(UOp(UOps.IF, src=(store.src[3],)),))),
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
lambda store: UOp(UOps.STORE, src=store.src[:2]+(UOp(UOps.IF, src=(store.src[2],)),))),
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
(UPat(UOps.ALU, name="m", arg=BinaryOps.MAX), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
])
@@ -113,7 +114,8 @@ class CStyleLanguage(Renderer):
if isinstance(dt, ImageDType):
return f"{'write_only' if mutable else 'read_only'} image2d_t"
if isinstance(dt, PtrDType):
return (self.smem_prefix if dt.local else self.buffer_prefix) + self.render_dtype(dt.base) + ("*" if isinstance(dt, PtrDType) else "")
return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) + \
self.render_dtype(dt.base) + ("*" if isinstance(dt, PtrDType) else "")
return self.type_map.get(scalar:=dt.scalar(), scalar.name) + (str(dt.count) if (dt.count) > 1 else "")
def __getitem__(self, key): return self.r[key] # hacky helper
@@ -144,15 +146,15 @@ class CStyleLanguage(Renderer):
else:
prefix = {UOps.RANGE: "ridx", UOps.ALU: "alu", UOps.WMMA: "wmma", UOps.DEFINE_LOCAL: "temp", UOps.CONST: "const",
UOps.CAST: "cast", UOps.BITCAST: "cast", UOps.GEP: "gep", UOps.VECTORIZE: "cast", UOps.NOOP: "precast",
UOps.DEFINE_ACC: "acc", UOps.LOAD: "val"}.get(u.op, "unk")
UOps.INDEX: "bidx", UOps.DEFINE_ACC: "acc", UOps.LOAD: "val"}.get(u.op, "unk")
r[u] = f"{prefix}{c[prefix]}"
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
if u.op in {UOps.ENDIF, UOps.ENDRANGE}: depth -= 1
if u.op in {UOps.CONST, UOps.GEP} or (u.op in {UOps.VECTORIZE, UOps.ALU, UOps.CAST, UOps.BITCAST}
and child_count[u] == 1 and not getenv("EXPAND_SSA")):
if u.op in {UOps.CONST, UOps.GEP, UOps.INDEX} or (u.op in {UOps.VECTORIZE, UOps.ALU, UOps.CAST, UOps.BITCAST}
and child_count[u] == 1 and not getenv("EXPAND_SSA")):
r[u] = l
else:
if u.op in {UOps.RANGE, UOps.ASSIGN, UOps.DEFINE_LOCAL} or u.dtype == dtypes.void:
@@ -189,10 +191,11 @@ class ClangRenderer(CStyleLanguage):
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({(sz:=dt.itemsize)}),vector_size({sz})));"
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
prefix, macros = [self.render_vector_prefix(dt) for dt in dedup(uop.dtype for uop in uops if uop.dtype.count>1)], []
# TODO: copied into AMD
prefix = [self.render_vector_prefix(dt) for dt in dedup(u.dtype for u in uops if u.dtype.count > 1 and not isinstance(u.dtype, PtrDType))]
# https://github.com/corsix/amx
for name, (N, M, _), dtype_in, _, _, _, _, _ in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
macros = [
prefix += [
'#define AMX_SET(imm5) __asm("nop\\nnop\\nnop\\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")',
'#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")',
]
@@ -200,7 +203,7 @@ class ClangRenderer(CStyleLanguage):
AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
AMX(0, (int *)(&data2), 0ull<<62); AMX(1, (int *)(&data1), 0ull<<62); AMX(12, 0, 0ull);
for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}\n AMX_SET(1);\n return data0;\n}}"""] # noqa: E501
return super().render_kernel(function_name, kernel, bufs, uops, macros + prefix)
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
class OpenCLRenderer(CStyleLanguage):
device = "GPU"
@@ -218,11 +221,11 @@ class OpenCLRenderer(CStyleLanguage):
string_rewrite = PatternMatcher([
(UPat(UOps.BITCAST, name="x"), lambda r,x: f"as_{r.render_dtype(x.dtype)}({r[x.src[0]]})"),
# load/store image (OpenCL)
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtypes.int.vec(2)), UPat.var("var"), UPat.var("gate"))),
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))),
lambda r,buf,idx,var,gate: f"({r[gate]}?read_imagef({r[buf]}, smp, {r[idx]}):{r[var]})"),
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtypes.int.vec(2)))),
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
lambda r,buf,idx: f"read_imagef({r[buf]}, smp, {r[idx]})"),
(UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtypes.int.vec(2)), UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
(UPat(UOps.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
lambda r,buf,idx,var: f"write_imagef({r[buf]}, {r[idx]}, {r[var]});"),
]) + base_rewrite
@@ -422,8 +425,7 @@ class AMDRenderer(CStyleLanguage):
# TODO: add BF16 vec dts
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("struct hip_bfloat16 { unsigned short data; };")
for dtype in dedup(uop.dtype for uop in uops if uop.dtype.count > 1): prefix.append(self.render_vector_prefix(dtype))
prefix += [self.render_vector_prefix(dt) for dt in dedup(u.dtype for u in uops if u.dtype.count > 1 and not isinstance(u.dtype, PtrDType))]
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")