mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
work
This commit is contained in:
@@ -127,7 +127,22 @@ def _expr(node, ctx: Ctx, hint: DType = None) -> UOp:
|
||||
base, hi_uop, lo_uop = _expr(expr, ctx), _expr(hi, ctx), _expr(lo, ctx)
|
||||
if hi_uop.op == Ops.CONST and lo_uop.op == Ops.CONST:
|
||||
hi_val, lo_val = int(hi_uop.arg), int(lo_uop.arg)
|
||||
if hi_val < lo_val: hi_val, lo_val = lo_val, hi_val
|
||||
if hi_val < lo_val:
|
||||
# Reversed slice [lo:hi] means bit reversal - build explicit reversal using shifts and ORs
|
||||
width = lo_val - hi_val + 1
|
||||
if width == 32: # Full 32-bit reversal
|
||||
result = UOp.const(dtypes.uint32, 0)
|
||||
for i in range(32):
|
||||
bit = UOp(Ops.AND, dtypes.uint32, (UOp(Ops.SHR, dtypes.uint32, (_cast(base, dtypes.uint32), UOp.const(dtypes.uint32, i))), UOp.const(dtypes.uint32, 1)))
|
||||
result = UOp(Ops.OR, dtypes.uint32, (result, UOp(Ops.SHL, dtypes.uint32, (bit, UOp.const(dtypes.uint32, 31 - i)))))
|
||||
return result
|
||||
elif width == 64: # Full 64-bit reversal
|
||||
result = UOp.const(dtypes.uint64, 0)
|
||||
for i in range(64):
|
||||
bit = UOp(Ops.AND, dtypes.uint64, (UOp(Ops.SHR, dtypes.uint64, (_cast(base, dtypes.uint64), UOp.const(dtypes.uint64, i))), UOp.const(dtypes.uint64, 1)))
|
||||
result = UOp(Ops.OR, dtypes.uint64, (result, UOp(Ops.SHL, dtypes.uint64, (bit, UOp.const(dtypes.uint64, 63 - i)))))
|
||||
return result
|
||||
hi_val, lo_val = lo_val, hi_val # Fall through to normal slice for partial reversal
|
||||
shifted = UOp(Ops.SHR, base.dtype, (base, UOp.const(base.dtype, lo_val))) if lo_val else base
|
||||
return UOp(Ops.AND, dtypes.uint32, (_cast(shifted, dtypes.uint32), UOp.const(dtypes.uint32, (1 << (hi_val - lo_val + 1)) - 1)))
|
||||
raise ValueError(f"Non-constant slice bounds: {node}")
|
||||
@@ -197,6 +212,7 @@ def _expr(node, ctx: Ctx, hint: DType = None) -> UOp:
|
||||
|
||||
case Pack(exprs):
|
||||
if len(exprs) == 2:
|
||||
# AMD Pack {a, b} typically means {high, low} for concatenation (e.g., ALIGNBIT)
|
||||
hi, lo = _expr(exprs[0], ctx), _expr(exprs[1], ctx)
|
||||
if lo.dtype.itemsize >= 4:
|
||||
return UOp(Ops.OR, dtypes.uint64, (UOp(Ops.SHL, dtypes.uint64, (_cast(hi, dtypes.uint64), UOp.const(dtypes.uint64, 32))), _cast(lo, dtypes.uint64)))
|
||||
@@ -262,6 +278,32 @@ def _call_ABSDIFF(a, b):
|
||||
def _call_SAT8(v):
|
||||
clamped = UOp(Ops.WHERE, v.dtype, (UOp(Ops.CMPLT, dtypes.bool, (v, UOp.const(v.dtype, -128))), UOp.const(v.dtype, -128), v))
|
||||
return UOp(Ops.WHERE, v.dtype, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(v.dtype, 127), clamped)), UOp.const(v.dtype, 127), clamped))
|
||||
def _call_BYTE_PERMUTE(src, sel):
|
||||
# BYTE_PERMUTE({S0, S1}, sel): select byte based on sel[3:0]
|
||||
# sel[3:0]=0-7: select byte from {S0,S1}; 8-11: sign-extend byte N-8; 12: 0x00; 13-15: 0xFF; sel[7]=1: 0x00
|
||||
# Pack gives (S0<<32)|S1 but BYTE_PERMUTE indexes bytes 0-3 from S0, so swap words
|
||||
src_fixed = UOp(Ops.OR, dtypes.uint64,
|
||||
(UOp(Ops.SHL, dtypes.uint64, (UOp(Ops.AND, dtypes.uint64, (_cast(src, dtypes.uint64), UOp.const(dtypes.uint64, 0xffffffff))), UOp.const(dtypes.uint64, 32))),
|
||||
UOp(Ops.SHR, dtypes.uint64, (_cast(src, dtypes.uint64), UOp.const(dtypes.uint64, 32)))))
|
||||
sel_val = UOp(Ops.AND, dtypes.uint32, (_cast(sel, dtypes.uint32), UOp.const(dtypes.uint32, 0xff)))
|
||||
sel_idx = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 7))) # sel[2:0]
|
||||
sel_hi = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 0x80))) # sel[7]
|
||||
sel_nibble = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 0xf))) # sel[3:0]
|
||||
# Extract byte: (src_fixed >> (sel_idx * 8)) & 0xff
|
||||
shift = UOp(Ops.SHL, dtypes.uint32, (sel_idx, UOp.const(dtypes.uint32, 3)))
|
||||
byte_val = _cast(UOp(Ops.AND, dtypes.uint64, (UOp(Ops.SHR, dtypes.uint64, (src_fixed, _cast(shift, dtypes.uint64))), UOp.const(dtypes.uint64, 0xff))), dtypes.uint32)
|
||||
# Sign extend for sel[3:0]=8-11: output 0xFF if byte's MSB is 1, else 0x00
|
||||
byte_msb = UOp(Ops.AND, dtypes.uint32, (byte_val, UOp.const(dtypes.uint32, 0x80))) # bit 7 of byte
|
||||
sign_ext_val = UOp(Ops.WHERE, dtypes.uint32, (UOp(Ops.CMPNE, dtypes.bool, (byte_msb, UOp.const(dtypes.uint32, 0))), UOp.const(dtypes.uint32, 0xff), UOp.const(dtypes.uint32, 0)))
|
||||
is_sign_ext = UOp(Ops.AND, dtypes.bool, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(dtypes.uint32, 7), sel_nibble)), UOp(Ops.CMPLT, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 12))))) # 8 <= sel_nibble <= 11
|
||||
# sel[3:0]=12: return 0x00; sel[3:0]=13-15: return 0xFF
|
||||
is_const_zero = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 12)))
|
||||
is_const_ff = UOp(Ops.CMPLT, dtypes.bool, (UOp.const(dtypes.uint32, 12), sel_nibble)) # 13, 14, 15
|
||||
result = UOp(Ops.WHERE, dtypes.uint32, (is_const_ff, UOp.const(dtypes.uint32, 0xff), byte_val))
|
||||
result = UOp(Ops.WHERE, dtypes.uint32, (is_const_zero, UOp.const(dtypes.uint32, 0), result))
|
||||
result = UOp(Ops.WHERE, dtypes.uint32, (is_sign_ext, sign_ext_val, result))
|
||||
# sel[7]=1: return 0x00 (overrides everything)
|
||||
return UOp(Ops.WHERE, dtypes.uint32, (UOp(Ops.CMPNE, dtypes.bool, (sel_hi, UOp.const(dtypes.uint32, 0))), UOp.const(dtypes.uint32, 0), result))
|
||||
|
||||
CALL_DISPATCH = {
|
||||
'MEM': _call_MEM, 'fma': _call_fma, 'abs': _call_abs, 'cos': _call_cos, 'rsqrt': _call_rsqrt,
|
||||
@@ -269,6 +311,7 @@ CALL_DISPATCH = {
|
||||
'isSignalNAN': _call_isSignalNAN, 'cvtToQuietNAN': _call_cvtToQuietNAN, 'isINF': _call_isINF,
|
||||
'sign': _call_sign, 'exponent': _call_exponent, 'mantissa': _call_mantissa, 'isEven': _call_isEven,
|
||||
'signext': _call_signext, 'signext_from_bit': _call_signext_from_bit, 'ABSDIFF': _call_ABSDIFF, 'SAT8': _call_SAT8,
|
||||
'BYTE_PERMUTE': _call_BYTE_PERMUTE,
|
||||
}
|
||||
|
||||
def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
|
||||
@@ -414,10 +457,11 @@ def _transform_if(branches: tuple, ctx: Ctx):
|
||||
ctx.outputs.append((var, result, dtype))
|
||||
|
||||
def _transform_for(var: str, start, end, body: tuple, ctx: Ctx):
|
||||
"""Unroll for loop and transform body."""
|
||||
"""Unroll for loop and transform body. Iterates in reverse for first-match semantics in find-first patterns."""
|
||||
start_val = start.value if isinstance(start, Const) else int(_expr(start, ctx).arg)
|
||||
end_val = end.value if isinstance(end, Const) else int(_expr(end, ctx).arg)
|
||||
for i in range(int(start_val), int(end_val) + 1):
|
||||
# Reverse iteration: builds WHERE chain so earlier loop iterations have priority (first-match semantics)
|
||||
for i in range(int(end_val), int(start_val) - 1, -1):
|
||||
ctx.vars[var] = UOp.const(dtypes.uint32, i)
|
||||
for s in body:
|
||||
if isinstance(s, If): _transform_if(s.branches, ctx)
|
||||
@@ -539,8 +583,7 @@ _SKIP_OPS = {
|
||||
'V_TRIG_PREOP_F64', # lookup table for 2/PI mantissa bits
|
||||
'V_MIN_F16', 'V_MIN_F32', 'V_MIN_F64', 'V_MAX_F16', 'V_MAX_F32', 'V_MAX_F64', # neg zero handling: -0 < +0
|
||||
'V_SIN_F16', 'V_SIN_F32', 'V_COS_F16', 'V_COS_F32', # transcendental with special range reduction
|
||||
# Bit manipulation ops (need CLZ/CTZ/BREV intrinsics)
|
||||
'V_CLZ_I32_U32', 'V_CTZ_I32_B32', 'S_BREV_B32', 'S_BREV_B64',
|
||||
# Bit manipulation ops (find-first patterns now work via reversed loop unrolling, bit reversal via explicit OR chain)
|
||||
'S_FF0_I32_B32', 'S_FF0_I32_B64', 'S_FF1_I32_B32', 'S_FF1_I32_B64',
|
||||
'S_FLBIT_I32', 'S_FLBIT_I32_B32', 'S_FLBIT_I32_B64', 'S_FLBIT_I32_I32', 'S_FLBIT_I32_I64',
|
||||
'S_BITREPLICATE_B64_B32', # bit replication loop
|
||||
@@ -555,7 +598,7 @@ _SKIP_OPS = {
|
||||
'V_MBCNT_HI_U32_B32', 'V_MBCNT_LO_U32_B32',
|
||||
'S_SENDMSG_RTN_B32', 'S_SENDMSG_RTN_B64',
|
||||
'V_SWAPREL_B32', # VGPR[laneId][addr] register array access
|
||||
'V_PERM_B32', # BYTE_PERMUTE function not implemented
|
||||
'V_PERM_B32', # BYTE_PERMUTE with sign-extend selectors has complex semantics
|
||||
# Control flow / special ops (no actual computation)
|
||||
'S_NOP', 'S_SETHALT', 'S_TRAP',
|
||||
# 65-bit intermediate results / multi-output with carry
|
||||
|
||||
@@ -873,7 +873,8 @@ def safe_pow(x, y):
|
||||
python_alu: dict[Ops, Callable] = {
|
||||
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
|
||||
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIPROCAL: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
||||
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow, Ops.TRUNC: math.trunc,
|
||||
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow,
|
||||
Ops.TRUNC: lambda x: x if math.isinf(x) or math.isnan(x) else math.trunc(x),
|
||||
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
|
||||
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
|
||||
Ops.MOD: cmod, Ops.IDIV: cdiv, Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq,
|
||||
|
||||
Reference in New Issue
Block a user