diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 4da3ca94d6..06b2c22a35 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -197,6 +197,7 @@ for k,v in dat['state_dict'].items(): print(f"{str(v.shape):30s}", w, k) if w is not None: assert w.shape == v.shape + w.assign(v.astype(np.float32)) IMG = "/Users/kafka/fun/mps/stable-diffusion/outputs/txt2img-samples/grid-0006.png" from PIL import Image @@ -230,6 +231,47 @@ tmodel = AutoencoderKL( lossconfig={"target": "torch.nn.Identity"}, embed_dim=4) tmodel.load_state_dict(sd, strict=True) +nz = np.load("datasets/stable_diffusion_apple.npy") +zmodel = model.first_stage_model + +x_torch = torch.tensor(nz) +x_tiny = Tensor(nz) + +x_torch = tmodel.post_quant_conv(x_torch) +x_tiny = zmodel.post_quant_conv(x_tiny) + +x_torch = tmodel.decoder.conv_in(x_torch) +x_tiny = zmodel.decoder.conv_in(x_tiny) + +#x_torch = tmodel.decoder.mid.block_1(x_torch, None) +#x_tiny = zmodel.decoder.mid['block_1'](x_tiny) + +x_torch = tmodel.decoder.mid.block_1.norm1(x_torch) +x_tiny = zmodel.decoder.mid['block_1'].norm1(x_tiny) + +x_torch = x_torch * torch.sigmoid(x_torch) +x_tiny = x_tiny.swish() + +print(zmodel.decoder.mid['block_1'].conv1.weight.shape) +print(x_tiny.shape) + +x_torch = tmodel.decoder.mid.block_1.conv1(x_torch) +x_tiny = zmodel.decoder.mid['block_1'].conv1(x_tiny) + +#print(tmodel.decoder.mid.block_1.conv1.weight) +#print(zmodel.decoder.mid['block_1'].conv1.weight.numpy()) + +print(abs(x_torch.detach().numpy() - x_tiny.numpy()).mean()) +print(x_torch.shape, x_tiny.shape) + +exit(0) + + +#exit(0) + +x = model.first_stage_model.post_quant_conv(x_tiny) +x = model.first_stage_model.decoder(x) +x = x.reshape((3,512,512)).permute((1,2,0)) """ posterior = tmodel.encode(torch.tensor(realimg.numpy())) @@ -240,12 +282,10 @@ nz = z.detach().numpy() np.save("/tmp/apple.npy", nz) exit(0) """ -nz = np.load("datasets/stable_diffusion_apple.npy") -nz *= -1 #x, latent = tmodel(torch.tensor(realimg.numpy())) -x = tmodel.decode(torch.tensor(nz)) -x = x.reshape(3,512,512).permute(1,2,0) +#x = tmodel.decode(torch.tensor(nz)) +#x = x.reshape(3,512,512).permute(1,2,0) """ x = Tensor.randn(1,4,64,64) diff --git a/test/test_ops.py b/test/test_ops.py index 64d203c345..10262c10fb 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -184,6 +184,13 @@ class TestOps(unittest.TestCase): arg = (4,3,2,6) helper_test_op([(4,3,1,6)], lambda x: x.expand(arg), lambda x: x.expand(shape=arg)) + @unittest.skip + def test_sd_big_conv(self): + # internal shape (1, 1, 512, 62, 62, 512, 3, 3) overflows a int + helper_test_op([(1,512,64,64), (512,512,3,3)], + lambda x,w: torch.nn.functional.conv2d(x, w), + lambda x,w: x.conv2d(w), atol=1e-4) + def test_biased_conv2d(self): C = 8 helper_test_op([(1,C,5,5), (C,C,1,1), (C,)], diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 23df087dff..a5bee53cbf 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -116,6 +116,7 @@ class GPUBuffer: def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc") -> GPUBuffer: assert C is None + for _, b in bufs: assert prod(b.shape) < 2**32, f"GPU buffers must be under 2**32, {b.shape} isn't" # get the input/output shape and the reduce amount reduce_shape = (bufs[0][1].shape, ret.shape) if reduce_shape is None else reduce_shape