Onnx 1.15.0 gogogo (#2217)

* lol

* lol

* add GELULULULUL

* onnx 1.50

* fuk torch bool neg

* exclude regex tests

* exclude dequantizelinear for now

* is sunny in philly

* damn it affinegrid

* fixed auto_pad VALID

* skip 0 shape tests

* add temporary cast in Reduces

* tests should pass now

* added comments and cleanup

* try moving dequantizelinear to onnx.py

* fixed dequantizedlinear?

* cleanup

* try?

* float16 segfaults LLVM CI..???

* cleanup comments

* pin to 1.50.0

* remove use of -np.inf cuz numpy is kill

* 1.50? lol I'm actually retarded

* thx for review, muhbad

* moved Gelu higher up
This commit is contained in:
geohotstan
2023-11-11 07:36:48 +08:00
committed by GitHub
parent 85d26ddc36
commit b853e9bb8c
6 changed files with 123 additions and 40 deletions

View File

@@ -3,9 +3,9 @@ from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
import importlib
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod, getenv, DEBUG, dtypes
from typing import List,Dict
from onnx.onnx_pb import AttributeProto, ModelProto, TensorProto, TypeProto
from tinygrad.helpers import getenv, DEBUG, dtypes
from typing import List, Dict
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto # onnx 1.50 uses serialized file (see onnx/onnx-ml.proto) as descriptors
try:
from onnx.helper import tensor_dtype_to_np_dtype
except ImportError:
@@ -51,7 +51,7 @@ def get_run_onnx(onnx_model: ModelProto):
else: raise Exception(f"unknown attr: {attr}, {type_proto}")
def buffer_parse(inp: TensorProto) -> Tensor:
if inp.data_type in (1,10,6,7):
if inp.data_type in (1,10,6,7,5):
# TODO: this is shared with below
if len(inp.float_data) > 0:
ret = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
@@ -74,7 +74,7 @@ def get_run_onnx(onnx_model: ModelProto):
elif a.type == AttributeProto.FLOATS: return tuple(float(x) for x in a.floats)
elif a.type == AttributeProto.INTS: return tuple(int(x) for x in a.ints)
elif a.type == AttributeProto.STRINGS: return tuple(x.decode("utf-8") for x in a.strings)
elif a.type == AttributeProto.GRAPH: raise Exception(f"graph not implemented: {a.g}")
elif a.type == AttributeProto.GRAPH: raise Exception(f"graph not implemented: {a.g}\n likely an OP requiring control flow")
else: raise Exception(f"can't parse {a.type} {a}")
def attribute_to_dict(a: RepeatedCompositeFieldContainer[AttributeProto]): return {x.name:attribute_parse(x) for x in a}
@@ -144,8 +144,10 @@ def get_run_onnx(onnx_model: ModelProto):
inp.append(t)
opt: Dict = attribute_dict[num]
if debug >= 1: print(f"{num}: op {n.op_type} shape {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}")
# some ops live here because they require some local variables
if n.op_type == "Split": # have to use n.output for cases when num_outputs is absent
# NOTE some ops live here because they require access to some local variables
# have to use n.output for cases when num_outputs is absent
if n.op_type == "Split":
axis = opt.get("axis", 0)
split = None if len(inp) == 1 else [int(x) for x in safe_numpy(inp[1])]
if split is None:
@@ -159,7 +161,9 @@ def get_run_onnx(onnx_model: ModelProto):
ret.append(inp[0].shrink(arg=tuple(arg)))
i = i+s
ret = tuple(ret)
elif n.op_type == "Slice": # need to check onnx_model_version
# need to check onnx_model_version
elif n.op_type == "Slice":
if onnx_model_version < 10:
axes, ends, starts, steps = list(opt.get("axes", range(inp[0].ndim))), list(opt["ends"]), list(opt["starts"]), [1]*inp[0].ndim
else:
@@ -177,11 +181,15 @@ def get_run_onnx(onnx_model: ModelProto):
new_shape = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in arg)
if any(s==e for s,e in new_shape): ret = inp[0].shrink(new_shape)
else: ret = inp[0].__getitem__(tuple([slice(s,e,st) for s,e,st in arg]))
elif n.op_type == "Gradient": # need to call backward on intermediate_tensors
# need to call backward on intermediate_tensors
elif n.op_type == "Gradient":
assert len(opt["xs"]) == len(inp), f"len(opt['xs']):{len(opt['xs'])}, len(inp):{len(inp)} output and input has to match"
y = opt["y"]
intermediate_tensors[y].backward()
ret = tuple([t.grad for t in inp])
# onnx_ops.py
elif hasattr(onnx_ops, n.op_type):
fxn = getattr(onnx_ops, n.op_type)
if isinstance(fxn, dict):
@@ -194,6 +202,7 @@ def get_run_onnx(onnx_model: ModelProto):
else:
print("UNSUPPORTED", n.op_type, n.input, n.output)
raise Exception(f"op_type {n.op_type} not supported")
if not isinstance(ret, tuple): ret = (ret, )
assert len(n.output) <= len(ret), f"expected output size must be less than {len(ret)}, it's {n.output}"
if debug >= 2: print([x.shape if isinstance(x, Tensor) else None for x in ret])

