diff --git a/examples/llama.py b/examples/llama.py index 36cd874998..0211c0b274 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -40,6 +40,11 @@ def apply_rotary_emb(xq, xk, freqs_cis) -> Tuple[Tensor, Tensor]: xk_out = complex_mult(xk, c, d) return xq_out.flatten(3), xk_out.flatten(3) +def repeat_kv(x:Tensor, n_rep:int) -> Tensor: + bs, seqlen, n_kv_heads, head_dim = x.shape + if n_rep == 1: return x + return x[:, :, :, None, :].expand(bs, seqlen, n_kv_heads, n_rep, head_dim).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim) + class RMSNorm: def __init__(self, dim, eps=1e-6): self.eps = eps @@ -50,14 +55,22 @@ class RMSNorm: return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight class Attention: - def __init__(self, dim, n_heads, linear=Linear): - self.wq, self.wk, self.wv, self.wo = [linear(dim, dim, bias=False) for _ in range(4)] + def __init__(self, dim, n_heads, n_kv_heads, linear=Linear): self.n_heads = n_heads + self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads self.head_dim = dim // n_heads + self.n_rep = self.n_heads // self.n_kv_heads + + self.wq = linear(dim, self.n_heads * self.head_dim, bias=False) + self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = linear(self.n_heads * self.head_dim, dim, bias=False) def prepare_attention(self, x:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor, Tensor]: xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], self.n_heads, self.head_dim) for x in (xq, xk, xv)] + xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim) + xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim) + xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) return xq, xk, xv @@ -74,6 +87,8 @@ class Attention: # save the cache self.cache_k, self.cache_v = keys.realize(), values.realize() + + keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep) return Tensor.scaled_dot_product_attention(xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), mask).transpose(1, 2).reshape(bsz, seqlen, -1) # NOTE: this is not called @@ -98,8 +113,8 @@ class FeedForward: return self.w2(self.w1(x).silu() * self.w3(x)) class TransformerBlock: - def __init__(self, dim, multiple_of, n_heads, norm_eps, linear=Linear, ffn_dim_multiplier=None): - self.attention = Attention(dim, n_heads, linear) + def __init__(self, dim, multiple_of, n_heads, n_kv_heads, norm_eps, linear=Linear, ffn_dim_multiplier=None): + self.attention = Attention(dim, n_heads, n_kv_heads, linear) self.feed_forward = FeedForward(dim, 4*dim, multiple_of, linear, ffn_dim_multiplier) self.attention_norm = RMSNorm(dim, norm_eps) self.ffn_norm = RMSNorm(dim, norm_eps) @@ -125,8 +140,8 @@ class TransformerBlock: return self._post(x, output) if mask is None else self.post(x, output) class Transformer: - def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None): - self.layers = [TransformerBlock(dim, multiple_of, n_heads, norm_eps, linear, ffn_dim_multiplier) for _ in range(n_layers)] + def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None, n_kv_heads=None): + self.layers = [TransformerBlock(dim, multiple_of, n_heads, n_kv_heads, norm_eps, linear, ffn_dim_multiplier) for _ in range(n_layers)] self.norm = RMSNorm(dim, norm_eps) self.tok_embeddings = Embedding(vocab_size, dim) self.output = linear(dim, vocab_size, bias=False) @@ -173,11 +188,10 @@ MODEL_PARAMS = { "args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE}, "files": 2, }, -# # 70B is disabled because we do not yet implement n_kv_heads argument -# "70B": { -# "args": {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE}, -# "files": 8, -# }, + "70B": { + "args": {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE}, + "files": 8, + }, }, } @@ -277,7 +291,7 @@ if __name__ == "__main__": parser.add_argument('--temperature', type=float, default=0.7, help="Temperature in the softmax") parser.add_argument('--timing', action='store_true', help="Print timing per token") parser.add_argument('--profile', action='store_true', help="Output profile data to out.prof") - parser.add_argument('--size', type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B] for Gen 1, [7B, 13B] for Gen 2") + parser.add_argument('--size', type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B] for Gen 1, [7B, 13B, 70B] for Gen 2") parser.add_argument('--gen', type=int, default="1", help="Generation of the model to use [1, 2]") parser.add_argument('--quantize', action='store_true', help="Quantize the weights to int8 in memory") diff --git a/extra/assembly/assembly_ptx.py b/extra/assembly/assembly_ptx.py deleted file mode 100644 index 8781b701f5..0000000000 --- a/extra/assembly/assembly_ptx.py +++ /dev/null @@ -1,61 +0,0 @@ -import struct -from tinygrad.codegen.assembly import AssemblyCodegen -from tinygrad.codegen.linearizer import UOps -from tinygrad.helpers import dtypes -from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps -from tinygrad.runtime.ops_cuda import arch - -dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "u16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32"} -def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1]) - -# https://docs.nvidia.com/cuda/parallel-thread-execution/# -class PTXCodegen(AssemblyCodegen): - #supports_constant_folding: bool = True - - def specialize(self, asm): - ins = [".version 8.2", ".target " + arch(), ".address_size 64", - f".visible .entry test({', '.join(f'.param .u64 buf{i}' for i in range(len(self.bufs)))}) {{"] - - alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max", - BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", BinaryOps.CMPEQ: "setp.eq", UnaryOps.SQRT: "sqrt.approx", - UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz", - TernaryOps.MULACC: "fma.rn"} - - for uop, out, vin, arg in asm: - if uop == UOps.DEFINE_REGISTER: - ins.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",) - elif uop == UOps.DEFINE_LOCAL: - ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];") - elif uop == UOps.SPECIAL: - if arg.startswith('buf'): - ins.append(f"ld.param.u64 {out}, [{arg}];") - # TODO: is this needed? - #ins.append(f"cvta.to.global.u64 {out}, {out};") - elif arg.startswith('gid'): - ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};") - elif arg.startswith('lid'): - ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};") - elif uop == UOps.ALU: - if arg == BinaryOps.MUL and out.dtype == dtypes.bool: - ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};") - else: - otype = vin[0].dtype if arg in [BinaryOps.CMPEQ, BinaryOps.CMPLT] else out.dtype - ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};") - elif uop == UOps.LOAD: - ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[out.dtype]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];") - elif uop == UOps.STORE: - ins.append(f"st.{arg[1]}.{dtype_to_nvtype[vin[1].dtype]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};") - elif uop == UOps.CAST: - if vin[0].dtype == dtypes.bool: - ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, 0f3F800000, 0f00000000, {vin[0]};") - else: - ins.append(f"cvt.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[vin[0].dtype]} {out}, {vin[0]};") - elif uop == UOps.CONST: - ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else arg};") - elif uop == UOps.LABEL: - ins.append(f"{arg}:") - elif uop == UOps.COND_BRANCH: - ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};") - - ins += ["ret;", "}"] - return "test", '\n'.join(ins) diff --git a/extra/onnx.py b/extra/onnx.py index 7477771570..ddb5d0fd29 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -21,7 +21,9 @@ def safe_numpy(t) -> np.ndarray: if t not in numpy_cache: if DEBUG >= 1: print("numpy cache miss", t) - numpy_cache[t] = t.numpy() + tmp = t.numpy() + numpy_cache[t] = tmp if len(tmp.shape) else tmp.reshape(1) + assert len(numpy_cache[t].shape) > 0 return numpy_cache[t] onnx_ops = importlib.import_module('extra.onnx_ops') @@ -92,7 +94,7 @@ def get_run_onnx(onnx_model: ModelProto): if inp.name in tensors: continue tmp=inp.type.optional_type.elem_type.tensor_type if inp.type.HasField("optional_type") else (inp.type.sequence_type.elem_type.tensor_type if inp.type.HasField("sequence_type") else inp.type.tensor_type) shape = shape_to_tuple(tmp.shape) - if len(shape) >= 1 and shape[0] == 0: shape = tuple([1]+list(shape[1:])) # 1 batch size + if len(shape) >= 1: shape = tuple([x if x != 0 else 1 for x in shape]) # replace all dynamic dims with 1 for now if inp.name in inputs: if isinstance(inputs[inp.name], Tensor): input_tensors[inp.name] = inputs[inp.name] @@ -183,6 +185,8 @@ def get_run_onnx(onnx_model: ModelProto): steps = safe_numpy(inp[4])[0] if len(inp) > 4 else 1 starts, ends = safe_numpy(starts.cast(dtypes.int32)).tolist(), safe_numpy(ends.cast(dtypes.int32)).tolist() # TODO: when indexing is added use that for i,axis in enumerate(axes.tolist()): + assert axis % 1 == 0 + axis = int(axis) arg[axis] = (starts[i] if starts[i] >= 0 else inp[0].shape[axis]+starts[i], ends[i] if ends[i] >= 0 else inp[0].shape[axis]+ends[i]) ret = inp[0].slice(arg=arg) elif n.op_type == "Shrink": diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 0129279d8f..571ae103aa 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -1,9 +1,10 @@ from tinygrad.tensor import Tensor from tinygrad.helpers import prod, dtypes from extra.onnx import safe_numpy +from onnx.helper import tensor_dtype_to_np_dtype import numpy as np import functools -from typing import Union, Tuple +from typing import Union, Tuple, Optional import math def Unsqueeze(data, axes): @@ -201,7 +202,7 @@ def Tile(input, repeats): final_shape = [r*s for r,s in zip(repeats_, input.shape)] return input.reshape(new_shape).expand(expand_shape).reshape(final_shape) -def Range(start, limit, delta): return Tensor.arange(safe_numpy(start)[0], safe_numpy(limit)[0], step=safe_numpy(delta)[0]) +def Range(start, limit, delta): return Tensor.arange(*[safe_numpy(x)[0].item() for x in (start, limit, delta)]) def Where(condition:Tensor,X:Tensor,Y:Tensor): return condition.where(X, Y).cast(X.dtype) def And(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.zeros(*x.shape)).cast(dtypes.bool) @@ -218,10 +219,7 @@ def ConstantOfShape(input, value:Tensor=None): shape = [int(x) for x in safe_numpy(input)] return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0]!=0 else 1) -# this is obviously wrong, but since we don't have types, it's better than nothing -def Cast(input, to): - print(f"WARNING: attempting to cast to {to}") - return input +def Cast(input, to): return input.cast(dtypes.from_np(tensor_dtype_to_np_dtype(to))) # NOTE: since we only have one type, this is valid! def CastLike(input, target_type): @@ -263,3 +261,70 @@ def OneHot(indices, depth, values, axis=-1): def Floor(x:Tensor): return x.floor() def Ceil(x:Tensor): return x.ceil() + +def EmbedLayerNormalization(input_ids, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None, segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None, position_ids:Optional[Tensor]=None, epsilon=None, mask_index_type=None): + # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization + assert (segment_ids is None) is (segment_embedding is None) + assert (mask is None) is (mask_index_type is None) + assert mask is None, "functionality not supported yet" # TODO + input_shape = input_ids.shape + bsz, seq_length = input_shape[0], input_shape[1] + compute_seg_emb = (segment_embedding is not None and segment_ids is not None) + vocab_size, max_position_embeddings, type_vocab_size = word_embedding.shape[0], position_embedding.shape[0], (segment_embedding.shape[0] if compute_seg_emb else None) + + def embedding(x:Tensor, vocab_size, weight:Tensor)->Tensor: # TODO from nn.Embedding. Could probably upstream this to Tensor + vocab_counter = Tensor.arange(vocab_size, dtype=x.dtype, requires_grad=False).reshape(1, 1, vocab_size).expand(*x.shape, vocab_size) + return (vocab_counter == x.unsqueeze(2).expand(*x.shape, vocab_size)) @ weight + + # bert embedding layer + if epsilon is None: epsilon = 1e-12 + if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape) + wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding) + pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding) + seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None + + embedding_sum = wrd_embedding_res + pos_embedding_res + seg_embedding_res + out = embedding_sum.layernorm(eps=epsilon) * gamma + beta + return out, None, embedding_sum + +def Attention(input:Tensor, weights, bias:Optional[Tensor]=None, mask_index:Optional[Tensor]=None, past:Optional[Tensor]=None, relative_position_bias:Optional[Tensor]=None, past_sequence_length:Optional[Tensor]=None, do_rotary=None, mask_filter_value=None, num_heads=None, past_present_share_buffer=None, qkv_hidden_sizes=None, scale=None, unidirectional=None): + # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention + assert num_heads is not None # required + assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None) + assert relative_position_bias==do_rotary==past_sequence_length==mask_filter_value==past_present_share_buffer==scale==None, "functionality not supported yet" # TODO strange params + hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,) + + if unidirectional: # gpt-style + assert hidden_size == v_hidden_size + xqkv = input.linear(weights, bias) + xq, xk, xv = [xqkv.slice([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)] + else: # bert-style + wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:] + bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None + xq, xk, xv = [input.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))] + xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)] + + if past is not None: + xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2) + present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0)) + + def attn(query, key, value, attn_mask): + query_length, key_length = query.shape[-2], key.shape[-2] + cdim = max(query_length, key_length) + 1 + attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1]) + # This is where Tensor.scaled_dot_product_attention differs: + causal_mask = Tensor.ones((cdim, cdim), requires_grad=False).cast(dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length].cast(dtypes.bool) + return (Tensor.where(causal_mask, attn_weights, -float("inf")) + attn_mask).softmax(-1) @ value + + bsz, _, seq_len, _ = xq.shape + out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1) + return out, present + +def SkipLayerNormalization(input:Tensor, skip:Tensor, gamma, beta:Optional[Tensor]=None, bias:Optional[Tensor]=None, epsilon=None): + if epsilon is None: epsilon=1e-12 + x = input + skip + bias + return x.layernorm(eps=epsilon) * gamma + beta, None, None, x + +def FastGelu(x:Tensor, bias:Optional[Tensor]=None): + x = x + bias + return 0.5 * x * (1 + (x * 0.797885 + 0.035677 * x ** 3).tanh()) diff --git a/test/external/external_model_benchmark.py b/test/external/external_model_benchmark.py index 0020610af2..1c4aadd70a 100644 --- a/test/external/external_model_benchmark.py +++ b/test/external/external_model_benchmark.py @@ -1,13 +1,14 @@ -import csv -import pathlib -import time -import onnx +import csv, pathlib, time, numpy as np +from os import getenv import torch torch.set_num_threads(1) +import onnx +from onnx.helper import tensor_dtype_to_np_dtype +import onnxruntime as ort from onnx2torch import convert from extra.utils import download_file from extra.onnx import get_run_onnx -from tinygrad.helpers import OSX +from tinygrad.helpers import OSX, DEBUG from tinygrad.tensor import Tensor from tinygrad.lazy import Device @@ -16,6 +17,7 @@ MODELS = { "openpilot": "https://github.com/commaai/openpilot/raw/7da48ebdba5e3cf4c0b8078c934bee9a199f0280/selfdrive/modeld/models/supercombo.onnx", "efficientnet": "https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx", "shufflenet": "https://github.com/onnx/models/raw/main/vision/classification/shufflenet/model/shufflenet-9.onnx", + "commavq": "https://github.com/commaai/commavq/raw/master/models/gpt2m.onnx", # broken in torch MPS #"zfnet": "https://github.com/onnx/models/raw/main/vision/classification/zfnet-512/model/zfnet512-9.onnx", @@ -29,6 +31,7 @@ MODELS = { CSV = {} open_csv = None +torch.manual_seed(1) def benchmark(mnm, nm, fxn): tms = [] @@ -36,26 +39,31 @@ def benchmark(mnm, nm, fxn): st = time.perf_counter_ns() ret = fxn() tms.append(time.perf_counter_ns() - st) - print(f"{m:15s} {nm:25s} {min(tms)*1e-6:7.2f} ms") + print(f"{mnm:15s} {nm:25s} {min(tms)*1e-6:7.2f} ms") CSV[nm] = min(tms)*1e-6 return min(tms), ret #BASE = pathlib.Path(__file__).parent.parent.parent / "weights" / "onnx" BASE = pathlib.Path("/tmp/onnx") -def benchmark_model(m): +def benchmark_model(m, validate_outs=False): global open_csv, CSV CSV = {"model": m} fn = BASE / MODELS[m].split("/")[-1] download_file(MODELS[m], fn) onnx_model = onnx.load(fn) - + output_names = [out.name for out in onnx_model.graph.output] excluded = {inp.name for inp in onnx_model.graph.initializer} input_shapes = {inp.name:tuple(x.dim_value if x.dim_value != 0 else 1 for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input if inp.name not in excluded} - np_inputs = {k:torch.randn(shp).numpy() for k,shp in input_shapes.items()} - assert len(input_shapes) < 20 + input_types = {inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input if inp.name not in excluded} + #input_types = {k:v if v!=np.float16 else np.float32 for k,v in input_types.items()} # cast + np_inputs = {k:torch.randn(shp).numpy().astype(input_types[k]) for k,shp in input_shapes.items()} + assert len(input_shapes) < 30, f"too many input shapes {len(input_shapes)}" - for device in ["METAL" if OSX else "GPU", "CLANG"]: + # print input names + if DEBUG >= 2: print([inp.name for inp in onnx_model.graph.input if inp.name not in excluded]) + + for device in ["METAL" if OSX else "GPU", "CLANG"]: # + (["CUDA"] if torch.cuda.is_available() else []): Device.DEFAULT = device inputs = {k:Tensor(inp) for k,inp in np_inputs.items()} tinygrad_model = get_run_onnx(onnx_model) @@ -67,19 +75,55 @@ def benchmark_model(m): benchmark(m, f"tinygrad_{device.lower()}_jit", lambda: {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}) del inputs, tinygrad_model, tinygrad_jitted_model - torch_model = convert(onnx_model) - torch_inputs = [torch.tensor(x) for x in np_inputs.values()] - benchmark(m, "torch_cpu", lambda: torch_model(*torch_inputs)) + try: + torch_model = convert(onnx_model) + torch_inputs = [torch.tensor(x) for x in np_inputs.values()] + benchmark(m, "torch_cpu", lambda: torch_model(*torch_inputs)) - torch_device = "mps" if OSX else "cuda" - torch_mps_model = torch_model.to(torch_device) - torch_mps_inputs = [x.to(torch_device) for x in torch_inputs] - benchmark(m, f"torch_{torch_device}", lambda: torch_mps_model(*torch_mps_inputs)) + torch_device = "mps" if OSX else "cuda" + torch_mps_model = torch_model.to(torch_device) + torch_mps_inputs = [x.to(torch_device) for x in torch_inputs] + benchmark(m, f"torch_{torch_device}", lambda: torch_mps_model(*torch_mps_inputs)) + except NotImplementedError: + print(f"{m:16s}onnx2torch doesn't support this model") + + # bench onnxruntime + ort_options = ort.SessionOptions() + ort_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + ort_options.log_severity_level = 3 # no warnings + for backend in ["CPU", "CUDA" if not OSX else "CoreML"]: # https://onnxruntime.ai/docs/execution-providers/ + provider = backend+"ExecutionProvider" + if provider not in ort.get_available_providers(): continue + ort_sess = ort.InferenceSession(str(fn), ort_options, [provider]) + benchmark(m, f"onnxruntime_{backend.lower()}", lambda: ort_sess.run(output_names, np_inputs)) + del ort_sess + + if validate_outs: + rtol, atol = 8e-4, 8e-4 # tolerance for fp16 models + inputs = {k:Tensor(inp) for k,inp in np_inputs.items()} + tinygrad_model = get_run_onnx(onnx_model) + tinygrad_out = tinygrad_model(inputs) + + ort_sess = ort.InferenceSession(str(fn), ort_options, ["CPUExecutionProvider"]) + onnx_out = ort_sess.run(output_names, np_inputs) + onnx_out = dict([*[(name,x) for name, x in zip(output_names, onnx_out)]]) + + assert_allclose(tinygrad_out, onnx_out, rtol=rtol, atol=atol) + print(f"{m:16s}outputs validated with rtol={rtol:.1e}, atol={atol:.1e}") if open_csv is None: open_csv = csv.DictWriter(open('onnx_inference_speed.csv', 'w', newline=''), fieldnames=list(CSV.keys())) open_csv.writeheader() open_csv.writerow(CSV) +def assert_allclose(tiny_out:dict, onnx_out:dict, rtol=1e-5, atol=1e-5): + assert len(tiny_out) == len(onnx_out) and tiny_out.keys() == onnx_out.keys() + for k in tiny_out.keys(): + tiny_v, onnx_v = tiny_out[k], onnx_out[k] + if tiny_v is None: assert tiny_v == onnx_v + else: np.testing.assert_allclose(tiny_v.numpy(), onnx_v, rtol=rtol, atol=atol, err_msg=f"For tensor '{k}' in {tiny_out.keys()}") + if __name__ == "__main__": - for m in MODELS: benchmark_model(m) + if getenv("MODEL", "") != "": benchmark_model(getenv("MODEL", ""), True) + else: + for m in MODELS: benchmark_model(m, True) diff --git a/test/external/external_test_allocator_on_models.py b/test/external/external_test_allocator_on_models.py new file mode 100644 index 0000000000..36717945d3 --- /dev/null +++ b/test/external/external_test_allocator_on_models.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python +import unittest, gc +import numpy as np +from tinygrad.tensor import Tensor +from tinygrad.state import get_parameters, get_state_dict +from tinygrad.ops import GlobalCounters, LazyOp, LoadOps +from tinygrad.runtime.lib import RawBuffer, LRUAllocator +from tinygrad.helpers import dtypes, prod +from tinygrad.lazy import Device + +from examples.llama import Transformer + +ALLOCATED_DEV_BUFS = 0 +class FakeDeviceBuffer(): + def __init__(self, sz, dt, device): + self.id = 1 + self.size = sz + self.dtype = dt + self.device = device + + global ALLOCATED_DEV_BUFS + ALLOCATED_DEV_BUFS += 1 +class FakeAllocator(LRUAllocator): + def _do_alloc(self, size, dtype, device, **kwargs): return FakeDeviceBuffer(size, dtype, device) + def _do_free(self, buf): + buf.id -= 1 + assert buf.id == 0, f"Free should be called once, but {buf.id}" + +FAKE_GLOBAL_ALLOCATOR = None +class FakeBuffer(RawBuffer): + def __init__(self, size, dtype, device='0'): + global FAKE_GLOBAL_ALLOCATOR + super().__init__(size, dtype, allocator=FAKE_GLOBAL_ALLOCATOR, **{'device': device}) + assert self._buf.size == size and self._buf.dtype == dtype and self._buf.device == device, "This allocator requires 100% match of dtype and size." + @classmethod + def fromCPU(cls, x:np.ndarray, **kwargs): return cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs) + def toCPU(self): return np.empty(self.size, dtype=self.dtype.np) +class FakeProgram: + def __init__(self, name:str, prg:str): pass + def __call__(self, global_size, local_size, *bufs, wait=False): pass + +def helper_test_correctness(gen, train): + from tinygrad.runtime.ops_gpu import CL, CLAllocator + old_alloc = CL.cl_allocator + CL.cl_allocator = CLAllocator(0) + no_alloc_result = train(*gen()).numpy() + Device[Device.DEFAULT].synchronize() + CL.cl_allocator = CLAllocator(512<<30) # Test cache correctness, so cache as much as possible, 512gb + for _ in range(4): + GlobalCounters.reset() + np.testing.assert_allclose(train(*gen()).numpy(), no_alloc_result, rtol=1e-3, atol=1e-5) + Device[Device.DEFAULT].synchronize() + assert len(CL.cl_allocator.cached_buffers) != 0, "Cache must be used" + CL.cl_allocator = old_alloc + +def __helper_test_alloc_count(gen, train): + was_alloc = ALLOCATED_DEV_BUFS + for _ in range(2): + train(*gen()) + return ALLOCATED_DEV_BUFS - was_alloc + +def helper_test_alloc_count(mm, gen, train): + global FAKE_GLOBAL_ALLOCATOR + backup_program = Device[Device.DEFAULT].runtime + backup_buffer = Device[Device.DEFAULT].buffer + Device[Device.DEFAULT].runtime = FakeProgram + Device[Device.DEFAULT].buffer = FakeBuffer + Device[Device.DEFAULT].method_cache.clear() + FAKE_GLOBAL_ALLOCATOR = FakeAllocator(16<<30) + new_allocs = __helper_test_alloc_count(gen, train) + Device[Device.DEFAULT].method_cache.clear() + FAKE_GLOBAL_ALLOCATOR = FakeAllocator(0) + old_allocs = __helper_test_alloc_count(gen, train) + print(f"{mm}: llama: old allocs count {old_allocs}, new allocs count {new_allocs}") + assert new_allocs < old_allocs, f"Hmm, doesn't cache work any more?" + Device[Device.DEFAULT].runtime = backup_program + Device[Device.DEFAULT].buffer = backup_buffer + FAKE_GLOBAL_ALLOCATOR = None + +def check_gc(): + if Device.DEFAULT == "GPU": + gc.collect() # Need to collect Tensors. + from extra.introspection import print_objects + assert print_objects() == 0 + +# for speed +def derandomize(x): + if isinstance(x, LazyOp): + if x.op == LoadOps.RAND: x.op = LoadOps.EMPTY + x.src = [derandomize(s) for s in x.src] + else: + x.op = derandomize(x.op) + return x + +def derandomize_model(model): + for p in get_parameters(model): + p.lazydata = derandomize(p.lazydata) + p.realize() + +class TestAllocators(unittest.TestCase): + @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") + def test_lru_allocator_tiny_llama(self): + old_type = Tensor.default_type + Tensor.default_type = dtypes.float16 + + args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000} + def __test(): + model = Transformer(**args_tiny) + derandomize_model(model) + def test(t): return model(t, 0).realize() + helper_test_correctness(lambda: (Tensor([[1,]]),), test) + __test() + Tensor.default_type = old_type + check_gc() + + @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") + def test_lru_allocator_tiny_llama_alloc_counts(self): + args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000} + def test_alloc_count(t): + model = Transformer(**args_tiny) + for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype)) + return model(t, 0).realize() + helper_test_alloc_count("llama", lambda: (Tensor([[2,]]),), test_alloc_count) + check_gc() + + @unittest.skip("huge for CI") + def test_stable_diffusion(self): + from examples.stable_diffusion import UNetModel + model = UNetModel() + derandomize_model(model) + def test(t, t2): return model(t, 801, t2).realize() + helper_test_correctness(lambda: (Tensor.randn(1, 4, 16, 16),Tensor.randn(1, 77, 768)), test) + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_allocators.py b/test/test_allocators.py new file mode 100644 index 0000000000..5480debbc4 --- /dev/null +++ b/test/test_allocators.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +import unittest +import numpy as np +from weakref import ref +from tinygrad.ops import GlobalCounters +from tinygrad.runtime.lib import RawBuffer, LRUAllocator +from tinygrad.helpers import dtypes, prod +from tinygrad.lazy import Device + +def check_gc(): + if Device.DEFAULT == "GPU": + from extra.introspection import print_objects + assert print_objects() == 0 + +class FakeDeviceBuffer(): + def __init__(self, sz, dt, device): + self.id = 1 + self.size = sz + self.dtype = dt + self.device = device + def __del__(self): + assert self.id == 0, "Should called _do_free() before" + +class FakeAllocator(LRUAllocator): + def _do_alloc(self, size, dtype, device, **kwargs): return FakeDeviceBuffer(size, dtype, device) + def _do_free(self, buf): + buf.id -= 1 + assert buf.id == 0, f"Free should be called once, but {buf.id}" + +FAKE_GLOBAL_ALLOCATOR = None +class FakeBuffer(RawBuffer): + def __init__(self, size, dtype, device='0'): + global FAKE_GLOBAL_ALLOCATOR + super().__init__(size, dtype, allocator=FAKE_GLOBAL_ALLOCATOR, **{'device': device}) + assert self._buf.size == size and self._buf.dtype == dtype and self._buf.device == device, "This allocator requires 100% match of dtype and size." + @classmethod + def fromCPU(cls, x:np.ndarray, **kwargs): return cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs) + def toCPU(self): return np.empty(self.size, dtype=self.dtype.np) + +def alloc(allocator, size, dtype, **kwargs): + global FAKE_GLOBAL_ALLOCATOR + FAKE_GLOBAL_ALLOCATOR = allocator + buf = FakeBuffer(size, dtype, **kwargs) + assert buf.dtype == dtype and buf.size == size + FAKE_GLOBAL_ALLOCATOR = None + return buf + +def alloc_free_trace(allocator, size, dtype, **kwargs): + buf = alloc(allocator, size, dtype, **kwargs) + return ref(buf._buf) + +def cmp_trace_and_buf(buf, trace_ref): return trace_ref and trace_ref() == buf._buf + +class TestAllocators(unittest.TestCase): + def test_lru_allocator_reusage(self): + def test(): + lru_allocator = FakeAllocator(2048) + traced_buf = alloc_free_trace(lru_allocator, 16, dtypes.float32) + assert GlobalCounters.mem_cached == 16*dtypes.float32.itemsize, "Buffer should be cached" + for _ in range(32): + def __test(): + buf = alloc(lru_allocator, 16, dtypes.float32) + assert cmp_trace_and_buf(buf, traced_buf), "Buffer should be reused" + __test() + + usedbuf = alloc(lru_allocator, 16, dtypes.float32) + for _ in range(32): + def __test(): + buf = alloc(lru_allocator, 16, dtypes.float32) + assert usedbuf != buf, "Nobody should get used buffer" + __test() + assert GlobalCounters.mem_used == 16*dtypes.float32.itemsize, "Only usedbuf is still allocated." + test() + check_gc() + + def test_lru_allocator_cache_free(self): + def test(): + lru_allocator = FakeAllocator(128) + refs = [] + for _ in range(32): + refs.append(alloc_free_trace(lru_allocator, 16, dtypes.float32)) + for sz in range(32): + alloc_free_trace(lru_allocator, sz, dtypes.float32) + assert GlobalCounters.mem_used + GlobalCounters.mem_cached <= 128, "Should not allocate on device more than allowed (128)" + for r in refs: assert r() is None, "All refs should be dead, since buffers were cleared from cache" + test() + check_gc() + + def test_lru_allocator_multidevice(self): + def test(): + lru_allocator = FakeAllocator(256) + refs=[] + for i in range(8): + refs.append(alloc_free_trace(lru_allocator, 16, dtypes.float32, device=str(i))) + for i in range(64): + def __test(): + dev = str(i % 8) + buf = alloc(lru_allocator, 16, dtypes.float32, device=dev) + assert cmp_trace_and_buf(buf, refs[i%8]), "Buffer should be reused" + __test() + for r in refs: assert r() is not None, "All refs should be cached" + test() + check_gc() + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_dtype.py b/test/test_dtype.py index 8849f835d8..d451252104 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -91,6 +91,12 @@ class TestHalfDtype(unittest.TestCase): def test_half_upcast_ops(self): _test_ops(a_dtype=dtypes.float16, b_dtype=dtypes.float32, target_dtype=dtypes.float32) def test_upcast_to_half_ops(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.float16, target_dtype=dtypes.float16) +@unittest.skipIf(Device.DEFAULT in ["WEBGPU", "METAL"], "float64 is not supported by some backends") +class TestDoubleDtype(unittest.TestCase): + def test_float64_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.double), np.double, [1,2,3,4]) + def test_casts_to_float64(self): _test_casts_to([1,2,3,4], source_dtypes=[dtypes.float32, dtypes.int32, dtypes.uint8], target_dtype=dtypes.float64) + def test_upcast_to_float64_ops(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.float64, target_dtype=dtypes.float64) + @unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int8") class TestInt8Dtype(unittest.TestCase): def test_int8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4]) @@ -107,8 +113,10 @@ class TestInt8Dtype(unittest.TestCase): def test_int8_upcast_int64(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.int64, target_dtype=dtypes.int64) @unittest.skipIf(getenv("CUDA",0)==1, "cuda saturation works differently") + @unittest.skipIf(getenv("PTX",0)==1, "cuda saturation doesn't wrap") def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252]) + @unittest.skipIf(getenv("PTX",0)==1, "cuda saturation doesn't wrap") def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4]) @unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH") diff --git a/test/test_helpers.py b/test/test_helpers.py index deacd97a6a..18f02f06c0 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -1,5 +1,5 @@ import unittest -from tinygrad.helpers import Context, ContextVar +from tinygrad.helpers import Context, ContextVar, merge_dicts VARIABLE = ContextVar("VARIABLE", 0) @@ -106,5 +106,17 @@ with Context(VARIABLE=1): ... assert D.value == 2, f"Expected D to be 2, but was {D.value}. Indicates that Context.__exit__ did not restore to the correct value." +class TestMergeDicts(unittest.TestCase): + def test_merge_dicts(self): + a = {"a": 1, "b": 2} + b = {"a": 1, "c": 3} + c = {} + d = {"a": 2, "b": 2} + assert merge_dicts([a, b]) == {"a": 1, "b": 2, "c": 3} + assert merge_dicts([a, c]) == a + assert merge_dicts([a, b, c]) == {"a": 1, "b": 2, "c": 3} + with self.assertRaises(AssertionError): + merge_dicts([a, d]) + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/test/test_ops.py b/test/test_ops.py index db2690a7fa..c346b4cdbd 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -4,7 +4,7 @@ import math import numpy as np import unittest from tinygrad.tensor import Tensor -from tinygrad.helpers import getenv, IMAGE, DEBUG, CI +from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes from tinygrad.lazy import Device if CI: @@ -1127,21 +1127,37 @@ class TestOps(unittest.TestCase): n = (x < 0).where(x, 1).numpy() assert np.all(n == 1.) + def test_slice_fancy_indexing(self): + # indices cannot have gradient + a = torch.randint(low=-1, high=1, size=(2,1,1,1,1,1), dtype=torch.int64, requires_grad=False) + b = torch.randint(high=1, size=(1,3,1,1,1,1), dtype=torch.int64, requires_grad=False) + c = torch.randint(low=-5, high=5, size=(1,1,4,1,1,1), dtype=torch.int64, requires_grad=False) + d = torch.randint(high=4, size=(2,1,1,5,1,1), dtype=torch.int64, requires_grad=False) + e = torch.randint(high=1, size=(1,1,1,1,6,1), dtype=torch.int64, requires_grad=False) + i, j, k, o, p = [Tensor(tor.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) for tor in [a,b,c,d,e]] + helper_test_op([(2,5,15,5,3,4)], lambda x: x[a,b,c,d,e], lambda x: x[i,j,k,o,p]) + helper_test_op([(2,5,15,5,3,4)], lambda x: x[:,b,c,d,e], lambda x: x[:,j,k,o,p]) + helper_test_op([(2,5,15,5,3,4)], lambda x: x[:,b,c,d,:], lambda x: x[:,j,k,o,:]) + helper_test_op([(2,5,15,5,3,4)], lambda x: x[a,b,...], lambda x: x[i,j,...]) + helper_test_op([(2,5,15,5,3,4)], lambda x: x[a,...,e], lambda x: x[i,...,p]) + helper_test_op([(2,5,15,5,3,4)], lambda x: x[...,c,:,e], lambda x: x[...,k,:,p]) + helper_test_op([(2,5,15,5,3,4)], lambda x: x[a,:,None,d,e], lambda x: x[i,:,None,o,p]) + helper_test_op([(2,5,15,5,3,4)], lambda x: x[1,:,10:11,d,0:2], lambda x: x[1,:,10:11,o,0:2]) + helper_test_op([(2,5,15,5,3,4)], lambda x: x[1,4,c,d,2], lambda x: x[1,4,k,o,2]) + helper_test_op([(2,3)], lambda x: x[torch.tensor([[0,0,0],[0,0,0]]), torch.tensor(1)], lambda x: x[Tensor([[0,0,0],[0,0,0]]), Tensor(1)]) + helper_test_op([(2,3)], lambda x: x[torch.tensor([1]), torch.tensor([[0,0,0],[0,0,0]])], lambda x: x[Tensor([1]), Tensor([[0,0,0],[0,0,0]])]) + def test_gather(self): - nda = np.random.randn(4,5,6,9,5).astype(np.float32) - ten = Tensor(nda, requires_grad=True) - tor = torch.tensor(nda, requires_grad=True) - c = np.random.randint(low=-4, high=4, size=[3,4,5]).astype(np.int32) - a = Tensor(c, requires_grad=False) - b = torch.tensor(c, requires_grad=False) - helper_test_op([], lambda: tor[b,:,:,:,:], lambda: ten.gather(a, dim=0)) - helper_test_op([], lambda: tor[:,b,:,:,:], lambda: ten.gather(a, dim=1)) - helper_test_op([], lambda: tor[:,:,b,:,:], lambda: ten.gather(a, dim=2)) - helper_test_op([], lambda: tor[:,:,:,b,:], lambda: ten.gather(a, dim=3)) - helper_test_op([], lambda: tor[:,:,:,:,b], lambda: ten.gather(a, dim=4)) - ta = Tensor(c, requires_grad=True) - tb = torch.tensor(c, requires_grad=True, dtype=torch.float32) - self.helper_test_exception([], lambda: tor[tb,:,:,:,:].sum().backward(), lambda: ten.gather(ta, dim=0).sum().backward(), expected=(IndexError, RuntimeError)) # torch raises IndexError, Tensor raises RuntimeError + # indices cannot have gradient + # indices cannot be negative (torch gather) + b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) + a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) + helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0)) + helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=1), lambda x: x.gather(idx=a, dim=1)) + helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=2), lambda x: x.gather(idx=a, dim=2)) + helper_test_op([(3,4,5)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0)) + self.helper_test_exception([(4,5,6)], lambda x: x.gather(index=torch.tensor([1], dtype=torch.int64), dim=0), lambda x: x.gather(idx=Tensor([1], dtype=dtypes.int32), dim=0), expected=(RuntimeError, AssertionError)) + self.helper_test_exception([(2,1,1)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0), expected=(RuntimeError, AssertionError)) def test_scaled_product_attention(self): helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z)) diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py new file mode 100644 index 0000000000..b6be67bcd6 --- /dev/null +++ b/test/test_symbolic_ops.py @@ -0,0 +1,113 @@ +import unittest +from tinygrad.shape.symbolic import Variable +from tinygrad.helpers import getenv, CI +from tinygrad.tensor import Tensor, Device +import numpy as np + +@unittest.skipIf(getenv("ARM64"), "ARM64 is not supported") +@unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG"], f"{Device.DEFAULT} is not supported") +class TestSymbolicOps(unittest.TestCase): + def test_plus1(self): + def f(a): return (a+1).realize() + vi = Variable("i", 1, 10) + for i in range(1, 5): + a = Tensor.rand(3, i) + symbolic = f(a.reshape(3, vi)).reshape(3, i).cpu().numpy() + expected = f(a).cpu().numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + + def test_add(self): + def f(a, b): return (a+b).realize() + vi = Variable("i", 1, 10) + for i in range(1, 5): + a = Tensor.rand(3, i) + b = Tensor.rand(3, i) + symbolic = f(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).cpu().numpy() + expected = f(a, b).cpu().numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + + def test_matmul(self): + def f(a, b): return (a@b).realize() + vi = Variable("i", 1, 10) + for i in range(1, 5): + a = Tensor.rand(3, i) + b = Tensor.rand(i, 5) + symbolic = f(a.reshape(3, vi), b.reshape(vi, 5)).cpu().numpy() + expected = f(a, b).cpu().numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + + def test_matmul_same_var_different_val(self): + def f(a, b): return (a@b).realize() + vi = Variable("i", 1, 10) + a = Tensor.rand(3, 4) + b = Tensor.rand(7, 5) + with self.assertRaises(AssertionError): + f(a.reshape(3, vi), b.reshape(vi, 5)).cpu().numpy() + + @unittest.skipIf(Device.DEFAULT == "CLANG" and CI, "broken on CLANG CI") + def test_attention(self): + def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize() + vi = Variable("i", 1, 10) + for i in range(1, 5): + q = Tensor.rand(2, 1, 4, 8) + k = Tensor.rand(2, i, 4, 8) + v = Tensor.rand(2, i, 4, 8) + symbolic = f(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).cpu().numpy() + expected = f(q, k, v).cpu().numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + + def test_cat_dim0(self): + def f(a, b): return a.cat(b, dim=0).realize() + vi = Variable("i", 1, 10) + for i in range(1, 5): + a = Tensor.rand(i, 3) + b = Tensor.rand(2, 3) + symbolic = f(a.reshape(vi, 3), b).reshape(i+2, 3).cpu().numpy() + expected = f(a, b).cpu().numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + + def test_cat_dim1(self): + def f(a, b): return a.cat(b, dim=1).realize() + vi = Variable("i", 1, 10) + for i in range(1, 5): + a = Tensor.rand(3, i) + b = Tensor.rand(3, 2) + symbolic = f(a.reshape(3, vi), b).reshape(3, i+2).cpu().numpy() + expected = f(a, b).cpu().numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + + def test_cat_dim0_two_vars(self): + def f(a, b): return a.cat(b, dim=0).realize() + vi = Variable("i", 1, 10) + vj = Variable("j", 1, 10) + for i in range(1, 5): + for j in range(1, 5): + a = Tensor.rand(i, 3) + b = Tensor.rand(j, 3) + symbolic = f(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).cpu().numpy() + expected = f(a, b).cpu().numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + + def test_cat_dim1_two_vars(self): + def f(a, b): return a.cat(b, dim=1).realize() + vi = Variable("i", 1, 10) + vj = Variable("j", 1, 10) + for i in range(1, 5): + for j in range(1, 5): + a = Tensor.rand(3, i) + b = Tensor.rand(3, j) + symbolic = f(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).cpu().numpy() + expected = f(a, b).cpu().numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + + def test_two_vars_plus1(self): + def f(a, b): return (a@b+1).realize() + vi = Variable("i", 1, 10) + vj = Variable("j", 1, 10) + for i in range(1, 5): + for j in range(1, 5): + a = Tensor.rand(i, 3) + b = Tensor.rand(3, j) + symbolic = f(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).cpu().numpy() + expected = f(a, b).cpu().numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) \ No newline at end of file diff --git a/test/test_symbolic_shapetracker.py b/test/test_symbolic_shapetracker.py index 8a8857b97b..6e3c650670 100644 --- a/test/test_symbolic_shapetracker.py +++ b/test/test_symbolic_shapetracker.py @@ -26,17 +26,18 @@ class TestSymbolic(unittest.TestCase): i = Variable("i", 1, 5) j = Variable("j", 1, 5) k = Variable("k", 1, 5) - t1 = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0) - st = t1.lazydata.st + t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0) + st = t.lazydata.st assert st.shape == (i+j+k, 4) assert st.real_strides() == (4, 1) - i = Variable("i", 1, 5) - j = Variable("j", 1, 5) - k = Variable("k", 1, 5) - t1 = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1) - st = t1.lazydata.st + t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1) + st = t.lazydata.st assert st.shape == (3, i+j+k) assert st.real_strides() == (i+j+k, 1) + t = Tensor.rand(i, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0) + st = t.lazydata.st + assert st.shape == (2*i+3, 3) + assert st.real_strides() == (3, 1) class TestSymbolicReshape(unittest.TestCase): def test_reshape_into_symbols_simple(self): diff --git a/test/test_uops.py b/test/test_uops.py index 4f8bcf8d54..c019cb5288 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -80,14 +80,14 @@ class TestFloatUOps(TestUOps): def test_where(self): self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c) # TODO: fix this on all the backends -@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or Device.DEFAULT == "LLVM" or getenv('ARM64', False), "only test for compiled backends, broken on some") +@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or getenv('ARM64', False), "only test for compiled backends, broken on some") class TestNonFloatUOps(TestUOps): def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), dtypes.int32) + def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), dtypes.int32) def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), dtypes.int32) def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), dtypes.int32, no_b_zero=True) def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], dtypes.int32, no_b_zero=True) def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a= 4: for tins in lang.ins: print(tins) - return global_size, local_size \ No newline at end of file + return global_size, local_size diff --git a/tinygrad/codegen/assembly_arm64.py b/tinygrad/codegen/assembly_arm64.py index 8f4e824215..5e4df3aea2 100644 --- a/tinygrad/codegen/assembly_arm64.py +++ b/tinygrad/codegen/assembly_arm64.py @@ -22,7 +22,7 @@ def specialize_to_arm64(fn_nm, asm): ins = [] x_regs = ['x' + str(i) for i in reversed(range(29)) if i not in (10,11,12,13,14,15,16,17,18,19,20)] s_regs = ['s' + str(i) for i in reversed(range(3,30))] - type_to_reg = {dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'} + type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'} alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max", BinaryOps.MOD: "", BinaryOps.CMPLT: "subs", UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"), @@ -137,12 +137,12 @@ def specialize_to_arm64(fn_nm, asm): mov_imm(arg[0], "x15") ins.append(f"add x15, {rtor[vin[0].nm]}, x15") ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]") - if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] == dtypes.half else 'scvtf'} {rtor[out.nm]}, {reg_in}") + if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}") elif uop == UOps.STORE: shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2", dtypes.bool: "#2"} #NOTE: if need casting load var in s/h0 or x/w12 temp regs reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm]) - if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] != dtypes.half else '' } {reg_out}, {rtor[vin[1].nm]}") + if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}") ins.append(f"mov x15, #{arg[0]}") ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]") elif uop == UOps.COND_BRANCH: diff --git a/tinygrad/codegen/assembly_ptx.py b/tinygrad/codegen/assembly_ptx.py new file mode 100644 index 0000000000..b620175684 --- /dev/null +++ b/tinygrad/codegen/assembly_ptx.py @@ -0,0 +1,98 @@ +from typing import List +import struct +from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage +from tinygrad.codegen.linearizer import UOps, UOp +from tinygrad.helpers import dtypes +from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps +from tinygrad.runtime.ops_cuda import arch + +dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "f16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.int8: "s8", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32", dtypes.uint16: "u16", dtypes.uint8: "u8", "bits16": "b16", dtypes.float64: "f64"} +def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1]) + +def ptx_needs_cast(dest_dtype, src_dtype): return dtypes.is_float(dest_dtype) and dtypes.is_int(src_dtype) or dtypes.is_int(dest_dtype) and dtypes.is_float(src_dtype) or (dtypes.is_float(src_dtype) and dtypes.is_float(dest_dtype) and dest_dtype.itemsize != src_dtype.itemsize) + +def render_cast(ins, inp, out): + if inp.dtype == dtypes.bool and (dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)): + ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};") + elif out.dtype == dtypes.bool: + ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};") + else: + round_mod = ".rzi" if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) else '.rz' if dtypes.is_float(out.dtype) and (dtypes.is_int(inp.dtype) or dtypes.is_float(inp.dtype) and inp.dtype.itemsize > out.dtype.itemsize) else '' + ins.append(f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};") + +# https://docs.nvidia.com/cuda/parallel-thread-execution/# + +class PTXLanguage(AssemblyLanguage): + supports_constant_folding: bool = True + +def specialize_to_ptx(lang, function_name): + param_cnt = 0 + ins = [] + alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max", + BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", UnaryOps.SQRT: "sqrt.approx", + UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz", + TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"} + for uop, out, vin, arg in lang.ins: + if uop == UOps.ENDLOOP: + ins.append("bar.sync 0;") + elif uop == UOps.DEFINE_LOCAL: + ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];") + elif uop == UOps.SPECIAL: + if arg.startswith('data'): + param_cnt += 1 + ins.append(f"ld.param.u64 {out}, [{arg}];") + # TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to? + # ins.append(f"cvta.to.global.u64 {out}, {out};") + elif arg.startswith('gid'): + ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};") + elif arg.startswith('lid'): + ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};") + elif uop == UOps.ALU: + if arg == BinaryOps.MUL and out.dtype == dtypes.bool: + ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};") + else: + otype = vin[0].dtype if arg in [BinaryOps.CMPLT] else out.dtype + if arg == TernaryOps.WHERE: + reg = lang.newreg((vin[0], 'bool'), dtypes.bool) + ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};") + vin = vin[1:] + [reg] + ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};") + elif uop == UOps.LOAD: + if arg.__class__ in (int, float): + ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};") + elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype): + dt = ('u16', dtypes.uint16) if arg[2] == dtypes.bool == out.dtype else ('u8', dtypes.uint8) if arg[2] == dtypes.bool else ('b16', dtypes.float16) if arg[2] == dtypes.half else (dtype_to_nvtype[arg[2]], arg[2]) + reg = lang.newreg((out, dt[0]), dtype=dt[1]) + ins.append(f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];") + render_cast(ins, reg, out) + else: + ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];") + elif uop == UOps.STORE: + if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool: + if arg[2] == dtypes.bool != vin[1].dtype: + prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool) + render_cast(ins, vin[1], prereg) + else: prereg = vin[1] + reg = lang.newreg((prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]), dtype=dtypes.uint16 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]) + render_cast(ins, prereg, reg) + ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};") + else: + ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};") + elif uop == UOps.CAST: + render_cast(ins, vin[0], out) + elif uop == UOps.LABEL: + ins.append(f"{arg}:") + elif uop == UOps.COND_BRANCH: + ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};") + + ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64", + f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{"] + for arg in [(dtype, lang.type_to_letter(dtype), c) for dtype,c in lang.cnts.items()]: ins_prefix.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",) + ins = ins_prefix + ins + ins += ["ret;", "}"] + return '\n'.join(ins) + +def uops_to_ptx_asm(function_name:str, uops:List[UOp]): + lang = PTXLanguage() + global_size, local_size = uops_to_asmstyle(lang, function_name, uops) + return specialize_to_ptx(lang, function_name), global_size[::-1], local_size[::-1], True diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 679b70a55c..02dc5b161e 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -9,7 +9,7 @@ from tinygrad.lazy import LazyBuffer from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, TernaryOps from tinygrad.runtime.lib import RawConst, buf_is_kernel_arg from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape, View -from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode +from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, sym_rename VariableOrNum = Union[Variable, NumNode, Node] # bottom ones are asm only @@ -301,6 +301,9 @@ class Linearizer: # add global buffers for buf,name in self.arg_bufs.items(): self.uop(UOps.DEFINE_GLOBAL, None, [], (name, buf.dtype)) + # add variables from symbolic shapes + for var in sorted(set(v for buf in self.ast.buffers for v in buf.st.var_vals), key=lambda k: k.key): + self.uop(UOps.DEFINE_GLOBAL, None, [], (var.expr, dtypes._arg_int32)) # add a local buffer for multistage reduce if len(self.group_for_reduce): @@ -317,7 +320,7 @@ class Linearizer: if DEBUG >= 3: self.printbufs() # kernel name (before late upcast) - self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) for x in self.full_shape]) + self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) if isinstance(x, int) else sym_rename(x) for x in self.full_shape]) self.display_name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())]) # parse AST @@ -548,7 +551,7 @@ class Linearizer: assert len(colors) == self.shape_len, "colors size mismatch" return colors - def colored_shape(self) -> str: return ' '.join(colored(f"{s:4d}", color) for s,color in zip(self.full_shape, self.colors())) + def colored_shape(self) -> str: return ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) else s for s in self.full_shape], self.colors())) def printbufs(self, prefix=""): for i in range(len(self.sts)): print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i].realized is not None else str(self.bufs[i]):47s}", self.sts[i].views) diff --git a/tinygrad/codegen/optimizer.py b/tinygrad/codegen/optimizer.py index b4a7b2925c..a6efbe699f 100644 --- a/tinygrad/codegen/optimizer.py +++ b/tinygrad/codegen/optimizer.py @@ -162,7 +162,7 @@ def hand_coded_optimizations(k:Linearizer): # early exit return - if k.opts.has_local: + if k.opts.has_local and all(isinstance(s, int) for s in k.sts[0].shape[:k.first_reduce]): # are we grouping? (requires local shape support) if not k.float4_axis(0) and k.first_reduce <= 2 and k.first_reduce + 1 <= k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048: # TODO: use 1024 if it's allowed in a smarter way @@ -204,8 +204,8 @@ def hand_coded_optimizations(k:Linearizer): while prod(k.sts[0].shape[:k.first_reduce]) >= 1024: xb_choices = [] for axis, upcast_amount in itertools.product(range(k.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce - # if we haven't upcasted it, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already - if axis not in upcasted_axis and k.full_shape[axis]%upcast_amount == 0 and any(k.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in k.upcasted_axis(buf_index)) for buf_index in range(len(k.sts))): + # if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already + if axis not in upcasted_axis and isinstance(k.full_shape[axis], int) and k.full_shape[axis]%upcast_amount == 0 and any(k.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in k.upcasted_axis(buf_index)) for buf_index in range(len(k.sts))): xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in k.sts), sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount)) if len(xb_choices): xb_choices = sorted(xb_choices) @@ -219,7 +219,7 @@ def hand_coded_optimizations(k:Linearizer): # if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS if k.first_reduce < (k.shape_len-k.upcasted) and (len(list(k.shape_offsets(k.full_buf_index))) <= 4 or not any(r for _,_,r in k.upcasted_axis(k.full_buf_index))): - if (s:=k.full_unupcasted_shape[-1]) <= 32: + if (s:=k.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis k.upcast() # if it's small, upcast a second reduce dimension too if k.first_reduce < (k.shape_len-k.upcasted) and s <= 3 and k.full_unupcasted_shape[-1] <= 3: k.upcast() diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index e46f97eb2f..1329f5af09 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -3,7 +3,7 @@ import os, functools, platform, time, re, contextlib from weakref import KeyedRef, ref from _weakref import _remove_dead_weakref # type: ignore import numpy as np -from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any +from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any, Iterable from math import prod # noqa: F401 # pylint:disable=unused-import ShapeType = Tuple[int, ...] @@ -22,6 +22,10 @@ def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return ( def flatten(l:Iterator): return [item for sublist in l for item in sublist] def mnum(i) -> str: return str(i) if i >= 0 else f"m{-i}" def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm) +def merge_dicts(ds:Iterable[Dict]) -> Dict: + kvs = set([(k,v) for d in ds for k,v in d.items()]) + assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" + return {k:v for k,v in kvs} @functools.lru_cache(maxsize=None) def getenv(key, default=0): return type(default)(os.getenv(key, default)) @@ -83,11 +87,11 @@ class ImageDType(DType): class dtypes: @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool - def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64) + def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) @staticmethod - def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes._half4, dtypes._float4) + def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes._half4, dtypes._float2, dtypes._float4) @staticmethod - def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint32, dtypes.uint64) + def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) @staticmethod def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name] @staticmethod @@ -115,6 +119,7 @@ class dtypes: _half4: Final[DType] = DType(0, 2*4, "half4", None, 4) _float2: Final[DType] = DType(4, 4*2, "float2", None, 2) _float4: Final[DType] = DType(4, 4*4, "float4", None, 4) + _arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None) # HACK: staticmethods are not callable in 3.8 so we have to compare the class DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod} @@ -125,6 +130,7 @@ class GlobalCounters: time_sum_s: ClassVar[float] = 0.0 kernel_count: ClassVar[int] = 0 mem_used: ClassVar[int] = 0 # NOTE: this is not reset + mem_cached: ClassVar[int] = 0 # NOTE: this is not reset cache: ClassVar[Optional[List[Tuple[Callable, Any]]]] = None @staticmethod def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0.0,0,None diff --git a/tinygrad/jit.py b/tinygrad/jit.py index e9da4a8f8a..24110963f2 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Tuple, Any, Dict, cast, Union +from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional import functools, itertools from tinygrad.helpers import DEBUG, DType @@ -12,7 +12,7 @@ class TinyJit: def __init__(self, fxn:Callable): self.fxn: Callable = fxn self.cnt: int = 0 - self.jit_cache: List[Tuple[Callable, Any]] = [] # TODO: Any should be List[RawBuffer], but this fails + self.jit_cache: List[Tuple[Callable, List[Optional[RawBuffer]]]] = [] self.ret: Any = None self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], int, DType]]= {} # (kernel_number, buffer_number) -> (input_name, expected_size, expected_type) @@ -29,7 +29,8 @@ class TinyJit: for (j,i),(input_name, expected_size, expected_type) in self.input_replace.items(): assert input_rawbuffers[input_name].size == expected_size and input_rawbuffers[input_name].dtype == expected_type, f"size or type mismatch in JIT, {input_rawbuffers[input_name]} != <{expected_size}, {expected_type}>" self.jit_cache[j][1][i] = input_rawbuffers[input_name] - for prg, args in self.jit_cache: prg(args, jit=True) + for prg, pargs in self.jit_cache: # type: Callable, List[Optional[RawBuffer]] + prg(pargs, jit=True) for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None elif self.cnt == 1: GlobalCounters.cache = [] @@ -40,10 +41,10 @@ class TinyJit: if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs") # get the inputs for replacement - for j,(prg,args) in enumerate(self.jit_cache): # pylint: disable=E1133 - for i,a in enumerate(args): + for j_,(_,pargs_) in enumerate(self.jit_cache): # type: Tuple[int, Tuple[Callable, List[Optional[RawBuffer]]]] + for i,a in enumerate(pargs_): if a in input_rawbuffers.values(): - self.input_replace[(j,i)] = [(k, v.size, v.dtype) for k,v in input_rawbuffers.items() if v == a][0] + self.input_replace[(j_,i)] = [(k, v.size, v.dtype) for k,v in input_rawbuffers.items() if v == a][0] #if prg.local_size is None: prg.local_size = prg.optimize_local_size(args, preserve_output=True) # the JIT can optimize local assert set([x[0] for x in self.input_replace.values()]) == set(input_rawbuffers.keys()), "some input tensors not found" for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 9f47828705..2f4995af25 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -2,8 +2,9 @@ from __future__ import annotations import functools, time from enum import Enum, auto from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast -from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, dedup +from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, dedup, merge_dicts from tinygrad.shape.shapetracker import MovementOps +from tinygrad.shape.symbolic import Variable, sym_infer from tinygrad.runtime.lib import RawBuffer, RawConst, buf_is_kernel_arg if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer @@ -131,20 +132,24 @@ class ASTRunner: self.clprg = runtime(self.name, self.prg, **self.runtime_args) return self - def exec(self, bufs, force_wait=False, optimizing=False) -> Optional[float]: + def exec(self, bufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]: rawbufs = dedup([x.realized for x in bufs if buf_is_kernel_arg(x)]) if GlobalCounters.cache is not None and not optimizing: GlobalCounters.cache.append((self, rawbufs)) - return self(rawbufs, force_wait=force_wait) + return self(rawbufs, var_vals, force_wait=force_wait) - def __call__(self, rawbufs:List[RawBuffer], jit=False, force_wait=False) -> Optional[float]: - if et := self.clprg((self.global_size + [1]*(3-len(self.global_size))) if self.global_size is not None else None, - (self.local_size + [1]*(3-len(self.local_size))) if self.local_size is not None else None, - *rawbufs, wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et + def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]: + if var_vals is None: var_vals = {} + global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size + local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else self.local_size + if et := self.clprg((global_size + [1]*(3-len(global_size))) if global_size is not None else None, + (local_size + [1]*(3-len(local_size))) if local_size is not None else None, + *rawbufs, *var_vals.values(), wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et + op_estimate = sym_infer(self.op_estimate, var_vals) if DEBUG >= 2: - print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(33-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + - (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) + print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(33-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(global_size):18s} {str(local_size):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + + (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) GlobalCounters.kernel_count += 1 - GlobalCounters.global_ops += self.op_estimate + GlobalCounters.global_ops += op_estimate GlobalCounters.global_mem += self.mem_estimate if getenv("EARLY_STOPPING") and GlobalCounters.kernel_count == getenv("EARLY_STOPPING"): exit(0) return et @@ -178,9 +183,11 @@ class Compiled: output.realized = None break - # we don't have an output buffer, we have to create it + # we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape if not output.realized: - output.realized = self.buffer(prod(output.shape), output.dtype, **kwargs) + output.realized = self.buffer(prod((s if isinstance(s, int) else s.max for s in output.shape)), output.dtype, **kwargs) + # update the output var_vals from src + output.st.var_vals = dict(sorted(merge_dicts([buf.st.var_vals for buf in ast.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key)) from tinygrad.codegen.linearizer import Linearizer k = Linearizer(ast, output, self.linearizer_opts) @@ -200,5 +207,5 @@ class Compiled: if prg.name == getenv("PRINT_PRG", ''): print(prg.prg) - prg.exec(k.bufs) + prg.exec(k.bufs, var_vals=output.st.var_vals) return output.realized diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 1eaa0b433d..2ba1c4f130 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -3,7 +3,7 @@ import math from tinygrad.codegen.linearizer import UOps, UOp, MemOp, ConstOp from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.helpers import ImageDType, dtypes, getenv, prod, DType -from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable +from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, sym_render # div is different in cl than python render_cl = render_python.copy() @@ -17,6 +17,7 @@ class CStyleLanguage(NamedTuple): buffer_prefix: str = "" buffer_suffix: str = "" smem_prefix: str = "" + arg_int_prefix: str = "" barrier: str = "" gid: List[str] = [] lid: List[str] = [] @@ -52,7 +53,7 @@ class CStyleLanguage(NamedTuple): def render_const(self, x:Union[float,int], var_dtype) -> str: if math.isnan(x): val = "NAN" elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY" - else: val = f"{x}" + ("f" if isinstance(x, float) else "") + else: val = f"{x}f" if dtypes.is_float(var_dtype) and isinstance(x, float) else f"{int(x)}" return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val # returns a str expression of the loaded value with the output type @@ -69,7 +70,7 @@ class CStyleLanguage(NamedTuple): def render_local(self, name:str, size:int): return self.smem_prefix + f"float {name}[{size}];" - def render_for(self, expr: str, _min:int, _max:int) -> str: + def render_for(self, expr: str, _min:int, _max:Union[int,str]) -> str: return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{" def render_conditional(self, cond: str, x:str, y:str) -> str: @@ -78,6 +79,7 @@ class CStyleLanguage(NamedTuple): def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]: tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else + self.arg_int_prefix if dtype == dtypes._arg_int32 else ("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)] prg = ''.join([f"{self.kernel_prefix} void {function_name}(",] + [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] + @@ -128,7 +130,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid)) else: if getenv("NOUNROLL") and not isinstance(var, NumNode): kk("#pragma unroll(1)") # prevent loop unrolling - kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, var.max)) + kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, sym_render(var.max))) depth += 1 elif uop == UOps.BARRIER: kk(lang.barrier) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 5c00734645..f6c2cf29c7 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -22,17 +22,18 @@ code_for_op: Final[Dict[Op, Callable]] = { UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=('fast',)), UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=('fast',)), UnaryOps.SQRT: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [ir.FloatType()]), [x], fastmath=('fast',)), - BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)), - BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)), - BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y, flags=('fast',)), - BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y, flags=('fast',)), - BinaryOps.CMPLT: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("<", x, y, flags=('fast',)), ir.FloatType()), + BinaryOps.ADD: lambda builder,x,y: builder.add(x,y) if isinstance(x.type, ir.IntType) else builder.fadd(x,y, flags=('fast',)), + BinaryOps.SUB: lambda builder,x,y: builder.sub(x,y) if isinstance(x.type, ir.IntType) else builder.fsub(x,y, flags=('fast',)), + BinaryOps.MUL: lambda builder,x,y: builder.mul(x,y) if isinstance(x.type, ir.IntType) else builder.fmul(x,y, flags=('fast',)), + BinaryOps.DIV: lambda builder,x,y: builder.sdiv(x,y) if isinstance(x.type, ir.IntType) else builder.fdiv(x,y, flags=('fast',)), + BinaryOps.CMPLT: lambda builder,x,y: builder.zext(builder.icmp_signed("<", x, y),ir.IntType(32)) if isinstance(x.type, ir.IntType) else builder.uitofp(builder.fcmp_ordered("<", x, y, flags=('fast',)), ir.FloatType()), BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=('fast',)), x, y, flags=('fast',)), + BinaryOps.MOD: lambda builder,x,y: builder.srem(x,y), TernaryOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=('fast',)), z, flags=('fast',)), TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), y, z, flags=('fast',)), } -dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)} +dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)} def cast(bb, val, input_type, output_type): if input_type == output_type: return val @@ -44,6 +45,8 @@ def cast(bb, val, input_type, output_type): val = bb[-1].sext(val, ir.IntType(32)) val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16)) val = bb[-1].bitcast(val, ir.FloatType()) + elif input_type == dtypes.float64: + val = bb[-1].fptrunc(val, ir.FloatType()) else: val = bb[-1].fpext(val, ir.FloatType()) return val @@ -55,6 +58,8 @@ def cast(bb, val, input_type, output_type): val = bb[-1].bitcast(val, ir.IntType(32)) val = bb[-1].lshr(val, ir.Constant(ir.IntType(32), 16)) val = bb[-1].trunc(val, ir.IntType(16)) + elif output_type == dtypes.float64: + val = bb[-1].fpext(val, ir.DoubleType()) else: val = bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type]) return val @@ -114,11 +119,11 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li assert newvar is not None and isinstance(args, (MemOp, ConstOp)) valid = args.valid.render(render_llvm, bb[-1]) if isinstance(args, ConstOp): - assert newvar.dtype == dtypes.float, "newvar must be float" + value, invalid_value = [int(args.value), int(args.invalid_value)] if dtypes.is_int(newvar.dtype) else ([bool(args.value), bool(args.invalid_value)] if newvar.dtype == dtypes.bool else [args.value, args.invalid_value]) # type: ignore if args.valid.min == 0 and args.valid.max == 1: - val = bb[-1].select(valid, ir.Constant(ir.FloatType(), args.value), ir.Constant(ir.FloatType(), args.invalid_value)) + val = bb[-1].select(valid, ir.Constant(dtype_to_llvm_dtype[newvar.dtype], value), ir.Constant(dtype_to_llvm_dtype[newvar.dtype], invalid_value)) else: - val = ir.Constant(ir.FloatType(), args.value if args.valid.min == 1 else args.invalid_value) + val = ir.Constant(dtype_to_llvm_dtype[newvar.dtype], value if args.valid.min == 1 else invalid_value) # TODO: this is a hack. it shouldn't be const that signals this reduce_phis.append(newvar) else: diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index 180cb1ad7b..84a2261d1c 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -38,7 +38,7 @@ class WGSLLanguage(CStyleLanguage): prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3, @builtin(local_invocation_id) lindex: vec3) {{\n" + "\n".join(kernel) + "\n}" return prg, global_size[::-1] if len(global_size) else [1], local_size - def render_for(self, expr:str, _min:int, _max:int) -> str: + def render_for(self, expr:str, _min:int, _max:Union[int,str]) -> str: return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{" def render_conditional(self, cond:str, x:str, y:str) -> str: diff --git a/tinygrad/runtime/lib.py b/tinygrad/runtime/lib.py index f930d5b5a5..00dff3cb25 100644 --- a/tinygrad/runtime/lib.py +++ b/tinygrad/runtime/lib.py @@ -1,18 +1,21 @@ import ctypes import numpy as np -from typing import TypeVar, Type, Any -from tinygrad.helpers import DType, dtypes, prod, GlobalCounters +from collections import defaultdict, deque +from typing import TypeVar, Type, Any, Dict, Deque, Tuple +from tinygrad.helpers import DType, dtypes, prod, GlobalCounters, ImageDType _T = TypeVar("_T") class RawBuffer: # pylint: disable=abstract-method - def __init__(self, size:int, dtype:DType, buf:Any=None): + def __init__(self, size:int, dtype:DType, buf:Any=None, allocator:Any=None, **kwargs): self.size: int = size self.dtype: DType = dtype - self._buf = buf + self._buf = buf if buf is not None else (allocator.alloc(size, dtype, **kwargs) if allocator else None) # If buf is provided, use it. Otherwise try to allocate from the allocator. self._memsz: int = size*dtype.itemsize + self._allocator = allocator GlobalCounters.mem_used += self._memsz def __del__(self): # NOTE: if it fails on init (bad dtype), it won't have a _memsz if hasattr(self, '_memsz'): GlobalCounters.mem_used -= self._memsz + if hasattr(self, '_allocator') and self._allocator: self._allocator.free(self._buf) def __repr__(self): return f"buffer<{self.size}, {self.dtype}>" @property def key(self): return (self.size, self.dtype.key) @@ -39,7 +42,7 @@ class RawBufferMapped(RawBufferCopyIn): # this one is simple enough that i moved it out of the runtimes class RawMallocBuffer(RawBufferMapped): - def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64}[dtype] * size)()) + def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float64:ctypes.c_double, dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64}[dtype] * size)()) def _buffer(self): return memoryview(self._buf) class RawBufferCopyInOut(RawBufferCopyIn): @@ -66,3 +69,43 @@ class RawConst(RawBuffer): # pylint: disable=abstract-method def buf_is_kernel_arg(x) -> bool: return x.realized is not None and x.realized.__class__ is not RawConst + +class LRUAllocator: + def __init__(self, dev_memsz=(4<<30)): + self.epoch = 0 + self.free_space: Dict[Any, int] = defaultdict(lambda: dev_memsz) + self.buffer_info: Dict[Any, Tuple[int, DType, str]] = dict() + self.cached_buffers: Dict[Tuple[int, ...], Deque[Tuple[Any, int]]] = defaultdict(deque) # Cached buffer storage, splitted by type and size, newest first. + self.aging_order: Dict[Any, Deque[Tuple[Tuple[int, ...], int]]] = defaultdict(deque) # Keys of cached_buffers, ordered from oldest to newest updates. + def __del__(self): + for v in self.cached_buffers.values(): + for buf, _ in v: self._free_buffer(buf) + def _cache_reuse_buffer(self, rawbufs: Deque[Tuple[Any, int]]): # The newest cached buffer is reused. + GlobalCounters.mem_cached -= self._underlying_buf_memsz(rawbufs[0][0]) + return rawbufs.popleft()[0] + def _alloc_buffer(self, size, dtype, device, **kwargs): + self.free_space[device] -= size*dtype.itemsize + while len(self.aging_order[device]) and self.free_space[device] < 0: # When OOM removing lru buffers. + bucket, epoch = self.aging_order[device].popleft() + if self.cached_buffers[bucket] and self.cached_buffers[bucket][-1][1] == epoch: self._free_buffer(self.cached_buffers[bucket].pop()[0]) # Free cached buffer if it is still in cache. + newbuf = self._do_alloc(size, dtype, device, **kwargs) + self.buffer_info[newbuf] = (size, dtype, device) + return newbuf + def _free_buffer(self, buf_to_free): + self.free_space[self.buffer_info[buf_to_free][2]] += self._underlying_buf_memsz(buf_to_free) + GlobalCounters.mem_cached -= self._underlying_buf_memsz(buf_to_free) + self.buffer_info.pop(buf_to_free) + self._do_free(buf_to_free) + def alloc(self, size, dtype, device='0', **kwargs): + rawbufs = self.cached_buffers.get(self._cached_bufkey(size, dtype, device), None) + return self._cache_reuse_buffer(rawbufs) if rawbufs else self._alloc_buffer(size, dtype, device, **kwargs) + def free(self, buf): # free() just caches buffer. It might be freed later when OOM during allocation. + self.epoch += 1 + size, dtype, device = self.buffer_info[buf] + self.cached_buffers[self._cached_bufkey(size, dtype, device)].appendleft((buf, self.epoch)) + self.aging_order[device].append((self._cached_bufkey(size, dtype, device), self.epoch)) + GlobalCounters.mem_cached += self._underlying_buf_memsz(buf) + def _underlying_buf_memsz(self, buf): return self.buffer_info[buf][0] * self.buffer_info[buf][1].itemsize + def _cached_bufkey(self, size, dtype, device) -> Tuple[int, ...]: return (device, size, dtype, dtype.shape) if isinstance(dtype, ImageDType) else (device, size, dtype) # Provides a key for reusing device buffers with identical keys. + def _do_alloc(self, size, dtype, device, **kwargs): raise NotImplementedError("must be implemented") + def _do_free(self, buf): pass diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index f2cbe5af23..de81116414 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -74,8 +74,8 @@ class ClangProgram: mu.emu_start(ADDRESS, ADDRESS + len(self.prg)) args[0]._buf = mu.mem_read(mu.reg_read(arm64_const.UC_ARM64_REG_X0), args[0].size * args[0].dtype.itemsize) else: - self.fxn(*[x._buf for x in args]) + self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args]) if wait: return time.monotonic()-st -renderer = fromimport("tinygrad.codegen.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict")) +renderer = fromimport("tinygrad.codegen.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict", arg_int_prefix="const int")) ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram) diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 942d41b1b4..53f5ed13b6 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -2,9 +2,9 @@ import subprocess, time, re, hashlib, tempfile, os, functools from typing import Optional import numpy as np from pycuda.compiler import compile as cuda_compile # type: ignore -from tinygrad.helpers import DEBUG, getenv, colored +from tinygrad.helpers import DEBUG, getenv, colored, fromimport from tinygrad.ops import Compiled -from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer +from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator from tinygrad.codegen.linearizer import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage @@ -47,8 +47,12 @@ if getenv("CUDACPU", 0) == 1: else: import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401 import pycuda.driver as cuda # type: ignore + class CUDAAllocator(LRUAllocator): + def _do_alloc(self, size, dtype, device, **kwargs): return cuda.mem_alloc(size * dtype.itemsize) # type: ignore + def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype. + CUDAAlloc = CUDAAllocator(pycuda.driver.Context.get_device().total_memory()) class RawCUDABuffer(RawBufferCopyInOut): # type: ignore - def __init__(self, size, dtype): super().__init__(size, dtype, cuda.mem_alloc(size * dtype.itemsize)) # type: ignore + def __init__(self, size, dtype): super().__init__(size, dtype, allocator=CUDAAlloc) def _copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._buf, x.ravel(), stream) # type: ignore def _copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._buf) # type: ignore @@ -91,5 +95,5 @@ renderer = functools.partial(uops_to_cstyle, CStyleLanguage( __device__ __forceinline__ explicit half4(const float4& a): x(make_half2(__float2half(a.x), __float2half(a.y))), y(make_half2(__float2half(a.z),__float2half(a.w))) {} __device__ __forceinline__ explicit operator float4() const {return make_float4(__half2float(x.x), __half2float(x.y), __half2float(y.x), __half2float(y.y)); } }; - """)) -CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), renderer, CUDAProgram, cuda.Context.synchronize) + """)) if not getenv("PTX") else fromimport("tinygrad.codegen.assembly_ptx", "uops_to_ptx_asm") +CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), renderer, CUDAProgram, cuda.Context.synchronize) diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 90eb237e14..9182d6f735 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -5,7 +5,7 @@ import pyopencl as cl # type: ignore from typing import Optional, List from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport from tinygrad.ops import Compiled -from tinygrad.runtime.lib import RawBufferCopyInOut, RawBufferTransfer +from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer from tinygrad.codegen.linearizer import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage @@ -17,30 +17,36 @@ ROCM_LLVM_PATH = pathlib.Path("/opt/rocm/llvm/bin") if DEBUG >= 5: early_exec = fromimport("extra.helpers", "enable_early_exec")() +class CLAllocator(LRUAllocator): + def _do_alloc(self, size, dtype, device, **kwargs): + if isinstance(dtype, ImageDType): + # NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize + assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}" + fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize]) + buf = cl.Image(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0])) + else: + buf = cl.Buffer(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, size * dtype.itemsize) + setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer + return buf + class _CL: + def __init__(self): + cl_platforms = cl.get_platforms() + platform_devices: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl_platforms] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl_platforms]) if len(y)] + self.devices = [device for device in platform_devices[getenv('CL_PLATFORM', 0)] if device.name not in getenv('CL_EXCLUDE', "").split(",")] + self.cl_platform = self.devices[0].platform def post_init(self, device=None): - platforms: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()]) if len(y)] - self.cl_platform = cl.get_platforms()[getenv('CL_PLATFORM', 0)] - self.cl_ctxs: List[cl.Context] = [cl.Context(devices=[x]) for x in platforms[getenv('CL_PLATFORM', 0)] if x.name not in getenv('CL_EXCLUDE', "").split(",")] if device is None else [cl.Context(devices=[platforms[getenv('CL_PLATFORM', 0)][device]])] + self.cl_ctxs: List[cl.Context] = [cl.Context(devices=[x]) for x in self.devices] if device is None else [cl.Context(devices=[self.devices[device]])] if DEBUG >= 1: print(f"using devices: {[ctx.devices[0].hashable_model_and_version_identifier for ctx in self.cl_ctxs]}") self.cl_queue: List[cl.CommandQueue] = [cl.CommandQueue(ctx, device=ctx.devices[0], properties=cl.command_queue_properties.PROFILING_ENABLE) for ctx in self.cl_ctxs] + self.cl_allocator = CLAllocator(CL.cl_ctxs[0].devices[0].get_info(cl.device_info.GLOBAL_MEM_SIZE)) def synchronize(self): for q in self.cl_queue: q.finish() CL = _CL() -CL.post_init() if not getenv("DELAYED_RUNTIME_INIT", False) else None +if not getenv("DELAYED_RUNTIME_INIT", False): CL.post_init() class CLBuffer(RawBufferCopyInOut, RawBufferTransfer): - def __init__(self, size, dtype, device='0'): - if isinstance(dtype, ImageDType): - fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize]) - buf = cl.Image(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0])) - assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}" - # NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize - else: - buf = cl.Buffer(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, size * dtype.itemsize) - setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer - super().__init__(size, dtype, buf) - + def __init__(self, size, dtype, device='0'): super().__init__(size, dtype, allocator=CL.cl_allocator, **{'device': device}) def _copyin(self, x:np.ndarray): assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}" self.event = cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements='C'), is_blocking=False) @@ -80,7 +86,7 @@ class CLProgram: def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]: - cl_bufs = [x._buf if isinstance(x, CLBuffer) else x for x in bufs] + cl_bufs = [x._buf if isinstance(x, CLBuffer) else np.int32(x) if isinstance(x, int) else x for x in bufs] e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=[x.event for x in bufs if isinstance(x, CLBuffer) and hasattr(x, "event")]) if wait: e.wait() @@ -91,9 +97,8 @@ class CLProgram: return None renderer = functools.partial(uops_to_cstyle, CStyleLanguage( - kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ", + kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ", arg_int_prefix = "const int", half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable", barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)", gid = [f'get_group_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True)) - GPUBuffer = Compiled(CLBuffer, LinearizerOptions(), renderer, CLProgram, CL.synchronize) diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index 86f9d33fbc..07369e08a1 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -3,7 +3,7 @@ import ctypes, functools import extra.hip_wrapper as hip from tinygrad.helpers import DEBUG from tinygrad.ops import Compiled -from tinygrad.runtime.lib import RawBufferCopyInOut +from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator from tinygrad.codegen.linearizer import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage @@ -14,9 +14,14 @@ if DEBUG >= 5: # The default HIP stream is used for everything. +class HIPAllocator(LRUAllocator): + def _do_alloc(self, size, dtype, device, **kwargs): return hip.hipMalloc(size * dtype.itemsize) + def _do_free(self, buf): hip.hipFree(buf) + def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype. +HIPAlloc = HIPAllocator(hip.hipGetDeviceProperties(hip.hipGetDevice()).totalGlobalMem) + class RawHIPBuffer(RawBufferCopyInOut): - def __init__(self, size, dtype): super().__init__(size, dtype, hip.hipMalloc(size * dtype.itemsize)) - def __del__(self): hip.hipFree(self._buf) + def __init__(self, size, dtype): super().__init__(size, dtype, allocator=HIPAlloc) def _copyin(self, x:np.ndarray): hip.hipMemcpyAsync_htod(self._buf, x.ctypes.data, self.size * self.dtype.itemsize, 0) def _copyout(self, x:np.ndarray): hip.hipMemcpy_dtoh(x.ctypes.data, self._buf, self.size * self.dtype.itemsize) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 4c5e564d17..2165a87d8f 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -1,20 +1,26 @@ # pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch -import os, subprocess, pathlib, functools +import os, subprocess, pathlib, functools, ctypes import Metal, Cocoa, libdispatch # type: ignore from typing import List, Any from tinygrad.codegen.linearizer import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage -from tinygrad.helpers import prod, getenv, DEBUG, DType +from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes from tinygrad.ops import Compiled -from tinygrad.runtime.lib import RawBufferMapped +from tinygrad.runtime.lib import RawBufferMapped, LRUAllocator METAL_XCODE = getenv("METAL_XCODE") +class MetalAllocator(LRUAllocator): + def _do_alloc(self, size, dtype, device, **kwargs): return METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared) + def _do_free(self, buf): buf.release() + def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype. + class _METAL: def __init__(self): self.mtl_buffers_in_flight: List[Any] = [] self.device = Metal.MTLCreateSystemDefaultDevice() self.mtl_queue = self.device.newCommandQueue() + self.allocator = MetalAllocator(self.device.dedicatedMemorySize() or self.device.sharedMemorySize()) # TODO: is there a better way to do this? def synchronize(self): for cbuf in self.mtl_buffers_in_flight: cbuf.waitUntilCompleted() @@ -23,10 +29,8 @@ METAL = _METAL() class RawMetalBuffer(RawBufferMapped): def __init__(self, size:int, dtype:DType): - super().__init__(size, dtype, METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)) - def __del__(self): - self._buf.release() - super().__del__() + assert dtype != dtypes.double, f"METAL does not support {dtype.name}" + super().__init__(size, dtype, allocator=METAL.allocator) def _buffer(self): METAL.synchronize() return self._buf.contents().as_buffer(self._buf.length()) @@ -64,7 +68,10 @@ class MetalProgram: command_buffer = METAL.mtl_queue.commandBuffer() encoder = command_buffer.computeCommandEncoder() encoder.setComputePipelineState_(self.pipeline_state) - for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a._buf, 0, i) + for i,a in enumerate(bufs): + if isinstance(a, RawMetalBuffer): encoder.setBuffer_offset_atIndex_(a._buf, 0, i) + elif isinstance(a, int): encoder.setBytes_length_atIndex_((arg:=ctypes.c_int32(a)), ctypes.sizeof(arg), i) + else: raise RuntimeError(f"arg at index {i} has unsupported type {type(a)}") encoder.dispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)) encoder.endEncoding() command_buffer.commit() @@ -74,7 +81,7 @@ class MetalProgram: METAL.mtl_buffers_in_flight.append(command_buffer) renderer = functools.partial(uops_to_cstyle, CStyleLanguage( - kernel_prefix = "#include \nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ", + kernel_prefix = "#include \nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ", arg_int_prefix = "constant int&", barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4", uses_ptr_arithmetic=True, gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)], extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]'])) diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index ba871b77e7..daf7953d18 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -6,7 +6,7 @@ from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc from tinygrad.runtime.lib import RawBuffer device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu")) -type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool} +type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool} inverse_type_map = {v:k for k,v in type_map.items()} torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ diff --git a/tinygrad/runtime/ops_webgpu.py b/tinygrad/runtime/ops_webgpu.py index 55574f6ef4..2124f83621 100644 --- a/tinygrad/runtime/ops_webgpu.py +++ b/tinygrad/runtime/ops_webgpu.py @@ -1,7 +1,7 @@ import numpy as np import functools from wgpu.utils._device import get_default_device # type: ignore -from tinygrad.runtime.lib import RawBufferCopyIn +from tinygrad.runtime.lib import RawBufferCopyIn, LRUAllocator from tinygrad.helpers import dtypes, DType from tinygrad.ops import Compiled from tinygrad.codegen.linearizer import LinearizerOptions @@ -9,32 +9,37 @@ from tinygrad.renderer.cstyle import uops_to_cstyle from tinygrad.renderer.wgsl import WGSLLanguage import wgpu # type: ignore -device = get_default_device() +wgpu_device = get_default_device() class WebGPUProgram: - def __init__(self, name: str, prg: str, binary=False): self.name,self.prg = name,device.create_shader_module(code=prg) + def __init__(self, name: str, prg: str, binary=False): self.name,self.prg = name,wgpu_device.create_shader_module(code=prg) def __call__(self, global_size, local_size, *bufs, wait=False): assert len(bufs) <= 8, "WEBGPU only supports 8 buffers" binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))] bindings = [{"binding": i, "resource": {"buffer": x._buf, "offset": 0, "size": x._buf.size}} for i, x in enumerate(bufs)] - bind_group_layout = device.create_bind_group_layout(entries=binding_layouts) - pipeline_layout = device.create_pipeline_layout(bind_group_layouts=[bind_group_layout]) - bind_group = device.create_bind_group(layout=bind_group_layout, entries=bindings) - compute_pipeline = device.create_compute_pipeline(layout=pipeline_layout,compute={"module": self.prg, "entry_point": self.name},) - command_encoder = device.create_command_encoder() + bind_group_layout = wgpu_device.create_bind_group_layout(entries=binding_layouts) + pipeline_layout = wgpu_device.create_pipeline_layout(bind_group_layouts=[bind_group_layout]) + bind_group = wgpu_device.create_bind_group(layout=bind_group_layout, entries=bindings) + compute_pipeline = wgpu_device.create_compute_pipeline(layout=pipeline_layout,compute={"module": self.prg, "entry_point": self.name},) + command_encoder = wgpu_device.create_command_encoder() compute_pass = command_encoder.begin_compute_pass() compute_pass.set_pipeline(compute_pipeline) compute_pass.set_bind_group(0, bind_group, [], 0, 999999) # last 2 not used compute_pass.dispatch_workgroups(*global_size) # x y z compute_pass.end() - device.queue.submit([command_encoder.finish()]) + wgpu_device.queue.submit([command_encoder.finish()]) + +class RawWebGPUAllocator(LRUAllocator): + def _do_alloc(self, size, dtype, device, **kwargs): return wgpu_device.create_buffer(size=size*dtype.itemsize, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC) + def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype. +WebGPUAlloc = RawWebGPUAllocator(wgpu_device.limits['max_buffer_size']) class RawWebGPUBuffer(RawBufferCopyIn): def __init__(self, size:int, dtype:DType): - assert dtype not in [dtypes.int8,dtypes.uint8,dtypes.int64,dtypes.uint64], f"dtype {dtype} not supported on WEBGPU" - super().__init__(size, dtype, device.create_buffer(size=size*dtype.itemsize, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC)) - def _copyin(self, x:np.ndarray): device.queue.write_buffer(self._buf, 0, np.ascontiguousarray(x)) - def toCPU(self) -> np.ndarray: return np.frombuffer(device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore + assert dtype not in [dtypes.int8,dtypes.uint8,dtypes.int64,dtypes.uint64,dtypes.double], f"dtype {dtype} not supported on WEBGPU" + super().__init__(size, dtype, allocator=WebGPUAlloc) + def _copyin(self, x:np.ndarray): wgpu_device.queue.write_buffer(self._buf, 0, np.ascontiguousarray(x)) + def toCPU(self) -> np.ndarray: return np.frombuffer(wgpu_device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore renderer = functools.partial(uops_to_cstyle, WGSLLanguage()) WebGpuBuffer = Compiled(RawWebGPUBuffer, LinearizerOptions(supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]), renderer, WebGPUProgram) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index beba5c4f17..0ecd5d6720 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -59,7 +59,7 @@ class View(ViewInternal): # generate an expression if you have a single idx variable def expr_node(self, idx=None) -> Node: if idx is None: idx = Variable('idx', 0, prod(self.shape)-1) - ret: List[Node] = [Variable.num(self.offset)] if self.offset else [] + ret: List[Node] = [Variable.num(self.offset) if isinstance(self.offset, int) else self.offset] if self.offset else [] acc = 1 for d,s in reversed(self.shape_strides): ret.append(((idx//acc)%d)*s) @@ -69,7 +69,7 @@ class View(ViewInternal): # generate an expression if you have a variable or expression for each index def expr_idxs(self, idxs) -> Node: assert len(idxs) == len(self.shape), f"need an idx for all dimensions {idxs} vs {self.shape}" - return Variable.sum([Variable.num(self.offset)] + [idx*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0]) + return Variable.sum([Variable.num(self.offset) if isinstance(self.offset, int) else self.offset] + [idx*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0]) @functools.lru_cache(maxsize=None) def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node: @@ -162,7 +162,7 @@ class ShapeTracker: idx, valid = self.expr_idxs(idxs) ret: List[Optional[Union[Node, int]]] = [None] * len(self.views[-1].shape) for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]): - if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable): + if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable) and this_dim.a in idxs: ret[idxs.index(this_dim.a)] = this_dim.b elif isinstance(this_dim, Variable): ret[idxs.index(this_dim)] = 1 diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 5d5e866015..8a5c6f7d51 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -65,6 +65,7 @@ class Node: def __rfloordiv__(self, b:int): raise RuntimeError(f"not supported: {b} // {self}") def __floordiv__(self, b:Union[Node,int], factoring_allowed=True): if isinstance(b, Node): + if self == b: return NumNode(1) if (b > self).min > 0 and self.min >= 0: return NumNode(0) raise RuntimeError(f"not supported: {self} // {b}") assert b != 0 @@ -262,6 +263,14 @@ def create_rednode(typ:Type[RedNode], nodes:List[Node]): elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes])) return create_node(ret) +def sym_infer(n:Union[Node,int], var_vals: Dict[Variable, int]) -> int: + if isinstance(n, (int, NumNode)): return int(n) + if isinstance(n, Variable): return var_vals[n] + if isinstance(n, MulNode): return sym_infer(n.a, var_vals) * sym_infer(n.b, var_vals) + if isinstance(n, SumNode): return sum(sym_infer(s, var_vals) for s in n.nodes) + raise NotImplementedError(n) +@functools.lru_cache(maxsize=None) +def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}" def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx) render_python: Dict[Type, Callable] = { diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2e9007494c..b26c354606 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -274,12 +274,20 @@ class Tensor: # - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s] # is possible. # - Apply Shrink to do the slice [:, 0] on axes of shapes [dim_sz_padded // s, s]. + # - Fancy indexing and combined indexing is supported + # - Combined indexing works by letting regular slicing finish first -> computing the resulting dims w.r.t to Tensors passed in -> fancy indexing + # - Any Tensors passed in __getitem__ will perform (CMPEQ with arange -> MUL with self -> SUM_REDUCE) iteratively + # - The first iteration will expand the dim of self while consecutive iterations will reduce the dim + # - The dims are reduced at sum_dim for each Tensor passed in + # - There's a special case where a permute is needed at the end: + # - if first Tensor passed in (expand dims) is not at dim 0 + # - and following Tensors does not follow consecutively to the end of fancy indexing's dims def __getitem__(self, val): def normalize_int(e, i, dim_sz): if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1 raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}") val = list(val) if isinstance(val, tuple) else [val] - if (num_slices := sum(isinstance(v, (slice, int)) for v in val)) > len(self.shape): + if (num_slices := sum(isinstance(v, (slice, int, Tensor)) for v in val)) > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}") orig_slices = list(val) ellipses_found = [i for i, v in enumerate(val) if v is Ellipsis] @@ -290,6 +298,8 @@ class Tensor: orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices) else: orig_slices += [slice(None)] * (len(self.shape) - num_slices) + tensor_found = [(i,v) for i, v in enumerate(orig_slices) if isinstance(v, Tensor)] + orig_slices = [slice(None, None, None) if isinstance(v, Tensor) else v for v in orig_slices] valid_slices = list(filterfalse(lambda x: x is None, orig_slices)) valid_slices = [v if isinstance(v, slice) else slice(y := normalize_int(v, i, dim_sz), y+1) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))] start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ()) @@ -315,21 +325,46 @@ class Tensor: final_slice = reduce(operator.add, (((0, sh), (0, 1)) for sh in new_shape), ()) sliced_tensor = reshaped_tensor.shrink(final_slice) final_shape = [] + sub = [0] * len(tensor_found) it_shape = iter(new_shape) - for i in orig_slices: - if isinstance(i, (int, slice)): + for i,s in enumerate(orig_slices): + if isinstance(s, (int, slice)): dim_shape = next(it_shape) - if isinstance(i, slice): final_shape.append(dim_shape) - else: # i is None + if isinstance(s, slice): final_shape.append(dim_shape) + elif tensor_found: + for i_ in range(len(tensor_found)): + if tensor_found[i_][0] > i: sub[i_] -= 1 + else: # s is None final_shape.append(1) - return sliced_tensor.reshape(tuple(final_shape)) # Reshape + ret = sliced_tensor.reshape(tuple(final_shape)) # Reshape + if tensor_found: # Fancy/tensor indexing + for i,s in enumerate(sub): tensor_found[i] = (tensor_found[i][0]+s, tensor_found[i][1]) + dim = [i[0] for i in tensor_found] + idx = [i[1].sign().contiguous().__neg__().contiguous().relu() * ret.shape[i[0]] + i[1] for i in tensor_found] # TODO first contiguous fixes torch+cpu_only CI, but it causes llvm to fail. Second one fixes llvm + max_dim = max(idx, key=lambda i: i.ndim).ndim + idx = [i if i.ndim == max_dim else i.reshape(*[1]*(max_dim-i.ndim), *i.shape) for i in idx] + sum_dim = [d if n==0 else d+i.ndim-n for n,(d,i) in enumerate(zip(dim,idx))] + new_idx = idx[0].reshape(*[1]*sum_dim[0], 1, *idx[0].shape, *[1]*(ret.ndim-sum_dim[0]-1)) + arange = Tensor.arange(ret.shape[sum_dim[0]], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*sum_dim[0], ret.shape[sum_dim[0]], *[1]*idx[0].ndim, *[1]*(ret.ndim-sum_dim[0]-1)) + ret = (ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*idx[0].ndim, *ret.shape[sum_dim[0]+1:]) * (arange == new_idx)).sum(sum_dim[0]) + for idx_,d in zip(idx[1:],sum_dim[1:]): + new_idx = idx_.reshape(*[1]*sum_dim[0], *idx_.shape, *[1]*(ret.ndim-sum_dim[0]-idx_.ndim)) + arange = Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*(d), ret.shape[d], *[1]*(ret.ndim-d-1)) + ret = ((new_idx == arange) * ret).sum(d) + if dim[0] != 0 and dim != list(range(dim[0], dim[-1]+1)) and len(dim) != 1: # special permute case + order = list(range(ret.ndim)) + order = order[dim[0]:dim[0]+idx[0].ndim] + order[:dim[0]] + order[dim[0]+idx[0].ndim:] + ret = ret.permute(order=order) + return ret - def gather(self, idx, dim): - idx = (idx < 0).where(idx+self.shape[dim], idx) # Turn neg idx pos - new_self = self.reshape(*self.shape[:dim+1], *[1]*idx.ndim, *self.shape[dim+1:]) - arange = Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*dim, self.shape[dim], *[1]*idx.ndim, *[1]*(self.ndim-dim-1)) - new_idx = idx.reshape(*[1]*dim, 1, *idx.shape, *[1]*(self.ndim-dim-1)) - return (new_self * (arange == new_idx)).sum(dim) + def gather(self: Tensor, idx: Tensor, dim: int): + assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim" + assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape" + if dim < 0: dim += self.ndim + idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1) + permarg = list(range(self.ndim)) + permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]] + return ((idx == Tensor.arange(self.shape[dim])) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim) def cat(self, *args, dim=0): dim = (dim + len(self.shape)) if dim < 0 else dim