Float4 support for CLANG (#5915)

* float4 support on clang

* skip linearizer tests that require locals

* add aligned attribute
This commit is contained in:
ignaciosica
2024-08-06 11:50:12 -03:00
committed by GitHub
parent a7db4c3ee9
commit 81ae9fadc8
2 changed files with 12 additions and 3 deletions

View File

@@ -992,6 +992,7 @@ class TestLinearizer(unittest.TestCase):
# the global store doesn't change
assert stores[1].src[2].dtype == dtypes.float
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_skip_unmatching_upcasts(self):
Tensor.manual_seed(0)
@@ -1193,6 +1194,7 @@ class TestFloat4(unittest.TestCase):
count = len([uop for uop in k.uops if uop.op is UOps.DEFINE_ACC and uop.dtype == dtypes.float.vec(4)])
assert count == expected, f"{count=}, {expected=}"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
def test_float2_acc(self):
# from resnet
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True))))),), arg=dtypes.float),), arg=(4, 6)),), arg=dtypes.half),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501

View File

@@ -41,7 +41,7 @@ class CStyleLanguage(Renderer):
def render_vectorize(self, x:List[str], var_dtype:DType) -> str:
assert len(x) == var_dtype.count, f"cast is wrong size {len(x)} != {var_dtype.count}"
assert self.float4 is not None, "vectorized cast is not supported on this platform"
return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}" + (f"{{{','.join(x)}}}" if self.device == "CLANG" else f"({','.join(x)})")
# returns a str expression of the const with the given type
def render_const(self, x:ConstType, dtype:DType) -> str:
@@ -180,15 +180,18 @@ class CStyleLanguage(Renderer):
assert src[0].dtype is not None
from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + \
(f"[{args}]" if src[0].dtype.count > (8 if self.device in {"CUDA", "NV"} else 4) else f".{'xyzwabcd'[args]}")
(f"[{args}]" if src[0].dtype.count > (8 if self.device in {"CUDA", "NV"} else 4) or self.device == 'CLANG' else f".{'xyzwabcd'[args]}")
else: raise RuntimeError(f"failed to render {u}")
# NOTE: this relies on bufs dict preserving order
return self.render_kernel(name, kernel, list(bufs.values()), uops)
def _make_clang_dtype(self, dtype):
return f"typedef {self.render_dtype(dtype.scalar())} {self.render_dtype(dtype)} __attribute__((aligned({(sz:=dtype.itemsize)}),vector_size({sz})));"
class ClangRenderer(CStyleLanguage):
device = "CLANG"
supports_float4 = False
float4 = "(float4)"
has_local = False
global_max = None
@@ -197,6 +200,10 @@ class ClangRenderer(CStyleLanguage):
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"(({a}>{b})?{a}:{b})"}
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
prefix = [_make_clang_dtype(self, dtype) for dtype in set(uop.dtype for uop in uops if uop.dtype is not None and uop.dtype.count>1)]
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
class OpenCLRenderer(CStyleLanguage):
device = "GPU"