mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Fix incorrect device argument initialization for LoRA training by extracting the device type and number and formatting it for pytorch (#1237)
Co-authored-by: Kyle Herndon <kyle@nod-labs.com>
This commit is contained in:
@@ -222,7 +222,12 @@ def lora_train(
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
device_str = device.split("=>", 1)[1].strip().split("://")
|
||||
if len(device_str) > 1:
|
||||
device_str = device_str[0] + ":" + device_str[1]
|
||||
else:
|
||||
device_str = device_str[0]
|
||||
args.device = device_str
|
||||
|
||||
# Load the Stable Diffusion model
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user