mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
am: add versioned_header to load_fw (#11702)
* am: add versioned_header to load_fw * fix mypy
This commit is contained in:
@@ -29,7 +29,7 @@ class AMFirmware:
|
||||
# Load SOS firmware
|
||||
self.sos_fw = {}
|
||||
|
||||
blob, sos_hdr = self.load_fw(f"psp_{fmt_ver(am.MP0_HWIP)}_sos.bin", am.struct_psp_firmware_header_v2_0)
|
||||
blob, sos_hdr = self.load_fw(f"psp_{fmt_ver(am.MP0_HWIP)}_sos.bin", versioned_header='struct_psp_firmware_header')
|
||||
fw_bin = sos_hdr.psp_fw_bin
|
||||
|
||||
for fw_i in range(sos_hdr.psp_fw_bin_count):
|
||||
@@ -45,11 +45,11 @@ class AMFirmware:
|
||||
self.smu_psp_desc = self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.header.ucode_size_bytes, am.GFX_FW_TYPE_SMU)
|
||||
|
||||
# SDMA firmware
|
||||
blob, hdr, hdr_v3 = self.load_fw(f"sdma_{fmt_ver(am.SDMA0_HWIP)}.bin", am.struct_sdma_firmware_header_v2_0, am.struct_sdma_firmware_header_v3_0)
|
||||
blob, hdr = self.load_fw(f"sdma_{fmt_ver(am.SDMA0_HWIP)}.bin", versioned_header='struct_sdma_firmware_header')
|
||||
if hdr.header.header_version_major < 3:
|
||||
self.descs += [self.desc(blob, hdr.ctl_ucode_offset, hdr.ctl_ucode_size_bytes, am.GFX_FW_TYPE_SDMA_UCODE_TH1)]
|
||||
self.descs += [self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.ctx_ucode_size_bytes, am.GFX_FW_TYPE_SDMA_UCODE_TH0)]
|
||||
else: self.descs += [self.desc(blob, hdr_v3.header.ucode_array_offset_bytes, hdr_v3.ucode_size_bytes, am.GFX_FW_TYPE_SDMA_UCODE_TH0)]
|
||||
else: self.descs += [self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.ucode_size_bytes, am.GFX_FW_TYPE_SDMA_UCODE_TH0)]
|
||||
|
||||
# PFP, ME, MEC firmware
|
||||
for (fw_name, fw_cnt) in ([('PFP', 1), ('ME', 1)] if self.adev.ip_ver[am.GC_HWIP] >= (12,0,0) else []) + [('MEC', 1)]:
|
||||
@@ -83,10 +83,13 @@ class AMFirmware:
|
||||
|
||||
self.descs += [self.desc(blob, hdr0.header.ucode_array_offset_bytes, hdr0.header.ucode_size_bytes, am.GFX_FW_TYPE_RLC_G)]
|
||||
|
||||
def load_fw(self, fname:str, *headers):
|
||||
def load_fw(self, fname:str, *headers, versioned_header:str|None=None):
|
||||
fpath = fetch(f"https://gitlab.com/kernel-firmware/linux-firmware/-/raw/45f59212aebd226c7630aff4b58598967c0c8c91/amdgpu/{fname}", subdir="fw")
|
||||
blob = memoryview(bytearray(fpath.read_bytes()))
|
||||
if AM_DEBUG >= 1: print(f"am {self.adev.devfmt}: loading firmware {fname}: {hashlib.sha256(blob).hexdigest()}")
|
||||
if versioned_header:
|
||||
chdr = am.struct_common_firmware_header.from_address(mv_address(blob))
|
||||
headers += (getattr(am, versioned_header + f"_v{chdr.header_version_major}_{chdr.header_version_minor}"),)
|
||||
return tuple([blob] + [hdr.from_address(mv_address(blob)) for hdr in headers])
|
||||
|
||||
def desc(self, blob:memoryview, offset:int, size:int, *types:int) -> tuple[list[int], memoryview]: return (list(types), blob[offset:offset+size])
|
||||
|
||||
Reference in New Issue
Block a user