diff --git a/tinygrad/codegen/assembly.py b/tinygrad/codegen/assembly.py new file mode 100644 index 0000000000..2080219d99 --- /dev/null +++ b/tinygrad/codegen/assembly.py @@ -0,0 +1,166 @@ +from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict +from tinygrad.codegen.linearizer import Linearizer, UOps +from tinygrad.ops import ASTRunner, FusedOps, BinaryOps, UnaryOps +from tinygrad.helpers import DType, dtypes, DEBUG +from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode +import functools +import math +from collections import defaultdict + +type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'I', dtypes.uint64: 'A'} + +class Register(NamedTuple): + nm:str + dtype:DType + def __repr__(self): return self.nm + +class AssemblyInstruction(NamedTuple): + op: UOps + out: Optional[Register] + vin: List[Union[Register, int, float]] + arg: Any = None + +# warp size of 32, s registers are shared across the warp, v are 32-wide vectors +class AssemblyCodegen(Linearizer): + supports_load3: bool = False + + def specialize(self, asm:List[AssemblyInstruction]) -> Tuple[str, str]: + raise NotImplementedError("must be implemented") + + # s registers are the addresses and non local indexes + def codegen(self): + self.process() + self.hand_coded_optimizations() + self.limit_global_dims(3) # all GPU asms have 3 (for now) + self.linearize() + + cnts:DefaultDict[DType, int] = defaultdict(int) + tor: Dict[Any, Register] = {} + def newreg(tok, dtype=dtypes.float32): + nonlocal cnts, tor + tor[tok] = ret = Register(f"%{type_to_letter[dtype]}{cnts[dtype]}", dtype) + cnts[dtype] += 1 + return ret + + def render_numnode(b): + key = ("num", b) + if key not in tor: ins.append(AssemblyInstruction(UOps.CONST, newreg(key, dtype=dtypes.int32), [], b)) + return tor[key] + + def render_alu(op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register: + key = (op, a, b) + if key not in tor: + #if not isinstance(b, Register): b = render_numnode(b) + ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype), [a, b], op)) + return tor[key] + + def render_cast(a:Register, new_dtype:DType) -> Register: + if a.dtype == new_dtype: return a + key = (a, new_dtype) + if key not in tor: + ins.append(AssemblyInstruction(UOps.CAST, newreg(key, dtype=new_dtype), [a])) + return tor[key] + + render_ops = { Variable: lambda self, ops, ctx: tor[self], NumNode: lambda self, ops, ctx: render_numnode(self.b), + MulNode: lambda self, ops, ctx: render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b), + DivNode: lambda self, ops, ctx: render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b), + ModNode: lambda self, ops, ctx: render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b), + LtNode: lambda self, ops, ctx: render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool), + SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)), + AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) } + + def addr_w_offset(args): + idx = args.idx*self.bufs[args.i].dtype.itemsize + off = 0 # TODO: should this be None? + if isinstance(idx, SumNode) and not self.supports_load3: + nums = [n.b for n in idx.nodes if isinstance(n, NumNode)] + if len(nums) > 0: + idx -= nums[0] + off = nums[0] + reg = idx.render(render_ops) + if self.supports_load3: + return tor[f"buf{args.i}"], reg + else: + reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64) + return reg, off + + ins = [] + ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64), [], f"buf{i}") for i in range(len(self.bufs))] + global_size, local_size = [], [] + skipload_branch = 0 + for uop,newvar,vin,args in self.uops: + if uop == UOps.CONST and newvar is not None: + ins.append(AssemblyInstruction(UOps.CONST, newreg(newvar), [], args)) + elif uop == UOps.DEFINE_LOCAL: + ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args)) + ins.append(AssemblyInstruction(UOps.ALU, newreg("buf-1", dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP)) + elif uop == UOps.LOOP: + if args[1] == "global": + for i,var in enumerate(args[0]): + global_size.append(var.max+1) + ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}")) + elif args[1] == "local": + for i,var in enumerate(args[0]): + local_size.append(var.max+1) + global_size[i] *= local_size[i] + ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}")) + else: + for var in args[0]: + if not isinstance(var, NumNode): # TODO: why is this coming through? + ins.append(AssemblyInstruction(UOps.CONST, newreg(var, dtype=dtypes.int32), [], 0)) + ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr)) + elif uop == UOps.ENDLOOP: + if args[1] not in ["global", "local"]: + for var in reversed(args[0]): + if not isinstance(var, NumNode): # TODO: why is this coming through? + pred = render_alu(BinaryOps.CMPLT, tor[var], var.max, dtypes.bool) + ins.append(AssemblyInstruction(UOps.ALU, tor[var], [tor[var], 1], BinaryOps.ADD)) + ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True))) + elif uop == UOps.ALU and newvar is not None: + if args == FusedOps.MULACC: vin = [vin[1], vin[2], vin[0]] # TODO: reorder MULACC everywhere + # this is the only thing that can violate SSA + if args in [BinaryOps.CMPEQ, BinaryOps.CMPLT]: + pred_reg = newreg((newvar, 'pred'), dtype=dtypes.bool) + ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [tor[x] for x in vin], args)) + ins.append(AssemblyInstruction(UOps.CAST, newreg(newvar), [pred_reg], args)) + elif args == BinaryOps.POW: + # TODO: add UnaryOps.SQRT + tmp = newreg((newvar, "exp_a")) + tmp2 = newreg((newvar, "exp_a_times_b")) + ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]]], UnaryOps.LOG2)) + ins.append(AssemblyInstruction(UOps.ALU, tmp2, [tmp, tor[vin[1]]], BinaryOps.MUL)) + ins.append(AssemblyInstruction(UOps.ALU, newreg(newvar), [tmp2], UnaryOps.EXP2)) + elif args == UnaryOps.SIN and hasattr(self, 'sin_is_sin2pi'): + tmp = newreg((newvar, "2pi")) + ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL)) + ins.append(AssemblyInstruction(UOps.ALU, newreg(newvar) if newvar not in tor else tor[newvar], [tmp], args)) + else: + ins.append(AssemblyInstruction(UOps.ALU, newreg(newvar) if newvar not in tor else tor[newvar], [tor[x] for x in vin], args)) + elif uop == UOps.LOAD and newvar is not None: + idx, off = addr_w_offset(args) + reg = newreg(newvar) + if args.valid.min == 0: + ins.append(AssemblyInstruction(UOps.CONST, reg, [], 0)) + if args.valid.max == 1: + pred = args.valid.render(render_ops) + ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False))) + if args.valid.max == 1: + # NOTE: you can't compute the index in here, because it assumes it's all available later + ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx], (off, 'global' if args.i != -1 else 'shared'))) + if args.valid.min == 0 and args.valid.max == 1: + ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}")) + skipload_branch += 1 + elif uop == UOps.STORE: + idx, off = addr_w_offset(args) + ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]], (off, 'global' if args.i != -1 else 'shared'))) + + # define registers + ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter[dtype], c)) for dtype,c in cnts.items()] + ins + + if DEBUG >= 4: + for tins in ins: print(tins) + name, asm = self.specialize(ins) + + return ASTRunner(name, asm, + global_size[::-1] if len(global_size) else [1], local_size[::-1] if len(local_size) else None, + op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name, runtime_args={"binary": True}) diff --git a/tinygrad/codegen/assembly_ptx.py b/tinygrad/codegen/assembly_ptx.py new file mode 100644 index 0000000000..a179446ca0 --- /dev/null +++ b/tinygrad/codegen/assembly_ptx.py @@ -0,0 +1,66 @@ +import struct +from tinygrad.codegen.assembly import AssemblyCodegen +from tinygrad.ops import BinaryOps, UnaryOps, FusedOps +from tinygrad.codegen.linearizer import UOps +from tinygrad.helpers import dtypes + +dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "u16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32"} +def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1]) + +# https://docs.nvidia.com/cuda/parallel-thread-execution/# +class PTXCodegen(AssemblyCodegen): + #supports_constant_folding: bool = True + + def specialize(self, asm): + ins = [".version 7.8", ".target sm_86", ".address_size 64", + f".visible .entry test({', '.join(f'.param .u64 buf{i}' for i in range(len(self.bufs)))}) {{"] + + alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max", + BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", BinaryOps.CMPEQ: "setp.eq", + UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz", + FusedOps.MULACC: "fma.rn"} + + for uop, out, vin, arg in asm: + if uop == UOps.DEFINE_REGISTER: + ins.append(f".reg .{dtype_to_nvtype[arg[0]]} %{arg[1]}<{arg[2]}>;",) + elif uop == UOps.DEFINE_LOCAL: + ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];") + elif uop == UOps.SPECIAL: + if arg.startswith('buf'): + ins.append(f"ld.param.u64 {out}, [{arg}];") + # TODO: is this needed? + #ins.append(f"cvta.to.global.u64 {out}, {out};") + elif arg.startswith('gid'): + #ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};") + ins.append("{ .reg .b32 %tmp<3>;") + l = 'xyz'[int(arg[3:])] + ins.append(f"mov.u32 %tmp0, %ctaid.{l};") + ins.append(f"mov.u32 %tmp1, %ntid.{l};") + ins.append(f"mov.u32 %tmp2, %tid.{l};") + ins.append(f"mad.lo.s32 {out}, %tmp0, %tmp1, %tmp2; }}") + elif arg.startswith('lid'): + ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};") + elif uop == UOps.ALU: + if arg == BinaryOps.MUL and out.dtype == dtypes.bool: + ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};") + else: + otype = vin[0].dtype if arg in [BinaryOps.CMPEQ, BinaryOps.CMPLT] else out.dtype + ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};") + elif uop == UOps.LOAD: + ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[out.dtype]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];") + elif uop == UOps.STORE: + ins.append(f"st.{arg[1]}.{dtype_to_nvtype[vin[1].dtype]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};") + elif uop == UOps.CAST: + if vin[0].dtype == dtypes.bool: + ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, 0f3F800000, 0f00000000, {vin[0]};") + else: + ins.append(f"cvt.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[vin[0].dtype]} {out}, {vin[0]};") + elif uop == UOps.CONST: + ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else arg};") + elif uop == UOps.LABEL: + ins.append(f"{arg}:") + elif uop == UOps.COND_BRANCH: + ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};") + + ins += ["ret;", "}"] + return "test", '\n'.join(ins) diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py index 36aca36125..f73ab839e9 100644 --- a/tinygrad/codegen/cstyle.py +++ b/tinygrad/codegen/cstyle.py @@ -2,7 +2,7 @@ from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, import math, collections from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer, LocalTypes from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps -from tinygrad.helpers import partition, ImageDType, DEBUG, dtypes, colored, prod +from tinygrad.helpers import partition, ImageDType, DEBUG, dtypes, colored from tinygrad.runtime.lib import RawConst from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode from tinygrad.lazy import LazyBuffer @@ -188,15 +188,7 @@ class CStyleCodegen(Linearizer): def codegen(self): self.process() self.hand_coded_optimizations() - - # sometimes, there's more dimensions than len(self.lang.gid). - # compact all the dimensions into the first - # NOTE: this might make multiview shapetrackers - if len(self.lang.gid) and self.first_reduce > len(self.lang.gid): - num_to_merge = (self.first_reduce - len(self.lang.gid))+1 - self.reshape_and_permute(lambda x: (prod(x[0:num_to_merge]),)+x[num_to_merge:], None) - if DEBUG >= 4: print("reshaped to", self.full_shape, "due to too many global dimensions") - + self.limit_global_dims(len(self.lang.gid)) self.linearize() prg, global_size, local_size = uops_to_cstyle(self.uops, self.bufs, self.lang) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index f532340489..7c47b39b3b 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -10,7 +10,9 @@ from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape from tinygrad.shape.symbolic import Variable -class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto() # noqa: E702 +# bottom ones are asm only +class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto(); \ + SPECIAL = auto(); DEFINE_REGISTER = auto(); LABEL = auto(); COND_BRANCH = auto() # noqa: E702 class LocalBuffer(NamedTuple): dtype: DType = dtypes.float32 @@ -453,6 +455,15 @@ class Linearizer: self.shift_to(unit_stride_axes_mul_4[0], 4) self.upcast() + def limit_global_dims(self, limit): + # sometimes, there's more dimensions than len(self.lang.gid). + # compact all the dimensions into the first + # NOTE: this might make multiview shapetrackers + if limit and self.first_reduce > limit: + num_to_merge = (self.first_reduce - limit)+1 + self.reshape_and_permute(lambda x: (prod(x[0:num_to_merge]),)+x[num_to_merge:], None) + if DEBUG >= 4: print("reshaped to", self.full_shape, "due to too many global dimensions") + def hand_coded_optimizations(self): # if there's images in the earlybufs, we have to make an axis the 4 loading one self.required_optimizations(early_only=True) @@ -523,3 +534,10 @@ class Linearizer: self.shift_to(len(self.full_unupcasted_shape)-1, splits, insert_before=len(self.full_unupcasted_shape)) self.upcast() break + + # if nothing at all is upcasted and it's easy to, do an upcast + # TODO: this is breaking the tests + #for splits in [4]: + # if self.upcasted == 0 and len(self.full_unupcasted_shape) > 0 and self.full_unupcasted_shape[-1] % splits == 0: + # self.shift_to(len(self.full_unupcasted_shape)-1, splits, insert_before=len(self.full_unupcasted_shape)) + # self.upcast() diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 6201b82d4d..1585b5268f 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -1,6 +1,6 @@ from __future__ import annotations from dataclasses import dataclass, asdict -import os, math, functools, time +import os, math, functools, time, re import numpy as np from typing import Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any ShapeType = Tuple[int, ...] @@ -12,6 +12,7 @@ def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0 def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True def colored(st, color, background=False, bright=False): return f"\u001b[{10*background+60*bright+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color)}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line +def ansilen(s): return len(re.sub('\x1b\\[(K|.*?m)', '', s)) def partition(lst, fxn): return [x for x in lst if fxn(x)], [x for x in lst if not fxn(x)] def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x def flatten(l:Iterator): return [item for sublist in l for item in sublist] @@ -34,6 +35,7 @@ class ContextVar: def __bool__(self): return self.value != 0 def __ge__(self, x): return self.value >= x def __gt__(self, x): return self.value > x + def __lt__(self, x): return self.value < x @property def value(self): return ContextVar.ctx_stack[-1][self.key] if self.key in ContextVar.ctx_stack[-1] else self.initial_value @@ -71,7 +73,7 @@ class dtypes: @staticmethod def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32) @staticmethod - def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8) + def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint32, dtypes.uint64) @staticmethod def from_np(x) -> DType: return asdict(dtypes())[np.dtype(x).name] bool: Final[DType] = DType(0, 1, "bool", bool) @@ -81,6 +83,8 @@ class dtypes: int32: Final[DType] = DType(1, 4, "int", np.int32) int64: Final[DType] = DType(2, 8, "int64", np.int64) uint8: Final[DType] = DType(0, 1, "uchar", np.uint8) + uint32: Final[DType] = DType(1, 4, "uint", np.uint32) + uint64: Final[DType] = DType(2, 8, "uint64", np.uint64) class GlobalCounters: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 7dbd675088..a98f396b69 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -2,14 +2,15 @@ from __future__ import annotations import functools, itertools, operator, random, time from enum import Enum, auto from typing import Union, Type, NamedTuple, Tuple, Any, List, Optional, Dict, Callable, ClassVar -from tinygrad.helpers import prod, DEBUG, getenv, GlobalCounters, DType, colored +from tinygrad.helpers import prod, DEBUG, getenv, GlobalCounters, DType, colored, ansilen from tinygrad.shape.shapetracker import MovementOps from tinygrad.runtime.lib import RawBuffer, RawConst # these are the llops your accelerator must implement, along with toCpu # the Enum class doesn't work with mypy, this is static. sorry it's ugly +# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto() # noqa: E702 -class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702 +class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702 class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class FusedOps(Enum): MULACC = auto() # noqa: E702 class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702 @@ -79,12 +80,12 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.ex # **************** for Compiled Buffers **************** class ASTRunner: - def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None): - if DEBUG >= 4: print(prg) - self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name + def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None): + if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args): print(prg) + self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {} def build(self, runtime): - self.clprg = runtime(self.name, self.prg) + self.clprg = runtime(self.name, self.prg, **self.runtime_args) return self def exec(self, bufs) -> Optional[float]: @@ -96,7 +97,7 @@ class ASTRunner: if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs, allow_cache=(getenv("OPTLOCAL") >= 2)) if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et if DEBUG >= 2: - print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(29-len(self.name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + + print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(29-ansilen(self.display_name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS, {self.mem_estimate/(et*1e9):7.2f} GB/s)")) GlobalCounters.kernel_count += 1 GlobalCounters.global_ops += self.op_estimate diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 3c1ce004ae..e6c032d1b7 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -4,10 +4,11 @@ import numpy as np import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401 import pycuda.driver as cuda # type: ignore from pycuda.compiler import compile as cuda_compile # type: ignore -from tinygrad.helpers import DEBUG +from tinygrad.helpers import DEBUG, getenv from tinygrad.ops import Compiled from tinygrad.runtime.lib import RawBufferCopyInOut from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage +from tinygrad.codegen.assembly_ptx import PTXCodegen class RawCUDABuffer(RawBufferCopyInOut): def __init__(self, size, dtype): super().__init__(size, dtype, cuda.mem_alloc(size * dtype.itemsize)) @@ -59,4 +60,4 @@ class CUDACodegen(CStyleCodegen): typedef long long int64; """) supports_float4_alu = False -CUDABuffer = Compiled(RawCUDABuffer, CUDACodegen, CUDAProgram, cuda.Context.synchronize) +CUDABuffer = Compiled(RawCUDABuffer, PTXCodegen if getenv("PTX") else CUDACodegen, CUDAProgram, cuda.Context.synchronize) diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index f0178791e9..8382e96fd9 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -19,6 +19,7 @@ class Node: @functools.cached_property def key(self) -> str: return self.render(ctx="DEBUG") def __repr__(self): return "<"+self.key+">" + def __hash__(self): return hash(self.__repr__()) def __eq__(self, other:object) -> bool: if not isinstance(other, Node): return NotImplemented return self.key == other.key diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d5100a1fcd..ab3664a6c3 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -471,7 +471,7 @@ class Tensor: def cumsum(self, axis=0): x = self.permute(*(i for i in range(self.ndim) if i != axis), axis) return x.reshape(1, 1, -1, self.shape[axis]).conv2d(Tensor.ones(1, 1, 1, self.shape[axis], dtype=self.dtype, device=self.device), padding=(self.shape[axis]-1, 0, 0, 0)).reshape(*x.shape).permute(*range(axis), self.ndim - 1, *range(axis, self.ndim-1)) - + # ***** mlops (unary) ***** def contiguous(self): return mlops.Contiguous.apply(self) @@ -481,12 +481,12 @@ class Tensor: def sin(self): return mlops.Sin.apply(self) def cos(self): return ((math.pi/2)-self).sin() def tan(self): return self.sin() / self.cos() - + @staticmethod def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(c-k, start=-k, **kwargs).unsqueeze(0).expand(r,c) def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype).where(self, Tensor.zeros_like(self)) def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype).where(Tensor.zeros_like(self), self) - + # ***** math functions (unary) ***** def __neg__(self): return 0.0-self @@ -527,7 +527,12 @@ class Tensor: def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Add, x, reverse) if isinstance(x, Tensor) or x != 0.0 else self def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if isinstance(x, Tensor) or x != 0.0 or reverse else self def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if isinstance(x, Tensor) or x != 1.0 else self - def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Pow, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self + def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: + if not isinstance(x, Tensor) and not reverse: + # simple pow identities + if x == 2.0: return self*self + if x == -1.0: return 1/self + return self._broadcasted(mlops.Pow, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if isinstance(x, Tensor) or reverse or x == 0.0 else self.mul(1/x) def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)