Fix incorrect device argument initialization for LoRA training by extracting the device type and number and formatting it for pytorch (#1237)

Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
This commit is contained in:
Kyle Herndon
2023-03-24 01:10:50 -07:00
committed by GitHub
parent 4fac46f7bb
commit 0b0526699a

View File

@@ -222,7 +222,12 @@ def lora_train(
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
device_str = device.split("=>", 1)[1].strip().split("://")
if len(device_str) > 1:
device_str = device_str[0] + ":" + device_str[1]
else:
device_str = device_str[0]
args.device = device_str
# Load the Stable Diffusion model
text_encoder = CLIPTextModel.from_pretrained(