mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Tweaks to VAE, pipeline states
This commit is contained in:
@@ -33,21 +33,22 @@ from apps.shark_studio.modules.ckpt_processing import (
|
||||
process_custom_pipe_weights,
|
||||
)
|
||||
from transformers import CLIPTokenizer
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from math import ceil
|
||||
from PIL import Image
|
||||
|
||||
sd_model_map = {
|
||||
"clip": {
|
||||
"initializer": clip.export_clip_model,
|
||||
"external_weight_file": None,
|
||||
"ireec_flags": [
|
||||
"--iree-flow-collapse-reduction-dims",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
],
|
||||
},
|
||||
"vae_encode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"external_weight_file": None,
|
||||
"ireec_flags": [
|
||||
"--iree-flow-collapse-reduction-dims",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
@@ -64,11 +65,9 @@ sd_model_map = {
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))",
|
||||
],
|
||||
"external_weight_file": None,
|
||||
},
|
||||
"vae_decode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"external_weight_file": None,
|
||||
"ireec_flags": [
|
||||
"--iree-flow-collapse-reduction-dims",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
@@ -109,7 +108,6 @@ class StableDiffusion(SharkPipelineBase):
|
||||
self.scheduler_obj = {}
|
||||
static_kwargs = {
|
||||
"pipe": {
|
||||
"external_weight_path": get_checkpoints_path(),
|
||||
"external_weights": "safetensors",
|
||||
},
|
||||
"clip": {"hf_model_name": base_model_id},
|
||||
@@ -130,7 +128,6 @@ class StableDiffusion(SharkPipelineBase):
|
||||
"hf_model_name": base_model_id,
|
||||
"vae_model": vae.VaeModel(
|
||||
hf_model_name=base_model_id,
|
||||
base_vae=False,
|
||||
custom_vae=custom_vae,
|
||||
),
|
||||
"batch_size": batch_size,
|
||||
@@ -142,7 +139,6 @@ class StableDiffusion(SharkPipelineBase):
|
||||
"hf_model_name": base_model_id,
|
||||
"vae_model": vae.VaeModel(
|
||||
hf_model_name=base_model_id,
|
||||
base_vae=False,
|
||||
custom_vae=custom_vae,
|
||||
),
|
||||
"batch_size": batch_size,
|
||||
@@ -177,115 +173,28 @@ class StableDiffusion(SharkPipelineBase):
|
||||
)
|
||||
self.is_img2img = is_img2img
|
||||
schedulers = get_schedulers(self.base_model_id)
|
||||
self.weights_path = get_checkpoints_path(
|
||||
os.path.join("..", self.safe_name(self.pipe_id))
|
||||
)
|
||||
self.scheduler = schedulers[scheduler]
|
||||
self.image_processor = VaeImageProcessor()#do_convert_rgb=True)
|
||||
self.weights_path = os.path.join(get_checkpoints_path(), self.safe_name(self.base_model_id))
|
||||
if not os.path.exists(self.weights_path):
|
||||
os.mkdir(self.weights_path)
|
||||
self.scheduler = schedulers[scheduler]
|
||||
print(f"[LOG] Loaded scheduler: {scheduler}")
|
||||
for model in adapters:
|
||||
self.model_map[model] = adapters[model]
|
||||
if custom_weights:
|
||||
if os.path.isfile(custom_weights):
|
||||
for i in self.model_map:
|
||||
self.model_map[i]["external_weights_file"] = None
|
||||
elif custom_weights:
|
||||
print(
|
||||
f"\n[LOG][WARNING] Custom weights were not found at {custom_weights}. Did you mean to pass a base model ID?"
|
||||
)
|
||||
self.static_kwargs["pipe"] = {
|
||||
# "external_weight_path": self.weights_path,
|
||||
# "external_weights": "safetensors",
|
||||
}
|
||||
|
||||
for submodel in self.static_kwargs:
|
||||
if custom_weights:
|
||||
if submodel not in ["clip", "clip2"]:
|
||||
self.static_kwargs[submodel]["external_weight_file"] = custom_weights
|
||||
else:
|
||||
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(self.weights_path, submodel + ".safetensors")
|
||||
else:
|
||||
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(self.weights_path, submodel + ".safetensors")
|
||||
self.get_compiled_map(pipe_id=self.pipe_id)
|
||||
print("\n[LOG] Pipeline successfully prepared for runtime.")
|
||||
return
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
control_mode,
|
||||
hints,
|
||||
):
|
||||
# TODO: Batched args
|
||||
self.ondemand = ondemand
|
||||
if self.is_img2img:
|
||||
image, _ = self.process_sd_init_image(image, resample_type)
|
||||
else:
|
||||
image = None
|
||||
|
||||
print("\n[LOG] Generating images...")
|
||||
batched_args = [
|
||||
prompt,
|
||||
negative_prompt,
|
||||
# steps,
|
||||
# strength,
|
||||
# guidance_scale,
|
||||
# seed,
|
||||
# resample_type,
|
||||
# control_mode,
|
||||
# hints,
|
||||
]
|
||||
for arg in batched_args:
|
||||
if not isinstance(arg, list):
|
||||
arg = [arg] * self.batch_size
|
||||
if len(arg) < self.batch_size:
|
||||
arg = arg * self.batch_size
|
||||
else:
|
||||
arg = [arg[i] for i in range(self.batch_size)]
|
||||
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
)
|
||||
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
init_latents, final_timesteps = self.prepare_latents(
|
||||
generator=generator,
|
||||
num_inference_steps=steps,
|
||||
image=image,
|
||||
strength=strength,
|
||||
)
|
||||
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=final_timesteps,
|
||||
cpu_scheduling=True, # until we have schedulers through Turbine
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
self.load_submodels(["vae_decode"])
|
||||
for i in tqdm(range(0, latents.shape[0], self.batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + self.batch_size],
|
||||
use_base_vae=False,
|
||||
cpu_scheduling=True,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_submodels(["vae_decode"])
|
||||
|
||||
return all_imgs
|
||||
|
||||
|
||||
def encode_prompts_weight(
|
||||
self,
|
||||
prompt,
|
||||
@@ -446,10 +355,7 @@ class StableDiffusion(SharkPipelineBase):
|
||||
all_latents = torch.cat(latent_history, dim=0)
|
||||
return all_latents
|
||||
|
||||
def decode_latents(self, latents, use_base_vae=False, cpu_scheduling=True):
|
||||
if use_base_vae:
|
||||
latents = 1 / 0.18215 * latents
|
||||
|
||||
def decode_latents(self, latents, cpu_scheduling=True):
|
||||
latents_numpy = latents.to(self.dtype)
|
||||
if cpu_scheduling:
|
||||
latents_numpy = latents.detach().numpy()
|
||||
@@ -460,13 +366,9 @@ class StableDiffusion(SharkPipelineBase):
|
||||
vae_inf_time = (time.time() - vae_start) * 1000
|
||||
# end_profiling(profile_device)
|
||||
print(f"\n[LOG] VAE Inference time (ms): {vae_inf_time:.3f}")
|
||||
if use_base_vae:
|
||||
images = torch.from_numpy(images)
|
||||
images = (images.detach().cpu() * 255.0).numpy()
|
||||
images = images.round()
|
||||
|
||||
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
pil_images = [Image.fromarray(image).convert("RGB") for image in images.numpy()]
|
||||
images = torch.from_numpy(images).permute(0, 2, 3, 1).float().numpy()
|
||||
pil_images = self.image_processor.numpy_to_pil(images)
|
||||
return pil_images
|
||||
|
||||
def process_sd_init_image(self, sd_init_image, resample_type):
|
||||
@@ -508,6 +410,83 @@ class StableDiffusion(SharkPipelineBase):
|
||||
is_img2img = True
|
||||
image = image_arr
|
||||
return image, is_img2img
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
control_mode,
|
||||
hints,
|
||||
):
|
||||
# TODO: Batched args
|
||||
self.ondemand = ondemand
|
||||
if self.is_img2img:
|
||||
image, _ = self.process_sd_init_image(image, resample_type)
|
||||
else:
|
||||
image = None
|
||||
|
||||
print("\n[LOG] Generating images...")
|
||||
batched_args = [
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
]
|
||||
for arg in batched_args:
|
||||
if not isinstance(arg, list):
|
||||
arg = [arg] * self.batch_size
|
||||
if len(arg) < self.batch_size:
|
||||
arg = arg * self.batch_size
|
||||
else:
|
||||
arg = [arg[i] for i in range(self.batch_size)]
|
||||
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
)
|
||||
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
init_latents, final_timesteps = self.prepare_latents(
|
||||
generator=generator,
|
||||
num_inference_steps=steps,
|
||||
image=image,
|
||||
strength=strength,
|
||||
)
|
||||
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=final_timesteps,
|
||||
cpu_scheduling=True, # until we have schedulers through Turbine
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
self.load_submodels(["vae_decode"])
|
||||
for i in tqdm(range(0, latents.shape[0], self.batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + self.batch_size],
|
||||
cpu_scheduling=True,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_submodels(["vae_decode"])
|
||||
|
||||
return all_imgs
|
||||
|
||||
|
||||
def shark_sd_fn_dict_input(
|
||||
@@ -622,7 +601,6 @@ def shark_sd_fn(
|
||||
"control_mode": control_mode,
|
||||
"hints": hints,
|
||||
}
|
||||
print(submit_pipe_kwargs)
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
|
||||
@@ -630,7 +608,6 @@ def shark_sd_fn(
|
||||
print("\n[LOG] Initializing new pipeline...")
|
||||
global_obj.clear_cache()
|
||||
gc.collect()
|
||||
global_obj.set_pipe_kwargs(submit_pipe_kwargs)
|
||||
|
||||
# Initializes the pipeline and retrieves IR based on all
|
||||
# parameters that are static in the turbine output format,
|
||||
@@ -640,8 +617,14 @@ def shark_sd_fn(
|
||||
**submit_pipe_kwargs,
|
||||
)
|
||||
global_obj.set_sd_obj(sd_pipe)
|
||||
global_obj.set_pipe_kwargs(submit_pipe_kwargs)
|
||||
if (
|
||||
not global_obj.get_prep_kwargs()
|
||||
or global_obj.get_prep_kwargs() != submit_prep_kwargs
|
||||
):
|
||||
global_obj.set_prep_kwargs(submit_prep_kwargs)
|
||||
global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs)
|
||||
|
||||
global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs)
|
||||
generated_imgs = []
|
||||
for current_batch in range(batch_count):
|
||||
start_time = time.time()
|
||||
@@ -653,7 +636,7 @@ def shark_sd_fn(
|
||||
# break
|
||||
# else:
|
||||
save_output_img(
|
||||
out_imgs[0],
|
||||
out_imgs[current_batch],
|
||||
seed,
|
||||
sd_kwargs,
|
||||
)
|
||||
|
||||
@@ -35,6 +35,7 @@ class SharkPipelineBase:
|
||||
import_mlir: bool = True,
|
||||
):
|
||||
self.model_map = model_map
|
||||
self.pipe_map = {}
|
||||
self.static_kwargs = static_kwargs
|
||||
self.base_model_id = base_model_id
|
||||
self.triple = get_iree_target_triple(device)
|
||||
@@ -45,6 +46,7 @@ class SharkPipelineBase:
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.mkdir(self.tmp_dir)
|
||||
self.tempfiles = {}
|
||||
self.pipe_vmfb_path = ""
|
||||
|
||||
def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
|
||||
# First checks whether we have .vmfbs precompiled, then populates the map
|
||||
@@ -63,11 +65,12 @@ class SharkPipelineBase:
|
||||
for key in self.model_map:
|
||||
self.get_compiled_map(pipe_id, submodel=key)
|
||||
else:
|
||||
self.get_precompiled(pipe_id, submodel)
|
||||
self.pipe_map[submodel] = {}
|
||||
self.get_precompiled(self.pipe_id, submodel)
|
||||
ireec_flags = []
|
||||
if submodel in self.iree_module_dict:
|
||||
return
|
||||
elif "vmfb_path" in self.model_map[submodel]:
|
||||
elif "vmfb_path" in self.pipe_map[submodel]:
|
||||
return
|
||||
elif submodel not in self.tempfiles:
|
||||
print(
|
||||
@@ -87,10 +90,8 @@ class SharkPipelineBase:
|
||||
else []
|
||||
)
|
||||
|
||||
if "external_weights_file" in self.model_map[submodel]:
|
||||
weights_path = self.model_map[submodel]["external_weights_file"]
|
||||
else:
|
||||
weights_path = None
|
||||
weights_path = self.get_io_params(submodel)
|
||||
|
||||
self.iree_module_dict[submodel] = get_iree_compiled_module(
|
||||
self.tempfiles[submodel],
|
||||
device=self.device,
|
||||
@@ -101,20 +102,30 @@ class SharkPipelineBase:
|
||||
write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb"),
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def get_io_params(self, submodel):
|
||||
if "external_weight_file" in self.static_kwargs[submodel]:
|
||||
# we are using custom weights
|
||||
weights_path = self.static_kwargs[submodel]["external_weight_file"]
|
||||
elif "external_weight_path" in self.static_kwargs[submodel]:
|
||||
# we are using the default weights for the HF model
|
||||
weights_path = self.static_kwargs[submodel]["external_weight_path"]
|
||||
else:
|
||||
# assume the torch IR contains the weights.
|
||||
weights_path = None
|
||||
return weights_path
|
||||
|
||||
def get_precompiled(self, pipe_id, submodel="None"):
|
||||
if submodel == "None":
|
||||
for model in self.model_map:
|
||||
self.get_precompiled(pipe_id, model)
|
||||
vmfbs = []
|
||||
vmfb_matches = {}
|
||||
vmfbs_path = self.pipe_vmfb_path
|
||||
for dirpath, dirnames, filenames in os.walk(vmfbs_path):
|
||||
for dirpath, dirnames, filenames in os.walk(self.pipe_vmfb_path):
|
||||
vmfbs.extend(filenames)
|
||||
break
|
||||
for file in vmfbs:
|
||||
if submodel in file:
|
||||
self.model_map[submodel]["vmfb_path"] = os.path.join(vmfbs_path, file)
|
||||
self.pipe_map[submodel]["vmfb_path"] = os.path.join(self.pipe_vmfb_path, file)
|
||||
return
|
||||
|
||||
def import_torch_ir(self, submodel, kwargs):
|
||||
@@ -139,9 +150,11 @@ class SharkPipelineBase:
|
||||
for submodel in submodels:
|
||||
if submodel in self.iree_module_dict:
|
||||
print(f"\n[LOG] {submodel} is ready for inference.")
|
||||
if "vmfb_path" in self.model_map[submodel]:
|
||||
continue
|
||||
if "vmfb_path" in self.pipe_map[submodel]:
|
||||
weights_path = self.get_io_params(submodel)
|
||||
print(
|
||||
f"\n[LOG] Loading .vmfb for {submodel} from {self.model_map[submodel]['vmfb_path']}"
|
||||
f"\n[LOG] Loading .vmfb for {submodel} from {self.pipe_map[submodel]['vmfb_path']}"
|
||||
)
|
||||
self.iree_module_dict[submodel] = {}
|
||||
(
|
||||
@@ -149,13 +162,11 @@ class SharkPipelineBase:
|
||||
self.iree_module_dict[submodel]["config"],
|
||||
self.iree_module_dict[submodel]["temp_file_to_unlink"],
|
||||
) = load_vmfb_using_mmap(
|
||||
self.model_map[submodel]["vmfb_path"],
|
||||
self.pipe_map[submodel]["vmfb_path"],
|
||||
self.device,
|
||||
device_idx=0,
|
||||
rt_flags=[],
|
||||
external_weight_file=self.model_map[submodel][
|
||||
"external_weight_file"
|
||||
],
|
||||
external_weight_file=weights_path,
|
||||
)
|
||||
else:
|
||||
self.get_compiled_map(self.pipe_id, submodel)
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"seed": -1,
|
||||
"batch_count": 1,
|
||||
"batch_size": 1,
|
||||
"scheduler": "DDIM",
|
||||
"scheduler": "EulerDiscrete",
|
||||
"base_model_id": "runwayml/stable-diffusion-v1-5",
|
||||
"custom_weights": "",
|
||||
"custom_vae": "",
|
||||
|
||||
@@ -226,76 +226,6 @@ def import_original(original_img, width, height):
|
||||
return EditorValue(img_dict)
|
||||
|
||||
|
||||
# def update_cn_input(
|
||||
# model,
|
||||
# stencil,
|
||||
# preprocessed_hint,
|
||||
# ):
|
||||
# print("update_cn_input")
|
||||
# if model == None:
|
||||
# stencil = None
|
||||
# preprocessed_hint = None
|
||||
# return [
|
||||
# gr.update(),
|
||||
# gr.update(),
|
||||
# gr.update(),
|
||||
# gr.update(),
|
||||
# gr.update(),
|
||||
# gr.update(),
|
||||
# stencil,
|
||||
# preprocessed_hint,
|
||||
# ]
|
||||
# elif model == "scribble":
|
||||
# return [
|
||||
# gr.ImageEditor(
|
||||
# visible=True,
|
||||
# interactive=True,
|
||||
# show_label=False,
|
||||
# image_mode="RGB",
|
||||
# type="pil",
|
||||
# brush=Brush(
|
||||
# colors=["#000000"],
|
||||
# color_mode="fixed",
|
||||
# default_size=5,
|
||||
# ),
|
||||
# ),
|
||||
# gr.Image(
|
||||
# visible=True,
|
||||
# show_label=False,
|
||||
# interactive=True,
|
||||
# show_download_button=False,
|
||||
# ),
|
||||
# gr.Slider(visible=True, label="Canvas Width"),
|
||||
# gr.Slider(visible=True, label="Canvas Height"),
|
||||
# gr.Button(visible=True),
|
||||
# gr.Button(visible=False),
|
||||
# stencil,
|
||||
# preprocessed_hint,
|
||||
# ]
|
||||
# else:
|
||||
# return [
|
||||
# gr.ImageEditor(
|
||||
# visible=True,
|
||||
# interactive=True,
|
||||
# show_label=False,
|
||||
# image_mode="RGB",
|
||||
# type="pil",
|
||||
# ),
|
||||
# gr.Image(
|
||||
# visible=True,
|
||||
# show_label=False,
|
||||
# interactive=True,
|
||||
# show_download_button=False,
|
||||
# ),
|
||||
# gr.Slider(visible=True, label="Canvas Width"),
|
||||
# gr.Slider(visible=True, label="Canvas Height"),
|
||||
# gr.Button(visible=False),
|
||||
# gr.Button(visible=True),
|
||||
# stencil,
|
||||
# preprocessed_hint,
|
||||
# ]
|
||||
|
||||
|
||||
with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
|
||||
@@ -12,11 +12,13 @@ def _init():
|
||||
global _sd_obj
|
||||
global _devices
|
||||
global _pipe_kwargs
|
||||
global _prep_kwargs
|
||||
global _gen_kwargs
|
||||
global _schedulers
|
||||
_sd_obj = None
|
||||
_devices = None
|
||||
_pipe_kwargs = None
|
||||
_prep_kwargs = None
|
||||
_gen_kwargs = None
|
||||
_schedulers = None
|
||||
set_devices()
|
||||
@@ -44,15 +46,19 @@ def set_sd_status(value):
|
||||
|
||||
def set_pipe_kwargs(value):
|
||||
global _pipe_kwargs
|
||||
print(value)
|
||||
_pipe_kwargs = value
|
||||
|
||||
|
||||
def set_prep_kwargs(value):
|
||||
global _prep_kwargs
|
||||
_prep_kwargs = value
|
||||
|
||||
|
||||
|
||||
def set_gen_kwargs(value):
|
||||
global _gen_kwargs
|
||||
_gen_kwargs = value
|
||||
|
||||
|
||||
def set_schedulers(value):
|
||||
global _schedulers
|
||||
_schedulers = value
|
||||
@@ -75,10 +81,14 @@ def get_sd_status():
|
||||
|
||||
def get_pipe_kwargs():
|
||||
global _pipe_kwargs
|
||||
print(_pipe_kwargs)
|
||||
return _pipe_kwargs
|
||||
|
||||
|
||||
def get_prep_kwargs():
|
||||
global _prep_kwargs
|
||||
return _prep_kwargs
|
||||
|
||||
|
||||
def get_gen_kwargs():
|
||||
global _gen_kwargs
|
||||
return _gen_kwargs
|
||||
@@ -92,14 +102,14 @@ def get_scheduler(key):
|
||||
def clear_cache():
|
||||
global _sd_obj
|
||||
global _pipe_kwargs
|
||||
global _prep_kwargs
|
||||
global _gen_kwargs
|
||||
global _schedulers
|
||||
del _sd_obj
|
||||
del _pipe_kwargs
|
||||
del _gen_kwargs
|
||||
del _schedulers
|
||||
gc.collect()
|
||||
_sd_obj = None
|
||||
_pipe_kwargs = None
|
||||
_prep_kwargs = None
|
||||
_gen_kwargs = None
|
||||
_schedulers = None
|
||||
|
||||
Reference in New Issue
Block a user