mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
ftz
This commit is contained in:
@@ -7,6 +7,13 @@ from extra.assembly.amd.qcode import parse, Assign, Declare, If, For
|
||||
SIGNED = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
|
||||
FLOATS = (dtypes.float16, dtypes.float32, dtypes.float64)
|
||||
|
||||
# FTZ (Flush To Zero): RDNA3 default mode flushes f32 denormals to ±0
|
||||
def _ftz32(bits: int) -> float:
|
||||
bits = bits & 0xffffffff
|
||||
if (bits & 0x7f800000) == 0 and (bits & 0x007fffff) != 0: # denormal
|
||||
return 0.0
|
||||
return struct.unpack('<f', struct.pack('<I', bits))[0]
|
||||
|
||||
def _cast(x: UOp, dtype: DType) -> UOp:
|
||||
return x if x.dtype == dtype else UOp(Ops.BITCAST if dtype.itemsize == x.dtype.itemsize else Ops.CAST, dtype, (x,))
|
||||
|
||||
@@ -28,6 +35,7 @@ INPUT_VARS = {
|
||||
'SIMM32': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('SIMM32', 0, 0xffffffff)),
|
||||
'PC': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('PC', 0, 0xffffffffffffffff)),
|
||||
'ADDR': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('ADDR', 0, 0xffffffffffffffff)),
|
||||
'ADDR_BASE': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('ADDR', 0, 0xffffffffffffffff)), # Alias for ADDR
|
||||
'SDATA': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('SDATA', 0, 0xffffffffffffffff)),
|
||||
'VDATA': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('VDATA', 0, 0xffffffffffffffff)),
|
||||
'VDST': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('VDST', 0, 0xffffffffffffffff)),
|
||||
@@ -37,6 +45,8 @@ INPUT_VARS = {
|
||||
'OFFSET': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET', 0, 0xffff)),
|
||||
'OFFSET0': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET0', 0, 0xff)),
|
||||
'OFFSET1': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET1', 0, 0xff)),
|
||||
'OPSEL': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OPSEL', 0, 7)),
|
||||
'OPSEL_HI': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OPSEL_HI', 0, 7)),
|
||||
}
|
||||
|
||||
MEM_BUF = UOp(Ops.DEFINE_GLOBAL, dtypes.uint8.ptr(addrspace=AddrSpace.GLOBAL), arg=0)
|
||||
@@ -121,7 +131,21 @@ def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp:
|
||||
if dt in FLOATS: return UOp(Ops.BITCAST, dt, (inner_resolved,))
|
||||
return _cast(inner_resolved, dt)
|
||||
|
||||
case UOp(Ops.CUSTOMI, _, (base_expr, hi_expr, lo_expr)): # Slice
|
||||
case UOp(Ops.CUSTOMI, _, (base_expr, hi_expr, lo_expr)): # Slice or array access
|
||||
# Check for array element access first: arr[idx] where arr is a vector type
|
||||
if base_expr.op == Ops.DEFINE_VAR and base_expr.arg[1] is None and hi_expr is lo_expr:
|
||||
name = base_expr.arg[0]
|
||||
var_dtype = ctx.decls.get(name)
|
||||
if var_dtype is not None and var_dtype.count > 1:
|
||||
# Array element access - look up stored element
|
||||
idx_uop = _expr(hi_expr, ctx)
|
||||
idx_uop = idx_uop.simplify()
|
||||
if idx_uop.op == Ops.CONST:
|
||||
arr_key = f"{name}_{int(idx_uop.arg)}"
|
||||
if arr_key in ctx.vars:
|
||||
return ctx.vars[arr_key]
|
||||
# Element not set, return default value
|
||||
return UOp.const(var_dtype.scalar(), 0)
|
||||
base, hi_uop, lo_uop = _expr(base_expr, ctx), _expr(hi_expr, ctx), _expr(lo_expr, ctx)
|
||||
# Single-bit slice: base[idx:idx] -> (base >> idx) & 1
|
||||
if hi_expr is lo_expr:
|
||||
@@ -296,23 +320,35 @@ def _call_SAT8(v):
|
||||
clamped = UOp(Ops.WHERE, v.dtype, (UOp(Ops.CMPLT, dtypes.bool, (v, UOp.const(v.dtype, -128))), UOp.const(v.dtype, -128), v))
|
||||
return UOp(Ops.WHERE, v.dtype, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(v.dtype, 127), clamped)), UOp.const(v.dtype, 127), clamped))
|
||||
def _call_BYTE_PERMUTE(src, sel):
|
||||
src_fixed = UOp(Ops.OR, dtypes.uint64,
|
||||
(UOp(Ops.SHL, dtypes.uint64, (UOp(Ops.AND, dtypes.uint64, (_cast(src, dtypes.uint64), UOp.const(dtypes.uint64, 0xffffffff))), UOp.const(dtypes.uint64, 32))),
|
||||
UOp(Ops.SHR, dtypes.uint64, (_cast(src, dtypes.uint64), UOp.const(dtypes.uint64, 32)))))
|
||||
# src is {S0, S1} = (S0 << 32) | S1, where bytes 0-3 are S1, bytes 4-7 are S0
|
||||
src64 = _cast(src, dtypes.uint64)
|
||||
sel_val = UOp(Ops.AND, dtypes.uint32, (_cast(sel, dtypes.uint32), UOp.const(dtypes.uint32, 0xff)))
|
||||
sel_idx = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 7)))
|
||||
sel_hi = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 0x80)))
|
||||
sel_nibble = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 0xf)))
|
||||
# Normal byte select (sel 0-7): extract byte at index
|
||||
shift = UOp(Ops.SHL, dtypes.uint32, (sel_idx, UOp.const(dtypes.uint32, 3)))
|
||||
byte_val = _cast(UOp(Ops.AND, dtypes.uint64, (UOp(Ops.SHR, dtypes.uint64, (src_fixed, _cast(shift, dtypes.uint64))), UOp.const(dtypes.uint64, 0xff))), dtypes.uint32)
|
||||
byte_msb = UOp(Ops.AND, dtypes.uint32, (byte_val, UOp.const(dtypes.uint32, 0x80)))
|
||||
sign_ext_val = UOp(Ops.WHERE, dtypes.uint32, (UOp(Ops.CMPNE, dtypes.bool, (byte_msb, UOp.const(dtypes.uint32, 0))), UOp.const(dtypes.uint32, 0xff), UOp.const(dtypes.uint32, 0)))
|
||||
is_sign_ext = UOp(Ops.AND, dtypes.bool, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(dtypes.uint32, 7), sel_nibble)), UOp(Ops.CMPLT, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 12)))))
|
||||
is_const_zero = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 12)))
|
||||
is_const_ff = UOp(Ops.CMPLT, dtypes.bool, (UOp.const(dtypes.uint32, 12), sel_nibble))
|
||||
result = UOp(Ops.WHERE, dtypes.uint32, (is_const_ff, UOp.const(dtypes.uint32, 0xff), byte_val))
|
||||
result = UOp(Ops.WHERE, dtypes.uint32, (is_const_zero, UOp.const(dtypes.uint32, 0), result))
|
||||
result = UOp(Ops.WHERE, dtypes.uint32, (is_sign_ext, sign_ext_val, result))
|
||||
byte_val = _cast(UOp(Ops.AND, dtypes.uint64, (UOp(Ops.SHR, dtypes.uint64, (src64, _cast(shift, dtypes.uint64))), UOp.const(dtypes.uint64, 0xff))), dtypes.uint32)
|
||||
# Sign extension (sel 8-11): check bit 15/31/47/63 respectively
|
||||
def sign_ext_bit(bit_pos):
|
||||
bit = UOp(Ops.AND, dtypes.uint64, (UOp(Ops.SHR, dtypes.uint64, (src64, UOp.const(dtypes.uint64, bit_pos))), UOp.const(dtypes.uint64, 1)))
|
||||
return UOp(Ops.WHERE, dtypes.uint32, (UOp(Ops.CMPNE, dtypes.bool, (bit, UOp.const(dtypes.uint64, 0))), UOp.const(dtypes.uint32, 0xff), UOp.const(dtypes.uint32, 0)))
|
||||
sign8, sign9, sign10, sign11 = sign_ext_bit(15), sign_ext_bit(31), sign_ext_bit(47), sign_ext_bit(63)
|
||||
# Build result based on selector
|
||||
is_sel8 = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 8)))
|
||||
is_sel9 = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 9)))
|
||||
is_sel10 = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 10)))
|
||||
is_sel11 = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 11)))
|
||||
is_sel12 = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 12)))
|
||||
is_sel_gt12 = UOp(Ops.CMPLT, dtypes.bool, (UOp.const(dtypes.uint32, 12), sel_nibble))
|
||||
result = byte_val
|
||||
result = UOp(Ops.WHERE, dtypes.uint32, (is_sel8, sign8, result))
|
||||
result = UOp(Ops.WHERE, dtypes.uint32, (is_sel9, sign9, result))
|
||||
result = UOp(Ops.WHERE, dtypes.uint32, (is_sel10, sign10, result))
|
||||
result = UOp(Ops.WHERE, dtypes.uint32, (is_sel11, sign11, result))
|
||||
result = UOp(Ops.WHERE, dtypes.uint32, (is_sel12, UOp.const(dtypes.uint32, 0), result))
|
||||
result = UOp(Ops.WHERE, dtypes.uint32, (is_sel_gt12, UOp.const(dtypes.uint32, 0xff), result))
|
||||
# High bit of selector (0x80) means return 0
|
||||
sel_hi = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 0x80)))
|
||||
return UOp(Ops.WHERE, dtypes.uint32, (UOp(Ops.CMPNE, dtypes.bool, (sel_hi, UOp.const(dtypes.uint32, 0))), UOp.const(dtypes.uint32, 0), result))
|
||||
|
||||
CALL_DISPATCH = {
|
||||
@@ -370,34 +406,50 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
|
||||
return result
|
||||
raise ValueError(f"Unknown function: {name}")
|
||||
|
||||
def _get_lhs_info(lhs: UOp, ctx: Ctx) -> tuple[str, DType, int|None, int|None, str|None]:
|
||||
"""Extract assignment target: (var_name, dtype, hi_bit, lo_bit, idx_var)"""
|
||||
def _get_lhs_info(lhs: UOp, ctx: Ctx) -> tuple[str, DType, int|None, int|None, str|None, int|None]:
|
||||
"""Extract assignment target: (var_name, dtype, hi_bit, lo_bit, idx_var, array_idx)"""
|
||||
match lhs:
|
||||
case UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)): return name, dt, None, None, None
|
||||
case UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)): return name, dt, None, None, None, None
|
||||
case UOp(Ops.BITCAST, dt, (UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.CONST, _, _, hi), UOp(Ops.CONST, _, _, lo))),)):
|
||||
return name, dt, int(hi), int(lo), None
|
||||
return name, dt, int(hi), int(lo), None, None
|
||||
case UOp(Ops.BITCAST, _, (UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)),)) if lhs.src[0].src[1] is lhs.src[0].src[2]:
|
||||
return name, dtypes.uint64, None, None, idx
|
||||
return name, dtypes.uint64, None, None, idx, None
|
||||
case UOp(Ops.BITCAST, dt, (UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)),)) if lhs.src[0].src[1] is lhs.src[0].src[2]:
|
||||
return name, dt, None, None, idx
|
||||
return name, dt, None, None, idx, None
|
||||
case UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), UOp(Ops.CONST, _, _, hi), UOp(Ops.CONST, _, _, lo))):
|
||||
return name, dtypes.uint32, int(hi), int(lo), None
|
||||
return name, dtypes.uint32, int(hi), int(lo), None, None
|
||||
case UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.CONST, _, _, idx), _)) if lhs.src[1] is lhs.src[2]:
|
||||
# Check if this is array element access (variable is a vector type)
|
||||
var_dtype = ctx.decls.get(name)
|
||||
if var_dtype is not None and var_dtype.count > 1:
|
||||
return name, var_dtype.scalar(), None, None, None, int(idx)
|
||||
return name, dtypes.uint32, int(idx), int(idx), None, None
|
||||
case UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.CONST, _, _, hi), UOp(Ops.CONST, _, _, lo))):
|
||||
return name, dtypes.uint32, int(hi), int(lo), None
|
||||
return name, dtypes.uint32, int(hi), int(lo), None, None
|
||||
case UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)) if lhs.src[1] is lhs.src[2]:
|
||||
return name, dt, None, None, idx
|
||||
# Handle tmp[i] where i is a variable (single-bit index)
|
||||
return name, dt, None, None, idx, None
|
||||
# Handle arr[i] where i is a variable - check if it's array element or bit index
|
||||
case UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)) if lhs.src[1] is lhs.src[2]:
|
||||
return name, dtypes.uint32, None, None, idx
|
||||
case UOp(Ops.DEFINE_VAR, _, _, (name, None, None)): return name, dtypes.uint32, None, None, None
|
||||
var_dtype = ctx.decls.get(name)
|
||||
if var_dtype is not None and var_dtype.count > 1:
|
||||
# Array element access with variable index
|
||||
return name, var_dtype.scalar(), None, None, None, idx # Return idx as variable name for array_idx
|
||||
return name, dtypes.uint32, None, None, idx, None
|
||||
case UOp(Ops.DEFINE_VAR, _, _, (name, None, None)): return name, dtypes.uint32, None, None, None, None
|
||||
raise ValueError(f"Cannot parse LHS: {lhs}")
|
||||
|
||||
def _stmt(stmt, ctx: Ctx):
|
||||
match stmt:
|
||||
case Declare(name, dtype):
|
||||
ctx.decls[name] = dtype
|
||||
# Initialize declared variable with zero value
|
||||
ctx.vars[name] = UOp.const(dtype, 0)
|
||||
# Special handling for S array - it maps to source operands S0, S1, S2
|
||||
if name == 'S' and dtype.count == 3:
|
||||
ctx.vars['S_0'] = ctx.vars['S0']
|
||||
ctx.vars['S_1'] = ctx.vars['S1']
|
||||
ctx.vars['S_2'] = ctx.vars['S2']
|
||||
else:
|
||||
# Initialize declared variable with zero value
|
||||
ctx.vars[name] = UOp.const(dtype, 0)
|
||||
case Assign(lhs, rhs):
|
||||
# Handle MEM[addr].type = value -> memory store
|
||||
if lhs.op == Ops.BITCAST and lhs.src[0].op == Ops.CUSTOM and lhs.src[0].arg == 'MEM':
|
||||
@@ -429,9 +481,27 @@ def _stmt(stmt, ctx: Ctx):
|
||||
offset += bits
|
||||
return
|
||||
|
||||
var, dtype, hi, lo, idx_var = _get_lhs_info(lhs, ctx)
|
||||
var, dtype, hi, lo, idx_var, array_idx = _get_lhs_info(lhs, ctx)
|
||||
out_vars = ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC', 'SDATA', 'VDATA', 'RETURN_DATA')
|
||||
|
||||
# Handle array element assignment: arr[idx] = value
|
||||
if array_idx is not None:
|
||||
var_dtype = ctx.decls.get(var)
|
||||
if var_dtype is None: raise ValueError(f"Unknown array variable: {var}")
|
||||
rhs_uop = _expr(rhs, ctx, dtype)
|
||||
# array_idx can be an int or a variable name (str)
|
||||
if isinstance(array_idx, str):
|
||||
# Variable index - resolve it
|
||||
idx_uop = ctx.vars.get(array_idx)
|
||||
if idx_uop is not None and idx_uop.op == Ops.CONST:
|
||||
arr_key = f"{var}_{int(idx_uop.arg)}"
|
||||
else:
|
||||
raise ValueError(f"Non-constant array index: {array_idx}")
|
||||
else:
|
||||
arr_key = f"{var}_{array_idx}"
|
||||
ctx.vars[arr_key] = rhs_uop
|
||||
return
|
||||
|
||||
if idx_var is not None:
|
||||
base, idx = ctx.vars.get(var), ctx.vars.get(idx_var)
|
||||
if base is None or idx is None: raise ValueError(f"Unknown variable: {var} or {idx_var}")
|
||||
@@ -564,7 +634,21 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
|
||||
return v if v is not None else None
|
||||
if u.op == Ops.BITCAST:
|
||||
v = _eval_uop(u.src[0])
|
||||
return v if v is not None else None
|
||||
if v is None: return None
|
||||
# Convert between int and float bit representations
|
||||
if u.dtype == dtypes.float64 and u.src[0].dtype in (dtypes.uint64, dtypes.int64):
|
||||
return struct.unpack('<d', struct.pack('<Q', int(v) & 0xffffffffffffffff))[0]
|
||||
if u.dtype == dtypes.float32 and u.src[0].dtype in (dtypes.uint32, dtypes.int32):
|
||||
return _ftz32(int(v)) # Apply FTZ for f32
|
||||
if u.dtype in (dtypes.uint64, dtypes.int64) and u.src[0].dtype == dtypes.float64:
|
||||
return struct.unpack('<Q', struct.pack('<d', float(v)))[0]
|
||||
if u.dtype in (dtypes.uint32, dtypes.int32) and u.src[0].dtype == dtypes.float32:
|
||||
return struct.unpack('<I', struct.pack('<f', float(v)))[0]
|
||||
return v
|
||||
if u.op == Ops.MULACC:
|
||||
a, b, c = _eval_uop(u.src[0]), _eval_uop(u.src[1]), _eval_uop(u.src[2])
|
||||
if a is None or b is None or c is None: return None
|
||||
return math.fma(float(a), float(b), float(c))
|
||||
if u.op in (Ops.ADD, Ops.SUB, Ops.MUL, Ops.AND, Ops.OR, Ops.XOR, Ops.SHR, Ops.SHL):
|
||||
l, r = _eval_uop(u.src[0]), _eval_uop(u.src[1])
|
||||
if l is None or r is None: return None
|
||||
@@ -633,7 +717,7 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
|
||||
return _extract_results(s2, MEM)
|
||||
return fn
|
||||
else:
|
||||
def fn(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None):
|
||||
def fn(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None, opsel=0, opsel_hi=0):
|
||||
simm16 = (literal if -32768 <= literal <= 32767 else (literal - 65536 if literal < 65536 else 0)) if literal is not None else 0
|
||||
dvars = {
|
||||
input_vars['S0']: UOp.const(dtypes.uint32, s0 & 0xffffffff), input_vars['S1']: UOp.const(dtypes.uint32, s1 & 0xffffffff),
|
||||
@@ -644,30 +728,24 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
|
||||
input_vars['EXEC']: UOp.const(dtypes.uint64, exec_mask), input_vars['laneId']: UOp.const(dtypes.uint32, laneId),
|
||||
input_vars['SIMM16']: UOp.const(dtypes.int32, simm16), input_vars['SIMM32']: UOp.const(dtypes.uint32, literal or 0),
|
||||
input_vars['PC']: UOp.const(dtypes.uint64, pc or 0),
|
||||
input_vars['OPSEL']: UOp.const(dtypes.uint32, opsel), input_vars['OPSEL_HI']: UOp.const(dtypes.uint32, opsel_hi),
|
||||
}
|
||||
return _extract_results(sink.substitute(dvars).simplify())
|
||||
return fn
|
||||
|
||||
_SKIP_OPS: set[str] = set()
|
||||
# Ops that need Python exec features (inline conditionals, complex PDF fixes, precise FMA) - fall back to pcode.py
|
||||
_SKIP_OPS: set[str] = {'V_DIV_FMAS_F32', 'V_DIV_FMAS_F64', 'V_DIV_SCALE_F32', 'V_DIV_SCALE_F64',
|
||||
'V_DIV_FIXUP_F32', 'V_DIV_FIXUP_F64', 'V_TRIG_PREOP_F64',
|
||||
'V_FMA_F64', 'V_FMA_F32', # FMA needs precise math.fma semantics
|
||||
'V_FREXP_MANT_F64', 'V_FREXP_MANT_F32', # mantissa() returns [0.5,1.0) range float
|
||||
'V_DOT2_F32_BF16'} # compound assignment parsing issues
|
||||
|
||||
_PCODE_PATTERNS = ('LDS[', 'LDS(', 'VGPR[', 'SGPR[', 'GPR[', 'GS_REGS', 'thread_in[', 'thread_out[', 'thread_valid[')
|
||||
_WIDE_OUTPUT_PATTERNS = ('SDATA[95', 'SDATA[127', 'SDATA[159', 'SDATA[191', 'SDATA[223', 'SDATA[255',
|
||||
'VDATA[95', 'VDATA[127')
|
||||
'VDATA[95', 'VDATA[127', 'RETURN_DATA[95', 'RETURN_DATA[127')
|
||||
|
||||
def _apply_pseudocode_fixes(op_name: str, pcode: str) -> str:
|
||||
"""Apply known fixes for PDF pseudocode bugs - same as pcode.py but for raw pseudocode."""
|
||||
# V_DIV_FMAS: fix scaling factor
|
||||
if op_name == 'V_DIV_FMAS_F32':
|
||||
pcode = pcode.replace('D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)',
|
||||
'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)')
|
||||
if op_name == 'V_DIV_FMAS_F64':
|
||||
pcode = pcode.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
|
||||
'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)')
|
||||
# V_DIV_SCALE: D0 defaults to S0 if no branch sets it
|
||||
if op_name == 'V_DIV_SCALE_F32':
|
||||
pcode = 'D0.f32 = S0.f32\n' + pcode
|
||||
if op_name == 'V_DIV_SCALE_F64':
|
||||
pcode = 'D0.f64 = S0.f64\n' + pcode
|
||||
return pcode
|
||||
|
||||
@functools.cache
|
||||
|
||||
@@ -32,6 +32,9 @@ def fold_bitcast(root:UOp, c:UOp) -> UOp|None:
|
||||
int_val = int(trunc(v)) & ((1 << (from_scalar.itemsize * 8)) - 1)
|
||||
# Convert to output type
|
||||
result = struct.unpack('<'+to_scalar.fmt, struct.pack('<'+int_fmt, int_val))[0]
|
||||
# FTZ: flush f32 denormals to zero (for AMD GPU emulation - RDNA3 default mode)
|
||||
if to_scalar.fmt == 'f' and (int_val & 0x7f800000) == 0 and (int_val & 0x007fffff) != 0:
|
||||
result = 0.0
|
||||
# Don't fold if result is NaN with non-canonical bits (as_const normalizes all NaN to math.nan)
|
||||
if isinstance(result, float) and math.isnan(result):
|
||||
canonical_nan_bits = struct.unpack('<'+int_fmt, struct.pack('<'+to_scalar.fmt, math.nan))[0]
|
||||
@@ -109,9 +112,9 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
||||
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
|
||||
# x*0 -> 0 or 0*x -> 0
|
||||
# if x is nan or inf it should render the nan value.
|
||||
# NOTE: this can be wrong for loaded NaN
|
||||
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if x.op is Ops.CONST
|
||||
and isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
# NOTE: this can be wrong for loaded NaN - disabled for AMD emulator correctness
|
||||
# (UPat.var("x") * 0, lambda x: x.const_like(float("nan") if x.op is Ops.CONST
|
||||
# and isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
# *** cast/bitcast ***
|
||||
(UPat(Ops.CAST, name="root", src=(UPat.cvar("c"),)), lambda root, c: root.const_like(c.arg)),
|
||||
(UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
|
||||
|
||||
Reference in New Issue
Block a user