mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
PTX assembly support (#977)
* ptx assembly * all ops tests pass * fix tests
This commit is contained in:
166
tinygrad/codegen/assembly.py
Normal file
166
tinygrad/codegen/assembly.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps
|
||||
from tinygrad.ops import ASTRunner, FusedOps, BinaryOps, UnaryOps
|
||||
from tinygrad.helpers import DType, dtypes, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
||||
import functools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
|
||||
type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'I', dtypes.uint64: 'A'}
|
||||
|
||||
class Register(NamedTuple):
|
||||
nm:str
|
||||
dtype:DType
|
||||
def __repr__(self): return self.nm
|
||||
|
||||
class AssemblyInstruction(NamedTuple):
|
||||
op: UOps
|
||||
out: Optional[Register]
|
||||
vin: List[Union[Register, int, float]]
|
||||
arg: Any = None
|
||||
|
||||
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
|
||||
class AssemblyCodegen(Linearizer):
|
||||
supports_load3: bool = False
|
||||
|
||||
def specialize(self, asm:List[AssemblyInstruction]) -> Tuple[str, str]:
|
||||
raise NotImplementedError("must be implemented")
|
||||
|
||||
# s registers are the addresses and non local indexes
|
||||
def codegen(self):
|
||||
self.process()
|
||||
self.hand_coded_optimizations()
|
||||
self.limit_global_dims(3) # all GPU asms have 3 (for now)
|
||||
self.linearize()
|
||||
|
||||
cnts:DefaultDict[DType, int] = defaultdict(int)
|
||||
tor: Dict[Any, Register] = {}
|
||||
def newreg(tok, dtype=dtypes.float32):
|
||||
nonlocal cnts, tor
|
||||
tor[tok] = ret = Register(f"%{type_to_letter[dtype]}{cnts[dtype]}", dtype)
|
||||
cnts[dtype] += 1
|
||||
return ret
|
||||
|
||||
def render_numnode(b):
|
||||
key = ("num", b)
|
||||
if key not in tor: ins.append(AssemblyInstruction(UOps.CONST, newreg(key, dtype=dtypes.int32), [], b))
|
||||
return tor[key]
|
||||
|
||||
def render_alu(op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
|
||||
key = (op, a, b)
|
||||
if key not in tor:
|
||||
#if not isinstance(b, Register): b = render_numnode(b)
|
||||
ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype), [a, b], op))
|
||||
return tor[key]
|
||||
|
||||
def render_cast(a:Register, new_dtype:DType) -> Register:
|
||||
if a.dtype == new_dtype: return a
|
||||
key = (a, new_dtype)
|
||||
if key not in tor:
|
||||
ins.append(AssemblyInstruction(UOps.CAST, newreg(key, dtype=new_dtype), [a]))
|
||||
return tor[key]
|
||||
|
||||
render_ops = { Variable: lambda self, ops, ctx: tor[self], NumNode: lambda self, ops, ctx: render_numnode(self.b),
|
||||
MulNode: lambda self, ops, ctx: render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
|
||||
DivNode: lambda self, ops, ctx: render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
|
||||
ModNode: lambda self, ops, ctx: render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
|
||||
LtNode: lambda self, ops, ctx: render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
def addr_w_offset(args):
|
||||
idx = args.idx*self.bufs[args.i].dtype.itemsize
|
||||
off = 0 # TODO: should this be None?
|
||||
if isinstance(idx, SumNode) and not self.supports_load3:
|
||||
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
|
||||
if len(nums) > 0:
|
||||
idx -= nums[0]
|
||||
off = nums[0]
|
||||
reg = idx.render(render_ops)
|
||||
if self.supports_load3:
|
||||
return tor[f"buf{args.i}"], reg
|
||||
else:
|
||||
reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64)
|
||||
return reg, off
|
||||
|
||||
ins = []
|
||||
ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64), [], f"buf{i}") for i in range(len(self.bufs))]
|
||||
global_size, local_size = [], []
|
||||
skipload_branch = 0
|
||||
for uop,newvar,vin,args in self.uops:
|
||||
if uop == UOps.CONST and newvar is not None:
|
||||
ins.append(AssemblyInstruction(UOps.CONST, newreg(newvar), [], args))
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, newreg("buf-1", dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
|
||||
elif uop == UOps.LOOP:
|
||||
if args[1] == "global":
|
||||
for i,var in enumerate(args[0]):
|
||||
global_size.append(var.max+1)
|
||||
ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
|
||||
elif args[1] == "local":
|
||||
for i,var in enumerate(args[0]):
|
||||
local_size.append(var.max+1)
|
||||
global_size[i] *= local_size[i]
|
||||
ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
|
||||
else:
|
||||
for var in args[0]:
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
ins.append(AssemblyInstruction(UOps.CONST, newreg(var, dtype=dtypes.int32), [], 0))
|
||||
ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
|
||||
elif uop == UOps.ENDLOOP:
|
||||
if args[1] not in ["global", "local"]:
|
||||
for var in reversed(args[0]):
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
pred = render_alu(BinaryOps.CMPLT, tor[var], var.max, dtypes.bool)
|
||||
ins.append(AssemblyInstruction(UOps.ALU, tor[var], [tor[var], 1], BinaryOps.ADD))
|
||||
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
|
||||
elif uop == UOps.ALU and newvar is not None:
|
||||
if args == FusedOps.MULACC: vin = [vin[1], vin[2], vin[0]] # TODO: reorder MULACC everywhere
|
||||
# this is the only thing that can violate SSA
|
||||
if args in [BinaryOps.CMPEQ, BinaryOps.CMPLT]:
|
||||
pred_reg = newreg((newvar, 'pred'), dtype=dtypes.bool)
|
||||
ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [tor[x] for x in vin], args))
|
||||
ins.append(AssemblyInstruction(UOps.CAST, newreg(newvar), [pred_reg], args))
|
||||
elif args == BinaryOps.POW:
|
||||
# TODO: add UnaryOps.SQRT
|
||||
tmp = newreg((newvar, "exp_a"))
|
||||
tmp2 = newreg((newvar, "exp_a_times_b"))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]]], UnaryOps.LOG2))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, tmp2, [tmp, tor[vin[1]]], BinaryOps.MUL))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, newreg(newvar), [tmp2], UnaryOps.EXP2))
|
||||
elif args == UnaryOps.SIN and hasattr(self, 'sin_is_sin2pi'):
|
||||
tmp = newreg((newvar, "2pi"))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
|
||||
ins.append(AssemblyInstruction(UOps.ALU, newreg(newvar) if newvar not in tor else tor[newvar], [tmp], args))
|
||||
else:
|
||||
ins.append(AssemblyInstruction(UOps.ALU, newreg(newvar) if newvar not in tor else tor[newvar], [tor[x] for x in vin], args))
|
||||
elif uop == UOps.LOAD and newvar is not None:
|
||||
idx, off = addr_w_offset(args)
|
||||
reg = newreg(newvar)
|
||||
if args.valid.min == 0:
|
||||
ins.append(AssemblyInstruction(UOps.CONST, reg, [], 0))
|
||||
if args.valid.max == 1:
|
||||
pred = args.valid.render(render_ops)
|
||||
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||||
if args.valid.max == 1:
|
||||
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
||||
ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx], (off, 'global' if args.i != -1 else 'shared')))
|
||||
if args.valid.min == 0 and args.valid.max == 1:
|
||||
ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||||
skipload_branch += 1
|
||||
elif uop == UOps.STORE:
|
||||
idx, off = addr_w_offset(args)
|
||||
ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]], (off, 'global' if args.i != -1 else 'shared')))
|
||||
|
||||
# define registers
|
||||
ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter[dtype], c)) for dtype,c in cnts.items()] + ins
|
||||
|
||||
if DEBUG >= 4:
|
||||
for tins in ins: print(tins)
|
||||
name, asm = self.specialize(ins)
|
||||
|
||||
return ASTRunner(name, asm,
|
||||
global_size[::-1] if len(global_size) else [1], local_size[::-1] if len(local_size) else None,
|
||||
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name, runtime_args={"binary": True})
|
||||
66
tinygrad/codegen/assembly_ptx.py
Normal file
66
tinygrad/codegen/assembly_ptx.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import struct
|
||||
from tinygrad.codegen.assembly import AssemblyCodegen
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, FusedOps
|
||||
from tinygrad.codegen.linearizer import UOps
|
||||
from tinygrad.helpers import dtypes
|
||||
|
||||
dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "u16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32"}
|
||||
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#
|
||||
class PTXCodegen(AssemblyCodegen):
|
||||
#supports_constant_folding: bool = True
|
||||
|
||||
def specialize(self, asm):
|
||||
ins = [".version 7.8", ".target sm_86", ".address_size 64",
|
||||
f".visible .entry test({', '.join(f'.param .u64 buf{i}' for i in range(len(self.bufs)))}) {{"]
|
||||
|
||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
||||
BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", BinaryOps.CMPEQ: "setp.eq",
|
||||
UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
|
||||
FusedOps.MULACC: "fma.rn"}
|
||||
|
||||
for uop, out, vin, arg in asm:
|
||||
if uop == UOps.DEFINE_REGISTER:
|
||||
ins.append(f".reg .{dtype_to_nvtype[arg[0]]} %{arg[1]}<{arg[2]}>;",)
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
|
||||
elif uop == UOps.SPECIAL:
|
||||
if arg.startswith('buf'):
|
||||
ins.append(f"ld.param.u64 {out}, [{arg}];")
|
||||
# TODO: is this needed?
|
||||
#ins.append(f"cvta.to.global.u64 {out}, {out};")
|
||||
elif arg.startswith('gid'):
|
||||
#ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
|
||||
ins.append("{ .reg .b32 %tmp<3>;")
|
||||
l = 'xyz'[int(arg[3:])]
|
||||
ins.append(f"mov.u32 %tmp0, %ctaid.{l};")
|
||||
ins.append(f"mov.u32 %tmp1, %ntid.{l};")
|
||||
ins.append(f"mov.u32 %tmp2, %tid.{l};")
|
||||
ins.append(f"mad.lo.s32 {out}, %tmp0, %tmp1, %tmp2; }}")
|
||||
elif arg.startswith('lid'):
|
||||
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
|
||||
elif uop == UOps.ALU:
|
||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||
ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
|
||||
else:
|
||||
otype = vin[0].dtype if arg in [BinaryOps.CMPEQ, BinaryOps.CMPLT] else out.dtype
|
||||
ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
|
||||
elif uop == UOps.LOAD:
|
||||
ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[out.dtype]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
|
||||
elif uop == UOps.STORE:
|
||||
ins.append(f"st.{arg[1]}.{dtype_to_nvtype[vin[1].dtype]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
|
||||
elif uop == UOps.CAST:
|
||||
if vin[0].dtype == dtypes.bool:
|
||||
ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, 0f3F800000, 0f00000000, {vin[0]};")
|
||||
else:
|
||||
ins.append(f"cvt.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[vin[0].dtype]} {out}, {vin[0]};")
|
||||
elif uop == UOps.CONST:
|
||||
ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else arg};")
|
||||
elif uop == UOps.LABEL:
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
|
||||
|
||||
ins += ["ret;", "}"]
|
||||
return "test", '\n'.join(ins)
|
||||
@@ -2,7 +2,7 @@ from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple,
|
||||
import math, collections
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer, LocalTypes
|
||||
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
|
||||
from tinygrad.helpers import partition, ImageDType, DEBUG, dtypes, colored, prod
|
||||
from tinygrad.helpers import partition, ImageDType, DEBUG, dtypes, colored
|
||||
from tinygrad.runtime.lib import RawConst
|
||||
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
@@ -188,15 +188,7 @@ class CStyleCodegen(Linearizer):
|
||||
def codegen(self):
|
||||
self.process()
|
||||
self.hand_coded_optimizations()
|
||||
|
||||
# sometimes, there's more dimensions than len(self.lang.gid).
|
||||
# compact all the dimensions into the first
|
||||
# NOTE: this might make multiview shapetrackers
|
||||
if len(self.lang.gid) and self.first_reduce > len(self.lang.gid):
|
||||
num_to_merge = (self.first_reduce - len(self.lang.gid))+1
|
||||
self.reshape_and_permute(lambda x: (prod(x[0:num_to_merge]),)+x[num_to_merge:], None)
|
||||
if DEBUG >= 4: print("reshaped to", self.full_shape, "due to too many global dimensions")
|
||||
|
||||
self.limit_global_dims(len(self.lang.gid))
|
||||
self.linearize()
|
||||
|
||||
prg, global_size, local_size = uops_to_cstyle(self.uops, self.bufs, self.lang)
|
||||
|
||||
@@ -10,7 +10,9 @@ from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto() # noqa: E702
|
||||
# bottom ones are asm only
|
||||
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto(); \
|
||||
SPECIAL = auto(); DEFINE_REGISTER = auto(); LABEL = auto(); COND_BRANCH = auto() # noqa: E702
|
||||
|
||||
class LocalBuffer(NamedTuple):
|
||||
dtype: DType = dtypes.float32
|
||||
@@ -453,6 +455,15 @@ class Linearizer:
|
||||
self.shift_to(unit_stride_axes_mul_4[0], 4)
|
||||
self.upcast()
|
||||
|
||||
def limit_global_dims(self, limit):
|
||||
# sometimes, there's more dimensions than len(self.lang.gid).
|
||||
# compact all the dimensions into the first
|
||||
# NOTE: this might make multiview shapetrackers
|
||||
if limit and self.first_reduce > limit:
|
||||
num_to_merge = (self.first_reduce - limit)+1
|
||||
self.reshape_and_permute(lambda x: (prod(x[0:num_to_merge]),)+x[num_to_merge:], None)
|
||||
if DEBUG >= 4: print("reshaped to", self.full_shape, "due to too many global dimensions")
|
||||
|
||||
def hand_coded_optimizations(self):
|
||||
# if there's images in the earlybufs, we have to make an axis the 4 loading one
|
||||
self.required_optimizations(early_only=True)
|
||||
@@ -523,3 +534,10 @@ class Linearizer:
|
||||
self.shift_to(len(self.full_unupcasted_shape)-1, splits, insert_before=len(self.full_unupcasted_shape))
|
||||
self.upcast()
|
||||
break
|
||||
|
||||
# if nothing at all is upcasted and it's easy to, do an upcast
|
||||
# TODO: this is breaking the tests
|
||||
#for splits in [4]:
|
||||
# if self.upcasted == 0 and len(self.full_unupcasted_shape) > 0 and self.full_unupcasted_shape[-1] % splits == 0:
|
||||
# self.shift_to(len(self.full_unupcasted_shape)-1, splits, insert_before=len(self.full_unupcasted_shape))
|
||||
# self.upcast()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass, asdict
|
||||
import os, math, functools, time
|
||||
import os, math, functools, time, re
|
||||
import numpy as np
|
||||
from typing import Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any
|
||||
ShapeType = Tuple[int, ...]
|
||||
@@ -12,6 +12,7 @@ def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0
|
||||
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
||||
def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True
|
||||
def colored(st, color, background=False, bright=False): return f"\u001b[{10*background+60*bright+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color)}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line
|
||||
def ansilen(s): return len(re.sub('\x1b\\[(K|.*?m)', '', s))
|
||||
def partition(lst, fxn): return [x for x in lst if fxn(x)], [x for x in lst if not fxn(x)]
|
||||
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
|
||||
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
|
||||
@@ -34,6 +35,7 @@ class ContextVar:
|
||||
def __bool__(self): return self.value != 0
|
||||
def __ge__(self, x): return self.value >= x
|
||||
def __gt__(self, x): return self.value > x
|
||||
def __lt__(self, x): return self.value < x
|
||||
@property
|
||||
def value(self): return ContextVar.ctx_stack[-1][self.key] if self.key in ContextVar.ctx_stack[-1] else self.initial_value
|
||||
|
||||
@@ -71,7 +73,7 @@ class dtypes:
|
||||
@staticmethod
|
||||
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32)
|
||||
@staticmethod
|
||||
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8)
|
||||
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint32, dtypes.uint64)
|
||||
@staticmethod
|
||||
def from_np(x) -> DType: return asdict(dtypes())[np.dtype(x).name]
|
||||
bool: Final[DType] = DType(0, 1, "bool", bool)
|
||||
@@ -81,6 +83,8 @@ class dtypes:
|
||||
int32: Final[DType] = DType(1, 4, "int", np.int32)
|
||||
int64: Final[DType] = DType(2, 8, "int64", np.int64)
|
||||
uint8: Final[DType] = DType(0, 1, "uchar", np.uint8)
|
||||
uint32: Final[DType] = DType(1, 4, "uint", np.uint32)
|
||||
uint64: Final[DType] = DType(2, 8, "uint64", np.uint64)
|
||||
|
||||
|
||||
class GlobalCounters:
|
||||
|
||||
@@ -2,14 +2,15 @@ from __future__ import annotations
|
||||
import functools, itertools, operator, random, time
|
||||
from enum import Enum, auto
|
||||
from typing import Union, Type, NamedTuple, Tuple, Any, List, Optional, Dict, Callable, ClassVar
|
||||
from tinygrad.helpers import prod, DEBUG, getenv, GlobalCounters, DType, colored
|
||||
from tinygrad.helpers import prod, DEBUG, getenv, GlobalCounters, DType, colored, ansilen
|
||||
from tinygrad.shape.shapetracker import MovementOps
|
||||
from tinygrad.runtime.lib import RawBuffer, RawConst
|
||||
|
||||
# these are the llops your accelerator must implement, along with toCpu
|
||||
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
|
||||
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
|
||||
class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto() # noqa: E702
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class FusedOps(Enum): MULACC = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
|
||||
@@ -79,12 +80,12 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.ex
|
||||
# **************** for Compiled Buffers ****************
|
||||
|
||||
class ASTRunner:
|
||||
def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None):
|
||||
if DEBUG >= 4: print(prg)
|
||||
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name
|
||||
def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
|
||||
if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args): print(prg)
|
||||
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
|
||||
|
||||
def build(self, runtime):
|
||||
self.clprg = runtime(self.name, self.prg)
|
||||
self.clprg = runtime(self.name, self.prg, **self.runtime_args)
|
||||
return self
|
||||
|
||||
def exec(self, bufs) -> Optional[float]:
|
||||
@@ -96,7 +97,7 @@ class ASTRunner:
|
||||
if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs, allow_cache=(getenv("OPTLOCAL") >= 2))
|
||||
if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(29-len(self.name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(29-ansilen(self.display_name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS, {self.mem_estimate/(et*1e9):7.2f} GB/s)"))
|
||||
GlobalCounters.kernel_count += 1
|
||||
GlobalCounters.global_ops += self.op_estimate
|
||||
|
||||
@@ -4,10 +4,11 @@ import numpy as np
|
||||
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
|
||||
import pycuda.driver as cuda # type: ignore
|
||||
from pycuda.compiler import compile as cuda_compile # type: ignore
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.helpers import DEBUG, getenv
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut
|
||||
from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage
|
||||
from tinygrad.codegen.assembly_ptx import PTXCodegen
|
||||
|
||||
class RawCUDABuffer(RawBufferCopyInOut):
|
||||
def __init__(self, size, dtype): super().__init__(size, dtype, cuda.mem_alloc(size * dtype.itemsize))
|
||||
@@ -59,4 +60,4 @@ class CUDACodegen(CStyleCodegen):
|
||||
typedef long long int64;
|
||||
""")
|
||||
supports_float4_alu = False
|
||||
CUDABuffer = Compiled(RawCUDABuffer, CUDACodegen, CUDAProgram, cuda.Context.synchronize)
|
||||
CUDABuffer = Compiled(RawCUDABuffer, PTXCodegen if getenv("PTX") else CUDACodegen, CUDAProgram, cuda.Context.synchronize)
|
||||
|
||||
@@ -19,6 +19,7 @@ class Node:
|
||||
@functools.cached_property
|
||||
def key(self) -> str: return self.render(ctx="DEBUG")
|
||||
def __repr__(self): return "<"+self.key+">"
|
||||
def __hash__(self): return hash(self.__repr__())
|
||||
def __eq__(self, other:object) -> bool:
|
||||
if not isinstance(other, Node): return NotImplemented
|
||||
return self.key == other.key
|
||||
|
||||
@@ -471,7 +471,7 @@ class Tensor:
|
||||
def cumsum(self, axis=0):
|
||||
x = self.permute(*(i for i in range(self.ndim) if i != axis), axis)
|
||||
return x.reshape(1, 1, -1, self.shape[axis]).conv2d(Tensor.ones(1, 1, 1, self.shape[axis], dtype=self.dtype, device=self.device), padding=(self.shape[axis]-1, 0, 0, 0)).reshape(*x.shape).permute(*range(axis), self.ndim - 1, *range(axis, self.ndim-1))
|
||||
|
||||
|
||||
# ***** mlops (unary) *****
|
||||
|
||||
def contiguous(self): return mlops.Contiguous.apply(self)
|
||||
@@ -481,12 +481,12 @@ class Tensor:
|
||||
def sin(self): return mlops.Sin.apply(self)
|
||||
def cos(self): return ((math.pi/2)-self).sin()
|
||||
def tan(self): return self.sin() / self.cos()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(c-k, start=-k, **kwargs).unsqueeze(0).expand(r,c)
|
||||
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype).where(self, Tensor.zeros_like(self))
|
||||
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype).where(Tensor.zeros_like(self), self)
|
||||
|
||||
|
||||
# ***** math functions (unary) *****
|
||||
|
||||
def __neg__(self): return 0.0-self
|
||||
@@ -527,7 +527,12 @@ class Tensor:
|
||||
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Add, x, reverse) if isinstance(x, Tensor) or x != 0.0 else self
|
||||
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if isinstance(x, Tensor) or x != 0.0 or reverse else self
|
||||
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if isinstance(x, Tensor) or x != 1.0 else self
|
||||
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Pow, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self
|
||||
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
if not isinstance(x, Tensor) and not reverse:
|
||||
# simple pow identities
|
||||
if x == 2.0: return self*self
|
||||
if x == -1.0: return 1/self
|
||||
return self._broadcasted(mlops.Pow, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self
|
||||
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if isinstance(x, Tensor) or reverse or x == 0.0 else self.mul(1/x)
|
||||
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user