diff --git a/extra/assembly/amd/ucode.py b/extra/assembly/amd/ucode.py index f6efe83385..c3d430166b 100644 --- a/extra/assembly/amd/ucode.py +++ b/extra/assembly/amd/ucode.py @@ -127,7 +127,22 @@ def _expr(node, ctx: Ctx, hint: DType = None) -> UOp: base, hi_uop, lo_uop = _expr(expr, ctx), _expr(hi, ctx), _expr(lo, ctx) if hi_uop.op == Ops.CONST and lo_uop.op == Ops.CONST: hi_val, lo_val = int(hi_uop.arg), int(lo_uop.arg) - if hi_val < lo_val: hi_val, lo_val = lo_val, hi_val + if hi_val < lo_val: + # Reversed slice [lo:hi] means bit reversal - build explicit reversal using shifts and ORs + width = lo_val - hi_val + 1 + if width == 32: # Full 32-bit reversal + result = UOp.const(dtypes.uint32, 0) + for i in range(32): + bit = UOp(Ops.AND, dtypes.uint32, (UOp(Ops.SHR, dtypes.uint32, (_cast(base, dtypes.uint32), UOp.const(dtypes.uint32, i))), UOp.const(dtypes.uint32, 1))) + result = UOp(Ops.OR, dtypes.uint32, (result, UOp(Ops.SHL, dtypes.uint32, (bit, UOp.const(dtypes.uint32, 31 - i))))) + return result + elif width == 64: # Full 64-bit reversal + result = UOp.const(dtypes.uint64, 0) + for i in range(64): + bit = UOp(Ops.AND, dtypes.uint64, (UOp(Ops.SHR, dtypes.uint64, (_cast(base, dtypes.uint64), UOp.const(dtypes.uint64, i))), UOp.const(dtypes.uint64, 1))) + result = UOp(Ops.OR, dtypes.uint64, (result, UOp(Ops.SHL, dtypes.uint64, (bit, UOp.const(dtypes.uint64, 63 - i))))) + return result + hi_val, lo_val = lo_val, hi_val # Fall through to normal slice for partial reversal shifted = UOp(Ops.SHR, base.dtype, (base, UOp.const(base.dtype, lo_val))) if lo_val else base return UOp(Ops.AND, dtypes.uint32, (_cast(shifted, dtypes.uint32), UOp.const(dtypes.uint32, (1 << (hi_val - lo_val + 1)) - 1))) raise ValueError(f"Non-constant slice bounds: {node}") @@ -197,6 +212,7 @@ def _expr(node, ctx: Ctx, hint: DType = None) -> UOp: case Pack(exprs): if len(exprs) == 2: + # AMD Pack {a, b} typically means {high, low} for concatenation (e.g., ALIGNBIT) hi, lo = _expr(exprs[0], ctx), _expr(exprs[1], ctx) if lo.dtype.itemsize >= 4: return UOp(Ops.OR, dtypes.uint64, (UOp(Ops.SHL, dtypes.uint64, (_cast(hi, dtypes.uint64), UOp.const(dtypes.uint64, 32))), _cast(lo, dtypes.uint64))) @@ -262,6 +278,32 @@ def _call_ABSDIFF(a, b): def _call_SAT8(v): clamped = UOp(Ops.WHERE, v.dtype, (UOp(Ops.CMPLT, dtypes.bool, (v, UOp.const(v.dtype, -128))), UOp.const(v.dtype, -128), v)) return UOp(Ops.WHERE, v.dtype, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(v.dtype, 127), clamped)), UOp.const(v.dtype, 127), clamped)) +def _call_BYTE_PERMUTE(src, sel): + # BYTE_PERMUTE({S0, S1}, sel): select byte based on sel[3:0] + # sel[3:0]=0-7: select byte from {S0,S1}; 8-11: sign-extend byte N-8; 12: 0x00; 13-15: 0xFF; sel[7]=1: 0x00 + # Pack gives (S0<<32)|S1 but BYTE_PERMUTE indexes bytes 0-3 from S0, so swap words + src_fixed = UOp(Ops.OR, dtypes.uint64, + (UOp(Ops.SHL, dtypes.uint64, (UOp(Ops.AND, dtypes.uint64, (_cast(src, dtypes.uint64), UOp.const(dtypes.uint64, 0xffffffff))), UOp.const(dtypes.uint64, 32))), + UOp(Ops.SHR, dtypes.uint64, (_cast(src, dtypes.uint64), UOp.const(dtypes.uint64, 32))))) + sel_val = UOp(Ops.AND, dtypes.uint32, (_cast(sel, dtypes.uint32), UOp.const(dtypes.uint32, 0xff))) + sel_idx = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 7))) # sel[2:0] + sel_hi = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 0x80))) # sel[7] + sel_nibble = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 0xf))) # sel[3:0] + # Extract byte: (src_fixed >> (sel_idx * 8)) & 0xff + shift = UOp(Ops.SHL, dtypes.uint32, (sel_idx, UOp.const(dtypes.uint32, 3))) + byte_val = _cast(UOp(Ops.AND, dtypes.uint64, (UOp(Ops.SHR, dtypes.uint64, (src_fixed, _cast(shift, dtypes.uint64))), UOp.const(dtypes.uint64, 0xff))), dtypes.uint32) + # Sign extend for sel[3:0]=8-11: output 0xFF if byte's MSB is 1, else 0x00 + byte_msb = UOp(Ops.AND, dtypes.uint32, (byte_val, UOp.const(dtypes.uint32, 0x80))) # bit 7 of byte + sign_ext_val = UOp(Ops.WHERE, dtypes.uint32, (UOp(Ops.CMPNE, dtypes.bool, (byte_msb, UOp.const(dtypes.uint32, 0))), UOp.const(dtypes.uint32, 0xff), UOp.const(dtypes.uint32, 0))) + is_sign_ext = UOp(Ops.AND, dtypes.bool, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(dtypes.uint32, 7), sel_nibble)), UOp(Ops.CMPLT, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 12))))) # 8 <= sel_nibble <= 11 + # sel[3:0]=12: return 0x00; sel[3:0]=13-15: return 0xFF + is_const_zero = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 12))) + is_const_ff = UOp(Ops.CMPLT, dtypes.bool, (UOp.const(dtypes.uint32, 12), sel_nibble)) # 13, 14, 15 + result = UOp(Ops.WHERE, dtypes.uint32, (is_const_ff, UOp.const(dtypes.uint32, 0xff), byte_val)) + result = UOp(Ops.WHERE, dtypes.uint32, (is_const_zero, UOp.const(dtypes.uint32, 0), result)) + result = UOp(Ops.WHERE, dtypes.uint32, (is_sign_ext, sign_ext_val, result)) + # sel[7]=1: return 0x00 (overrides everything) + return UOp(Ops.WHERE, dtypes.uint32, (UOp(Ops.CMPNE, dtypes.bool, (sel_hi, UOp.const(dtypes.uint32, 0))), UOp.const(dtypes.uint32, 0), result)) CALL_DISPATCH = { 'MEM': _call_MEM, 'fma': _call_fma, 'abs': _call_abs, 'cos': _call_cos, 'rsqrt': _call_rsqrt, @@ -269,6 +311,7 @@ CALL_DISPATCH = { 'isSignalNAN': _call_isSignalNAN, 'cvtToQuietNAN': _call_cvtToQuietNAN, 'isINF': _call_isINF, 'sign': _call_sign, 'exponent': _call_exponent, 'mantissa': _call_mantissa, 'isEven': _call_isEven, 'signext': _call_signext, 'signext_from_bit': _call_signext_from_bit, 'ABSDIFF': _call_ABSDIFF, 'SAT8': _call_SAT8, + 'BYTE_PERMUTE': _call_BYTE_PERMUTE, } def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp: @@ -414,10 +457,11 @@ def _transform_if(branches: tuple, ctx: Ctx): ctx.outputs.append((var, result, dtype)) def _transform_for(var: str, start, end, body: tuple, ctx: Ctx): - """Unroll for loop and transform body.""" + """Unroll for loop and transform body. Iterates in reverse for first-match semantics in find-first patterns.""" start_val = start.value if isinstance(start, Const) else int(_expr(start, ctx).arg) end_val = end.value if isinstance(end, Const) else int(_expr(end, ctx).arg) - for i in range(int(start_val), int(end_val) + 1): + # Reverse iteration: builds WHERE chain so earlier loop iterations have priority (first-match semantics) + for i in range(int(end_val), int(start_val) - 1, -1): ctx.vars[var] = UOp.const(dtypes.uint32, i) for s in body: if isinstance(s, If): _transform_if(s.branches, ctx) @@ -539,8 +583,7 @@ _SKIP_OPS = { 'V_TRIG_PREOP_F64', # lookup table for 2/PI mantissa bits 'V_MIN_F16', 'V_MIN_F32', 'V_MIN_F64', 'V_MAX_F16', 'V_MAX_F32', 'V_MAX_F64', # neg zero handling: -0 < +0 'V_SIN_F16', 'V_SIN_F32', 'V_COS_F16', 'V_COS_F32', # transcendental with special range reduction - # Bit manipulation ops (need CLZ/CTZ/BREV intrinsics) - 'V_CLZ_I32_U32', 'V_CTZ_I32_B32', 'S_BREV_B32', 'S_BREV_B64', + # Bit manipulation ops (find-first patterns now work via reversed loop unrolling, bit reversal via explicit OR chain) 'S_FF0_I32_B32', 'S_FF0_I32_B64', 'S_FF1_I32_B32', 'S_FF1_I32_B64', 'S_FLBIT_I32', 'S_FLBIT_I32_B32', 'S_FLBIT_I32_B64', 'S_FLBIT_I32_I32', 'S_FLBIT_I32_I64', 'S_BITREPLICATE_B64_B32', # bit replication loop @@ -555,7 +598,7 @@ _SKIP_OPS = { 'V_MBCNT_HI_U32_B32', 'V_MBCNT_LO_U32_B32', 'S_SENDMSG_RTN_B32', 'S_SENDMSG_RTN_B64', 'V_SWAPREL_B32', # VGPR[laneId][addr] register array access - 'V_PERM_B32', # BYTE_PERMUTE function not implemented + 'V_PERM_B32', # BYTE_PERMUTE with sign-extend selectors has complex semantics # Control flow / special ops (no actual computation) 'S_NOP', 'S_SETHALT', 'S_TRAP', # 65-bit intermediate results / multi-output with carry diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 80fe51002a..b93e76d2e1 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -873,7 +873,8 @@ def safe_pow(x, y): python_alu: dict[Ops, Callable] = { Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2, Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIPROCAL: lambda x: 1/x if x != 0 else math.copysign(math.inf, x), - Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow, Ops.TRUNC: math.trunc, + Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow, + Ops.TRUNC: lambda x: x if math.isinf(x) or math.isnan(x) else math.trunc(x), Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt, Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max, Ops.MOD: cmod, Ops.IDIV: cdiv, Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq,