diff --git a/tinygrad/runtime/support/amd.py b/tinygrad/runtime/support/amd.py index 9086d1c176..7863ad9d25 100644 --- a/tinygrad/runtime/support/amd.py +++ b/tinygrad/runtime/support/amd.py @@ -1,7 +1,7 @@ import functools, importlib, time from collections import defaultdict from dataclasses import dataclass -from tinygrad.helpers import getbits, round_up +from tinygrad.helpers import getbits, round_up, getenv from tinygrad.runtime.autogen import pci from tinygrad.runtime.support.usb import ASM24Controller @@ -33,8 +33,10 @@ def import_module(name:str, version:tuple[int, ...], version_prefix:str=""): raise ImportError(f"Failed to load autogen module for {name.upper()} {'.'.join(map(str, version))}") def setup_pci_bars(usb:ASM24Controller, gpu_bus:int, mem_base:int, pref_mem_base:int) -> dict[int, tuple[int, int]]: - try: usb.pcie_cfg_req(pci.PCI_VENDOR_ID, bus=gpu_bus, dev=0, fn=0, size=2) - except RuntimeError: + try: need_reset = (usb.pcie_cfg_req(pci.PCI_VENDOR_ID, bus=gpu_bus, dev=0, fn=0, size=2) != 0x1002) + except RuntimeError: need_reset = True + + if need_reset or getenv("USB_RESCAN_BUS", 0) == 1: for bus in range(gpu_bus): usb.pcie_cfg_req(pci.PCI_SUBORDINATE_BUS, bus=bus, dev=0, fn=0, value=gpu_bus, size=1) usb.pcie_cfg_req(pci.PCI_SECONDARY_BUS, bus=bus, dev=0, fn=0, value=bus+1, size=1)