mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
fix pcode
This commit is contained in:
@@ -175,7 +175,11 @@ def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp:
|
||||
|
||||
case UOp(Ops.CAST, dt, (inner,)):
|
||||
inner_resolved = _expr(inner, ctx, dt)
|
||||
if dt in FLOATS: return UOp(Ops.CAST, dt, (inner_resolved,))
|
||||
if dt in FLOATS:
|
||||
# For 32'F(0xffc00000) etc, treat integer constants as BITCAST (interpret bits as float)
|
||||
if inner_resolved.op == Ops.CONST and inner_resolved.dtype not in FLOATS:
|
||||
return UOp(Ops.BITCAST, dt, (inner_resolved,))
|
||||
return UOp(Ops.CAST, dt, (inner_resolved,))
|
||||
if inner_resolved.dtype.itemsize == dt.itemsize:
|
||||
return _cast(inner_resolved, dt)
|
||||
if dt in SIGNED and inner_resolved.dtype in SIGNED: return UOp(Ops.CAST, dt, (inner_resolved,))
|
||||
@@ -291,6 +295,15 @@ def _call_cvtToQuietNAN(v): return v
|
||||
def _call_isINF(v):
|
||||
return UOp(Ops.OR, dtypes.bool, (UOp(Ops.CMPEQ, dtypes.bool, (v, UOp.const(v.dtype, float('inf')))),
|
||||
UOp(Ops.CMPEQ, dtypes.bool, (v, UOp.const(v.dtype, float('-inf'))))))
|
||||
def _call_isDENORM(v):
|
||||
# Denormalized float: exponent is 0, mantissa is non-zero, value is not zero
|
||||
uint_dt, _, exp_shift, exp_mask, mant_mask = FP_INFO.get(v.dtype, FP_INFO[dtypes.float32])
|
||||
bits = UOp(Ops.BITCAST, uint_dt, (v,))
|
||||
exp = UOp(Ops.AND, uint_dt, (UOp(Ops.SHR, uint_dt, (bits, UOp.const(uint_dt, exp_shift))), UOp.const(uint_dt, exp_mask)))
|
||||
mant = UOp(Ops.AND, uint_dt, (bits, UOp.const(uint_dt, mant_mask)))
|
||||
is_exp_zero = UOp(Ops.CMPEQ, dtypes.bool, (exp, UOp.const(uint_dt, 0)))
|
||||
is_mant_nonzero = UOp(Ops.CMPNE, dtypes.bool, (mant, UOp.const(uint_dt, 0)))
|
||||
return UOp(Ops.AND, dtypes.bool, (is_exp_zero, is_mant_nonzero))
|
||||
def _call_sign(v):
|
||||
uint_dt, sign_shift, _, _, _ = FP_INFO.get(v.dtype, FP_INFO[dtypes.float32])
|
||||
bits = UOp(Ops.BITCAST, uint_dt, (v,))
|
||||
@@ -300,9 +313,18 @@ def _call_exponent(v):
|
||||
bits = UOp(Ops.BITCAST, uint_dt, (v,))
|
||||
return UOp(Ops.AND, dtypes.uint32, (_cast(UOp(Ops.SHR, uint_dt, (bits, UOp.const(uint_dt, exp_shift))), dtypes.uint32), UOp.const(dtypes.uint32, exp_mask)))
|
||||
def _call_mantissa(v):
|
||||
uint_dt, _, _, _, mant_mask = FP_INFO.get(v.dtype, FP_INFO[dtypes.float32])
|
||||
bits, out_dt = UOp(Ops.BITCAST, uint_dt, (v,)), dtypes.uint64 if v.dtype == dtypes.float64 else dtypes.uint32
|
||||
return UOp(Ops.AND, out_dt, (_cast(bits, out_dt) if out_dt != uint_dt else bits, UOp.const(out_dt, mant_mask)))
|
||||
# AMD V_FREXP_MANT returns mantissa in [0.5, 1.0) range like math.frexp()[0]
|
||||
# For normalized floats: set exponent to bias-1 (makes value in [0.5,1.0))
|
||||
# For zero: return zero; for inf/nan: should be handled by caller
|
||||
uint_dt, sign_shift, exp_shift, exp_mask, mant_mask = FP_INFO.get(v.dtype, FP_INFO[dtypes.float32])
|
||||
bias = {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}.get(v.dtype, 127)
|
||||
bits = UOp(Ops.BITCAST, uint_dt, (v,))
|
||||
sign_and_mant = UOp(Ops.AND, uint_dt, (bits, UOp.const(uint_dt, (1 << sign_shift) | mant_mask)))
|
||||
new_exp = UOp.const(uint_dt, (bias - 1) << exp_shift) # exponent = -1 in biased form
|
||||
result_bits = UOp(Ops.OR, uint_dt, (sign_and_mant, new_exp))
|
||||
result = UOp(Ops.BITCAST, v.dtype, (result_bits,))
|
||||
is_zero = UOp(Ops.CMPEQ, dtypes.bool, (v, UOp.const(v.dtype, 0.0)))
|
||||
return UOp(Ops.WHERE, v.dtype, (is_zero, v, result))
|
||||
def _call_isEven(v):
|
||||
int_val = UOp(Ops.CAST, dtypes.int64, (v,))
|
||||
return UOp(Ops.CMPEQ, dtypes.bool, (UOp(Ops.AND, dtypes.int64, (int_val, UOp.const(dtypes.int64, 1))), UOp.const(dtypes.int64, 0)))
|
||||
@@ -354,7 +376,7 @@ def _call_BYTE_PERMUTE(src, sel):
|
||||
CALL_DISPATCH = {
|
||||
'MEM': _call_MEM, 'fma': _call_fma, 'abs': _call_abs, 'cos': _call_cos, 'rsqrt': _call_rsqrt,
|
||||
'clamp': _call_clamp, 'floor': _call_floor, 'fract': _call_fract, 'isNAN': _call_isNAN, 'isQuietNAN': _call_isQuietNAN,
|
||||
'isSignalNAN': _call_isSignalNAN, 'cvtToQuietNAN': _call_cvtToQuietNAN, 'isINF': _call_isINF,
|
||||
'isSignalNAN': _call_isSignalNAN, 'cvtToQuietNAN': _call_cvtToQuietNAN, 'isINF': _call_isINF, 'isDENORM': _call_isDENORM,
|
||||
'sign': _call_sign, 'exponent': _call_exponent, 'mantissa': _call_mantissa, 'isEven': _call_isEven,
|
||||
'signext': _call_signext, 'signext_from_bit': _call_signext_from_bit, 'ABSDIFF': _call_ABSDIFF, 'SAT8': _call_SAT8,
|
||||
'BYTE_PERMUTE': _call_BYTE_PERMUTE, 'bf16_to_f32': _call_bf16_to_f32,
|
||||
@@ -435,7 +457,11 @@ def _get_lhs_info(lhs: UOp, ctx: Ctx) -> tuple[str, DType, int|None, int|None, s
|
||||
# Array element access with variable index
|
||||
return name, var_dtype.scalar(), None, None, None, idx # Return idx as variable name for array_idx
|
||||
return name, dtypes.uint32, None, None, idx, None
|
||||
case UOp(Ops.DEFINE_VAR, _, _, (name, None, None)): return name, dtypes.uint32, None, None, None, None
|
||||
case UOp(Ops.DEFINE_VAR, _, _, (name, None, None)):
|
||||
# If the variable already exists, use its dtype; otherwise default to uint32
|
||||
existing = ctx.vars.get(name)
|
||||
dtype = existing.dtype if existing is not None else dtypes.uint32
|
||||
return name, dtype, None, None, None, None
|
||||
raise ValueError(f"Cannot parse LHS: {lhs}")
|
||||
|
||||
def _stmt(stmt, ctx: Ctx):
|
||||
@@ -660,6 +686,20 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
|
||||
if u.op == Ops.XOR: return int(l) ^ int(r)
|
||||
if u.op == Ops.SHR: return int(l) >> int(r)
|
||||
if u.op == Ops.SHL: return int(l) << int(r)
|
||||
if u.op == Ops.NEG:
|
||||
v = _eval_uop(u.src[0])
|
||||
return -v if v is not None else None
|
||||
if u.op in (Ops.CMPEQ, Ops.CMPNE, Ops.CMPLT, Ops.CMPLE):
|
||||
l, r = _eval_uop(u.src[0]), _eval_uop(u.src[1])
|
||||
if l is None or r is None: return None
|
||||
if u.op == Ops.CMPEQ: return l == r
|
||||
if u.op == Ops.CMPNE: return l != r
|
||||
if u.op == Ops.CMPLT: return l < r
|
||||
if u.op == Ops.CMPLE: return l <= r
|
||||
if u.op == Ops.WHERE:
|
||||
c, t, f = _eval_uop(u.src[0]), _eval_uop(u.src[1]), _eval_uop(u.src[2])
|
||||
if c is None or t is None or f is None: return None
|
||||
return t if c else f
|
||||
return None
|
||||
|
||||
def _extract_results(s, MEM=None):
|
||||
@@ -733,12 +773,8 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
|
||||
return _extract_results(sink.substitute(dvars).simplify())
|
||||
return fn
|
||||
|
||||
# Ops that need Python exec features (inline conditionals, complex PDF fixes, precise FMA) - fall back to pcode.py
|
||||
_SKIP_OPS: set[str] = {'V_DIV_FMAS_F32', 'V_DIV_FMAS_F64', 'V_DIV_SCALE_F32', 'V_DIV_SCALE_F64',
|
||||
'V_DIV_FIXUP_F32', 'V_DIV_FIXUP_F64', 'V_TRIG_PREOP_F64',
|
||||
'V_FMA_F64', 'V_FMA_F32', # FMA needs precise math.fma semantics
|
||||
'V_FREXP_MANT_F64', 'V_FREXP_MANT_F32', # mantissa() returns [0.5,1.0) range float
|
||||
'V_DOT2_F32_BF16'} # compound assignment parsing issues
|
||||
# Ops that need Python exec features (inline conditionals, complex PDF fixes) - fall back to pcode.py
|
||||
_SKIP_OPS: set[str] = {'V_TRIG_PREOP_F64'}
|
||||
|
||||
_PCODE_PATTERNS = ('LDS[', 'LDS(', 'VGPR[', 'SGPR[', 'GPR[', 'GS_REGS', 'thread_in[', 'thread_out[', 'thread_valid[')
|
||||
_WIDE_OUTPUT_PATTERNS = ('SDATA[95', 'SDATA[127', 'SDATA[159', 'SDATA[191', 'SDATA[223', 'SDATA[255',
|
||||
@@ -746,6 +782,62 @@ _WIDE_OUTPUT_PATTERNS = ('SDATA[95', 'SDATA[127', 'SDATA[159', 'SDATA[191', 'SDA
|
||||
|
||||
def _apply_pseudocode_fixes(op_name: str, pcode: str) -> str:
|
||||
"""Apply known fixes for PDF pseudocode bugs - same as pcode.py but for raw pseudocode."""
|
||||
if op_name == 'V_DIV_FMAS_F32':
|
||||
pcode = pcode.replace('D0.f32 = 2.0F ** 32 * fma(S0.f32, S1.f32, S2.f32)',
|
||||
'D0.f32 = (exponent(S2.f32) > 127) ? (2.0F ** 64 * fma(S0.f32, S1.f32, S2.f32)) : (2.0F ** -64 * fma(S0.f32, S1.f32, S2.f32))')
|
||||
if op_name == 'V_DIV_FMAS_F64':
|
||||
pcode = pcode.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
|
||||
'D0.f64 = (exponent(S2.f64) > 1023) ? (2.0 ** 128 * fma(S0.f64, S1.f64, S2.f64)) : (2.0 ** -128 * fma(S0.f64, S1.f64, S2.f64))')
|
||||
if op_name == 'V_DIV_FIXUP_F32':
|
||||
# When S0 (estimate) is NaN but inputs are valid, return OVERFLOW instead of NaN
|
||||
pcode = pcode.replace('D0.f32 = sign_out ? -abs(S0.f32) : abs(S0.f32)',
|
||||
'D0.f32 = isNAN(S0.f32) ? (sign_out ? -OVERFLOW_F32 : OVERFLOW_F32) : (sign_out ? -abs(S0.f32) : abs(S0.f32))')
|
||||
if op_name == 'V_DIV_FIXUP_F64':
|
||||
pcode = pcode.replace('D0.f64 = sign_out ? -abs(S0.f64) : abs(S0.f64)',
|
||||
'D0.f64 = isNAN(S0.f64) ? (sign_out ? -OVERFLOW_F64 : OVERFLOW_F64) : (sign_out ? -abs(S0.f64) : abs(S0.f64))')
|
||||
if op_name == 'V_DIV_SCALE_F32':
|
||||
# Fix 0: Replace DENORM comparisons with isDENORM() calls (order matters - do longer patterns first)
|
||||
pcode = pcode.replace('S2.f32 / S1.f32 == DENORM.f32', 'isDENORM(S2.f32 / S1.f32)')
|
||||
pcode = pcode.replace('1.0 / 64\'F(S1.f32) == DENORM.f64', 'isDENORM(1.0 / 64\'F(S1.f32))')
|
||||
pcode = pcode.replace('S1.f32 == DENORM.f32', 'isDENORM(S1.f32)')
|
||||
# Fix 1: Set VCC=1 when returning NAN for zero inputs
|
||||
pcode = pcode.replace('D0.f32 = NAN.f32', 'VCC = 0x1LL;\nD0.f32 = NAN.f32')
|
||||
# Fix 2: Remove the S1==DENORM branch (it's wrong), handle at end
|
||||
pcode = pcode.replace('elsif isDENORM(S1.f32) then\nD0.f32 = ldexp(S0.f32, 64)',
|
||||
'elsif 1 == 0 then\nD0.f32 = S0.f32')
|
||||
# Fix 3: Set VCC=1 for tiny numerator case
|
||||
pcode = pcode.replace('elsif exponent(S2.f32) <= 23 then\n// Numerator is tiny\nD0.f32 = ldexp(S0.f32, 64)',
|
||||
'elsif exponent(S2.f32) <= 23 then\nVCC = 0x1LL;\nD0.f32 = ldexp(S0.f32, 64)')
|
||||
# Fix 4: Simplify S2/S1==DENORM case (just set VCC, don't check S0==S2)
|
||||
pcode = pcode.replace('elsif isDENORM(S2.f32 / S1.f32) then\nVCC = 0x1LL;\nif S0.f32 == S2.f32 then\n// Only scale the numerator\nD0.f32 = ldexp(S0.f32, 64)\nendif',
|
||||
'elsif isDENORM(S2.f32 / S1.f32) then\nVCC = 0x1LL;\nD0.f32 = S0.f32')
|
||||
# Fix 5: Add else to nested ifs that don't have D0 assignment
|
||||
pcode = pcode.replace('D0.f32 = ldexp(S0.f32, 64)\nendif\nelsif', 'D0.f32 = ldexp(S0.f32, 64)\nelse\nD0.f32 = S0.f32\nendif\nelsif')
|
||||
# Fix 6: Add else clause to outermost if before final endif, and check for S1==DENORM at end
|
||||
lines = pcode.rstrip().split('\n')
|
||||
for i in range(len(lines) - 1, -1, -1):
|
||||
if lines[i].strip() == 'endif':
|
||||
lines.insert(i, 'else\nD0.f32 = S0.f32')
|
||||
break
|
||||
pcode = '\n'.join(lines) + ';\nif isDENORM(S1.f32) then\nD0.f32 = NAN.f32\nendif'
|
||||
if op_name == 'V_DIV_SCALE_F64':
|
||||
pcode = pcode.replace('S2.f64 / S1.f64 == DENORM.f64', 'isDENORM(S2.f64 / S1.f64)')
|
||||
pcode = pcode.replace('1.0 / S1.f64 == DENORM.f64', 'isDENORM(1.0 / S1.f64)')
|
||||
pcode = pcode.replace('S1.f64 == DENORM.f64', 'isDENORM(S1.f64)')
|
||||
pcode = pcode.replace('D0.f64 = NAN.f64', 'VCC = 0x1LL;\nD0.f64 = NAN.f64')
|
||||
pcode = pcode.replace('elsif isDENORM(S1.f64) then\nD0.f64 = ldexp(S0.f64, 128)',
|
||||
'elsif 1 == 0 then\nD0.f64 = S0.f64')
|
||||
pcode = pcode.replace('elsif exponent(S2.f64) <= 52 then\n// Numerator is tiny\nD0.f64 = ldexp(S0.f64, 128)',
|
||||
'elsif exponent(S2.f64) <= 52 then\nVCC = 0x1LL;\nD0.f64 = ldexp(S0.f64, 128)')
|
||||
pcode = pcode.replace('elsif isDENORM(S2.f64 / S1.f64) then\nVCC = 0x1LL;\nif S0.f64 == S2.f64 then\n// Only scale the numerator\nD0.f64 = ldexp(S0.f64, 128)\nendif',
|
||||
'elsif isDENORM(S2.f64 / S1.f64) then\nVCC = 0x1LL;\nD0.f64 = S0.f64')
|
||||
pcode = pcode.replace('D0.f64 = ldexp(S0.f64, 128)\nendif\nelsif', 'D0.f64 = ldexp(S0.f64, 128)\nelse\nD0.f64 = S0.f64\nendif\nelsif')
|
||||
lines = pcode.rstrip().split('\n')
|
||||
for i in range(len(lines) - 1, -1, -1):
|
||||
if lines[i].strip() == 'endif':
|
||||
lines.insert(i, 'else\nD0.f64 = S0.f64')
|
||||
break
|
||||
pcode = '\n'.join(lines) + ';\nif isDENORM(S1.f64) then\nD0.f64 = NAN.f64\nendif'
|
||||
return pcode
|
||||
|
||||
@functools.cache
|
||||
|
||||
@@ -878,7 +878,8 @@ python_alu: dict[Ops, Callable] = {
|
||||
Ops.TRUNC: lambda x: x if math.isinf(x) or math.isnan(x) else math.copysign(math.trunc(x), x),
|
||||
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt, Ops.CMPLE: operator.le,
|
||||
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
|
||||
Ops.MOD: cmod, Ops.IDIV: cdiv, Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq,
|
||||
Ops.MOD: cmod, Ops.IDIV: cdiv, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq,
|
||||
Ops.MULACC: lambda x,y,z: math.fma(x,y,z) if not (math.isinf(x) or math.isinf(y) or math.isnan(x) or math.isnan(y)) else (x*y)+z,
|
||||
Ops.FDIV: lambda x,y: x/y if y != 0 else (math.nan if x == 0 else math.copysign(math.inf, x*y))}
|
||||
|
||||
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
|
||||
|
||||
Reference in New Issue
Block a user