This commit is contained in:
George Hotz
2026-01-04 16:32:35 -08:00
parent b52ff63896
commit 87e72f1540
2 changed files with 129 additions and 48 deletions

View File

@@ -7,6 +7,13 @@ from extra.assembly.amd.qcode import parse, Assign, Declare, If, For
SIGNED = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
FLOATS = (dtypes.float16, dtypes.float32, dtypes.float64)
# FTZ (Flush To Zero): RDNA3 default mode flushes f32 denormals to ±0
def _ftz32(bits: int) -> float:
bits = bits & 0xffffffff
if (bits & 0x7f800000) == 0 and (bits & 0x007fffff) != 0: # denormal
return 0.0
return struct.unpack('<f', struct.pack('<I', bits))[0]
def _cast(x: UOp, dtype: DType) -> UOp:
return x if x.dtype == dtype else UOp(Ops.BITCAST if dtype.itemsize == x.dtype.itemsize else Ops.CAST, dtype, (x,))
@@ -28,6 +35,7 @@ INPUT_VARS = {
'SIMM32': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('SIMM32', 0, 0xffffffff)),
'PC': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('PC', 0, 0xffffffffffffffff)),
'ADDR': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('ADDR', 0, 0xffffffffffffffff)),
'ADDR_BASE': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('ADDR', 0, 0xffffffffffffffff)), # Alias for ADDR
'SDATA': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('SDATA', 0, 0xffffffffffffffff)),
'VDATA': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('VDATA', 0, 0xffffffffffffffff)),
'VDST': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('VDST', 0, 0xffffffffffffffff)),
@@ -37,6 +45,8 @@ INPUT_VARS = {
'OFFSET': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET', 0, 0xffff)),
'OFFSET0': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET0', 0, 0xff)),
'OFFSET1': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET1', 0, 0xff)),
'OPSEL': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OPSEL', 0, 7)),
'OPSEL_HI': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OPSEL_HI', 0, 7)),
}
MEM_BUF = UOp(Ops.DEFINE_GLOBAL, dtypes.uint8.ptr(addrspace=AddrSpace.GLOBAL), arg=0)
@@ -121,7 +131,21 @@ def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp:
if dt in FLOATS: return UOp(Ops.BITCAST, dt, (inner_resolved,))
return _cast(inner_resolved, dt)
case UOp(Ops.CUSTOMI, _, (base_expr, hi_expr, lo_expr)): # Slice
case UOp(Ops.CUSTOMI, _, (base_expr, hi_expr, lo_expr)): # Slice or array access
# Check for array element access first: arr[idx] where arr is a vector type
if base_expr.op == Ops.DEFINE_VAR and base_expr.arg[1] is None and hi_expr is lo_expr:
name = base_expr.arg[0]
var_dtype = ctx.decls.get(name)
if var_dtype is not None and var_dtype.count > 1:
# Array element access - look up stored element
idx_uop = _expr(hi_expr, ctx)
idx_uop = idx_uop.simplify()
if idx_uop.op == Ops.CONST:
arr_key = f"{name}_{int(idx_uop.arg)}"
if arr_key in ctx.vars:
return ctx.vars[arr_key]
# Element not set, return default value
return UOp.const(var_dtype.scalar(), 0)
base, hi_uop, lo_uop = _expr(base_expr, ctx), _expr(hi_expr, ctx), _expr(lo_expr, ctx)
# Single-bit slice: base[idx:idx] -> (base >> idx) & 1
if hi_expr is lo_expr:
@@ -296,23 +320,35 @@ def _call_SAT8(v):
clamped = UOp(Ops.WHERE, v.dtype, (UOp(Ops.CMPLT, dtypes.bool, (v, UOp.const(v.dtype, -128))), UOp.const(v.dtype, -128), v))
return UOp(Ops.WHERE, v.dtype, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(v.dtype, 127), clamped)), UOp.const(v.dtype, 127), clamped))
def _call_BYTE_PERMUTE(src, sel):
src_fixed = UOp(Ops.OR, dtypes.uint64,
(UOp(Ops.SHL, dtypes.uint64, (UOp(Ops.AND, dtypes.uint64, (_cast(src, dtypes.uint64), UOp.const(dtypes.uint64, 0xffffffff))), UOp.const(dtypes.uint64, 32))),
UOp(Ops.SHR, dtypes.uint64, (_cast(src, dtypes.uint64), UOp.const(dtypes.uint64, 32)))))
# src is {S0, S1} = (S0 << 32) | S1, where bytes 0-3 are S1, bytes 4-7 are S0
src64 = _cast(src, dtypes.uint64)
sel_val = UOp(Ops.AND, dtypes.uint32, (_cast(sel, dtypes.uint32), UOp.const(dtypes.uint32, 0xff)))
sel_idx = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 7)))
sel_hi = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 0x80)))
sel_nibble = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 0xf)))
# Normal byte select (sel 0-7): extract byte at index
shift = UOp(Ops.SHL, dtypes.uint32, (sel_idx, UOp.const(dtypes.uint32, 3)))
byte_val = _cast(UOp(Ops.AND, dtypes.uint64, (UOp(Ops.SHR, dtypes.uint64, (src_fixed, _cast(shift, dtypes.uint64))), UOp.const(dtypes.uint64, 0xff))), dtypes.uint32)
byte_msb = UOp(Ops.AND, dtypes.uint32, (byte_val, UOp.const(dtypes.uint32, 0x80)))
sign_ext_val = UOp(Ops.WHERE, dtypes.uint32, (UOp(Ops.CMPNE, dtypes.bool, (byte_msb, UOp.const(dtypes.uint32, 0))), UOp.const(dtypes.uint32, 0xff), UOp.const(dtypes.uint32, 0)))
is_sign_ext = UOp(Ops.AND, dtypes.bool, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(dtypes.uint32, 7), sel_nibble)), UOp(Ops.CMPLT, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 12)))))
is_const_zero = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 12)))
is_const_ff = UOp(Ops.CMPLT, dtypes.bool, (UOp.const(dtypes.uint32, 12), sel_nibble))
result = UOp(Ops.WHERE, dtypes.uint32, (is_const_ff, UOp.const(dtypes.uint32, 0xff), byte_val))
result = UOp(Ops.WHERE, dtypes.uint32, (is_const_zero, UOp.const(dtypes.uint32, 0), result))
result = UOp(Ops.WHERE, dtypes.uint32, (is_sign_ext, sign_ext_val, result))
byte_val = _cast(UOp(Ops.AND, dtypes.uint64, (UOp(Ops.SHR, dtypes.uint64, (src64, _cast(shift, dtypes.uint64))), UOp.const(dtypes.uint64, 0xff))), dtypes.uint32)
# Sign extension (sel 8-11): check bit 15/31/47/63 respectively
def sign_ext_bit(bit_pos):
bit = UOp(Ops.AND, dtypes.uint64, (UOp(Ops.SHR, dtypes.uint64, (src64, UOp.const(dtypes.uint64, bit_pos))), UOp.const(dtypes.uint64, 1)))
return UOp(Ops.WHERE, dtypes.uint32, (UOp(Ops.CMPNE, dtypes.bool, (bit, UOp.const(dtypes.uint64, 0))), UOp.const(dtypes.uint32, 0xff), UOp.const(dtypes.uint32, 0)))
sign8, sign9, sign10, sign11 = sign_ext_bit(15), sign_ext_bit(31), sign_ext_bit(47), sign_ext_bit(63)
# Build result based on selector
is_sel8 = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 8)))
is_sel9 = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 9)))
is_sel10 = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 10)))
is_sel11 = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 11)))
is_sel12 = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 12)))
is_sel_gt12 = UOp(Ops.CMPLT, dtypes.bool, (UOp.const(dtypes.uint32, 12), sel_nibble))
result = byte_val
result = UOp(Ops.WHERE, dtypes.uint32, (is_sel8, sign8, result))
result = UOp(Ops.WHERE, dtypes.uint32, (is_sel9, sign9, result))
result = UOp(Ops.WHERE, dtypes.uint32, (is_sel10, sign10, result))
result = UOp(Ops.WHERE, dtypes.uint32, (is_sel11, sign11, result))
result = UOp(Ops.WHERE, dtypes.uint32, (is_sel12, UOp.const(dtypes.uint32, 0), result))
result = UOp(Ops.WHERE, dtypes.uint32, (is_sel_gt12, UOp.const(dtypes.uint32, 0xff), result))
# High bit of selector (0x80) means return 0
sel_hi = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 0x80)))
return UOp(Ops.WHERE, dtypes.uint32, (UOp(Ops.CMPNE, dtypes.bool, (sel_hi, UOp.const(dtypes.uint32, 0))), UOp.const(dtypes.uint32, 0), result))
CALL_DISPATCH = {
@@ -370,34 +406,50 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
return result
raise ValueError(f"Unknown function: {name}")
def _get_lhs_info(lhs: UOp, ctx: Ctx) -> tuple[str, DType, int|None, int|None, str|None]:
"""Extract assignment target: (var_name, dtype, hi_bit, lo_bit, idx_var)"""
def _get_lhs_info(lhs: UOp, ctx: Ctx) -> tuple[str, DType, int|None, int|None, str|None, int|None]:
"""Extract assignment target: (var_name, dtype, hi_bit, lo_bit, idx_var, array_idx)"""
match lhs:
case UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)): return name, dt, None, None, None
case UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)): return name, dt, None, None, None, None
case UOp(Ops.BITCAST, dt, (UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.CONST, _, _, hi), UOp(Ops.CONST, _, _, lo))),)):
return name, dt, int(hi), int(lo), None
return name, dt, int(hi), int(lo), None, None
case UOp(Ops.BITCAST, _, (UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)),)) if lhs.src[0].src[1] is lhs.src[0].src[2]:
return name, dtypes.uint64, None, None, idx
return name, dtypes.uint64, None, None, idx, None
case UOp(Ops.BITCAST, dt, (UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)),)) if lhs.src[0].src[1] is lhs.src[0].src[2]:
return name, dt, None, None, idx
return name, dt, None, None, idx, None
case UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), UOp(Ops.CONST, _, _, hi), UOp(Ops.CONST, _, _, lo))):
return name, dtypes.uint32, int(hi), int(lo), None
return name, dtypes.uint32, int(hi), int(lo), None, None
case UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.CONST, _, _, idx), _)) if lhs.src[1] is lhs.src[2]:
# Check if this is array element access (variable is a vector type)
var_dtype = ctx.decls.get(name)
if var_dtype is not None and var_dtype.count > 1:
return name, var_dtype.scalar(), None, None, None, int(idx)
return name, dtypes.uint32, int(idx), int(idx), None, None
case UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.CONST, _, _, hi), UOp(Ops.CONST, _, _, lo))):
return name, dtypes.uint32, int(hi), int(lo), None
return name, dtypes.uint32, int(hi), int(lo), None, None
case UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)) if lhs.src[1] is lhs.src[2]:
return name, dt, None, None, idx
# Handle tmp[i] where i is a variable (single-bit index)
return name, dt, None, None, idx, None
# Handle arr[i] where i is a variable - check if it's array element or bit index
case UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)) if lhs.src[1] is lhs.src[2]:
return name, dtypes.uint32, None, None, idx
case UOp(Ops.DEFINE_VAR, _, _, (name, None, None)): return name, dtypes.uint32, None, None, None
var_dtype = ctx.decls.get(name)
if var_dtype is not None and var_dtype.count > 1:
# Array element access with variable index
return name, var_dtype.scalar(), None, None, None, idx # Return idx as variable name for array_idx
return name, dtypes.uint32, None, None, idx, None
case UOp(Ops.DEFINE_VAR, _, _, (name, None, None)): return name, dtypes.uint32, None, None, None, None
raise ValueError(f"Cannot parse LHS: {lhs}")
def _stmt(stmt, ctx: Ctx):
match stmt:
case Declare(name, dtype):
ctx.decls[name] = dtype
# Initialize declared variable with zero value
ctx.vars[name] = UOp.const(dtype, 0)
# Special handling for S array - it maps to source operands S0, S1, S2
if name == 'S' and dtype.count == 3:
ctx.vars['S_0'] = ctx.vars['S0']
ctx.vars['S_1'] = ctx.vars['S1']
ctx.vars['S_2'] = ctx.vars['S2']
else:
# Initialize declared variable with zero value
ctx.vars[name] = UOp.const(dtype, 0)
case Assign(lhs, rhs):
# Handle MEM[addr].type = value -> memory store
if lhs.op == Ops.BITCAST and lhs.src[0].op == Ops.CUSTOM and lhs.src[0].arg == 'MEM':
@@ -429,9 +481,27 @@ def _stmt(stmt, ctx: Ctx):
offset += bits
return
var, dtype, hi, lo, idx_var = _get_lhs_info(lhs, ctx)
var, dtype, hi, lo, idx_var, array_idx = _get_lhs_info(lhs, ctx)
out_vars = ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC', 'SDATA', 'VDATA', 'RETURN_DATA')
# Handle array element assignment: arr[idx] = value
if array_idx is not None:
var_dtype = ctx.decls.get(var)
if var_dtype is None: raise ValueError(f"Unknown array variable: {var}")
rhs_uop = _expr(rhs, ctx, dtype)
# array_idx can be an int or a variable name (str)
if isinstance(array_idx, str):
# Variable index - resolve it
idx_uop = ctx.vars.get(array_idx)
if idx_uop is not None and idx_uop.op == Ops.CONST:
arr_key = f"{var}_{int(idx_uop.arg)}"
else:
raise ValueError(f"Non-constant array index: {array_idx}")
else:
arr_key = f"{var}_{array_idx}"
ctx.vars[arr_key] = rhs_uop
return
if idx_var is not None:
base, idx = ctx.vars.get(var), ctx.vars.get(idx_var)
if base is None or idx is None: raise ValueError(f"Unknown variable: {var} or {idx_var}")
@@ -564,7 +634,21 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
return v if v is not None else None
if u.op == Ops.BITCAST:
v = _eval_uop(u.src[0])
return v if v is not None else None
if v is None: return None
# Convert between int and float bit representations
if u.dtype == dtypes.float64 and u.src[0].dtype in (dtypes.uint64, dtypes.int64):
return struct.unpack('<d', struct.pack('<Q', int(v) & 0xffffffffffffffff))[0]
if u.dtype == dtypes.float32 and u.src[0].dtype in (dtypes.uint32, dtypes.int32):
return _ftz32(int(v)) # Apply FTZ for f32
if u.dtype in (dtypes.uint64, dtypes.int64) and u.src[0].dtype == dtypes.float64:
return struct.unpack('<Q', struct.pack('<d', float(v)))[0]
if u.dtype in (dtypes.uint32, dtypes.int32) and u.src[0].dtype == dtypes.float32:
return struct.unpack('<I', struct.pack('<f', float(v)))[0]
return v
if u.op == Ops.MULACC:
a, b, c = _eval_uop(u.src[0]), _eval_uop(u.src[1]), _eval_uop(u.src[2])
if a is None or b is None or c is None: return None
return math.fma(float(a), float(b), float(c))
if u.op in (Ops.ADD, Ops.SUB, Ops.MUL, Ops.AND, Ops.OR, Ops.XOR, Ops.SHR, Ops.SHL):
l, r = _eval_uop(u.src[0]), _eval_uop(u.src[1])
if l is None or r is None: return None
@@ -633,7 +717,7 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
return _extract_results(s2, MEM)
return fn
else:
def fn(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None):
def fn(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None, opsel=0, opsel_hi=0):
simm16 = (literal if -32768 <= literal <= 32767 else (literal - 65536 if literal < 65536 else 0)) if literal is not None else 0
dvars = {
input_vars['S0']: UOp.const(dtypes.uint32, s0 & 0xffffffff), input_vars['S1']: UOp.const(dtypes.uint32, s1 & 0xffffffff),
@@ -644,30 +728,24 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
input_vars['EXEC']: UOp.const(dtypes.uint64, exec_mask), input_vars['laneId']: UOp.const(dtypes.uint32, laneId),
input_vars['SIMM16']: UOp.const(dtypes.int32, simm16), input_vars['SIMM32']: UOp.const(dtypes.uint32, literal or 0),
input_vars['PC']: UOp.const(dtypes.uint64, pc or 0),
input_vars['OPSEL']: UOp.const(dtypes.uint32, opsel), input_vars['OPSEL_HI']: UOp.const(dtypes.uint32, opsel_hi),
}
return _extract_results(sink.substitute(dvars).simplify())
return fn
_SKIP_OPS: set[str] = set()
# Ops that need Python exec features (inline conditionals, complex PDF fixes, precise FMA) - fall back to pcode.py
_SKIP_OPS: set[str] = {'V_DIV_FMAS_F32', 'V_DIV_FMAS_F64', 'V_DIV_SCALE_F32', 'V_DIV_SCALE_F64',
'V_DIV_FIXUP_F32', 'V_DIV_FIXUP_F64', 'V_TRIG_PREOP_F64',
'V_FMA_F64', 'V_FMA_F32', # FMA needs precise math.fma semantics
'V_FREXP_MANT_F64', 'V_FREXP_MANT_F32', # mantissa() returns [0.5,1.0) range float
'V_DOT2_F32_BF16'} # compound assignment parsing issues
_PCODE_PATTERNS = ('LDS[', 'LDS(', 'VGPR[', 'SGPR[', 'GPR[', 'GS_REGS', 'thread_in[', 'thread_out[', 'thread_valid[')
_WIDE_OUTPUT_PATTERNS = ('SDATA[95', 'SDATA[127', 'SDATA[159', 'SDATA[191', 'SDATA[223', 'SDATA[255',
'VDATA[95', 'VDATA[127')
'VDATA[95', 'VDATA[127', 'RETURN_DATA[95', 'RETURN_DATA[127')
def _apply_pseudocode_fixes(op_name: str, pcode: str) -> str:
"""Apply known fixes for PDF pseudocode bugs - same as pcode.py but for raw pseudocode."""
# V_DIV_FMAS: fix scaling factor
if op_name == 'V_DIV_FMAS_F32':
pcode = pcode.replace('D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)',
'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)')
if op_name == 'V_DIV_FMAS_F64':
pcode = pcode.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)')
# V_DIV_SCALE: D0 defaults to S0 if no branch sets it
if op_name == 'V_DIV_SCALE_F32':
pcode = 'D0.f32 = S0.f32\n' + pcode
if op_name == 'V_DIV_SCALE_F64':
pcode = 'D0.f64 = S0.f64\n' + pcode
return pcode
@functools.cache

