diff --git a/examples/vgg7.py b/examples/vgg7.py index 0cac968c4a..a4a5835e53 100644 --- a/examples/vgg7.py +++ b/examples/vgg7.py @@ -26,23 +26,23 @@ def set_sample_count(samples_dir, sc): file.write(str(sc) + "\n") if len(sys.argv) < 2: - print("python3 -m examples.vgg7 import MODELJSON MODELDIR") + print("python3 -m examples.vgg7 import MODELJSON MODEL") print(" imports a waifu2x JSON vgg_7 model, i.e. waifu2x/models/vgg_7/art/scale2.0x_model.json") print(" into a safetensors file") print(" weight tensors are ordered in tinygrad/ncnn form, as so: (outC,inC,H,W)") print(" *this format is used by most other commands in this program*") - print("python3 -m examples.vgg7 import_kinne MODELDIR MODEL_SAFETENSORS") + print("python3 -m examples.vgg7 import_kinne MODEL_KINNE MODEL_SAFETENSORS") print(" imports a model in 'KINNE' format (raw floats: used by older versions of this example) into safetensors") - print("python3 -m examples.vgg7 execute MODELDIR IMG_IN IMG_OUT") + print("python3 -m examples.vgg7 execute MODEL IMG_IN IMG_OUT") print(" given an already-nearest-neighbour-scaled image, runs vgg7 on it") print(" output image has 7 pixels removed on all edges") print(" do not run on large images, will have *hilarious* RAM use") - print("python3 -m examples.vgg7 execute_full MODELDIR IMG_IN IMG_OUT") + print("python3 -m examples.vgg7 execute_full MODEL IMG_IN IMG_OUT") print(" does the 'whole thing' (padding, tiling)") print(" safe for large images, etc.") - print("python3 -m examples.vgg7 new MODELDIR") + print("python3 -m examples.vgg7 new MODEL") print(" creates a new model (experimental)") - print("python3 -m examples.vgg7 train MODELDIR SAMPLES_DIR ROUNDS ROUNDS_SAVE") + print("python3 -m examples.vgg7 train MODEL SAMPLES_DIR ROUNDS ROUNDS_SAVE") print(" trains a model (experimental)") print(" (how experimental? well, every time I tried it, it flooded w/ NaNs)") print(" note: ROUNDS < 0 means 'forever'. ROUNDS_SAVE <= 0 is not a good idea.") @@ -130,7 +130,7 @@ elif cmd == "train": # This is used to try and get the network to focus on "interesting" samples, # which works nicely with the microsample system. sample_probs = None - sample_probs_path = model + "/sample_probs.bin" + sample_probs_path = model + "_sample_probs.bin" try: # try to read... sample_probs = numpy.fromfile(sample_probs_path, " numpy.ndarray: """ # file na = numpy.array(Image.open(path)) + if na.shape[2] == 4: + # RGBA -> RGB (covers opaque images with alpha channels) + na = na[:,:,0:3] # fix shape na = numpy.moveaxis(na, [2,0,1], [0,1,2]) # shape is now (3,h,w), add 1 @@ -113,6 +118,19 @@ class Vgg7: def get_parameters(self) -> list: return self.conv1.get_parameters() + self.conv2.get_parameters() + self.conv3.get_parameters() + self.conv4.get_parameters() + self.conv5.get_parameters() + self.conv6.get_parameters() + self.conv7.get_parameters() + def load_from_pretrained(self, intent = "art", subtype = "scale2.0x"): + """ + Downloads a nagadomi/waifu2x JSON weight file and loads it. + """ + fn = Path(__file__).parents[2] / ("weights/vgg_7_" + intent + "_" + subtype + "_model.json") + download_file("https://github.com/nagadomi/waifu2x/raw/master/models/vgg_7/" + intent + "/" + subtype + "_model.json", fn) + + import json + with open(fn, "rb") as f: + data = json.load(f) + + self.load_waifu2x_json(data) + def load_waifu2x_json(self, data: list): """ Loads weights from one of the waifu2x JSON files, i.e. waifu2x/models/vgg_7/art/noise0_model.json @@ -126,7 +144,6 @@ class Vgg7: self.conv6.load_waifu2x_json(data[5]) self.conv7.load_waifu2x_json(data[6]) - def forward_tiled(self, image: numpy.ndarray, tile_size: int) -> numpy.ndarray: """ Given an ndarray image as loaded by image_load (NOT a tensor), scales it, pads it, splits it up, forwards the pieces, and reconstitutes it. diff --git a/test/models/test_waifu2x.py b/test/models/test_waifu2x.py new file mode 100644 index 0000000000..0b34ae0356 --- /dev/null +++ b/test/models/test_waifu2x.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +import pathlib +import unittest +import numpy as np +from tinygrad.tensor import Tensor +from tinygrad.ops import Device + +class TestVGG7(unittest.TestCase): + def test_vgg7(self): + from examples.vgg7_helpers.waifu2x import Vgg7, image_load + + # Create in tinygrad + Tensor.manual_seed(1337) + mdl = Vgg7() + mdl.load_from_pretrained() + + # Scale up an image + test_x = image_load(pathlib.Path(__file__).parent / 'waifu2x/input.png') + test_y = image_load(pathlib.Path(__file__).parent / 'waifu2x/output.png') + scaled = mdl.forward_tiled(test_x, 156) + scaled = np.fmax(0, np.fmin(1, scaled)) + np.testing.assert_allclose(scaled, test_y, atol=5e-3, rtol=5e-3) + +if __name__ == '__main__': + unittest.main() diff --git a/test/models/waifu2x/input.png b/test/models/waifu2x/input.png new file mode 100644 index 0000000000..9ae415a953 Binary files /dev/null and b/test/models/waifu2x/input.png differ diff --git a/test/models/waifu2x/output.png b/test/models/waifu2x/output.png new file mode 100644 index 0000000000..b105a2e2cf Binary files /dev/null and b/test/models/waifu2x/output.png differ