diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e56c330950..3d44f93917 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -346,7 +346,7 @@ jobs: strategy: fail-fast: false matrix: - backend: [llvm, clang, gpu, cuda, hip] #, triton] #, ptx] + backend: [llvm, clang, gpu, cuda, hip, ptx] #, triton] name: Tests on (${{ matrix.backend }}) runs-on: ubuntu-latest diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 3870c5581b..8231bac0fa 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -784,8 +784,9 @@ class TestLinearizerUOptimize(unittest.TestCase): assert store_val.dtype == dtypes.float.vec(4) and store_val.uop != UOps.CAST def test_grouped_store_locals_and_globals(self): - if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.has_shared: - self.skipTest("Only Compiled uses linearizer with locals and shared") + if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.has_shared or \ + not Device[Device.DEFAULT].compiler.linearizer_opts.supports_float4: + self.skipTest("Only Compiled uses linearizer with locals, shared, and float4") x, y = Tensor.rand(128, 128), Tensor.rand(128, 128) out = x@y @@ -808,8 +809,9 @@ class TestLinearizerUOptimize(unittest.TestCase): assert len([u for u in k.uops if u.uop is UOps.IF and u.vin[-1] == barrier]) == 1 def test_grouped_store_local_only(self): - if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.has_shared: - self.skipTest("Only Compiled uses linearizer with locals and shared") + if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.has_shared or \ + not Device[Device.DEFAULT].compiler.linearizer_opts.supports_float4: + self.skipTest("Only Compiled uses linearizer with locals, shared, and float4") x, y = Tensor.rand(1,128), Tensor.rand(128, 128) r = (x@y).relu() diff --git a/test/test_linearizer_overflows.py b/test/test_linearizer_overflows.py index da6292a935..6fdd7edf2f 100644 --- a/test/test_linearizer_overflows.py +++ b/test/test_linearizer_overflows.py @@ -1,6 +1,7 @@ # ruff: noqa: E501 import unittest -from tinygrad import dtypes, Device +from tinygrad import dtypes +from tinygrad.helpers import CI from tinygrad.codegen.linearizer import Linearizer from tinygrad.features.search import Opt, OptOps from tinygrad.features.search import time_linearizer, bufs_from_lin @@ -63,7 +64,8 @@ class TestLinearizerOverflow(unittest.TestCase): opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)] _test_overflow(ast, opts) -@unittest.skipIf(Device.DEFAULT not in {"GPU", "HIP", "HSA", "CUDA", "METAL"}, "only backends with locals") +#@unittest.skipIf(Device.DEFAULT not in {"GPU", "HIP", "HSA", "CUDA", "METAL"}, "only backends with locals") +@unittest.skipIf(CI, "slow") class TestLinearizerOverflowAlt(unittest.TestCase): def test_overflow_1(self): BS = 2 diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index c3b7855026..96c61f5f2b 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -2,12 +2,10 @@ import unittest from test.helpers import assert_jit_cache_len from tinygrad.features.jit import TinyJit -from tinygrad.helpers import getenv from tinygrad.shape.symbolic import Variable from tinygrad.tensor import Tensor import numpy as np -@unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported") class TestSymbolicJit(unittest.TestCase): def test_plus1(self): def f(a): return (a+1).realize() diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index 9338ea0df5..beeebfac7f 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -5,7 +5,6 @@ from tinygrad.tensor import Tensor from examples.gpt2 import Attention import numpy as np -@unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported") class TestSymbolicOps(unittest.TestCase): def test_plus1(self): def f(a): return (a+1).realize() diff --git a/test/test_uops.py b/test/test_uops.py index 7cfffe85b4..fa7a18b545 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -2,12 +2,11 @@ from typing import Optional, Tuple, Any, List import unittest, math import numpy as np from tinygrad.dtype import dtypes, DType, PtrDType -from tinygrad.helpers import getenv from tinygrad.device import Buffer, Device from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.device import CompiledASTRunner, Compiled from tinygrad.codegen.linearizer import UOps, UOp -from tinygrad.codegen.uops import exec_alu +from tinygrad.codegen.uops import exec_alu, UOpGraph from test.test_dtype import is_dtype_supported def _uops_to_prg(uops): @@ -29,7 +28,7 @@ def _test_single_value(vals, op, dts): uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) buf = Buffer(Device.DEFAULT, 1, output_dtype) buf2 = [Buffer(Device.DEFAULT, 1, dtype).copyin(np.array([a], dtype=dtype.np).data) for a,dtype in zip(vals, dts)] - prg = _uops_to_prg(uops) + prg = _uops_to_prg(UOpGraph(uops)) prg.exec([buf]+buf2) ret = np.empty(1, output_dtype.np) buf.copyout(ret.data) @@ -43,7 +42,7 @@ def _test_single_value_const(vals, op, dts): alu = uop(uops, UOps.ALU, output_dtype, loads, op) uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) buf = Buffer(Device.DEFAULT, 1, output_dtype) - prg = _uops_to_prg(uops) + prg = _uops_to_prg(UOpGraph(uops)) prg.exec([buf]) ret = np.empty(1, output_dtype.np) buf.copyout(ret.data) @@ -88,26 +87,51 @@ class TestFloatUOps(TestUOps): # MOD isn't tested on floats def test_where(self): - self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (PtrDType(dtypes.bool), PtrDType(dtypes.float), PtrDType(dtypes.float))) + self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float, dtypes.float)) -# TODO: fix this on all the backends -@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or getenv('ARM64', False), "only test for compiled backends, broken on some") class TestNonFloatUOps(TestUOps): - def test_neg_int32(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a, (PtrDType(dtypes.int32), )) - def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), (PtrDType(dtypes.int32), PtrDType(dtypes.int32))) - def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), (PtrDType(dtypes.int32), PtrDType(dtypes.int32))) - def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), (PtrDType(dtypes.int32), PtrDType(dtypes.int32))) + def test_neg_int32(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a, (dtypes.int32, )) + def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), (dtypes.int32, dtypes.int32)) + def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), (dtypes.int32, dtypes.int32)) + def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), (dtypes.int32, dtypes.int32)) def test_div_int32(self): - self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), (PtrDType(dtypes.int32), PtrDType(dtypes.int32)), no_b_zero=True) + self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), (dtypes.int32, dtypes.int32), no_b_zero=True) def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, - lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], (PtrDType(dtypes.int32), PtrDType(dtypes.int32)), no_b_zero=True) - def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a sint: def phi_resolve_acc(u:UOp) -> UOp: return u if u.uop is UOps.DEFINE_ACC else phi_resolve_acc(u.vin[0]) class UOpGraph: - def __init__(self): + def __init__(self, start_uops:Optional[List[UOp]]=None): # list of uops - self.uops: List[UOp] = [] + self.uops: List[UOp] = [] if start_uops is None else start_uops # global uop cache self.saved_exprs: Dict[Tuple, UOp] = dict() @@ -88,7 +88,8 @@ class UOpGraph: if arg is BinaryOps.ADD and vin[1].uop is UOps.ALU and vin[1].arg is UnaryOps.NEG: return self.add(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable, insert_before) # constant folding - if arg is UnaryOps.NEG and vin[0].uop is UOps.CONST: return self.add(UOps.CONST, dtype, arg=-vin[0].arg, insert_before=insert_before) + if arg is UnaryOps.NEG and vin[0].uop is UOps.CONST: + return self.add(UOps.CONST, dtype, arg=-vin[0].arg if dtype != dtypes.bool else not vin[0].arg, insert_before=insert_before) if arg is TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop if arg is TernaryOps.WHERE and vin[0].uop is UOps.CONST: return vin[1] if vin[0].arg else vin[2] if arg is BinaryOps.MUL and vin[0].uop is UOps.CONST and vin[1].uop is UOps.CONST and dtype is not None and dtypes.is_float(dtype): diff --git a/tinygrad/device.py b/tinygrad/device.py index 31baa051a1..b7d7ae002c 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -238,7 +238,7 @@ class Compiled: ops, mem = k.uops.flops_mem() run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else [])) # NOTE: we use min here to ignore the indexing FLOPS - ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops.uops), self, k.global_size, k.local_size, + ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size, k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count)) return ret diff --git a/extra/backends/ptx.py b/tinygrad/renderer/assembly.py similarity index 78% rename from extra/backends/ptx.py rename to tinygrad/renderer/assembly.py index 3ae62ffa3a..fdd008e433 100644 --- a/extra/backends/ptx.py +++ b/tinygrad/renderer/assembly.py @@ -4,6 +4,7 @@ from collections import defaultdict from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op from tinygrad.dtype import dtypes, DType, PtrDType, INVERSE_DTYPES_DICT +from tinygrad.codegen.uops import UOpGraph def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1]) def double_to_hex(x): return "%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1]) @@ -35,11 +36,38 @@ class AssemblyLanguage(NamedTuple): def render_kernel(self, kernel, function_name, bufs, regs) -> str: raise NotImplementedError() -def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:List[UOp]) -> str: +def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: local_size: List[int] = [] kernel:List[str] = [] bufs = [] + # here we do a pretransform on UOps to fix some shortcomings of PTX + # all uops must be a register + # TODO: uops class should make these rewrites easier + replace: Dict[UOp, UOp] = {} + for u in uops: + for o,n in replace.items(): + if o in u.vin and u is not n: + u.vin = tuple(n if x == o else x for x in u.vin) + if u.uop is UOps.LOAD and u.dtype is dtypes.bool: + # rewrite load bool + if len(u.vin) == 4: + new = uops.add(UOps.CAST, dtypes.uint8, (u.vin[3],), insert_before=uops.uops.index(u)) + u.vin = u.vin[0:3] + (new,) + u.dtype = dtypes.uint8 + new = uops.add(UOps.CAST, dtypes.bool, (u,), insert_before=uops.uops.index(u)+1) + replace[u] = new + if u.uop is UOps.ALU and u.arg in {BinaryOps.CMPEQ, BinaryOps.CMPLT} and u.vin[0].dtype is dtypes.bool: + if u.arg == BinaryOps.CMPEQ: + u.arg = BinaryOps.XOR + new = uops.add(UOps.ALU, dtypes.bool, (u,), arg=UnaryOps.NEG, insert_before=uops.uops.index(u)+1) + replace[u] = new + if u.arg == BinaryOps.CMPLT: + new = uops.add(UOps.ALU, dtypes.bool, (u.vin[0],), arg=UnaryOps.NEG, insert_before=uops.uops.index(u)) + u.vin = (new, u.vin[1]) + u.arg = BinaryOps.MUL + #uops.print() + def kk(*s: str): kernel.append("\n".join(s)) c: DefaultDict[str, int] = defaultdict(int) @@ -78,12 +106,12 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:List[UOp]) -> str assert vin[0].dtype is not None kk(*lang.render_bra(lb:=ssa_label(u, 'if'), cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:") elif uop == UOps.BARRIER and lang.barrier: kk(lang.barrier) - elif uop == UOps.END: - if vin[0].uop == UOps.LOOP: - kk(lang.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, lang.types[dtypes.int]), - lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa(None, "pred", "pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int])) - kk(*lang.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:") - else: kk(f"{r_label[vin[0]]}:") + elif uop == UOps.ENDLOOP: + kk(lang.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, lang.types[dtypes.int]), + lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa(None, "pred", "pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int])) + kk(*lang.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:") + elif uop == UOps.ENDIF: + kk(f"{r_label[vin[0]]}:") elif uop == UOps.STORE: assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None kk(*lang.render_gep(loc:=ssa(None,'loc','u64'), r[vin[0]], r[vin[1]], vin[0].dtype)) @@ -97,13 +125,10 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:List[UOp]) -> str elif uop == UOps.ALU: assert vin[0].dtype is not None if args == BinaryOps.CMPLT or args == BinaryOps.CMPEQ: - regs = [cast(r[x], dtypes.int16, dtypes.bool) if x.dtype == dtypes.bool else r[x] for x in vin] - dt = dtypes.int16 if vin[0].dtype == dtypes.bool else vin[0].dtype - kk(lang.asm_for_op[args](pred:=ssa(u,'lt','pred'), *regs, dt, lang.types[dt])) - elif args == TernaryOps.MULACC: - assert vin[1].dtype is not None - kk(lang.asm_for_op[args](ssa(u, 'alu'), *[r[x] for x in vin], dtype, lang.types[vin[1].dtype])) - else: kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype])) + # pass in the other dtype here + kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], vin[0].dtype, lang.types[vin[0].dtype])) + else: + kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype])) elif uop == UOps.DEFINE_ACC: kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};") elif uop == UOps.SPECIAL: if args[1][0] == "i": kk(f"mov.u32 %{args[1]}, {lang.gid[args[0]]};", f"mov.u32 {(gdim:=ssa(None,'tmp','u32'))}, {lang.gdim[args[0]]};", @@ -133,12 +158,17 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:List[UOp]) -> str assert vin[0].dtype is not None cast(r[vin[0]], dtype, vin[0].dtype, bitcast=isinstance(args, tuple) and args[1], u=u) elif uop == UOps.DEFINE_LOCAL: kk(*lang.render_local(ssa(u, 'local', lang.types[dtypes.ulong]), args[0], args[1], dtype)) - elif uop == UOps.DEFINE_GLOBAL: - bufs.append((args, dtype)) - r[u] = f"%{args}" + elif uop is UOps.DEFINE_VAR: + bufs.append((args.expr, dtype)) + r[u] = f"%{args.expr}" + if lang.load_global: + kk(*lang.render_load(args.expr, ssa(u, 'dat', dtype=lang.types[dtype]), dtype, ss=".param")) + elif uop is UOps.DEFINE_GLOBAL: + bufs.append((args[1], dtype)) + r[u] = f"%{args[1]}" if lang.load_global: dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype - kk(*lang.render_load(args, ssa(u, 'dat', dtype=lang.types[dt]), dt, ss=".param")) + kk(*lang.render_load(args[1], ssa(u, 'dat', dtype=lang.types[dt]), dt, ss=".param")) else: raise NotImplementedError(f"no code for {uop}") return lang.render_kernel(kernel, function_name, bufs, c.items()) @@ -156,24 +186,22 @@ class PTXLanguage(AssemblyLanguage): gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)] lid = [f'%tid.{chr(120+i)}' for i in range(3)] asm_for_op = { - UnaryOps.NEG: lambda d,a,dt,name: f"neg.{name} {d}, {a};", + UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"neg.{name} {d}, {a};", UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};", UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};", BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};", BinaryOps.SUB: lambda d,a,b,dt,name: f"sub.{name} {d}, {a}, {b};", BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};", - BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.b{name[1:]} {d}, {a}, {b};", + BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};", BinaryOps.DIV: lambda d,a,b,dt,name: f"div{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a}, {b};", BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};", BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", BinaryOps.CMPEQ: lambda d,a,b,dt,name: f"setp.eq.{name} {d}, {a}, {b};", - TernaryOps.MULACC: lambda d,a,b,c,dt,name: (('fma.rn' if dtypes.is_float(dt) else 'mad.lo' if a.split('_')[1]==c.split('_')[1] else 'mad.wide') + - f".{name} {d}, {a}, {b}, {c};"), - TernaryOps.WHERE: lambda d,a,b,c,dt,name: f"selp.{name} {d}, {b}, {c}, {a};" + TernaryOps.WHERE: lambda d,a,b,c,dt,name: + f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};" } - supports_half = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, - TernaryOps.MULACC, TernaryOps.WHERE] + supports_half = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE] types = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64", dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64", @@ -206,14 +234,11 @@ class PTXLanguage(AssemblyLanguage): def mem_type(self, dtype): return 's8' if dtype.itemsize == 1 else 'b16' if dtype == dtypes.float16 else self.types[dtype] def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="") -> List[str]: + assert dtype is not dtypes.bool ret = [] - if (byte:=dtype.itemsize == 1): ret.append(f".reg .s8 {dest}_tmp;") - if (isbool:= dtype == dtypes.bool): ret.append(f".reg .s16 {dest}_bool;") if gate: ret.extend([f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}];", - f"@!{gate} mov.b{'8' if byte else self.types[dtype][1:]} {dest + ('_tmp' if byte else '')}, {alt};"]) - else: ret.append(f"ld{ss}.{'s8' if byte else 'b16' if dtype==dtypes.float16 else self.types[dtype]} {dest + ('_tmp' if byte else '')}, [{loc}];") - if byte: ret.append(f"cvt.{'s16' if isbool else self.types[dtype]}.s8 {dest + ('_bool' if isbool else '')}, {dest}_tmp;") - if isbool: ret.append(f"setp.ne.s16 {dest}, {dest}_bool, {self.render_const(0, dtypes.int16)};") + f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]) + else: ret.append(f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}];") return ret def render_store(self, loc, val, dtype, gate=None, ss="") -> List[str]: diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 6b58121e12..27276f4109 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -5,6 +5,7 @@ from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.helpers import strip_parens, getenv from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType +from tinygrad.codegen.uops import UOpGraph class CStyleLanguage(NamedTuple): kernel_prefix: str = "" @@ -24,7 +25,7 @@ class CStyleLanguage(NamedTuple): uses_ptr_arithmetic: bool = False type_map: Dict[DType, str] = {} code_for_op: Dict = { - UnaryOps.NEG: lambda x,dtype: f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})", + UnaryOps.NEG: lambda x,dtype: f"(!{x})" if dtype is dtypes.bool else f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})", UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})", BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})", @@ -61,7 +62,7 @@ class CStyleLanguage(NamedTuple): out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]" return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val - def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str: + def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:UOpGraph, prefix=None) -> str: tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501 buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else ("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else @@ -86,7 +87,7 @@ class CStyleLanguage(NamedTuple): def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{dtype.name} {name}[{size}];" def render_dtype(self, var_dtype:DType) -> str: return self.type_map[var_dtype] if var_dtype in self.type_map else var_dtype.name -def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> str: +def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str: kernel = [] bufs: List[Tuple[str, Tuple[DType, bool]]] = [] #pend_close = None diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 174e95ca92..746e8c7ff2 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -3,13 +3,15 @@ from llvmlite import ir from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.dtype import DType, PtrDType, dtypes from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps +from tinygrad.codegen.uops import UOpGraph MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype) code_for_op: Final[Dict[Op, Callable]] = { - UnaryOps.NEG: lambda builder, x, var_dtype: builder.neg(x) if dtypes.is_int(var_dtype) else builder.fneg(x, flags=MFLAGS), + UnaryOps.NEG: lambda builder, x, var_dtype: builder.neg(x) if dtypes.is_int(var_dtype) else \ + (builder.not_(x) if var_dtype is dtypes.bool else builder.fneg(x, flags=MFLAGS)), UnaryOps.EXP2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS), UnaryOps.LOG2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS), UnaryOps.SIN: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS), @@ -65,7 +67,7 @@ def const(args, dtype): # TODO: remove int from int(args) once const args conform with dtype return ir.Constant(dtype_to_llvm_dtype[dtype], int(args) if dtypes.is_int(dtype) else bool(args) if dtype == dtypes.bool else args) -def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> str: +def uops_to_llvm_ir(function_name:str, uops:UOpGraph) -> str: # all llvm stuff goes into a module module = ir.Module(name=__file__) diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 7f22d8864b..e0164e4d7c 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -7,6 +7,7 @@ from tinygrad.helpers import DEBUG, getenv, from_mv, to_char_p_p, init_c_var, co from tinygrad.device import Compiled, LRUAllocator, MallocAllocator, Compiler from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import CUDARenderer +from tinygrad.renderer.assembly import PTXRenderer def pretty_ptx(s): # all expressions match `` and replace it with `color()` @@ -33,6 +34,15 @@ def _get_bytes(arg, get_str, get_sz, check) -> bytes: sz = init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x)))) return ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value) +class PTXCompiler(Compiler): + linearizer_opts = LinearizerOptions("CUDA", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], supports_float4=False) + def __init__(self, arch:str): + self.arch = arch + PTXCompiler.linearizer_opts = PTXCompiler.linearizer_opts._replace(has_tensor_cores=int(arch[3:]) >= 80) + super().__init__(f"compile_ptx_{self.arch}") + def render(self, name:str, uops) -> str: return PTXRenderer(name, uops).replace("TARGET", self.arch) + def compile(self, src:str) -> bytes: return src.encode() + class CUDACompiler(Compiler): linearizer_opts = LinearizerOptions("CUDA", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024]) def __init__(self, arch:str): @@ -100,7 +110,8 @@ class CUDADevice(Compiled): self.arch = f"sm_{major.value}{minor.value}" if not CUDACPU else "sm_35" from tinygrad.runtime.graph.cuda import CUDAGraph - super().__init__(device, CUDAAllocator(self) if not CUDACPU else MallocAllocator, CUDACompiler(self.arch), + super().__init__(device, CUDAAllocator(self) if not CUDACPU else MallocAllocator, + PTXCompiler(self.arch) if getenv("PTX") else CUDACompiler(self.arch), functools.partial(CUDAProgram, self), graph=CUDAGraph if not CUDACPU else None) def synchronize(self): if not CUDACPU: diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index bca0c58b4f..09486ebe34 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -6,7 +6,7 @@ import pickle, base64, itertools, time, struct from tinygrad.dtype import DType, dtypes, ImageDType from tinygrad.helpers import all_same, getenv, flatten from tinygrad.device import Compiled, Allocator, Compiler -from tinygrad.codegen.uops import UOp, UOps, exec_alu +from tinygrad.codegen.uops import UOpGraph, UOps, exec_alu from tinygrad.ops import BinaryOps, TernaryOps from tinygrad.codegen.kernel import LinearizerOptions @@ -188,8 +188,8 @@ class PythonCompiler(Compiler): linearizer_opts = LinearizerOptions("METAL", has_tensor_cores=True) if getenv("EMULATE_METAL") else \ (LinearizerOptions("HIP", has_tensor_cores=True) if getenv("EMULATE_HIP") else \ (LinearizerOptions("CUDA", has_tensor_cores=True) if getenv("EMULATE_CUDA") else LinearizerOptions("PYTHON"))) - def render(self, name:str, uops:List[UOp]) -> str: - lops = [(u.uop, u.dtype, [uops.index(v) for v in u.vin], u.arg) for u in uops] + def render(self, name:str, uops:UOpGraph) -> str: + lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops] return base64.b64encode(pickle.dumps(lops)).decode() def compile(self, src:str) -> bytes: return base64.b64decode(src)