mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
am: fix reg update (#10707)
This commit is contained in:
@@ -190,15 +190,15 @@ class AMDComputeQueue(HWQueue):
|
||||
self.wreg(self.gc.regGRBM_GFX_INDEX, se_index=se, instance_broadcast_writes=1)
|
||||
# Wait for FINISH_PENDING==0
|
||||
self.pkt3(self.pm4.PACKET3_WAIT_REG_MEM, self.pm4.WAIT_REG_MEM_FUNCTION(WAIT_REG_MEM_FUNCTION_EQ),
|
||||
self.gc.regSQ_THREAD_TRACE_STATUS.addr, 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.field_mask('finish_pending'), 4)
|
||||
self.gc.regSQ_THREAD_TRACE_STATUS.addr, 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('finish_pending'), 4)
|
||||
# Wait for FINISH_DONE!=0
|
||||
self.pkt3(self.pm4.PACKET3_WAIT_REG_MEM, self.pm4.WAIT_REG_MEM_FUNCTION(WAIT_REG_MEM_FUNCTION_NEQ),
|
||||
self.gc.regSQ_THREAD_TRACE_STATUS.addr, 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.field_mask('finish_done'), 4)
|
||||
self.gc.regSQ_THREAD_TRACE_STATUS.addr, 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('finish_done'), 4)
|
||||
# Disable SQTT
|
||||
self.sqtt_config(tracing=False)
|
||||
# Wait for BUSY==0
|
||||
self.pkt3(self.pm4.PACKET3_WAIT_REG_MEM, self.pm4.WAIT_REG_MEM_FUNCTION(WAIT_REG_MEM_FUNCTION_EQ),
|
||||
self.gc.regSQ_THREAD_TRACE_STATUS.addr, 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.field_mask('busy'), 4)
|
||||
self.gc.regSQ_THREAD_TRACE_STATUS.addr, 0, 0, self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('busy'), 4)
|
||||
# Copy WPTR to memory (src_sel = perf, dst_sel = tc_l2, wr_confirm = True)
|
||||
self.pkt3(self.pm4.PACKET3_COPY_DATA, 1 << 20 | 2 << 8 | 4, self.gc.regSQ_THREAD_TRACE_WPTR.addr, 0, *data64_le(wptrs.va_addr+(se*4)))
|
||||
# Restore global broadcasting
|
||||
|
||||
@@ -18,7 +18,7 @@ class AMRegister(AMDReg):
|
||||
|
||||
def write(self, _am_val:int=0, **kwargs): self.adev.wreg(self.addr, _am_val | self.encode(**kwargs))
|
||||
|
||||
def update(self, **kwargs): self.write(self.encode(**{**self.read_bitfields(), **kwargs}))
|
||||
def update(self, **kwargs): self.write(self.read() & ~self.fields_mask(*kwargs.keys()), **kwargs)
|
||||
|
||||
class AMFirmware:
|
||||
def __init__(self, adev):
|
||||
|
||||
@@ -11,10 +11,9 @@ class AMDReg:
|
||||
|
||||
def encode(self, **kwargs) -> int: return functools.reduce(int.__or__, (value << self.fields[name][0] for name,value in kwargs.items()), 0)
|
||||
def decode(self, val: int) -> dict: return {name:getbits(val, start, end) for name,(start,end) in self.fields.items()}
|
||||
def field_mask(self, field_name) -> int:
|
||||
start, end = self.fields[field_name]
|
||||
num_bits = end - start + 1
|
||||
return ((1 << num_bits) - 1) << start
|
||||
|
||||
def fields_mask(self, *names) -> int:
|
||||
return functools.reduce(int.__or__, ((((1 << (self.fields[nm][1]-self.fields[nm][0]+1)) - 1) << self.fields[nm][0]) for nm in names), 0)
|
||||
|
||||
@property
|
||||
def addr(self): return self.bases[self.segment] + self.offset
|
||||
|
||||
Reference in New Issue
Block a user