View File

@@ -1,8 +1,9 @@
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod, dtypes, ImageDType
from tinygrad.helpers import prod, dtypes, ImageDType, flatten
from extra.onnx import safe_numpy
from onnx.helper import tensor_dtype_to_np_dtype
from onnx.onnx_pb import TensorProto
from onnx import TensorProto
import io
import os
import numpy as np
import functools
@@ -64,6 +65,7 @@ def Tanh(x): return x.tanh()
def HardSigmoid(input: Tensor, alpha=0.2, beta=0.5): return (alpha*input + beta).clip(0, 1)
def HardSwish(input: Tensor): return input * HardSigmoid(input, 1/6, 0.5)
def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + Erf(x/math.sqrt(2)))
def Celu(X: Tensor, alpha=1.0): return X.relu() - (-alpha*(X/alpha).exp()+1).relu()
def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
def Softplus(X: Tensor): return X.softplus()
@@ -178,6 +180,7 @@ def Expand(input: Tensor, shape):
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x_shape, y_shape))
return input.reshape(x_shape).expand(shape_ret)
# **************** Complex Ops ****************
def Gemm(A: Tensor, B: Tensor, C: Tensor=None, alpha=1.0, beta=1.0, transA=0, transB=0, broadcast=0):
@@ -238,19 +241,26 @@ def _format_padding(onnx_pads, ndims=None, axes=None):
np_pads[axes[i]] = (onnx_pads[i], onnx_pads[i + num_axes])
return np_pads
def _padding(X: Tensor, pads=None, auto_pad="NOTSET", axes=None, constant_value=0., strides=None, kernel_shape=None, dilations=None):
if auto_pad != "NOTSET": pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
def _padding(X: Tensor, pads=None, auto_pad="NOTSET", axes=None, constant_value=0., strides=None, kernel_shape=None, dilations=None, ceil_mode=0):
if auto_pad != "NOTSET": pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations, ceil_mode)
if pads is None: return X
pads = _format_padding(pads, ndims=len(X.shape), axes=axes)
return X.pad(tuple(pads), value=constant_value)
def _auto_pad(X, auto_pad, strides, kernel_shape, dilations):
# TODO works but hacky and messy, think of cleaner way to do this
def _auto_pad(X: Tensor, auto_pad, strides, kernel_shape, dilations, ceil_mode=0):
strides = [strides]*len(kernel_shape) if isinstance(strides, int) else strides if strides else [1]*len(kernel_shape)
dilations = [1]*len(kernel_shape) if dilations == 1 else dilations
pad_shape = [(math.ceil(sh/st)-1)*st+((ks-1)*di+1)-sh for sh, st, ks, di in zip(X.shape[-len(strides):], strides, kernel_shape, dilations)]
if auto_pad == "SAME_UPPER": return [pad_shape[0]//2, pad_shape[1]//2, pad_shape[0]-pad_shape[0]//2, pad_shape[1]-pad_shape[1]//2]
elif auto_pad == "SAME_LOWER": return [pad_shape[0]-pad_shape[0]//2, pad_shape[1]-pad_shape[1]//2, pad_shape[0]//2, pad_shape[1]//2]
else: raise NotImplementedError(f"auto_pad={auto_pad} not implemented, yet")
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
pad_shape = [(math.ceil(sh/st)-1)*st+((ks-1)*di+1)-sh for sh, st, ks, di in zip(X.shape[-len(kernel_shape):], strides, kernel_shape, dilations)]
pad_shape = flatten([[sh//2, sh-sh//2] for sh in pad_shape])
return pad_shape[::2] + pad_shape[1::2] if auto_pad == "SAME_UPPER" else pad_shape[1::2] + pad_shape[::2]
elif auto_pad == "VALID":
out_spatial_shape = [math.ceil((sh - ((ker-1)*dil+1)) / st) + 1 if ceil_mode else math.floor((sh - ((ker-1)*dil+1)) / st) + 1 for sh, st, ker, dil in zip(X.shape[-len(kernel_shape):], strides, kernel_shape, dilations)]
pad_shape = [(osh-1)*st+((ks-1)*dil+1)-ish for osh, st, ks, dil, ish in zip(out_spatial_shape, strides, kernel_shape, dilations, X.shape[-len(kernel_shape):])]
pad_shape = flatten([[sh//2, sh-sh//2] for sh in pad_shape])
return pad_shape[::2] + pad_shape[1::2]
else: raise NotImplementedError(f"auto_pad={auto_pad} not implemented")
def Pad(x: Tensor, pads: Union[Tensor, Tuple[int, ...]], constant_value: Tensor=None, axes: Tensor=None, mode="constant", value: float=0.):
constant_value = value if constant_value is None else float(safe_numpy(constant_value)[0])
@@ -297,7 +307,7 @@ def Pad(x: Tensor, pads: Union[Tensor, Tuple[int, ...]], constant_value: Tensor=
def AveragePool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, count_include_pad=0, dilations=1, pads=None, strides=1):
if dilations != 1: raise NotImplementedError(f"dilations != 1 not supported, dilations:{dilations}")
pixel_axes = tuple(range(len(X.shape)))[-2:]
pixel_axes = tuple(range(len(X.shape)))[-len(kernel_shape):]
if ceil_mode: auto_pad = "SAME_UPPER"
padding_included = _padding(X, pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations).avg_pool2d(kernel_shape, stride=strides)
if count_include_pad:
@@ -307,8 +317,8 @@ def AveragePool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, count_i
return padding_included / div
def MaxPool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, dilations=1, pads=None, storage_order=0, strides=1):
if ceil_mode: auto_pad = "SAME_UPPER"
ret = _padding(X, pads, auto_pad, constant_value=-np.inf, axes=tuple(range(len(X.shape)))[-len(kernel_shape):], strides=strides, kernel_shape=kernel_shape, dilations=dilations)
if ceil_mode and auto_pad == "NOTSET": auto_pad="VALID"
ret = _padding(X, pads, auto_pad, constant_value=float("-inf"), axes=tuple(range(len(X.shape)))[-len(kernel_shape):], strides=strides, kernel_shape=kernel_shape, dilations=dilations, ceil_mode=ceil_mode)
ret = ret.max_pool2d(kernel_shape, stride=strides, dilation=dilations)
ret_len, X_len = ret.numel(), X.numel()
indices = ((ret.flatten().unsqueeze(1).expand(ret_len, X_len) == X.flatten().reshape(1, X_len).expand(ret_len, X_len)) * Tensor.arange(X_len).reshape(1, X_len).expand(ret_len, X_len)).sum(1).reshape(ret.shape).cast(dtypes.int64)
@@ -429,6 +439,7 @@ def _round(x:Tensor, n:float, equidistant_case = "round_down") -> Tensor:
def Round(X:Tensor): return _round(X, 0.5, "round_to_even")
# TODO clean this up, it's taking the longest in CI
def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None, coordinate_transformation_mode='half_pixel', cubic_coeff_a=-0.75, exclude_outside=0, extrapolation_value=0.0, keep_aspect_ratio_policy='stretch', mode='nearest', nearest_mode='round_prefer_floor'):
def _nearest_gather(X: Tensor, x_out, y_out): return X[:,:,y_out,:][:,:,:,x_out]
def _nearest_mode(x_resized: Tensor, nearest_mode: str, x_len):
@@ -579,21 +590,67 @@ def EyeLike(x: Tensor, dtype=None, k=0):
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode)
# Needs work
def IsInf(x,detect_negative=1,detect_positive=1):
def IsInf(x: Tensor, detect_negative=1, detect_positive=1):
ret = (x == float("inf"))*detect_positive + (x == float("-inf"))*detect_negative + Tensor.zeros(*x.shape)
return ret.cast(dtypes.bool)
# Needs work
def DequantizeLinear(x: Tensor, x_scale: Tensor, x_zero_point=0, axis=1):
def DequantizeLinear(x: Tensor, x_scale: Tensor, x_zero_point: Union[Tensor, int] = 0, axis=1):
axis = axis + x.ndim if axis < 0 else axis
x = x.cast(dtypes.float)
if x_zero_point.__class__ is Tensor: x_zero_point.cast(dtypes.float)
x_sc = x_scale.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim))
x_zer = x_zero_point.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim)) if isinstance(x_zero_point, Tensor) else x_zero_point
return (x - x_zer) * x_sc
return ((x - x_zer) * x_sc).cast(x_scale.dtype)
# Needs work
def IsNaN(x):
def IsNaN(x: Tensor):
return (x < float("-inf")).cast(dtypes.bool)
# copied from https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_image_decoder.py
# without importing PIL we'll have to manually decode a bunch of image formats like PNG, JPEG, WebP, etc
def ImageDecoder(encoded_stream: Tensor, pixel_format="RGB"):
try:
import PIL.Image
except ImportError as e:
raise ImportError("Pillow must be installed to use the reference implementation of the ImageDecoder operator") from e
img = PIL.Image.open(io.BytesIO(safe_numpy(encoded_stream).tobytes()))
if pixel_format == "BGR":
return Tensor(np.array(img))[:, :, ::-1]
elif pixel_format == "RGB":
return Tensor(np.array(img))
elif pixel_format == "Grayscale":
img = img.convert("L")
decoded = Tensor(np.array(img))
return decoded.unsqueeze(-1) # (H, W) to (H, W, 1)
else:
raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
def AffineGrid(theta: Tensor, size: Tensor, align_corners=0):
_, _, *data_sz = safe_numpy(size).tolist()
size_zeros, original_grid = Tensor.zeros(data_sz), Tensor.ones(data_sz)
stackable = [original_grid]
for dim, dim_sz in enumerate(data_sz):
a = Tensor.arange(-1, 1.0001, 2/(dim_sz-1)) if align_corners == 1 else Tensor.arange(-1+1/dim_sz, 1, 2/dim_sz)
if dim == 0: stackable = [a.reshape(dim_sz, *[1]*(len(data_sz)-1)) + size_zeros, *stackable]
elif dim == 1: stackable = [a.reshape(1, dim_sz, *[1]*(len(data_sz)-2)) + size_zeros, *stackable]
else: stackable = [a.reshape(1, dim_sz) + size_zeros, *stackable]
original_grid = Tensor.stack(stackable, dim=len(data_sz))
if original_grid.ndim == 3:
N, dim_2d, dim_homo = theta.shape
assert dim_2d == 2 and dim_homo == 3
H, W, dim_homo = original_grid.shape
assert dim_homo == 3
original_grid = original_grid.reshape(H*W, dim_homo).transpose()
return theta.matmul(original_grid).permute(0,2,1).reshape(N, H, W, dim_2d)
else:
assert original_grid.ndim == 4
N, dim_3d, dim_homo = theta.shape
assert dim_3d == 3 and dim_homo == 4
D, H, W, dim_homo = original_grid.shape
assert dim_homo == 4
original_grid = original_grid.reshape(D*H*W, dim_homo).transpose()
return theta.matmul(original_grid).permute(0,2,1).reshape(N, D, H, W, dim_3d)
# **************** com.microsoft Ops ****************
def SkipLayerNormalization(input:Tensor, skip:Tensor, gamma, beta:Optional[Tensor]=None, bias:Optional[Tensor]=None, epsilon=None):

View File

@@ -44,7 +44,7 @@ setup(name='tinygrad',
"pillow",
"pytest",
"pytest-xdist",
"onnx==1.14.1",
"onnx==1.15.0",
"onnx2torch",
"opencv-python",
"tabulate",

View File

@@ -93,9 +93,15 @@ backend_test.exclude('test_lstm_*')
backend_test.exclude('test_simple_rnn_*')
# no control flow
# control flow uses AttributeProto.GRAPH
backend_test.exclude('test_if_*')
backend_test.exclude('test_loop*')
backend_test.exclude('test_range_float_type_positive_delta_expanded_cpu') # requires loop
backend_test.exclude('test_affine_grid_2d_align_corners_expanded_cpu')
backend_test.exclude('test_affine_grid_2d_expanded_cpu')
backend_test.exclude('test_affine_grid_3d_align_corners_expanded_cpu')
backend_test.exclude('test_affine_grid_3d_expanded_cpu')
backend_test.exclude('test_range_int32_type_negative_delta_expanded_cpu')
# unsupported (strange) ops
backend_test.exclude('test_bitwise_*')
@@ -139,20 +145,26 @@ backend_test.exclude('test_resize_upsample_scales_cubic_*') # unsure how to impl
backend_test.exclude('test_resize_upsample_sizes_cubic_*') # unsure how to implement cubic
# rest of the failing tests
backend_test.exclude('test_averagepool_2d_dilations_cpu') # dilations != 1 not supported for avgpool
backend_test.exclude('test_averagepool_2d_dilations_*') # dilations != 1 not supported for avgpool in tensor.py
backend_test.exclude('test_averagepool_3d_dilations_*') # dilations != 1 not supported for avgpool in tensor.py
backend_test.exclude('test_regex_*') # does not support string Tensors
backend_test.exclude('test_convtranspose_autopad_same_cpu') # TODO geohotstan has no idea how this is done, autopad requires output_shape but output_shape requires pads from autopad
backend_test.exclude('test_optional_has_element_empty_optional_input_cpu') # Attempts to create Tensor from None
backend_test.exclude('test_range_int32_type_negative_delta_expanded_cpu') # AttributeProto.GRAPH not implemented
backend_test.exclude('test_reshape_allowzero_reordered_cpu') # reshaping to 0 shape
backend_test.exclude('test_reshape_allowzero_reordered_cpu') # reshaping to shape with 0
backend_test.exclude('test_reduce_min_empty_set_cpu') # max a tensor with 0 in shape
backend_test.exclude('test_reduce_sum_empty_set_non_reduced_axis_zero_cpu') # reducing a tensor with 0 in shape
backend_test.exclude('test_resize_downsample_scales_linear_antialias_cpu') # antialias not implemented
backend_test.exclude('test_resize_downsample_sizes_linear_antialias_cpu') # antialias not implemented
backend_test.exclude('test_resize_tf_crop_and_resize_cpu') # unsure about fill value after clip
backend_test.exclude('test_operator_addconstant_cpu') # bad data type
backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_cpu') # bad data type string
backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad data type string
# issue 1556 https://github.com/tinygrad/tinygrad/issues/1556
backend_test.exclude('test_isinf_cpu')
backend_test.exclude('test_isinf_negative_cpu')
backend_test.exclude('test_isinf_positive_cpu')
backend_test.exclude('test_isinf_float16_cpu')
backend_test.exclude('test_isnan_float16_cpu')
backend_test.exclude('test_isnan_cpu')
# issue 1791 fast math messes with these https://github.com/tinygrad/tinygrad/issues/1791
@@ -165,15 +177,19 @@ if getenv('METAL'):
backend_test.exclude('test_maxpool_2d_pads_cpu')
backend_test.exclude('test_maxpool_2d_same_lower_cpu')
# Don't know how to treat special TensorProto like TensorProto.FLOAT8E4M3FN
if getenv("CPU") or getenv("TORCH"):
backend_test.exclude('test_dequantizelinear_axis_cpu')
backend_test.exclude('test_dequantizelinear_cpu')
# compiled backends cannot reshape to and from 0
# compiled backends cannot reshape to or from 0
if getenv('LLVM') or getenv('GPU') or getenv('CLANG') or getenv('METAL') or getenv('CUDA'):
backend_test.exclude('test_slice_start_out_of_bounds_cpu')
backend_test.exclude('test_constantofshape_int_shape_zero_cpu')
backend_test.exclude('test_reduce_l1_empty_set_cpu')
backend_test.exclude('test_reduce_sum_empty_set_cpu')
backend_test.exclude('test_reduce_l1_empty_set_expanded_cpu')
backend_test.exclude('test_reduce_sum_square_empty_set_cpu')
backend_test.exclude('test_reduce_l2_empty_set_cpu')
backend_test.exclude('test_reduce_sum_square_empty_set_expanded_cpu')
backend_test.exclude('test_reduce_l2_empty_set_expanded_cpu')
backend_test.exclude('test_reduce_log_sum_empty_set_cpu')
backend_test.exclude('test_reduce_log_sum_empty_set_expanded_cpu')
if getenv('GPU') or getenv('METAL'):
backend_test.exclude('test_mish_cpu') # weird inaccuracy
@@ -184,6 +200,7 @@ if getenv('GPU') or getenv('METAL'):
if (getenv('LLVM') or getenv('CUDA')) and CI:
backend_test.exclude('test_max_float16_cpu')
backend_test.exclude('test_min_float16_cpu')
backend_test.exclude('test_dequantizelinear_e4m3fn_float16_cpu')
# disable model tests for now since they are slow
if not getenv("MODELTESTS"):

View File

@@ -10,7 +10,7 @@ def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple
return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b)
base_fxn_for_op: Dict[Op, Callable] = {
BufferOps.MEM: lambda x: x._buf, UnaryOps.NEG: operator.neg, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv,
BufferOps.MEM: lambda x: x._buf, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv,
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
MovementOps.RESHAPE: lambda x, arg: x.reshape(arg), MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
@@ -33,7 +33,7 @@ def einsum_mulacc(einsum, get_strides, expand):
numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(np.promote_types(x.dtype,y.dtype)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,

View File

@@ -21,7 +21,7 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
#BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]),
BufferOps.CONST: lambda val, dtype: torch.from_numpy(np.array(val, dtype=dtype.np)).requires_grad_(False).to(device),
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.SIN: torch.sin,
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])),
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])), UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).type(torch.promote_types(x.dtype, y.dtype)), BinaryOps.SUB: lambda x,y: torch.logical_xor(x, y) if y.dtype is torch.bool else torch.sub(x, y),
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]), # pylint: disable=E1102
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)),