This commit is contained in:
George Hotz
2026-01-04 12:22:01 -08:00
parent 2be5f8b688
commit cfeeab8485
2 changed files with 51 additions and 7 deletions

View File

@@ -127,7 +127,22 @@ def _expr(node, ctx: Ctx, hint: DType = None) -> UOp:
base, hi_uop, lo_uop = _expr(expr, ctx), _expr(hi, ctx), _expr(lo, ctx)
if hi_uop.op == Ops.CONST and lo_uop.op == Ops.CONST:
hi_val, lo_val = int(hi_uop.arg), int(lo_uop.arg)
if hi_val < lo_val: hi_val, lo_val = lo_val, hi_val
if hi_val < lo_val:
# Reversed slice [lo:hi] means bit reversal - build explicit reversal using shifts and ORs
width = lo_val - hi_val + 1
if width == 32: # Full 32-bit reversal
result = UOp.const(dtypes.uint32, 0)
for i in range(32):
bit = UOp(Ops.AND, dtypes.uint32, (UOp(Ops.SHR, dtypes.uint32, (_cast(base, dtypes.uint32), UOp.const(dtypes.uint32, i))), UOp.const(dtypes.uint32, 1)))
result = UOp(Ops.OR, dtypes.uint32, (result, UOp(Ops.SHL, dtypes.uint32, (bit, UOp.const(dtypes.uint32, 31 - i)))))
return result
elif width == 64: # Full 64-bit reversal
result = UOp.const(dtypes.uint64, 0)
for i in range(64):
bit = UOp(Ops.AND, dtypes.uint64, (UOp(Ops.SHR, dtypes.uint64, (_cast(base, dtypes.uint64), UOp.const(dtypes.uint64, i))), UOp.const(dtypes.uint64, 1)))
result = UOp(Ops.OR, dtypes.uint64, (result, UOp(Ops.SHL, dtypes.uint64, (bit, UOp.const(dtypes.uint64, 63 - i)))))
return result
hi_val, lo_val = lo_val, hi_val # Fall through to normal slice for partial reversal
shifted = UOp(Ops.SHR, base.dtype, (base, UOp.const(base.dtype, lo_val))) if lo_val else base
return UOp(Ops.AND, dtypes.uint32, (_cast(shifted, dtypes.uint32), UOp.const(dtypes.uint32, (1 << (hi_val - lo_val + 1)) - 1)))
raise ValueError(f"Non-constant slice bounds: {node}")
@@ -197,6 +212,7 @@ def _expr(node, ctx: Ctx, hint: DType = None) -> UOp:
case Pack(exprs):
if len(exprs) == 2:
# AMD Pack {a, b} typically means {high, low} for concatenation (e.g., ALIGNBIT)
hi, lo = _expr(exprs[0], ctx), _expr(exprs[1], ctx)
if lo.dtype.itemsize >= 4:
return UOp(Ops.OR, dtypes.uint64, (UOp(Ops.SHL, dtypes.uint64, (_cast(hi, dtypes.uint64), UOp.const(dtypes.uint64, 32))), _cast(lo, dtypes.uint64)))
@@ -262,6 +278,32 @@ def _call_ABSDIFF(a, b):
def _call_SAT8(v):
clamped = UOp(Ops.WHERE, v.dtype, (UOp(Ops.CMPLT, dtypes.bool, (v, UOp.const(v.dtype, -128))), UOp.const(v.dtype, -128), v))
return UOp(Ops.WHERE, v.dtype, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(v.dtype, 127), clamped)), UOp.const(v.dtype, 127), clamped))
def _call_BYTE_PERMUTE(src, sel):
# BYTE_PERMUTE({S0, S1}, sel): select byte based on sel[3:0]
# sel[3:0]=0-7: select byte from {S0,S1}; 8-11: sign-extend byte N-8; 12: 0x00; 13-15: 0xFF; sel[7]=1: 0x00
# Pack gives (S0<<32)|S1 but BYTE_PERMUTE indexes bytes 0-3 from S0, so swap words
src_fixed = UOp(Ops.OR, dtypes.uint64,
(UOp(Ops.SHL, dtypes.uint64, (UOp(Ops.AND, dtypes.uint64, (_cast(src, dtypes.uint64), UOp.const(dtypes.uint64, 0xffffffff))), UOp.const(dtypes.uint64, 32))),
UOp(Ops.SHR, dtypes.uint64, (_cast(src, dtypes.uint64), UOp.const(dtypes.uint64, 32)))))
sel_val = UOp(Ops.AND, dtypes.uint32, (_cast(sel, dtypes.uint32), UOp.const(dtypes.uint32, 0xff)))
sel_idx = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 7))) # sel[2:0]
sel_hi = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 0x80))) # sel[7]
sel_nibble = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 0xf))) # sel[3:0]
# Extract byte: (src_fixed >> (sel_idx * 8)) & 0xff
shift = UOp(Ops.SHL, dtypes.uint32, (sel_idx, UOp.const(dtypes.uint32, 3)))
byte_val = _cast(UOp(Ops.AND, dtypes.uint64, (UOp(Ops.SHR, dtypes.uint64, (src_fixed, _cast(shift, dtypes.uint64))), UOp.const(dtypes.uint64, 0xff))), dtypes.uint32)
# Sign extend for sel[3:0]=8-11: output 0xFF if byte's MSB is 1, else 0x00
byte_msb = UOp(Ops.AND, dtypes.uint32, (byte_val, UOp.const(dtypes.uint32, 0x80))) # bit 7 of byte
sign_ext_val = UOp(Ops.WHERE, dtypes.uint32, (UOp(Ops.CMPNE, dtypes.bool, (byte_msb, UOp.const(dtypes.uint32, 0))), UOp.const(dtypes.uint32, 0xff), UOp.const(dtypes.uint32, 0)))
is_sign_ext = UOp(Ops.AND, dtypes.bool, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(dtypes.uint32, 7), sel_nibble)), UOp(Ops.CMPLT, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 12))))) # 8 <= sel_nibble <= 11
# sel[3:0]=12: return 0x00; sel[3:0]=13-15: return 0xFF
is_const_zero = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 12)))
is_const_ff = UOp(Ops.CMPLT, dtypes.bool, (UOp.const(dtypes.uint32, 12), sel_nibble)) # 13, 14, 15
result = UOp(Ops.WHERE, dtypes.uint32, (is_const_ff, UOp.const(dtypes.uint32, 0xff), byte_val))
result = UOp(Ops.WHERE, dtypes.uint32, (is_const_zero, UOp.const(dtypes.uint32, 0), result))
result = UOp(Ops.WHERE, dtypes.uint32, (is_sign_ext, sign_ext_val, result))
# sel[7]=1: return 0x00 (overrides everything)
return UOp(Ops.WHERE, dtypes.uint32, (UOp(Ops.CMPNE, dtypes.bool, (sel_hi, UOp.const(dtypes.uint32, 0))), UOp.const(dtypes.uint32, 0), result))
CALL_DISPATCH = {
'MEM': _call_MEM, 'fma': _call_fma, 'abs': _call_abs, 'cos': _call_cos, 'rsqrt': _call_rsqrt,
@@ -269,6 +311,7 @@ CALL_DISPATCH = {
'isSignalNAN': _call_isSignalNAN, 'cvtToQuietNAN': _call_cvtToQuietNAN, 'isINF': _call_isINF,
'sign': _call_sign, 'exponent': _call_exponent, 'mantissa': _call_mantissa, 'isEven': _call_isEven,
'signext': _call_signext, 'signext_from_bit': _call_signext_from_bit, 'ABSDIFF': _call_ABSDIFF, 'SAT8': _call_SAT8,
'BYTE_PERMUTE': _call_BYTE_PERMUTE,
}
def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
@@ -414,10 +457,11 @@ def _transform_if(branches: tuple, ctx: Ctx):
ctx.outputs.append((var, result, dtype))
def _transform_for(var: str, start, end, body: tuple, ctx: Ctx):
"""Unroll for loop and transform body."""
"""Unroll for loop and transform body. Iterates in reverse for first-match semantics in find-first patterns."""
start_val = start.value if isinstance(start, Const) else int(_expr(start, ctx).arg)
end_val = end.value if isinstance(end, Const) else int(_expr(end, ctx).arg)
for i in range(int(start_val), int(end_val) + 1):
# Reverse iteration: builds WHERE chain so earlier loop iterations have priority (first-match semantics)
for i in range(int(end_val), int(start_val) - 1, -1):
ctx.vars[var] = UOp.const(dtypes.uint32, i)
for s in body:
if isinstance(s, If): _transform_if(s.branches, ctx)
@@ -539,8 +583,7 @@ _SKIP_OPS = {
'V_TRIG_PREOP_F64', # lookup table for 2/PI mantissa bits
'V_MIN_F16', 'V_MIN_F32', 'V_MIN_F64', 'V_MAX_F16', 'V_MAX_F32', 'V_MAX_F64', # neg zero handling: -0 < +0
'V_SIN_F16', 'V_SIN_F32', 'V_COS_F16', 'V_COS_F32', # transcendental with special range reduction
# Bit manipulation ops (need CLZ/CTZ/BREV intrinsics)
'V_CLZ_I32_U32', 'V_CTZ_I32_B32', 'S_BREV_B32', 'S_BREV_B64',
# Bit manipulation ops (find-first patterns now work via reversed loop unrolling, bit reversal via explicit OR chain)
'S_FF0_I32_B32', 'S_FF0_I32_B64', 'S_FF1_I32_B32', 'S_FF1_I32_B64',
'S_FLBIT_I32', 'S_FLBIT_I32_B32', 'S_FLBIT_I32_B64', 'S_FLBIT_I32_I32', 'S_FLBIT_I32_I64',
'S_BITREPLICATE_B64_B32', # bit replication loop
@@ -555,7 +598,7 @@ _SKIP_OPS = {
'V_MBCNT_HI_U32_B32', 'V_MBCNT_LO_U32_B32',
'S_SENDMSG_RTN_B32', 'S_SENDMSG_RTN_B64',
'V_SWAPREL_B32', # VGPR[laneId][addr] register array access
'V_PERM_B32', # BYTE_PERMUTE function not implemented
'V_PERM_B32', # BYTE_PERMUTE with sign-extend selectors has complex semantics
# Control flow / special ops (no actual computation)
'S_NOP', 'S_SETHALT', 'S_TRAP',
# 65-bit intermediate results / multi-output with carry

View File

@@ -873,7 +873,8 @@ def safe_pow(x, y):
python_alu: dict[Ops, Callable] = {
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIPROCAL: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow, Ops.TRUNC: math.trunc,
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow,
Ops.TRUNC: lambda x: x if math.isinf(x) or math.isnan(x) else math.trunc(x),
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
Ops.MOD: cmod, Ops.IDIV: cdiv, Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq,