PTX assembly support (#977)

* ptx assembly

* all ops tests pass

* fix tests
This commit is contained in:
George Hotz
2023-06-13 12:31:42 -07:00
committed by GitHub
parent 727416201f
commit ba4eadb04c
9 changed files with 280 additions and 26 deletions

View File

@@ -0,0 +1,166 @@
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict
from tinygrad.codegen.linearizer import Linearizer, UOps
from tinygrad.ops import ASTRunner, FusedOps, BinaryOps, UnaryOps
from tinygrad.helpers import DType, dtypes, DEBUG
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
import functools
import math
from collections import defaultdict
type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'I', dtypes.uint64: 'A'}
class Register(NamedTuple):
nm:str
dtype:DType
def __repr__(self): return self.nm
class AssemblyInstruction(NamedTuple):
op: UOps
out: Optional[Register]
vin: List[Union[Register, int, float]]
arg: Any = None
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
class AssemblyCodegen(Linearizer):
supports_load3: bool = False
def specialize(self, asm:List[AssemblyInstruction]) -> Tuple[str, str]:
raise NotImplementedError("must be implemented")
# s registers are the addresses and non local indexes
def codegen(self):
self.process()
self.hand_coded_optimizations()
self.limit_global_dims(3) # all GPU asms have 3 (for now)
self.linearize()
cnts:DefaultDict[DType, int] = defaultdict(int)
tor: Dict[Any, Register] = {}
def newreg(tok, dtype=dtypes.float32):
nonlocal cnts, tor
tor[tok] = ret = Register(f"%{type_to_letter[dtype]}{cnts[dtype]}", dtype)
cnts[dtype] += 1
return ret
def render_numnode(b):
key = ("num", b)
if key not in tor: ins.append(AssemblyInstruction(UOps.CONST, newreg(key, dtype=dtypes.int32), [], b))
return tor[key]
def render_alu(op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
key = (op, a, b)
if key not in tor:
#if not isinstance(b, Register): b = render_numnode(b)
ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype), [a, b], op))
return tor[key]
def render_cast(a:Register, new_dtype:DType) -> Register:
if a.dtype == new_dtype: return a
key = (a, new_dtype)
if key not in tor:
ins.append(AssemblyInstruction(UOps.CAST, newreg(key, dtype=new_dtype), [a]))
return tor[key]
render_ops = { Variable: lambda self, ops, ctx: tor[self], NumNode: lambda self, ops, ctx: render_numnode(self.b),
MulNode: lambda self, ops, ctx: render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
DivNode: lambda self, ops, ctx: render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
ModNode: lambda self, ops, ctx: render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
LtNode: lambda self, ops, ctx: render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
def addr_w_offset(args):
idx = args.idx*self.bufs[args.i].dtype.itemsize
off = 0 # TODO: should this be None?
if isinstance(idx, SumNode) and not self.supports_load3:
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
if len(nums) > 0:
idx -= nums[0]
off = nums[0]
reg = idx.render(render_ops)
if self.supports_load3:
return tor[f"buf{args.i}"], reg
else:
reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64)
return reg, off
ins = []
ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64), [], f"buf{i}") for i in range(len(self.bufs))]
global_size, local_size = [], []
skipload_branch = 0
for uop,newvar,vin,args in self.uops:
if uop == UOps.CONST and newvar is not None:
ins.append(AssemblyInstruction(UOps.CONST, newreg(newvar), [], args))
elif uop == UOps.DEFINE_LOCAL:
ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
ins.append(AssemblyInstruction(UOps.ALU, newreg("buf-1", dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
elif uop == UOps.LOOP:
if args[1] == "global":
for i,var in enumerate(args[0]):
global_size.append(var.max+1)
ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
elif args[1] == "local":
for i,var in enumerate(args[0]):
local_size.append(var.max+1)
global_size[i] *= local_size[i]
ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
else:
for var in args[0]:
if not isinstance(var, NumNode): # TODO: why is this coming through?
ins.append(AssemblyInstruction(UOps.CONST, newreg(var, dtype=dtypes.int32), [], 0))
ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
elif uop == UOps.ENDLOOP:
if args[1] not in ["global", "local"]:
for var in reversed(args[0]):
if not isinstance(var, NumNode): # TODO: why is this coming through?
pred = render_alu(BinaryOps.CMPLT, tor[var], var.max, dtypes.bool)
ins.append(AssemblyInstruction(UOps.ALU, tor[var], [tor[var], 1], BinaryOps.ADD))
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
elif uop == UOps.ALU and newvar is not None:
if args == FusedOps.MULACC: vin = [vin[1], vin[2], vin[0]] # TODO: reorder MULACC everywhere
# this is the only thing that can violate SSA
if args in [BinaryOps.CMPEQ, BinaryOps.CMPLT]:
pred_reg = newreg((newvar, 'pred'), dtype=dtypes.bool)
ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [tor[x] for x in vin], args))
ins.append(AssemblyInstruction(UOps.CAST, newreg(newvar), [pred_reg], args))
elif args == BinaryOps.POW:
# TODO: add UnaryOps.SQRT
tmp = newreg((newvar, "exp_a"))
tmp2 = newreg((newvar, "exp_a_times_b"))
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]]], UnaryOps.LOG2))
ins.append(AssemblyInstruction(UOps.ALU, tmp2, [tmp, tor[vin[1]]], BinaryOps.MUL))
ins.append(AssemblyInstruction(UOps.ALU, newreg(newvar), [tmp2], UnaryOps.EXP2))
elif args == UnaryOps.SIN and hasattr(self, 'sin_is_sin2pi'):
tmp = newreg((newvar, "2pi"))
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
ins.append(AssemblyInstruction(UOps.ALU, newreg(newvar) if newvar not in tor else tor[newvar], [tmp], args))
else:
ins.append(AssemblyInstruction(UOps.ALU, newreg(newvar) if newvar not in tor else tor[newvar], [tor[x] for x in vin], args))
elif uop == UOps.LOAD and newvar is not None:
idx, off = addr_w_offset(args)
reg = newreg(newvar)
if args.valid.min == 0:
ins.append(AssemblyInstruction(UOps.CONST, reg, [], 0))
if args.valid.max == 1:
pred = args.valid.render(render_ops)
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
if args.valid.max == 1:
# NOTE: you can't compute the index in here, because it assumes it's all available later
ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx], (off, 'global' if args.i != -1 else 'shared')))
if args.valid.min == 0 and args.valid.max == 1:
ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
skipload_branch += 1
elif uop == UOps.STORE:
idx, off = addr_w_offset(args)
ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]], (off, 'global' if args.i != -1 else 'shared')))
# define registers
ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter[dtype], c)) for dtype,c in cnts.items()] + ins
if DEBUG >= 4:
for tins in ins: print(tins)
name, asm = self.specialize(ins)
return ASTRunner(name, asm,
global_size[::-1] if len(global_size) else [1], local_size[::-1] if len(local_size) else None,
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name, runtime_args={"binary": True})

View File

@@ -0,0 +1,66 @@
import struct
from tinygrad.codegen.assembly import AssemblyCodegen
from tinygrad.ops import BinaryOps, UnaryOps, FusedOps
from tinygrad.codegen.linearizer import UOps
from tinygrad.helpers import dtypes
dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "u16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32"}
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
# https://docs.nvidia.com/cuda/parallel-thread-execution/#
class PTXCodegen(AssemblyCodegen):
#supports_constant_folding: bool = True
def specialize(self, asm):
ins = [".version 7.8", ".target sm_86", ".address_size 64",
f".visible .entry test({', '.join(f'.param .u64 buf{i}' for i in range(len(self.bufs)))}) {{"]
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", BinaryOps.CMPEQ: "setp.eq",
UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
FusedOps.MULACC: "fma.rn"}
for uop, out, vin, arg in asm:
if uop == UOps.DEFINE_REGISTER:
ins.append(f".reg .{dtype_to_nvtype[arg[0]]} %{arg[1]}<{arg[2]}>;",)
elif uop == UOps.DEFINE_LOCAL:
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
elif uop == UOps.SPECIAL:
if arg.startswith('buf'):
ins.append(f"ld.param.u64 {out}, [{arg}];")
# TODO: is this needed?
#ins.append(f"cvta.to.global.u64 {out}, {out};")
elif arg.startswith('gid'):
#ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
ins.append("{ .reg .b32 %tmp<3>;")
l = 'xyz'[int(arg[3:])]
ins.append(f"mov.u32 %tmp0, %ctaid.{l};")
ins.append(f"mov.u32 %tmp1, %ntid.{l};")
ins.append(f"mov.u32 %tmp2, %tid.{l};")
ins.append(f"mad.lo.s32 {out}, %tmp0, %tmp1, %tmp2; }}")
elif arg.startswith('lid'):
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
elif uop == UOps.ALU:
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
else:
otype = vin[0].dtype if arg in [BinaryOps.CMPEQ, BinaryOps.CMPLT] else out.dtype
ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
elif uop == UOps.LOAD:
ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[out.dtype]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
elif uop == UOps.STORE:
ins.append(f"st.{arg[1]}.{dtype_to_nvtype[vin[1].dtype]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
elif uop == UOps.CAST:
if vin[0].dtype == dtypes.bool:
ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, 0f3F800000, 0f00000000, {vin[0]};")
else:
ins.append(f"cvt.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[vin[0].dtype]} {out}, {vin[0]};")
elif uop == UOps.CONST:
ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else arg};")
elif uop == UOps.LABEL:
ins.append(f"{arg}:")
elif uop == UOps.COND_BRANCH:
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
ins += ["ret;", "}"]
return "test", '\n'.join(ins)

View File

@@ -2,7 +2,7 @@ from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple,
import math, collections
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer, LocalTypes
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
from tinygrad.helpers import partition, ImageDType, DEBUG, dtypes, colored, prod
from tinygrad.helpers import partition, ImageDType, DEBUG, dtypes, colored
from tinygrad.runtime.lib import RawConst
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode
from tinygrad.lazy import LazyBuffer
@@ -188,15 +188,7 @@ class CStyleCodegen(Linearizer):
def codegen(self):
self.process()
self.hand_coded_optimizations()
# sometimes, there's more dimensions than len(self.lang.gid).
# compact all the dimensions into the first
# NOTE: this might make multiview shapetrackers
if len(self.lang.gid) and self.first_reduce > len(self.lang.gid):
num_to_merge = (self.first_reduce - len(self.lang.gid))+1
self.reshape_and_permute(lambda x: (prod(x[0:num_to_merge]),)+x[num_to_merge:], None)
if DEBUG >= 4: print("reshaped to", self.full_shape, "due to too many global dimensions")
self.limit_global_dims(len(self.lang.gid))
self.linearize()
prg, global_size, local_size = uops_to_cstyle(self.uops, self.bufs, self.lang)

View File

@@ -10,7 +10,9 @@ from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
from tinygrad.shape.symbolic import Variable
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto() # noqa: E702
# bottom ones are asm only
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto(); \
SPECIAL = auto(); DEFINE_REGISTER = auto(); LABEL = auto(); COND_BRANCH = auto() # noqa: E702
class LocalBuffer(NamedTuple):
dtype: DType = dtypes.float32
@@ -453,6 +455,15 @@ class Linearizer:
self.shift_to(unit_stride_axes_mul_4[0], 4)
self.upcast()
def limit_global_dims(self, limit):
# sometimes, there's more dimensions than len(self.lang.gid).
# compact all the dimensions into the first
# NOTE: this might make multiview shapetrackers
if limit and self.first_reduce > limit:
num_to_merge = (self.first_reduce - limit)+1
self.reshape_and_permute(lambda x: (prod(x[0:num_to_merge]),)+x[num_to_merge:], None)
if DEBUG >= 4: print("reshaped to", self.full_shape, "due to too many global dimensions")
def hand_coded_optimizations(self):
# if there's images in the earlybufs, we have to make an axis the 4 loading one
self.required_optimizations(early_only=True)
@@ -523,3 +534,10 @@ class Linearizer:
self.shift_to(len(self.full_unupcasted_shape)-1, splits, insert_before=len(self.full_unupcasted_shape))
self.upcast()
break
# if nothing at all is upcasted and it's easy to, do an upcast
# TODO: this is breaking the tests
#for splits in [4]:
# if self.upcasted == 0 and len(self.full_unupcasted_shape) > 0 and self.full_unupcasted_shape[-1] % splits == 0:
# self.shift_to(len(self.full_unupcasted_shape)-1, splits, insert_before=len(self.full_unupcasted_shape))
# self.upcast()

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass, asdict
import os, math, functools, time
import os, math, functools, time, re
import numpy as np
from typing import Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any
ShapeType = Tuple[int, ...]
@@ -12,6 +12,7 @@ def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True
def colored(st, color, background=False, bright=False): return f"\u001b[{10*background+60*bright+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color)}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line
def ansilen(s): return len(re.sub('\x1b\\[(K|.*?m)', '', s))
def partition(lst, fxn): return [x for x in lst if fxn(x)], [x for x in lst if not fxn(x)]
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
@@ -34,6 +35,7 @@ class ContextVar:
def __bool__(self): return self.value != 0
def __ge__(self, x): return self.value >= x
def __gt__(self, x): return self.value > x
def __lt__(self, x): return self.value < x
@property
def value(self): return ContextVar.ctx_stack[-1][self.key] if self.key in ContextVar.ctx_stack[-1] else self.initial_value
@@ -71,7 +73,7 @@ class dtypes:
@staticmethod
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32)
@staticmethod
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8)
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint32, dtypes.uint64)
@staticmethod
def from_np(x) -> DType: return asdict(dtypes())[np.dtype(x).name]
bool: Final[DType] = DType(0, 1, "bool", bool)
@@ -81,6 +83,8 @@ class dtypes:
int32: Final[DType] = DType(1, 4, "int", np.int32)
int64: Final[DType] = DType(2, 8, "int64", np.int64)
uint8: Final[DType] = DType(0, 1, "uchar", np.uint8)
uint32: Final[DType] = DType(1, 4, "uint", np.uint32)
uint64: Final[DType] = DType(2, 8, "uint64", np.uint64)
class GlobalCounters:

