mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Add MLPerf UNet3D model (#775)
* Add ResNet inference test and cannon * Test with ResNet50 * test_car works with resnet fix * Add KiTS19 dataset * KiTS19: Implement iterate * No batch load for this dataset * Save results on iterate * Implement dice score * Add data prep and eval functions * Resolve shape issue * Conversion works but wrong values * Segfaults when load_from_pretrained is called * Fix segfault and assign properly * Final result generated, though very slow * Store and load final result to save time * Fix typo in finalize * Score computes * More bug fixes, dice score is very low * Working broken code * Assign output values to result * Getting a much higher score now * Fix dataset preprocessing * Mean DICE score of 88.5 * Ugh, typo * Attempt to reimplement model * Rename layers * Tiny model works, kinda * Accuracy? gone * Implement InstanceNorm and match torch * Test instance norm 2d and 3d * Combined input block with downsample block * Tiny model works, support strided convtranspose * Commands to download dataset * Clean up a bit * unet3d_v2 -> unet3d * Remove duplicated code * Oops, put tests back
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -23,6 +23,7 @@ disassemblers/cuda_ioctl_sniffer
|
||||
datasets/cifar-10-python.tar.gz
|
||||
datasets/librispeech/
|
||||
datasets/imagenet/
|
||||
datasets/kits19/
|
||||
datasets/squad/
|
||||
datasets/img_align_celeba*
|
||||
datasets/open-images-v6-mlperf
|
||||
|
||||
131
datasets/kits19.py
Normal file
131
datasets/kits19.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import random
|
||||
import functools
|
||||
from pathlib import Path
|
||||
import requests
|
||||
import numpy as np
|
||||
import nibabel as nib
|
||||
from scipy import signal
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
BASEDIR = Path(__file__).parent.parent.resolve() / "datasets" / "kits19" / "data"
|
||||
|
||||
"""
|
||||
To download the dataset:
|
||||
```sh
|
||||
git clone https://github.com/neheller/kits19
|
||||
cd kits19
|
||||
pip3 install -r requirements.txt
|
||||
python3 -m starter_code.get_imaging
|
||||
cd ..
|
||||
mv kits datasets
|
||||
```
|
||||
"""
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_val_files():
|
||||
data = requests.get("https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt")
|
||||
return sorted([x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.text.split("\n")])
|
||||
|
||||
def load_pair(file_path):
|
||||
image, label = nib.load(file_path / "imaging.nii.gz"), nib.load(file_path / "segmentation.nii.gz")
|
||||
image_spacings = image.header["pixdim"][1:4].tolist()
|
||||
image, label = image.get_fdata().astype(np.float32), label.get_fdata().astype(np.uint8)
|
||||
image, label = np.expand_dims(image, 0), np.expand_dims(label, 0)
|
||||
return image, label, image_spacings
|
||||
|
||||
def resample3d(image, label, image_spacings, target_spacing=(1.6, 1.2, 1.2)):
|
||||
if image_spacings != target_spacing:
|
||||
spc_arr, targ_arr, shp_arr = np.array(image_spacings), np.array(target_spacing), np.array(image.shape[1:])
|
||||
new_shape = (spc_arr / targ_arr * shp_arr).astype(int).tolist()
|
||||
image = F.interpolate(torch.from_numpy(np.expand_dims(image, axis=0)), size=new_shape, mode="trilinear", align_corners=True)
|
||||
label = F.interpolate(torch.from_numpy(np.expand_dims(label, axis=0)), size=new_shape, mode="nearest")
|
||||
image = np.squeeze(image.numpy(), axis=0)
|
||||
label = np.squeeze(label.numpy(), axis=0)
|
||||
return image, label
|
||||
|
||||
def normal_intensity(image, min_clip=-79.0, max_clip=304.0, mean=101.0, std=76.9):
|
||||
image = np.clip(image, min_clip, max_clip)
|
||||
image = (image - mean) / std
|
||||
return image
|
||||
|
||||
def pad_to_min_shape(image, label, roi_shape=(128, 128, 128)):
|
||||
current_shape = image.shape[1:]
|
||||
bounds = [max(0, roi_shape[i] - current_shape[i]) for i in range(3)]
|
||||
paddings = [(0, 0)] + [(bounds[i] // 2, bounds[i] - bounds[i] // 2) for i in range(3)]
|
||||
image = np.pad(image, paddings, mode="edge")
|
||||
label = np.pad(label, paddings, mode="edge")
|
||||
return image, label
|
||||
|
||||
def preprocess(file_path):
|
||||
image, label, image_spacings = load_pair(file_path)
|
||||
image, label = resample3d(image, label, image_spacings)
|
||||
image = normal_intensity(image.copy())
|
||||
image, label = pad_to_min_shape(image, label)
|
||||
return image, label
|
||||
|
||||
def iterate(val=True, shuffle=False):
|
||||
if not val: raise NotImplementedError
|
||||
files = get_val_files()
|
||||
order = list(range(0, len(files)))
|
||||
if shuffle: random.shuffle(order)
|
||||
for file in files:
|
||||
X, Y = preprocess(file)
|
||||
X = np.expand_dims(X, axis=0)
|
||||
yield (X, Y)
|
||||
|
||||
def gaussian_kernel(n, std):
|
||||
gaussian_1d = signal.gaussian(n, std)
|
||||
gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
|
||||
gaussian_3d = np.outer(gaussian_2d, gaussian_1d)
|
||||
gaussian_3d = gaussian_3d.reshape(n, n, n)
|
||||
gaussian_3d = np.cbrt(gaussian_3d)
|
||||
gaussian_3d /= gaussian_3d.max()
|
||||
return gaussian_3d
|
||||
|
||||
def pad_input(volume, roi_shape, strides, padding_mode="constant", padding_val=-2.2, dim=3):
|
||||
bounds = [(strides[i] - volume.shape[2:][i] % strides[i]) % strides[i] for i in range(dim)]
|
||||
bounds = [bounds[i] if (volume.shape[2:][i] + bounds[i]) >= roi_shape[i] else bounds[i] + strides[i] for i in range(dim)]
|
||||
paddings = [bounds[2]//2, bounds[2]-bounds[2]//2, bounds[1]//2, bounds[1]-bounds[1]//2, bounds[0]//2, bounds[0]-bounds[0]//2, 0, 0, 0, 0]
|
||||
return F.pad(torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val).numpy(), paddings
|
||||
|
||||
def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5):
|
||||
from tinygrad.jit import TinyJit
|
||||
mdl_run = TinyJit(lambda x: model(x).realize())
|
||||
image_shape, dim = list(inputs.shape[2:]), len(inputs.shape[2:])
|
||||
strides = [int(roi_shape[i] * (1 - overlap)) for i in range(dim)]
|
||||
bounds = [image_shape[i] % strides[i] for i in range(dim)]
|
||||
bounds = [bounds[i] if bounds[i] < strides[i] // 2 else 0 for i in range(dim)]
|
||||
inputs = inputs[
|
||||
...,
|
||||
bounds[0]//2:image_shape[0]-(bounds[0]-bounds[0]//2),
|
||||
bounds[1]//2:image_shape[1]-(bounds[1]-bounds[1]//2),
|
||||
bounds[2]//2:image_shape[2]-(bounds[2]-bounds[2]//2),
|
||||
]
|
||||
labels = labels[
|
||||
...,
|
||||
bounds[0]//2:image_shape[0]-(bounds[0]-bounds[0]//2),
|
||||
bounds[1]//2:image_shape[1]-(bounds[1]-bounds[1]//2),
|
||||
bounds[2]//2:image_shape[2]-(bounds[2]-bounds[2]//2),
|
||||
]
|
||||
inputs, paddings = pad_input(inputs, roi_shape, strides)
|
||||
padded_shape = inputs.shape[2:]
|
||||
size = [(inputs.shape[2:][i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)]
|
||||
result = np.zeros((1, 3, *padded_shape), dtype=np.float32)
|
||||
norm_map = np.zeros((1, 3, *padded_shape), dtype=np.float32)
|
||||
norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0])
|
||||
norm_patch = np.expand_dims(norm_patch, axis=0)
|
||||
for i in range(0, strides[0] * size[0], strides[0]):
|
||||
for j in range(0, strides[1] * size[1], strides[1]):
|
||||
for k in range(0, strides[2] * size[2], strides[2]):
|
||||
out = mdl_run(Tensor(inputs[..., i:roi_shape[0]+i,j:roi_shape[1]+j, k:roi_shape[2]+k])).numpy()
|
||||
result[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += out * norm_patch
|
||||
norm_map[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += norm_patch
|
||||
result /= norm_map
|
||||
result = result[..., paddings[4]:image_shape[0]+paddings[4], paddings[2]:image_shape[1]+paddings[2], paddings[0]:image_shape[2]+paddings[0]]
|
||||
return result, labels
|
||||
|
||||
if __name__ == "__main__":
|
||||
for X, Y in iterate():
|
||||
print(X.shape, Y.shape)
|
||||
@@ -1,5 +1,35 @@
|
||||
from collections import OrderedDict
|
||||
import unicodedata
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
|
||||
def gaussian_kernel(n, std):
|
||||
gaussian_1d = signal.gaussian(n, std)
|
||||
gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
|
||||
gaussian_3d = np.outer(gaussian_2d, gaussian_1d)
|
||||
gaussian_3d = gaussian_3d.reshape(n, n, n)
|
||||
gaussian_3d = np.cbrt(gaussian_3d)
|
||||
gaussian_3d /= gaussian_3d.max()
|
||||
return gaussian_3d
|
||||
|
||||
def prepare_arrays(image, roi_shape=(128, 128, 128)):
|
||||
assert len(roi_shape) == 3 and any(roi_shape)
|
||||
image_shape = list(image.shape[2:])
|
||||
result = np.zeros((1, 3, *image_shape), dtype=image.dtype)
|
||||
norm_map = np.zeros_like(result)
|
||||
norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]).astype(norm_map.dtype)
|
||||
return result, norm_map, norm_patch
|
||||
|
||||
def get_slice(image, roi_shape=(128, 128, 128), overlap_factor=0.5):
|
||||
assert len(roi_shape) == 3 and any(roi_shape)
|
||||
assert 0 < overlap_factor < 1
|
||||
image_shape, dim = list(image.shape[2:]), len(image.shape[2:])
|
||||
strides = [int(roi_shape[i] * (1 - overlap_factor)) for i in range(dim)]
|
||||
size = [(image_shape[i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)]
|
||||
for i in range(0, strides[0] * size[0], strides[0]):
|
||||
for j in range(0, strides[1] * size[1], strides[1]):
|
||||
for k in range(0, strides[2] * size[2], strides[2]):
|
||||
yield i, j, k
|
||||
|
||||
def _get_best_indices(logits, n_best_size):
|
||||
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import re
|
||||
import string
|
||||
from collections import Counter
|
||||
import numpy as np
|
||||
|
||||
def levenshtein(a, b):
|
||||
n, m = len(a), len(b)
|
||||
@@ -28,6 +29,22 @@ def word_error_rate(x, y):
|
||||
scores += levenshtein(h_list, r_list)
|
||||
return float(scores) / words, float(scores), words
|
||||
|
||||
def one_hot(arr, num_classes=3):
|
||||
res = np.eye(num_classes)[np.array(arr).reshape(-1)]
|
||||
arr = res.reshape(list(arr.shape) + [num_classes])
|
||||
arr = arr.transpose((0, 4, 1, 2, 3)).astype(np.float32)
|
||||
return arr
|
||||
|
||||
def get_dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr=1e-6):
|
||||
channel_axis, reduce_axis = 1, tuple(range(2, len(prediction.shape)))
|
||||
prediction = prediction.argmax(axis=channel_axis)
|
||||
prediction, target= one_hot(prediction)[:, 1:], one_hot(target)[:, 1:]
|
||||
intersection = np.sum(prediction * target, axis=reduce_axis)
|
||||
target_sum = np.sum(target, axis=reduce_axis)
|
||||
prediction_sum = np.sum(prediction, axis=reduce_axis)
|
||||
result = (2.0 * intersection + smooth_nr) / (target_sum + prediction_sum + smooth_dr)
|
||||
return result[0]
|
||||
|
||||
def normalize_string(s):
|
||||
s = "".join(c for c in s.lower() if c not in string.punctuation)
|
||||
s = re.sub(r'\b(a|an|the)\b', ' ', s)
|
||||
|
||||
@@ -4,6 +4,7 @@ import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.helpers import getenv
|
||||
from examples.mlperf import helpers
|
||||
|
||||
def eval_resnet():
|
||||
# Resnet50-v1.5
|
||||
@@ -41,6 +42,24 @@ def eval_resnet():
|
||||
d += len(t)
|
||||
print(f"****** {n}/{d} {n*100.0/d:.2f}%")
|
||||
st = time.perf_counter()
|
||||
|
||||
def eval_unet3d():
|
||||
# UNet3D
|
||||
from models.unet3d import UNet3D
|
||||
from datasets.kits19 import iterate, sliding_window_inference
|
||||
from examples.mlperf.metrics import get_dice_score
|
||||
mdl = UNet3D()
|
||||
mdl.load_from_pretrained()
|
||||
s = 0
|
||||
st = time.perf_counter()
|
||||
for i, (image, label) in enumerate(iterate(), start=1):
|
||||
mt = time.perf_counter()
|
||||
pred, label = sliding_window_inference(mdl, image, label)
|
||||
et = time.perf_counter()
|
||||
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
|
||||
s += get_dice_score(pred, label).mean()
|
||||
print(f"****** {s:.2f}/{i} {s/i:.5f} Mean DICE score")
|
||||
st = time.perf_counter()
|
||||
|
||||
def eval_retinanet():
|
||||
# RetinaNet with ResNeXt50_32X4D
|
||||
|
||||
@@ -28,7 +28,8 @@ def spec_unet3d():
|
||||
# 3D UNET
|
||||
from models.unet3d import UNet3D
|
||||
mdl = UNet3D()
|
||||
img = Tensor.randn(1, 1, 5, 224, 224)
|
||||
mdl.load_from_pretrained()
|
||||
img = Tensor.randn(1, 1, 128, 128, 128)
|
||||
test_model(mdl, img)
|
||||
|
||||
def spec_rnnt():
|
||||
|
||||
@@ -1,42 +1,59 @@
|
||||
# https://github.com/wolny/pytorch-3dunet
|
||||
from pathlib import Path
|
||||
from extra.utils import download_file, fake_torch_load, get_child
|
||||
import tinygrad.nn as nn
|
||||
import torch
|
||||
from tinygrad import nn
|
||||
from tinygrad.tensor import Tensor
|
||||
from extra.utils import download_file, get_child
|
||||
|
||||
class SingleConv:
|
||||
def __init__(self, in_channels, out_channels):
|
||||
self.groupnorm = nn.GroupNorm(1, in_channels) # 1 group?
|
||||
# TODO: make 2D conv generic for 3D, might already work with kernel_size=(3,3,3)
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False)
|
||||
def __call__(self, x):
|
||||
return self.conv(self.groupnorm(x)).relu()
|
||||
class DownsampleBlock:
|
||||
def __init__(self, c0, c1, stride=2):
|
||||
self.conv1 = [nn.Conv2d(c0, c1, kernel_size=(3,3,3), stride=stride, padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
|
||||
self.conv2 = [nn.Conv2d(c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
|
||||
|
||||
class BasicModule:
|
||||
def __init__(self, c0, c1, c2):
|
||||
self.basic_module = {"SingleConv1": SingleConv(c0, c1), "SingleConv2": SingleConv(c1, c2)}
|
||||
def __call__(self, x):
|
||||
return self.basic_module['SingleConv2'](self.basic_module['SingleConv1'](x))
|
||||
return x.sequential(self.conv1).sequential(self.conv2)
|
||||
|
||||
class UpsampleBlock:
|
||||
def __init__(self, c0, c1):
|
||||
self.upsample_conv = [nn.ConvTranspose2d(c0, c1, kernel_size=(2,2,2), stride=2)]
|
||||
self.conv1 = [nn.Conv2d(2 * c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
|
||||
self.conv2 = [nn.Conv2d(c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
|
||||
|
||||
def __call__(self, x, skip):
|
||||
x = x.sequential(self.upsample_conv)
|
||||
x = Tensor.cat(x, skip, dim=1)
|
||||
return x.sequential(self.conv1).sequential(self.conv2)
|
||||
|
||||
class UNet3D:
|
||||
def __init__(self):
|
||||
ups = [16,32,64,128,256]
|
||||
self.encoders = [BasicModule(ups[i] if i != 0 else 1, ups[i], ups[i+1]) for i in range(4)]
|
||||
self.decoders = [BasicModule(ups[-1-i] + ups[-2-i], ups[-2-i], ups[-2-i]) for i in range(3)]
|
||||
self.final_conv = nn.Conv2d(32, 1, (1,1,1))
|
||||
def __init__(self, in_channels=1, n_class=3):
|
||||
filters = [32, 64, 128, 256, 320]
|
||||
inp, out = filters[:-1], filters[1:]
|
||||
self.input_block = DownsampleBlock(in_channels, filters[0], stride=1)
|
||||
self.downsample = [DownsampleBlock(i, o) for i, o in zip(inp, out)]
|
||||
self.bottleneck = DownsampleBlock(filters[-1], filters[-1])
|
||||
self.upsample = [UpsampleBlock(filters[-1], filters[-1])] + [UpsampleBlock(i, o) for i, o in zip(out[::-1], inp[::-1])]
|
||||
self.output = {"conv": nn.Conv2d(filters[0], n_class, kernel_size=(1, 1, 1))}
|
||||
|
||||
def __call__(self, x):
|
||||
intermediates = [x]
|
||||
for e in self.encoders: intermediates.append(e(intermediates[-1]))
|
||||
ret = intermediates[-1]
|
||||
for d,i in zip(self.decoders, intermediates[:-1][::-1]): ret = d(ret.cat(i, dim=1))
|
||||
return ret
|
||||
|
||||
x = self.input_block(x)
|
||||
outputs = [x]
|
||||
for downsample in self.downsample:
|
||||
x = downsample(x)
|
||||
outputs.append(x)
|
||||
x = self.bottleneck(x)
|
||||
for upsample, skip in zip(self.upsample, outputs[::-1]):
|
||||
x = upsample(x, skip)
|
||||
x = self.output["conv"](x)
|
||||
return x
|
||||
|
||||
def load_from_pretrained(self):
|
||||
fn = Path(__file__).parent.parent / "weights/unet-3d.ckpt"
|
||||
download_file("https://oc.embl.de/index.php/s/61s67Mg5VQy7dh9/download?path=%2FLateral-Root-Primordia%2Funet_bce_dice_ds1x&files=best_checkpoint.pytorch", fn)
|
||||
state_dict = fake_torch_load(open(fn, "rb").read())['model_state_dict']
|
||||
fn = Path(__file__).parent.parent / "weights" / "unet-3d.ckpt"
|
||||
download_file("https://zenodo.org/record/5597155/files/3dunet_kits19_pytorch.ptc?download=1", fn)
|
||||
state_dict = torch.jit.load(fn, map_location=torch.device("cpu")).state_dict()
|
||||
for k, v in state_dict.items():
|
||||
print(k, v.shape)
|
||||
obj = get_child(self, k)
|
||||
assert obj.shape == v.shape, (k, obj.shape, v.shape)
|
||||
obj.assign(v.numpy())
|
||||
|
||||
if __name__ == "__main__":
|
||||
mdl = UNet3D()
|
||||
mdl.load_from_pretrained()
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
import numpy as np
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.nn import BatchNorm2d, Conv2d, ConvTranspose2d, Linear, GroupNorm, LayerNorm, LayerNorm2d, Embedding
|
||||
from tinygrad.nn import BatchNorm2d, Conv2d, ConvTranspose2d, Linear, GroupNorm, LayerNorm, LayerNorm2d, Embedding, InstanceNorm
|
||||
import torch
|
||||
|
||||
class TestNN(unittest.TestCase):
|
||||
@@ -168,6 +168,45 @@ class TestNN(unittest.TestCase):
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.cpu().numpy())
|
||||
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
|
||||
|
||||
def test_instancenorm_2d(self):
|
||||
N, C, H, W = 20, 5, 10, 10
|
||||
|
||||
# create in tinygrad
|
||||
layer = InstanceNorm(C)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.InstanceNorm2d(C, affine=True).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor.randn(N, C, H, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.cpu().numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
|
||||
def test_instancenorm_3d(self):
|
||||
N, C, D, H, W = 20, 5, 3, 10, 10
|
||||
|
||||
# create in tinygrad
|
||||
layer = InstanceNorm(C)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.InstanceNorm3d(C, affine=True).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor.randn(N, C, D, H, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.cpu().numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
|
||||
def test_embedding(self):
|
||||
|
||||
@@ -78,6 +78,17 @@ class GroupNorm:
|
||||
# elementwise_affine on channels
|
||||
return x * self.weight.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) + self.bias.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)])
|
||||
|
||||
class InstanceNorm:
|
||||
def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True):
|
||||
self.num_features, self.eps = num_features, eps
|
||||
self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None
|
||||
self.bias: Optional[Tensor] = Tensor.zeros(num_features) if affine else None
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape)
|
||||
if self.weight is None or self.bias is None: return x
|
||||
return x * self.weight.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) + self.bias.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)])
|
||||
|
||||
class LayerNorm:
|
||||
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
|
||||
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
|
||||
|
||||
Reference in New Issue
Block a user