mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
Float4 support for CLANG (#5915)
* float4 support on clang * skip linearizer tests that require locals * add aligned attribute
This commit is contained in:
@@ -992,6 +992,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
# the global store doesn't change
|
||||
assert stores[1].src[2].dtype == dtypes.float
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
def test_skip_unmatching_upcasts(self):
|
||||
Tensor.manual_seed(0)
|
||||
@@ -1193,6 +1194,7 @@ class TestFloat4(unittest.TestCase):
|
||||
count = len([uop for uop in k.uops if uop.op is UOps.DEFINE_ACC and uop.dtype == dtypes.float.vec(4)])
|
||||
assert count == expected, f"{count=}, {expected=}"
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
def test_float2_acc(self):
|
||||
# from resnet
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True))))),), arg=dtypes.float),), arg=(4, 6)),), arg=dtypes.half),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
|
||||
@@ -41,7 +41,7 @@ class CStyleLanguage(Renderer):
|
||||
def render_vectorize(self, x:List[str], var_dtype:DType) -> str:
|
||||
assert len(x) == var_dtype.count, f"cast is wrong size {len(x)} != {var_dtype.count}"
|
||||
assert self.float4 is not None, "vectorized cast is not supported on this platform"
|
||||
return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
|
||||
return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}" + (f"{{{','.join(x)}}}" if self.device == "CLANG" else f"({','.join(x)})")
|
||||
|
||||
# returns a str expression of the const with the given type
|
||||
def render_const(self, x:ConstType, dtype:DType) -> str:
|
||||
@@ -180,15 +180,18 @@ class CStyleLanguage(Renderer):
|
||||
assert src[0].dtype is not None
|
||||
from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
|
||||
r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + \
|
||||
(f"[{args}]" if src[0].dtype.count > (8 if self.device in {"CUDA", "NV"} else 4) else f".{'xyzwabcd'[args]}")
|
||||
(f"[{args}]" if src[0].dtype.count > (8 if self.device in {"CUDA", "NV"} else 4) or self.device == 'CLANG' else f".{'xyzwabcd'[args]}")
|
||||
else: raise RuntimeError(f"failed to render {u}")
|
||||
|
||||
# NOTE: this relies on bufs dict preserving order
|
||||
return self.render_kernel(name, kernel, list(bufs.values()), uops)
|
||||
|
||||
def _make_clang_dtype(self, dtype):
|
||||
return f"typedef {self.render_dtype(dtype.scalar())} {self.render_dtype(dtype)} __attribute__((aligned({(sz:=dtype.itemsize)}),vector_size({sz})));"
|
||||
|
||||
class ClangRenderer(CStyleLanguage):
|
||||
device = "CLANG"
|
||||
supports_float4 = False
|
||||
float4 = "(float4)"
|
||||
has_local = False
|
||||
global_max = None
|
||||
|
||||
@@ -197,6 +200,10 @@ class ClangRenderer(CStyleLanguage):
|
||||
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
|
||||
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"(({a}>{b})?{a}:{b})"}
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
prefix = [_make_clang_dtype(self, dtype) for dtype in set(uop.dtype for uop in uops if uop.dtype is not None and uop.dtype.count>1)]
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
class OpenCLRenderer(CStyleLanguage):
|
||||
device = "GPU"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user