feat(installer): add options to include or exclude xFormers based on the GPU model

This commit is contained in:
Eugene Brodsky
2024-10-08 09:58:45 -04:00
committed by Kent Keirsey
parent 3c46522595
commit 8deeac1372
2 changed files with 19 additions and 17 deletions

View File

@@ -282,12 +282,6 @@ class InvokeAiInstance:
shutil.copy(src, dest)
os.chmod(dest, 0o0755)
def update(self):
pass
def remove(self):
pass
### Utility functions ###
@@ -402,7 +396,7 @@ def get_torch_source() -> Tuple[str | None, str | None]:
:rtype: list
"""
from messages import select_gpu
from messages import GpuType, select_gpu
# device can be one of: "cuda", "rocm", "cpu", "cuda_and_dml, autodetect"
device = select_gpu()
@@ -412,16 +406,20 @@ def get_torch_source() -> Tuple[str | None, str | None]:
url = None
optional_modules: str | None = None
if OS == "Linux":
if device.value == "rocm":
if device == GpuType.ROCM:
url = "https://download.pytorch.org/whl/rocm5.6"
elif device.value == "cpu":
elif device == GpuType.CPU:
url = "https://download.pytorch.org/whl/cpu"
elif device.value == "cuda":
# CUDA uses the default PyPi index
# CUDA uses the default PyPi index
elif device == GpuType.CUDA:
optional_modules = "[onnx-cuda]"
elif device == GpuType.CUDA_WITH_XFORMERS:
optional_modules = "[xformers,onnx-cuda]"
elif OS == "Windows":
if device.value == "cuda":
if device == GpuType.CUDA:
url = "https://download.pytorch.org/whl/cu124"
optional_modules = "[onnx-cuda]"
elif device == GpuType.CUDA_WITH_XFORMERS:
optional_modules = "[xformers,onnx-cuda]"
elif device.value == "cpu":
# CPU uses the default PyPi index, no optional modules

View File

@@ -206,6 +206,7 @@ def dest_path(dest: Optional[str | Path] = None) -> Path | None:
class GpuType(Enum):
CUDA_WITH_XFORMERS = "xformers"
CUDA = "cuda"
ROCM = "rocm"
CPU = "cpu"
@@ -221,11 +222,15 @@ def select_gpu() -> GpuType:
return GpuType.CPU
nvidia = (
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
"an [gold1 b]NVIDIA[/] RTX 3060 or newer GPU using CUDA",
GpuType.CUDA,
)
vintage_nvidia = (
"an [gold1 b]NVIDIA[/] RTX 20xx or older GPU using CUDA+xFormers",
GpuType.CUDA_WITH_XFORMERS,
)
amd = (
"an [gold1 b]AMD[/] GPU (using ROCm™)",
"an [gold1 b]AMD[/] GPU using ROCm",
GpuType.ROCM,
)
cpu = (
@@ -235,14 +240,13 @@ def select_gpu() -> GpuType:
options = []
if OS == "Windows":
options = [nvidia, cpu]
options = [nvidia, vintage_nvidia, cpu]
if OS == "Linux":
options = [nvidia, amd, cpu]
options = [nvidia, vintage_nvidia, amd, cpu]
elif OS == "Darwin":
options = [cpu]
if len(options) == 1:
print(f'Your platform [gold1]{OS}-{ARCH}[/] only supports the "{options[0][1]}" driver. Proceeding with that.')
return options[0][1]
options = {str(i): opt for i, opt in enumerate(options, 1)}