assembly/amd: clean up clt/ctz hack (#13901)

* assembly/amd: clean up clt/ctz hack

* add breaks
This commit is contained in:
George Hotz
2025-12-30 11:59:28 -05:00
committed by GitHub
parent 69cdc8066d
commit 7e14cdcb06
5 changed files with 33 additions and 31 deletions

View File

@@ -251,7 +251,7 @@ def _SOP1Op_S_FF0_I32_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
tmp = Reg(-1)
for i in range(0, int(31)+1):
if S0.u32[i] == 0:
tmp = Reg(i)
tmp = Reg(i); break
D0.i32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
@@ -274,7 +274,7 @@ def _SOP1Op_S_FF0_I32_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
tmp = Reg(-1)
for i in range(0, int(63)+1):
if S0.u64[i] == 0:
tmp = Reg(i)
tmp = Reg(i); break
D0.i32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
@@ -297,7 +297,7 @@ def _SOP1Op_S_FF1_I32_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
tmp = Reg(-1)
for i in range(0, int(31)+1):
if S0.u32[i] == 1:
tmp = Reg(i)
tmp = Reg(i); break
D0.i32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
@@ -320,7 +320,7 @@ def _SOP1Op_S_FF1_I32_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
tmp = Reg(-1)
for i in range(0, int(63)+1):
if S0.u64[i] == 1:
tmp = Reg(i)
tmp = Reg(i); break
D0.i32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
@@ -343,7 +343,7 @@ def _SOP1Op_S_FLBIT_I32_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
tmp = Reg(-1)
for i in range(0, int(31)+1):
if S0.u32[31 - i] == 1:
tmp = Reg(i)
tmp = Reg(i); break
D0.i32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
@@ -366,7 +366,7 @@ def _SOP1Op_S_FLBIT_I32_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
tmp = Reg(-1)
for i in range(0, int(63)+1):
if S0.u64[63 - i] == 1:
tmp = Reg(i)
tmp = Reg(i); break
D0.i32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
@@ -389,7 +389,7 @@ def _SOP1Op_S_FLBIT_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR
tmp = Reg(-1)
for i in range(1, int(31)+1):
if S0.u32[31 - i] != S0.u32[31]:
tmp = Reg(i)
tmp = Reg(i); break
D0.i32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
@@ -412,7 +412,7 @@ def _SOP1Op_S_FLBIT_I32_I64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
tmp = Reg(-1)
for i in range(1, int(63)+1):
if S0.u64[63 - i] != S0.u64[63]:
tmp = Reg(i)
tmp = Reg(i); break
D0.i32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
@@ -2964,7 +2964,7 @@ def _VOP1Op_V_FFBH_U32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
D0.i32 = -1
for i in range(0, int(31)+1):
if S0.u32[31 - i] == 1:
D0.i32 = i
D0.i32 = i; break
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
@@ -2984,7 +2984,7 @@ def _VOP1Op_V_FFBL_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
D0.i32 = -1
for i in range(0, int(31)+1):
if S0.u32[i] == 1:
D0.i32 = i
D0.i32 = i; break
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
@@ -3004,7 +3004,7 @@ def _VOP1Op_V_FFBH_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
D0.i32 = -1
for i in range(1, int(31)+1):
if S0.i32[31 - i] != S0.i32[31]:
D0.i32 = i
D0.i32 = i; break
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
@@ -14702,7 +14702,7 @@ def _VOP3AOp_V_FFBH_U32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR
D0.i32 = -1
for i in range(0, int(31)+1):
if S0.u32[31 - i] == 1:
D0.i32 = i
D0.i32 = i; break
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
@@ -14722,7 +14722,7 @@ def _VOP3AOp_V_FFBL_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR
D0.i32 = -1
for i in range(0, int(31)+1):
if S0.u32[i] == 1:
D0.i32 = i
D0.i32 = i; break
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
@@ -14742,7 +14742,7 @@ def _VOP3AOp_V_FFBH_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR
D0.i32 = -1
for i in range(1, int(31)+1):
if S0.i32[31 - i] != S0.i32[31]:
D0.i32 = i
D0.i32 = i; break
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result

View File

@@ -185,7 +185,7 @@ def _SOP1Op_S_CLS_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
tmp = Reg(-1)
for i in range(1, int(31)+1):
if S0.u32[31 - i] != S0.u32[31]:
tmp = Reg(i)
tmp = Reg(i); break
D0.i32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
@@ -208,7 +208,7 @@ def _SOP1Op_S_CLS_I32_I64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
tmp = Reg(-1)
for i in range(1, int(63)+1):
if S0.u64[63 - i] != S0.u64[63]:
tmp = Reg(i)
tmp = Reg(i); break
D0.i32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
@@ -4190,7 +4190,7 @@ def _VOP1Op_V_CLS_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
D0.i32 = -1
for i in range(1, int(31)+1):
if S0.i32[31 - i] != S0.i32[31]:
D0.i32 = i
D0.i32 = i; break
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
@@ -9472,7 +9472,7 @@ def _VOP3Op_V_CLS_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
D0.i32 = -1
for i in range(1, int(31)+1):
if S0.i32[31 - i] != S0.i32[31]:
D0.i32 = i
D0.i32 = i; break
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result

View File

@@ -185,7 +185,7 @@ def _SOP1Op_S_CLS_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
tmp = Reg(-1)
for i in range(1, int(31)+1):
if S0.u32[31 - i] != S0.u32[31]:
tmp = Reg(i)
tmp = Reg(i); break
D0.i32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
@@ -208,7 +208,7 @@ def _SOP1Op_S_CLS_I32_I64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
tmp = Reg(-1)
for i in range(1, int(63)+1):
if S0.u64[63 - i] != S0.u64[63]:
tmp = Reg(i)
tmp = Reg(i); break
D0.i32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
@@ -4184,7 +4184,7 @@ def _VOP1Op_V_CLS_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
D0.i32 = -1
for i in range(1, int(31)+1):
if S0.i32[31 - i] != S0.i32[31]:
D0.i32 = i
D0.i32 = i; break
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
@@ -9128,7 +9128,7 @@ def _VOP3Op_V_CLS_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
D0.i32 = -1
for i in range(1, int(31)+1):
if S0.i32[31 - i] != S0.i32[31]:
D0.i32 = i
D0.i32 = i; break
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result

View File

@@ -642,7 +642,7 @@ def compile_pseudocode(pseudocode: str) -> str:
joined_lines.append(line)
lines = []
indent, need_pass = 0, False
indent, need_pass, in_first_match_loop = 0, False, False
for line in joined_lines:
line = line.strip()
if not line or line.startswith('//'): continue
@@ -671,14 +671,14 @@ def compile_pseudocode(pseudocode: str) -> str:
elif line.startswith('endfor'):
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
need_pass = False
need_pass, in_first_match_loop = False, False
elif line.startswith('declare '):
pass
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
start, end = _expr(m[2].strip()), _expr(m[3].strip())
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
indent += 1
need_pass = True
need_pass, in_first_match_loop = True, True
elif '=' in line and not line.startswith('=='):
need_pass = False
line = line.rstrip(';')
@@ -697,7 +697,12 @@ def compile_pseudocode(pseudocode: str) -> str:
break
else:
lhs, rhs = line.split('=', 1)
lines.append(' ' * indent + _assign(lhs.strip(), _expr(rhs.strip())))
lhs_s, rhs_s = lhs.strip(), rhs.strip()
stmt = _assign(lhs_s, _expr(rhs_s))
# CLZ/CTZ pattern: assignment of loop var to tmp/D0.i32 in first-match loop needs break
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
stmt += "; break"
lines.append(' ' * indent + stmt)
# If we ended with a control statement that needs a body, add pass
if need_pass: lines.append(' ' * indent + "pass")
return '\n'.join(lines)
@@ -1014,11 +1019,6 @@ from extra.assembly.amd.pcode import *
code = compile_pseudocode(pc)
# NOTE: Do NOT add more code.replace() hacks here. Fix issues properly in the DSL
# (compile_pseudocode, helper functions, or Reg/TypedView classes) instead.
# CLZ/CTZ: The PDF pseudocode searches for the first 1 bit but doesn't break.
# Hardware stops at first match. SOP1 uses tmp=i, VOP1/VOP3 use D0.i32=i
if 'CLZ' in op.name or 'CTZ' in op.name:
code = code.replace('tmp = Reg(i)', 'tmp = Reg(i); break')
code = code.replace('D0.i32 = i', 'D0.i32 = i; break')
# V_DIV_FMAS_F32/F64: PDF page 449 says 2^32/2^64 but hardware behavior is more complex.
# The scale direction depends on S2 (the addend): if exponent(S2) > 127 (i.e., S2 >= 2.0),
# scale by 2^+64 (to unscale a numerator that was scaled). Otherwise scale by 2^-64

View File

@@ -208,6 +208,8 @@ D0.u32 = tmp.u32""")
for i in 0 : 31 do
if S0.u32[i] == 1 then
tmp = i
endif
endfor
D0.i32 = tmp""")
ctx = ExecContext(s0=0b1000) # Bit 3 is set
ctx.run(code)