Compare commits

..

1 Commits

Author SHA1 Message Date
psychedelicious
096c17465b fix(backend): use random seed for SDE & Ancestral schedulers
SDE and Ancestral schedulers use some randomness at each step when applying conditioning. We were using a static seed of `0` for this, regardless of the initial noise used. This could cause results to be a bit same-y.

Unfortunately, we do not have easy access to the seed used to create the initial noise at this time.

Changing this to use a random seed value instead of always 0.
2023-08-07 09:09:12 +10:00
7 changed files with 19 additions and 140 deletions

View File

@@ -180,6 +180,10 @@ class TextToLatentsInvocation(BaseInvocation):
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
# for ancestral and sde schedulers
generator = torch.Generator(device=unet.device)
generator.seed()
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
text_embeddings=c,
@@ -198,7 +202,7 @@ class TextToLatentsInvocation(BaseInvocation):
# for ddim scheduler
eta=0.0, # ddim_eta
# for ancestral and sde schedulers
generator=torch.Generator(device=unet.device).manual_seed(0),
generator=generator,
)
return conditioning_data

View File

@@ -1,73 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Fast hashing of diffusers and checkpoint-style models.
Usage:
from invokeai.backend.model_management.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f'
"""
import os
import hashlib
from imohash import hashfile
from pathlib import Path
from typing import Dict, Union
class FastModelHash(object):
"""FastModelHash obect provides one public class method, hash()."""
# When traversing directories, ignore files smaller than this
# minimum value
MINIMUM_FILE_SIZE = 100000
@classmethod
def hash(cls, model_location: Union[str, Path]) -> str:
"""
Return hexdigest string for model located at model_location.
:param model_location: Path to the model
"""
model_location = Path(model_location)
if model_location.is_file():
return cls._hash_file(model_location)
elif model_location.is_dir():
return cls._hash_dir(model_location)
else:
# avoid circular import
from .models import InvalidModelException
raise InvalidModelException(f"Not a valid file or directory: {model_location}")
@classmethod
def _hash_file(cls, model_location: Union[str, Path]) -> str:
"""
Fasthash a single file and return its hexdigest.
:param model_location: Path to the model file
"""
# we return sha256 hash of the filehash in order to be
# consistent with length of hashes returned by _hash_dir()
return hashlib.sha256(hashfile(model_location)).hexdigest()
@classmethod
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
components: Dict[str, str] = {}
for root, dirs, files in os.walk(model_location):
for file in files:
# Only pay attention to the big files. The config
# files contain things like diffusers point version
# which change locally.
path = Path(root) / file
if path.stat().st_size < cls.MINIMUM_FILE_SIZE:
continue
fast_hash = cls._hash_file(path)
components.update({str(path): fast_hash})
# hash all the model hashes together, using alphabetic file order
sha = hashlib.sha256()
for path, fast_hash in sorted(components.items()):
sha.update(fast_hash.encode("utf-8"))
return sha.hexdigest()

View File

@@ -260,7 +260,6 @@ from .models import (
InvalidModelException,
DuplicateModelException,
)
from .model_hash import FastModelHash
# We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help
@@ -365,8 +364,6 @@ class ModelManager(object):
model_class = MODEL_CLASSES[base_model][model_type]
# alias for config file
model_config["model_format"] = model_config.pop("format")
if not model_config.get("hash"):
model_config["hash"] = FastModelHash.hash(self.resolve_model_path(model_config["path"]))
self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary
@@ -434,28 +431,6 @@ class ModelManager(object):
with open(config_path, "w") as yaml_file:
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
def get_model_by_hash(
self,
model_hash: str,
submodel_type: Optional[SubModelType] = None,
) -> ModelInfo:
"""
Given a model's unique hash, return its ModelInfo.
:param model_hash: Unique hash for this model.
"""
info = self.list_models()
keys = [x for x in info if x["hash"] == model_hash]
if len(keys) == 0:
raise InvalidModelException(f"No model with hash {model_hash} found")
if len(keys) > 1:
raise DuplicateModelException(f"Duplicate models detected: {keys}")
return self.get_model(
keys[0]["model_name"],
base_model=keys[0]["base_model"],
model_type=keys[0]["model_type"],
)
def get_model(
self,
model_name: str,
@@ -525,12 +500,14 @@ class ModelManager(object):
self.cache_keys[model_key] = set()
self.cache_keys[model_key].add(model_context.key)
model_hash = "<NO_HASH>" # TODO:
return ModelInfo(
context=model_context,
name=model_name,
base_model=base_model,
type=submodel_type or model_type,
hash=model_config.hash,
hash=model_hash,
location=model_path, # TODO:
precision=self.cache.precision,
_cache=self.cache,
@@ -683,22 +660,12 @@ class ModelManager(object):
if path := model_attributes.get("path"):
model_attributes["path"] = str(self.relative_model_path(Path(path)))
if not model_attributes.get("hash"):
hash = FastModelHash.hash(self.resolve_model_path(model_attributes["path"]))
model_attributes["hash"] = hash
model_class = MODEL_CLASSES[base_model][model_type]
model_config = model_class.create_config(**model_attributes)
model_key = self.create_key(model_name, base_model, model_type)
if not clobber:
if model_key in self.models:
raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
try:
i = self.get_model_by_hash(model_attributes["hash"])
raise DuplicateModelException(f"There is already a model with hash {hash}: {i['name']}")
except:
pass
if model_key in self.models and not clobber:
raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
old_model = self.models.pop(model_key, None)
if old_model is not None:
@@ -974,11 +941,7 @@ class ModelManager(object):
raise DuplicateModelException(f"Model with key {model_key} added twice")
model_path = self.relative_model_path(model_path)
model_config: ModelConfigBase = model_class.probe_config(
str(model_path),
hash=FastModelHash.hash(model_path),
model_base=cur_base_model,
)
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
self.models[model_key] = model_config
new_models_found = True
except DuplicateModelException as e:

View File

@@ -345,12 +345,8 @@ class LoRACheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1
elif lora_token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif lora_token_vector_length is None: # variant w/o the text encoder!
return BaseModelType.StableDiffusion1
else:
raise InvalidModelException(
f"Unknown LoRA type: {self.checkpoint_path}, lora_token_vector_length={lora_token_vector_length}"
)
raise InvalidModelException(f"Unknown LoRA type")
class TextualInversionCheckpointProbe(CheckpointProbeBase):

View File

@@ -89,7 +89,6 @@ class ModelConfigBase(BaseModel):
path: str # or Path
description: Optional[str] = Field(None)
model_format: Optional[str] = Field(None)
hash: Optional[str] = Field(None)
error: Optional[ModelError] = Field(None)
class Config:
@@ -198,16 +197,15 @@ class ModelBase(metaclass=ABCMeta):
def create_config(cls, **kwargs) -> ModelConfigBase:
if "model_format" not in kwargs:
raise Exception("Field 'model_format' not found in model config")
configs = cls._get_configs()
config = configs[kwargs["model_format"]](**kwargs)
return config
return configs[kwargs["model_format"]](**kwargs)
@classmethod
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
return cls.create_config(
path=path,
model_format=cls.detect_format(path),
hash=kwargs["hash"],
)
@classmethod

View File

@@ -13,11 +13,8 @@ from .base import (
read_checkpoint_meta,
classproperty,
)
from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf
app_config = InvokeAIAppConfig.get_config()
class StableDiffusionXLModelFormat(str, Enum):
Checkpoint = "checkpoint"
@@ -25,7 +22,7 @@ class StableDiffusionXLModelFormat(str, Enum):
class StableDiffusionXLModel(DiffusersModel):
# TODO: check that configs overwritten properly
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
vae: Optional[str] = Field(None)
@@ -82,19 +79,14 @@ class StableDiffusionXLModel(DiffusersModel):
else:
raise Exception("Unkown stable diffusion 2.* model format")
if ckpt_config_path is None and "model_base" in kwargs:
ckpt_config_path = (
app_config.legacy_conf_path / "sd_xl_base.yaml"
if kwargs["model_base"] == BaseModelType.StableDiffusionXL
else app_config.legacy_conf_path / "sd_xl_refiner.yaml"
if kwargs["model_base"] == BaseModelType.StableDiffusionXLRefiner
else None
)
if ckpt_config_path is None:
# TO DO: implement picking
pass
return cls.create_config(
path=path,
model_format=model_format,
config=str(ckpt_config_path),
config=ckpt_config_path,
variant=variant,
)

View File

@@ -55,7 +55,6 @@ dependencies = [
"flask_socketio==5.3.0",
"flaskwebgui==1.0.3",
"huggingface-hub>=0.11.1",
"imohash~=1.0.0",
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
"matplotlib", # needed for plotting of Penner easing functions
"mediapipe", # needed for "mediapipeface" controlnet model