mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
This reverts commit e9cef13f0b.
This commit is contained in:
1
.github/workflows/test.yml
vendored
1
.github/workflows/test.yml
vendored
@@ -45,7 +45,6 @@ jobs:
|
||||
run: |
|
||||
DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
|
||||
DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
|
||||
DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
|
||||
- name: Test dtype with Python emulator
|
||||
run: DEBUG=2 PYTHON=1 python3 test/test_dtype.py
|
||||
- name: Test ops with Python emulator
|
||||
|
||||
@@ -30,7 +30,7 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
|
||||
dtype_out: DType # dtype for C and D
|
||||
threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
|
||||
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
|
||||
thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim
|
||||
thread_local_sizes: List[int] # in each thread, the number of elements stored in registers for each TC dim
|
||||
wmma_func: str # name of wmma function to call
|
||||
def __str__(self): return f"tensor_core<{self.dims}, {self.dtype_in}, {self.dtype_out}>"
|
||||
def num_threads(self): return len(self.threads)
|
||||
@@ -38,17 +38,14 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
|
||||
|
||||
tensor_cores: Dict[str, List[TensorCore]] = {
|
||||
"METAL": [
|
||||
TensorCore(dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, wmma_func="__metal_wmma<float2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
|
||||
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma<half2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
|
||||
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma<half2,simdgroup_half8x8,half2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
|
||||
TensorCore(dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, wmma_func="__metal_wmma<float2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
|
||||
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma<half2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
|
||||
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma<half2,simdgroup_half8x8,half2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
|
||||
],
|
||||
"HIP": [
|
||||
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__builtin_amdgcn_wmma_f32_16x16x16_f16_w32", threads=[(0,16),(1,2)], thread_local_sizes=[[16],[16],[8]], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
|
||||
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__hip_wmma_f16_f16", threads=[(0,16),(1,2)], thread_local_sizes=[[16],[16],[8]], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
|
||||
],
|
||||
"CUDA": [
|
||||
TensorCore(dims=[8,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__cuda_mma_m16n8k16_f16_f32", threads=[(0,4),(1,4),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[0],[-1,1,-3],[2,3,-2]], [[0],[0],[2,3],[-1,1,-2],[0]], [[0],[0],[-1,1],[0],[2,3,-2]] ]), # noqa: E501
|
||||
],
|
||||
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__builtin_amdgcn_wmma_f32_16x16x16_f16_w32", threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
|
||||
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__hip_wmma_f16_f16", threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
|
||||
]
|
||||
}
|
||||
|
||||
class LocalBuffer(NamedTuple):
|
||||
|
||||
@@ -252,27 +252,30 @@ class Linearizer(Kernel):
|
||||
# define accumulator
|
||||
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop))
|
||||
|
||||
# reduce loop
|
||||
loop_ctx = render_loop(reduce_idxs)
|
||||
|
||||
# barrier for fast GEMM
|
||||
if (tc:=self.tensor_core):
|
||||
self.uop(UOps.BARRIER, None, (), cachable=False)
|
||||
|
||||
def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
|
||||
replace_idxs, thread_idxs, thread_idx = [], [], Variable("_uidx_tc", 0, prod(local_sizes)-1)
|
||||
for s in local_sizes:
|
||||
thread_idxs.append(thread_idx % s)
|
||||
thread_idx //= s
|
||||
def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
|
||||
replace_idxs = []
|
||||
for alias in aliases:
|
||||
full_var, full_var_sz = NumNode(0), 1
|
||||
if alias[0] != 0:
|
||||
for i in alias:
|
||||
next_var = local_idxs[-i] if i > 0 else thread_idxs[-i-1]
|
||||
next_var = local_idxs[-i] if i > 0 else Variable("_uidx_tc", 0, local_size-1)
|
||||
full_var += next_var * full_var_sz
|
||||
full_var_sz *= next_var.max+1
|
||||
replace_idxs.append(full_var)
|
||||
return replace_idxs
|
||||
replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
|
||||
for n in range(len(tc.threads)):
|
||||
local_idxs[self.local_dims-len(tc.threads)+n] = replace_acc_idxs[n] # replace locals
|
||||
for n in range(len(replace_acc_idxs)-len(tc.threads)):
|
||||
upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
|
||||
if DEBUG >= 3: print("store alias: idxs=", global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs)
|
||||
|
||||
# reduce loop
|
||||
loop_ctx = render_loop(reduce_idxs)
|
||||
|
||||
# barrier for fast GEMM
|
||||
if self.tensor_core: self.uop(UOps.BARRIER, None, (), cachable=False)
|
||||
|
||||
# compute local aliases
|
||||
locals_to_store = []
|
||||
@@ -286,20 +289,13 @@ class Linearizer(Kernel):
|
||||
buf_idxs[self.first_reduce-tc.num_threads()+n] = replace_input_idxs[n] # replace locals
|
||||
for n in range(tc.num_upcasts()):
|
||||
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[tc.num_threads()+n] # replace upcasts
|
||||
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs=", buf_idxs)
|
||||
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: idxs=", buf_idxs)
|
||||
ll = self.global_load(i, buf_idxs)
|
||||
locals_to_store.append((localbuf_idx, buf_idxs, ll))
|
||||
|
||||
# copy in any global buffers
|
||||
if (tc:=self.tensor_core):
|
||||
replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
|
||||
for n in range(len(tc.threads)):
|
||||
local_idxs[self.local_dims-len(tc.threads)+n] = replace_acc_idxs[n] # replace locals
|
||||
for n in range(len(replace_acc_idxs)-len(tc.threads)):
|
||||
upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
|
||||
if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs}")
|
||||
|
||||
wmma_sz = [prod(l) for l in tc.thread_local_sizes]
|
||||
wmma_sz = tc.thread_local_sizes
|
||||
def upcast_strides(buf:int):
|
||||
strides, next = [], 1
|
||||
for (sz, stride, reduce) in self.upcasted_axis(buf)[tc.num_upcasts():]:
|
||||
|
||||
@@ -115,6 +115,5 @@ def uops_flops_mem(uops:List[UOp], vars:Dict[str, Variable]) -> Tuple[sint, sint
|
||||
if u.uop is UOps.WMMA:
|
||||
if u.arg.startswith("__metal_wmma"): flops += 2*(8*8*8)//32 * mults
|
||||
elif u.arg == "__hip_wmma_f16_f16" or u.arg == "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32": flops += 2*(16*16*16)//32 * mults
|
||||
elif u.arg == "__cuda_mma_m16n8k16_f16_f32": flops += 2*(8*16*16)//32 * mults
|
||||
else: raise Exception("not implemented")
|
||||
return flops, mem
|
||||
|
||||
@@ -236,14 +236,8 @@ class CUDALanguage(CStyleLanguage):
|
||||
def render_kernel(self, function_name, kernel, bufs, local_size, uops, prefix=None):
|
||||
prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
|
||||
if any(uop.dtype == dtypes.half for uop in uops):
|
||||
prefix += ["#include <cuda_fp16.h>", "struct half4 { half x, y, z, w; };", "struct half8 { half x, y, z, w, a, b, c, d; };",
|
||||
"__device__ half4 make_half4(half x, half y, half z, half w) { return {x, y, z, w}; }",
|
||||
"__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { return {x, y, z, w, a, b, c, d}; }",
|
||||
"""__device__ float4 __cuda_mma_m16n8k16_f16_f32(half8 a, half4 b, float4 c) { int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
|
||||
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
|
||||
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
|
||||
return c;}""",
|
||||
]
|
||||
prefix += ["#include <cuda_fp16.h>", "struct half4 { half x, y, z, w; };",
|
||||
"__device__ half4 make_half4(half x, half y, half z, half w) { half4 ret; ret.x = x; ret.y = y; ret.z = z; ret.w = w; return ret; }"]
|
||||
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("#include <cuda_bf16.h>")
|
||||
return super().render_kernel(function_name, kernel, bufs, local_size, uops, prefix=prefix)
|
||||
CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage())
|
||||
|
||||
@@ -33,7 +33,6 @@ class CUDACompiler(Compiler):
|
||||
linearizer_opts = LinearizerOptions("CUDA", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024])
|
||||
def __init__(self, arch:str):
|
||||
self.arch = arch
|
||||
CUDACompiler.linearizer_opts = CUDACompiler.linearizer_opts._replace(has_tensor_cores=int(arch[3:]) >= 80)
|
||||
super().__init__(f"compile_cuda_{self.arch}")
|
||||
def render(self, name:str, uops) -> str: return CUDARenderer(name, uops)
|
||||
def compile(self, src:str) -> bytes:
|
||||
|
||||
@@ -189,11 +189,6 @@ class PythonProgram:
|
||||
return a_elem(x, j, i, goff)
|
||||
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
|
||||
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
elif arg == '__cuda_mma_m16n8k16_f16_f32':
|
||||
def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4] # A (8 elements on 32 threads)
|
||||
def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4] # B (4 elements on 32 threads)
|
||||
def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8) # (i, j), C, D (4 elements on 32 threads)
|
||||
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
|
||||
else:
|
||||
raise Exception(f"unimplemented tensor core {arg}")
|
||||
elif uop is UOps.ALU:
|
||||
@@ -206,8 +201,7 @@ class PythonProgram:
|
||||
|
||||
class PythonCompiler(Compiler):
|
||||
linearizer_opts = LinearizerOptions("METAL", has_tensor_cores=True) if getenv("EMULATE_METAL") else \
|
||||
(LinearizerOptions("HIP", has_tensor_cores=True) if getenv("EMULATE_HIP") else \
|
||||
(LinearizerOptions("CUDA", has_tensor_cores=True) if getenv("EMULATE_CUDA") else LinearizerOptions()))
|
||||
(LinearizerOptions("HIP", has_tensor_cores=True) if getenv("EMULATE_HIP") else LinearizerOptions())
|
||||
def render(self, name:str, uops:List[UOp]) -> str:
|
||||
lops = [(u.uop, u.dtype, [uops.index(v) for v in u.vin], u.arg) for u in uops]
|
||||
return base64.b64encode(pickle.dumps(lops)).decode()
|
||||
|
||||
Reference in New Issue
Block a user