mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
chore: ruff
This commit is contained in:
@@ -587,9 +587,9 @@ def invocation(
|
|||||||
for field_name, field_info in cls.model_fields.items():
|
for field_name, field_info in cls.model_fields.items():
|
||||||
annotation = field_info.annotation
|
annotation = field_info.annotation
|
||||||
assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation."
|
assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation."
|
||||||
assert isinstance(
|
assert isinstance(field_info.json_schema_extra, dict), (
|
||||||
field_info.json_schema_extra, dict
|
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
|
||||||
), f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
|
)
|
||||||
|
|
||||||
original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
|
original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
|
||||||
|
|
||||||
@@ -712,9 +712,9 @@ def invocation_output(
|
|||||||
for field_name, field_info in cls.model_fields.items():
|
for field_name, field_info in cls.model_fields.items():
|
||||||
annotation = field_info.annotation
|
annotation = field_info.annotation
|
||||||
assert annotation is not None, f"{field_name} on invocation output {output_type} has no type annotation."
|
assert annotation is not None, f"{field_name} on invocation output {output_type} has no type annotation."
|
||||||
assert isinstance(
|
assert isinstance(field_info.json_schema_extra, dict), (
|
||||||
field_info.json_schema_extra, dict
|
f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?"
|
||||||
), f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?"
|
)
|
||||||
|
|
||||||
cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
|
cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
|
||||||
|
|
||||||
|
|||||||
@@ -184,9 +184,9 @@ class SegmentAnythingInvocation(BaseInvocation):
|
|||||||
# Find the largest mask.
|
# Find the largest mask.
|
||||||
return [max(masks, key=lambda x: float(x.sum()))]
|
return [max(masks, key=lambda x: float(x.sum()))]
|
||||||
elif self.mask_filter == "highest_box_score":
|
elif self.mask_filter == "highest_box_score":
|
||||||
assert (
|
assert bounding_boxes is not None, (
|
||||||
bounding_boxes is not None
|
"Bounding boxes must be provided to use the 'highest_box_score' mask filter."
|
||||||
), "Bounding boxes must be provided to use the 'highest_box_score' mask filter."
|
)
|
||||||
assert len(masks) == len(bounding_boxes)
|
assert len(masks) == len(bounding_boxes)
|
||||||
# Find the index of the bounding box with the highest score.
|
# Find the index of the bounding box with the highest score.
|
||||||
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
|
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
|
||||||
|
|||||||
@@ -482,9 +482,9 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
|||||||
try:
|
try:
|
||||||
# Meta is not included in the model fields, so we need to validate it separately
|
# Meta is not included in the model fields, so we need to validate it separately
|
||||||
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||||
assert (
|
assert config.schema_version == CONFIG_SCHEMA_VERSION, (
|
||||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
||||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
)
|
||||||
return config
|
return config
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||||
|
|||||||
@@ -379,13 +379,13 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
|||||||
bytes_ = path.read_bytes()
|
bytes_ = path.read_bytes()
|
||||||
workflow_from_file = WorkflowValidator.validate_json(bytes_)
|
workflow_from_file = WorkflowValidator.validate_json(bytes_)
|
||||||
|
|
||||||
assert workflow_from_file.id.startswith(
|
assert workflow_from_file.id.startswith("default_"), (
|
||||||
"default_"
|
f'Invalid default workflow ID (must start with "default_"): {workflow_from_file.id}'
|
||||||
), f'Invalid default workflow ID (must start with "default_"): {workflow_from_file.id}'
|
)
|
||||||
|
|
||||||
assert (
|
assert workflow_from_file.meta.category is WorkflowCategory.Default, (
|
||||||
workflow_from_file.meta.category is WorkflowCategory.Default
|
f"Invalid default workflow category: {workflow_from_file.meta.category}"
|
||||||
), f"Invalid default workflow category: {workflow_from_file.meta.category}"
|
)
|
||||||
|
|
||||||
workflows_from_file.append(workflow_from_file)
|
workflows_from_file.append(workflow_from_file)
|
||||||
|
|
||||||
|
|||||||
@@ -115,19 +115,19 @@ class ModelMerger(object):
|
|||||||
base_models: Set[BaseModelType] = set()
|
base_models: Set[BaseModelType] = set()
|
||||||
variant = None if self._installer.app_config.precision == "float32" else "fp16"
|
variant = None if self._installer.app_config.precision == "float32" else "fp16"
|
||||||
|
|
||||||
assert (
|
assert len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference, (
|
||||||
len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference
|
"When merging three models, only the 'add_difference' merge method is supported"
|
||||||
), "When merging three models, only the 'add_difference' merge method is supported"
|
)
|
||||||
|
|
||||||
for key in model_keys:
|
for key in model_keys:
|
||||||
info = store.get_model(key)
|
info = store.get_model(key)
|
||||||
model_names.append(info.name)
|
model_names.append(info.name)
|
||||||
assert isinstance(
|
assert isinstance(info, MainDiffusersConfig), (
|
||||||
info, MainDiffusersConfig
|
f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging"
|
||||||
), f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging"
|
)
|
||||||
assert info.variant == ModelVariantType(
|
assert info.variant == ModelVariantType("normal"), (
|
||||||
"normal"
|
f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
|
||||||
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
|
)
|
||||||
|
|
||||||
# tally base models used
|
# tally base models used
|
||||||
base_models.add(info.base)
|
base_models.add(info.base)
|
||||||
|
|||||||
@@ -211,12 +211,12 @@ def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
|
|||||||
assert job.bytes > 0, "expected download bytes to be positive"
|
assert job.bytes > 0, "expected download bytes to be positive"
|
||||||
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
|
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
|
||||||
assert job.download_path == tmp_path / "sdxl-turbo"
|
assert job.download_path == tmp_path / "sdxl-turbo"
|
||||||
assert Path(
|
assert Path(tmp_path, "sdxl-turbo/model_index.json").exists(), (
|
||||||
tmp_path, "sdxl-turbo/model_index.json"
|
f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
|
||||||
).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
|
)
|
||||||
assert Path(
|
assert Path(tmp_path, "sdxl-turbo/text_encoder/config.json").exists(), (
|
||||||
tmp_path, "sdxl-turbo/text_encoder/config.json"
|
f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist"
|
||||||
).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist"
|
)
|
||||||
|
|
||||||
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
||||||
queue.stop()
|
queue.stop()
|
||||||
|
|||||||
@@ -48,9 +48,9 @@ def test_flux_aitoolkit_transformer_state_dict_is_in_invoke_format():
|
|||||||
model_keys = set(model.state_dict().keys())
|
model_keys = set(model.state_dict().keys())
|
||||||
|
|
||||||
for converted_key_prefix in converted_key_prefixes:
|
for converted_key_prefix in converted_key_prefixes:
|
||||||
assert any(
|
assert any(model_key.startswith(converted_key_prefix) for model_key in model_keys), (
|
||||||
model_key.startswith(converted_key_prefix) for model_key in model_keys
|
f"'{converted_key_prefix}' did not match any model keys."
|
||||||
), f"'{converted_key_prefix}' did not match any model keys."
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_lora_model_from_flux_aitoolkit_state_dict():
|
def test_lora_model_from_flux_aitoolkit_state_dict():
|
||||||
|
|||||||
Reference in New Issue
Block a user