vgg7 (image upscaling) implementation - not the best, but it works (#255)

* vgg7 implementation - not the best, but it works

* VGG7 implementation: Spread nansbane to deter NaNs, maybe improved training experience

* VGG7 implementation: Fix training, for real this time

Results actually attempt to approximate the input

* VGG7 implementation: Sample probability management
This commit is contained in:
20kdc
2021-05-13 07:48:51 +01:00
committed by GitHub
parent 81bf933a91
commit 2653d33292
3 changed files with 505 additions and 0 deletions

76
extra/kinne.py Normal file
View File

@@ -0,0 +1,76 @@
from tinygrad.tensor import Tensor
import numpy
import os
# Format Details:
# A KINNE parameter set is stored as a set of files named "snoop_bin_*.bin",
# where the * is a number starting at 0.
# Each file is simply raw little-endian floats,
# as readable by: numpy.fromfile(path, "<f4")
# and as writable by: t.data.astype("<f4", "C").tofile(path)
# This format is intended to be extremely simple to get into literally anything.
# It is not intended to be structural or efficient - reloading a network when
# unnecessary is inefficient anyway.
# Ultimately, the idea behind this is as a format that, while it will always
# require code to implement, requires as little code as possible, and therefore
# works as a suitable interchange for any situation.
# To add to the usability of the format, some informal metadata is provided,
# in "meta.txt", which provides human-readable shape information.
# This is intended to help with debugging other implementations of the network,
# by providing concrete human-readable information on tensor shapes.
# It is NOT meant to be read by machines.
class KinneDir:
"""
A KinneDir is an intermediate object used to save or load a model.
"""
def __init__(self, base: str, save: bool):
"""
Opens a new KINNE directory with the given base path.
If save is true, the directory is created if possible.
(This does not create parents.)
Save being true or false determines if tensors are loaded or saved.
The base path is of the form "models/abc" - no trailing slash.
It is important that if you wish to save in the current directory,
you use ".", not the empty string.
"""
if save:
try:
os.mkdir(base)
except:
# Silence the exception - the directory may (and if reading, does) already exist.
pass
self.base = base + "/snoop_bin_"
self.next_part_index = 0
self.save = save
if save:
self.metadata = open(base + "/meta.txt", "w")
def parameter(self, t: Tensor):
"""
parameter loads or saves a parameter, given as a tensor.
"""
path = self.base + str(self.next_part_index) + ".bin"
if self.save:
t.data.astype("<f4", "C").tofile(path)
self.metadata.write(str(self.next_part_index) + ": " + str(t.shape) + "\n")
else:
t.assign(Tensor(numpy.fromfile(path, "<f4")).reshape(shape=t.shape))
self.next_part_index += 1
def parameters(self, params):
"""
parameters loads or saves a sequence of parameters.
It's intended for easily attaching to an existing model,
assuming that your parameters list orders are consistent.
(In other words, usage with tinygrad.utils.get_parameters isn't advised -
it's too 'implicit'.)
"""
for t in params:
self.parameter(t)
def close(self):
if self.save:
self.metadata.close()

178
extra/waifu2x.py Normal file
View File

@@ -0,0 +1,178 @@
# Implementation of waifu2x vgg7 in tinygrad.
# Obviously, not developed, supported, etc. by the original waifu2x author(s).
import numpy
from tinygrad.tensor import Tensor
from PIL import Image
# File Formats
# tinygrad convolution tensor input layout is (1,c,y,x) - and therefore the form for all images used in the project
# tinygrad convolution tensor weight layout is (outC,inC,H,W) - this matches NCNN (and therefore KINNE), but not waifu2x json
def image_load(path) -> numpy.ndarray:
"""
Loads an image in the shape expected by other functions in this module.
Doesn't Tensor it, in case you need to do further work with it.
"""
# file
na = numpy.array(Image.open(path))
# fix shape
na = numpy.moveaxis(na, [2,0,1], [0,1,2])
# shape is now (3,h,w), add 1
na = na.reshape(1,3,na.shape[1],na.shape[2])
# change type
na = na.astype("float32") / 255.0
return na
def image_save(path, na: numpy.ndarray):
"""
Saves an image of the shape expected by other functions in this module.
However, note this expects a numpy array.
"""
# change type
na = numpy.fmax(numpy.fmin(na * 255.0, 255), 0).astype("uint8")
# shape is now (1,3,h,w), remove 1
na = na.reshape(3,na.shape[2],na.shape[3])
# fix shape
na = numpy.moveaxis(na, [0,1,2], [2,0,1])
# shape is now (h,w,3)
# file
Image.fromarray(na).save(path)
# The Model
class Conv3x3Biased:
"""
A 3x3 convolution layer with some utility functions.
"""
def __init__(self, inC, outC, last = False):
# Massively overstate the weights to get them to be focused on,
# since otherwise the biases overrule everything
self.weight = Tensor.uniform(outC, inC, 3, 3) * 16.0
# Layout-wise, blatant cheat, but serious_mnist does it. I'd guess channels either have to have a size of 1 or whatever the target is?
# Values-wise, entirely different blatant cheat.
# In most cases, use uniform bias, but tiny.
# For the last layer, use just 0.5, constant.
if last:
self.bias = Tensor.zeros(1, outC, 1, 1) + 0.5
else:
self.bias = Tensor.uniform(1, outC, 1, 1)
def forward(self, x):
# You might be thinking, "but what about padding?"
# Answer: Tiling is used to stitch everything back together, though you could pad the image before providing it.
return x.conv2d(self.weight).add(self.bias)
def get_parameters(self) -> list:
return [self.weight, self.bias]
def load_waifu2x_json(self, layer: dict):
# Weights in this file are outChannel,inChannel,X,Y.
# Not outChannel,inChannel,Y,X.
# Therefore, transpose it before assignment.
# I have long since forgotten how I worked this out.
self.weight.assign(Tensor(layer["weight"]).reshape(shape=self.weight.shape).transpose(order=(0, 1, 3, 2)))
self.bias.assign(Tensor(layer["bias"]).reshape(shape=self.bias.shape))
class Vgg7:
"""
The 'vgg7' waifu2x network.
Lower quality and slower than even upconv7 (nevermind cunet), but is very easy to implement and test.
"""
def __init__(self):
self.conv1 = Conv3x3Biased(3, 32)
self.conv2 = Conv3x3Biased(32, 32)
self.conv3 = Conv3x3Biased(32, 64)
self.conv4 = Conv3x3Biased(64, 64)
self.conv5 = Conv3x3Biased(64, 128)
self.conv6 = Conv3x3Biased(128, 128)
self.conv7 = Conv3x3Biased(128, 3, True)
def forward(self, x):
"""
Forward pass: Actually runs the network.
Input format: (1, 3, Y, X)
Output format: (1, 3, Y - 14, X - 14)
(the - 14 represents the 7-pixel context border that is lost)
"""
x = self.conv1.forward(x).leakyrelu(0.1)
x = self.conv2.forward(x).leakyrelu(0.1)
x = self.conv3.forward(x).leakyrelu(0.1)
x = self.conv4.forward(x).leakyrelu(0.1)
x = self.conv5.forward(x).leakyrelu(0.1)
x = self.conv6.forward(x).leakyrelu(0.1)
x = self.conv7.forward(x)
return x
def get_parameters(self) -> list:
return self.conv1.get_parameters() + self.conv2.get_parameters() + self.conv3.get_parameters() + self.conv4.get_parameters() + self.conv5.get_parameters() + self.conv6.get_parameters() + self.conv7.get_parameters()
def load_waifu2x_json(self, data: list):
"""
Loads weights from one of the waifu2x JSON files, i.e. waifu2x/models/vgg_7/art/noise0_model.json
data (passed in) is assumed to be the output of json.load or some similar on such a file
"""
self.conv1.load_waifu2x_json(data[0])
self.conv2.load_waifu2x_json(data[1])
self.conv3.load_waifu2x_json(data[2])
self.conv4.load_waifu2x_json(data[3])
self.conv5.load_waifu2x_json(data[4])
self.conv6.load_waifu2x_json(data[5])
self.conv7.load_waifu2x_json(data[6])
def forward_tiled(self, image: numpy.ndarray, tile_size: int) -> numpy.ndarray:
"""
Given an ndarray image as loaded by image_load (NOT a tensor), scales it, pads it, splits it up, forwards the pieces, and reconstitutes it.
Note that you really shouldn't try to run anything not (1, 3, *, *) through this.
"""
# Constant that only really gets repeated a ton here.
context = 7
context2 = context + context
# Notably, numpy is used here because it makes this fine manipulation a lot simpler.
# Scaling first - repeat on axis 2 and axis 3 (Y & X)
image = image.repeat(2, 2).repeat(2, 3)
# Resulting image buffer. This is made before the input is padded,
# since the input has the padded shape right now.
image_out = numpy.zeros(image.shape)
# Padding next. Note that this padding is done on the whole image.
# Padding the tiles would lose critical context, cause seams, etc.
image = numpy.pad(image, [[0, 0], [0, 0], [context, context], [context, context]], mode = "edge")
# Now for tiling.
# The output tile size is the usable output from an input tile (tile_size).
# As such, the tiles overlap.
out_tile_size = tile_size - context2
for out_y in range(0, image_out.shape[2], out_tile_size):
for out_x in range(0, image_out.shape[3], out_tile_size):
# Input is sourced from the same coordinates, but some stuff ought to be
# noted here for future reference:
# + out_x/y's equivalent position w/ the padding is out_x + context.
# + The output, however, is without context. Input needs context.
# + Therefore, the input rectangle is expanded on all sides by context.
# + Therefore, the input position has the context subtracted again.
# + Therefore:
in_y = out_y
in_x = out_x
# not shown: in_w/in_h = tile_size (as opposed to out_tile_size)
# Extract tile.
# Note that numpy will auto-crop this at the bottom-right.
# This will never be a problem, as tiles are specifically chosen within the padded section.
tile = image[:, :, in_y:in_y + tile_size, in_x:in_x + tile_size]
# Extracted tile dimensions -> output dimensions
# This is important because of said cropping, otherwise it'd be interior tile size.
out_h = tile.shape[2] - context2
out_w = tile.shape[3] - context2
# Process tile.
tile_t = Tensor(tile)
tile_fwd_t = self.forward(tile_t)
# Replace tile.
image_out[:, :, out_y:out_y + out_h, out_x:out_x + out_w] = tile_fwd_t.data
return image_out