diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 353d4e96cf..ff0a94b00f 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -190,15 +190,15 @@ class AMDComputeQueue(HWQueue): self.wreg(self.gc.regGRBM_GFX_INDEX, se_index=se, instance_broadcast_writes=1) # Wait for FINISH_PENDING==0 self.pkt3(self.pm4.PACKET3_WAIT_REG_MEM, self.pm4.WAIT_REG_MEM_FUNCTION(WAIT_REG_MEM_FUNCTION_EQ), - self.gc.regSQ_THREAD_TRACE_STATUS.addr, 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.field_mask('finish_pending'), 4) + self.gc.regSQ_THREAD_TRACE_STATUS.addr, 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('finish_pending'), 4) # Wait for FINISH_DONE!=0 self.pkt3(self.pm4.PACKET3_WAIT_REG_MEM, self.pm4.WAIT_REG_MEM_FUNCTION(WAIT_REG_MEM_FUNCTION_NEQ), - self.gc.regSQ_THREAD_TRACE_STATUS.addr, 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.field_mask('finish_done'), 4) + self.gc.regSQ_THREAD_TRACE_STATUS.addr, 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('finish_done'), 4) # Disable SQTT self.sqtt_config(tracing=False) # Wait for BUSY==0 self.pkt3(self.pm4.PACKET3_WAIT_REG_MEM, self.pm4.WAIT_REG_MEM_FUNCTION(WAIT_REG_MEM_FUNCTION_EQ), - self.gc.regSQ_THREAD_TRACE_STATUS.addr, 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.field_mask('busy'), 4) + self.gc.regSQ_THREAD_TRACE_STATUS.addr, 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('busy'), 4) # Copy WPTR to memory (src_sel = perf, dst_sel = tc_l2, wr_confirm = True) self.pkt3(self.pm4.PACKET3_COPY_DATA, 1 << 20 | 2 << 8 | 4, self.gc.regSQ_THREAD_TRACE_WPTR.addr, 0, *data64_le(wptrs.va_addr+(se*4))) # Restore global broadcasting diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index d4ae5ca0f3..22e88c0bdb 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -18,7 +18,7 @@ class AMRegister(AMDReg): def write(self, _am_val:int=0, **kwargs): self.adev.wreg(self.addr, _am_val | self.encode(**kwargs)) - def update(self, **kwargs): self.write(self.encode(**{**self.read_bitfields(), **kwargs})) + def update(self, **kwargs): self.write(self.read() & ~self.fields_mask(*kwargs.keys()), **kwargs) class AMFirmware: def __init__(self, adev): diff --git a/tinygrad/runtime/support/amd.py b/tinygrad/runtime/support/amd.py index 37034867cf..8177418ed6 100644 --- a/tinygrad/runtime/support/amd.py +++ b/tinygrad/runtime/support/amd.py @@ -11,10 +11,9 @@ class AMDReg: def encode(self, **kwargs) -> int: return functools.reduce(int.__or__, (value << self.fields[name][0] for name,value in kwargs.items()), 0) def decode(self, val: int) -> dict: return {name:getbits(val, start, end) for name,(start,end) in self.fields.items()} - def field_mask(self, field_name) -> int: - start, end = self.fields[field_name] - num_bits = end - start + 1 - return ((1 << num_bits) - 1) << start + + def fields_mask(self, *names) -> int: + return functools.reduce(int.__or__, ((((1 << (self.fields[nm][1]-self.fields[nm][0]+1)) - 1) << self.fields[nm][0]) for nm in names), 0) @property def addr(self): return self.bases[self.segment] + self.offset