View File

@@ -2,14 +2,15 @@ from __future__ import annotations
import functools, itertools, operator, random, time
from enum import Enum, auto
from typing import Union, Type, NamedTuple, Tuple, Any, List, Optional, Dict, Callable, ClassVar
from tinygrad.helpers import prod, DEBUG, getenv, GlobalCounters, DType, colored
from tinygrad.helpers import prod, DEBUG, getenv, GlobalCounters, DType, colored, ansilen
from tinygrad.shape.shapetracker import MovementOps
from tinygrad.runtime.lib import RawBuffer, RawConst
# these are the llops your accelerator must implement, along with toCpu
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto() # noqa: E702
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class FusedOps(Enum): MULACC = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
@@ -79,12 +80,12 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.ex
# **************** for Compiled Buffers ****************
class ASTRunner:
def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None):
if DEBUG >= 4: print(prg)
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name
def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args): print(prg)
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
def build(self, runtime):
self.clprg = runtime(self.name, self.prg)
self.clprg = runtime(self.name, self.prg, **self.runtime_args)
return self
def exec(self, bufs) -> Optional[float]:
@@ -96,7 +97,7 @@ class ASTRunner:
if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs, allow_cache=(getenv("OPTLOCAL") >= 2))
if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(29-len(self.name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(29-ansilen(self.display_name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS, {self.mem_estimate/(et*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += self.op_estimate

View File

@@ -4,10 +4,11 @@ import numpy as np
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
import pycuda.driver as cuda # type: ignore
from pycuda.compiler import compile as cuda_compile # type: ignore
from tinygrad.helpers import DEBUG
from tinygrad.helpers import DEBUG, getenv
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut
from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage
from tinygrad.codegen.assembly_ptx import PTXCodegen
class RawCUDABuffer(RawBufferCopyInOut):
def __init__(self, size, dtype): super().__init__(size, dtype, cuda.mem_alloc(size * dtype.itemsize))
@@ -59,4 +60,4 @@ class CUDACodegen(CStyleCodegen):
typedef long long int64;
""")
supports_float4_alu = False
CUDABuffer = Compiled(RawCUDABuffer, CUDACodegen, CUDAProgram, cuda.Context.synchronize)
CUDABuffer = Compiled(RawCUDABuffer, PTXCodegen if getenv("PTX") else CUDACodegen, CUDAProgram, cuda.Context.synchronize)

View File

@@ -19,6 +19,7 @@ class Node:
@functools.cached_property
def key(self) -> str: return self.render(ctx="DEBUG")
def __repr__(self): return "<"+self.key+">"
def __hash__(self): return hash(self.__repr__())
def __eq__(self, other:object) -> bool:
if not isinstance(other, Node): return NotImplemented
return self.key == other.key

View File

@@ -471,7 +471,7 @@ class Tensor:
def cumsum(self, axis=0):
x = self.permute(*(i for i in range(self.ndim) if i != axis), axis)
return x.reshape(1, 1, -1, self.shape[axis]).conv2d(Tensor.ones(1, 1, 1, self.shape[axis], dtype=self.dtype, device=self.device), padding=(self.shape[axis]-1, 0, 0, 0)).reshape(*x.shape).permute(*range(axis), self.ndim - 1, *range(axis, self.ndim-1))
# ***** mlops (unary) *****
def contiguous(self): return mlops.Contiguous.apply(self)
@@ -481,12 +481,12 @@ class Tensor:
def sin(self): return mlops.Sin.apply(self)
def cos(self): return ((math.pi/2)-self).sin()
def tan(self): return self.sin() / self.cos()
@staticmethod
def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(c-k, start=-k, **kwargs).unsqueeze(0).expand(r,c)
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype).where(self, Tensor.zeros_like(self))
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype).where(Tensor.zeros_like(self), self)
# ***** math functions (unary) *****
def __neg__(self): return 0.0-self
@@ -527,7 +527,12 @@ class Tensor:
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Add, x, reverse) if isinstance(x, Tensor) or x != 0.0 else self
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if isinstance(x, Tensor) or x != 0.0 or reverse else self
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if isinstance(x, Tensor) or x != 1.0 else self
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Pow, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor:
if not isinstance(x, Tensor) and not reverse:
# simple pow identities
if x == 2.0: return self*self
if x == -1.0: return 1/self
return self._broadcasted(mlops.Pow, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if isinstance(x, Tensor) or reverse or x == 0.0 else self.mul(1/x)
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)