mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
Onnx Layer/Group/RMS/Batch-Norm ReduceL2 fp32 intermediates for fp16 (#12109)
* match onnx spec * use least_upper_dtype * promote the square * just cast before the square
This commit is contained in:
6
test/external/external_test_onnx_ops.py
vendored
6
test/external/external_test_onnx_ops.py
vendored
@@ -272,6 +272,10 @@ class TestMainOnnxOps(TestOnnxOps):
|
||||
def test_qlinearmatmul_2D_int8_float32(self): self._run_qlinearmatmul_test(np.int8, np.float32, 2)
|
||||
def test_qlinearmatmul_3D_int8_float32(self): self._run_qlinearmatmul_test(np.int8, np.float32, 3)
|
||||
|
||||
def test_reduce_l2_half(self):
|
||||
inputs = {"data": np.random.randn(1, 1, 32, 32, 32).astype(np.half)*100}
|
||||
self.helper_test_single_op("ReduceL2", inputs, {}, ["reduced"])
|
||||
|
||||
class TestTrainingOnnxOps(TestOnnxOps):
|
||||
# NOTE: ORT doesn't actually support training ops on cpu so we test using functions provided by onnx
|
||||
DOMAIN = AI_ONNX_PREVIEW_TRAINING_DOMAIN
|
||||
@@ -487,4 +491,4 @@ class TestContribOnnxOps(TestOnnxOps):
|
||||
self.helper_test_single_op("QLinearGlobalAveragePool", inputs, attributes, outputs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@@ -5,7 +5,7 @@ from io import BufferedReader
|
||||
from tinygrad.nn.state import TensorIO
|
||||
from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr
|
||||
from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort, is_numpy_ndarray, get_single_element, polyN
|
||||
from tinygrad.dtype import DType, ConstType, dtypes, _from_np_dtype, truncate
|
||||
from tinygrad.dtype import DType, ConstType, dtypes, _from_np_dtype, truncate, least_upper_dtype
|
||||
from tinygrad.device import is_dtype_supported, Device
|
||||
|
||||
# ***** protobuf definitions ******
|
||||
@@ -670,7 +670,8 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
def ReduceL1(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes)
|
||||
def ReduceL2(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt()
|
||||
dtype = dtypes.float if data.dtype in (dtypes.float16, dtypes.bfloat16) else data.dtype
|
||||
return ReduceSum(data.cast(dtype).square(), axes, keepdims, noop_with_empty_axes).sqrt().cast(data.dtype)
|
||||
def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log()
|
||||
def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
@@ -897,7 +898,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9,
|
||||
training_mode:int=0, spatial=1, is_test=0):
|
||||
if training_mode:
|
||||
x_detached = X.detach()
|
||||
x_detached = X.detach().cast(least_upper_dtype(X.dtype, dtypes.float32))
|
||||
current_mean = x_detached.mean(axis=(0,2,3))
|
||||
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
current_var = (y*y).mean(axis=(0,2,3))
|
||||
@@ -906,18 +907,20 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
running_mean = input_mean * momentum + current_mean * (1 - momentum)
|
||||
running_var = input_var * momentum + current_var * (1 - momentum)
|
||||
|
||||
return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
|
||||
return X.batchnorm(scale, B, current_mean, current_invstd).cast(X.dtype),running_mean.cast(input_mean.dtype),running_var.cast(input_var.dtype)
|
||||
return X.batchnorm(scale, B, input_mean, (input_var + epsilon).rsqrt())
|
||||
def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05):
|
||||
x = x.reshape(x.shape[0], num_groups, -1).layernorm(eps=epsilon).reshape(x.shape)
|
||||
def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05, stash_type:int=1):
|
||||
assert stash_type == 1, "only float32 is supported"
|
||||
x = x.reshape(x.shape[0], num_groups, -1).cast(dtypes.float).layernorm(eps=epsilon).cast(x.dtype).reshape(x.shape)
|
||||
return x * scale.reshape(1, -1, *[1] * (x.ndim-2)) + bias.reshape(1, -1, *[1] * (x.ndim-2))
|
||||
def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05):
|
||||
return GroupNormalization(x, scale, bias, num_groups=cast(int, x.shape[1]), epsilon=epsilon)
|
||||
def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1):
|
||||
assert stash_type == 1, "only float32 is supported"
|
||||
axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
|
||||
mean = x.mean(axis=axes, keepdim=True)
|
||||
return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt()
|
||||
mean = (x32:=x.cast(dtypes.float)).mean(axis=axes, keepdim=True)
|
||||
inv_std_dev = (x32.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt()
|
||||
return (x32.sub(mean)*inv_std_dev).cast(x.dtype).mul(scale).add(bias), mean, inv_std_dev
|
||||
def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12):
|
||||
x = x + skip
|
||||
if bias is not None: x = x + bias
|
||||
@@ -1089,9 +1092,10 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
return output, present_key, present_value, qk_matmul_return_val
|
||||
Attention = {OpSetId(Domain.ONNX, 1): attention_onnx, OpSetId(Domain.MICROSOFT_CONTRIB_OPS, 1): attention_contrib}
|
||||
|
||||
def RMSNormalization(X:Tensor, scale:Tensor, axis:int=-1, epsilon:float=1e-5):
|
||||
norm = X.square().mean(axis=tuple(range(axis + X.ndim if axis < 0 else axis, X.ndim)), keepdim=True).add(epsilon).rsqrt()
|
||||
return X * norm * scale
|
||||
def RMSNormalization(X:Tensor, scale:Tensor, axis:int=-1, epsilon:float=1e-5, stash_type:int=1):
|
||||
assert stash_type == 1, "only float32 is supported"
|
||||
norm = X.cast(dtypes.float).square().mean(axis=tuple(range(axis + X.ndim if axis < 0 else axis, X.ndim)), keepdim=True).add(epsilon).rsqrt()
|
||||
return X.cast(X.dtype) * norm * scale
|
||||
|
||||
def RotaryEmbedding(X:Tensor, cos_cache:Tensor, sin_cache:Tensor, position_ids:Tensor|None=None, interleaved:int=0, num_heads:int|None=None,
|
||||
rotary_embedding_dim:int=0):
|
||||
|
||||
Reference in New Issue
Block a user