mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
assembly/amd: clean up clt/ctz hack (#13901)
* assembly/amd: clean up clt/ctz hack * add breaks
This commit is contained in:
@@ -251,7 +251,7 @@ def _SOP1Op_S_FF0_I32_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
|
||||
tmp = Reg(-1)
|
||||
for i in range(0, int(31)+1):
|
||||
if S0.u32[i] == 0:
|
||||
tmp = Reg(i)
|
||||
tmp = Reg(i); break
|
||||
D0.i32 = tmp
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
@@ -274,7 +274,7 @@ def _SOP1Op_S_FF0_I32_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
|
||||
tmp = Reg(-1)
|
||||
for i in range(0, int(63)+1):
|
||||
if S0.u64[i] == 0:
|
||||
tmp = Reg(i)
|
||||
tmp = Reg(i); break
|
||||
D0.i32 = tmp
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
@@ -297,7 +297,7 @@ def _SOP1Op_S_FF1_I32_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
|
||||
tmp = Reg(-1)
|
||||
for i in range(0, int(31)+1):
|
||||
if S0.u32[i] == 1:
|
||||
tmp = Reg(i)
|
||||
tmp = Reg(i); break
|
||||
D0.i32 = tmp
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
@@ -320,7 +320,7 @@ def _SOP1Op_S_FF1_I32_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
|
||||
tmp = Reg(-1)
|
||||
for i in range(0, int(63)+1):
|
||||
if S0.u64[i] == 1:
|
||||
tmp = Reg(i)
|
||||
tmp = Reg(i); break
|
||||
D0.i32 = tmp
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
@@ -343,7 +343,7 @@ def _SOP1Op_S_FLBIT_I32_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
|
||||
tmp = Reg(-1)
|
||||
for i in range(0, int(31)+1):
|
||||
if S0.u32[31 - i] == 1:
|
||||
tmp = Reg(i)
|
||||
tmp = Reg(i); break
|
||||
D0.i32 = tmp
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
@@ -366,7 +366,7 @@ def _SOP1Op_S_FLBIT_I32_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
|
||||
tmp = Reg(-1)
|
||||
for i in range(0, int(63)+1):
|
||||
if S0.u64[63 - i] == 1:
|
||||
tmp = Reg(i)
|
||||
tmp = Reg(i); break
|
||||
D0.i32 = tmp
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
@@ -389,7 +389,7 @@ def _SOP1Op_S_FLBIT_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR
|
||||
tmp = Reg(-1)
|
||||
for i in range(1, int(31)+1):
|
||||
if S0.u32[31 - i] != S0.u32[31]:
|
||||
tmp = Reg(i)
|
||||
tmp = Reg(i); break
|
||||
D0.i32 = tmp
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
@@ -412,7 +412,7 @@ def _SOP1Op_S_FLBIT_I32_I64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
|
||||
tmp = Reg(-1)
|
||||
for i in range(1, int(63)+1):
|
||||
if S0.u64[63 - i] != S0.u64[63]:
|
||||
tmp = Reg(i)
|
||||
tmp = Reg(i); break
|
||||
D0.i32 = tmp
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
@@ -2964,7 +2964,7 @@ def _VOP1Op_V_FFBH_U32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
|
||||
D0.i32 = -1
|
||||
for i in range(0, int(31)+1):
|
||||
if S0.u32[31 - i] == 1:
|
||||
D0.i32 = i
|
||||
D0.i32 = i; break
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
@@ -2984,7 +2984,7 @@ def _VOP1Op_V_FFBL_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
|
||||
D0.i32 = -1
|
||||
for i in range(0, int(31)+1):
|
||||
if S0.u32[i] == 1:
|
||||
D0.i32 = i
|
||||
D0.i32 = i; break
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
@@ -3004,7 +3004,7 @@ def _VOP1Op_V_FFBH_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
|
||||
D0.i32 = -1
|
||||
for i in range(1, int(31)+1):
|
||||
if S0.i32[31 - i] != S0.i32[31]:
|
||||
D0.i32 = i
|
||||
D0.i32 = i; break
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
@@ -14702,7 +14702,7 @@ def _VOP3AOp_V_FFBH_U32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR
|
||||
D0.i32 = -1
|
||||
for i in range(0, int(31)+1):
|
||||
if S0.u32[31 - i] == 1:
|
||||
D0.i32 = i
|
||||
D0.i32 = i; break
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
@@ -14722,7 +14722,7 @@ def _VOP3AOp_V_FFBL_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR
|
||||
D0.i32 = -1
|
||||
for i in range(0, int(31)+1):
|
||||
if S0.u32[i] == 1:
|
||||
D0.i32 = i
|
||||
D0.i32 = i; break
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
@@ -14742,7 +14742,7 @@ def _VOP3AOp_V_FFBH_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR
|
||||
D0.i32 = -1
|
||||
for i in range(1, int(31)+1):
|
||||
if S0.i32[31 - i] != S0.i32[31]:
|
||||
D0.i32 = i
|
||||
D0.i32 = i; break
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
|
||||
@@ -185,7 +185,7 @@ def _SOP1Op_S_CLS_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
|
||||
tmp = Reg(-1)
|
||||
for i in range(1, int(31)+1):
|
||||
if S0.u32[31 - i] != S0.u32[31]:
|
||||
tmp = Reg(i)
|
||||
tmp = Reg(i); break
|
||||
D0.i32 = tmp
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
@@ -208,7 +208,7 @@ def _SOP1Op_S_CLS_I32_I64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
|
||||
tmp = Reg(-1)
|
||||
for i in range(1, int(63)+1):
|
||||
if S0.u64[63 - i] != S0.u64[63]:
|
||||
tmp = Reg(i)
|
||||
tmp = Reg(i); break
|
||||
D0.i32 = tmp
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
@@ -4190,7 +4190,7 @@ def _VOP1Op_V_CLS_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
|
||||
D0.i32 = -1
|
||||
for i in range(1, int(31)+1):
|
||||
if S0.i32[31 - i] != S0.i32[31]:
|
||||
D0.i32 = i
|
||||
D0.i32 = i; break
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
@@ -9472,7 +9472,7 @@ def _VOP3Op_V_CLS_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
|
||||
D0.i32 = -1
|
||||
for i in range(1, int(31)+1):
|
||||
if S0.i32[31 - i] != S0.i32[31]:
|
||||
D0.i32 = i
|
||||
D0.i32 = i; break
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
|
||||
@@ -185,7 +185,7 @@ def _SOP1Op_S_CLS_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
|
||||
tmp = Reg(-1)
|
||||
for i in range(1, int(31)+1):
|
||||
if S0.u32[31 - i] != S0.u32[31]:
|
||||
tmp = Reg(i)
|
||||
tmp = Reg(i); break
|
||||
D0.i32 = tmp
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
@@ -208,7 +208,7 @@ def _SOP1Op_S_CLS_I32_I64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
|
||||
tmp = Reg(-1)
|
||||
for i in range(1, int(63)+1):
|
||||
if S0.u64[63 - i] != S0.u64[63]:
|
||||
tmp = Reg(i)
|
||||
tmp = Reg(i); break
|
||||
D0.i32 = tmp
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
@@ -4184,7 +4184,7 @@ def _VOP1Op_V_CLS_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
|
||||
D0.i32 = -1
|
||||
for i in range(1, int(31)+1):
|
||||
if S0.i32[31 - i] != S0.i32[31]:
|
||||
D0.i32 = i
|
||||
D0.i32 = i; break
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
@@ -9128,7 +9128,7 @@ def _VOP3Op_V_CLS_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
|
||||
D0.i32 = -1
|
||||
for i in range(1, int(31)+1):
|
||||
if S0.i32[31 - i] != S0.i32[31]:
|
||||
D0.i32 = i
|
||||
D0.i32 = i; break
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
|
||||
@@ -642,7 +642,7 @@ def compile_pseudocode(pseudocode: str) -> str:
|
||||
joined_lines.append(line)
|
||||
|
||||
lines = []
|
||||
indent, need_pass = 0, False
|
||||
indent, need_pass, in_first_match_loop = 0, False, False
|
||||
for line in joined_lines:
|
||||
line = line.strip()
|
||||
if not line or line.startswith('//'): continue
|
||||
@@ -671,14 +671,14 @@ def compile_pseudocode(pseudocode: str) -> str:
|
||||
elif line.startswith('endfor'):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
need_pass = False
|
||||
need_pass, in_first_match_loop = False, False
|
||||
elif line.startswith('declare '):
|
||||
pass
|
||||
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
|
||||
start, end = _expr(m[2].strip()), _expr(m[3].strip())
|
||||
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
need_pass, in_first_match_loop = True, True
|
||||
elif '=' in line and not line.startswith('=='):
|
||||
need_pass = False
|
||||
line = line.rstrip(';')
|
||||
@@ -697,7 +697,12 @@ def compile_pseudocode(pseudocode: str) -> str:
|
||||
break
|
||||
else:
|
||||
lhs, rhs = line.split('=', 1)
|
||||
lines.append(' ' * indent + _assign(lhs.strip(), _expr(rhs.strip())))
|
||||
lhs_s, rhs_s = lhs.strip(), rhs.strip()
|
||||
stmt = _assign(lhs_s, _expr(rhs_s))
|
||||
# CLZ/CTZ pattern: assignment of loop var to tmp/D0.i32 in first-match loop needs break
|
||||
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
|
||||
stmt += "; break"
|
||||
lines.append(' ' * indent + stmt)
|
||||
# If we ended with a control statement that needs a body, add pass
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
return '\n'.join(lines)
|
||||
@@ -1014,11 +1019,6 @@ from extra.assembly.amd.pcode import *
|
||||
code = compile_pseudocode(pc)
|
||||
# NOTE: Do NOT add more code.replace() hacks here. Fix issues properly in the DSL
|
||||
# (compile_pseudocode, helper functions, or Reg/TypedView classes) instead.
|
||||
# CLZ/CTZ: The PDF pseudocode searches for the first 1 bit but doesn't break.
|
||||
# Hardware stops at first match. SOP1 uses tmp=i, VOP1/VOP3 use D0.i32=i
|
||||
if 'CLZ' in op.name or 'CTZ' in op.name:
|
||||
code = code.replace('tmp = Reg(i)', 'tmp = Reg(i); break')
|
||||
code = code.replace('D0.i32 = i', 'D0.i32 = i; break')
|
||||
# V_DIV_FMAS_F32/F64: PDF page 449 says 2^32/2^64 but hardware behavior is more complex.
|
||||
# The scale direction depends on S2 (the addend): if exponent(S2) > 127 (i.e., S2 >= 2.0),
|
||||
# scale by 2^+64 (to unscale a numerator that was scaled). Otherwise scale by 2^-64
|
||||
|
||||
@@ -208,6 +208,8 @@ D0.u32 = tmp.u32""")
|
||||
for i in 0 : 31 do
|
||||
if S0.u32[i] == 1 then
|
||||
tmp = i
|
||||
endif
|
||||
endfor
|
||||
D0.i32 = tmp""")
|
||||
ctx = ExecContext(s0=0b1000) # Bit 3 is set
|
||||
ctx.run(code)
|
||||
|
||||
Reference in New Issue
Block a user