mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import sys, onnx, time
|
||||
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, fetch
|
||||
from tinygrad.tensor import _from_np_dtype
|
||||
from extra.onnx import OnnxRunner
|
||||
|
||||
def load_onnx_model(fn):
|
||||
@@ -18,19 +17,25 @@ def load_onnx_model(fn):
|
||||
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())), prune=True)
|
||||
return run_onnx_jit, input_shapes, input_types
|
||||
|
||||
def get_new_inputs(input_shapes):
|
||||
#from tinygrad.tensor import _from_np_dtype
|
||||
#return {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
|
||||
import numpy as np
|
||||
return {k:Tensor(np.random.uniform(size=shp).astype(input_types[k]) * 8).realize() for k,shp in sorted(input_shapes.items())}
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_onnx_jit, input_shapes, input_types = load_onnx_model(sys.argv[1])
|
||||
print("loaded model")
|
||||
|
||||
for i in range(3):
|
||||
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
|
||||
new_inputs = get_new_inputs(input_shapes)
|
||||
GlobalCounters.reset()
|
||||
print(f"run {i}")
|
||||
run_onnx_jit(**new_inputs)
|
||||
|
||||
# run 20 times
|
||||
for _ in range(20):
|
||||
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
|
||||
new_inputs = get_new_inputs(input_shapes)
|
||||
GlobalCounters.reset()
|
||||
st = time.perf_counter()
|
||||
out = run_onnx_jit(**new_inputs)
|
||||
|
||||
@@ -17,6 +17,10 @@ from tinygrad.helpers import fetch, getenv
|
||||
# https://huggingface.co/qualcomm/MobileNet-v2-Quantized/resolve/main/MobileNet-v2-Quantized.onnx
|
||||
# ~35% - https://github.com/axinc-ai/onnx-quantization/raw/refs/heads/main/models/mobilenev2_quantized.onnx
|
||||
|
||||
# QUANT=1 python3 examples/test_onnx_imagenet.py
|
||||
# https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx
|
||||
# VIZ=1 DONT_REALIZE_EXPAND=1 python3 examples/benchmark_onnx.py /tmp/model.quant.onnx
|
||||
|
||||
def imagenet_dataloader(cnt=0):
|
||||
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
|
||||
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
|
||||
@@ -61,7 +65,7 @@ if __name__ == "__main__":
|
||||
assert shape[1:] == (3,224,224), f"shape is {shape}"
|
||||
|
||||
hit = 0
|
||||
for i,(img,y) in enumerate(imagenet_dataloader()):
|
||||
for i,(img,y) in enumerate(imagenet_dataloader(cnt=100)):
|
||||
p = run_onnx_jit(**{t_name:img})
|
||||
assert p.shape == (1,1000)
|
||||
t = p.argmax().item()
|
||||
|
||||
@@ -37,7 +37,7 @@ if __name__ == "__main__":
|
||||
print("mmapped", hex(res))
|
||||
to_mv(res, 0x10)[1] = 0xaa
|
||||
|
||||
from tinygrad.runtime.ops_clang import ClangCompiler
|
||||
from tinygrad.runtime.ops_dsp import ClangCompiler
|
||||
cc = ClangCompiler(args=["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib"])
|
||||
|
||||
obj = cc.compile("""
|
||||
|
||||
27
extra/dsp/opt.py
Normal file
27
extra/dsp/opt.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from tinygrad.runtime.ops_dsp import DSPCompiler
|
||||
|
||||
# PATH=/opt/homebrew/opt/llvm/bin:$PATH python3 extra/dsp/opt.py
|
||||
|
||||
if __name__ == "__main__":
|
||||
compiler = DSPCompiler()
|
||||
|
||||
lib = compiler.compile("""
|
||||
typedef long HVX_Vector __attribute__((__vector_size__(128))) __attribute__ ((aligned(128)));
|
||||
typedef long HVX_VectorPair __attribute__((__vector_size__(256))) __attribute__ ((aligned(256)));
|
||||
|
||||
void test(unsigned char *c, unsigned char *a, unsigned char *b) {
|
||||
HVX_Vector t0 = *(HVX_Vector*)a;
|
||||
//HVX_VectorPair t1 = *((HVX_VectorPair*)b);
|
||||
HVX_Vector acc = __builtin_HEXAGON_V6_vd0_128B();
|
||||
for (int i = 0; i < 128; i++) {
|
||||
//__builtin_HEXAGON_V6_lvsplatb_128B(t0[i])
|
||||
//acc += __builtin_HEXAGON_V6_lvsplatb_128B(t0[i]) * t1;
|
||||
//acc += t0[i] * t1;
|
||||
unsigned int t1 = ((unsigned int *)b)[i];
|
||||
//acc = __builtin_HEXAGON_V6_vrmpyub_acc_128B(acc, t0, t1);
|
||||
acc = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc, t0, t1);
|
||||
}
|
||||
*((HVX_Vector*)c) = acc;
|
||||
}""")
|
||||
|
||||
compiler.disassemble(lib)
|
||||
651
extra/onnx.py
651
extra/onnx.py
@@ -1,8 +1,8 @@
|
||||
from typing import Callable, Any, Sequence
|
||||
import importlib, functools, dataclasses
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv, DEBUG, all_same
|
||||
from tinygrad.dtype import DType, ConstType, dtypes
|
||||
from typing import Any, Sequence, cast, Literal, Callable
|
||||
import dataclasses, functools, io, math, types
|
||||
from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr
|
||||
from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple
|
||||
from tinygrad.dtype import DType, ConstType, dtypes, ImageDType
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
# ***** protobuf parsing ******
|
||||
@@ -111,11 +111,11 @@ limit = int(getenv("ONNXLIMIT", "-1"))
|
||||
class OnnxRunner:
|
||||
def __init__(self, model: ModelProto):
|
||||
# parse model protobuf
|
||||
self.is_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in model.graph.node)
|
||||
self.is_training = any(n.domain in {"ai.onnx.training", "ai.onnx.preview.training"} for n in model.graph.node)
|
||||
self.old_training, self.old_no_grad = Tensor.training, Tensor.no_grad
|
||||
Tensor.training = True if self.is_training else False
|
||||
Tensor.no_grad = False if self.is_training else True
|
||||
self.graph_values = {x.name:buffer_parse(x) for x in model.graph.initializer}
|
||||
self.graph_values = {"": None, **{x.name:buffer_parse(x) for x in model.graph.initializer}}
|
||||
self.graph_inputs = {x.name:type_parse(x.type) for x in model.graph.input if x.name not in self.graph_values}
|
||||
self.graph_outputs = {x.name:type_parse(x.type) for x in model.graph.output}
|
||||
self.graph_nodes = tuple(OnnxNode(num, n.op_type, tuple(n.input), tuple(n.output), {x.name:attribute_parse(x) for x in n.attribute})
|
||||
@@ -123,14 +123,7 @@ class OnnxRunner:
|
||||
self.opset_version = model.opset_import[0].version
|
||||
self.variable_dims: dict[str, int] = {}
|
||||
|
||||
# TODO: move extra.onnx_ops here so we don't have to deal with annoying circular import
|
||||
# TODO: clean up opset stuff after moving extra.onnx_ops here
|
||||
self.onnx_ops_module = importlib.import_module('extra.onnx_ops')
|
||||
self.onnx_ops = {
|
||||
**{op: getattr(Tensor, op.lower()) for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan",
|
||||
"Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh",
|
||||
"Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")},
|
||||
}
|
||||
self.onnx_ops = onnx_ops
|
||||
|
||||
def _parse_input(self, name: str, value: Any, spec: OnnxValue):
|
||||
if spec.is_optional and value is None: return None
|
||||
@@ -148,9 +141,8 @@ class OnnxRunner:
|
||||
return tensor
|
||||
|
||||
def _dispatch_op(self, op, inps, opts):
|
||||
if op in self.onnx_ops: return self.onnx_ops[op](*inps, **opts)
|
||||
if hasattr(self.onnx_ops_module, op):
|
||||
fxn = getattr(self.onnx_ops_module, op)
|
||||
if op in self.onnx_ops:
|
||||
fxn = self.onnx_ops[op]
|
||||
if isinstance(fxn, dict):
|
||||
for k in sorted(fxn.keys()):
|
||||
if k <= self.opset_version:
|
||||
@@ -165,7 +157,7 @@ class OnnxRunner:
|
||||
self.graph_values[name] = self._parse_input(name, inputs[name], input_spec)
|
||||
|
||||
for node in self.graph_nodes:
|
||||
inps = [to_python_const(self.graph_values.get(name), node.op, i) for i,name in enumerate(node.inputs)]
|
||||
inps = [to_python_const(self.graph_values[name], node.op, i) for i,name in enumerate(node.inputs)]
|
||||
opts = node.opts
|
||||
|
||||
# provide additional opts
|
||||
@@ -184,4 +176,623 @@ class OnnxRunner:
|
||||
Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad
|
||||
return {name:self.graph_values[name] for name in node.outputs}
|
||||
Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad
|
||||
return {name:self.graph_values[name] for name in self.graph_outputs}
|
||||
return {name:self.graph_values[name] for name in self.graph_outputs}
|
||||
|
||||
####################
|
||||
##### ONNX OPS #####
|
||||
####################
|
||||
def get_onnx_ops():
|
||||
# ***** helper functions *****
|
||||
def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None)
|
||||
|
||||
# (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...)
|
||||
def _onnx_pads_to_tiny_pads(pads): return tuple(flatten(reversed(list(zip(pads, pads[len(pads)//2:])))))
|
||||
|
||||
AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"]
|
||||
# (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right)
|
||||
def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS):
|
||||
if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))]
|
||||
return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))]
|
||||
|
||||
def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS):
|
||||
i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_))
|
||||
if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2)
|
||||
o_ = [((i - (1 if auto_pad in ("SAME_UPPER", "SAME_LOWER") else k)) // s + 1) for i,k,s in zip(i_, k_, s_)]
|
||||
return _onnx_pads_to_tiny_pads(_auto_pad([(o-1)*s+k-i for o,i,k,s in zip(o_, i_, k_, s_)], auto_pad))
|
||||
|
||||
def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype)
|
||||
|
||||
def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0):
|
||||
if axis < 0: axis += x.ndim
|
||||
if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape)
|
||||
if block_size == 0:
|
||||
shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim))
|
||||
return scale.reshape(shape), zero_point.reshape(shape)
|
||||
return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis)
|
||||
|
||||
def _op_integer(op, inputs:list[Tensor], zero_points:list[Tensor], **opts):
|
||||
adjusted_inputs = [inp.int() - zp for inp, zp in zip(inputs, zero_points)]
|
||||
return op(*adjusted_inputs, **opts)
|
||||
|
||||
def _qlinearop_quantized(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts):
|
||||
# op execution is done in quantized int
|
||||
out = _op_integer(op, inputs, zero_points, **opts)
|
||||
assert dtypes.is_int(out.dtype), "quantized op should've done math in int"
|
||||
out_quantized = (out * prod(scales) / out_scale).round() + out_zero_point
|
||||
return _clamp_cast(out_quantized, out_zero_point.dtype)
|
||||
|
||||
def _qlinearop_float(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts):
|
||||
# op execution is done in float32
|
||||
dequantized_inputs = [(inp.int() - zp) * scale for inp, zp, scale in zip(inputs, zero_points, scales)]
|
||||
out = op(*dequantized_inputs, **opts)
|
||||
assert dtypes.is_float(out.dtype), "op should've done math in float"
|
||||
out_quantized = (out / out_scale).round() + out_zero_point
|
||||
return _clamp_cast(out_quantized, out_zero_point.dtype)
|
||||
|
||||
def _onnx_training(input_group_size):
|
||||
def __decorator(func):
|
||||
def ___wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs):
|
||||
R = R.detach()
|
||||
groups = len(inputs) // input_group_size
|
||||
ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))]
|
||||
return tuple(flatten(zip(*ret)))
|
||||
return ___wrapper
|
||||
return __decorator
|
||||
|
||||
# ***** Property/Graph Ops *****
|
||||
def Identity(x:Tensor): return x
|
||||
def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, value_floats:list[float]|None=None,
|
||||
value_int:int|None=None, value_ints:list[int]|None=None, value_string:str|None=None, value_strings:list[str]|None=None):
|
||||
if value is not None: return value
|
||||
if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
|
||||
if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
|
||||
if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
|
||||
if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
|
||||
if value_string is not None or value_strings is not None and sparse_value is not None:
|
||||
raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value')
|
||||
|
||||
def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta)
|
||||
|
||||
def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"):
|
||||
try: import PIL.Image
|
||||
except ImportError as e: raise ImportError("Pillow must be installed for the ImageDecoder operator") from e
|
||||
img = PIL.Image.open(io.BytesIO(encoded_stream))
|
||||
if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1]
|
||||
if pixel_format == "RGB": return Tensor(np.array(img))
|
||||
if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1)
|
||||
raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
|
||||
|
||||
def EyeLike(x:Tensor, dtype:int|None=None, k:int=0):
|
||||
ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype)
|
||||
return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape))
|
||||
|
||||
def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0)
|
||||
def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([])
|
||||
def ConstantOfShape(shape:list[int], value:Tensor|None=None):
|
||||
if value is None: value = Tensor(0, dtype=dtypes.float32)
|
||||
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1)
|
||||
|
||||
def Size(data:Tensor): return data.numel()
|
||||
def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64)
|
||||
|
||||
# ***** Unary Ops (math) *****
|
||||
def Not(x:Tensor): return x.logical_not()
|
||||
def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None):
|
||||
return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype)
|
||||
|
||||
# ***** Unary Ops (activation) *****
|
||||
def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis)
|
||||
def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis)
|
||||
Softmax = {1:Softmax_1, 13:Softmax_13}
|
||||
def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1)
|
||||
def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
|
||||
def FastGelu(x:Tensor, bias:Tensor|None=None):
|
||||
# this is tanh approximated
|
||||
return (x + bias).gelu() if bias is not None else x.gelu()
|
||||
# TODO: fix this
|
||||
def PRelu(X:Tensor, slope:Tensor):
|
||||
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope
|
||||
return (X > 0).where(X, X * slope)
|
||||
def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha)
|
||||
def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0)
|
||||
def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis)
|
||||
def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float()
|
||||
|
||||
# ***** Unary Ops (broadcasted) *****
|
||||
def Add(x:Tensor,y:Tensor, broadcast=None, axis=None): return x + y if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + y).cast(x.dtype)
|
||||
def Sub(x:Tensor|int,y:Tensor): return x - y # some test has input as int
|
||||
def Div(x:Tensor,y:Tensor): return (x/y).cast(x.dtype)
|
||||
def Less(x:Tensor,y:Tensor): return x < y
|
||||
def LessOrEqual(x:Tensor,y:Tensor): return x <= y
|
||||
def Greater(x:Tensor,y:Tensor): return x > y
|
||||
def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y
|
||||
def Equal(x:Tensor,y:Tensor): return x == y
|
||||
def And(x:Tensor,y:Tensor): return (x==y).where(x, False)
|
||||
def Or(x:Tensor,y:Tensor): return (x==y).where(x, True)
|
||||
def BitwiseAnd(x:Tensor,y:Tensor): return x & y
|
||||
def BitwiseOr(x:Tensor,y:Tensor): return x | y
|
||||
def BitwiseXor(x:Tensor,y:Tensor): return x ^ y
|
||||
def BitwiseNot(x:Tensor): return ~x
|
||||
|
||||
# ***** Casting Ops *****
|
||||
# TODO: saturate
|
||||
def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to))
|
||||
def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype)
|
||||
|
||||
# ***** Reduce Ops *****
|
||||
def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0)
|
||||
def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0)
|
||||
def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0)
|
||||
def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0)
|
||||
def ReduceMax(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceMin(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceMean(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceSumSquare(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return ReduceSum(data.square(), axes, keepdims, noop_with_empty_axes)
|
||||
def ReduceProd(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return data.prod(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceL1(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes)
|
||||
def ReduceL2(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt()
|
||||
def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log()
|
||||
def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log()
|
||||
def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0):
|
||||
if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64)
|
||||
return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
|
||||
def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0):
|
||||
return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
|
||||
|
||||
# ***** Movement Ops *****
|
||||
def Reshape(data:Tensor, shape:list[int], allowzero:int=0):
|
||||
return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)])
|
||||
def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1)
|
||||
def Expand(x:Tensor, shape:list[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape)))
|
||||
def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias)
|
||||
def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm)
|
||||
|
||||
# TODO: add test for when axes is None
|
||||
def Squeeze(data:Tensor, axes:list[int]|None=None):
|
||||
return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data)
|
||||
def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data)
|
||||
|
||||
def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats)
|
||||
def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis)
|
||||
def Slice(data:Tensor, starts:list[int], ends:list[int], axes:list[int]|None=None, steps:list[int]|None=None):
|
||||
axes = axes or list(range(data.ndim))
|
||||
steps = steps or [1]*data.ndim
|
||||
slices = [slice(0,x,1) for x in data.shape]
|
||||
for i, axis in enumerate(axes): slices[axis] = slice(starts[i], ends[i], steps[i])
|
||||
return data[tuple(slices)]
|
||||
|
||||
def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0):
|
||||
sz = data.shape[axis]
|
||||
if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)]
|
||||
return data.split(split, axis)
|
||||
|
||||
def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None,
|
||||
mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0):
|
||||
value = constant_value or value
|
||||
axes = axes or list(range(x.ndim))
|
||||
real_pads = [0] * (x.ndim*2)
|
||||
for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)]
|
||||
return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value)
|
||||
|
||||
def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None):
|
||||
shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim
|
||||
pad_arg:list[None|tuple[int,int]] = [None] * t.ndim
|
||||
for s, x in zip(shape, axes or range(t.ndim)):
|
||||
tx = t.shape[x]
|
||||
if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2)
|
||||
elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2)
|
||||
return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
|
||||
|
||||
# ***** Processing Ops *****
|
||||
def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0,
|
||||
dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1):
|
||||
return X.avg_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad),
|
||||
ceil_mode=ceil_mode, count_include_pad=count_include_pad)
|
||||
|
||||
def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0,
|
||||
storage_order:int=0, strides:list[int]|int=1):
|
||||
ret = X.max_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), ceil_mode=ceil_mode)
|
||||
# tests expect indices with int64 dtype
|
||||
# TODO: if there are repeated values, this is wrong
|
||||
indices = ((ret.reshape(-1, 1) == X.reshape(1, -1)) * Tensor.arange(X.numel(), dtype=dtypes.int64).unsqueeze(0)).sum(1).reshape(ret.shape)
|
||||
return ret.cast(X.dtype), indices.transpose(-2, -1) if storage_order else indices
|
||||
|
||||
def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
|
||||
kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1):
|
||||
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations,
|
||||
padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad))
|
||||
|
||||
def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
|
||||
kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0,
|
||||
strides:list[int]|int=1):
|
||||
input_shape, kernel_shape = X.shape[2:], (kernel_shape or W.shape[2:])
|
||||
strides, dilations, output_padding = (make_tuple(x, len(input_shape)) for x in (strides, dilations, output_padding))
|
||||
if output_shape is not None: # we pad according to output_shape
|
||||
pads = _auto_pad([s*(i-1) + op + ((k-1)*d+1) - os for s,i,op,k,d,os in
|
||||
zip(strides, input_shape, output_padding, kernel_shape, dilations, output_shape)], auto_pad)
|
||||
if pads is None: # we generate pads
|
||||
output_shape = output_shape or [X.shape[i+2] * strides[i] for i in range(len(strides))]
|
||||
pads = [strides[i]*(input_shape[i]-1) + output_padding[i] + ((kernel_shape[i]-1)*dilations[i]+1)-output_shape[i] for i in range(len(input_shape))]
|
||||
pads = _auto_pad(pads, auto_pad) if auto_pad != "NOTSET" else [0] * len(input_shape) * 2
|
||||
pads = _onnx_pads_to_tiny_pads(pads)
|
||||
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads, output_padding=output_padding)
|
||||
|
||||
def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]=None, pads:list[int]|int=0, strides:list[int]|int=1):
|
||||
pads, strides = (make_tuple(x, len(xI.shape)) for x in (pads, strides))
|
||||
out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)]
|
||||
ret = (xI.reshape(-1, 1)._one_hot_along_dim(prod(out_sh)) * xT.reshape(-1, 1)).sum(0).reshape(1, 1, *out_sh)
|
||||
if outshape is not None and outshape != ret.shape: pads = _auto_pad([outshape[-2] - ret.shape[-2], outshape[-1] - ret.shape[-1]], "SAME_UPPER")
|
||||
return ret.pad(_onnx_pads_to_tiny_pads(pads))
|
||||
|
||||
def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True)
|
||||
def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True)
|
||||
|
||||
def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0):
|
||||
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
|
||||
if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1]))
|
||||
return ret
|
||||
|
||||
def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs)
|
||||
|
||||
def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0):
|
||||
axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis)
|
||||
if reverse: X = X.flip(axis)
|
||||
if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\
|
||||
.shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim)))
|
||||
return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis)
|
||||
|
||||
def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k)
|
||||
|
||||
def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, sizes:list[int]|None=None, antialias:int=0,
|
||||
axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0,
|
||||
extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'):
|
||||
def _apply_nearest_mode(index: Tensor, input_dim, mode: str):
|
||||
if mode == "round_prefer_floor": index = (index - 0.5).ceil()
|
||||
elif mode == "round_prefer_ceil": index = (index + 0.5).floor()
|
||||
elif mode in ["floor", "ceil"]: index = getattr(index, mode)()
|
||||
else: raise ValueError(f"invalid {nearest_mode=}")
|
||||
return index.cast(dtypes.int32).clip(0, input_dim-1)
|
||||
def _apply_transformation(index: Tensor, input_dim, scale_dim, roi_dim, mode):
|
||||
# TODO: needs more testing, not confident in this
|
||||
# NOTE: their reference implementation differ from the implementation in their reference docs
|
||||
# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_resize.py
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize
|
||||
output_dim = scale_dim * input_dim
|
||||
if mode == "half_pixel": index = (index + 0.5) / scale_dim - 0.5
|
||||
elif mode == "align_corners": index = index * (input_dim - 1) / (output_dim - 1) if output_dim != 1 else Tensor([0])
|
||||
elif mode == "asymmetric": index = index / scale_dim
|
||||
elif mode == "pytorch_half_pixel": index = (index + 0.5) / scale_dim - 0.5 if output_dim != 1 else Tensor([-0.5])
|
||||
elif mode == "half_pixel_symmetric": index = input_dim / 2 * (1 - int(output_dim) / output_dim) + (index + 0.5) / scale_dim - 0.5
|
||||
elif mode == "tf_crop_and_resize": index = roi_dim[0] * (input_dim - 1) + index * ((roi_dim[1] - roi_dim[0]) * (input_dim - 1) / (output_dim - 1))
|
||||
else: raise ValueError(f"invalid {coordinate_transformation_mode=}")
|
||||
return index.clip(0, input_dim-1)
|
||||
|
||||
scales, sizes = (None if scales is None else scales[2-(X.ndim-len(scales)):]), (None if sizes is None else sizes[2-(X.ndim-len(sizes)):])
|
||||
# we pre permute the axes and permute back after resize
|
||||
axes, input_shape, = (axes or list(range(X.ndim))), cast(tuple[int, ...], X.shape[2:]),
|
||||
perm = [a for a in range(len(X.shape)) if a not in axes] + list(axes)
|
||||
X = X.permute(*perm)
|
||||
|
||||
if sizes is not None:
|
||||
if keep_aspect_ratio_policy in ["not_larger", "not_smaller"]:
|
||||
scale_fxn = min if keep_aspect_ratio_policy == "not_larger" else max
|
||||
scales = [scale_fxn([sizes[i] / input_shape[i] for i in range(len(input_shape)) if i+2 in axes])] * 2
|
||||
sizes = [int((scales[0] * input_shape[i]) + 0.5) if i+2 in axes else input_shape[i] for i in range(X.ndim-2)]
|
||||
else:
|
||||
scales = [size / input_shape for size, input_shape in zip(sizes, input_shape)]
|
||||
else:
|
||||
sizes = [int(sc*sh) for sc, sh in zip(scales, input_shape)]
|
||||
regions = [[st, ed] for st, ed in zip(roi, roi[len(roi)//2:])] if isinstance(roi, list) and roi else [[0.0, 0.0]] * (X.ndim-2)
|
||||
|
||||
# NOTE: this transformation makes it so that we can't just call Tensor.interpolate
|
||||
# in Tensor.interpolate, we use indexes without any transformation
|
||||
indexes = []
|
||||
for shape, size, scale, region in zip(input_shape, sizes, scales, regions):
|
||||
indexes.append(_apply_transformation(Tensor.arange(size), shape, scale, region, coordinate_transformation_mode))
|
||||
|
||||
if mode == "nearest":
|
||||
indexes = [_apply_nearest_mode(index, shape, nearest_mode) for (index, shape) in zip(indexes, input_shape)]
|
||||
X = X[(..., *Tensor.meshgrid(*indexes))]
|
||||
if mode == "linear":
|
||||
expand = list(X.shape)
|
||||
for i in range(-len(sizes), 0):
|
||||
reshape, index = [1] * X.ndim, indexes[i]
|
||||
reshape[i] = expand[i] = sizes[i]
|
||||
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
|
||||
X = X.gather(i, low).lerp(X.gather(i, high), perc)
|
||||
if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented")
|
||||
return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X
|
||||
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated
|
||||
|
||||
# ***** Neural Network Ops *****
|
||||
# TODO: try to factor out common implementations for these ops
|
||||
# https://medium.com/@zljdanceholic/groupnorm-then-batchnorm-instancenorm-layernorm-e2b2a1d350a0
|
||||
def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9,
|
||||
training_mode:int=0, spatial=1, is_test=0):
|
||||
if training_mode:
|
||||
x_detached = X.detach()
|
||||
current_mean = x_detached.mean(axis=(0,2,3))
|
||||
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
current_var = (y*y).mean(axis=(0,2,3))
|
||||
current_invstd = current_var.add(epsilon).rsqrt()
|
||||
|
||||
running_mean = input_mean * momentum + current_mean * (1 - momentum)
|
||||
running_var = input_var * momentum + current_var * (1 - momentum)
|
||||
|
||||
return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
|
||||
invstd = (input_var + epsilon).rsqrt()
|
||||
return X.batchnorm(scale, B, input_mean, invstd)
|
||||
def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05):
|
||||
axis = tuple(range(2, x.ndim))
|
||||
mean = x.mean(axis=axis, keepdim=True)
|
||||
invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
|
||||
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
|
||||
def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1):
|
||||
assert stash_type == 1, "only float32 is supported"
|
||||
axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
|
||||
mean = x.mean(axis=axes, keepdim=True)
|
||||
return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt()
|
||||
def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05):
|
||||
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
|
||||
def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]):
|
||||
return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9)
|
||||
def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12):
|
||||
x = x + skip + bias
|
||||
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
|
||||
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor,
|
||||
segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None,
|
||||
position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0):
|
||||
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
|
||||
assert (segment_ids is None) is (segment_embedding is None)
|
||||
assert mask is None and not mask_index_type, "functionality not supported yet" # TODO
|
||||
input_shape = input_ids.shape
|
||||
seq_length = input_shape[1]
|
||||
compute_seg_emb = (segment_embedding is not None and segment_ids is not None)
|
||||
vocab_size, max_position_embeddings = word_embedding.shape[0], position_embedding.shape[0]
|
||||
type_vocab_size = (segment_embedding.shape[0] if compute_seg_emb else None)
|
||||
|
||||
def embedding(x:Tensor, vocab_size, weight:Tensor) -> Tensor:
|
||||
return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight
|
||||
|
||||
# bert embedding layer
|
||||
if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
|
||||
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
|
||||
pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
|
||||
seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None
|
||||
|
||||
embedding_sum = wrd_embedding_res + pos_embedding_res
|
||||
if seg_embedding_res is not None: embedding_sum = embedding_sum + seg_embedding_res
|
||||
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
|
||||
return out, None, embedding_sum
|
||||
|
||||
def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1):
|
||||
# Scalar or Rank 1 tensor containing exactly one element
|
||||
depth = int(depth[0] if isinstance(depth, list) else depth)
|
||||
indices = (indices < 0).where(indices+depth, indices)
|
||||
return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0])
|
||||
|
||||
def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"):
|
||||
return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize)
|
||||
def SpaceToDepth(X:Tensor, blocksize:int):
|
||||
return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize)
|
||||
|
||||
# Reimplemented here because you need legacy RNG for passing ONNX tests.
|
||||
def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None):
|
||||
if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
|
||||
mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device)
|
||||
return data * mask * (1/(1.0 - ratio)), mask
|
||||
# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx
|
||||
def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test)
|
||||
Dropout = {6:Dropout_6, 7:Dropout_7}
|
||||
|
||||
def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0):
|
||||
pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1)
|
||||
return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta)
|
||||
|
||||
def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
|
||||
return x.nll_loss(target, weight, ignore_index, reduction)
|
||||
def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
|
||||
log_probs = scores.log_softmax(1)
|
||||
return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs
|
||||
|
||||
def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0):
|
||||
N, _, *spatial_dims = size
|
||||
def generate_grid(steps):
|
||||
return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device)
|
||||
grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims))
|
||||
base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1)
|
||||
base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1)
|
||||
return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1)
|
||||
|
||||
def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past:Tensor|None=None,
|
||||
relative_position_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int|None=None,
|
||||
mask_filter_value:float|None=None, num_heads:int|None=None, past_present_share_buffer:int|None=None,
|
||||
qkv_hidden_sizes:list[int]|None=None, scale:float|None=None, unidirectional:int|None=None):
|
||||
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
|
||||
assert num_heads is not None # required
|
||||
assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None)
|
||||
assert relative_position_bias is do_rotary is past_sequence_length is mask_filter_value is past_present_share_buffer is scale is None, \
|
||||
"functionality not supported yet" # TODO strange params
|
||||
hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,)
|
||||
|
||||
if unidirectional: # gpt-style
|
||||
assert hidden_size == v_hidden_size
|
||||
xqkv = x.linear(weights, bias)
|
||||
xq, xk, xv = [xqkv.shrink([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)]
|
||||
else: # bert-style
|
||||
wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:]
|
||||
bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None
|
||||
xq, xk, xv = [x.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))]
|
||||
xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)]
|
||||
|
||||
if past is not None:
|
||||
xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2)
|
||||
present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0))
|
||||
|
||||
def attn(query, key, value, attn_mask):
|
||||
query_length, key_length = query.shape[-2], key.shape[-2]
|
||||
cdim = max(query_length, key_length) + 1
|
||||
attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1])
|
||||
# This is where Tensor.scaled_dot_product_attention differs:
|
||||
causal_mask = Tensor.ones((cdim, cdim), requires_grad=False, dtype=dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length]
|
||||
masked = Tensor.where(causal_mask, attn_weights, -math.inf)
|
||||
if attn_mask is not None: masked = masked + attn_mask
|
||||
return masked.softmax(-1) @ value
|
||||
|
||||
bsz, _, seq_len, _ = xq.shape
|
||||
out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
return out, present if past is not None else out
|
||||
|
||||
# ***** Indexing Ops *****
|
||||
def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices]
|
||||
|
||||
def Gather(x:Tensor, indices:Tensor, axis:int=0):
|
||||
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
|
||||
x_sh = list(x.shape)
|
||||
ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:]
|
||||
if indices.ndim > 1: indices = indices.flatten()
|
||||
indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)]
|
||||
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore
|
||||
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
|
||||
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
|
||||
return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]
|
||||
def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated
|
||||
|
||||
def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0):
|
||||
if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))]
|
||||
x_shape, i_shape = x.shape, indices.shape
|
||||
b = math.prod(x.shape[dim] for dim in range(batch_dims))
|
||||
# NOTE: each batched dim of both input and indices are equal
|
||||
x = x.reshape(b, *x.shape[batch_dims:])
|
||||
indices = indices.reshape(b, *indices.shape[batch_dims:])
|
||||
b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1])
|
||||
ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))]
|
||||
return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:])
|
||||
def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'):
|
||||
assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):]
|
||||
x = x.contiguous()
|
||||
for index, u in zip(indices.split(1, 0), updates.split(1, 0)):
|
||||
i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1))
|
||||
u = u.squeeze(0)
|
||||
if reduction == "none": x[i] = u
|
||||
elif reduction == "add": x[i] += u
|
||||
elif reduction == "mul": x[i] *= u
|
||||
else: raise NotImplementedError("reduction doesn't support max or min")
|
||||
return x
|
||||
|
||||
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"):
|
||||
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
||||
return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction))
|
||||
def GatherElements(x:Tensor, indices:Tensor, axis:int):
|
||||
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
||||
return x.gather(axis, indices)
|
||||
|
||||
def Compress(inp:Tensor, condition:list[bool], axis:int|None=None):
|
||||
if axis is None:
|
||||
inp = inp.flatten()
|
||||
axis = 0
|
||||
if axis < 0: axis += inp.ndim
|
||||
con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor
|
||||
return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]
|
||||
|
||||
# ***** Quantization Ops *****
|
||||
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
|
||||
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
|
||||
y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
|
||||
return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous()
|
||||
|
||||
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
|
||||
x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)
|
||||
return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype)
|
||||
|
||||
def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor,
|
||||
y_zero_point: Tensor|int, B:Tensor|None=None, **opts):
|
||||
return _qlinearop_quantized(Conv, [x,w], [x_zero_point,w_zero_point], [x_scale,w_scale], y_scale, y_zero_point, **{"B":B, **opts})
|
||||
|
||||
def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor,
|
||||
y_zero_point:Tensor|int) -> Tensor:
|
||||
return _qlinearop_quantized(Tensor.matmul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], y_scale, y_zero_point)
|
||||
|
||||
def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor):
|
||||
return _qlinearop_float(Tensor.add, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point)
|
||||
|
||||
def QLinearGlobalAveragePool(X:Tensor, x_scale:Tensor, x_zero_point:Tensor, y_scale:Tensor, y_zero_point:Tensor, channels_last:int):
|
||||
assert channels_last == 0, "unsure what this does"
|
||||
return _qlinearop_float(GlobalAveragePool, [X], [x_zero_point], [x_scale], y_scale, y_zero_point)
|
||||
|
||||
def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None, **opts) -> Tensor:
|
||||
return _op_integer(Conv, [x,w], [x_zero_point,w_zero_point], **{"B":B, **opts})
|
||||
|
||||
def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor:
|
||||
return _op_integer(Tensor.matmul, [A,B], [a_zero_point,b_zero_point])
|
||||
|
||||
# ***** Training Ops *****
|
||||
# NOTE: onnx test coverage only covers `T==0` cases, so for all `T>0` this isn't tested
|
||||
# NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code
|
||||
@_onnx_training(3)
|
||||
def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0):
|
||||
X, G, H = (i.detach() for i in inputs)
|
||||
grad = norm_coefficient * X + G
|
||||
H.assign(H + grad.square())
|
||||
up = grad / (H.sqrt() + epsilon)
|
||||
r = R / (1 + T * decay_factor)
|
||||
X.assign(X.detach() - r * up)
|
||||
return [X, H]
|
||||
|
||||
@_onnx_training(4)
|
||||
def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0,
|
||||
norm_coefficient_post:float=0.0):
|
||||
from tinygrad.nn.optim import Adam as TinyAdam
|
||||
X, G, V, H = inputs
|
||||
G, V, H = G.detach(), V.detach(), H.detach() # TODO we shouldn't need these detaches
|
||||
X.grad = norm_coefficient * X.detach() + G
|
||||
opt = TinyAdam([X], b1=alpha, b2=beta, eps=epsilon)
|
||||
opt.m, opt.v, opt.lr = [V], [H], R
|
||||
# need no-op for m_hat and v_hat if T == 0
|
||||
if T == 0: opt.b1_t, opt.b2_t = opt.b1_t.zeros_like(), opt.b2_t.zeros_like()
|
||||
else:
|
||||
# `T-1` since it's applied again at the start of `_step`
|
||||
opt.b1_t = Tensor([alpha**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False)
|
||||
opt.b2_t = Tensor([beta**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False)
|
||||
opt.step()
|
||||
X = (1 - norm_coefficient_post) * X
|
||||
return [X, V, H]
|
||||
|
||||
@_onnx_training(3)
|
||||
def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float):
|
||||
from tinygrad.nn.optim import SGD
|
||||
X, G, V = inputs
|
||||
G, V = G.detach(), V.detach()
|
||||
X.grad = (norm_coefficient * X.detach() + G) * (beta if T > 0 else 1)
|
||||
opt = SGD([X], momentum=alpha, nesterov=(mode=="nesterov"))
|
||||
opt.b, opt.lr = [V], R
|
||||
opt.step()
|
||||
return [X, V]
|
||||
|
||||
def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **_):
|
||||
intermediate_tensors[y].backward()
|
||||
return tuple([t.grad for t in inputs])
|
||||
|
||||
return {
|
||||
# Tensor ops
|
||||
**{op: getattr(Tensor, op.lower()) for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan",
|
||||
"Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh",
|
||||
"Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")},
|
||||
# Implemented ops
|
||||
**{name:obj for name,obj in locals().items() if isinstance(obj, types.FunctionType) and not name.startswith("_") and name[0].isupper()},
|
||||
# Version ops
|
||||
**{name:obj for name,obj in locals().items() if isinstance(obj, dict)},
|
||||
}
|
||||
|
||||
onnx_ops = get_onnx_ops()
|
||||
|
||||
@@ -1,606 +0,0 @@
|
||||
import functools, io, math
|
||||
from typing import cast, Literal
|
||||
from tinygrad.tensor import Tensor, _broadcast_shape, ConstType, ReductionStr
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType
|
||||
from tinygrad.helpers import prod, flatten, make_tuple
|
||||
from extra.onnx import dtype_parse, _cached_to_python_const
|
||||
import numpy as np
|
||||
|
||||
# ***** Property/Graph Ops *****
|
||||
def Identity(x:Tensor): return x
|
||||
def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, value_floats:list[float]|None=None,
|
||||
value_int:int|None=None, value_ints:list[int]|None=None, value_string:str|None=None, value_strings:list[str]|None=None):
|
||||
if value is not None: return value
|
||||
if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
|
||||
if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
|
||||
if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
|
||||
if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
|
||||
if value_string is not None or value_strings is not None and sparse_value is not None:
|
||||
raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value')
|
||||
|
||||
def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta)
|
||||
|
||||
def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"):
|
||||
try: import PIL.Image
|
||||
except ImportError as e: raise ImportError("Pillow must be installed for the ImageDecoder operator") from e
|
||||
img = PIL.Image.open(io.BytesIO(encoded_stream))
|
||||
if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1]
|
||||
if pixel_format == "RGB": return Tensor(np.array(img))
|
||||
if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1)
|
||||
raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
|
||||
|
||||
def EyeLike(x:Tensor, dtype:int|None=None, k:int=0):
|
||||
ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype)
|
||||
return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape))
|
||||
|
||||
def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0)
|
||||
def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([])
|
||||
def ConstantOfShape(shape:list[int], value:Tensor|None=None):
|
||||
if value is None: value = Tensor(0, dtype=dtypes.float32)
|
||||
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1)
|
||||
|
||||
def Size(data:Tensor): return data.numel()
|
||||
def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64)
|
||||
|
||||
# ***** Unary Ops (math) *****
|
||||
def Not(x:Tensor): return x.logical_not()
|
||||
def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None):
|
||||
return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype)
|
||||
|
||||
# ***** Unary Ops (activation) *****
|
||||
def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis)
|
||||
def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis)
|
||||
Softmax = {1:Softmax_1, 13:Softmax_13}
|
||||
def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1)
|
||||
def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
|
||||
def FastGelu(x:Tensor, bias:Tensor|None=None):
|
||||
# this is tanh approximated
|
||||
return (x + bias).gelu() if bias is not None else x.gelu()
|
||||
# TODO: fix this
|
||||
def PRelu(X:Tensor, slope:Tensor):
|
||||
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope
|
||||
return (X > 0).where(X, X * slope)
|
||||
def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha)
|
||||
def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0)
|
||||
def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis)
|
||||
def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float()
|
||||
|
||||
# ***** Unary Ops (broadcasted) *****
|
||||
def Add(x:Tensor,y:Tensor, broadcast=None, axis=None): return x + y if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + y).cast(x.dtype)
|
||||
def Sub(x:Tensor|int,y:Tensor): return x - y # some test has input as int
|
||||
def Div(x:Tensor,y:Tensor): return (x/y).cast(x.dtype)
|
||||
def Less(x:Tensor,y:Tensor): return x < y
|
||||
def LessOrEqual(x:Tensor,y:Tensor): return x <= y
|
||||
def Greater(x:Tensor,y:Tensor): return x > y
|
||||
def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y
|
||||
def Equal(x:Tensor,y:Tensor): return x == y
|
||||
def And(x:Tensor,y:Tensor): return (x==y).where(x, False)
|
||||
def Or(x:Tensor,y:Tensor): return (x==y).where(x, True)
|
||||
def BitwiseAnd(x:Tensor,y:Tensor): return x & y
|
||||
def BitwiseOr(x:Tensor,y:Tensor): return x | y
|
||||
def BitwiseXor(x:Tensor,y:Tensor): return x ^ y
|
||||
def BitwiseNot(x:Tensor): return ~x
|
||||
|
||||
# ***** Casting Ops *****
|
||||
# TODO: saturate
|
||||
def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to))
|
||||
def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype)
|
||||
|
||||
# ***** Reduce Ops *****
|
||||
def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0)
|
||||
def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0)
|
||||
def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0)
|
||||
def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0)
|
||||
def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None)
|
||||
def ReduceMax(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceMin(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceMean(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceSumSquare(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return ReduceSum(data.square(), axes, keepdims, noop_with_empty_axes)
|
||||
def ReduceProd(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return data.prod(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceL1(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes)
|
||||
def ReduceL2(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt()
|
||||
def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log()
|
||||
def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log()
|
||||
def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0):
|
||||
if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64)
|
||||
return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
|
||||
def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0):
|
||||
return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
|
||||
|
||||
# ***** Movement Ops *****
|
||||
def Reshape(data:Tensor, shape:list[int], allowzero:int=0):
|
||||
return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)])
|
||||
def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1)
|
||||
def Expand(x:Tensor, shape:list[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape)))
|
||||
def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias)
|
||||
def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm)
|
||||
|
||||
# TODO: add test for when axes is None
|
||||
def Squeeze(data:Tensor, axes:list[int]|None=None):
|
||||
return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data)
|
||||
def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data)
|
||||
|
||||
def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats)
|
||||
def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis)
|
||||
def Slice(data:Tensor, starts:list[int], ends:list[int], axes:list[int]|None=None, steps:list[int]|None=None):
|
||||
axes = axes or list(range(data.ndim))
|
||||
steps = steps or [1]*data.ndim
|
||||
slices = [slice(0,x,1) for x in data.shape]
|
||||
for i, axis in enumerate(axes): slices[axis] = slice(starts[i], ends[i], steps[i])
|
||||
return data[tuple(slices)]
|
||||
|
||||
def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0):
|
||||
sz = data.shape[axis]
|
||||
if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)]
|
||||
return data.split(split, axis)
|
||||
|
||||
def _onnx_pads_to_tiny_pads(pads):
|
||||
# (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...)
|
||||
return tuple(flatten(reversed(list(zip(pads, pads[len(pads)//2:])))))
|
||||
def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None,
|
||||
mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0):
|
||||
value = constant_value or value
|
||||
axes = axes or list(range(x.ndim))
|
||||
real_pads = [0] * (x.ndim*2)
|
||||
for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)]
|
||||
return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value)
|
||||
|
||||
def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None):
|
||||
shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim
|
||||
pad_arg:list[None|tuple[int,int]] = [None] * t.ndim
|
||||
for s, x in zip(shape, axes or range(t.ndim)):
|
||||
tx = t.shape[x]
|
||||
if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2)
|
||||
elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2)
|
||||
return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
|
||||
|
||||
# ***** Processing Ops *****
|
||||
AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"]
|
||||
def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS):
|
||||
# (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right)
|
||||
if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))]
|
||||
return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))]
|
||||
def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS):
|
||||
i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_))
|
||||
if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2)
|
||||
o_ = [((i - (1 if auto_pad in ("SAME_UPPER", "SAME_LOWER") else k)) // s + 1) for i,k,s in zip(i_, k_, s_)]
|
||||
return _onnx_pads_to_tiny_pads(_auto_pad([(o-1)*s+k-i for o,i,k,s in zip(o_, i_, k_, s_)], auto_pad))
|
||||
|
||||
def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0,
|
||||
dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1):
|
||||
return X.avg_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad),
|
||||
ceil_mode=ceil_mode, count_include_pad=count_include_pad)
|
||||
|
||||
def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0,
|
||||
storage_order:int=0, strides:list[int]|int=1):
|
||||
ret = X.max_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), ceil_mode=ceil_mode)
|
||||
# tests expect indices with int64 dtype
|
||||
# TODO: if there are repeated values, this is wrong
|
||||
indices = ((ret.reshape(-1, 1) == X.reshape(1, -1)) * Tensor.arange(X.numel(), dtype=dtypes.int64).unsqueeze(0)).sum(1).reshape(ret.shape)
|
||||
return ret.cast(X.dtype), indices.transpose(-2, -1) if storage_order else indices
|
||||
|
||||
def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
|
||||
kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1):
|
||||
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations,
|
||||
padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad))
|
||||
|
||||
def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
|
||||
kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0,
|
||||
strides:list[int]|int=1):
|
||||
input_shape, kernel_shape = X.shape[2:], (kernel_shape or W.shape[2:])
|
||||
strides, dilations, output_padding = (make_tuple(x, len(input_shape)) for x in (strides, dilations, output_padding))
|
||||
if output_shape is not None: # we pad according to output_shape
|
||||
pads = _auto_pad([s*(i-1) + op + ((k-1)*d+1) - os for s,i,op,k,d,os in
|
||||
zip(strides, input_shape, output_padding, kernel_shape, dilations, output_shape)], auto_pad)
|
||||
if pads is None: # we generate pads
|
||||
output_shape = output_shape or [X.shape[i+2] * strides[i] for i in range(len(strides))]
|
||||
pads = [strides[i]*(input_shape[i]-1) + output_padding[i] + ((kernel_shape[i]-1)*dilations[i]+1)-output_shape[i] for i in range(len(input_shape))]
|
||||
pads = _auto_pad(pads, auto_pad) if auto_pad != "NOTSET" else [0] * len(input_shape) * 2
|
||||
pads = _onnx_pads_to_tiny_pads(pads)
|
||||
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads, output_padding=output_padding)
|
||||
|
||||
def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]=None, pads:list[int]|int=0, strides:list[int]|int=1):
|
||||
pads, strides = (make_tuple(x, len(xI.shape)) for x in (pads, strides))
|
||||
out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)]
|
||||
ret = (xI.reshape(-1, 1)._one_hot_along_dim(prod(out_sh)) * xT.reshape(-1, 1)).sum(0).reshape(1, 1, *out_sh)
|
||||
if outshape is not None and outshape != ret.shape: pads = _auto_pad([outshape[-2] - ret.shape[-2], outshape[-1] - ret.shape[-1]], "SAME_UPPER")
|
||||
return ret.pad(_onnx_pads_to_tiny_pads(pads))
|
||||
|
||||
def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True)
|
||||
def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True)
|
||||
|
||||
def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0):
|
||||
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
|
||||
if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1]))
|
||||
return ret
|
||||
|
||||
def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs)
|
||||
|
||||
def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0):
|
||||
axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis)
|
||||
if reverse: X = X.flip(axis)
|
||||
if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\
|
||||
.shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim)))
|
||||
return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis)
|
||||
|
||||
def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k)
|
||||
|
||||
def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, sizes:list[int]|None=None, antialias:int=0,
|
||||
axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0,
|
||||
extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'):
|
||||
def _apply_nearest_mode(index: Tensor, input_dim, mode: str):
|
||||
if mode == "round_prefer_floor": index = (index - 0.5).ceil()
|
||||
elif mode == "round_prefer_ceil": index = (index + 0.5).floor()
|
||||
elif mode in ["floor", "ceil"]: index = getattr(index, mode)()
|
||||
else: raise ValueError(f"invalid {nearest_mode=}")
|
||||
return index.cast(dtypes.int32).clip(0, input_dim-1)
|
||||
def _apply_transformation(index: Tensor, input_dim, scale_dim, roi_dim, mode):
|
||||
# TODO: needs more testing, not confident in this
|
||||
# NOTE: their reference implementation differ from the implementation in their reference docs
|
||||
# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_resize.py
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize
|
||||
output_dim = scale_dim * input_dim
|
||||
if mode == "half_pixel": index = (index + 0.5) / scale_dim - 0.5
|
||||
elif mode == "align_corners": index = index * (input_dim - 1) / (output_dim - 1) if output_dim != 1 else Tensor([0])
|
||||
elif mode == "asymmetric": index = index / scale_dim
|
||||
elif mode == "pytorch_half_pixel": index = (index + 0.5) / scale_dim - 0.5 if output_dim != 1 else Tensor([-0.5])
|
||||
elif mode == "half_pixel_symmetric": index = input_dim / 2 * (1 - int(output_dim) / output_dim) + (index + 0.5) / scale_dim - 0.5
|
||||
elif mode == "tf_crop_and_resize": index = roi_dim[0] * (input_dim - 1) + index * ((roi_dim[1] - roi_dim[0]) * (input_dim - 1) / (output_dim - 1))
|
||||
else: raise ValueError(f"invalid {coordinate_transformation_mode=}")
|
||||
return index.clip(0, input_dim-1)
|
||||
|
||||
scales, sizes = (None if scales is None else scales[2-(X.ndim-len(scales)):]), (None if sizes is None else sizes[2-(X.ndim-len(sizes)):])
|
||||
# we pre permute the axes and permute back after resize
|
||||
axes, input_shape, = (axes or list(range(X.ndim))), cast(tuple[int, ...], X.shape[2:]),
|
||||
perm = [a for a in range(len(X.shape)) if a not in axes] + list(axes)
|
||||
X = X.permute(*perm)
|
||||
|
||||
if sizes is not None:
|
||||
if keep_aspect_ratio_policy in ["not_larger", "not_smaller"]:
|
||||
scale_fxn = min if keep_aspect_ratio_policy == "not_larger" else max
|
||||
scales = [scale_fxn([sizes[i] / input_shape[i] for i in range(len(input_shape)) if i+2 in axes])] * 2
|
||||
sizes = [int((scales[0] * input_shape[i]) + 0.5) if i+2 in axes else input_shape[i] for i in range(X.ndim-2)]
|
||||
else:
|
||||
scales = [size / input_shape for size, input_shape in zip(sizes, input_shape)]
|
||||
else:
|
||||
sizes = [int(sc*sh) for sc, sh in zip(scales, input_shape)]
|
||||
regions = [[st, ed] for st, ed in zip(roi, roi[len(roi)//2:])] if isinstance(roi, list) and roi else [[0.0, 0.0]] * (X.ndim-2)
|
||||
|
||||
# NOTE: this transformation makes it so that we can't just call Tensor.interpolate
|
||||
# in Tensor.interpolate, we use indexes without any transformation
|
||||
indexes = []
|
||||
for shape, size, scale, region in zip(input_shape, sizes, scales, regions):
|
||||
indexes.append(_apply_transformation(Tensor.arange(size), shape, scale, region, coordinate_transformation_mode))
|
||||
|
||||
if mode == "nearest":
|
||||
indexes = [_apply_nearest_mode(index, shape, nearest_mode) for (index, shape) in zip(indexes, input_shape)]
|
||||
X = X[(..., *Tensor.meshgrid(*indexes))]
|
||||
if mode == "linear":
|
||||
expand = list(X.shape)
|
||||
for i in range(-len(sizes), 0):
|
||||
reshape, index = [1] * X.ndim, indexes[i]
|
||||
reshape[i] = expand[i] = sizes[i]
|
||||
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
|
||||
X = X.gather(i, low).lerp(X.gather(i, high), perc)
|
||||
if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented")
|
||||
return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X
|
||||
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated
|
||||
|
||||
# ***** Neural Network Ops *****
|
||||
# TODO: try to factor out common implementations for these ops
|
||||
# https://medium.com/@zljdanceholic/groupnorm-then-batchnorm-instancenorm-layernorm-e2b2a1d350a0
|
||||
def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9,
|
||||
training_mode:int=0, spatial=1, is_test=0):
|
||||
if training_mode:
|
||||
x_detached = X.detach()
|
||||
current_mean = x_detached.mean(axis=(0,2,3))
|
||||
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
current_var = (y*y).mean(axis=(0,2,3))
|
||||
current_invstd = current_var.add(epsilon).rsqrt()
|
||||
|
||||
running_mean = input_mean * momentum + current_mean * (1 - momentum)
|
||||
running_var = input_var * momentum + current_var * (1 - momentum)
|
||||
|
||||
return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
|
||||
invstd = (input_var + epsilon).rsqrt()
|
||||
return X.batchnorm(scale, B, input_mean, invstd)
|
||||
def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05):
|
||||
axis = tuple(range(2, x.ndim))
|
||||
mean = x.mean(axis=axis, keepdim=True)
|
||||
invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
|
||||
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
|
||||
def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1):
|
||||
assert stash_type == 1, "only float32 is supported"
|
||||
axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
|
||||
mean = x.mean(axis=axes, keepdim=True)
|
||||
return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt()
|
||||
def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05):
|
||||
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
|
||||
def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]):
|
||||
return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9)
|
||||
def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12):
|
||||
x = x + skip + bias
|
||||
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
|
||||
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor,
|
||||
segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None,
|
||||
position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0):
|
||||
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
|
||||
assert (segment_ids is None) is (segment_embedding is None)
|
||||
assert mask is None and not mask_index_type, "functionality not supported yet" # TODO
|
||||
input_shape = input_ids.shape
|
||||
seq_length = input_shape[1]
|
||||
compute_seg_emb = (segment_embedding is not None and segment_ids is not None)
|
||||
vocab_size, max_position_embeddings = word_embedding.shape[0], position_embedding.shape[0]
|
||||
type_vocab_size = (segment_embedding.shape[0] if compute_seg_emb else None)
|
||||
|
||||
def embedding(x:Tensor, vocab_size, weight:Tensor) -> Tensor:
|
||||
return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight
|
||||
|
||||
# bert embedding layer
|
||||
if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
|
||||
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
|
||||
pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
|
||||
seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None
|
||||
|
||||
embedding_sum = wrd_embedding_res + pos_embedding_res
|
||||
if seg_embedding_res is not None: embedding_sum = embedding_sum + seg_embedding_res
|
||||
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
|
||||
return out, None, embedding_sum
|
||||
|
||||
def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1):
|
||||
# Scalar or Rank 1 tensor containing exactly one element
|
||||
depth = int(depth[0] if isinstance(depth, list) else depth)
|
||||
indices = (indices < 0).where(indices+depth, indices)
|
||||
return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0])
|
||||
|
||||
def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"):
|
||||
return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize)
|
||||
def SpaceToDepth(X:Tensor, blocksize:int):
|
||||
return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize)
|
||||
|
||||
# Reimplemented here because you need legacy RNG for passing ONNX tests.
|
||||
def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None):
|
||||
if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
|
||||
mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device)
|
||||
return data * mask * (1/(1.0 - ratio)), mask
|
||||
# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx
|
||||
def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test)
|
||||
Dropout = {6:Dropout_6, 7:Dropout_7}
|
||||
|
||||
def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0):
|
||||
pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1)
|
||||
return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta)
|
||||
|
||||
def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
|
||||
return x.nll_loss(target, weight, ignore_index, reduction)
|
||||
def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
|
||||
log_probs = scores.log_softmax(1)
|
||||
return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs
|
||||
|
||||
def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0):
|
||||
N, _, *spatial_dims = size
|
||||
def generate_grid(steps):
|
||||
return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device)
|
||||
grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims))
|
||||
base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1)
|
||||
base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1)
|
||||
return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1)
|
||||
|
||||
def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past:Tensor|None=None,
|
||||
relative_position_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int|None=None,
|
||||
mask_filter_value:float|None=None, num_heads:int|None=None, past_present_share_buffer:int|None=None,
|
||||
qkv_hidden_sizes:list[int]|None=None, scale:float|None=None, unidirectional:int|None=None):
|
||||
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
|
||||
assert num_heads is not None # required
|
||||
assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None)
|
||||
assert relative_position_bias is do_rotary is past_sequence_length is mask_filter_value is past_present_share_buffer is scale is None, \
|
||||
"functionality not supported yet" # TODO strange params
|
||||
hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,)
|
||||
|
||||
if unidirectional: # gpt-style
|
||||
assert hidden_size == v_hidden_size
|
||||
xqkv = x.linear(weights, bias)
|
||||
xq, xk, xv = [xqkv.shrink([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)]
|
||||
else: # bert-style
|
||||
wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:]
|
||||
bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None
|
||||
xq, xk, xv = [x.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))]
|
||||
xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)]
|
||||
|
||||
if past is not None:
|
||||
xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2)
|
||||
present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0))
|
||||
|
||||
def attn(query, key, value, attn_mask):
|
||||
query_length, key_length = query.shape[-2], key.shape[-2]
|
||||
cdim = max(query_length, key_length) + 1
|
||||
attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1])
|
||||
# This is where Tensor.scaled_dot_product_attention differs:
|
||||
causal_mask = Tensor.ones((cdim, cdim), requires_grad=False, dtype=dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length]
|
||||
masked = Tensor.where(causal_mask, attn_weights, -math.inf)
|
||||
if attn_mask is not None: masked = masked + attn_mask
|
||||
return masked.softmax(-1) @ value
|
||||
|
||||
bsz, _, seq_len, _ = xq.shape
|
||||
out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
return out, present if past is not None else out
|
||||
|
||||
# ***** Indexing Ops *****
|
||||
def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices]
|
||||
|
||||
def Gather(x:Tensor, indices:Tensor, axis:int=0):
|
||||
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
|
||||
x_sh = list(x.shape)
|
||||
ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:]
|
||||
if indices.ndim > 1: indices = indices.flatten()
|
||||
indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)]
|
||||
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore
|
||||
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
|
||||
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
|
||||
return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]
|
||||
def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated
|
||||
|
||||
def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0):
|
||||
if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))]
|
||||
x_shape, i_shape = x.shape, indices.shape
|
||||
b = math.prod(x.shape[dim] for dim in range(batch_dims))
|
||||
# NOTE: each batched dim of both input and indices are equal
|
||||
x = x.reshape(b, *x.shape[batch_dims:])
|
||||
indices = indices.reshape(b, *indices.shape[batch_dims:])
|
||||
b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1])
|
||||
ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))]
|
||||
return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:])
|
||||
def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'):
|
||||
assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):]
|
||||
x = x.contiguous()
|
||||
for index, u in zip(indices.split(1, 0), updates.split(1, 0)):
|
||||
i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1))
|
||||
u = u.squeeze(0)
|
||||
if reduction == "none": x[i] = u
|
||||
elif reduction == "add": x[i] += u
|
||||
elif reduction == "mul": x[i] *= u
|
||||
else: raise NotImplementedError("reduction doesn't support max or min")
|
||||
return x
|
||||
|
||||
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"):
|
||||
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
||||
return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction))
|
||||
def GatherElements(x:Tensor, indices:Tensor, axis:int):
|
||||
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
||||
return x.gather(axis, indices)
|
||||
|
||||
def Compress(inp:Tensor, condition:list[bool], axis:int|None=None):
|
||||
if axis is None:
|
||||
inp = inp.flatten()
|
||||
axis = 0
|
||||
if axis < 0: axis += inp.ndim
|
||||
con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor
|
||||
return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]
|
||||
|
||||
# ***** Quantization Ops *****
|
||||
def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype)
|
||||
|
||||
def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0):
|
||||
if axis < 0: axis += x.ndim
|
||||
if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape)
|
||||
if block_size == 0:
|
||||
shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim))
|
||||
return scale.reshape(shape), zero_point.reshape(shape)
|
||||
return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis)
|
||||
|
||||
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
|
||||
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
|
||||
y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
|
||||
return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous()
|
||||
|
||||
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
|
||||
x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)
|
||||
return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype)
|
||||
|
||||
def _op_integer(op, inputs:list[Tensor], zero_points:list[Tensor], **opts):
|
||||
adjusted_inputs = [inp.int() - zp for inp, zp in zip(inputs, zero_points)]
|
||||
return op(*adjusted_inputs, **opts)
|
||||
|
||||
def _qlinearop_quantized(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts):
|
||||
# op execution is done in quantized int
|
||||
out = _op_integer(op, inputs, zero_points, **opts)
|
||||
assert dtypes.is_int(out.dtype), "quantized op should've done math in int"
|
||||
out_quantized = (out * prod(scales) / out_scale).round() + out_zero_point
|
||||
return _clamp_cast(out_quantized, out_zero_point.dtype)
|
||||
|
||||
def _qlinearop_float(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts):
|
||||
# op execution is done in float32
|
||||
dequantized_inputs = [(inp.int() - zp) * scale for inp, zp, scale in zip(inputs, zero_points, scales)]
|
||||
out = op(*dequantized_inputs, **opts)
|
||||
assert dtypes.is_float(out.dtype), "op should've done math in float"
|
||||
out_quantized = (out / out_scale).round() + out_zero_point
|
||||
return _clamp_cast(out_quantized, out_zero_point.dtype)
|
||||
|
||||
def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor,
|
||||
y_zero_point: Tensor|int, B:Tensor|None=None, **opts):
|
||||
return _qlinearop_quantized(Conv, [x,w], [x_zero_point,w_zero_point], [x_scale,w_scale], y_scale, y_zero_point, **{"B":B, **opts})
|
||||
|
||||
def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor,
|
||||
y_zero_point:Tensor|int) -> Tensor:
|
||||
return _qlinearop_quantized(Tensor.matmul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], y_scale, y_zero_point)
|
||||
|
||||
def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor):
|
||||
return _qlinearop_float(Tensor.add, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point)
|
||||
|
||||
def QLinearGlobalAveragePool(X:Tensor, x_scale:Tensor, x_zero_point:Tensor, y_scale:Tensor, y_zero_point:Tensor, channels_last:int):
|
||||
assert channels_last == 0, "unsure what this does"
|
||||
return _qlinearop_float(GlobalAveragePool, [X], [x_zero_point], [x_scale], y_scale, y_zero_point)
|
||||
|
||||
def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None, **opts) -> Tensor:
|
||||
return _op_integer(Conv, [x,w], [x_zero_point,w_zero_point], **{"B":B, **opts})
|
||||
|
||||
def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor:
|
||||
return _op_integer(Tensor.matmul, [A,B], [a_zero_point,b_zero_point])
|
||||
|
||||
# ***** Training Ops *****
|
||||
# NOTE: onnx test coverage only covers `T==0` cases, so for all `T>0` this isn't tested
|
||||
# NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code
|
||||
def _onnx_training(input_group_size):
|
||||
def __decorator(func):
|
||||
def ___wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs):
|
||||
R = R.detach()
|
||||
groups = len(inputs) // input_group_size
|
||||
ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))]
|
||||
return tuple(flatten(zip(*ret)))
|
||||
return ___wrapper
|
||||
return __decorator
|
||||
|
||||
@_onnx_training(3)
|
||||
def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0):
|
||||
X, G, H = (i.detach() for i in inputs)
|
||||
grad = norm_coefficient * X + G
|
||||
H.assign(H + grad.square())
|
||||
up = grad / (H.sqrt() + epsilon)
|
||||
r = R / (1 + T * decay_factor)
|
||||
X.assign(X.detach() - r * up)
|
||||
return [X, H]
|
||||
|
||||
@_onnx_training(4)
|
||||
def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0,
|
||||
norm_coefficient_post:float=0.0):
|
||||
from tinygrad.nn.optim import Adam as TinyAdam
|
||||
X, G, V, H = inputs
|
||||
G, V, H = G.detach(), V.detach(), H.detach() # TODO we shouldn't need these detaches
|
||||
X.grad = norm_coefficient * X.detach() + G
|
||||
opt = TinyAdam([X], b1=alpha, b2=beta, eps=epsilon)
|
||||
opt.m, opt.v, opt.lr = [V], [H], R
|
||||
# need no-op for m_hat and v_hat if T == 0
|
||||
if T == 0: opt.b1_t, opt.b2_t = opt.b1_t.zeros_like(), opt.b2_t.zeros_like()
|
||||
else:
|
||||
# `T-1` since it's applied again at the start of `_step`
|
||||
opt.b1_t = Tensor([alpha**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False)
|
||||
opt.b2_t = Tensor([beta**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False)
|
||||
opt.step()
|
||||
X = (1 - norm_coefficient_post) * X
|
||||
return [X, V, H]
|
||||
|
||||
@_onnx_training(3)
|
||||
def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float):
|
||||
from tinygrad.nn.optim import SGD
|
||||
X, G, V = inputs
|
||||
G, V = G.detach(), V.detach()
|
||||
X.grad = (norm_coefficient * X.detach() + G) * (beta if T > 0 else 1)
|
||||
opt = SGD([X], momentum=alpha, nesterov=(mode=="nesterov"))
|
||||
opt.b, opt.lr = [V], R
|
||||
opt.step()
|
||||
return [X, V]
|
||||
|
||||
def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **_):
|
||||
intermediate_tensors[y].backward()
|
||||
return tuple([t.grad for t in inputs])
|
||||
@@ -94,10 +94,9 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.ones(4))
|
||||
def test_literal_one_pow(self):
|
||||
_check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4]))
|
||||
# this fails because of DETACH, it shouldn't
|
||||
# update: passes after CONST(VIEW(DEVICE)) in tensor
|
||||
# TODO: pow simplification
|
||||
def test_tensor_one_pow(self):
|
||||
_check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
|
||||
_check_ast_count(1, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
|
||||
|
||||
# folds advance indexing into basic indexing
|
||||
class TestIndexingConstFolding(unittest.TestCase):
|
||||
|
||||
@@ -3,7 +3,7 @@ import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
from tinygrad.helpers import Timing, fetch, temp, CI
|
||||
from tinygrad.helpers import Timing, fetch, temp, CI, OSX
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
def compare_weights_both(url):
|
||||
@@ -298,6 +298,7 @@ class TestDiskTensor(unittest.TestCase):
|
||||
ret = t.bitcast(dtypes.uint16).to("CLANG") + 1
|
||||
assert ret.tolist() == [2827, 3341, 3855, 4369]
|
||||
|
||||
@unittest.skipIf(OSX, "new LLVM has an issue on OSX")
|
||||
def test_bf16_disk_write_read(self):
|
||||
t = Tensor([10000, -1, -1000, -10000, 20], dtype=dtypes.float32)
|
||||
t.to(f"disk:{temp('dt_bf16_disk_write_read_f32')}").realize()
|
||||
|
||||
@@ -620,7 +620,7 @@ class Kernel:
|
||||
if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals
|
||||
local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape))
|
||||
st = store_st = ShapeTracker.from_shape(local_shape)
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), f"temp{i + 1}")
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), f"temp{i}")
|
||||
if swizzle: store_st = get_tc_swizzle_st(store_st.shape, *swizzle)
|
||||
local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i])
|
||||
srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store))
|
||||
@@ -648,7 +648,7 @@ class Kernel:
|
||||
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
|
||||
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
|
||||
local_size = st_uop.arg.real_size()
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)+1}")
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")
|
||||
local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, ret)))
|
||||
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
|
||||
if op is self.reduceops[-1]: return grouped_reduce
|
||||
|
||||
@@ -43,13 +43,28 @@ def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) ->
|
||||
|
||||
from tinygrad.ops import PatternMatcher, UPat, GroupOp, graph_rewrite_map, track_rewrites
|
||||
|
||||
def get_axis(root:UOp):
|
||||
if root.op is Ops.MULTI: return root.arg[0]
|
||||
# NOTE: they all have to share an axis, we always choose [-1]
|
||||
if root.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in root.src if x.axis is not None])) else None
|
||||
src_axis = get_axis(root.src[0])
|
||||
if root.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in root.arg[1] else src_axis
|
||||
if root.op is Ops.RESHAPE:
|
||||
if src_axis is None: return None
|
||||
arg_acc:list[sint] = list(itertools.accumulate(root.arg, operator.mul, initial=1))
|
||||
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
|
||||
# TODO: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
|
||||
return len(arg_acc) - arg_acc[::-1].index(prod(root.src[0].shape[:src_axis])) - 1
|
||||
if root.op is Ops.PERMUTE: return root.arg.index(src_axis) if src_axis is not None else None
|
||||
raise NotImplementedError("rest should be passthrough")
|
||||
|
||||
def alu_multi(root:UOp):
|
||||
msrcs = root.src
|
||||
assert all(x.op is Ops.MULTI for x in msrcs), f"all buffers must be MultiLazyBuffer {[x.op for x in msrcs]}"
|
||||
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
|
||||
|
||||
# NOTE: they all have to share an axis, we always choose [-1]
|
||||
axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None)
|
||||
axis = get_axis(root)
|
||||
bounds = dedup([x.bounds for x in root.src if x.axis == axis])[-1] if axis is not None else None
|
||||
srcs:list[list[UOp]] = []
|
||||
not_all_real = not all(all(mlb.real) for mlb in msrcs)
|
||||
new_real = tuple(all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])) if not_all_real else msrcs[0].real
|
||||
@@ -64,28 +79,24 @@ def alu_multi(root:UOp):
|
||||
return UOp.multi(*new_lbs, axis=axis, real=new_real)
|
||||
|
||||
def reduce_multi(root:UOp, multi:UOp):
|
||||
op, axis = root.arg
|
||||
(op, axis), new_axis = root.arg, get_axis(root)
|
||||
if multi.axis is not None and multi.axis in axis:
|
||||
# all-reduce on sharded axes
|
||||
reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(multi.src, multi.real)]
|
||||
# if all partitions are real, do all_reduce
|
||||
if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=None)
|
||||
if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=new_axis)
|
||||
# only one partition is real, keep it
|
||||
return UOp.multi(*reduced_parts, axis=None, real=multi.real)
|
||||
return UOp.multi(*reduced_parts, axis=new_axis, real=multi.real)
|
||||
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
|
||||
return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=multi.axis, real=multi.real)
|
||||
return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=new_axis, real=multi.real)
|
||||
|
||||
def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]:
|
||||
return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape))
|
||||
|
||||
def reshape_multi(root:UOp, multi:UOp):
|
||||
arg = root.arg
|
||||
if multi.axis is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=None, real=multi.real)
|
||||
arg, new_axis = root.arg, get_axis(root)
|
||||
if multi.axis is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=new_axis, real=multi.real)
|
||||
assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)"
|
||||
arg_acc:list[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
|
||||
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
|
||||
# todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
|
||||
new_axis = len(arg_acc) - arg_acc[::-1].index(prod(multi.shape[:multi.axis])) - 1
|
||||
assert all(prod(lb.shape[multi.axis:])%prod(arg[new_axis+1:])==0 for lb in multi.src), \
|
||||
f"reshape cannot move items between shards {multi.shape} -> {root.arg=}"
|
||||
lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[multi.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in multi.src]
|
||||
@@ -109,7 +120,7 @@ def pad_multi(root:UOp, multi:UOp):
|
||||
|
||||
def permute_multi(root:UOp, multi:UOp):
|
||||
# all permutes supported!
|
||||
return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=root.arg.index(multi.axis) if multi.axis is not None else None, real=multi.real)
|
||||
return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=get_axis(root), real=multi.real)
|
||||
|
||||
def shrink_multi(root:UOp, multi:UOp):
|
||||
assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \
|
||||
|
||||
@@ -13,7 +13,257 @@ from tinygrad.device import Buffer
|
||||
# creation can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
# **** ScheduleItem return type
|
||||
# **** schedule simplifier
|
||||
|
||||
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
|
||||
if not all_int(x.shape): return None
|
||||
# remove reduce on unmasked const
|
||||
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
|
||||
ret = x.const_arg
|
||||
match reduce.arg[0]:
|
||||
case Ops.ADD: ret *= prshape
|
||||
case Ops.MUL: ret **= prshape
|
||||
case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
|
||||
case _: return None
|
||||
return reduce.const_like(ret)
|
||||
|
||||
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
|
||||
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
|
||||
def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
|
||||
new_src = list(alu.src)
|
||||
for i,s in enumerate(alu.src):
|
||||
if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
|
||||
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
# UOp with size 0 is zero
|
||||
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
|
||||
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
|
||||
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
# reduce of size 0 is the identity element
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
# reduce of const is collapsed (TODO: make this a generic rule for stride0)
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop),
|
||||
# COPY(CONST) creates a new CONST on the destination device
|
||||
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)),
|
||||
# no COPY to same device, except clone (arg is True)
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
|
||||
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
|
||||
# remove cast to image when it's already a contiguous image
|
||||
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)),
|
||||
lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
||||
# remove contiguous if we can just view the buffer
|
||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
|
||||
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
|
||||
# contiguous/buffer is already contiguous
|
||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER)),)), lambda root: root.src[0]),
|
||||
# support for using a contiguous permuted view instead of the parent view if one exists
|
||||
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
|
||||
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
|
||||
# remove CONST/BIND/BUFFER from SINK
|
||||
(UPat(Ops.SINK, name="root"),
|
||||
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
|
||||
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
|
||||
])
|
||||
|
||||
remove_movement_ops = merge_views+PatternMatcher([
|
||||
# NOTE: movement ops are always applied to base
|
||||
(UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))),
|
||||
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
|
||||
(UPat(Ops.VIEW, name="view"),
|
||||
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
|
||||
])
|
||||
|
||||
# **** UOp realization
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleContext:
|
||||
tensor_uops: dict[UOp, list[UOp]] # this maps BUFFER uops of this schedule to the tensor uop
|
||||
assigns: set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule
|
||||
realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
|
||||
allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
|
||||
ops_metadata: dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata
|
||||
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
||||
preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
||||
|
||||
# wrap tensor uops around a VIEW(BUFFER, <uop>)
|
||||
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
|
||||
def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
# SINK is passthrough
|
||||
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
|
||||
# skip creating buffers for CONST/BIND/DEVICE/BUFFER
|
||||
if buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf
|
||||
if buf.base.op is Ops.BUFFER: return buf.view(unwrap(buf.st))
|
||||
# VIEW is passthrough
|
||||
if buf is not buf.base:
|
||||
cache[buf] = ret = add_buffers(buf.base, buffer_map, cache).view(unwrap(buf.st))
|
||||
return ret
|
||||
# make things that can't be images not images
|
||||
dtype = buf.dtype
|
||||
if isinstance(dtype, ImageDType) and (prod(buf.shape)!=prod(dtype.shape) or not any(buf.shape[x]%4==0 for x in unwrap(buf.st).unit_stride_axes())):
|
||||
if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}")
|
||||
dtype = buf.dtype.base
|
||||
# ASSIGN already has a target buffer, otherwise we create a new one
|
||||
assert isinstance(buf.device, str), f"buf device is str, not {buf.device}"
|
||||
buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
|
||||
# track the buffer uop for the simplified uop
|
||||
buffer_map[buf] = buf_uop
|
||||
# (early) bufferize
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
|
||||
return ret
|
||||
|
||||
class UPatScheduled(UPat):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
|
||||
|
||||
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
|
||||
|
||||
def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
|
||||
st = unwrap(view.st)
|
||||
# fold simple pads
|
||||
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
|
||||
return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src)
|
||||
# early realize before expand
|
||||
if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src)
|
||||
# otherwise safety check pads
|
||||
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src)
|
||||
|
||||
def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) -> UOp|None:
|
||||
if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(x.base).op is Ops.COPY: return None
|
||||
del ctx.realizes[b]
|
||||
return x.view(unwrap(view.st))
|
||||
|
||||
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
|
||||
if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
|
||||
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
|
||||
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
|
||||
|
||||
do_realize = PatternMatcher([
|
||||
# always realize SINK parents
|
||||
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)),
|
||||
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
|
||||
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize),
|
||||
# realize before expand or unsafe pad ops
|
||||
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view),
|
||||
# don't realize image to image casts
|
||||
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="x"),), dtype=dtypes.float),)),
|
||||
fold_img_cast),
|
||||
# realize before COPY or BUFFER_VIEW
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
||||
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
|
||||
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
|
||||
])
|
||||
|
||||
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
|
||||
ctx.allbufs[buf_uop] = view
|
||||
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
|
||||
for x in op.base.src:
|
||||
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
|
||||
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
|
||||
|
||||
def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER
|
||||
def uval(u:UOp) -> UOp:
|
||||
assert is_scheduled(u), f"must be a scheduled op {u}"
|
||||
return u.src[1]
|
||||
|
||||
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp], realizes:dict[UOp, UOp],
|
||||
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
|
||||
"""recursively search the uop for groupable children, realize the UOp if a child can't group"""
|
||||
if (tr, st) in cache: return
|
||||
cache.setdefault((tr, st))
|
||||
rsize = unwrap(allbufs[r].st).size
|
||||
if tr in realizes and tr is not r:
|
||||
# can only fuse contiguous
|
||||
# max one reduceop per kernel
|
||||
if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
|
||||
return group.setdefault(tr)
|
||||
for tr_next in children[tr]:
|
||||
# max one reduceop per kernel
|
||||
if (tr_next_uop:=uval(allbufs[tr_next]).base).op is Ops.REDUCE_AXIS: return group.setdefault(r)
|
||||
# can only fuse contiguous
|
||||
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
|
||||
recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache)
|
||||
|
||||
def group_realizes(sink:UOp, ctx:ScheduleContext) -> dict[UOp, UOp]:
|
||||
# start by adding uops that always realize
|
||||
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
|
||||
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
||||
reduce_for_op: dict[UOp, UOp] = {}
|
||||
double_reduces: list[UOp] = []
|
||||
for r, r_uop in ctx.allbufs.items():
|
||||
if (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue
|
||||
if FUSE_CONV_BW and is_scheduled((x:=r_uop.src[0]).base) and uval(x.base).op is r_uop.op and x.base is not x: double_reduces.append(r)
|
||||
if r in ctx.realizes: continue
|
||||
group: dict[UOp, None] = {}
|
||||
recursive_group(r, unwrap(r_uop.st), r, ctx.children, ctx.allbufs, ctx.realizes, reduce_for_op, group, cache={})
|
||||
# max one reduceop per kernel
|
||||
can_chase = all(tr not in reduce_for_op for tr in group)
|
||||
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
|
||||
forced_realize = r in group
|
||||
# can only have one output
|
||||
if not forced_realize and len(group) > 1: forced_realize = True
|
||||
# can only fuse assign if no other assign_target is used in the kernel
|
||||
if not forced_realize and any(x in ctx.assigns for x in group):
|
||||
parents = deque((r, *group))
|
||||
while parents and not forced_realize:
|
||||
if (p_uop:=ctx.allbufs.get(p:=parents.pop())) is None: continue
|
||||
if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False
|
||||
if p in ctx.realizes: continue
|
||||
parents.extend([x.base.buf_uop for x in p_uop.src if x.base.is_realized or (x.base.op is Ops.VIEW and len(x.base.src) != 0)])
|
||||
if forced_realize or not group:
|
||||
tr = r
|
||||
if can_chase:
|
||||
# can chase this down to contiguous children
|
||||
st = unwrap(r_uop.st)
|
||||
while len(ctx.children[tr]) == 1:
|
||||
tr_next_uop = uval(ctx.allbufs[(tr_next:=next(iter(ctx.children[tr])))])
|
||||
st_childs = dedup([unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop is tr])
|
||||
if len(st_childs) > 1: break
|
||||
if st.size != st_childs[0].size: break
|
||||
st = st + st_childs[0]
|
||||
if not st.contiguous or tr_next_uop.op is Ops.REDUCE_AXIS: break
|
||||
tr = tr_next
|
||||
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
||||
if (tr_uop:=uval(ctx.allbufs[tr])).op is Ops.CAST and tr_uop.dtype.base.itemsize > tr_uop.src[0].dtype.base.itemsize:
|
||||
tr = tr_uop.src[0].base.buf_uop
|
||||
group = {tr: None}
|
||||
ctx.realizes[tr] = tr
|
||||
reduce_for_op.update((tr, r) for tr in group)
|
||||
if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.CONST:
|
||||
# maybe fuse arange with its children
|
||||
if len(flatten(ctx.children[tr] for tr in group)) != 0:
|
||||
for tr in group: del ctx.realizes[tr]
|
||||
# fuse double reduces with no other child
|
||||
for reduceop in double_reduces:
|
||||
top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop
|
||||
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
|
||||
graph_rewrite(sink, break_sched, ctx)
|
||||
return ctx.realizes
|
||||
|
||||
# break the SINK into stores
|
||||
|
||||
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
|
||||
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
|
||||
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
|
||||
|
||||
def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
|
||||
if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m
|
||||
if b not in ctx.realizes: return x # collapse BUFFER
|
||||
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
|
||||
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
|
||||
|
||||
break_sched = PatternMatcher([
|
||||
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
|
||||
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
|
||||
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
|
||||
])
|
||||
|
||||
# **** ScheduleItem creation
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItem:
|
||||
@@ -31,49 +281,11 @@ class ScheduleItem:
|
||||
@functools.cached_property
|
||||
def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
|
||||
|
||||
# **** Schedule context and big graph
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleContext:
|
||||
tensor_uops: dict[UOp, list[UOp]] = field(default_factory=dict) # this maps BUFFER uops of this schedule to the tensor uop
|
||||
assigns: set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule
|
||||
realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
|
||||
allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
|
||||
ops_metadata: dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata
|
||||
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
||||
preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
||||
|
||||
# wrap tensor uops around a VIEW(BUFFER, <uop>)
|
||||
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
|
||||
def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
# SINK is passthrough
|
||||
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src))
|
||||
# skip creating buffers for CONST/BIND/DEVICE/BUFFER
|
||||
if buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf
|
||||
if buf.base.op is Ops.BUFFER: return buf.view(unwrap(buf.st))
|
||||
# VIEW is passthrough
|
||||
if buf is not buf.base:
|
||||
cache[buf] = ret = add_buffers(buf.base, tensor_map, ctx, cache).view(unwrap(buf.st))
|
||||
return ret
|
||||
# make things that can't be images not images
|
||||
dtype = buf.dtype
|
||||
if isinstance(dtype, ImageDType) and (prod(buf.shape)!=prod(dtype.shape) or not any(buf.shape[x]%4==0 for x in unwrap(buf.st).unit_stride_axes())):
|
||||
if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}")
|
||||
dtype = buf.dtype.base
|
||||
# ASSIGN already has a target buffer, otherwise we create a new one
|
||||
assert isinstance(buf.device, str), f"buf device is str, not {buf.device}"
|
||||
buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src))
|
||||
# track the underlying tensor uop for this buffer
|
||||
ctx.tensor_uops[buf_uop] = tensor_map[buf]
|
||||
# (early) bufferize
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
|
||||
return ret
|
||||
|
||||
# **** AST graph rewrite
|
||||
|
||||
# ** movement ops
|
||||
class ScheduleItemContext:
|
||||
var_vals: dict[Variable, int]
|
||||
sts: set[ShapeTracker] = field(default_factory=set)
|
||||
bufs: list[UOp] = field(default_factory=list)
|
||||
|
||||
def apply_swizzle(u:UOp) -> UOp:
|
||||
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
|
||||
@@ -127,14 +339,6 @@ view_right = merge_views+PatternMatcher([
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
|
||||
# ** ScheduleItem context builder
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItemContext:
|
||||
var_vals: dict[Variable, int]
|
||||
sts: set[ShapeTracker] = field(default_factory=set)
|
||||
bufs: list[UOp] = field(default_factory=list)
|
||||
|
||||
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
|
||||
if (st:=unwrap(x.st)) in ctx.sts: return None
|
||||
st, var_vals = st.simplify().unbind()
|
||||
@@ -203,222 +407,7 @@ if CAPTURE_PROCESS_REPLAY:
|
||||
def save_process_replay() -> None:
|
||||
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
|
||||
|
||||
# **** Schedule grouping
|
||||
|
||||
def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER
|
||||
def uval(u:UOp) -> UOp:
|
||||
assert is_scheduled(u), f"must be a scheduled op {u}"
|
||||
return u.src[1]
|
||||
|
||||
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp], realizes:dict[UOp, UOp],
|
||||
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
|
||||
"""recursively search the uop for groupable children, realize the UOp if a child can't group"""
|
||||
if (tr, st) in cache: return
|
||||
cache.setdefault((tr, st))
|
||||
rsize = unwrap(allbufs[r].st).size
|
||||
if tr in realizes and tr is not r:
|
||||
# can only fuse contiguous
|
||||
# max one reduceop per kernel
|
||||
if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
|
||||
return group.setdefault(tr)
|
||||
for tr_next in children[tr]:
|
||||
# max one reduceop per kernel
|
||||
if (tr_next_uop:=uval(allbufs[tr_next]).base).op is Ops.REDUCE_AXIS: return group.setdefault(r)
|
||||
# can only fuse contiguous
|
||||
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
|
||||
recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache)
|
||||
|
||||
def group_realizes(ctx:ScheduleContext) -> None:
|
||||
"""search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop"""
|
||||
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
||||
reduce_for_op: dict[UOp, UOp] = {}
|
||||
double_reduces: list[UOp] = []
|
||||
for r, r_uop in ctx.allbufs.items():
|
||||
if (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue
|
||||
if FUSE_CONV_BW and is_scheduled((x:=r_uop.src[0]).base) and uval(x.base).op is r_uop.op and x.base is not x: double_reduces.append(r)
|
||||
if r in ctx.realizes: continue
|
||||
group: dict[UOp, None] = {}
|
||||
recursive_group(r, unwrap(r_uop.st), r, ctx.children, ctx.allbufs, ctx.realizes, reduce_for_op, group, cache={})
|
||||
# max one reduceop per kernel
|
||||
can_chase = all(tr not in reduce_for_op for tr in group)
|
||||
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
|
||||
forced_realize = r in group
|
||||
# can only have one output
|
||||
if not forced_realize and len(group) > 1: forced_realize = True
|
||||
# can only fuse assign if no other assign_target is used in the kernel
|
||||
if not forced_realize and any(x in ctx.assigns for x in group):
|
||||
parents = deque((r, *group))
|
||||
while parents and not forced_realize:
|
||||
if (p_uop:=ctx.allbufs.get(p:=parents.pop())) is None: continue
|
||||
if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False
|
||||
if p in ctx.realizes: continue
|
||||
parents.extend([x.base.buf_uop for x in p_uop.src if x.base.is_realized or (x.base.op is Ops.VIEW and len(x.base.src) != 0)])
|
||||
if forced_realize or not group:
|
||||
tr = r
|
||||
if can_chase:
|
||||
# can chase this down to contiguous children
|
||||
st = unwrap(r_uop.st)
|
||||
while len(ctx.children[tr]) == 1:
|
||||
tr_next_uop = uval(ctx.allbufs[(tr_next:=next(iter(ctx.children[tr])))])
|
||||
st_childs = dedup([unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop is tr])
|
||||
if len(st_childs) > 1: break
|
||||
if st.size != st_childs[0].size: break
|
||||
st = st + st_childs[0]
|
||||
if not st.contiguous or tr_next_uop.op is Ops.REDUCE_AXIS: break
|
||||
tr = tr_next
|
||||
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
||||
if (tr_uop:=uval(ctx.allbufs[tr])).op is Ops.CAST and tr_uop.dtype.base.itemsize > tr_uop.src[0].dtype.base.itemsize:
|
||||
tr = tr_uop.src[0].base.buf_uop
|
||||
group = {tr: None}
|
||||
ctx.realizes[tr] = tr
|
||||
reduce_for_op.update((tr, r) for tr in group)
|
||||
if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.CONST:
|
||||
# maybe fuse arange with its children
|
||||
if len(flatten(ctx.children[tr] for tr in group)) != 0:
|
||||
for tr in group: del ctx.realizes[tr]
|
||||
# fuse double reduces with no other child
|
||||
for reduceop in double_reduces:
|
||||
top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop
|
||||
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
|
||||
|
||||
# **** Schedule creation and BFS toposort
|
||||
|
||||
# ** this is schedule level const folding
|
||||
|
||||
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
|
||||
if not all_int(x.shape): return None
|
||||
# remove reduce on unmasked const
|
||||
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
|
||||
ret = x.const_arg
|
||||
match reduce.arg[0]:
|
||||
case Ops.ADD: ret *= prshape
|
||||
case Ops.MUL: ret **= prshape
|
||||
case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
|
||||
case _: return None
|
||||
return reduce.const_like(ret)
|
||||
|
||||
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
|
||||
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
|
||||
def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
|
||||
new_src = list(alu.src)
|
||||
for i,s in enumerate(alu.src):
|
||||
if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
|
||||
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
# UOp with size 0 is zero
|
||||
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
|
||||
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
|
||||
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
# reduce of size 0 is the identity element
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
# reduce of const is collapsed (TODO: make this a generic rule for stride0)
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop),
|
||||
# COPY(CONST) creates a new CONST on the destination device
|
||||
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)),
|
||||
# no COPY to same device, except clone (arg is True)
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
|
||||
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
|
||||
# remove cast to image when it's already a contiguous image
|
||||
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)),
|
||||
lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
||||
# remove contiguous if we can just view the buffer
|
||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
|
||||
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
|
||||
# contiguous/buffer is already contiguous
|
||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER)),)), lambda root: root.src[0]),
|
||||
# support for using a contiguous permuted view instead of the parent view if one exists
|
||||
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
|
||||
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
|
||||
# remove CONST/BIND/BUFFER from SINK
|
||||
(UPat(Ops.SINK, name="root"),
|
||||
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
|
||||
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
|
||||
])
|
||||
|
||||
# ** this decides which ops get realized
|
||||
|
||||
class UPatScheduled(UPat):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
|
||||
|
||||
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
|
||||
|
||||
def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
|
||||
st = unwrap(view.st)
|
||||
# fold simple pads
|
||||
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
|
||||
return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src)
|
||||
# early realize before expand
|
||||
if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src)
|
||||
# otherwise safety check pads
|
||||
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src)
|
||||
|
||||
def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) -> UOp|None:
|
||||
if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(x.base).op is Ops.COPY: return None
|
||||
del ctx.realizes[b]
|
||||
return x.view(unwrap(view.st))
|
||||
|
||||
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
|
||||
if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
|
||||
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
|
||||
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
|
||||
|
||||
do_realize = PatternMatcher([
|
||||
# always realize SINK parents
|
||||
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)),
|
||||
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
|
||||
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize),
|
||||
# realize before expand or unsafe pad ops
|
||||
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view),
|
||||
# don't realize image to image casts
|
||||
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="x"),), dtype=dtypes.float),)),
|
||||
fold_img_cast),
|
||||
# realize before COPY or BUFFER_VIEW
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
||||
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
|
||||
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
|
||||
])
|
||||
|
||||
# **** rewrite VIEW into LOAD/STORE or fuse the underlying UOp
|
||||
|
||||
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
|
||||
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
|
||||
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
|
||||
|
||||
def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
|
||||
if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m
|
||||
if b not in ctx.realizes: return x # collapse BUFFER
|
||||
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
|
||||
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
|
||||
|
||||
break_sched = PatternMatcher([
|
||||
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
|
||||
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
|
||||
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
|
||||
])
|
||||
|
||||
# **** Schedule context builder
|
||||
|
||||
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
|
||||
ctx.allbufs[buf_uop] = view
|
||||
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
|
||||
for x in op.base.src:
|
||||
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
|
||||
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
|
||||
|
||||
# **** movement ops
|
||||
|
||||
remove_movement_ops = merge_views+PatternMatcher([
|
||||
# NOTE: movement ops are always applied to base
|
||||
(UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))),
|
||||
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
|
||||
(UPat(Ops.VIEW, name="view"),
|
||||
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
|
||||
])
|
||||
# **** schedule creation and toposort
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
||||
@@ -434,15 +423,13 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
elif v.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
|
||||
|
||||
# we group the rest of UOps into ScheduleItems
|
||||
rev_tensor_map: dict[UOp, list[UOp]] = {}
|
||||
for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k)
|
||||
# add BUFFER uops
|
||||
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx:=ScheduleContext(), cache={})
|
||||
# add realizes
|
||||
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
|
||||
# group realizes into kernels
|
||||
group_realizes(ctx)
|
||||
graph_rewrite(sink, break_sched, ctx)
|
||||
buffer_map: dict[UOp, UOp] = {}
|
||||
sink = add_buffers(tensor_map[big_sink], buffer_map, cache={})
|
||||
# get realizes
|
||||
buf_tensors: dict[UOp, list[UOp]] = {}
|
||||
for k,v in tensor_map.items():
|
||||
if (b:=buffer_map.get(v)) is not None: buf_tensors.setdefault(b, []).append(k)
|
||||
realize_map = group_realizes(sink, ctx:=ScheduleContext(buf_tensors))
|
||||
|
||||
# TODO: this should be the break between the "grouper" and the "linearizer"
|
||||
# here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign)
|
||||
@@ -451,11 +438,11 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
# create schedule items + map buffers to realized tensors
|
||||
prescheduled: list[ScheduleItem] = []
|
||||
var_vals: dict[Variable, int] = {}
|
||||
for buf_uop,store in ctx.realizes.items():
|
||||
for buf_uop,store in realize_map.items():
|
||||
assert store.op is Ops.STORE, f"expected a realized BUFFER to get a STORE {sink}"
|
||||
prescheduled.append(schedule_uop(store.sink(), ctx, var_vals))
|
||||
# can only schedule once
|
||||
for tensor_uop in ctx.tensor_uops[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
|
||||
for tensor_uop in buf_tensors[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
|
||||
# increment refcount for this buffer
|
||||
buf_uop.buffer.ref(1)
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ class MathTrait(SimpleMathTrait):
|
||||
# the order of these Ops controls the order of the toposort
|
||||
class Ops(FastEnum):
|
||||
# uops that aren't rendered
|
||||
SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); PRELOAD = auto() # noqa: E702
|
||||
SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); PRELOAD = auto(); KERNEL = auto() # noqa: E702
|
||||
|
||||
# TODO: empty continues to exist because of tensor
|
||||
EMPTY = auto()
|
||||
|
||||
@@ -1,25 +1,9 @@
|
||||
import platform, tempfile, pathlib, subprocess, sys
|
||||
from tinygrad.helpers import cpu_objdump, capstone_flatdump
|
||||
import platform, subprocess, sys
|
||||
from tinygrad.helpers import capstone_flatdump
|
||||
from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram
|
||||
from tinygrad.runtime.support.elf import jit_loader
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
|
||||
# Used by ops_dsp.py
|
||||
class ClangCompiler(Compiler):
|
||||
def __init__(self, cachekey="compile_clang", args:list[str]|None=None, objdump_tool='objdump'):
|
||||
self.args = ['-march=native'] if args is None else args
|
||||
self.objdump_tool = objdump_tool
|
||||
super().__init__(cachekey)
|
||||
|
||||
def compile(self, src:str) -> bytes:
|
||||
# TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here
|
||||
with tempfile.NamedTemporaryFile(delete=True) as output_file:
|
||||
subprocess.check_output(['clang', '-shared', *self.args, '-O2', '-Wall', '-Werror', '-x', 'c', '-fPIC', '-ffreestanding', '-nostdlib',
|
||||
'-', '-o', str(output_file.name)], input=src.encode('utf-8'))
|
||||
return pathlib.Path(output_file.name).read_bytes()
|
||||
|
||||
def disassemble(self, lib:bytes): return cpu_objdump(lib, self.objdump_tool)
|
||||
|
||||
class ClangJITCompiler(Compiler):
|
||||
def __init__(self, cachekey="compile_clang_jit"): super().__init__(cachekey)
|
||||
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from typing import Tuple, Any, List
|
||||
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys
|
||||
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess
|
||||
assert sys.platform != 'win32'
|
||||
from tinygrad.device import BufferSpec, Compiled, Allocator
|
||||
from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.ops import Ops, UOp
|
||||
from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv
|
||||
from tinygrad.runtime.ops_clang import ClangCompiler
|
||||
from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv, cpu_objdump
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
from tinygrad.runtime.autogen import libc, qcom_dsp
|
||||
if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import
|
||||
@@ -91,10 +90,23 @@ class DSPAllocator(Allocator):
|
||||
def _copyout(self, dest:memoryview, src:DSPBuffer): ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes)
|
||||
def _offset(self, buf, size:int, offset:int): return DSPBuffer(buf.va_addr+offset, size, buf.share_info, buf.offset+offset)
|
||||
|
||||
class DSPDevice(Compiled):
|
||||
def __init__(self, device:str=""):
|
||||
self.ion_fd = os.open('/dev/ion', os.O_RDONLY)
|
||||
class ClangCompiler(Compiler):
|
||||
def __init__(self, cachekey="compile_clang", args:list[str]|None=None, objdump_tool='objdump'):
|
||||
self.args = ['-march=native'] if args is None else args
|
||||
self.objdump_tool = objdump_tool
|
||||
super().__init__(cachekey)
|
||||
|
||||
def compile(self, src:str) -> bytes:
|
||||
# TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here
|
||||
with tempfile.NamedTemporaryFile(delete=True) as output_file:
|
||||
subprocess.check_output(['clang', '-shared', *self.args, '-O2', '-Wall', '-Werror', '-x', 'c', '-fPIC', '-ffreestanding', '-nostdlib',
|
||||
'-', '-o', str(output_file.name)], input=src.encode('utf-8'))
|
||||
return pathlib.Path(output_file.name).read_bytes()
|
||||
|
||||
def disassemble(self, lib:bytes): return cpu_objdump(lib, self.objdump_tool)
|
||||
|
||||
class DSPCompiler(ClangCompiler):
|
||||
def __init__(self):
|
||||
# Generate link script to pass into clang. Aligning all used sections to 4k fixes invoke problem.
|
||||
sections = ['hash', 'text', 'rela.plt', 'got', 'got.plt', 'dynamic', 'dynsym', 'dynstr', 'plt', 'data', 'bss']
|
||||
sections_link = '\n'.join([f'.{n} : ALIGN(4096) {{ *(.{n}) }}' for n in sections])
|
||||
@@ -103,15 +115,19 @@ class DSPDevice(Compiled):
|
||||
self.link_ld.flush()
|
||||
|
||||
compiler_args = ["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib", "-mhvx=v65", "-mhvx-length=128b", f"-T{self.link_ld.name}"]
|
||||
super().__init__(device, DSPAllocator(self), DSPRenderer(),
|
||||
ClangCompiler("compile_dsp", args=compiler_args, objdump_tool='llvm-objdump'), functools.partial(DSPProgram, self))
|
||||
return super().__init__("compile_dsp", args=compiler_args, objdump_tool='llvm-objdump')
|
||||
|
||||
class DSPDevice(Compiled):
|
||||
def __init__(self, device:str=""):
|
||||
self.ion_fd = os.open('/dev/ion', os.O_RDONLY)
|
||||
super().__init__(device, DSPAllocator(self), DSPRenderer(), DSPCompiler(), functools.partial(DSPProgram, self))
|
||||
|
||||
fastrpc_shell = memoryview(bytearray(pathlib.Path('/dsp/cdsp/fastrpc_shell_3').read_bytes()))
|
||||
self.shell_buf = self.allocator.alloc(round_up(fastrpc_shell.nbytes, 0x1000), BufferSpec(nolru=True))
|
||||
ctypes.memmove(self.shell_buf.va_addr, mv_address(fastrpc_shell), fastrpc_shell.nbytes)
|
||||
|
||||
self.init_dsp()
|
||||
RPCListner(self).start()
|
||||
RPCListener(self).start()
|
||||
|
||||
def open_lib(self, lib):
|
||||
self.binded_lib, self.binded_lib_off = lib, 0
|
||||
@@ -149,7 +165,7 @@ class DSPDevice(Compiled):
|
||||
qcom_dsp.FASTRPC_IOCTL_INIT(self.rpc_fd, flags=0x1, file=self.shell_buf.va_addr, filelen=self.shell_buf.size, filefd=self.shell_buf.share_info.fd)
|
||||
qcom_dsp.FASTRPC_IOCTL_INVOKE(self.rpc_fd, handle=3, sc=rpc_sc(method=3, ins=0, outs=0))
|
||||
|
||||
class RPCListner(threading.Thread):
|
||||
class RPCListener(threading.Thread):
|
||||
def __init__(self, device:DSPDevice):
|
||||
super().__init__()
|
||||
self.device, self.daemon = device, True
|
||||
|
||||
@@ -11,12 +11,15 @@ def expect(x, err, ret=None):
|
||||
if x: raise RuntimeError(llvm.string_cast(err.contents) if not isinstance(err, str) else err)
|
||||
return ret
|
||||
|
||||
HOST_ARCH = {'arm64': 'AArch64', 'aarch64': 'AArch64', 'x86_64': 'X86', 'AMD64': 'X86'}[platform.machine()]
|
||||
HOST_TRIPLE = {'AArch64': 'aarch64', 'X86': 'x86_64'}[HOST_ARCH]
|
||||
REQUIRED_COMPONENTS = ['Target', 'TargetInfo', 'TargetMC', 'AsmPrinter']
|
||||
|
||||
class LLVMCompiler(Compiler):
|
||||
def __init__(self, target_machine, opt):
|
||||
def __init__(self, host_arch:str, opt:bool):
|
||||
for component in ['Target', 'TargetInfo', 'TargetMC', 'AsmPrinter']: getattr(llvm, f'LLVMInitialize{host_arch}{component}')()
|
||||
triple = ({'AArch64': 'aarch64', 'X86': 'x86_64'}[host_arch]+'-none-unknown-elf').encode()
|
||||
|
||||
target = expect(llvm.LLVMGetTargetFromTriple(triple, ctypes.pointer(tgt:=llvm.LLVMTargetRef()), err:=cerr()), err, tgt)
|
||||
target_machine = llvm.LLVMCreateTargetMachine(target, triple, b'', b'+reserve-x18' if platform.machine() == 'arm64' else b'',
|
||||
llvm.LLVMCodeGenLevelDefault, llvm.LLVMRelocPIC, llvm.LLVMCodeModelDefault)
|
||||
|
||||
self.pbo = llvm.LLVMCreatePassBuilderOptions()
|
||||
if opt:
|
||||
self.passes = b'default<O2>'
|
||||
@@ -48,14 +51,5 @@ class LLVMCompiler(Compiler):
|
||||
|
||||
class LLVMDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
for component in REQUIRED_COMPONENTS:
|
||||
getattr(llvm, f'LLVMInitialize{HOST_ARCH}{component}')()
|
||||
|
||||
triple = f'{HOST_TRIPLE}-none-unknown-elf'.encode()
|
||||
target = expect(llvm.LLVMGetTargetFromTriple(triple, ctypes.pointer(tgt:=llvm.LLVMTargetRef()), err:=cerr()), err, tgt)
|
||||
features = b'+reserve-x18' if platform.machine() == 'arm64' else b''
|
||||
target_machine = llvm.LLVMCreateTargetMachine(target, triple, b'', features, llvm.LLVMCodeGenLevelDefault, llvm.LLVMRelocPIC,
|
||||
llvm.LLVMCodeModelDefault)
|
||||
|
||||
super().__init__(device, MallocAllocator, LLVMRenderer('win64cc' if sys.platform == 'win32' else None),
|
||||
LLVMCompiler(target_machine, getenv("LLVMOPT")), CPUProgram)
|
||||
compiler = LLVMCompiler({'arm64': 'AArch64', 'aarch64': 'AArch64', 'x86_64': 'X86', 'AMD64': 'X86'}[platform.machine()], bool(getenv("LLVMOPT")))
|
||||
super().__init__(device, MallocAllocator, LLVMRenderer('win64cc' if sys.platform == 'win32' else None), compiler, CPUProgram)
|
||||
|
||||
@@ -3314,10 +3314,10 @@ class Tensor(SimpleMathTrait):
|
||||
if not base.is_floating_point(): raise RuntimeError("base needs to be float")
|
||||
# start with b ** e = exp(e * log(b))
|
||||
ret = base.abs().log().mul(exponent).exp()
|
||||
# correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)
|
||||
# correct sign of negative base with odd exponent
|
||||
negative_base = (base < 0).detach().where(1, 0)
|
||||
# 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
|
||||
correct_sign = 1 + negative_base * ((exponent * math.pi).cos() - 1)
|
||||
correct_sign = (exponent.int()%2==0).where(1, 1-2*negative_base)
|
||||
# inject nan for negative base and non-integer exponent
|
||||
inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
|
||||
# apply correct_sign inject_nan, and fix 0 ** 0 = 1
|
||||
|
||||
@@ -12,7 +12,7 @@ from tinygrad.dtype import dtypes
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
|
||||
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff",
|
||||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0"}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user