mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
fixes
This commit is contained in:
@@ -89,6 +89,14 @@ def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp:
|
||||
# Handle Var.type
|
||||
if inner.op == Ops.DEFINE_VAR and inner.arg[1] is None:
|
||||
name = inner.arg[0]
|
||||
# Handle INF.f32, INF.f64, NAN.f32, NAN.f64, etc.
|
||||
if name == 'INF' or name in ('+INF', '-INF'):
|
||||
return UOp.const(dt, float('-inf') if name.startswith('-') else float('inf'))
|
||||
if name == 'NAN':
|
||||
return UOp.const(dt, float('nan'))
|
||||
if name == 'DENORM':
|
||||
denorm = {dtypes.float32: 1.17549435e-38, dtypes.float64: 2.2250738585072014e-308}.get(dt, 1.17549435e-38)
|
||||
return UOp.const(dt, denorm)
|
||||
if name in ('VCCZ', 'EXECZ'):
|
||||
return _cast(UOp(Ops.CMPEQ, dtypes.bool, (ctx.vars.get('VCC' if name == 'VCCZ' else 'EXEC'), UOp.const(dtypes.uint64, 0))), dtypes.uint32)
|
||||
if name.startswith('WAVE_STATUS.COND_DBG'): return UOp.const(dtypes.uint32, 0)
|
||||
@@ -118,6 +126,8 @@ def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp:
|
||||
# Single-bit slice: base[idx:idx] -> (base >> idx) & 1
|
||||
if hi_expr is lo_expr:
|
||||
return UOp(Ops.AND, dtypes.uint32, (_cast(UOp(Ops.SHR, base.dtype, (base, _cast(lo_uop, base.dtype))), dtypes.uint32), UOp.const(dtypes.uint32, 1)))
|
||||
# Simplify the bounds to get constant values (needed when loop variables are substituted)
|
||||
hi_uop, lo_uop = hi_uop.simplify(), lo_uop.simplify()
|
||||
if hi_uop.op == Ops.CONST and lo_uop.op == Ops.CONST:
|
||||
hi_val, lo_val = int(hi_uop.arg), int(lo_uop.arg)
|
||||
if hi_val < lo_val:
|
||||
@@ -230,7 +240,29 @@ def _call_floor(v):
|
||||
return UOp(Ops.WHERE, v.dtype, (needs_adjust, UOp(Ops.SUB, v.dtype, (truncated, UOp.const(v.dtype, 1))), truncated))
|
||||
def _call_fract(v): return UOp(Ops.SUB, v.dtype, (v, _call_floor(v)))
|
||||
def _call_isNAN(v): return UOp(Ops.CMPNE, dtypes.bool, (v, v))
|
||||
def _call_isSignalNAN(v): return UOp.const(dtypes.bool, 0)
|
||||
def _call_isSignalNAN(v):
|
||||
# Signaling NaN: exponent all 1s, mantissa non-zero, MSB of mantissa is 0
|
||||
# Unwrap CAST to check on original float type
|
||||
while v.op == Ops.CAST and v.dtype in FLOATS: v = v.src[0]
|
||||
uint_dt, _, exp_shift, exp_mask, mant_mask = FP_INFO.get(v.dtype, FP_INFO[dtypes.float32])
|
||||
bits = UOp(Ops.BITCAST, uint_dt, (v,))
|
||||
exp = UOp(Ops.AND, uint_dt, (UOp(Ops.SHR, uint_dt, (bits, UOp.const(uint_dt, exp_shift))), UOp.const(uint_dt, exp_mask)))
|
||||
mant = UOp(Ops.AND, uint_dt, (bits, UOp.const(uint_dt, mant_mask)))
|
||||
quiet_bit = {dtypes.float64: 0x8000000000000, dtypes.float32: 0x400000, dtypes.float16: 0x200}.get(v.dtype, 0x400000)
|
||||
is_exp_all_ones = UOp(Ops.CMPEQ, dtypes.bool, (exp, UOp.const(uint_dt, exp_mask)))
|
||||
is_mant_nonzero = UOp(Ops.CMPNE, dtypes.bool, (mant, UOp.const(uint_dt, 0)))
|
||||
is_quiet_bit_clear = UOp(Ops.CMPEQ, dtypes.bool, (UOp(Ops.AND, uint_dt, (mant, UOp.const(uint_dt, quiet_bit))), UOp.const(uint_dt, 0)))
|
||||
return UOp(Ops.AND, dtypes.bool, (UOp(Ops.AND, dtypes.bool, (is_exp_all_ones, is_mant_nonzero)), is_quiet_bit_clear))
|
||||
def _call_isQuietNAN(v):
|
||||
# Quiet NaN: exponent all 1s, MSB of mantissa is 1
|
||||
while v.op == Ops.CAST and v.dtype in FLOATS: v = v.src[0]
|
||||
uint_dt, _, exp_shift, exp_mask, mant_mask = FP_INFO.get(v.dtype, FP_INFO[dtypes.float32])
|
||||
bits = UOp(Ops.BITCAST, uint_dt, (v,))
|
||||
exp = UOp(Ops.AND, uint_dt, (UOp(Ops.SHR, uint_dt, (bits, UOp.const(uint_dt, exp_shift))), UOp.const(uint_dt, exp_mask)))
|
||||
quiet_bit = {dtypes.float64: 0x8000000000000, dtypes.float32: 0x400000, dtypes.float16: 0x200}.get(v.dtype, 0x400000)
|
||||
is_exp_all_ones = UOp(Ops.CMPEQ, dtypes.bool, (exp, UOp.const(uint_dt, exp_mask)))
|
||||
is_quiet_bit_set = UOp(Ops.CMPNE, dtypes.bool, (UOp(Ops.AND, uint_dt, (bits, UOp.const(uint_dt, quiet_bit))), UOp.const(uint_dt, 0)))
|
||||
return UOp(Ops.AND, dtypes.bool, (is_exp_all_ones, is_quiet_bit_set))
|
||||
def _call_cvtToQuietNAN(v): return v
|
||||
def _call_isINF(v):
|
||||
return UOp(Ops.OR, dtypes.bool, (UOp(Ops.CMPEQ, dtypes.bool, (v, UOp.const(v.dtype, float('inf')))),
|
||||
@@ -285,7 +317,7 @@ def _call_BYTE_PERMUTE(src, sel):
|
||||
|
||||
CALL_DISPATCH = {
|
||||
'MEM': _call_MEM, 'fma': _call_fma, 'abs': _call_abs, 'cos': _call_cos, 'rsqrt': _call_rsqrt,
|
||||
'clamp': _call_clamp, 'floor': _call_floor, 'fract': _call_fract, 'isNAN': _call_isNAN, 'isQuietNAN': _call_isNAN,
|
||||
'clamp': _call_clamp, 'floor': _call_floor, 'fract': _call_fract, 'isNAN': _call_isNAN, 'isQuietNAN': _call_isQuietNAN,
|
||||
'isSignalNAN': _call_isSignalNAN, 'cvtToQuietNAN': _call_cvtToQuietNAN, 'isINF': _call_isINF,
|
||||
'sign': _call_sign, 'exponent': _call_exponent, 'mantissa': _call_mantissa, 'isEven': _call_isEven,
|
||||
'signext': _call_signext, 'signext_from_bit': _call_signext_from_bit, 'ABSDIFF': _call_ABSDIFF, 'SAT8': _call_SAT8,
|
||||
@@ -299,6 +331,10 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
|
||||
result_dt = a[0].dtype if a[0].dtype in FLOATS else hint or dtypes.float32
|
||||
return UOp(Ops.EXP2, result_dt, (a[1] if a[1].dtype == result_dt else UOp(Ops.CAST, result_dt, (a[1],)),))
|
||||
if name in MATH_OPS: return UOp(MATH_OPS[name], a[0].dtype, (a[0],))
|
||||
if name == 'ldexp':
|
||||
# ldexp(x, exp) = x * 2^exp
|
||||
exp_float = UOp(Ops.CAST, a[0].dtype, (a[1],)) if a[1].dtype != a[0].dtype else a[1]
|
||||
return UOp(Ops.MUL, a[0].dtype, (a[0], UOp(Ops.EXP2, a[0].dtype, (exp_float,))))
|
||||
if name in ('min', 'max'):
|
||||
return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, ((a[0], a[1]) if name == 'min' else (a[1], a[0]))), a[0], a[1]))
|
||||
if name in CVT_MAP:
|
||||
@@ -350,12 +386,18 @@ def _get_lhs_info(lhs: UOp, ctx: Ctx) -> tuple[str, DType, int|None, int|None, s
|
||||
return name, dtypes.uint32, int(hi), int(lo), None
|
||||
case UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)) if lhs.src[1] is lhs.src[2]:
|
||||
return name, dt, None, None, idx
|
||||
# Handle tmp[i] where i is a variable (single-bit index)
|
||||
case UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)) if lhs.src[1] is lhs.src[2]:
|
||||
return name, dtypes.uint32, None, None, idx
|
||||
case UOp(Ops.DEFINE_VAR, _, _, (name, None, None)): return name, dtypes.uint32, None, None, None
|
||||
raise ValueError(f"Cannot parse LHS: {lhs}")
|
||||
|
||||
def _stmt(stmt, ctx: Ctx):
|
||||
match stmt:
|
||||
case Declare(name, dtype): ctx.decls[name] = dtype
|
||||
case Declare(name, dtype):
|
||||
ctx.decls[name] = dtype
|
||||
# Initialize declared variable with zero value
|
||||
ctx.vars[name] = UOp.const(dtype, 0)
|
||||
case Assign(lhs, rhs):
|
||||
# Handle MEM[addr].type = value -> memory store
|
||||
if lhs.op == Ops.BITCAST and lhs.src[0].op == Ops.CUSTOM and lhs.src[0].arg == 'MEM':
|
||||
@@ -367,6 +409,26 @@ def _stmt(stmt, ctx: Ctx):
|
||||
ctx.mem_stores.append(UOp(Ops.STORE, dtypes.void, (idx, val_uop)))
|
||||
return
|
||||
|
||||
# Handle CAT (multi-output assignment) like {D1.u1, D0.u64} = ...
|
||||
if lhs.op == Ops.CAT:
|
||||
rhs_uop = _expr(rhs, ctx)
|
||||
out_vars = ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC', 'SDATA', 'VDATA', 'RETURN_DATA')
|
||||
offset = 0
|
||||
for part in reversed(lhs.src): # CAT is hi, lo order, so reverse to get lo first
|
||||
if part.op == Ops.BITCAST and part.src[0].op == Ops.DEFINE_VAR:
|
||||
dt, name = part.dtype, part.src[0].arg[0]
|
||||
# Map non-standard dtypes to real dtypes
|
||||
if dt.name == 'u1': bits, real_dt = 1, dtypes.uint32
|
||||
elif dt == dtypes.ulong or dt.name == 'ulong': bits, real_dt = 64, dtypes.uint64
|
||||
else: bits, real_dt = dt.itemsize * 8, dt
|
||||
mask = (1 << bits) - 1
|
||||
extracted = UOp(Ops.AND, rhs_uop.dtype, (UOp(Ops.SHR, rhs_uop.dtype, (rhs_uop, UOp.const(rhs_uop.dtype, offset))), UOp.const(rhs_uop.dtype, mask)))
|
||||
val = _cast(extracted, real_dt)
|
||||
ctx.vars[name] = val
|
||||
if name in out_vars: ctx.outputs.append((name, val, real_dt))
|
||||
offset += bits
|
||||
return
|
||||
|
||||
var, dtype, hi, lo, idx_var = _get_lhs_info(lhs, ctx)
|
||||
out_vars = ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC', 'SDATA', 'VDATA', 'RETURN_DATA')
|
||||
|
||||
@@ -420,25 +482,41 @@ def _transform_if(branches: tuple, ctx: Ctx):
|
||||
for s in body: _stmt(s, sub_ctx)
|
||||
parsed.append((cond_uop, sub_ctx))
|
||||
|
||||
# Collect all assigned variables across all branches
|
||||
# Collect all assigned variables across all branches (both outputs and locals)
|
||||
out_vars = ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC', 'SDATA', 'VDATA', 'RETURN_DATA')
|
||||
assigned = set()
|
||||
assigned_outputs = set()
|
||||
assigned_locals = set()
|
||||
for _, sub_ctx in parsed:
|
||||
for name, _, _ in sub_ctx.outputs:
|
||||
if name in out_vars: assigned.add(name)
|
||||
if name in out_vars: assigned_outputs.add(name)
|
||||
# Track local variables that were modified in branches
|
||||
for name, val in sub_ctx.vars.items():
|
||||
if name not in ctx.vars or ctx.vars[name] is not val:
|
||||
if name not in out_vars and name not in INPUT_VARS:
|
||||
assigned_locals.add(name)
|
||||
|
||||
for var in assigned:
|
||||
# Merge output variables
|
||||
for var in assigned_outputs:
|
||||
dtype = next((d for _, sub_ctx in parsed for n, _, d in sub_ctx.outputs if n == var), dtypes.uint32)
|
||||
result = ctx.vars.get(var, UOp.const(dtype, 0))
|
||||
for cond_uop, sub_ctx in reversed(parsed):
|
||||
branch_val = next((u for n, u, _ in sub_ctx.outputs if n == var), None)
|
||||
if branch_val is not None:
|
||||
result = branch_val if cond_uop is None else UOp(Ops.WHERE, branch_val.dtype, (cond_uop, branch_val, _cast(result, branch_val.dtype)))
|
||||
|
||||
ctx.vars[var] = result
|
||||
ctx.outputs = [(n, u, d) for n, u, d in ctx.outputs if n != var]
|
||||
ctx.outputs.append((var, result, dtype))
|
||||
|
||||
# Merge local variables (like 'result')
|
||||
for var in assigned_locals:
|
||||
dtype = ctx.decls.get(var, dtypes.uint32)
|
||||
result = ctx.vars.get(var, UOp.const(dtype, 0))
|
||||
for cond_uop, sub_ctx in reversed(parsed):
|
||||
if var in sub_ctx.vars and (var not in ctx.vars or sub_ctx.vars[var] is not ctx.vars[var]):
|
||||
branch_val = sub_ctx.vars[var]
|
||||
result = branch_val if cond_uop is None else UOp(Ops.WHERE, branch_val.dtype, (cond_uop, branch_val, _cast(result, branch_val.dtype)))
|
||||
ctx.vars[var] = result
|
||||
|
||||
def _transform_for(var: str, start: UOp, end: UOp, body: tuple, ctx: Ctx):
|
||||
start_val = start.arg if start.op == Ops.CONST else int(_expr(start, ctx).arg)
|
||||
end_val = end.arg if end.op == Ops.CONST else int(_expr(end, ctx).arg)
|
||||
@@ -478,6 +556,28 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
|
||||
is_lds = any(u.op == Ops.DEFINE_LOCAL for u in topo)
|
||||
is_mem = bool(mem_stores) or any(u.op == Ops.LOAD for u in topo)
|
||||
|
||||
def _eval_uop(u: UOp) -> int|float|None:
|
||||
"""Recursively evaluate a UOp tree to a constant value."""
|
||||
if u.op == Ops.CONST: return u.arg
|
||||
if u.op == Ops.CAST:
|
||||
v = _eval_uop(u.src[0])
|
||||
return v if v is not None else None
|
||||
if u.op == Ops.BITCAST:
|
||||
v = _eval_uop(u.src[0])
|
||||
return v if v is not None else None
|
||||
if u.op in (Ops.ADD, Ops.SUB, Ops.MUL, Ops.AND, Ops.OR, Ops.XOR, Ops.SHR, Ops.SHL):
|
||||
l, r = _eval_uop(u.src[0]), _eval_uop(u.src[1])
|
||||
if l is None or r is None: return None
|
||||
if u.op == Ops.ADD: return l + r
|
||||
if u.op == Ops.SUB: return l - r
|
||||
if u.op == Ops.MUL: return l * r
|
||||
if u.op == Ops.AND: return int(l) & int(r)
|
||||
if u.op == Ops.OR: return int(l) | int(r)
|
||||
if u.op == Ops.XOR: return int(l) ^ int(r)
|
||||
if u.op == Ops.SHR: return int(l) >> int(r)
|
||||
if u.op == Ops.SHL: return int(l) << int(r)
|
||||
return None
|
||||
|
||||
def _extract_results(s, MEM=None):
|
||||
for u in s.src:
|
||||
if u.op == Ops.STORE:
|
||||
@@ -489,8 +589,13 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
|
||||
setattr(MEM[addr], acc, int(val))
|
||||
result = {}
|
||||
for i, (name, dtype) in enumerate(output_info):
|
||||
if i >= len(s.src) or s.src[i].op != Ops.CONST: continue
|
||||
result[name] = _float_to_bits(s.src[i].arg, dtype) if dtype in FLOATS else int(s.src[i].arg) & (0xffffffff if dtype.itemsize <= 4 else 0xffffffffffffffff)
|
||||
if i >= len(s.src): continue
|
||||
if s.src[i].op == Ops.CONST:
|
||||
val = s.src[i].arg
|
||||
else:
|
||||
val = _eval_uop(s.src[i])
|
||||
if val is None: continue
|
||||
result[name] = _float_to_bits(val, dtype) if dtype in FLOATS else int(val) & (0xffffffff if dtype.itemsize <= 4 else 0xffffffffffffffff)
|
||||
return result
|
||||
|
||||
if is_lds:
|
||||
@@ -543,44 +648,34 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
|
||||
return _extract_results(sink.substitute(dvars).simplify())
|
||||
return fn
|
||||
|
||||
_SKIP_OPS = {
|
||||
'V_DIV_FMAS_F32', 'V_DIV_FMAS_F64',
|
||||
'V_CMP_CLASS_F16', 'V_CMP_CLASS_F32', 'V_CMP_CLASS_F64', 'V_CMPX_CLASS_F16', 'V_CMPX_CLASS_F32', 'V_CMPX_CLASS_F64',
|
||||
'V_FREXP_MANT_F16', 'V_FREXP_MANT_F32', 'V_FREXP_MANT_F64',
|
||||
'V_DIV_FIXUP_F16', 'V_DIV_FIXUP_F32', 'V_DIV_SCALE_F32', 'V_DIV_SCALE_F64',
|
||||
'V_TRIG_PREOP_F64',
|
||||
'S_FF0_I32_B32', 'S_FF0_I32_B64', 'S_FF1_I32_B32', 'S_FF1_I32_B64',
|
||||
'S_FLBIT_I32', 'S_FLBIT_I32_B32', 'S_FLBIT_I32_B64', 'S_FLBIT_I32_I32', 'S_FLBIT_I32_I64',
|
||||
'S_BITREPLICATE_B64_B32',
|
||||
'S_BITSET0_B32', 'S_BITSET0_B64', 'S_BITSET1_B32', 'S_BITSET1_B64',
|
||||
'S_QUADMASK_B32', 'S_QUADMASK_B64', 'S_WQM_B32', 'S_WQM_B64',
|
||||
'S_GETREG_B32', 'S_SETREG_B32', 'S_SETREG_IMM32_B32',
|
||||
'S_MOVRELD_B32', 'S_MOVRELD_B64', 'S_MOVRELS_B32', 'S_MOVRELS_B64', 'S_MOVRELSD_2_B32',
|
||||
'V_MOVRELD_B32', 'V_MOVRELS_B32', 'V_MOVRELSD_B32', 'V_MOVRELSD_2_B32',
|
||||
'V_READFIRSTLANE_B32', 'V_WRITELANE_B32', 'V_READLANE_B32',
|
||||
'V_PERMLANE16_B32', 'V_PERMLANE64_B32', 'V_PERMLANEX16_B32',
|
||||
'V_MBCNT_HI_U32_B32', 'V_MBCNT_LO_U32_B32',
|
||||
'S_SENDMSG_RTN_B32', 'S_SENDMSG_RTN_B64',
|
||||
'V_SWAPREL_B32', 'V_PERM_B32',
|
||||
'S_NOP', 'S_SETHALT', 'S_TRAP',
|
||||
'V_MAD_U64_U32', 'V_MAD_I64_I32',
|
||||
'V_DOT2_F32_BF16',
|
||||
'V_DOT4_I32_IU8', 'V_DOT4_U32_U8',
|
||||
'V_DOT8_I32_IU4', 'V_DOT8_U32_U4',
|
||||
'V_FMA_MIX_F32', 'V_FMA_MIXLO_F16', 'V_FMA_MIXHI_F16',
|
||||
'V_CVT_OFF_F32_I4',
|
||||
}
|
||||
_SKIP_OPS: set[str] = set()
|
||||
|
||||
_PCODE_PATTERNS = ('LDS[', 'LDS(', 'VGPR[', 'SGPR[', 'GPR[', 'GS_REGS', 'thread_in[', 'thread_out[', 'thread_valid[',
|
||||
'DATA2', 'OFFSET0', 'OFFSET1')
|
||||
_PCODE_PATTERNS = ('LDS[', 'LDS(', 'VGPR[', 'SGPR[', 'GPR[', 'GS_REGS', 'thread_in[', 'thread_out[', 'thread_valid[')
|
||||
_WIDE_OUTPUT_PATTERNS = ('SDATA[95', 'SDATA[127', 'SDATA[159', 'SDATA[191', 'SDATA[223', 'SDATA[255',
|
||||
'VDATA[95', 'VDATA[127')
|
||||
|
||||
def _apply_pseudocode_fixes(op_name: str, pcode: str) -> str:
|
||||
"""Apply known fixes for PDF pseudocode bugs - same as pcode.py but for raw pseudocode."""
|
||||
# V_DIV_FMAS: fix scaling factor
|
||||
if op_name == 'V_DIV_FMAS_F32':
|
||||
pcode = pcode.replace('D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)',
|
||||
'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)')
|
||||
if op_name == 'V_DIV_FMAS_F64':
|
||||
pcode = pcode.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
|
||||
'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)')
|
||||
# V_DIV_SCALE: D0 defaults to S0 if no branch sets it
|
||||
if op_name == 'V_DIV_SCALE_F32':
|
||||
pcode = 'D0.f32 = S0.f32\n' + pcode
|
||||
if op_name == 'V_DIV_SCALE_F64':
|
||||
pcode = 'D0.f64 = S0.f64\n' + pcode
|
||||
return pcode
|
||||
|
||||
@functools.cache
|
||||
def compile_uop(op_name: str, pseudocode: str):
|
||||
if op_name in _SKIP_OPS: return None
|
||||
if any(p in pseudocode for p in _PCODE_PATTERNS): return None
|
||||
if any(p in pseudocode for p in _WIDE_OUTPUT_PATTERNS): return None
|
||||
pseudocode = _apply_pseudocode_fixes(op_name, pseudocode)
|
||||
is_ds = op_name.startswith('DS_')
|
||||
mem_buf = LDS_BUF if is_ds else MEM_BUF
|
||||
sink, output_info, input_vars, mem_stores = _compile_pseudocode(pseudocode, mem_buf)
|
||||
|
||||
Reference in New Issue
Block a user