diff --git a/extra/helpers.py b/extra/helpers.py index 92376e0a6a..d6c544e2e4 100644 --- a/extra/helpers.py +++ b/extra/helpers.py @@ -5,7 +5,10 @@ import multiprocessing def _early_exec_process(qin, qout): while True: path, inp = qin.get() - qout.put(subprocess.check_output(path, input=inp)) + try: + qout.put(subprocess.check_output(path, input=inp)) + except subprocess.CalledProcessError as e: + qout.put(e) def enable_early_exec(): qin: multiprocessing.Queue = multiprocessing.Queue() @@ -15,7 +18,9 @@ def enable_early_exec(): p.start() def early_exec(x): qin.put(x) - return qout.get() + ret = qout.get() + if isinstance(ret, Exception): raise ret + else: return ret return early_exec def proc(itermaker, q) -> None: diff --git a/extra/rocm/rdna3/asm.py b/extra/rocm/rdna3/asm.py index aa58c87a48..f980e10fa9 100644 --- a/extra/rocm/rdna3/asm.py +++ b/extra/rocm/rdna3/asm.py @@ -24,8 +24,10 @@ code = open(pathlib.Path(__file__).parent / "prog.s", "r").read() gen = [] FLOPS = 0 -for j in range(4): - for i in range(0, 251, 6): +#MAX_REG = 251 +MAX_REG = 32 +for j in range(1): + for i in range(0, MAX_REG, 6): #gen.append(f"v_dual_fmac_f32 v{i+0}, v{i+1}, v{i+2} :: v_dual_fmac_f32 v{i+3}, v{i+4}, v{i+5}") #FLOPS += 4 gen.append(f"v_dual_dot2acc_f32_f16 v{i+0}, v{i+1}, v{i+2} :: v_dual_dot2acc_f32_f16 v{i+3}, v{i+4}, v{i+5}") @@ -48,9 +50,10 @@ print(colored("creating CLProgram", "green")) prg = CLProgram("code", asm, binary=True) print(colored("running program", "green")) -FLOPS *= 100000*1024*1024 # loop * global_size +G = 256 +FLOPS *= 100000*G*G # loop * global_size for i in range(3): - tm = prg([1024, 1024], [256, 1], buf, wait=True) + tm = prg([G, G], [256, 1], buf, wait=True) print(f"ran in {tm*1e3:.2f} ms, {FLOPS/(tm*1e9):.2f} GFLOPS") print(colored("transferring buffer", "green")) diff --git a/setup.py b/setup.py index 8bee1ca3fc..c16580ed9c 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ setup(name='tinygrad', "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License" ], - install_requires=['numpy', 'requests', 'pillow', 'tqdm', 'networkx', 'pyopencl'], + install_requires=['numpy', 'requests', 'pillow', 'tqdm', 'networkx', 'pyopencl', 'PyYAML'], python_requires='>=3.8', extras_require={ 'llvm': ["llvmlite"], @@ -41,6 +41,7 @@ setup(name='tinygrad', "opencv-python", "tabulate", "safetensors", + "types-PyYAML", ], }, include_package_data=True) diff --git a/test/test_ops.py b/test/test_ops.py index 48e616dcc7..686105e398 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -309,7 +309,7 @@ class TestOps(unittest.TestCase): def test_sum_full(self): helper_test_op([(16384)], lambda x: x.sum(), lambda x: x.sum()) def test_sum_small_full(self): - helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum) + helper_test_op([(45,5)], lambda x: x.sum(), Tensor.sum) def test_sum_relu(self): helper_test_op([(3,4,5)], lambda x: x.relu().sum().relu(), lambda x: x.relu().sum().relu()) def test_sum(self): @@ -877,7 +877,7 @@ class TestOps(unittest.TestCase): lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=stride), lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride)) - @unittest.skipIf(Device.DEFAULT == "CUDA", "CUDA fails on this") + @unittest.skipIf(Device.DEFAULT in ["CUDA", "PTX"], "CUDA fails on this") def test_maxpool2d_unit_stride(self): helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=1), diff --git a/tinygrad/codegen/assembly.py b/tinygrad/codegen/assembly.py index 2080219d99..3003e3bac1 100644 --- a/tinygrad/codegen/assembly.py +++ b/tinygrad/codegen/assembly.py @@ -1,5 +1,5 @@ from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict -from tinygrad.codegen.linearizer import Linearizer, UOps +from tinygrad.codegen.linearizer import Linearizer, UOps, Token from tinygrad.ops import ASTRunner, FusedOps, BinaryOps, UnaryOps from tinygrad.helpers import DType, dtypes, DEBUG from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode @@ -7,13 +7,19 @@ import functools import math from collections import defaultdict -type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'I', dtypes.uint64: 'A'} +_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes._float4: 'x'} +def type_to_letter(x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]] class Register(NamedTuple): nm:str dtype:DType - def __repr__(self): return self.nm - + scalar:bool + off:Optional[int] = None + def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}" + def subregs(self): + if self.dtype == dtypes._float4: + return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)] + return [] class AssemblyInstruction(NamedTuple): op: UOps out: Optional[Register] @@ -23,6 +29,8 @@ class AssemblyInstruction(NamedTuple): # warp size of 32, s registers are shared across the warp, v are 32-wide vectors class AssemblyCodegen(Linearizer): supports_load3: bool = False + sin_is_sin2pi: bool = False + no_div: bool = False def specialize(self, asm:List[AssemblyInstruction]) -> Tuple[str, str]: raise NotImplementedError("must be implemented") @@ -34,24 +42,28 @@ class AssemblyCodegen(Linearizer): self.limit_global_dims(3) # all GPU asms have 3 (for now) self.linearize() - cnts:DefaultDict[DType, int] = defaultdict(int) + cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int) tor: Dict[Any, Register] = {} - def newreg(tok, dtype=dtypes.float32): + def newreg(tok, dtype=dtypes.float32, scalar=False): nonlocal cnts, tor - tor[tok] = ret = Register(f"%{type_to_letter[dtype]}{cnts[dtype]}", dtype) - cnts[dtype] += 1 + if isinstance(tok, Token): dtype = tok.dtype # this + tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{cnts[(dtype, scalar)]}", dtype, scalar) + if dtype == dtypes._float4: + for off in range(4): + tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off) + cnts[(dtype, scalar)] += 1 return ret def render_numnode(b): key = ("num", b) - if key not in tor: ins.append(AssemblyInstruction(UOps.CONST, newreg(key, dtype=dtypes.int32), [], b)) + if key not in tor: ins.append(AssemblyInstruction(UOps.CONST, newreg(key, scalar=True, dtype=dtypes.int32), [], b)) return tor[key] def render_alu(op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register: key = (op, a, b) if key not in tor: - #if not isinstance(b, Register): b = render_numnode(b) - ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype), [a, b], op)) + if not isinstance(b, Register): b = render_numnode(b) + ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op)) return tor[key] def render_cast(a:Register, new_dtype:DType) -> Register: @@ -72,25 +84,29 @@ class AssemblyCodegen(Linearizer): def addr_w_offset(args): idx = args.idx*self.bufs[args.i].dtype.itemsize off = 0 # TODO: should this be None? - if isinstance(idx, SumNode) and not self.supports_load3: + if isinstance(idx, SumNode): nums = [n.b for n in idx.nodes if isinstance(n, NumNode)] - if len(nums) > 0: + if len(nums) > 0 and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU? idx -= nums[0] off = nums[0] reg = idx.render(render_ops) if self.supports_load3: - return tor[f"buf{args.i}"], reg + if reg.scalar: + new_reg = newreg((reg.nm, 'vec'), dtype=reg.dtype) + ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP)) + reg = new_reg + return tor[f"buf{args.i}"], reg, off else: reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64) - return reg, off + return reg, None, off ins = [] - ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64), [], f"buf{i}") for i in range(len(self.bufs))] + ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64, scalar=True), [], f"buf{i}") for i in range(len(self.bufs))] global_size, local_size = [], [] skipload_branch = 0 for uop,newvar,vin,args in self.uops: if uop == UOps.CONST and newvar is not None: - ins.append(AssemblyInstruction(UOps.CONST, newreg(newvar), [], args)) + ins.append(AssemblyInstruction(UOps.CONST, newreg(newvar, dtype=newvar.dtype), [], args)) elif uop == UOps.DEFINE_LOCAL: ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args)) ins.append(AssemblyInstruction(UOps.ALU, newreg("buf-1", dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP)) @@ -107,38 +123,48 @@ class AssemblyCodegen(Linearizer): else: for var in args[0]: if not isinstance(var, NumNode): # TODO: why is this coming through? - ins.append(AssemblyInstruction(UOps.CONST, newreg(var, dtype=dtypes.int32), [], 0)) + ins.append(AssemblyInstruction(UOps.CONST, newreg(var, dtype=dtypes.int32, scalar=True), [], 0)) ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr)) elif uop == UOps.ENDLOOP: if args[1] not in ["global", "local"]: for var in reversed(args[0]): if not isinstance(var, NumNode): # TODO: why is this coming through? - pred = render_alu(BinaryOps.CMPLT, tor[var], var.max, dtypes.bool) ins.append(AssemblyInstruction(UOps.ALU, tor[var], [tor[var], 1], BinaryOps.ADD)) + pred = render_alu(BinaryOps.CMPLT, tor[var], var.max+1, dtypes.bool) ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True))) + elif uop == UOps.CAST and newvar is not None: + # TODO: we should reconsider outputting CAST in the linearizer. these are needless copies + out = newreg(newvar) + for i,sr in enumerate(out.subregs()): + ins.append(AssemblyInstruction(UOps.ALU, sr, [tor[vin[i]]], UnaryOps.NOOP)) elif uop == UOps.ALU and newvar is not None: if args == FusedOps.MULACC: vin = [vin[1], vin[2], vin[0]] # TODO: reorder MULACC everywhere + out = newreg(newvar) if newvar not in tor else tor[newvar] # this is the only thing that can violate SSA if args in [BinaryOps.CMPEQ, BinaryOps.CMPLT]: pred_reg = newreg((newvar, 'pred'), dtype=dtypes.bool) ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [tor[x] for x in vin], args)) - ins.append(AssemblyInstruction(UOps.CAST, newreg(newvar), [pred_reg], args)) + ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args)) elif args == BinaryOps.POW: # TODO: add UnaryOps.SQRT tmp = newreg((newvar, "exp_a")) tmp2 = newreg((newvar, "exp_a_times_b")) ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]]], UnaryOps.LOG2)) ins.append(AssemblyInstruction(UOps.ALU, tmp2, [tmp, tor[vin[1]]], BinaryOps.MUL)) - ins.append(AssemblyInstruction(UOps.ALU, newreg(newvar), [tmp2], UnaryOps.EXP2)) - elif args == UnaryOps.SIN and hasattr(self, 'sin_is_sin2pi'): + ins.append(AssemblyInstruction(UOps.ALU, out, [tmp2], UnaryOps.EXP2)) + elif args == BinaryOps.DIV and self.no_div: + tmp = newreg((newvar, "rcp")) + ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[1]]], UnaryOps.RECIP)) + ins.append(AssemblyInstruction(UOps.ALU, out, [tor[vin[0]], tmp], BinaryOps.MUL)) + elif args == UnaryOps.SIN and self.sin_is_sin2pi: tmp = newreg((newvar, "2pi")) ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL)) - ins.append(AssemblyInstruction(UOps.ALU, newreg(newvar) if newvar not in tor else tor[newvar], [tmp], args)) + ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args)) else: - ins.append(AssemblyInstruction(UOps.ALU, newreg(newvar) if newvar not in tor else tor[newvar], [tor[x] for x in vin], args)) + ins.append(AssemblyInstruction(UOps.ALU, out, [tor[x] for x in vin], args)) elif uop == UOps.LOAD and newvar is not None: - idx, off = addr_w_offset(args) - reg = newreg(newvar) + idx, treg, off = addr_w_offset(args) + reg = newreg(newvar, dtype=newvar.dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) # and not dtypes.is_float(newvar.dtype))) if args.valid.min == 0: ins.append(AssemblyInstruction(UOps.CONST, reg, [], 0)) if args.valid.max == 1: @@ -146,16 +172,16 @@ class AssemblyCodegen(Linearizer): ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False))) if args.valid.max == 1: # NOTE: you can't compute the index in here, because it assumes it's all available later - ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx], (off, 'global' if args.i != -1 else 'shared'))) + ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared'))) if args.valid.min == 0 and args.valid.max == 1: ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}")) skipload_branch += 1 elif uop == UOps.STORE: - idx, off = addr_w_offset(args) - ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]], (off, 'global' if args.i != -1 else 'shared'))) + idx, treg, off = addr_w_offset(args) + ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared'))) # define registers - ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter[dtype], c)) for dtype,c in cnts.items()] + ins + ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in cnts.items()] + ins if DEBUG >= 4: for tins in ins: print(tins) diff --git a/tinygrad/codegen/assembly_rdna.py b/tinygrad/codegen/assembly_rdna.py new file mode 100644 index 0000000000..fe78f9e272 --- /dev/null +++ b/tinygrad/codegen/assembly_rdna.py @@ -0,0 +1,181 @@ +import yaml +from typing import Tuple, Set, Dict +from tinygrad.helpers import dtypes +from tinygrad.codegen.assembly import AssemblyCodegen, Register +from tinygrad.codegen.linearizer import UOps +from tinygrad.ops import BinaryOps, UnaryOps, FusedOps +from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH + +# ugh, is this really needed? +from extra.helpers import enable_early_exec +early_exec = enable_early_exec() + +boilerplate_start = """ +.global _start +_start: +.rodata +.align 0x10 +.global code.kd +.type code.kd,STT_OBJECT +.amdhsa_kernel code""" + +code_start = """.end_amdhsa_kernel +.text +code: +""" + +# https://github.com/RadeonOpenCompute/ROCm_Documentation/blob/master/ROCm_Compiler_SDK/ROCm-Codeobj-format.rst +# https://github.com/ROCm-Developer-Tools/ROCm-ComputeABI-Doc/blob/master/AMDGPU-ABI.md#initial-kernel-register-state +# RDNA3 is actually a SIMD machine! +class RDNACodegen(AssemblyCodegen): + supports_float4: bool = True + supports_float4_alu: bool = False + supports_load3: bool = True + sin_is_sin2pi: bool = True + no_div: bool = True + + def specialize(self, asm) -> Tuple[str, str]: + args = [] + for i,b in enumerate(self.bufs): args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8, '.type_name': b.dtype.name+"*", '.value_kind': 'global_buffer'}) + ins = [] + + v_cnt = 3 # v[0:2] is local_xyz + s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz + + dtype_to_rdnatype = {dtypes.float32: "f32", dtypes.int64: "i64", dtypes.int32: "i32", dtypes.uint64: "u64", dtypes.bool: "i32"} + alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", FusedOps.MULACC: "fma", + BinaryOps.MAX: "max", UnaryOps.RECIP: "rcp", + UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin", UnaryOps.LOG2: "log", UnaryOps.EXP2: "exp", + BinaryOps.CMPEQ: "cmp_eq", BinaryOps.CMPLT: "cmp_lt"} + + pend_regs:Set[Register] = set() + rtor:Dict[Register, str] = {} + def reg_in(x): + nonlocal pend_regs + #print("reg_in", x, rtor[x], pend_regs) + if x in pend_regs: + #print("clear") + ins.append('s_waitcnt lgkmcnt(0), vmcnt(0)') + pend_regs.clear() + return rtor[x] + def reg_out(x): + return rtor[x] + for uop, out, vin, arg in asm: + if uop == UOps.DEFINE_REGISTER: + if arg[0][0] == dtypes.uint64 and arg[0][1]: + # assuming these are scalar + s_cnt += s_cnt%2 # aligned(2) + for i in range(arg[2]): + rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"s[{s_cnt}:{s_cnt+1}]" + s_cnt += 2 + elif arg[0][0] == dtypes._float4 and not arg[0][1]: + v_cnt += (4-v_cnt%4) if v_cnt%4 != 0 else 0 + for i in range(arg[2]): + rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"v[{v_cnt}:{v_cnt+3}]" + for off in range(4): rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = f"v{v_cnt+off}" + v_cnt += 4 + elif arg[0][0] in [dtypes.int32, dtypes.float32]: + for i in range(arg[2]): + if arg[0][1]: + rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"s{s_cnt}" + s_cnt += 1 + else: + rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"v{v_cnt}" + v_cnt += 1 + elif arg[0][0] == dtypes.bool and arg[0][1]: + for i in range(arg[2]): + rtor[Register(f"%{arg[1]}{i}", *arg[0])] = "scc" if arg[0][1] else "vcc" + else: + raise NotImplementedError(arg) + elif uop == UOps.SPECIAL: + if arg.startswith('buf'): + i = int(arg[3:]) + ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}') + pend_regs.add(out) + for r in out.subregs(): pend_regs.add(r) + elif arg.startswith('gid'): + ins.append(f'v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}') + # the docs lied, this is actually y + if int(arg[3]) == 2: ins.append("v_bfe_u32 v2, v0, 20, 10") # untested + if int(arg[3]) == 1: ins.append("v_bfe_u32 v1, v0, 10, 10") + elif int(arg[3]) == 0: ins.append("v_and_b32_e32 v0, 0x3ff, v0") + # get local size + offset = len(args)*8 + args.append({".offset": offset, ".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}", ".size": 8}) + ins.append(f's_load_b32 s{2+int(arg[3])}, s[0:1], {offset}') + ins.append('s_waitcnt vmcnt(0) lgkmcnt(0)') + pend_regs.clear() + ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}') + ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}') + elif uop == UOps.CONST: + if arg == float('inf'): arg = "0x7f800000" + elif arg == float('-inf'): arg = "0xff800000" + if out.dtype == dtypes._float4: + for off in range(4): + ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}") + else: + ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}") + elif uop == UOps.ALU: + if arg == BinaryOps.CMPLT: + if out.scalar: + ins.append(f"s_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}") + else: + ins.append(f"v_cmp_lt_{dtype_to_rdnatype[out.dtype]} vcc, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}") + else: + alu_arg = alu[arg] + if arg == FusedOps.MULACC and out == vin[2]: + alu_arg = "fmac" + vin = vin[0:2] + if out.dtype == dtypes._float4: + tins = [] + for rr in zip(*[x.subregs() if x.dtype == dtypes._float4 else [x,x,x,x] for x in [out]+vin]): + tins.append(f"{'s_' if rr[0].scalar else 'v_'}dual_{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}") + ins.append(tins[0] + " :: " + tins[1]) + ins.append(tins[2] + " :: " + tins[3]) + else: + ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}") + elif uop == UOps.LOAD: + if out.scalar: + # swap arg order + ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}') + else: + ins.append(f'global_load_{"b128" if out.dtype == dtypes._float4 else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}') + pend_regs.add(out) + for r in out.subregs(): pend_regs.add(r) + elif uop == UOps.STORE: + ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes._float4 else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}') + elif uop == UOps.LABEL: + ins.append(f"{arg}:") + elif uop == UOps.COND_BRANCH: + ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}") + else: + raise NotImplementedError(uop) + + ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end'] + return 'code', self.assemble(args, ins, v_cnt, s_cnt) + + def assemble(self, args, ins, v_cnt, s_cnt): + kernel_desc = {'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0, + '.amdhsa_next_free_vgpr': v_cnt, # this matters! + '.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0, + '.amdhsa_next_free_sgpr': s_cnt, + '.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3, '.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1, + '.amdhsa_fp16_overflow': 0, '.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0, + '.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1, + '.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real? + '.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0, '.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0, + '.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0, '.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1, + '.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0} + + metadata = {'amdhsa.kernels': [{'.args': args, + '.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"], + '.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256, + '.name': 'code', '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0, + '.symbol': 'code.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0, + '.wavefront_size': 32}], + 'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]} + + code = boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + '\n'.join(ins) + "\n.amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata" + obj = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8"))) + asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], obj)) + return asm diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py index f16802f0f8..b514b90fb1 100644 --- a/tinygrad/codegen/cstyle.py +++ b/tinygrad/codegen/cstyle.py @@ -172,8 +172,8 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan [', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + lang.extra_args)] + [") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"]) - if lang.half_prekernel: prg =''.join([f"{lang.half_prekernel}", "\n", prg]) - if lang.double_prekernel: prg = ''.join([f"{lang.double_prekernel}", "\n", prg]) + if lang.half_prekernel and any(x.dtype == dtypes.float16 for x in bufs): prg = ''.join([f"{lang.half_prekernel}", "\n", prg]) + if lang.double_prekernel and any(x.dtype == dtypes.float64 for x in bufs): prg = ''.join([f"{lang.double_prekernel}", "\n", prg]) return prg, global_size, local_size class CStyleCodegen(Linearizer): diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 4505ee0a16..3cde401ebb 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -19,6 +19,7 @@ def partition(lst, fxn): return [x for x in lst if fxn(x)], [x for x in lst if n def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x def flatten(l:Iterator): return [item for sublist in l for item in sublist] def mnum(i) -> str: return str(i) if i >= 0 else f"m{-i}" +def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm) @functools.lru_cache(maxsize=None) def getenv(key, default=0): return type(default)(os.getenv(key, default)) @@ -74,7 +75,7 @@ class dtypes: @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64) @staticmethod - def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64) + def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes._half4, dtypes._float4) @staticmethod def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint32, dtypes.uint64) @staticmethod @@ -87,10 +88,10 @@ class dtypes: float64: Final[DType] = DType(5, 8, "double", np.float64) int8: Final[DType] = DType(0, 1, "char", np.int8) int32: Final[DType] = DType(1, 4, "int", np.int32) - int64: Final[DType] = DType(2, 8, "int64", np.int64) + int64: Final[DType] = DType(2, 8, "long", np.int64) uint8: Final[DType] = DType(0, 1, "uchar", np.uint8) uint32: Final[DType] = DType(1, 4, "uint", np.uint32) - uint64: Final[DType] = DType(2, 8, "uint64", np.uint64) + uint64: Final[DType] = DType(2, 8, "ulong", np.uint64) # NOTE: these are internal dtypes, should probably check for that _half4: Final[DType] = DType(0, 2*4, "half4", None, 4) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index a98f396b69..ff52183028 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -9,7 +9,8 @@ from tinygrad.runtime.lib import RawBuffer, RawConst # these are the llops your accelerator must implement, along with toCpu # the Enum class doesn't work with mypy, this is static. sorry it's ugly # NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars -class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto() # noqa: E702 +# NOTE: rdna3 only has RECIP and not DIV. DIV and POW are on the chopping block +class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); RECIP = auto() # noqa: E702 class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702 class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class FusedOps(Enum): MULACC = auto() # noqa: E702 diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index e6c032d1b7..4ef1de5eb7 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -4,11 +4,10 @@ import numpy as np import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401 import pycuda.driver as cuda # type: ignore from pycuda.compiler import compile as cuda_compile # type: ignore -from tinygrad.helpers import DEBUG, getenv +from tinygrad.helpers import DEBUG, getenv, fromimport from tinygrad.ops import Compiled from tinygrad.runtime.lib import RawBufferCopyInOut from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage -from tinygrad.codegen.assembly_ptx import PTXCodegen class RawCUDABuffer(RawBufferCopyInOut): def __init__(self, size, dtype): super().__init__(size, dtype, cuda.mem_alloc(size * dtype.itemsize)) @@ -60,4 +59,5 @@ class CUDACodegen(CStyleCodegen): typedef long long int64; """) supports_float4_alu = False -CUDABuffer = Compiled(RawCUDABuffer, PTXCodegen if getenv("PTX") else CUDACodegen, CUDAProgram, cuda.Context.synchronize) + +CUDABuffer = Compiled(RawCUDABuffer, fromimport("tinygrad.codegen.assembly_ptx", "PTXCodegen") if getenv("PTX") else CUDACodegen, CUDAProgram, cuda.Context.synchronize) diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 1e4110d461..6ce8d26fef 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -3,7 +3,7 @@ import pathlib import numpy as np import pyopencl as cl # type: ignore from typing import Optional, List -from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, dtypes +from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, dtypes, fromimport from tinygrad.ops import Compiled from tinygrad.runtime.lib import RawBufferCopyInOut from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage @@ -61,7 +61,7 @@ class CLProgram: if 'Adreno' in CL.cl_ctx.devices[0].name: from disassemblers.adreno import disasm disasm(self.binary()) - elif 'gfx1100' in CL.cl_ctx.devices[0].name: + elif CL.cl_ctx.devices[0].name.startswith('gfx'): asm = early_exec(([ROCM_LLVM_PATH / "llvm-objdump", '-d', '-'], self.binary())) print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x])) else: @@ -87,11 +87,12 @@ class CLProgram: class CLCodegen(CStyleCodegen): lang = CStyleLanguage( - kernel_prefix = "#define int64 long\n__kernel", buffer_prefix = "__global ", smem_prefix = "__local ", + kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ", double_prekernel="#ifdef cl_khr_fp64\n#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n#elif defined(cl_amd_fp64)\n#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n#endif", half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable", barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)", gid = [f'get_global_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True) supports_float4_alu = True supports_float4 = True -GPUBuffer = Compiled(CLBuffer, CLCodegen, CLProgram, CL.synchronize) + +GPUBuffer = Compiled(CLBuffer, fromimport("tinygrad.codegen.assembly_rdna", "RDNACodegen") if getenv("RDNA") else CLCodegen, CLProgram, CL.synchronize) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 379281243d..c6703141c4 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -80,7 +80,7 @@ class MetalProgram: class MetalCodegen(CStyleCodegen): lang = CStyleLanguage( - kernel_prefix = "#include ;\n#define int64 long\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ", + kernel_prefix = "#include \nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ", barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4", gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)], extra_args = ['uint3 gid [[thread_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]'])