View File

@@ -32,6 +32,9 @@ def fold_bitcast(root:UOp, c:UOp) -> UOp|None:
int_val = int(trunc(v)) & ((1 << (from_scalar.itemsize * 8)) - 1)
# Convert to output type
result = struct.unpack('<'+to_scalar.fmt, struct.pack('<'+int_fmt, int_val))[0]
# FTZ: flush f32 denormals to zero (for AMD GPU emulation - RDNA3 default mode)
if to_scalar.fmt == 'f' and (int_val & 0x7f800000) == 0 and (int_val & 0x007fffff) != 0:
result = 0.0
# Don't fold if result is NaN with non-canonical bits (as_const normalizes all NaN to math.nan)
if isinstance(result, float) and math.isnan(result):
canonical_nan_bits = struct.unpack('<'+int_fmt, struct.pack('<'+to_scalar.fmt, math.nan))[0]
@@ -109,9 +112,9 @@ symbolic_simple = propagate_invalid + PatternMatcher([
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
# x*0 -> 0 or 0*x -> 0
# if x is nan or inf it should render the nan value.
# NOTE: this can be wrong for loaded NaN
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if x.op is Ops.CONST
and isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
# NOTE: this can be wrong for loaded NaN - disabled for AMD emulator correctness
# (UPat.var("x") * 0, lambda x: x.const_like(float("nan") if x.op is Ops.CONST
# and isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
# *** cast/bitcast ***
(UPat(Ops.CAST, name="root", src=(UPat.cvar("c"),)), lambda root, c: root.const_like(c.arg)),
(UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),