usbgpu: check hash in patcher (#10266)

This commit is contained in:
nimlgen
2025-05-12 21:08:53 +03:00
committed by GitHub
parent 94907d02c8
commit bb31cc4582

View File

@@ -1,12 +1,16 @@
#!/usr/bin/env python3
import sys, os, zlib, struct
import sys, os, zlib, struct, hashlib
from hexdump import hexdump
from tinygrad.helpers import DEBUG, getenv, fetch
from tinygrad.runtime.support.usb import USB3
def patch(input_filepath, patches):
def patch(input_filepath, file_hash, patches):
with open(input_filepath, 'rb') as infile: data = bytearray(infile.read())
if_hash = hashlib.md5(data).hexdigest()
if if_hash != file_hash:
raise ValueError(f"File hash mismatch: expected {file_hash}, got {if_hash}")
for offset, expected_bytes, new_bytes in patches:
if len(expected_bytes) != len(new_bytes):
raise ValueError("Expected bytes and new bytes must be the same length")
@@ -26,6 +30,7 @@ def patch(input_filepath, patches):
return data
path = os.path.dirname(os.path.abspath(__file__))
file_hash = "5284e618d96ef804c06f47f3b73656b7"
file_path = os.path.join(path, "Software/AS_USB4_240417_85_00_00.bin")
if not os.path.exists(file_path):
@@ -34,7 +39,7 @@ if not os.path.exists(file_path):
os.system(f'unzip -o "{path}/fw.zip" "Software/AS_USB4_240417_85_00_00.bin" -d "{path}"')
patches = [(0x2a0d + 1 + 4, b'\x0a', b'\x05')]
patched_fw = patch(file_path, patches)
patched_fw = patch(file_path, file_hash, patches)
vendor, device = [int(x, base=16) for x in getenv("USBDEV", "174C:2464").split(":")]
try: dev = USB3(vendor, device, 0x81, 0x83, 0x02, 0x04)