mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-03 22:35:24 -05:00
Textual Inversion for M1
Update main.py Update ddpm.py Update personalized.py Update personalized_style.py Update v1-finetune.yaml Update environment-mac.yaml Rename v1-finetune.yaml to v1-m1-finetune.yaml Create v1-finetune.yaml Update main.py Update main.py Update environment-mac.yaml Update v1-inference.yaml
This commit is contained in:
@@ -117,7 +117,7 @@ class PersonalizedBase(Dataset):
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
for file_path in os.listdir(self.data_root) if file_path != ".DS_Store"
|
||||
]
|
||||
|
||||
# self._length = len(self.image_paths)
|
||||
|
||||
@@ -93,7 +93,7 @@ class PersonalizedBase(Dataset):
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
for file_path in os.listdir(self.data_root) if file_path != ".DS_Store"
|
||||
]
|
||||
|
||||
# self._length = len(self.image_paths)
|
||||
|
||||
@@ -701,7 +701,7 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
@rank_zero_only
|
||||
@torch.no_grad()
|
||||
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
|
||||
def on_train_batch_start(self, batch, batch_idx, dataloader_idx=None):
|
||||
# only for very first batch
|
||||
if (
|
||||
self.scale_by_std
|
||||
@@ -1890,7 +1890,7 @@ class LatentDiffusion(DDPM):
|
||||
N=8,
|
||||
n_row=4,
|
||||
sample=True,
|
||||
ddim_steps=200,
|
||||
ddim_steps=50,
|
||||
ddim_eta=1.0,
|
||||
return_keys=None,
|
||||
quantize_denoised=True,
|
||||
|
||||
@@ -169,9 +169,14 @@ class EmbeddingManager(nn.Module):
|
||||
placeholder_embedding.shape[0], max_step_tokens
|
||||
)
|
||||
|
||||
placeholder_rows, placeholder_cols = torch.where(
|
||||
tokenized_text == placeholder_token.to(device)
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
placeholder_rows, placeholder_cols = torch.where(
|
||||
tokenized_text == placeholder_token.to(device)
|
||||
)
|
||||
else:
|
||||
placeholder_rows, placeholder_cols = torch.where(
|
||||
tokenized_text == placeholder_token
|
||||
)
|
||||
|
||||
if placeholder_rows.nelement() == 0:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user