mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add lockstep tracer based on TorchMLIR eager mode + examples (#243)
This commit is contained in:
73
shark/examples/shark_eager/squeezenet_lockstep.py
Normal file
73
shark/examples/shark_eager/squeezenet_lockstep.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
model = torch.hub.load(
|
||||
"pytorch/vision:v0.10.0", "squeezenet1_0", pretrained=True
|
||||
)
|
||||
model.eval()
|
||||
|
||||
# from PIL import Image
|
||||
# from torchvision import transforms
|
||||
# import urllib
|
||||
#
|
||||
# url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
|
||||
# try: urllib.URLopener().retrieve(url, filename)
|
||||
# except: urllib.request.urlretrieve(url, filename)
|
||||
#
|
||||
#
|
||||
# input_image = Image.open(filename)
|
||||
# preprocess = transforms.Compose([
|
||||
# transforms.Resize(256),
|
||||
# transforms.CenterCrop(224),
|
||||
# transforms.ToTensor(),
|
||||
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
# ])
|
||||
# input_tensor = preprocess(input_image)
|
||||
# input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
|
||||
# print(input_batch.shape) # size = [1, 3, 224, 224]
|
||||
|
||||
# The above is code for generating sample inputs from an image. We can just use
|
||||
# random values for accuracy testing though
|
||||
input_batch = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
# Focus on CPU for now
|
||||
if False and torch.cuda.is_available():
|
||||
input_batch = input_batch.to("cuda")
|
||||
model.to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_batch)
|
||||
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
|
||||
golden_confidences = output[0]
|
||||
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
|
||||
golden_probabilities = torch.nn.functional.softmax(
|
||||
golden_confidences, dim=0
|
||||
).numpy()
|
||||
|
||||
golden_confidences = golden_confidences.numpy()
|
||||
|
||||
from shark.torch_mlir_lockstep_tensor import TorchMLIRLockstepTensor
|
||||
|
||||
input_detached_clone = input_batch.clone()
|
||||
eager_input_batch = TorchMLIRLockstepTensor(input_detached_clone)
|
||||
|
||||
print("getting torch-mlir result")
|
||||
|
||||
output = model(eager_input_batch)
|
||||
|
||||
static_output = output.elem
|
||||
confidences = static_output[0]
|
||||
probabilities = torch.nn.functional.softmax(
|
||||
torch.from_numpy(confidences), dim=0
|
||||
).numpy()
|
||||
|
||||
print("The obtained result via shark is: ", confidences)
|
||||
print("The golden result is:", golden_confidences)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
golden_confidences, confidences, rtol=1e-02, atol=1e-03
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
golden_probabilities, probabilities, rtol=1e-02, atol=1e-03
|
||||
)
|
||||
206
shark/torch_mlir_lockstep_tensor.py
Normal file
206
shark/torch_mlir_lockstep_tensor.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
import contextlib
|
||||
import re
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from torch_mlir.eager_mode.ir_building import build_mlir_module
|
||||
from torch_mlir.eager_mode.torch_mlir_dispatch import (
|
||||
UnsupportedByTorchMlirEagerMode,
|
||||
normalize_args_kwargs,
|
||||
check_get_aliased_arg,
|
||||
)
|
||||
from torch_mlir.eager_mode import EAGER_MODE_DEBUG
|
||||
from torch_mlir.eager_mode.torch_mlir_tensor import (
|
||||
TorchMLIRTensor,
|
||||
check_requires_grad,
|
||||
make_wrapper_subclass_from_torch_tensor,
|
||||
make_bare_wrapper_subclass,
|
||||
UNSUPPORTED_OPS,
|
||||
no_dispatch,
|
||||
)
|
||||
from torch_mlir.eager_mode import torch_mlir_tensor
|
||||
from shark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend
|
||||
|
||||
|
||||
backend = EagerModeIREELinalgOnTensorsBackend("cpu")
|
||||
torch_mlir_tensor.backend = backend
|
||||
rtol = 1e-04
|
||||
atol = 1e-05
|
||||
|
||||
|
||||
class TorchMLIRLockstepTensor(TorchMLIRTensor):
|
||||
"""This class overrides the dispatching for TorchMLIRTensor to allow for an op-by-op numerical comparison between PyTorch and the Torch-MLIR -> IREE backend compilation pipeline. This only supports the IREE backend and focuses on op-by-op level verification.
|
||||
|
||||
TODO: Extend this to do a cumulative trace with summary statistics at the end. Possibly requires a wrapper environment to store full trace info.
|
||||
"""
|
||||
|
||||
def __new__(cls, elem, **kwargs):
|
||||
if kwargs.get("constructing_from_device_tensor", False):
|
||||
tensor_meta_data = backend.get_torch_metadata(elem, kwargs)
|
||||
r = make_bare_wrapper_subclass(
|
||||
cls=cls,
|
||||
size=tensor_meta_data.size,
|
||||
strides=tensor_meta_data.strides,
|
||||
storage_offset=tensor_meta_data.storage_offset,
|
||||
dtype=tensor_meta_data.dtype,
|
||||
layout=tensor_meta_data.layout,
|
||||
device=tensor_meta_data.device,
|
||||
requires_grad=tensor_meta_data.requires_grad,
|
||||
)
|
||||
r.elem = elem
|
||||
elif isinstance(elem, torch.nn.Parameter):
|
||||
r = make_wrapper_subclass_from_torch_tensor(
|
||||
cls, elem.data, **kwargs
|
||||
)
|
||||
# This is a hack to handle non-contiguous data through IREE-backend
|
||||
nt = elem.detach().data.numpy()
|
||||
if not nt.flags["C_CONTIGUOUS"]:
|
||||
nt = np.ascontiguousarray(nt, dtype=nt.dtype)
|
||||
r.elem = backend.transfer_from_torch_to_device(torch.Tensor(nt))
|
||||
elif isinstance(elem, torch.Tensor):
|
||||
r = make_wrapper_subclass_from_torch_tensor(cls, elem, **kwargs)
|
||||
# Ditto TODO: Find a better way to handle this
|
||||
nt = elem.numpy()
|
||||
if not nt.flags["C_CONTIGUOUS"]:
|
||||
nt = np.ascontiguousarray(nt, dtype=nt.dtype)
|
||||
r.elem = backend.transfer_from_torch_to_device(torch.Tensor(nt))
|
||||
# This branch handles the case when a python scalar is passed to some op
|
||||
# or is returned from some aten op, such as _local_scalar_dense.
|
||||
elif isinstance(elem, (int, float, bool)):
|
||||
return elem
|
||||
else:
|
||||
raise ValueError(f"Unknown element type: {type(elem)}")
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
if self.grad_fn:
|
||||
return f"TorchMLIRLockstepTensor({self.elem}, backend={backend.__class__.__name__}, grad_fn={self.grad_fn})"
|
||||
else:
|
||||
return f"TorchMLIRLockstepTensor({self.elem}, backend={backend.__class__.__name__})"
|
||||
|
||||
"""This does essentially the same dispatch as TorchMLIRTensor but operates as if debug mode is enabled. The numeric verification happens after the Torch-MLIR result is obtained by comparing against the
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, _types, args=(), kwargs=None):
|
||||
requires_grad = check_requires_grad(*args, **kwargs)
|
||||
try:
|
||||
with no_dispatch():
|
||||
if hasattr(func, "op_name"):
|
||||
op_name = func.op_name
|
||||
elif hasattr(func, "__name__"):
|
||||
# Handle builtin_function_or_method.
|
||||
op_name = func.__name__
|
||||
else:
|
||||
raise RuntimeError(f"op {func} has no name")
|
||||
|
||||
if UNSUPPORTED_OPS.match(op_name):
|
||||
raise UnsupportedByTorchMlirEagerMode(op_name)
|
||||
|
||||
if not hasattr(func, "_schema"):
|
||||
raise RuntimeError(f"op {func} has no schema.")
|
||||
|
||||
normalized_kwargs = normalize_args_kwargs(func, args, kwargs)
|
||||
|
||||
if "layout" in normalized_kwargs and normalized_kwargs[
|
||||
"layout"
|
||||
] not in {0, None}:
|
||||
raise UnsupportedByTorchMlirEagerMode(
|
||||
f"{normalized_kwargs['layout']} layout not supported."
|
||||
)
|
||||
if "memory_format" in normalized_kwargs and normalized_kwargs[
|
||||
"memory_format"
|
||||
] not in {0, None}:
|
||||
raise UnsupportedByTorchMlirEagerMode(
|
||||
f"{normalized_kwargs['memory_format']} memory format not supported."
|
||||
)
|
||||
eager_module = build_mlir_module(func, normalized_kwargs)
|
||||
device_tensor_args = [
|
||||
kwarg.elem
|
||||
for _, kwarg in normalized_kwargs.items()
|
||||
if isinstance(kwarg, cls)
|
||||
]
|
||||
assert len(eager_module.body.operations[0].arguments) == len(
|
||||
device_tensor_args
|
||||
), "Number of parameters and number of arguments differs."
|
||||
op_mlir_backend_callable = backend.compile(eager_module)
|
||||
out = op_mlir_backend_callable(*device_tensor_args)
|
||||
out = tree_map(
|
||||
lambda x: cls(
|
||||
x,
|
||||
requires_grad=requires_grad,
|
||||
constructing_from_device_tensor=True,
|
||||
),
|
||||
out,
|
||||
)
|
||||
|
||||
# Numeric verification; Value for comparison comes from PyTorch eager
|
||||
with no_dispatch():
|
||||
unwrapped_args = tree_map(cls.unwrap, args)
|
||||
unwrapped_kwargs = tree_map(cls.unwrap, kwargs)
|
||||
native_out = func(*unwrapped_args, **unwrapped_kwargs)
|
||||
|
||||
native_out = tree_map(
|
||||
lambda x: cls(x, requires_grad=requires_grad), native_out
|
||||
).elem
|
||||
tmp_out = out.elem
|
||||
|
||||
try:
|
||||
np.testing.assert_allclose(
|
||||
native_out.to_host(),
|
||||
tmp_out.to_host(),
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
)
|
||||
except Exception as e:
|
||||
shaped_args = [
|
||||
arg.shape if torch.is_tensor(arg) else arg
|
||||
for arg in unwrapped_args
|
||||
]
|
||||
shaped_kwargs = [
|
||||
kwarg.shape if torch.is_tensor(kwarg) else kwarg
|
||||
for kwarg in unwrapped_kwargs
|
||||
]
|
||||
warnings.warn(
|
||||
f"Lockstep accuracy verification failed with error: *{str(e)}*; "
|
||||
f"Dispatched function name: *{str(func)}*; "
|
||||
f"Dispatched function args: *{str(shaped_args)}*; "
|
||||
f"Dispatched function kwargs: *{str(shaped_kwargs)}*; "
|
||||
)
|
||||
except Exception as e:
|
||||
warnings.warn(traceback.format_exc())
|
||||
if isinstance(e, UnsupportedByTorchMlirEagerMode):
|
||||
warnings.warn(
|
||||
f"Couldn't use TorchMLIR eager because current incompatibility: *{str(e)}*; running through PyTorch eager."
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Couldn't use TorchMLIR eager because of error: *{str(e)}*; "
|
||||
f"Running through PyTorch eager"
|
||||
)
|
||||
|
||||
with no_dispatch():
|
||||
unwrapped_args = tree_map(cls.unwrap, args)
|
||||
unwrapped_kwargs = tree_map(cls.unwrap, kwargs)
|
||||
out = func(*unwrapped_args, **unwrapped_kwargs)
|
||||
|
||||
out = tree_map(lambda x: cls(x, requires_grad=requires_grad), out)
|
||||
|
||||
maybe_aliased_arg_name = check_get_aliased_arg(func)
|
||||
if maybe_aliased_arg_name is not None:
|
||||
warnings.warn(
|
||||
f"Found aliased arg, but didn't copy tensor contents. This could lead to incorrect results for E2E model execution but doesn't affect the validity of the lockstep op verification."
|
||||
)
|
||||
# TODO: Find a way to handle argument aliasing for IREE backend
|
||||
# backend.copy_into(normalized_kwargs[maybe_aliased_arg_name].elem, out.elem)
|
||||
|
||||
return out
|
||||
285
tank/pytorch/v_diffusion_pytorch/cfg_sample_eager.py
Executable file
285
tank/pytorch/v_diffusion_pytorch/cfg_sample_eager.py
Executable file
@@ -0,0 +1,285 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""Classifier-free guidance sampling from a diffusion model."""
|
||||
|
||||
import argparse
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import functional as TF
|
||||
from tqdm import trange
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.torch_mlir_lockstep_tensor import TorchMLIRLockstepTensor
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.append("v-diffusion-pytorch")
|
||||
from CLIP import clip
|
||||
from diffusion import get_model, get_models, sampling, utils
|
||||
|
||||
MODULE_DIR = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
def parse_prompt(prompt, default_weight=3.0):
|
||||
if prompt.startswith("http://") or prompt.startswith("https://"):
|
||||
vals = prompt.rsplit(":", 2)
|
||||
vals = [vals[0] + ":" + vals[1], *vals[2:]]
|
||||
else:
|
||||
vals = prompt.rsplit(":", 1)
|
||||
vals = vals + ["", default_weight][len(vals) :]
|
||||
return vals[0], float(vals[1])
|
||||
|
||||
|
||||
def resize_and_center_crop(image, size):
|
||||
fac = max(size[0] / image.size[0], size[1] / image.size[1])
|
||||
image = image.resize(
|
||||
(int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS
|
||||
)
|
||||
return TF.center_crop(image, size[::-1])
|
||||
|
||||
|
||||
# def main():
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
p.add_argument(
|
||||
"prompts", type=str, default=[], nargs="*", help="the text prompts to use"
|
||||
)
|
||||
p.add_argument(
|
||||
"--images",
|
||||
type=str,
|
||||
default=[],
|
||||
nargs="*",
|
||||
metavar="IMAGE",
|
||||
help="the image prompts",
|
||||
)
|
||||
p.add_argument(
|
||||
"--batch-size",
|
||||
"-bs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="the number of images per batch",
|
||||
)
|
||||
p.add_argument("--checkpoint", type=str, help="the checkpoint to use")
|
||||
p.add_argument("--device", type=str, help="the device to use")
|
||||
p.add_argument(
|
||||
"--eta",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="the amount of noise to add during sampling (0-1)",
|
||||
)
|
||||
p.add_argument("--init", type=str, help="the init image")
|
||||
p.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="plms",
|
||||
choices=["ddpm", "ddim", "prk", "plms", "pie", "plms2", "iplms"],
|
||||
help="the sampling method to use",
|
||||
)
|
||||
p.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="cc12m_1_cfg",
|
||||
choices=["cc12m_1_cfg"],
|
||||
help="the model to use",
|
||||
)
|
||||
p.add_argument(
|
||||
"-n", type=int, default=1, help="the number of images to sample"
|
||||
)
|
||||
p.add_argument("--seed", type=int, default=0, help="the random seed")
|
||||
p.add_argument("--size", type=int, nargs=2, help="the output image size")
|
||||
p.add_argument(
|
||||
"--starting-timestep",
|
||||
"-st",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="the timestep to start at (used with init images)",
|
||||
)
|
||||
p.add_argument("--steps", type=int, default=50, help="the number of timesteps")
|
||||
args = p.parse_args()
|
||||
|
||||
if args.device:
|
||||
device = torch.device(args.device)
|
||||
else:
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print("Using device:", device)
|
||||
|
||||
model = get_model(args.model)()
|
||||
_, side_y, side_x = model.shape
|
||||
if args.size:
|
||||
side_x, side_y = args.size
|
||||
checkpoint = args.checkpoint
|
||||
if not checkpoint:
|
||||
checkpoint = MODULE_DIR / f"checkpoints/{args.model}.pth"
|
||||
model.load_state_dict(torch.load(checkpoint, map_location="cpu"))
|
||||
if device.type == "cuda":
|
||||
model = model.half()
|
||||
model = model.to(device).eval().requires_grad_(False)
|
||||
clip_model_name = (
|
||||
model.clip_model if hasattr(model, "clip_model") else "ViT-B/16"
|
||||
)
|
||||
clip_model = clip.load(clip_model_name, jit=False, device=device)[0]
|
||||
clip_model.eval().requires_grad_(False)
|
||||
normalize = transforms.Normalize(
|
||||
mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
std=[0.26862954, 0.26130258, 0.27577711],
|
||||
)
|
||||
|
||||
if args.init:
|
||||
init = Image.open(utils.fetch(args.init)).convert("RGB")
|
||||
init = resize_and_center_crop(init, (side_x, side_y))
|
||||
init = (
|
||||
utils.from_pil_image(init).to(device)[None].repeat([args.n, 1, 1, 1])
|
||||
)
|
||||
|
||||
zero_embed = torch.zeros([1, clip_model.visual.output_dim], device=device)
|
||||
target_embeds, weights = [zero_embed], []
|
||||
|
||||
for prompt in args.prompts:
|
||||
txt, weight = parse_prompt(prompt)
|
||||
target_embeds.append(
|
||||
clip_model.encode_text(clip.tokenize(txt).to(device)).float()
|
||||
)
|
||||
weights.append(weight)
|
||||
|
||||
for prompt in args.images:
|
||||
path, weight = parse_prompt(prompt)
|
||||
img = Image.open(utils.fetch(path)).convert("RGB")
|
||||
clip_size = clip_model.visual.input_resolution
|
||||
img = resize_and_center_crop(img, (clip_size, clip_size))
|
||||
batch = TF.to_tensor(img)[None].to(device)
|
||||
embed = F.normalize(
|
||||
clip_model.encode_image(normalize(batch)).float(), dim=-1
|
||||
)
|
||||
target_embeds.append(embed)
|
||||
weights.append(weight)
|
||||
|
||||
weights = torch.tensor([1 - sum(weights), *weights], device=device)
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
|
||||
def cfg_model_fn(x, t):
|
||||
n = x.shape[0]
|
||||
n_conds = len(target_embeds)
|
||||
x_in = x.repeat([n_conds, 1, 1, 1])
|
||||
t_in = t.repeat([n_conds])
|
||||
clip_embed_in = torch.cat([*target_embeds]).repeat([n, 1])
|
||||
vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
|
||||
v = vs.mul(weights[:, None, None, None, None]).sum(0)
|
||||
return v
|
||||
|
||||
|
||||
x = torch.randn([args.n, 3, side_y, side_x], device=device)
|
||||
t = torch.linspace(1, 0, args.steps + 1, device=device)[:-1]
|
||||
steps = utils.get_spliced_ddpm_cosine_schedule(t)
|
||||
min_batch_size = min(args.n, args.batch_size)
|
||||
x_in = x[0:min_batch_size, :, :, :]
|
||||
ts = x_in.new_ones([x_in.shape[0]])
|
||||
t_in = t[0] * ts
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
import torch_mlir
|
||||
|
||||
fx_g = make_fx(
|
||||
cfg_model_fn,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(x_in, t_in)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
|
||||
# module = torch_mlir.compile(
|
||||
# ts_g,
|
||||
# [x_in, t_in],
|
||||
# torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
# use_tracing=False,
|
||||
# )
|
||||
#
|
||||
# mlir_model = module
|
||||
# func_name = "forward"
|
||||
#
|
||||
# shark_module = SharkInference(
|
||||
# mlir_model, func_name, device="gpu", mlir_dialect="linalg"
|
||||
# )
|
||||
# shark_module.compile()
|
||||
|
||||
|
||||
def compiled_cfg_model_fn(x, t):
|
||||
x_in_eager = TorchMLIRLockstepTensor(x.clone())
|
||||
t_in_eager = TorchMLIRLockstepTensor(t.clone())
|
||||
return ts_g(x_in_eager, t_in_eager)
|
||||
|
||||
|
||||
def run(x, steps):
|
||||
if args.method == "ddpm":
|
||||
return sampling.sample(compiled_cfg_model_fn, x, steps, 1.0, {})
|
||||
if args.method == "ddim":
|
||||
return sampling.sample(compiled_cfg_model_fn, x, steps, args.eta, {})
|
||||
if args.method == "prk":
|
||||
return sampling.prk_sample(compiled_cfg_model_fn, x, steps, {})
|
||||
if args.method == "plms":
|
||||
return sampling.plms_sample(compiled_cfg_model_fn, x, steps, {})
|
||||
if args.method == "pie":
|
||||
return sampling.pie_sample(compiled_cfg_model_fn, x, steps, {})
|
||||
if args.method == "plms2":
|
||||
return sampling.plms2_sample(compiled_cfg_model_fn, x, steps, {})
|
||||
if args.method == "iplms":
|
||||
return sampling.iplms_sample(compiled_cfg_model_fn, x, steps, {})
|
||||
assert False
|
||||
|
||||
|
||||
def run_all(x, t, steps, n, batch_size):
|
||||
x = torch.randn([n, 3, side_y, side_x], device=device)
|
||||
t = torch.linspace(1, 0, args.steps + 1, device=device)[:-1]
|
||||
steps = utils.get_spliced_ddpm_cosine_schedule(t)
|
||||
if args.init:
|
||||
steps = steps[steps < args.starting_timestep]
|
||||
alpha, sigma = utils.t_to_alpha_sigma(steps[0])
|
||||
x = init * alpha + x * sigma
|
||||
for i in trange(0, n, batch_size):
|
||||
cur_batch_size = min(n - i, batch_size)
|
||||
outs = run(x[i : i + cur_batch_size], steps)
|
||||
for j, out in enumerate(outs):
|
||||
utils.to_pil_image(out).save(f"out_{i + j:05}.png")
|
||||
|
||||
|
||||
steps = 1
|
||||
|
||||
run_all(x, t, steps, args.n, args.batch_size)
|
||||
Reference in New Issue
Block a user