From 2653d332920be2f1875c2c8f352de5ebc6e9b4e4 Mon Sep 17 00:00:00 2001 From: 20kdc Date: Thu, 13 May 2021 07:48:51 +0100 Subject: [PATCH] vgg7 (image upscaling) implementation - not the best, but it works (#255) * vgg7 implementation - not the best, but it works * VGG7 implementation: Spread nansbane to deter NaNs, maybe improved training experience * VGG7 implementation: Fix training, for real this time Results actually attempt to approximate the input * VGG7 implementation: Sample probability management --- examples/vgg7.py | 251 +++++++++++++++++++++++++++++++++++++++++++++++ extra/kinne.py | 76 ++++++++++++++ extra/waifu2x.py | 178 +++++++++++++++++++++++++++++++++ 3 files changed, 505 insertions(+) create mode 100644 examples/vgg7.py create mode 100644 extra/kinne.py create mode 100644 extra/waifu2x.py diff --git a/examples/vgg7.py b/examples/vgg7.py new file mode 100644 index 0000000000..e705bfcede --- /dev/null +++ b/examples/vgg7.py @@ -0,0 +1,251 @@ +from PIL import Image +from tinygrad.tensor import Tensor +from tinygrad.optim import SGD +import extra.waifu2x +from extra.kinne import KinneDir +import sys +import os +import random +import json +import numpy + +# amount of context erased by model +CONTEXT = 7 + +def get_sample_count(samples_dir): + try: + samples_dir_count_file = open(samples_dir + "/sample_count.txt", "r") + v = samples_dir_count_file.readline() + samples_dir_count_file.close() + return int(v) + except: + return 0 + +def set_sample_count(samples_dir, sc): + samples_dir_count_file = open(samples_dir + "/sample_count.txt", "w") + samples_dir_count_file.write(str(sc) + "\n") + samples_dir_count_file.close() + +if len(sys.argv) < 2: + print("python3 -m examples.vgg7 import MODELJSON MODELDIR") + print(" imports a waifu2x JSON vgg_7 model, i.e. waifu2x/models/vgg_7/art/scale2.0x_model.json") + print(" into a directory of float binaries along with a meta.txt file containing tensor sizes") + print(" weight tensors are ordered in tinygrad/ncnn form, as so: (outC,inC,H,W)") + print(" *this format is used by all other commands in this program*") + print("python3 -m examples.vgg7 execute MODELDIR IMG_IN IMG_OUT") + print(" given an already-nearest-neighbour-scaled image, runs vgg7 on it") + print(" output image has 7 pixels removed on all edges") + print(" do not run on large images, will have *hilarious* RAM use") + print("python3 -m examples.vgg7 execute_full MODELDIR IMG_IN IMG_OUT") + print(" does the 'whole thing' (padding, tiling)") + print(" safe for large images, etc.") + print("python3 -m examples.vgg7 new MODELDIR") + print(" creates a new model (experimental)") + print("python3 -m examples.vgg7 train MODELDIR SAMPLES_DIR ROUNDS ROUNDS_SAVE") + print(" trains a model (experimental)") + print(" (how experimental? well, every time I tried it, it flooded w/ NaNs)") + print(" note: ROUNDS < 0 means 'forever'. ROUNDS_SAVE <= 0 is not a good idea.") + print(" expects roughly execute's input as SAMPLES_DIR/IDXa.png") + print(" expects roughly execute's output as SAMPLES_DIR/IDXb.png") + print(" (i.e. my_samples/0a.png is the first pre-nearest-scaled image,") + print(" my_samples/0b.png is the first original image)") + print(" in addition, SAMPLES_DIR/samples_count.txt indicates sample count") + print(" won't pad or tile, so keep image sizes sane") + print("python3 -m examples.vgg7 samplify IMG_A IMG_B SAMPLES_DIR SIZE") + print(" creates overlapping micropatches (SIZExSIZE w/ 7-pixel border) for training") + print(" maintains/creates samples_count.txt automatically") + print(" unlike training, IMG_A must be exactly half the size of IMG_B") + sys.exit(1) + +cmd = sys.argv[1] +vgg7 = extra.waifu2x.Vgg7() + +def nansbane(p): + if numpy.isnan(numpy.min(p.data)): + raise Exception("A NaN in the model has been detected. This model will not be interacted with to prevent further damage.") + +def load_and_save(path, save): + if save: + for v in vgg7.get_parameters(): + nansbane(v) + kn = KinneDir(model, save) + kn.parameters(vgg7.get_parameters()) + kn.close() + if not save: + for v in vgg7.get_parameters(): + nansbane(v) + +if cmd == "import": + src = sys.argv[2] + model = sys.argv[3] + + vgg7.load_waifu2x_json(json.load(open(src, "rb"))) + + os.mkdir(model) + load_and_save(model, True) +elif cmd == "execute": + model = sys.argv[2] + in_file = sys.argv[3] + out_file = sys.argv[4] + + load_and_save(model, False) + + extra.waifu2x.image_save(out_file, vgg7.forward(Tensor(extra.waifu2x.image_load(in_file))).data) +elif cmd == "execute_full": + model = sys.argv[2] + in_file = sys.argv[3] + out_file = sys.argv[4] + + load_and_save(model, False) + + extra.waifu2x.image_save(out_file, vgg7.forward_tiled(extra.waifu2x.image_load(in_file), 156)) +elif cmd == "new": + model = sys.argv[2] + + os.mkdir(model) + load_and_save(model, True) +elif cmd == "train": + model = sys.argv[2] + samples_base = sys.argv[3] + samples_count = get_sample_count(samples_base) + rounds = int(sys.argv[4]) + rounds_per_save = int(sys.argv[5]) + + load_and_save(model, False) + + # Initialize sample probabilities. + # This is used to try and get the network to focus on "interesting" samples, + # which works nicely with the microsample system. + sample_probs = None + sample_probs_path = model + "/sample_probs.bin" + try: + # try to read... + sample_probs = numpy.fromfile(sample_probs_path, " numpy.ndarray: + """ + Loads an image in the shape expected by other functions in this module. + Doesn't Tensor it, in case you need to do further work with it. + """ + # file + na = numpy.array(Image.open(path)) + # fix shape + na = numpy.moveaxis(na, [2,0,1], [0,1,2]) + # shape is now (3,h,w), add 1 + na = na.reshape(1,3,na.shape[1],na.shape[2]) + # change type + na = na.astype("float32") / 255.0 + return na + +def image_save(path, na: numpy.ndarray): + """ + Saves an image of the shape expected by other functions in this module. + However, note this expects a numpy array. + """ + # change type + na = numpy.fmax(numpy.fmin(na * 255.0, 255), 0).astype("uint8") + # shape is now (1,3,h,w), remove 1 + na = na.reshape(3,na.shape[2],na.shape[3]) + # fix shape + na = numpy.moveaxis(na, [0,1,2], [2,0,1]) + # shape is now (h,w,3) + # file + Image.fromarray(na).save(path) + +# The Model + +class Conv3x3Biased: + """ + A 3x3 convolution layer with some utility functions. + """ + def __init__(self, inC, outC, last = False): + # Massively overstate the weights to get them to be focused on, + # since otherwise the biases overrule everything + self.weight = Tensor.uniform(outC, inC, 3, 3) * 16.0 + # Layout-wise, blatant cheat, but serious_mnist does it. I'd guess channels either have to have a size of 1 or whatever the target is? + # Values-wise, entirely different blatant cheat. + # In most cases, use uniform bias, but tiny. + # For the last layer, use just 0.5, constant. + if last: + self.bias = Tensor.zeros(1, outC, 1, 1) + 0.5 + else: + self.bias = Tensor.uniform(1, outC, 1, 1) + + def forward(self, x): + # You might be thinking, "but what about padding?" + # Answer: Tiling is used to stitch everything back together, though you could pad the image before providing it. + return x.conv2d(self.weight).add(self.bias) + + def get_parameters(self) -> list: + return [self.weight, self.bias] + + def load_waifu2x_json(self, layer: dict): + # Weights in this file are outChannel,inChannel,X,Y. + # Not outChannel,inChannel,Y,X. + # Therefore, transpose it before assignment. + # I have long since forgotten how I worked this out. + self.weight.assign(Tensor(layer["weight"]).reshape(shape=self.weight.shape).transpose(order=(0, 1, 3, 2))) + self.bias.assign(Tensor(layer["bias"]).reshape(shape=self.bias.shape)) + +class Vgg7: + """ + The 'vgg7' waifu2x network. + Lower quality and slower than even upconv7 (nevermind cunet), but is very easy to implement and test. + """ + + def __init__(self): + self.conv1 = Conv3x3Biased(3, 32) + self.conv2 = Conv3x3Biased(32, 32) + self.conv3 = Conv3x3Biased(32, 64) + self.conv4 = Conv3x3Biased(64, 64) + self.conv5 = Conv3x3Biased(64, 128) + self.conv6 = Conv3x3Biased(128, 128) + self.conv7 = Conv3x3Biased(128, 3, True) + + def forward(self, x): + """ + Forward pass: Actually runs the network. + Input format: (1, 3, Y, X) + Output format: (1, 3, Y - 14, X - 14) + (the - 14 represents the 7-pixel context border that is lost) + """ + x = self.conv1.forward(x).leakyrelu(0.1) + x = self.conv2.forward(x).leakyrelu(0.1) + x = self.conv3.forward(x).leakyrelu(0.1) + x = self.conv4.forward(x).leakyrelu(0.1) + x = self.conv5.forward(x).leakyrelu(0.1) + x = self.conv6.forward(x).leakyrelu(0.1) + x = self.conv7.forward(x) + return x + + def get_parameters(self) -> list: + return self.conv1.get_parameters() + self.conv2.get_parameters() + self.conv3.get_parameters() + self.conv4.get_parameters() + self.conv5.get_parameters() + self.conv6.get_parameters() + self.conv7.get_parameters() + + def load_waifu2x_json(self, data: list): + """ + Loads weights from one of the waifu2x JSON files, i.e. waifu2x/models/vgg_7/art/noise0_model.json + data (passed in) is assumed to be the output of json.load or some similar on such a file + """ + self.conv1.load_waifu2x_json(data[0]) + self.conv2.load_waifu2x_json(data[1]) + self.conv3.load_waifu2x_json(data[2]) + self.conv4.load_waifu2x_json(data[3]) + self.conv5.load_waifu2x_json(data[4]) + self.conv6.load_waifu2x_json(data[5]) + self.conv7.load_waifu2x_json(data[6]) + + + def forward_tiled(self, image: numpy.ndarray, tile_size: int) -> numpy.ndarray: + """ + Given an ndarray image as loaded by image_load (NOT a tensor), scales it, pads it, splits it up, forwards the pieces, and reconstitutes it. + Note that you really shouldn't try to run anything not (1, 3, *, *) through this. + """ + # Constant that only really gets repeated a ton here. + context = 7 + context2 = context + context + + # Notably, numpy is used here because it makes this fine manipulation a lot simpler. + # Scaling first - repeat on axis 2 and axis 3 (Y & X) + image = image.repeat(2, 2).repeat(2, 3) + + # Resulting image buffer. This is made before the input is padded, + # since the input has the padded shape right now. + image_out = numpy.zeros(image.shape) + + # Padding next. Note that this padding is done on the whole image. + # Padding the tiles would lose critical context, cause seams, etc. + image = numpy.pad(image, [[0, 0], [0, 0], [context, context], [context, context]], mode = "edge") + + # Now for tiling. + # The output tile size is the usable output from an input tile (tile_size). + # As such, the tiles overlap. + out_tile_size = tile_size - context2 + for out_y in range(0, image_out.shape[2], out_tile_size): + for out_x in range(0, image_out.shape[3], out_tile_size): + # Input is sourced from the same coordinates, but some stuff ought to be + # noted here for future reference: + # + out_x/y's equivalent position w/ the padding is out_x + context. + # + The output, however, is without context. Input needs context. + # + Therefore, the input rectangle is expanded on all sides by context. + # + Therefore, the input position has the context subtracted again. + # + Therefore: + in_y = out_y + in_x = out_x + # not shown: in_w/in_h = tile_size (as opposed to out_tile_size) + # Extract tile. + # Note that numpy will auto-crop this at the bottom-right. + # This will never be a problem, as tiles are specifically chosen within the padded section. + tile = image[:, :, in_y:in_y + tile_size, in_x:in_x + tile_size] + # Extracted tile dimensions -> output dimensions + # This is important because of said cropping, otherwise it'd be interior tile size. + out_h = tile.shape[2] - context2 + out_w = tile.shape[3] - context2 + # Process tile. + tile_t = Tensor(tile) + tile_fwd_t = self.forward(tile_t) + # Replace tile. + image_out[:, :, out_y:out_y + out_h, out_x:out_x + out_w] = tile_fwd_t.data + + return image_out +