diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 104474b430..bff1fa7ecd 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -96,7 +96,8 @@ class ONNXPromptInvocation(BaseInvocation): #import traceback #print(traceback.format_exc()) print(f"Warn: trigger: \"{trigger}\" not found") - + if loras or ti_list: + text_encoder.release_session() with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),\ ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager): @@ -127,7 +128,6 @@ class ONNXPromptInvocation(BaseInvocation): prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - text_encoder.release_session() conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" @@ -255,6 +255,8 @@ class ONNXTextToLatentsInvocation(BaseInvocation): #loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras] + if loras: + unet.release_session() with ONNXModelPatcher.apply_lora_unet(unet, loras): # TODO: _, _, h, w = latents.shape @@ -303,7 +305,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation): # call the callback, if provided #if callback is not None and i % callback_steps == 0: # callback(i, t, latents) - unet.release_session() torch.cuda.empty_cache() @@ -360,7 +361,6 @@ class ONNXLatentsToImageInvocation(BaseInvocation): image = image.transpose((0, 2, 3, 1)) image = VaeImageProcessor.numpy_to_pil(image)[0] - vae.release_session() torch.cuda.empty_cache() diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index 06255ac6f6..d8b851411b 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -387,7 +387,7 @@ def _calc_model_by_data(model) -> int: def _calc_onnx_model_by_data(model) -> int: - tensor_size = model.tensors.size() + tensor_size = model.tensors.size() * 2 # The session doubles this mem = tensor_size # in bytes return mem @@ -608,9 +608,9 @@ class IAIOnnxRuntimeModel: # self.io_binding = self.session.io_binding() def release_session(self): - # self.session = None - # import gc - # gc.collect() + self.session = None + import gc + gc.collect() return def __call__(self, **kwargs):