mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge remote-tracking branch 'upstream/master' into triton
This commit is contained in:
@@ -40,6 +40,11 @@ def apply_rotary_emb(xq, xk, freqs_cis) -> Tuple[Tensor, Tensor]:
|
||||
xk_out = complex_mult(xk, c, d)
|
||||
return xq_out.flatten(3), xk_out.flatten(3)
|
||||
|
||||
def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
|
||||
bs, seqlen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1: return x
|
||||
return x[:, :, :, None, :].expand(bs, seqlen, n_kv_heads, n_rep, head_dim).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
|
||||
|
||||
class RMSNorm:
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
self.eps = eps
|
||||
@@ -50,14 +55,22 @@ class RMSNorm:
|
||||
return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight
|
||||
|
||||
class Attention:
|
||||
def __init__(self, dim, n_heads, linear=Linear):
|
||||
self.wq, self.wk, self.wv, self.wo = [linear(dim, dim, bias=False) for _ in range(4)]
|
||||
def __init__(self, dim, n_heads, n_kv_heads, linear=Linear):
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
|
||||
self.head_dim = dim // n_heads
|
||||
self.n_rep = self.n_heads // self.n_kv_heads
|
||||
|
||||
self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
|
||||
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
def prepare_attention(self, x:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], self.n_heads, self.head_dim) for x in (xq, xk, xv)]
|
||||
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
|
||||
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
return xq, xk, xv
|
||||
|
||||
@@ -74,6 +87,8 @@ class Attention:
|
||||
|
||||
# save the cache
|
||||
self.cache_k, self.cache_v = keys.realize(), values.realize()
|
||||
|
||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||
return Tensor.scaled_dot_product_attention(xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), mask).transpose(1, 2).reshape(bsz, seqlen, -1)
|
||||
|
||||
# NOTE: this is not called
|
||||
@@ -98,8 +113,8 @@ class FeedForward:
|
||||
return self.w2(self.w1(x).silu() * self.w3(x))
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim, multiple_of, n_heads, norm_eps, linear=Linear, ffn_dim_multiplier=None):
|
||||
self.attention = Attention(dim, n_heads, linear)
|
||||
def __init__(self, dim, multiple_of, n_heads, n_kv_heads, norm_eps, linear=Linear, ffn_dim_multiplier=None):
|
||||
self.attention = Attention(dim, n_heads, n_kv_heads, linear)
|
||||
self.feed_forward = FeedForward(dim, 4*dim, multiple_of, linear, ffn_dim_multiplier)
|
||||
self.attention_norm = RMSNorm(dim, norm_eps)
|
||||
self.ffn_norm = RMSNorm(dim, norm_eps)
|
||||
@@ -125,8 +140,8 @@ class TransformerBlock:
|
||||
return self._post(x, output) if mask is None else self.post(x, output)
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None):
|
||||
self.layers = [TransformerBlock(dim, multiple_of, n_heads, norm_eps, linear, ffn_dim_multiplier) for _ in range(n_layers)]
|
||||
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None, n_kv_heads=None):
|
||||
self.layers = [TransformerBlock(dim, multiple_of, n_heads, n_kv_heads, norm_eps, linear, ffn_dim_multiplier) for _ in range(n_layers)]
|
||||
self.norm = RMSNorm(dim, norm_eps)
|
||||
self.tok_embeddings = Embedding(vocab_size, dim)
|
||||
self.output = linear(dim, vocab_size, bias=False)
|
||||
@@ -173,11 +188,10 @@ MODEL_PARAMS = {
|
||||
"args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
|
||||
"files": 2,
|
||||
},
|
||||
# # 70B is disabled because we do not yet implement n_kv_heads argument
|
||||
# "70B": {
|
||||
# "args": {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
|
||||
# "files": 8,
|
||||
# },
|
||||
"70B": {
|
||||
"args": {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE},
|
||||
"files": 8,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -277,7 +291,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--temperature', type=float, default=0.7, help="Temperature in the softmax")
|
||||
parser.add_argument('--timing', action='store_true', help="Print timing per token")
|
||||
parser.add_argument('--profile', action='store_true', help="Output profile data to out.prof")
|
||||
parser.add_argument('--size', type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B] for Gen 1, [7B, 13B] for Gen 2")
|
||||
parser.add_argument('--size', type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B] for Gen 1, [7B, 13B, 70B] for Gen 2")
|
||||
parser.add_argument('--gen', type=int, default="1", help="Generation of the model to use [1, 2]")
|
||||
parser.add_argument('--quantize', action='store_true', help="Quantize the weights to int8 in memory")
|
||||
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
import struct
|
||||
from tinygrad.codegen.assembly import AssemblyCodegen
|
||||
from tinygrad.codegen.linearizer import UOps
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.runtime.ops_cuda import arch
|
||||
|
||||
dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "u16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32"}
|
||||
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#
|
||||
class PTXCodegen(AssemblyCodegen):
|
||||
#supports_constant_folding: bool = True
|
||||
|
||||
def specialize(self, asm):
|
||||
ins = [".version 8.2", ".target " + arch(), ".address_size 64",
|
||||
f".visible .entry test({', '.join(f'.param .u64 buf{i}' for i in range(len(self.bufs)))}) {{"]
|
||||
|
||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
||||
BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", BinaryOps.CMPEQ: "setp.eq", UnaryOps.SQRT: "sqrt.approx",
|
||||
UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
|
||||
TernaryOps.MULACC: "fma.rn"}
|
||||
|
||||
for uop, out, vin, arg in asm:
|
||||
if uop == UOps.DEFINE_REGISTER:
|
||||
ins.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",)
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
|
||||
elif uop == UOps.SPECIAL:
|
||||
if arg.startswith('buf'):
|
||||
ins.append(f"ld.param.u64 {out}, [{arg}];")
|
||||
# TODO: is this needed?
|
||||
#ins.append(f"cvta.to.global.u64 {out}, {out};")
|
||||
elif arg.startswith('gid'):
|
||||
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
|
||||
elif arg.startswith('lid'):
|
||||
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
|
||||
elif uop == UOps.ALU:
|
||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||
ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
|
||||
else:
|
||||
otype = vin[0].dtype if arg in [BinaryOps.CMPEQ, BinaryOps.CMPLT] else out.dtype
|
||||
ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
|
||||
elif uop == UOps.LOAD:
|
||||
ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[out.dtype]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
|
||||
elif uop == UOps.STORE:
|
||||
ins.append(f"st.{arg[1]}.{dtype_to_nvtype[vin[1].dtype]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
|
||||
elif uop == UOps.CAST:
|
||||
if vin[0].dtype == dtypes.bool:
|
||||
ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, 0f3F800000, 0f00000000, {vin[0]};")
|
||||
else:
|
||||
ins.append(f"cvt.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[vin[0].dtype]} {out}, {vin[0]};")
|
||||
elif uop == UOps.CONST:
|
||||
ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else arg};")
|
||||
elif uop == UOps.LABEL:
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
|
||||
|
||||
ins += ["ret;", "}"]
|
||||
return "test", '\n'.join(ins)
|
||||
@@ -21,7 +21,9 @@ def safe_numpy(t) -> np.ndarray:
|
||||
if t not in numpy_cache:
|
||||
if DEBUG >= 1:
|
||||
print("numpy cache miss", t)
|
||||
numpy_cache[t] = t.numpy()
|
||||
tmp = t.numpy()
|
||||
numpy_cache[t] = tmp if len(tmp.shape) else tmp.reshape(1)
|
||||
assert len(numpy_cache[t].shape) > 0
|
||||
return numpy_cache[t]
|
||||
|
||||
onnx_ops = importlib.import_module('extra.onnx_ops')
|
||||
@@ -92,7 +94,7 @@ def get_run_onnx(onnx_model: ModelProto):
|
||||
if inp.name in tensors: continue
|
||||
tmp=inp.type.optional_type.elem_type.tensor_type if inp.type.HasField("optional_type") else (inp.type.sequence_type.elem_type.tensor_type if inp.type.HasField("sequence_type") else inp.type.tensor_type)
|
||||
shape = shape_to_tuple(tmp.shape)
|
||||
if len(shape) >= 1 and shape[0] == 0: shape = tuple([1]+list(shape[1:])) # 1 batch size
|
||||
if len(shape) >= 1: shape = tuple([x if x != 0 else 1 for x in shape]) # replace all dynamic dims with 1 for now
|
||||
if inp.name in inputs:
|
||||
if isinstance(inputs[inp.name], Tensor):
|
||||
input_tensors[inp.name] = inputs[inp.name]
|
||||
@@ -183,6 +185,8 @@ def get_run_onnx(onnx_model: ModelProto):
|
||||
steps = safe_numpy(inp[4])[0] if len(inp) > 4 else 1
|
||||
starts, ends = safe_numpy(starts.cast(dtypes.int32)).tolist(), safe_numpy(ends.cast(dtypes.int32)).tolist() # TODO: when indexing is added use that
|
||||
for i,axis in enumerate(axes.tolist()):
|
||||
assert axis % 1 == 0
|
||||
axis = int(axis)
|
||||
arg[axis] = (starts[i] if starts[i] >= 0 else inp[0].shape[axis]+starts[i], ends[i] if ends[i] >= 0 else inp[0].shape[axis]+ends[i])
|
||||
ret = inp[0].slice(arg=arg)
|
||||
elif n.op_type == "Shrink":
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod, dtypes
|
||||
from extra.onnx import safe_numpy
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
import numpy as np
|
||||
import functools
|
||||
from typing import Union, Tuple
|
||||
from typing import Union, Tuple, Optional
|
||||
import math
|
||||
|
||||
def Unsqueeze(data, axes):
|
||||
@@ -201,7 +202,7 @@ def Tile(input, repeats):
|
||||
final_shape = [r*s for r,s in zip(repeats_, input.shape)]
|
||||
return input.reshape(new_shape).expand(expand_shape).reshape(final_shape)
|
||||
|
||||
def Range(start, limit, delta): return Tensor.arange(safe_numpy(start)[0], safe_numpy(limit)[0], step=safe_numpy(delta)[0])
|
||||
def Range(start, limit, delta): return Tensor.arange(*[safe_numpy(x)[0].item() for x in (start, limit, delta)])
|
||||
def Where(condition:Tensor,X:Tensor,Y:Tensor): return condition.where(X, Y).cast(X.dtype)
|
||||
|
||||
def And(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.zeros(*x.shape)).cast(dtypes.bool)
|
||||
@@ -218,10 +219,7 @@ def ConstantOfShape(input, value:Tensor=None):
|
||||
shape = [int(x) for x in safe_numpy(input)]
|
||||
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0]!=0 else 1)
|
||||
|
||||
# this is obviously wrong, but since we don't have types, it's better than nothing
|
||||
def Cast(input, to):
|
||||
print(f"WARNING: attempting to cast to {to}")
|
||||
return input
|
||||
def Cast(input, to): return input.cast(dtypes.from_np(tensor_dtype_to_np_dtype(to)))
|
||||
|
||||
# NOTE: since we only have one type, this is valid!
|
||||
def CastLike(input, target_type):
|
||||
@@ -263,3 +261,70 @@ def OneHot(indices, depth, values, axis=-1):
|
||||
|
||||
def Floor(x:Tensor): return x.floor()
|
||||
def Ceil(x:Tensor): return x.ceil()
|
||||
|
||||
def EmbedLayerNormalization(input_ids, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None, segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None, position_ids:Optional[Tensor]=None, epsilon=None, mask_index_type=None):
|
||||
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
|
||||
assert (segment_ids is None) is (segment_embedding is None)
|
||||
assert (mask is None) is (mask_index_type is None)
|
||||
assert mask is None, "functionality not supported yet" # TODO
|
||||
input_shape = input_ids.shape
|
||||
bsz, seq_length = input_shape[0], input_shape[1]
|
||||
compute_seg_emb = (segment_embedding is not None and segment_ids is not None)
|
||||
vocab_size, max_position_embeddings, type_vocab_size = word_embedding.shape[0], position_embedding.shape[0], (segment_embedding.shape[0] if compute_seg_emb else None)
|
||||
|
||||
def embedding(x:Tensor, vocab_size, weight:Tensor)->Tensor: # TODO from nn.Embedding. Could probably upstream this to Tensor
|
||||
vocab_counter = Tensor.arange(vocab_size, dtype=x.dtype, requires_grad=False).reshape(1, 1, vocab_size).expand(*x.shape, vocab_size)
|
||||
return (vocab_counter == x.unsqueeze(2).expand(*x.shape, vocab_size)) @ weight
|
||||
|
||||
# bert embedding layer
|
||||
if epsilon is None: epsilon = 1e-12
|
||||
if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
|
||||
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
|
||||
pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
|
||||
seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None
|
||||
|
||||
embedding_sum = wrd_embedding_res + pos_embedding_res + seg_embedding_res
|
||||
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
|
||||
return out, None, embedding_sum
|
||||
|
||||
def Attention(input:Tensor, weights, bias:Optional[Tensor]=None, mask_index:Optional[Tensor]=None, past:Optional[Tensor]=None, relative_position_bias:Optional[Tensor]=None, past_sequence_length:Optional[Tensor]=None, do_rotary=None, mask_filter_value=None, num_heads=None, past_present_share_buffer=None, qkv_hidden_sizes=None, scale=None, unidirectional=None):
|
||||
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
|
||||
assert num_heads is not None # required
|
||||
assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None)
|
||||
assert relative_position_bias==do_rotary==past_sequence_length==mask_filter_value==past_present_share_buffer==scale==None, "functionality not supported yet" # TODO strange params
|
||||
hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,)
|
||||
|
||||
if unidirectional: # gpt-style
|
||||
assert hidden_size == v_hidden_size
|
||||
xqkv = input.linear(weights, bias)
|
||||
xq, xk, xv = [xqkv.slice([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)]
|
||||
else: # bert-style
|
||||
wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:]
|
||||
bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None
|
||||
xq, xk, xv = [input.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))]
|
||||
xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)]
|
||||
|
||||
if past is not None:
|
||||
xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2)
|
||||
present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0))
|
||||
|
||||
def attn(query, key, value, attn_mask):
|
||||
query_length, key_length = query.shape[-2], key.shape[-2]
|
||||
cdim = max(query_length, key_length) + 1
|
||||
attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1])
|
||||
# This is where Tensor.scaled_dot_product_attention differs:
|
||||
causal_mask = Tensor.ones((cdim, cdim), requires_grad=False).cast(dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length].cast(dtypes.bool)
|
||||
return (Tensor.where(causal_mask, attn_weights, -float("inf")) + attn_mask).softmax(-1) @ value
|
||||
|
||||
bsz, _, seq_len, _ = xq.shape
|
||||
out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
return out, present
|
||||
|
||||
def SkipLayerNormalization(input:Tensor, skip:Tensor, gamma, beta:Optional[Tensor]=None, bias:Optional[Tensor]=None, epsilon=None):
|
||||
if epsilon is None: epsilon=1e-12
|
||||
x = input + skip + bias
|
||||
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
|
||||
|
||||
def FastGelu(x:Tensor, bias:Optional[Tensor]=None):
|
||||
x = x + bias
|
||||
return 0.5 * x * (1 + (x * 0.797885 + 0.035677 * x ** 3).tanh())
|
||||
|
||||
82
test/external/external_model_benchmark.py
vendored
82
test/external/external_model_benchmark.py
vendored
@@ -1,13 +1,14 @@
|
||||
import csv
|
||||
import pathlib
|
||||
import time
|
||||
import onnx
|
||||
import csv, pathlib, time, numpy as np
|
||||
from os import getenv
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
import onnx
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
import onnxruntime as ort
|
||||
from onnx2torch import convert
|
||||
from extra.utils import download_file
|
||||
from extra.onnx import get_run_onnx
|
||||
from tinygrad.helpers import OSX
|
||||
from tinygrad.helpers import OSX, DEBUG
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import Device
|
||||
|
||||
@@ -16,6 +17,7 @@ MODELS = {
|
||||
"openpilot": "https://github.com/commaai/openpilot/raw/7da48ebdba5e3cf4c0b8078c934bee9a199f0280/selfdrive/modeld/models/supercombo.onnx",
|
||||
"efficientnet": "https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx",
|
||||
"shufflenet": "https://github.com/onnx/models/raw/main/vision/classification/shufflenet/model/shufflenet-9.onnx",
|
||||
"commavq": "https://github.com/commaai/commavq/raw/master/models/gpt2m.onnx",
|
||||
|
||||
# broken in torch MPS
|
||||
#"zfnet": "https://github.com/onnx/models/raw/main/vision/classification/zfnet-512/model/zfnet512-9.onnx",
|
||||
@@ -29,6 +31,7 @@ MODELS = {
|
||||
|
||||
CSV = {}
|
||||
open_csv = None
|
||||
torch.manual_seed(1)
|
||||
|
||||
def benchmark(mnm, nm, fxn):
|
||||
tms = []
|
||||
@@ -36,26 +39,31 @@ def benchmark(mnm, nm, fxn):
|
||||
st = time.perf_counter_ns()
|
||||
ret = fxn()
|
||||
tms.append(time.perf_counter_ns() - st)
|
||||
print(f"{m:15s} {nm:25s} {min(tms)*1e-6:7.2f} ms")
|
||||
print(f"{mnm:15s} {nm:25s} {min(tms)*1e-6:7.2f} ms")
|
||||
CSV[nm] = min(tms)*1e-6
|
||||
return min(tms), ret
|
||||
|
||||
#BASE = pathlib.Path(__file__).parent.parent.parent / "weights" / "onnx"
|
||||
BASE = pathlib.Path("/tmp/onnx")
|
||||
def benchmark_model(m):
|
||||
def benchmark_model(m, validate_outs=False):
|
||||
global open_csv, CSV
|
||||
CSV = {"model": m}
|
||||
|
||||
fn = BASE / MODELS[m].split("/")[-1]
|
||||
download_file(MODELS[m], fn)
|
||||
onnx_model = onnx.load(fn)
|
||||
|
||||
output_names = [out.name for out in onnx_model.graph.output]
|
||||
excluded = {inp.name for inp in onnx_model.graph.initializer}
|
||||
input_shapes = {inp.name:tuple(x.dim_value if x.dim_value != 0 else 1 for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input if inp.name not in excluded}
|
||||
np_inputs = {k:torch.randn(shp).numpy() for k,shp in input_shapes.items()}
|
||||
assert len(input_shapes) < 20
|
||||
input_types = {inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input if inp.name not in excluded}
|
||||
#input_types = {k:v if v!=np.float16 else np.float32 for k,v in input_types.items()} # cast
|
||||
np_inputs = {k:torch.randn(shp).numpy().astype(input_types[k]) for k,shp in input_shapes.items()}
|
||||
assert len(input_shapes) < 30, f"too many input shapes {len(input_shapes)}"
|
||||
|
||||
for device in ["METAL" if OSX else "GPU", "CLANG"]:
|
||||
# print input names
|
||||
if DEBUG >= 2: print([inp.name for inp in onnx_model.graph.input if inp.name not in excluded])
|
||||
|
||||
for device in ["METAL" if OSX else "GPU", "CLANG"]: # + (["CUDA"] if torch.cuda.is_available() else []):
|
||||
Device.DEFAULT = device
|
||||
inputs = {k:Tensor(inp) for k,inp in np_inputs.items()}
|
||||
tinygrad_model = get_run_onnx(onnx_model)
|
||||
@@ -67,19 +75,55 @@ def benchmark_model(m):
|
||||
benchmark(m, f"tinygrad_{device.lower()}_jit", lambda: {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()})
|
||||
del inputs, tinygrad_model, tinygrad_jitted_model
|
||||
|
||||
torch_model = convert(onnx_model)
|
||||
torch_inputs = [torch.tensor(x) for x in np_inputs.values()]
|
||||
benchmark(m, "torch_cpu", lambda: torch_model(*torch_inputs))
|
||||
try:
|
||||
torch_model = convert(onnx_model)
|
||||
torch_inputs = [torch.tensor(x) for x in np_inputs.values()]
|
||||
benchmark(m, "torch_cpu", lambda: torch_model(*torch_inputs))
|
||||
|
||||
torch_device = "mps" if OSX else "cuda"
|
||||
torch_mps_model = torch_model.to(torch_device)
|
||||
torch_mps_inputs = [x.to(torch_device) for x in torch_inputs]
|
||||
benchmark(m, f"torch_{torch_device}", lambda: torch_mps_model(*torch_mps_inputs))
|
||||
torch_device = "mps" if OSX else "cuda"
|
||||
torch_mps_model = torch_model.to(torch_device)
|
||||
torch_mps_inputs = [x.to(torch_device) for x in torch_inputs]
|
||||
benchmark(m, f"torch_{torch_device}", lambda: torch_mps_model(*torch_mps_inputs))
|
||||
except NotImplementedError:
|
||||
print(f"{m:16s}onnx2torch doesn't support this model")
|
||||
|
||||
# bench onnxruntime
|
||||
ort_options = ort.SessionOptions()
|
||||
ort_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
ort_options.log_severity_level = 3 # no warnings
|
||||
for backend in ["CPU", "CUDA" if not OSX else "CoreML"]: # https://onnxruntime.ai/docs/execution-providers/
|
||||
provider = backend+"ExecutionProvider"
|
||||
if provider not in ort.get_available_providers(): continue
|
||||
ort_sess = ort.InferenceSession(str(fn), ort_options, [provider])
|
||||
benchmark(m, f"onnxruntime_{backend.lower()}", lambda: ort_sess.run(output_names, np_inputs))
|
||||
del ort_sess
|
||||
|
||||
if validate_outs:
|
||||
rtol, atol = 8e-4, 8e-4 # tolerance for fp16 models
|
||||
inputs = {k:Tensor(inp) for k,inp in np_inputs.items()}
|
||||
tinygrad_model = get_run_onnx(onnx_model)
|
||||
tinygrad_out = tinygrad_model(inputs)
|
||||
|
||||
ort_sess = ort.InferenceSession(str(fn), ort_options, ["CPUExecutionProvider"])
|
||||
onnx_out = ort_sess.run(output_names, np_inputs)
|
||||
onnx_out = dict([*[(name,x) for name, x in zip(output_names, onnx_out)]])
|
||||
|
||||
assert_allclose(tinygrad_out, onnx_out, rtol=rtol, atol=atol)
|
||||
print(f"{m:16s}outputs validated with rtol={rtol:.1e}, atol={atol:.1e}")
|
||||
|
||||
if open_csv is None:
|
||||
open_csv = csv.DictWriter(open('onnx_inference_speed.csv', 'w', newline=''), fieldnames=list(CSV.keys()))
|
||||
open_csv.writeheader()
|
||||
open_csv.writerow(CSV)
|
||||
|
||||
def assert_allclose(tiny_out:dict, onnx_out:dict, rtol=1e-5, atol=1e-5):
|
||||
assert len(tiny_out) == len(onnx_out) and tiny_out.keys() == onnx_out.keys()
|
||||
for k in tiny_out.keys():
|
||||
tiny_v, onnx_v = tiny_out[k], onnx_out[k]
|
||||
if tiny_v is None: assert tiny_v == onnx_v
|
||||
else: np.testing.assert_allclose(tiny_v.numpy(), onnx_v, rtol=rtol, atol=atol, err_msg=f"For tensor '{k}' in {tiny_out.keys()}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
for m in MODELS: benchmark_model(m)
|
||||
if getenv("MODEL", "") != "": benchmark_model(getenv("MODEL", ""), True)
|
||||
else:
|
||||
for m in MODELS: benchmark_model(m, True)
|
||||
|
||||
135
test/external/external_test_allocator_on_models.py
vendored
Normal file
135
test/external/external_test_allocator_on_models.py
vendored
Normal file
@@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest, gc
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.state import get_parameters, get_state_dict
|
||||
from tinygrad.ops import GlobalCounters, LazyOp, LoadOps
|
||||
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
|
||||
from tinygrad.helpers import dtypes, prod
|
||||
from tinygrad.lazy import Device
|
||||
|
||||
from examples.llama import Transformer
|
||||
|
||||
ALLOCATED_DEV_BUFS = 0
|
||||
class FakeDeviceBuffer():
|
||||
def __init__(self, sz, dt, device):
|
||||
self.id = 1
|
||||
self.size = sz
|
||||
self.dtype = dt
|
||||
self.device = device
|
||||
|
||||
global ALLOCATED_DEV_BUFS
|
||||
ALLOCATED_DEV_BUFS += 1
|
||||
class FakeAllocator(LRUAllocator):
|
||||
def _do_alloc(self, size, dtype, device, **kwargs): return FakeDeviceBuffer(size, dtype, device)
|
||||
def _do_free(self, buf):
|
||||
buf.id -= 1
|
||||
assert buf.id == 0, f"Free should be called once, but {buf.id}"
|
||||
|
||||
FAKE_GLOBAL_ALLOCATOR = None
|
||||
class FakeBuffer(RawBuffer):
|
||||
def __init__(self, size, dtype, device='0'):
|
||||
global FAKE_GLOBAL_ALLOCATOR
|
||||
super().__init__(size, dtype, allocator=FAKE_GLOBAL_ALLOCATOR, **{'device': device})
|
||||
assert self._buf.size == size and self._buf.dtype == dtype and self._buf.device == device, "This allocator requires 100% match of dtype and size."
|
||||
@classmethod
|
||||
def fromCPU(cls, x:np.ndarray, **kwargs): return cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
|
||||
def toCPU(self): return np.empty(self.size, dtype=self.dtype.np)
|
||||
class FakeProgram:
|
||||
def __init__(self, name:str, prg:str): pass
|
||||
def __call__(self, global_size, local_size, *bufs, wait=False): pass
|
||||
|
||||
def helper_test_correctness(gen, train):
|
||||
from tinygrad.runtime.ops_gpu import CL, CLAllocator
|
||||
old_alloc = CL.cl_allocator
|
||||
CL.cl_allocator = CLAllocator(0)
|
||||
no_alloc_result = train(*gen()).numpy()
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
CL.cl_allocator = CLAllocator(512<<30) # Test cache correctness, so cache as much as possible, 512gb
|
||||
for _ in range(4):
|
||||
GlobalCounters.reset()
|
||||
np.testing.assert_allclose(train(*gen()).numpy(), no_alloc_result, rtol=1e-3, atol=1e-5)
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
assert len(CL.cl_allocator.cached_buffers) != 0, "Cache must be used"
|
||||
CL.cl_allocator = old_alloc
|
||||
|
||||
def __helper_test_alloc_count(gen, train):
|
||||
was_alloc = ALLOCATED_DEV_BUFS
|
||||
for _ in range(2):
|
||||
train(*gen())
|
||||
return ALLOCATED_DEV_BUFS - was_alloc
|
||||
|
||||
def helper_test_alloc_count(mm, gen, train):
|
||||
global FAKE_GLOBAL_ALLOCATOR
|
||||
backup_program = Device[Device.DEFAULT].runtime
|
||||
backup_buffer = Device[Device.DEFAULT].buffer
|
||||
Device[Device.DEFAULT].runtime = FakeProgram
|
||||
Device[Device.DEFAULT].buffer = FakeBuffer
|
||||
Device[Device.DEFAULT].method_cache.clear()
|
||||
FAKE_GLOBAL_ALLOCATOR = FakeAllocator(16<<30)
|
||||
new_allocs = __helper_test_alloc_count(gen, train)
|
||||
Device[Device.DEFAULT].method_cache.clear()
|
||||
FAKE_GLOBAL_ALLOCATOR = FakeAllocator(0)
|
||||
old_allocs = __helper_test_alloc_count(gen, train)
|
||||
print(f"{mm}: llama: old allocs count {old_allocs}, new allocs count {new_allocs}")
|
||||
assert new_allocs < old_allocs, f"Hmm, doesn't cache work any more?"
|
||||
Device[Device.DEFAULT].runtime = backup_program
|
||||
Device[Device.DEFAULT].buffer = backup_buffer
|
||||
FAKE_GLOBAL_ALLOCATOR = None
|
||||
|
||||
def check_gc():
|
||||
if Device.DEFAULT == "GPU":
|
||||
gc.collect() # Need to collect Tensors.
|
||||
from extra.introspection import print_objects
|
||||
assert print_objects() == 0
|
||||
|
||||
# for speed
|
||||
def derandomize(x):
|
||||
if isinstance(x, LazyOp):
|
||||
if x.op == LoadOps.RAND: x.op = LoadOps.EMPTY
|
||||
x.src = [derandomize(s) for s in x.src]
|
||||
else:
|
||||
x.op = derandomize(x.op)
|
||||
return x
|
||||
|
||||
def derandomize_model(model):
|
||||
for p in get_parameters(model):
|
||||
p.lazydata = derandomize(p.lazydata)
|
||||
p.realize()
|
||||
|
||||
class TestAllocators(unittest.TestCase):
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
def test_lru_allocator_tiny_llama(self):
|
||||
old_type = Tensor.default_type
|
||||
Tensor.default_type = dtypes.float16
|
||||
|
||||
args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
def __test():
|
||||
model = Transformer(**args_tiny)
|
||||
derandomize_model(model)
|
||||
def test(t): return model(t, 0).realize()
|
||||
helper_test_correctness(lambda: (Tensor([[1,]]),), test)
|
||||
__test()
|
||||
Tensor.default_type = old_type
|
||||
check_gc()
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
def test_lru_allocator_tiny_llama_alloc_counts(self):
|
||||
args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
def test_alloc_count(t):
|
||||
model = Transformer(**args_tiny)
|
||||
for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype))
|
||||
return model(t, 0).realize()
|
||||
helper_test_alloc_count("llama", lambda: (Tensor([[2,]]),), test_alloc_count)
|
||||
check_gc()
|
||||
|
||||
@unittest.skip("huge for CI")
|
||||
def test_stable_diffusion(self):
|
||||
from examples.stable_diffusion import UNetModel
|
||||
model = UNetModel()
|
||||
derandomize_model(model)
|
||||
def test(t, t2): return model(t, 801, t2).realize()
|
||||
helper_test_correctness(lambda: (Tensor.randn(1, 4, 16, 16),Tensor.randn(1, 77, 768)), test)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
106
test/test_allocators.py
Normal file
106
test/test_allocators.py
Normal file
@@ -0,0 +1,106 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from weakref import ref
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
|
||||
from tinygrad.helpers import dtypes, prod
|
||||
from tinygrad.lazy import Device
|
||||
|
||||
def check_gc():
|
||||
if Device.DEFAULT == "GPU":
|
||||
from extra.introspection import print_objects
|
||||
assert print_objects() == 0
|
||||
|
||||
class FakeDeviceBuffer():
|
||||
def __init__(self, sz, dt, device):
|
||||
self.id = 1
|
||||
self.size = sz
|
||||
self.dtype = dt
|
||||
self.device = device
|
||||
def __del__(self):
|
||||
assert self.id == 0, "Should called _do_free() before"
|
||||
|
||||
class FakeAllocator(LRUAllocator):
|
||||
def _do_alloc(self, size, dtype, device, **kwargs): return FakeDeviceBuffer(size, dtype, device)
|
||||
def _do_free(self, buf):
|
||||
buf.id -= 1
|
||||
assert buf.id == 0, f"Free should be called once, but {buf.id}"
|
||||
|
||||
FAKE_GLOBAL_ALLOCATOR = None
|
||||
class FakeBuffer(RawBuffer):
|
||||
def __init__(self, size, dtype, device='0'):
|
||||
global FAKE_GLOBAL_ALLOCATOR
|
||||
super().__init__(size, dtype, allocator=FAKE_GLOBAL_ALLOCATOR, **{'device': device})
|
||||
assert self._buf.size == size and self._buf.dtype == dtype and self._buf.device == device, "This allocator requires 100% match of dtype and size."
|
||||
@classmethod
|
||||
def fromCPU(cls, x:np.ndarray, **kwargs): return cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
|
||||
def toCPU(self): return np.empty(self.size, dtype=self.dtype.np)
|
||||
|
||||
def alloc(allocator, size, dtype, **kwargs):
|
||||
global FAKE_GLOBAL_ALLOCATOR
|
||||
FAKE_GLOBAL_ALLOCATOR = allocator
|
||||
buf = FakeBuffer(size, dtype, **kwargs)
|
||||
assert buf.dtype == dtype and buf.size == size
|
||||
FAKE_GLOBAL_ALLOCATOR = None
|
||||
return buf
|
||||
|
||||
def alloc_free_trace(allocator, size, dtype, **kwargs):
|
||||
buf = alloc(allocator, size, dtype, **kwargs)
|
||||
return ref(buf._buf)
|
||||
|
||||
def cmp_trace_and_buf(buf, trace_ref): return trace_ref and trace_ref() == buf._buf
|
||||
|
||||
class TestAllocators(unittest.TestCase):
|
||||
def test_lru_allocator_reusage(self):
|
||||
def test():
|
||||
lru_allocator = FakeAllocator(2048)
|
||||
traced_buf = alloc_free_trace(lru_allocator, 16, dtypes.float32)
|
||||
assert GlobalCounters.mem_cached == 16*dtypes.float32.itemsize, "Buffer should be cached"
|
||||
for _ in range(32):
|
||||
def __test():
|
||||
buf = alloc(lru_allocator, 16, dtypes.float32)
|
||||
assert cmp_trace_and_buf(buf, traced_buf), "Buffer should be reused"
|
||||
__test()
|
||||
|
||||
usedbuf = alloc(lru_allocator, 16, dtypes.float32)
|
||||
for _ in range(32):
|
||||
def __test():
|
||||
buf = alloc(lru_allocator, 16, dtypes.float32)
|
||||
assert usedbuf != buf, "Nobody should get used buffer"
|
||||
__test()
|
||||
assert GlobalCounters.mem_used == 16*dtypes.float32.itemsize, "Only usedbuf is still allocated."
|
||||
test()
|
||||
check_gc()
|
||||
|
||||
def test_lru_allocator_cache_free(self):
|
||||
def test():
|
||||
lru_allocator = FakeAllocator(128)
|
||||
refs = []
|
||||
for _ in range(32):
|
||||
refs.append(alloc_free_trace(lru_allocator, 16, dtypes.float32))
|
||||
for sz in range(32):
|
||||
alloc_free_trace(lru_allocator, sz, dtypes.float32)
|
||||
assert GlobalCounters.mem_used + GlobalCounters.mem_cached <= 128, "Should not allocate on device more than allowed (128)"
|
||||
for r in refs: assert r() is None, "All refs should be dead, since buffers were cleared from cache"
|
||||
test()
|
||||
check_gc()
|
||||
|
||||
def test_lru_allocator_multidevice(self):
|
||||
def test():
|
||||
lru_allocator = FakeAllocator(256)
|
||||
refs=[]
|
||||
for i in range(8):
|
||||
refs.append(alloc_free_trace(lru_allocator, 16, dtypes.float32, device=str(i)))
|
||||
for i in range(64):
|
||||
def __test():
|
||||
dev = str(i % 8)
|
||||
buf = alloc(lru_allocator, 16, dtypes.float32, device=dev)
|
||||
assert cmp_trace_and_buf(buf, refs[i%8]), "Buffer should be reused"
|
||||
__test()
|
||||
for r in refs: assert r() is not None, "All refs should be cached"
|
||||
test()
|
||||
check_gc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -91,6 +91,12 @@ class TestHalfDtype(unittest.TestCase):
|
||||
def test_half_upcast_ops(self): _test_ops(a_dtype=dtypes.float16, b_dtype=dtypes.float32, target_dtype=dtypes.float32)
|
||||
def test_upcast_to_half_ops(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.float16, target_dtype=dtypes.float16)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in ["WEBGPU", "METAL"], "float64 is not supported by some backends")
|
||||
class TestDoubleDtype(unittest.TestCase):
|
||||
def test_float64_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.double), np.double, [1,2,3,4])
|
||||
def test_casts_to_float64(self): _test_casts_to([1,2,3,4], source_dtypes=[dtypes.float32, dtypes.int32, dtypes.uint8], target_dtype=dtypes.float64)
|
||||
def test_upcast_to_float64_ops(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.float64, target_dtype=dtypes.float64)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int8")
|
||||
class TestInt8Dtype(unittest.TestCase):
|
||||
def test_int8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4])
|
||||
@@ -107,8 +113,10 @@ class TestInt8Dtype(unittest.TestCase):
|
||||
def test_int8_upcast_int64(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.int64, target_dtype=dtypes.int64)
|
||||
|
||||
@unittest.skipIf(getenv("CUDA",0)==1, "cuda saturation works differently")
|
||||
@unittest.skipIf(getenv("PTX",0)==1, "cuda saturation doesn't wrap")
|
||||
def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])
|
||||
|
||||
@unittest.skipIf(getenv("PTX",0)==1, "cuda saturation doesn't wrap")
|
||||
def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import unittest
|
||||
from tinygrad.helpers import Context, ContextVar
|
||||
from tinygrad.helpers import Context, ContextVar, merge_dicts
|
||||
|
||||
VARIABLE = ContextVar("VARIABLE", 0)
|
||||
|
||||
@@ -106,5 +106,17 @@ with Context(VARIABLE=1):
|
||||
...
|
||||
assert D.value == 2, f"Expected D to be 2, but was {D.value}. Indicates that Context.__exit__ did not restore to the correct value."
|
||||
|
||||
class TestMergeDicts(unittest.TestCase):
|
||||
def test_merge_dicts(self):
|
||||
a = {"a": 1, "b": 2}
|
||||
b = {"a": 1, "c": 3}
|
||||
c = {}
|
||||
d = {"a": 2, "b": 2}
|
||||
assert merge_dicts([a, b]) == {"a": 1, "b": 2, "c": 3}
|
||||
assert merge_dicts([a, c]) == a
|
||||
assert merge_dicts([a, b, c]) == {"a": 1, "b": 2, "c": 3}
|
||||
with self.assertRaises(AssertionError):
|
||||
merge_dicts([a, d])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -4,7 +4,7 @@ import math
|
||||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes
|
||||
from tinygrad.lazy import Device
|
||||
|
||||
if CI:
|
||||
@@ -1127,21 +1127,37 @@ class TestOps(unittest.TestCase):
|
||||
n = (x < 0).where(x, 1).numpy()
|
||||
assert np.all(n == 1.)
|
||||
|
||||
def test_slice_fancy_indexing(self):
|
||||
# indices cannot have gradient
|
||||
a = torch.randint(low=-1, high=1, size=(2,1,1,1,1,1), dtype=torch.int64, requires_grad=False)
|
||||
b = torch.randint(high=1, size=(1,3,1,1,1,1), dtype=torch.int64, requires_grad=False)
|
||||
c = torch.randint(low=-5, high=5, size=(1,1,4,1,1,1), dtype=torch.int64, requires_grad=False)
|
||||
d = torch.randint(high=4, size=(2,1,1,5,1,1), dtype=torch.int64, requires_grad=False)
|
||||
e = torch.randint(high=1, size=(1,1,1,1,6,1), dtype=torch.int64, requires_grad=False)
|
||||
i, j, k, o, p = [Tensor(tor.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) for tor in [a,b,c,d,e]]
|
||||
helper_test_op([(2,5,15,5,3,4)], lambda x: x[a,b,c,d,e], lambda x: x[i,j,k,o,p])
|
||||
helper_test_op([(2,5,15,5,3,4)], lambda x: x[:,b,c,d,e], lambda x: x[:,j,k,o,p])
|
||||
helper_test_op([(2,5,15,5,3,4)], lambda x: x[:,b,c,d,:], lambda x: x[:,j,k,o,:])
|
||||
helper_test_op([(2,5,15,5,3,4)], lambda x: x[a,b,...], lambda x: x[i,j,...])
|
||||
helper_test_op([(2,5,15,5,3,4)], lambda x: x[a,...,e], lambda x: x[i,...,p])
|
||||
helper_test_op([(2,5,15,5,3,4)], lambda x: x[...,c,:,e], lambda x: x[...,k,:,p])
|
||||
helper_test_op([(2,5,15,5,3,4)], lambda x: x[a,:,None,d,e], lambda x: x[i,:,None,o,p])
|
||||
helper_test_op([(2,5,15,5,3,4)], lambda x: x[1,:,10:11,d,0:2], lambda x: x[1,:,10:11,o,0:2])
|
||||
helper_test_op([(2,5,15,5,3,4)], lambda x: x[1,4,c,d,2], lambda x: x[1,4,k,o,2])
|
||||
helper_test_op([(2,3)], lambda x: x[torch.tensor([[0,0,0],[0,0,0]]), torch.tensor(1)], lambda x: x[Tensor([[0,0,0],[0,0,0]]), Tensor(1)])
|
||||
helper_test_op([(2,3)], lambda x: x[torch.tensor([1]), torch.tensor([[0,0,0],[0,0,0]])], lambda x: x[Tensor([1]), Tensor([[0,0,0],[0,0,0]])])
|
||||
|
||||
def test_gather(self):
|
||||
nda = np.random.randn(4,5,6,9,5).astype(np.float32)
|
||||
ten = Tensor(nda, requires_grad=True)
|
||||
tor = torch.tensor(nda, requires_grad=True)
|
||||
c = np.random.randint(low=-4, high=4, size=[3,4,5]).astype(np.int32)
|
||||
a = Tensor(c, requires_grad=False)
|
||||
b = torch.tensor(c, requires_grad=False)
|
||||
helper_test_op([], lambda: tor[b,:,:,:,:], lambda: ten.gather(a, dim=0))
|
||||
helper_test_op([], lambda: tor[:,b,:,:,:], lambda: ten.gather(a, dim=1))
|
||||
helper_test_op([], lambda: tor[:,:,b,:,:], lambda: ten.gather(a, dim=2))
|
||||
helper_test_op([], lambda: tor[:,:,:,b,:], lambda: ten.gather(a, dim=3))
|
||||
helper_test_op([], lambda: tor[:,:,:,:,b], lambda: ten.gather(a, dim=4))
|
||||
ta = Tensor(c, requires_grad=True)
|
||||
tb = torch.tensor(c, requires_grad=True, dtype=torch.float32)
|
||||
self.helper_test_exception([], lambda: tor[tb,:,:,:,:].sum().backward(), lambda: ten.gather(ta, dim=0).sum().backward(), expected=(IndexError, RuntimeError)) # torch raises IndexError, Tensor raises RuntimeError
|
||||
# indices cannot have gradient
|
||||
# indices cannot be negative (torch gather)
|
||||
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
|
||||
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0))
|
||||
helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=1), lambda x: x.gather(idx=a, dim=1))
|
||||
helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=2), lambda x: x.gather(idx=a, dim=2))
|
||||
helper_test_op([(3,4,5)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0))
|
||||
self.helper_test_exception([(4,5,6)], lambda x: x.gather(index=torch.tensor([1], dtype=torch.int64), dim=0), lambda x: x.gather(idx=Tensor([1], dtype=dtypes.int32), dim=0), expected=(RuntimeError, AssertionError))
|
||||
self.helper_test_exception([(2,1,1)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0), expected=(RuntimeError, AssertionError))
|
||||
|
||||
def test_scaled_product_attention(self):
|
||||
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z))
|
||||
|
||||
113
test/test_symbolic_ops.py
Normal file
113
test/test_symbolic_ops.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import unittest
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.helpers import getenv, CI
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
import numpy as np
|
||||
|
||||
@unittest.skipIf(getenv("ARM64"), "ARM64 is not supported")
|
||||
@unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG"], f"{Device.DEFAULT} is not supported")
|
||||
class TestSymbolicOps(unittest.TestCase):
|
||||
def test_plus1(self):
|
||||
def f(a): return (a+1).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(3, i)
|
||||
symbolic = f(a.reshape(3, vi)).reshape(3, i).cpu().numpy()
|
||||
expected = f(a).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_add(self):
|
||||
def f(a, b): return (a+b).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, i)
|
||||
symbolic = f(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_matmul(self):
|
||||
def f(a, b): return (a@b).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(i, 5)
|
||||
symbolic = f(a.reshape(3, vi), b.reshape(vi, 5)).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_matmul_same_var_different_val(self):
|
||||
def f(a, b): return (a@b).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
a = Tensor.rand(3, 4)
|
||||
b = Tensor.rand(7, 5)
|
||||
with self.assertRaises(AssertionError):
|
||||
f(a.reshape(3, vi), b.reshape(vi, 5)).cpu().numpy()
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CLANG" and CI, "broken on CLANG CI")
|
||||
def test_attention(self):
|
||||
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
q = Tensor.rand(2, 1, 4, 8)
|
||||
k = Tensor.rand(2, i, 4, 8)
|
||||
v = Tensor.rand(2, i, 4, 8)
|
||||
symbolic = f(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).cpu().numpy()
|
||||
expected = f(q, k, v).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_cat_dim0(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(2, 3)
|
||||
symbolic = f(a.reshape(vi, 3), b).reshape(i+2, 3).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_cat_dim1(self):
|
||||
def f(a, b): return a.cat(b, dim=1).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, 2)
|
||||
symbolic = f(a.reshape(3, vi), b).reshape(3, i+2).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_cat_dim0_two_vars(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(j, 3)
|
||||
symbolic = f(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_cat_dim1_two_vars(self):
|
||||
def f(a, b): return a.cat(b, dim=1).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, j)
|
||||
symbolic = f(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_two_vars_plus1(self):
|
||||
def f(a, b): return (a@b+1).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(3, j)
|
||||
symbolic = f(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
@@ -26,17 +26,18 @@ class TestSymbolic(unittest.TestCase):
|
||||
i = Variable("i", 1, 5)
|
||||
j = Variable("j", 1, 5)
|
||||
k = Variable("k", 1, 5)
|
||||
t1 = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
|
||||
st = t1.lazydata.st
|
||||
t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
|
||||
st = t.lazydata.st
|
||||
assert st.shape == (i+j+k, 4)
|
||||
assert st.real_strides() == (4, 1)
|
||||
i = Variable("i", 1, 5)
|
||||
j = Variable("j", 1, 5)
|
||||
k = Variable("k", 1, 5)
|
||||
t1 = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
|
||||
st = t1.lazydata.st
|
||||
t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
|
||||
st = t.lazydata.st
|
||||
assert st.shape == (3, i+j+k)
|
||||
assert st.real_strides() == (i+j+k, 1)
|
||||
t = Tensor.rand(i, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
|
||||
st = t.lazydata.st
|
||||
assert st.shape == (2*i+3, 3)
|
||||
assert st.real_strides() == (3, 1)
|
||||
|
||||
class TestSymbolicReshape(unittest.TestCase):
|
||||
def test_reshape_into_symbols_simple(self):
|
||||
|
||||
@@ -80,14 +80,14 @@ class TestFloatUOps(TestUOps):
|
||||
def test_where(self): self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c)
|
||||
|
||||
# TODO: fix this on all the backends
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or Device.DEFAULT == "LLVM" or getenv('ARM64', False), "only test for compiled backends, broken on some")
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or getenv('ARM64', False), "only test for compiled backends, broken on some")
|
||||
class TestNonFloatUOps(TestUOps):
|
||||
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), dtypes.int32)
|
||||
def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), dtypes.int32)
|
||||
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), dtypes.int32)
|
||||
def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), dtypes.int32, no_b_zero=True)
|
||||
def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], dtypes.int32, no_b_zero=True)
|
||||
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), dtypes.int32)
|
||||
@unittest.skipIf(Device.DEFAULT == "CLANG", "broken in CLANG")
|
||||
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), dtypes.bool)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -283,6 +283,7 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
|
||||
assert NumNode(0) // (Variable("i", 1, 10)*128) == 0
|
||||
assert NumNode(127) // (Variable("i", 1, 10)*128) == 0
|
||||
assert idx0 // (i*3) == 0
|
||||
assert i // i == 1
|
||||
|
||||
def test_node_mod_node(self):
|
||||
i = Variable("i", 1, 10)
|
||||
|
||||
@@ -7,8 +7,8 @@ import functools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
|
||||
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes._float4: 'x'}
|
||||
def type_to_letter(x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
|
||||
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes._float4: 'x', dtypes.uint8: 'uc', dtypes.float16: 'h',
|
||||
dtypes.int8: 'c', dtypes.uint16: 'us', dtypes.float64: 'd'}
|
||||
|
||||
class Register(NamedTuple):
|
||||
nm:str
|
||||
@@ -37,9 +37,10 @@ class AssemblyLanguage:
|
||||
tor: Dict[Any, Register] = {}
|
||||
ins: List[AssemblyInstruction] = []
|
||||
|
||||
def type_to_letter(self,x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
|
||||
def newreg(self, tok, dtype=dtypes.float32, scalar=False):
|
||||
if isinstance(tok, Token): dtype = tok.dtype # this
|
||||
self.tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
|
||||
self.tor[tok] = ret = Register(f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
|
||||
if dtype == dtypes._float4:
|
||||
for off in range(4):
|
||||
self.tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off)
|
||||
@@ -96,8 +97,8 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
|
||||
#TODO: Do not use clear()
|
||||
lang.ins.clear()
|
||||
lang.tor.clear()
|
||||
lang.cnts.clear()
|
||||
buf_to_dtype = {args[0]:args[1] for uop,_,_,args in uops if uop == UOps.DEFINE_GLOBAL}
|
||||
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
||||
global_size, local_size = [], []
|
||||
skipload_branch = 0
|
||||
lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
|
||||
@@ -117,7 +118,7 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
|
||||
else:
|
||||
for var in args[0]:
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0)) #FIXME: what should valid be here?
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
|
||||
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
|
||||
elif uop == UOps.ENDLOOP:
|
||||
if args[1] not in ["global", "local", "global+local"]:
|
||||
@@ -129,7 +130,9 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
|
||||
elif args[1] == "global+local":
|
||||
for i, var in enumerate(reversed(args[0])):
|
||||
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
|
||||
|
||||
elif args[1] == 'local':
|
||||
for i, var in enumerate(reversed(args[0])):
|
||||
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
|
||||
elif uop == UOps.CAST and newvar is not None:
|
||||
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
||||
out = lang.newreg(newvar)
|
||||
@@ -155,10 +158,11 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
|
||||
elif uop == UOps.LOAD and newvar is not None:
|
||||
if isinstance(args, ConstOp):
|
||||
if args.valid.min == 0 and args.valid.max == 1:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.invalid_value))
|
||||
reg = lang.newreg(newvar, dtype=newvar.dtype)
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args.invalid_value))
|
||||
pred = args.valid.render(lang.render_ops, lang)
|
||||
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.value))
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args.value))
|
||||
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||||
skipload_branch += 1
|
||||
else:
|
||||
@@ -173,16 +177,16 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
|
||||
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||||
if args.valid.max == 1:
|
||||
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if buf_index[args.name] != -1 else 'shared', args.memory_dtype if buf_to_dtype[args.name] != dtypes.float else None)))
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||||
if args.valid.min == 0 and args.valid.max == 1:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||||
skipload_branch += 1
|
||||
elif uop == UOps.STORE:
|
||||
idx, treg, off = lang.addr_w_offset(args)
|
||||
lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if buf_index[args.name] != -1 else 'shared', args.memory_dtype if buf_to_dtype['data0'] != dtypes.float else None)))
|
||||
lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||||
# define registers
|
||||
lang.ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in lang.cnts.items()] + lang.ins
|
||||
lang.ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, lang.type_to_letter(dtype), c)) for dtype,c in lang.cnts.items()] + lang.ins
|
||||
|
||||
if DEBUG >= 4:
|
||||
for tins in lang.ins: print(tins)
|
||||
return global_size, local_size
|
||||
return global_size, local_size
|
||||
|
||||
@@ -22,7 +22,7 @@ def specialize_to_arm64(fn_nm, asm):
|
||||
ins = []
|
||||
x_regs = ['x' + str(i) for i in reversed(range(29)) if i not in (10,11,12,13,14,15,16,17,18,19,20)]
|
||||
s_regs = ['s' + str(i) for i in reversed(range(3,30))]
|
||||
type_to_reg = {dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
|
||||
type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
|
||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
||||
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
|
||||
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
|
||||
@@ -137,12 +137,12 @@ def specialize_to_arm64(fn_nm, asm):
|
||||
mov_imm(arg[0], "x15")
|
||||
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
|
||||
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
|
||||
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] == dtypes.half else 'scvtf'} {rtor[out.nm]}, {reg_in}")
|
||||
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
|
||||
elif uop == UOps.STORE:
|
||||
shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2", dtypes.bool: "#2"}
|
||||
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
|
||||
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] != dtypes.half else '' } {reg_out}, {rtor[vin[1].nm]}")
|
||||
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
|
||||
ins.append(f"mov x15, #{arg[0]}")
|
||||
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
|
||||
98
tinygrad/codegen/assembly_ptx.py
Normal file
98
tinygrad/codegen/assembly_ptx.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from typing import List
|
||||
import struct
|
||||
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.runtime.ops_cuda import arch
|
||||
|
||||
dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "f16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.int8: "s8", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32", dtypes.uint16: "u16", dtypes.uint8: "u8", "bits16": "b16", dtypes.float64: "f64"}
|
||||
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
|
||||
def ptx_needs_cast(dest_dtype, src_dtype): return dtypes.is_float(dest_dtype) and dtypes.is_int(src_dtype) or dtypes.is_int(dest_dtype) and dtypes.is_float(src_dtype) or (dtypes.is_float(src_dtype) and dtypes.is_float(dest_dtype) and dest_dtype.itemsize != src_dtype.itemsize)
|
||||
|
||||
def render_cast(ins, inp, out):
|
||||
if inp.dtype == dtypes.bool and (dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)):
|
||||
ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};")
|
||||
elif out.dtype == dtypes.bool:
|
||||
ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};")
|
||||
else:
|
||||
round_mod = ".rzi" if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) else '.rz' if dtypes.is_float(out.dtype) and (dtypes.is_int(inp.dtype) or dtypes.is_float(inp.dtype) and inp.dtype.itemsize > out.dtype.itemsize) else ''
|
||||
ins.append(f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};")
|
||||
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#
|
||||
|
||||
class PTXLanguage(AssemblyLanguage):
|
||||
supports_constant_folding: bool = True
|
||||
|
||||
def specialize_to_ptx(lang, function_name):
|
||||
param_cnt = 0
|
||||
ins = []
|
||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
||||
BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", UnaryOps.SQRT: "sqrt.approx",
|
||||
UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
|
||||
TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"}
|
||||
for uop, out, vin, arg in lang.ins:
|
||||
if uop == UOps.ENDLOOP:
|
||||
ins.append("bar.sync 0;")
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
|
||||
elif uop == UOps.SPECIAL:
|
||||
if arg.startswith('data'):
|
||||
param_cnt += 1
|
||||
ins.append(f"ld.param.u64 {out}, [{arg}];")
|
||||
# TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to?
|
||||
# ins.append(f"cvta.to.global.u64 {out}, {out};")
|
||||
elif arg.startswith('gid'):
|
||||
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
|
||||
elif arg.startswith('lid'):
|
||||
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
|
||||
elif uop == UOps.ALU:
|
||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||
ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
|
||||
else:
|
||||
otype = vin[0].dtype if arg in [BinaryOps.CMPLT] else out.dtype
|
||||
if arg == TernaryOps.WHERE:
|
||||
reg = lang.newreg((vin[0], 'bool'), dtypes.bool)
|
||||
ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};")
|
||||
vin = vin[1:] + [reg]
|
||||
ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
|
||||
elif uop == UOps.LOAD:
|
||||
if arg.__class__ in (int, float):
|
||||
ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};")
|
||||
elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
|
||||
dt = ('u16', dtypes.uint16) if arg[2] == dtypes.bool == out.dtype else ('u8', dtypes.uint8) if arg[2] == dtypes.bool else ('b16', dtypes.float16) if arg[2] == dtypes.half else (dtype_to_nvtype[arg[2]], arg[2])
|
||||
reg = lang.newreg((out, dt[0]), dtype=dt[1])
|
||||
ins.append(f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
|
||||
render_cast(ins, reg, out)
|
||||
else:
|
||||
ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
|
||||
elif uop == UOps.STORE:
|
||||
if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool:
|
||||
if arg[2] == dtypes.bool != vin[1].dtype:
|
||||
prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool)
|
||||
render_cast(ins, vin[1], prereg)
|
||||
else: prereg = vin[1]
|
||||
reg = lang.newreg((prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]), dtype=dtypes.uint16 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2])
|
||||
render_cast(ins, prereg, reg)
|
||||
ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};")
|
||||
else:
|
||||
ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
|
||||
elif uop == UOps.CAST:
|
||||
render_cast(ins, vin[0], out)
|
||||
elif uop == UOps.LABEL:
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
|
||||
|
||||
ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64",
|
||||
f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{"]
|
||||
for arg in [(dtype, lang.type_to_letter(dtype), c) for dtype,c in lang.cnts.items()]: ins_prefix.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",)
|
||||
ins = ins_prefix + ins
|
||||
ins += ["ret;", "}"]
|
||||
return '\n'.join(ins)
|
||||
|
||||
def uops_to_ptx_asm(function_name:str, uops:List[UOp]):
|
||||
lang = PTXLanguage()
|
||||
global_size, local_size = uops_to_asmstyle(lang, function_name, uops)
|
||||
return specialize_to_ptx(lang, function_name), global_size[::-1], local_size[::-1], True
|
||||
@@ -9,7 +9,7 @@ from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, TernaryOps
|
||||
from tinygrad.runtime.lib import RawConst, buf_is_kernel_arg
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape, View
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, sym_rename
|
||||
VariableOrNum = Union[Variable, NumNode, Node]
|
||||
|
||||
# bottom ones are asm only
|
||||
@@ -301,6 +301,9 @@ class Linearizer:
|
||||
# add global buffers
|
||||
for buf,name in self.arg_bufs.items():
|
||||
self.uop(UOps.DEFINE_GLOBAL, None, [], (name, buf.dtype))
|
||||
# add variables from symbolic shapes
|
||||
for var in sorted(set(v for buf in self.ast.buffers for v in buf.st.var_vals), key=lambda k: k.key):
|
||||
self.uop(UOps.DEFINE_GLOBAL, None, [], (var.expr, dtypes._arg_int32))
|
||||
|
||||
# add a local buffer for multistage reduce
|
||||
if len(self.group_for_reduce):
|
||||
@@ -317,7 +320,7 @@ class Linearizer:
|
||||
if DEBUG >= 3: self.printbufs()
|
||||
|
||||
# kernel name (before late upcast)
|
||||
self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) for x in self.full_shape])
|
||||
self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) if isinstance(x, int) else sym_rename(x) for x in self.full_shape])
|
||||
self.display_name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
||||
|
||||
# parse AST
|
||||
@@ -548,7 +551,7 @@ class Linearizer:
|
||||
assert len(colors) == self.shape_len, "colors size mismatch"
|
||||
return colors
|
||||
|
||||
def colored_shape(self) -> str: return ' '.join(colored(f"{s:4d}", color) for s,color in zip(self.full_shape, self.colors()))
|
||||
def colored_shape(self) -> str: return ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) else s for s in self.full_shape], self.colors()))
|
||||
def printbufs(self, prefix=""):
|
||||
for i in range(len(self.sts)):
|
||||
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i].realized is not None else str(self.bufs[i]):47s}", self.sts[i].views)
|
||||
|
||||
@@ -162,7 +162,7 @@ def hand_coded_optimizations(k:Linearizer):
|
||||
# early exit
|
||||
return
|
||||
|
||||
if k.opts.has_local:
|
||||
if k.opts.has_local and all(isinstance(s, int) for s in k.sts[0].shape[:k.first_reduce]):
|
||||
# are we grouping? (requires local shape support)
|
||||
if not k.float4_axis(0) and k.first_reduce <= 2 and k.first_reduce + 1 <= k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048:
|
||||
# TODO: use 1024 if it's allowed in a smarter way
|
||||
@@ -204,8 +204,8 @@ def hand_coded_optimizations(k:Linearizer):
|
||||
while prod(k.sts[0].shape[:k.first_reduce]) >= 1024:
|
||||
xb_choices = []
|
||||
for axis, upcast_amount in itertools.product(range(k.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
|
||||
# if we haven't upcasted it, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis not in upcasted_axis and k.full_shape[axis]%upcast_amount == 0 and any(k.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in k.upcasted_axis(buf_index)) for buf_index in range(len(k.sts))):
|
||||
# if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis not in upcasted_axis and isinstance(k.full_shape[axis], int) and k.full_shape[axis]%upcast_amount == 0 and any(k.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in k.upcasted_axis(buf_index)) for buf_index in range(len(k.sts))):
|
||||
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in k.sts), sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount))
|
||||
if len(xb_choices):
|
||||
xb_choices = sorted(xb_choices)
|
||||
@@ -219,7 +219,7 @@ def hand_coded_optimizations(k:Linearizer):
|
||||
|
||||
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
|
||||
if k.first_reduce < (k.shape_len-k.upcasted) and (len(list(k.shape_offsets(k.full_buf_index))) <= 4 or not any(r for _,_,r in k.upcasted_axis(k.full_buf_index))):
|
||||
if (s:=k.full_unupcasted_shape[-1]) <= 32:
|
||||
if (s:=k.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
|
||||
k.upcast()
|
||||
# if it's small, upcast a second reduce dimension too
|
||||
if k.first_reduce < (k.shape_len-k.upcasted) and s <= 3 and k.full_unupcasted_shape[-1] <= 3: k.upcast()
|
||||
|
||||
@@ -3,7 +3,7 @@ import os, functools, platform, time, re, contextlib
|
||||
from weakref import KeyedRef, ref
|
||||
from _weakref import _remove_dead_weakref # type: ignore
|
||||
import numpy as np
|
||||
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any
|
||||
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any, Iterable
|
||||
from math import prod # noqa: F401 # pylint:disable=unused-import
|
||||
|
||||
ShapeType = Tuple[int, ...]
|
||||
@@ -22,6 +22,10 @@ def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (
|
||||
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
|
||||
def mnum(i) -> str: return str(i) if i >= 0 else f"m{-i}"
|
||||
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
||||
def merge_dicts(ds:Iterable[Dict]) -> Dict:
|
||||
kvs = set([(k,v) for d in ds for k,v in d.items()])
|
||||
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
|
||||
return {k:v for k,v in kvs}
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def getenv(key, default=0): return type(default)(os.getenv(key, default))
|
||||
@@ -83,11 +87,11 @@ class ImageDType(DType):
|
||||
|
||||
class dtypes:
|
||||
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
||||
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64)
|
||||
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
||||
@staticmethod
|
||||
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes._half4, dtypes._float4)
|
||||
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes._half4, dtypes._float2, dtypes._float4)
|
||||
@staticmethod
|
||||
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint32, dtypes.uint64)
|
||||
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
||||
@staticmethod
|
||||
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
|
||||
@staticmethod
|
||||
@@ -115,6 +119,7 @@ class dtypes:
|
||||
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
|
||||
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
|
||||
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
|
||||
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
|
||||
|
||||
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}
|
||||
@@ -125,6 +130,7 @@ class GlobalCounters:
|
||||
time_sum_s: ClassVar[float] = 0.0
|
||||
kernel_count: ClassVar[int] = 0
|
||||
mem_used: ClassVar[int] = 0 # NOTE: this is not reset
|
||||
mem_cached: ClassVar[int] = 0 # NOTE: this is not reset
|
||||
cache: ClassVar[Optional[List[Tuple[Callable, Any]]]] = None
|
||||
@staticmethod
|
||||
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0.0,0,None
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, List, Tuple, Any, Dict, cast, Union
|
||||
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import DEBUG, DType
|
||||
|
||||
@@ -12,7 +12,7 @@ class TinyJit:
|
||||
def __init__(self, fxn:Callable):
|
||||
self.fxn: Callable = fxn
|
||||
self.cnt: int = 0
|
||||
self.jit_cache: List[Tuple[Callable, Any]] = [] # TODO: Any should be List[RawBuffer], but this fails
|
||||
self.jit_cache: List[Tuple[Callable, List[Optional[RawBuffer]]]] = []
|
||||
self.ret: Any = None
|
||||
self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], int, DType]]= {} # (kernel_number, buffer_number) -> (input_name, expected_size, expected_type)
|
||||
|
||||
@@ -29,7 +29,8 @@ class TinyJit:
|
||||
for (j,i),(input_name, expected_size, expected_type) in self.input_replace.items():
|
||||
assert input_rawbuffers[input_name].size == expected_size and input_rawbuffers[input_name].dtype == expected_type, f"size or type mismatch in JIT, {input_rawbuffers[input_name]} != <{expected_size}, {expected_type}>"
|
||||
self.jit_cache[j][1][i] = input_rawbuffers[input_name]
|
||||
for prg, args in self.jit_cache: prg(args, jit=True)
|
||||
for prg, pargs in self.jit_cache: # type: Callable, List[Optional[RawBuffer]]
|
||||
prg(pargs, jit=True)
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
|
||||
elif self.cnt == 1:
|
||||
GlobalCounters.cache = []
|
||||
@@ -40,10 +41,10 @@ class TinyJit:
|
||||
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
||||
|
||||
# get the inputs for replacement
|
||||
for j,(prg,args) in enumerate(self.jit_cache): # pylint: disable=E1133
|
||||
for i,a in enumerate(args):
|
||||
for j_,(_,pargs_) in enumerate(self.jit_cache): # type: Tuple[int, Tuple[Callable, List[Optional[RawBuffer]]]]
|
||||
for i,a in enumerate(pargs_):
|
||||
if a in input_rawbuffers.values():
|
||||
self.input_replace[(j,i)] = [(k, v.size, v.dtype) for k,v in input_rawbuffers.items() if v == a][0]
|
||||
self.input_replace[(j_,i)] = [(k, v.size, v.dtype) for k,v in input_rawbuffers.items() if v == a][0]
|
||||
#if prg.local_size is None: prg.local_size = prg.optimize_local_size(args, preserve_output=True) # the JIT can optimize local
|
||||
assert set([x[0] for x in self.input_replace.values()]) == set(input_rawbuffers.keys()), "some input tensors not found"
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
|
||||
|
||||
@@ -2,8 +2,9 @@ from __future__ import annotations
|
||||
import functools, time
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, dedup
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, dedup, merge_dicts
|
||||
from tinygrad.shape.shapetracker import MovementOps
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
from tinygrad.runtime.lib import RawBuffer, RawConst, buf_is_kernel_arg
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
@@ -131,20 +132,24 @@ class ASTRunner:
|
||||
self.clprg = runtime(self.name, self.prg, **self.runtime_args)
|
||||
return self
|
||||
|
||||
def exec(self, bufs, force_wait=False, optimizing=False) -> Optional[float]:
|
||||
def exec(self, bufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]:
|
||||
rawbufs = dedup([x.realized for x in bufs if buf_is_kernel_arg(x)])
|
||||
if GlobalCounters.cache is not None and not optimizing: GlobalCounters.cache.append((self, rawbufs))
|
||||
return self(rawbufs, force_wait=force_wait)
|
||||
return self(rawbufs, var_vals, force_wait=force_wait)
|
||||
|
||||
def __call__(self, rawbufs:List[RawBuffer], jit=False, force_wait=False) -> Optional[float]:
|
||||
if et := self.clprg((self.global_size + [1]*(3-len(self.global_size))) if self.global_size is not None else None,
|
||||
(self.local_size + [1]*(3-len(self.local_size))) if self.local_size is not None else None,
|
||||
*rawbufs, wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
|
||||
def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
|
||||
if var_vals is None: var_vals = {}
|
||||
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size
|
||||
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else self.local_size
|
||||
if et := self.clprg((global_size + [1]*(3-len(global_size))) if global_size is not None else None,
|
||||
(local_size + [1]*(3-len(local_size))) if local_size is not None else None,
|
||||
*rawbufs, *var_vals.values(), wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
|
||||
op_estimate = sym_infer(self.op_estimate, var_vals)
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(33-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(33-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(global_size):18s} {str(local_size):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
GlobalCounters.kernel_count += 1
|
||||
GlobalCounters.global_ops += self.op_estimate
|
||||
GlobalCounters.global_ops += op_estimate
|
||||
GlobalCounters.global_mem += self.mem_estimate
|
||||
if getenv("EARLY_STOPPING") and GlobalCounters.kernel_count == getenv("EARLY_STOPPING"): exit(0)
|
||||
return et
|
||||
@@ -178,9 +183,11 @@ class Compiled:
|
||||
output.realized = None
|
||||
break
|
||||
|
||||
# we don't have an output buffer, we have to create it
|
||||
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
|
||||
if not output.realized:
|
||||
output.realized = self.buffer(prod(output.shape), output.dtype, **kwargs)
|
||||
output.realized = self.buffer(prod((s if isinstance(s, int) else s.max for s in output.shape)), output.dtype, **kwargs)
|
||||
# update the output var_vals from src
|
||||
output.st.var_vals = dict(sorted(merge_dicts([buf.st.var_vals for buf in ast.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
|
||||
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
k = Linearizer(ast, output, self.linearizer_opts)
|
||||
@@ -200,5 +207,5 @@ class Compiled:
|
||||
|
||||
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
|
||||
|
||||
prg.exec(k.bufs)
|
||||
prg.exec(k.bufs, var_vals=output.st.var_vals)
|
||||
return output.realized
|
||||
|
||||
@@ -3,7 +3,7 @@ import math
|
||||
from tinygrad.codegen.linearizer import UOps, UOp, MemOp, ConstOp
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.helpers import ImageDType, dtypes, getenv, prod, DType
|
||||
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable
|
||||
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, sym_render
|
||||
|
||||
# div is different in cl than python
|
||||
render_cl = render_python.copy()
|
||||
@@ -17,6 +17,7 @@ class CStyleLanguage(NamedTuple):
|
||||
buffer_prefix: str = ""
|
||||
buffer_suffix: str = ""
|
||||
smem_prefix: str = ""
|
||||
arg_int_prefix: str = ""
|
||||
barrier: str = ""
|
||||
gid: List[str] = []
|
||||
lid: List[str] = []
|
||||
@@ -52,7 +53,7 @@ class CStyleLanguage(NamedTuple):
|
||||
def render_const(self, x:Union[float,int], var_dtype) -> str:
|
||||
if math.isnan(x): val = "NAN"
|
||||
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
|
||||
else: val = f"{x}" + ("f" if isinstance(x, float) else "")
|
||||
else: val = f"{x}f" if dtypes.is_float(var_dtype) and isinstance(x, float) else f"{int(x)}"
|
||||
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
|
||||
|
||||
# returns a str expression of the loaded value with the output type
|
||||
@@ -69,7 +70,7 @@ class CStyleLanguage(NamedTuple):
|
||||
def render_local(self, name:str, size:int):
|
||||
return self.smem_prefix + f"float {name}[{size}];"
|
||||
|
||||
def render_for(self, expr: str, _min:int, _max:int) -> str:
|
||||
def render_for(self, expr: str, _min:int, _max:Union[int,str]) -> str:
|
||||
return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{"
|
||||
|
||||
def render_conditional(self, cond: str, x:str, y:str) -> str:
|
||||
@@ -78,6 +79,7 @@ class CStyleLanguage(NamedTuple):
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]:
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else ""
|
||||
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
|
||||
self.arg_int_prefix if dtype == dtypes._arg_int32 else
|
||||
("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)]
|
||||
prg = ''.join([f"{self.kernel_prefix} void {function_name}(",] +
|
||||
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
|
||||
@@ -128,7 +130,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
|
||||
kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
|
||||
else:
|
||||
if getenv("NOUNROLL") and not isinstance(var, NumNode): kk("#pragma unroll(1)") # prevent loop unrolling
|
||||
kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, var.max))
|
||||
kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, sym_render(var.max)))
|
||||
depth += 1
|
||||
elif uop == UOps.BARRIER:
|
||||
kk(lang.barrier)
|
||||
|
||||
@@ -22,17 +22,18 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
||||
UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
UnaryOps.SQRT: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)),
|
||||
BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)),
|
||||
BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y, flags=('fast',)),
|
||||
BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y, flags=('fast',)),
|
||||
BinaryOps.CMPLT: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("<", x, y, flags=('fast',)), ir.FloatType()),
|
||||
BinaryOps.ADD: lambda builder,x,y: builder.add(x,y) if isinstance(x.type, ir.IntType) else builder.fadd(x,y, flags=('fast',)),
|
||||
BinaryOps.SUB: lambda builder,x,y: builder.sub(x,y) if isinstance(x.type, ir.IntType) else builder.fsub(x,y, flags=('fast',)),
|
||||
BinaryOps.MUL: lambda builder,x,y: builder.mul(x,y) if isinstance(x.type, ir.IntType) else builder.fmul(x,y, flags=('fast',)),
|
||||
BinaryOps.DIV: lambda builder,x,y: builder.sdiv(x,y) if isinstance(x.type, ir.IntType) else builder.fdiv(x,y, flags=('fast',)),
|
||||
BinaryOps.CMPLT: lambda builder,x,y: builder.zext(builder.icmp_signed("<", x, y),ir.IntType(32)) if isinstance(x.type, ir.IntType) else builder.uitofp(builder.fcmp_ordered("<", x, y, flags=('fast',)), ir.FloatType()),
|
||||
BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=('fast',)), x, y, flags=('fast',)),
|
||||
BinaryOps.MOD: lambda builder,x,y: builder.srem(x,y),
|
||||
TernaryOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=('fast',)), z, flags=('fast',)),
|
||||
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), y, z, flags=('fast',)),
|
||||
}
|
||||
|
||||
dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
|
||||
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
|
||||
|
||||
def cast(bb, val, input_type, output_type):
|
||||
if input_type == output_type: return val
|
||||
@@ -44,6 +45,8 @@ def cast(bb, val, input_type, output_type):
|
||||
val = bb[-1].sext(val, ir.IntType(32))
|
||||
val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16))
|
||||
val = bb[-1].bitcast(val, ir.FloatType())
|
||||
elif input_type == dtypes.float64:
|
||||
val = bb[-1].fptrunc(val, ir.FloatType())
|
||||
else:
|
||||
val = bb[-1].fpext(val, ir.FloatType())
|
||||
return val
|
||||
@@ -55,6 +58,8 @@ def cast(bb, val, input_type, output_type):
|
||||
val = bb[-1].bitcast(val, ir.IntType(32))
|
||||
val = bb[-1].lshr(val, ir.Constant(ir.IntType(32), 16))
|
||||
val = bb[-1].trunc(val, ir.IntType(16))
|
||||
elif output_type == dtypes.float64:
|
||||
val = bb[-1].fpext(val, ir.DoubleType())
|
||||
else:
|
||||
val = bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type])
|
||||
return val
|
||||
@@ -114,11 +119,11 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
|
||||
assert newvar is not None and isinstance(args, (MemOp, ConstOp))
|
||||
valid = args.valid.render(render_llvm, bb[-1])
|
||||
if isinstance(args, ConstOp):
|
||||
assert newvar.dtype == dtypes.float, "newvar must be float"
|
||||
value, invalid_value = [int(args.value), int(args.invalid_value)] if dtypes.is_int(newvar.dtype) else ([bool(args.value), bool(args.invalid_value)] if newvar.dtype == dtypes.bool else [args.value, args.invalid_value]) # type: ignore
|
||||
if args.valid.min == 0 and args.valid.max == 1:
|
||||
val = bb[-1].select(valid, ir.Constant(ir.FloatType(), args.value), ir.Constant(ir.FloatType(), args.invalid_value))
|
||||
val = bb[-1].select(valid, ir.Constant(dtype_to_llvm_dtype[newvar.dtype], value), ir.Constant(dtype_to_llvm_dtype[newvar.dtype], invalid_value))
|
||||
else:
|
||||
val = ir.Constant(ir.FloatType(), args.value if args.valid.min == 1 else args.invalid_value)
|
||||
val = ir.Constant(dtype_to_llvm_dtype[newvar.dtype], value if args.valid.min == 1 else invalid_value)
|
||||
# TODO: this is a hack. it shouldn't be const that signals this
|
||||
reduce_phis.append(newvar)
|
||||
else:
|
||||
|
||||
@@ -38,7 +38,7 @@ class WGSLLanguage(CStyleLanguage):
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}"
|
||||
return prg, global_size[::-1] if len(global_size) else [1], local_size
|
||||
|
||||
def render_for(self, expr:str, _min:int, _max:int) -> str:
|
||||
def render_for(self, expr:str, _min:int, _max:Union[int,str]) -> str:
|
||||
return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{"
|
||||
|
||||
def render_conditional(self, cond:str, x:str, y:str) -> str:
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
import ctypes
|
||||
import numpy as np
|
||||
from typing import TypeVar, Type, Any
|
||||
from tinygrad.helpers import DType, dtypes, prod, GlobalCounters
|
||||
from collections import defaultdict, deque
|
||||
from typing import TypeVar, Type, Any, Dict, Deque, Tuple
|
||||
from tinygrad.helpers import DType, dtypes, prod, GlobalCounters, ImageDType
|
||||
|
||||
_T = TypeVar("_T")
|
||||
class RawBuffer: # pylint: disable=abstract-method
|
||||
def __init__(self, size:int, dtype:DType, buf:Any=None):
|
||||
def __init__(self, size:int, dtype:DType, buf:Any=None, allocator:Any=None, **kwargs):
|
||||
self.size: int = size
|
||||
self.dtype: DType = dtype
|
||||
self._buf = buf
|
||||
self._buf = buf if buf is not None else (allocator.alloc(size, dtype, **kwargs) if allocator else None) # If buf is provided, use it. Otherwise try to allocate from the allocator.
|
||||
self._memsz: int = size*dtype.itemsize
|
||||
self._allocator = allocator
|
||||
GlobalCounters.mem_used += self._memsz
|
||||
def __del__(self): # NOTE: if it fails on init (bad dtype), it won't have a _memsz
|
||||
if hasattr(self, '_memsz'): GlobalCounters.mem_used -= self._memsz
|
||||
if hasattr(self, '_allocator') and self._allocator: self._allocator.free(self._buf)
|
||||
def __repr__(self): return f"buffer<{self.size}, {self.dtype}>"
|
||||
@property
|
||||
def key(self): return (self.size, self.dtype.key)
|
||||
@@ -39,7 +42,7 @@ class RawBufferMapped(RawBufferCopyIn):
|
||||
|
||||
# this one is simple enough that i moved it out of the runtimes
|
||||
class RawMallocBuffer(RawBufferMapped):
|
||||
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64}[dtype] * size)())
|
||||
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float64:ctypes.c_double, dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64}[dtype] * size)())
|
||||
def _buffer(self): return memoryview(self._buf)
|
||||
|
||||
class RawBufferCopyInOut(RawBufferCopyIn):
|
||||
@@ -66,3 +69,43 @@ class RawConst(RawBuffer): # pylint: disable=abstract-method
|
||||
|
||||
def buf_is_kernel_arg(x) -> bool:
|
||||
return x.realized is not None and x.realized.__class__ is not RawConst
|
||||
|
||||
class LRUAllocator:
|
||||
def __init__(self, dev_memsz=(4<<30)):
|
||||
self.epoch = 0
|
||||
self.free_space: Dict[Any, int] = defaultdict(lambda: dev_memsz)
|
||||
self.buffer_info: Dict[Any, Tuple[int, DType, str]] = dict()
|
||||
self.cached_buffers: Dict[Tuple[int, ...], Deque[Tuple[Any, int]]] = defaultdict(deque) # Cached buffer storage, splitted by type and size, newest first.
|
||||
self.aging_order: Dict[Any, Deque[Tuple[Tuple[int, ...], int]]] = defaultdict(deque) # Keys of cached_buffers, ordered from oldest to newest updates.
|
||||
def __del__(self):
|
||||
for v in self.cached_buffers.values():
|
||||
for buf, _ in v: self._free_buffer(buf)
|
||||
def _cache_reuse_buffer(self, rawbufs: Deque[Tuple[Any, int]]): # The newest cached buffer is reused.
|
||||
GlobalCounters.mem_cached -= self._underlying_buf_memsz(rawbufs[0][0])
|
||||
return rawbufs.popleft()[0]
|
||||
def _alloc_buffer(self, size, dtype, device, **kwargs):
|
||||
self.free_space[device] -= size*dtype.itemsize
|
||||
while len(self.aging_order[device]) and self.free_space[device] < 0: # When OOM removing lru buffers.
|
||||
bucket, epoch = self.aging_order[device].popleft()
|
||||
if self.cached_buffers[bucket] and self.cached_buffers[bucket][-1][1] == epoch: self._free_buffer(self.cached_buffers[bucket].pop()[0]) # Free cached buffer if it is still in cache.
|
||||
newbuf = self._do_alloc(size, dtype, device, **kwargs)
|
||||
self.buffer_info[newbuf] = (size, dtype, device)
|
||||
return newbuf
|
||||
def _free_buffer(self, buf_to_free):
|
||||
self.free_space[self.buffer_info[buf_to_free][2]] += self._underlying_buf_memsz(buf_to_free)
|
||||
GlobalCounters.mem_cached -= self._underlying_buf_memsz(buf_to_free)
|
||||
self.buffer_info.pop(buf_to_free)
|
||||
self._do_free(buf_to_free)
|
||||
def alloc(self, size, dtype, device='0', **kwargs):
|
||||
rawbufs = self.cached_buffers.get(self._cached_bufkey(size, dtype, device), None)
|
||||
return self._cache_reuse_buffer(rawbufs) if rawbufs else self._alloc_buffer(size, dtype, device, **kwargs)
|
||||
def free(self, buf): # free() just caches buffer. It might be freed later when OOM during allocation.
|
||||
self.epoch += 1
|
||||
size, dtype, device = self.buffer_info[buf]
|
||||
self.cached_buffers[self._cached_bufkey(size, dtype, device)].appendleft((buf, self.epoch))
|
||||
self.aging_order[device].append((self._cached_bufkey(size, dtype, device), self.epoch))
|
||||
GlobalCounters.mem_cached += self._underlying_buf_memsz(buf)
|
||||
def _underlying_buf_memsz(self, buf): return self.buffer_info[buf][0] * self.buffer_info[buf][1].itemsize
|
||||
def _cached_bufkey(self, size, dtype, device) -> Tuple[int, ...]: return (device, size, dtype, dtype.shape) if isinstance(dtype, ImageDType) else (device, size, dtype) # Provides a key for reusing device buffers with identical keys.
|
||||
def _do_alloc(self, size, dtype, device, **kwargs): raise NotImplementedError("must be implemented")
|
||||
def _do_free(self, buf): pass
|
||||
|
||||
@@ -74,8 +74,8 @@ class ClangProgram:
|
||||
mu.emu_start(ADDRESS, ADDRESS + len(self.prg))
|
||||
args[0]._buf = mu.mem_read(mu.reg_read(arm64_const.UC_ARM64_REG_X0), args[0].size * args[0].dtype.itemsize)
|
||||
else:
|
||||
self.fxn(*[x._buf for x in args])
|
||||
self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args])
|
||||
if wait: return time.monotonic()-st
|
||||
|
||||
renderer = fromimport("tinygrad.codegen.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict"))
|
||||
renderer = fromimport("tinygrad.codegen.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict", arg_int_prefix="const int"))
|
||||
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram)
|
||||
|
||||
@@ -2,9 +2,9 @@ import subprocess, time, re, hashlib, tempfile, os, functools
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
from pycuda.compiler import compile as cuda_compile # type: ignore
|
||||
from tinygrad.helpers import DEBUG, getenv, colored
|
||||
from tinygrad.helpers import DEBUG, getenv, colored, fromimport
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator
|
||||
from tinygrad.codegen.linearizer import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
||||
|
||||
@@ -47,8 +47,12 @@ if getenv("CUDACPU", 0) == 1:
|
||||
else:
|
||||
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
|
||||
import pycuda.driver as cuda # type: ignore
|
||||
class CUDAAllocator(LRUAllocator):
|
||||
def _do_alloc(self, size, dtype, device, **kwargs): return cuda.mem_alloc(size * dtype.itemsize) # type: ignore
|
||||
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
|
||||
CUDAAlloc = CUDAAllocator(pycuda.driver.Context.get_device().total_memory())
|
||||
class RawCUDABuffer(RawBufferCopyInOut): # type: ignore
|
||||
def __init__(self, size, dtype): super().__init__(size, dtype, cuda.mem_alloc(size * dtype.itemsize)) # type: ignore
|
||||
def __init__(self, size, dtype): super().__init__(size, dtype, allocator=CUDAAlloc)
|
||||
def _copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._buf, x.ravel(), stream) # type: ignore
|
||||
def _copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._buf) # type: ignore
|
||||
|
||||
@@ -91,5 +95,5 @@ renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
|
||||
__device__ __forceinline__ explicit half4(const float4& a): x(make_half2(__float2half(a.x), __float2half(a.y))), y(make_half2(__float2half(a.z),__float2half(a.w))) {}
|
||||
__device__ __forceinline__ explicit operator float4() const {return make_float4(__half2float(x.x), __half2float(x.y), __half2float(y.x), __half2float(y.y)); }
|
||||
};
|
||||
"""))
|
||||
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), renderer, CUDAProgram, cuda.Context.synchronize)
|
||||
""")) if not getenv("PTX") else fromimport("tinygrad.codegen.assembly_ptx", "uops_to_ptx_asm")
|
||||
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), renderer, CUDAProgram, cuda.Context.synchronize)
|
||||
|
||||
@@ -5,7 +5,7 @@ import pyopencl as cl # type: ignore
|
||||
from typing import Optional, List
|
||||
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut, RawBufferTransfer
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
|
||||
from tinygrad.codegen.linearizer import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
||||
|
||||
@@ -17,30 +17,36 @@ ROCM_LLVM_PATH = pathlib.Path("/opt/rocm/llvm/bin")
|
||||
if DEBUG >= 5:
|
||||
early_exec = fromimport("extra.helpers", "enable_early_exec")()
|
||||
|
||||
class CLAllocator(LRUAllocator):
|
||||
def _do_alloc(self, size, dtype, device, **kwargs):
|
||||
if isinstance(dtype, ImageDType):
|
||||
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
|
||||
assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}"
|
||||
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
|
||||
buf = cl.Image(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
|
||||
else:
|
||||
buf = cl.Buffer(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, size * dtype.itemsize)
|
||||
setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer
|
||||
return buf
|
||||
|
||||
class _CL:
|
||||
def __init__(self):
|
||||
cl_platforms = cl.get_platforms()
|
||||
platform_devices: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl_platforms] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl_platforms]) if len(y)]
|
||||
self.devices = [device for device in platform_devices[getenv('CL_PLATFORM', 0)] if device.name not in getenv('CL_EXCLUDE', "").split(",")]
|
||||
self.cl_platform = self.devices[0].platform
|
||||
def post_init(self, device=None):
|
||||
platforms: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()]) if len(y)]
|
||||
self.cl_platform = cl.get_platforms()[getenv('CL_PLATFORM', 0)]
|
||||
self.cl_ctxs: List[cl.Context] = [cl.Context(devices=[x]) for x in platforms[getenv('CL_PLATFORM', 0)] if x.name not in getenv('CL_EXCLUDE', "").split(",")] if device is None else [cl.Context(devices=[platforms[getenv('CL_PLATFORM', 0)][device]])]
|
||||
self.cl_ctxs: List[cl.Context] = [cl.Context(devices=[x]) for x in self.devices] if device is None else [cl.Context(devices=[self.devices[device]])]
|
||||
if DEBUG >= 1: print(f"using devices: {[ctx.devices[0].hashable_model_and_version_identifier for ctx in self.cl_ctxs]}")
|
||||
self.cl_queue: List[cl.CommandQueue] = [cl.CommandQueue(ctx, device=ctx.devices[0], properties=cl.command_queue_properties.PROFILING_ENABLE) for ctx in self.cl_ctxs]
|
||||
self.cl_allocator = CLAllocator(CL.cl_ctxs[0].devices[0].get_info(cl.device_info.GLOBAL_MEM_SIZE))
|
||||
def synchronize(self):
|
||||
for q in self.cl_queue: q.finish()
|
||||
CL = _CL()
|
||||
CL.post_init() if not getenv("DELAYED_RUNTIME_INIT", False) else None
|
||||
if not getenv("DELAYED_RUNTIME_INIT", False): CL.post_init()
|
||||
|
||||
class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
|
||||
def __init__(self, size, dtype, device='0'):
|
||||
if isinstance(dtype, ImageDType):
|
||||
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
|
||||
buf = cl.Image(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
|
||||
assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}"
|
||||
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
|
||||
else:
|
||||
buf = cl.Buffer(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, size * dtype.itemsize)
|
||||
setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer
|
||||
super().__init__(size, dtype, buf)
|
||||
|
||||
def __init__(self, size, dtype, device='0'): super().__init__(size, dtype, allocator=CL.cl_allocator, **{'device': device})
|
||||
def _copyin(self, x:np.ndarray):
|
||||
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
|
||||
self.event = cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements='C'), is_blocking=False)
|
||||
@@ -80,7 +86,7 @@ class CLProgram:
|
||||
def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size
|
||||
|
||||
def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]:
|
||||
cl_bufs = [x._buf if isinstance(x, CLBuffer) else x for x in bufs]
|
||||
cl_bufs = [x._buf if isinstance(x, CLBuffer) else np.int32(x) if isinstance(x, int) else x for x in bufs]
|
||||
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=[x.event for x in bufs if isinstance(x, CLBuffer) and hasattr(x, "event")])
|
||||
if wait:
|
||||
e.wait()
|
||||
@@ -91,9 +97,8 @@ class CLProgram:
|
||||
return None
|
||||
|
||||
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
|
||||
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ",
|
||||
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ", arg_int_prefix = "const int",
|
||||
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
|
||||
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
|
||||
gid = [f'get_group_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True))
|
||||
|
||||
GPUBuffer = Compiled(CLBuffer, LinearizerOptions(), renderer, CLProgram, CL.synchronize)
|
||||
|
||||
@@ -3,7 +3,7 @@ import ctypes, functools
|
||||
import extra.hip_wrapper as hip
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator
|
||||
from tinygrad.codegen.linearizer import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
||||
|
||||
@@ -14,9 +14,14 @@ if DEBUG >= 5:
|
||||
|
||||
# The default HIP stream is used for everything.
|
||||
|
||||
class HIPAllocator(LRUAllocator):
|
||||
def _do_alloc(self, size, dtype, device, **kwargs): return hip.hipMalloc(size * dtype.itemsize)
|
||||
def _do_free(self, buf): hip.hipFree(buf)
|
||||
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
|
||||
HIPAlloc = HIPAllocator(hip.hipGetDeviceProperties(hip.hipGetDevice()).totalGlobalMem)
|
||||
|
||||
class RawHIPBuffer(RawBufferCopyInOut):
|
||||
def __init__(self, size, dtype): super().__init__(size, dtype, hip.hipMalloc(size * dtype.itemsize))
|
||||
def __del__(self): hip.hipFree(self._buf)
|
||||
def __init__(self, size, dtype): super().__init__(size, dtype, allocator=HIPAlloc)
|
||||
def _copyin(self, x:np.ndarray): hip.hipMemcpyAsync_htod(self._buf, x.ctypes.data, self.size * self.dtype.itemsize, 0)
|
||||
def _copyout(self, x:np.ndarray): hip.hipMemcpy_dtoh(x.ctypes.data, self._buf, self.size * self.dtype.itemsize)
|
||||
|
||||
|
||||
@@ -1,20 +1,26 @@
|
||||
# pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch
|
||||
import os, subprocess, pathlib, functools
|
||||
import os, subprocess, pathlib, functools, ctypes
|
||||
import Metal, Cocoa, libdispatch # type: ignore
|
||||
from typing import List, Any
|
||||
from tinygrad.codegen.linearizer import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, DType
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawBufferMapped
|
||||
from tinygrad.runtime.lib import RawBufferMapped, LRUAllocator
|
||||
|
||||
METAL_XCODE = getenv("METAL_XCODE")
|
||||
|
||||
class MetalAllocator(LRUAllocator):
|
||||
def _do_alloc(self, size, dtype, device, **kwargs): return METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
|
||||
def _do_free(self, buf): buf.release()
|
||||
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
|
||||
|
||||
class _METAL:
|
||||
def __init__(self):
|
||||
self.mtl_buffers_in_flight: List[Any] = []
|
||||
self.device = Metal.MTLCreateSystemDefaultDevice()
|
||||
self.mtl_queue = self.device.newCommandQueue()
|
||||
self.allocator = MetalAllocator(self.device.dedicatedMemorySize() or self.device.sharedMemorySize())
|
||||
# TODO: is there a better way to do this?
|
||||
def synchronize(self):
|
||||
for cbuf in self.mtl_buffers_in_flight: cbuf.waitUntilCompleted()
|
||||
@@ -23,10 +29,8 @@ METAL = _METAL()
|
||||
|
||||
class RawMetalBuffer(RawBufferMapped):
|
||||
def __init__(self, size:int, dtype:DType):
|
||||
super().__init__(size, dtype, METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared))
|
||||
def __del__(self):
|
||||
self._buf.release()
|
||||
super().__del__()
|
||||
assert dtype != dtypes.double, f"METAL does not support {dtype.name}"
|
||||
super().__init__(size, dtype, allocator=METAL.allocator)
|
||||
def _buffer(self):
|
||||
METAL.synchronize()
|
||||
return self._buf.contents().as_buffer(self._buf.length())
|
||||
@@ -64,7 +68,10 @@ class MetalProgram:
|
||||
command_buffer = METAL.mtl_queue.commandBuffer()
|
||||
encoder = command_buffer.computeCommandEncoder()
|
||||
encoder.setComputePipelineState_(self.pipeline_state)
|
||||
for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a._buf, 0, i)
|
||||
for i,a in enumerate(bufs):
|
||||
if isinstance(a, RawMetalBuffer): encoder.setBuffer_offset_atIndex_(a._buf, 0, i)
|
||||
elif isinstance(a, int): encoder.setBytes_length_atIndex_((arg:=ctypes.c_int32(a)), ctypes.sizeof(arg), i)
|
||||
else: raise RuntimeError(f"arg at index {i} has unsupported type {type(a)}")
|
||||
encoder.dispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
encoder.endEncoding()
|
||||
command_buffer.commit()
|
||||
@@ -74,7 +81,7 @@ class MetalProgram:
|
||||
METAL.mtl_buffers_in_flight.append(command_buffer)
|
||||
|
||||
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
|
||||
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ",
|
||||
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ", arg_int_prefix = "constant int&",
|
||||
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4", uses_ptr_arithmetic=True,
|
||||
gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)],
|
||||
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']))
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool}
|
||||
type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool}
|
||||
inverse_type_map = {v:k for k,v in type_map.items()}
|
||||
|
||||
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import functools
|
||||
from wgpu.utils._device import get_default_device # type: ignore
|
||||
from tinygrad.runtime.lib import RawBufferCopyIn
|
||||
from tinygrad.runtime.lib import RawBufferCopyIn, LRUAllocator
|
||||
from tinygrad.helpers import dtypes, DType
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.codegen.linearizer import LinearizerOptions
|
||||
@@ -9,32 +9,37 @@ from tinygrad.renderer.cstyle import uops_to_cstyle
|
||||
from tinygrad.renderer.wgsl import WGSLLanguage
|
||||
import wgpu # type: ignore
|
||||
|
||||
device = get_default_device()
|
||||
wgpu_device = get_default_device()
|
||||
|
||||
class WebGPUProgram:
|
||||
def __init__(self, name: str, prg: str, binary=False): self.name,self.prg = name,device.create_shader_module(code=prg)
|
||||
def __init__(self, name: str, prg: str, binary=False): self.name,self.prg = name,wgpu_device.create_shader_module(code=prg)
|
||||
def __call__(self, global_size, local_size, *bufs, wait=False):
|
||||
assert len(bufs) <= 8, "WEBGPU only supports 8 buffers"
|
||||
binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))]
|
||||
bindings = [{"binding": i, "resource": {"buffer": x._buf, "offset": 0, "size": x._buf.size}} for i, x in enumerate(bufs)]
|
||||
bind_group_layout = device.create_bind_group_layout(entries=binding_layouts)
|
||||
pipeline_layout = device.create_pipeline_layout(bind_group_layouts=[bind_group_layout])
|
||||
bind_group = device.create_bind_group(layout=bind_group_layout, entries=bindings)
|
||||
compute_pipeline = device.create_compute_pipeline(layout=pipeline_layout,compute={"module": self.prg, "entry_point": self.name},)
|
||||
command_encoder = device.create_command_encoder()
|
||||
bind_group_layout = wgpu_device.create_bind_group_layout(entries=binding_layouts)
|
||||
pipeline_layout = wgpu_device.create_pipeline_layout(bind_group_layouts=[bind_group_layout])
|
||||
bind_group = wgpu_device.create_bind_group(layout=bind_group_layout, entries=bindings)
|
||||
compute_pipeline = wgpu_device.create_compute_pipeline(layout=pipeline_layout,compute={"module": self.prg, "entry_point": self.name},)
|
||||
command_encoder = wgpu_device.create_command_encoder()
|
||||
compute_pass = command_encoder.begin_compute_pass()
|
||||
compute_pass.set_pipeline(compute_pipeline)
|
||||
compute_pass.set_bind_group(0, bind_group, [], 0, 999999) # last 2 not used
|
||||
compute_pass.dispatch_workgroups(*global_size) # x y z
|
||||
compute_pass.end()
|
||||
device.queue.submit([command_encoder.finish()])
|
||||
wgpu_device.queue.submit([command_encoder.finish()])
|
||||
|
||||
class RawWebGPUAllocator(LRUAllocator):
|
||||
def _do_alloc(self, size, dtype, device, **kwargs): return wgpu_device.create_buffer(size=size*dtype.itemsize, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC)
|
||||
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
|
||||
WebGPUAlloc = RawWebGPUAllocator(wgpu_device.limits['max_buffer_size'])
|
||||
|
||||
class RawWebGPUBuffer(RawBufferCopyIn):
|
||||
def __init__(self, size:int, dtype:DType):
|
||||
assert dtype not in [dtypes.int8,dtypes.uint8,dtypes.int64,dtypes.uint64], f"dtype {dtype} not supported on WEBGPU"
|
||||
super().__init__(size, dtype, device.create_buffer(size=size*dtype.itemsize, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC))
|
||||
def _copyin(self, x:np.ndarray): device.queue.write_buffer(self._buf, 0, np.ascontiguousarray(x))
|
||||
def toCPU(self) -> np.ndarray: return np.frombuffer(device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
|
||||
assert dtype not in [dtypes.int8,dtypes.uint8,dtypes.int64,dtypes.uint64,dtypes.double], f"dtype {dtype} not supported on WEBGPU"
|
||||
super().__init__(size, dtype, allocator=WebGPUAlloc)
|
||||
def _copyin(self, x:np.ndarray): wgpu_device.queue.write_buffer(self._buf, 0, np.ascontiguousarray(x))
|
||||
def toCPU(self) -> np.ndarray: return np.frombuffer(wgpu_device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
|
||||
|
||||
renderer = functools.partial(uops_to_cstyle, WGSLLanguage())
|
||||
WebGpuBuffer = Compiled(RawWebGPUBuffer, LinearizerOptions(supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]), renderer, WebGPUProgram)
|
||||
|
||||
@@ -59,7 +59,7 @@ class View(ViewInternal):
|
||||
# generate an expression if you have a single idx variable
|
||||
def expr_node(self, idx=None) -> Node:
|
||||
if idx is None: idx = Variable('idx', 0, prod(self.shape)-1)
|
||||
ret: List[Node] = [Variable.num(self.offset)] if self.offset else []
|
||||
ret: List[Node] = [Variable.num(self.offset) if isinstance(self.offset, int) else self.offset] if self.offset else []
|
||||
acc = 1
|
||||
for d,s in reversed(self.shape_strides):
|
||||
ret.append(((idx//acc)%d)*s)
|
||||
@@ -69,7 +69,7 @@ class View(ViewInternal):
|
||||
# generate an expression if you have a variable or expression for each index
|
||||
def expr_idxs(self, idxs) -> Node:
|
||||
assert len(idxs) == len(self.shape), f"need an idx for all dimensions {idxs} vs {self.shape}"
|
||||
return Variable.sum([Variable.num(self.offset)] + [idx*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0])
|
||||
return Variable.sum([Variable.num(self.offset) if isinstance(self.offset, int) else self.offset] + [idx*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0])
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
|
||||
@@ -162,7 +162,7 @@ class ShapeTracker:
|
||||
idx, valid = self.expr_idxs(idxs)
|
||||
ret: List[Optional[Union[Node, int]]] = [None] * len(self.views[-1].shape)
|
||||
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
|
||||
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable):
|
||||
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable) and this_dim.a in idxs:
|
||||
ret[idxs.index(this_dim.a)] = this_dim.b
|
||||
elif isinstance(this_dim, Variable):
|
||||
ret[idxs.index(this_dim)] = 1
|
||||
|
||||
@@ -65,6 +65,7 @@ class Node:
|
||||
def __rfloordiv__(self, b:int): raise RuntimeError(f"not supported: {b} // {self}")
|
||||
def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
|
||||
if isinstance(b, Node):
|
||||
if self == b: return NumNode(1)
|
||||
if (b > self).min > 0 and self.min >= 0: return NumNode(0)
|
||||
raise RuntimeError(f"not supported: {self} // {b}")
|
||||
assert b != 0
|
||||
@@ -262,6 +263,14 @@ def create_rednode(typ:Type[RedNode], nodes:List[Node]):
|
||||
elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes]))
|
||||
return create_node(ret)
|
||||
|
||||
def sym_infer(n:Union[Node,int], var_vals: Dict[Variable, int]) -> int:
|
||||
if isinstance(n, (int, NumNode)): return int(n)
|
||||
if isinstance(n, Variable): return var_vals[n]
|
||||
if isinstance(n, MulNode): return sym_infer(n.a, var_vals) * sym_infer(n.b, var_vals)
|
||||
if isinstance(n, SumNode): return sum(sym_infer(s, var_vals) for s in n.nodes)
|
||||
raise NotImplementedError(n)
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}"
|
||||
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
|
||||
|
||||
render_python: Dict[Type, Callable] = {
|
||||
|
||||
@@ -274,12 +274,20 @@ class Tensor:
|
||||
# - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s]
|
||||
# is possible.
|
||||
# - Apply Shrink to do the slice [:, 0] on axes of shapes [dim_sz_padded // s, s].
|
||||
# - Fancy indexing and combined indexing is supported
|
||||
# - Combined indexing works by letting regular slicing finish first -> computing the resulting dims w.r.t to Tensors passed in -> fancy indexing
|
||||
# - Any Tensors passed in __getitem__ will perform (CMPEQ with arange -> MUL with self -> SUM_REDUCE) iteratively
|
||||
# - The first iteration will expand the dim of self while consecutive iterations will reduce the dim
|
||||
# - The dims are reduced at sum_dim for each Tensor passed in
|
||||
# - There's a special case where a permute is needed at the end:
|
||||
# - if first Tensor passed in (expand dims) is not at dim 0
|
||||
# - and following Tensors does not follow consecutively to the end of fancy indexing's dims
|
||||
def __getitem__(self, val):
|
||||
def normalize_int(e, i, dim_sz):
|
||||
if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1
|
||||
raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}")
|
||||
val = list(val) if isinstance(val, tuple) else [val]
|
||||
if (num_slices := sum(isinstance(v, (slice, int)) for v in val)) > len(self.shape):
|
||||
if (num_slices := sum(isinstance(v, (slice, int, Tensor)) for v in val)) > len(self.shape):
|
||||
raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
|
||||
orig_slices = list(val)
|
||||
ellipses_found = [i for i, v in enumerate(val) if v is Ellipsis]
|
||||
@@ -290,6 +298,8 @@ class Tensor:
|
||||
orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
|
||||
else:
|
||||
orig_slices += [slice(None)] * (len(self.shape) - num_slices)
|
||||
tensor_found = [(i,v) for i, v in enumerate(orig_slices) if isinstance(v, Tensor)]
|
||||
orig_slices = [slice(None, None, None) if isinstance(v, Tensor) else v for v in orig_slices]
|
||||
valid_slices = list(filterfalse(lambda x: x is None, orig_slices))
|
||||
valid_slices = [v if isinstance(v, slice) else slice(y := normalize_int(v, i, dim_sz), y+1) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))]
|
||||
start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ())
|
||||
@@ -315,21 +325,46 @@ class Tensor:
|
||||
final_slice = reduce(operator.add, (((0, sh), (0, 1)) for sh in new_shape), ())
|
||||
sliced_tensor = reshaped_tensor.shrink(final_slice)
|
||||
final_shape = []
|
||||
sub = [0] * len(tensor_found)
|
||||
it_shape = iter(new_shape)
|
||||
for i in orig_slices:
|
||||
if isinstance(i, (int, slice)):
|
||||
for i,s in enumerate(orig_slices):
|
||||
if isinstance(s, (int, slice)):
|
||||
dim_shape = next(it_shape)
|
||||
if isinstance(i, slice): final_shape.append(dim_shape)
|
||||
else: # i is None
|
||||
if isinstance(s, slice): final_shape.append(dim_shape)
|
||||
elif tensor_found:
|
||||
for i_ in range(len(tensor_found)):
|
||||
if tensor_found[i_][0] > i: sub[i_] -= 1
|
||||
else: # s is None
|
||||
final_shape.append(1)
|
||||
return sliced_tensor.reshape(tuple(final_shape)) # Reshape
|
||||
ret = sliced_tensor.reshape(tuple(final_shape)) # Reshape
|
||||
if tensor_found: # Fancy/tensor indexing
|
||||
for i,s in enumerate(sub): tensor_found[i] = (tensor_found[i][0]+s, tensor_found[i][1])
|
||||
dim = [i[0] for i in tensor_found]
|
||||
idx = [i[1].sign().contiguous().__neg__().contiguous().relu() * ret.shape[i[0]] + i[1] for i in tensor_found] # TODO first contiguous fixes torch+cpu_only CI, but it causes llvm to fail. Second one fixes llvm
|
||||
max_dim = max(idx, key=lambda i: i.ndim).ndim
|
||||
idx = [i if i.ndim == max_dim else i.reshape(*[1]*(max_dim-i.ndim), *i.shape) for i in idx]
|
||||
sum_dim = [d if n==0 else d+i.ndim-n for n,(d,i) in enumerate(zip(dim,idx))]
|
||||
new_idx = idx[0].reshape(*[1]*sum_dim[0], 1, *idx[0].shape, *[1]*(ret.ndim-sum_dim[0]-1))
|
||||
arange = Tensor.arange(ret.shape[sum_dim[0]], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*sum_dim[0], ret.shape[sum_dim[0]], *[1]*idx[0].ndim, *[1]*(ret.ndim-sum_dim[0]-1))
|
||||
ret = (ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*idx[0].ndim, *ret.shape[sum_dim[0]+1:]) * (arange == new_idx)).sum(sum_dim[0])
|
||||
for idx_,d in zip(idx[1:],sum_dim[1:]):
|
||||
new_idx = idx_.reshape(*[1]*sum_dim[0], *idx_.shape, *[1]*(ret.ndim-sum_dim[0]-idx_.ndim))
|
||||
arange = Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*(d), ret.shape[d], *[1]*(ret.ndim-d-1))
|
||||
ret = ((new_idx == arange) * ret).sum(d)
|
||||
if dim[0] != 0 and dim != list(range(dim[0], dim[-1]+1)) and len(dim) != 1: # special permute case
|
||||
order = list(range(ret.ndim))
|
||||
order = order[dim[0]:dim[0]+idx[0].ndim] + order[:dim[0]] + order[dim[0]+idx[0].ndim:]
|
||||
ret = ret.permute(order=order)
|
||||
return ret
|
||||
|
||||
def gather(self, idx, dim):
|
||||
idx = (idx < 0).where(idx+self.shape[dim], idx) # Turn neg idx pos
|
||||
new_self = self.reshape(*self.shape[:dim+1], *[1]*idx.ndim, *self.shape[dim+1:])
|
||||
arange = Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*dim, self.shape[dim], *[1]*idx.ndim, *[1]*(self.ndim-dim-1))
|
||||
new_idx = idx.reshape(*[1]*dim, 1, *idx.shape, *[1]*(self.ndim-dim-1))
|
||||
return (new_self * (arange == new_idx)).sum(dim)
|
||||
def gather(self: Tensor, idx: Tensor, dim: int):
|
||||
assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
|
||||
assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
|
||||
if dim < 0: dim += self.ndim
|
||||
idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1)
|
||||
permarg = list(range(self.ndim))
|
||||
permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
|
||||
return ((idx == Tensor.arange(self.shape[dim])) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim)
|
||||
|
||||
def cat(self, *args, dim=0):
|
||||
dim = (dim + len(self.shape)) if dim < 0 else dim
|
||||
|
||||
Reference in New Issue
Block a user