This commit is contained in:
George Hotz
2026-01-11 14:47:11 +09:00
parent 31ab6f8107
commit 634b86654e
2 changed files with 43 additions and 47 deletions

View File

@@ -335,7 +335,46 @@ def _transform_stmt(stmt, ctx: dict):
case UOp(): return _transform_uop(stmt, ctx)
case _: return stmt
def parse_transform(pcode: str) -> tuple:
def _apply_pseudocode_fixes(op_name: str, pcode: str) -> str:
"""Apply known fixes for PDF pseudocode bugs."""
if op_name == 'V_DIV_FMAS_F32':
pcode = pcode.replace('D0.f32 = 2.0F ** 32 * fma(S0.f32, S1.f32, S2.f32)',
'D0.f32 = (exponent(S2.f32) > 127) ? (2.0F ** 64 * fma(S0.f32, S1.f32, S2.f32)) : (2.0F ** -64 * fma(S0.f32, S1.f32, S2.f32))')
if op_name == 'V_DIV_FMAS_F64':
pcode = pcode.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
'D0.f64 = (exponent(S2.f64) > 1023) ? (2.0 ** 128 * fma(S0.f64, S1.f64, S2.f64)) : (2.0 ** -128 * fma(S0.f64, S1.f64, S2.f64))')
if op_name == 'V_DIV_FIXUP_F32':
pcode = pcode.replace('D0.f32 = sign_out ? -abs(S0.f32) : abs(S0.f32)',
'D0.f32 = isNAN(S0.f32) ? (sign_out ? -OVERFLOW_F32 : OVERFLOW_F32) : (sign_out ? -abs(S0.f32) : abs(S0.f32))')
if op_name == 'V_DIV_FIXUP_F64':
pcode = pcode.replace('D0.f64 = sign_out ? -abs(S0.f64) : abs(S0.f64)',
'D0.f64 = isNAN(S0.f64) ? (sign_out ? -OVERFLOW_F64 : OVERFLOW_F64) : (sign_out ? -abs(S0.f64) : abs(S0.f64))')
if 'V_DIV_SCALE' in op_name:
dt = 'f32' if 'F32' in op_name else 'f64'
exp_lim, ldexp_val = ('23', '64') if dt == 'f32' else ('52', '128')
pcode = pcode.replace(f'S2.{dt} / S1.{dt} == DENORM.{dt}', f'isDENORM(S2.{dt} / S1.{dt})')
pcode = pcode.replace(f"1.0 / 64'F(S1.{dt}) == DENORM.f64", f"isDENORM(1.0 / 64'F(S1.{dt}))")
pcode = pcode.replace(f'1.0 / S1.{dt} == DENORM.{dt}', f'isDENORM(1.0 / S1.{dt})')
pcode = pcode.replace(f'S1.{dt} == DENORM.{dt}', f'isDENORM(S1.{dt})')
pcode = pcode.replace(f'D0.{dt} = NAN.{dt}', f'VCC = 0x1LL;\nD0.{dt} = NAN.{dt}')
pcode = pcode.replace(f'elsif isDENORM(S1.{dt}) then\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})', f'elsif 1 == 0 then\nD0.{dt} = S0.{dt}')
pcode = pcode.replace(f'elsif exponent(S2.{dt}) <= {exp_lim} then\n// Numerator is tiny\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})',
f'elsif exponent(S2.{dt}) <= {exp_lim} then\nVCC = 0x1LL;\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})')
pcode = pcode.replace(f'elsif isDENORM(S2.{dt} / S1.{dt}) then\nVCC = 0x1LL;\nif S0.{dt} == S2.{dt} then\n// Only scale the numerator\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif',
f'elsif isDENORM(S2.{dt} / S1.{dt}) then\nVCC = 0x1LL;\nD0.{dt} = S0.{dt}')
pcode = pcode.replace(f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif\nelsif', f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nelse\nD0.{dt} = S0.{dt}\nendif\nelsif')
lines = pcode.rstrip().split('\n')
for i in range(len(lines) - 1, -1, -1):
if lines[i].strip() == 'endif':
lines.insert(i, f'else\nD0.{dt} = S0.{dt}')
break
pcode = '\n'.join(lines) + f';\nif isDENORM(S1.{dt}) then\nD0.{dt} = NAN.{dt}\nendif'
if op_name == 'V_TRIG_PREOP_F64':
pcode = pcode.replace("result = 64'F((1201'B(2.0 / PI)[1200 : 0] << shift.u32) & 1201'0x1fffffffffffff)", "result = trig_preop_result(shift)")
return pcode
def parse_transform(pcode: str, op_name: str | None = None) -> tuple:
if op_name is not None: pcode = _apply_pseudocode_fixes(op_name, pcode)
ctx: dict[str, DType] = {'SCC': dtypes.bool, 'VCC': dtypes.uint64, 'EXEC': dtypes.uint64,
'VDATA': dtypes.uint64, 'SDATA': dtypes.uint64, 'ADDR': dtypes.uint64, 'VDST': dtypes.uint32,
'ROUND_MODE': dtypes.uint32, 'ROUND_TOWARD_ZERO': dtypes.uint32, 'HW_REGISTERS': dtypes.uint32,

View File

@@ -523,10 +523,10 @@ _DTYPE_ACCESSOR = {dtypes.uint8: 'u8', dtypes.int8: 'i8', dtypes.uint16: 'u16',
dtypes.uint32: 'u32', dtypes.int32: 'i32', dtypes.uint64: 'u64', dtypes.int64: 'i64',
dtypes.float32: 'u32', dtypes.float64: 'u64'}
def _compile_pseudocode(pseudocode: str, mem_buf: UOp = MEM_BUF) -> tuple[UOp, list[tuple[str, DType]], dict[str, UOp], list[UOp]]:
def _compile_pseudocode(pseudocode: str, mem_buf: UOp = MEM_BUF, op_name: str | None = None) -> tuple[UOp, list[tuple[str, DType]], dict[str, UOp], list[UOp]]:
ctx = Ctx(mem_buf=mem_buf)
try:
stmts = parse_transform(pseudocode)
stmts = parse_transform(pseudocode, op_name)
except AssertionError as e:
print("issue parsing")
print(pseudocode)
@@ -606,51 +606,8 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
return _extract_results(s1_sub)
return fn
# ═══════════════════════════════════════════════════════════════════════════════
# PSEUDOCODE FIXES
# ═══════════════════════════════════════════════════════════════════════════════
def _apply_pseudocode_fixes(op_name: str, pcode: str) -> str:
"""Apply known fixes for PDF pseudocode bugs."""
if op_name == 'V_DIV_FMAS_F32':
pcode = pcode.replace('D0.f32 = 2.0F ** 32 * fma(S0.f32, S1.f32, S2.f32)',
'D0.f32 = (exponent(S2.f32) > 127) ? (2.0F ** 64 * fma(S0.f32, S1.f32, S2.f32)) : (2.0F ** -64 * fma(S0.f32, S1.f32, S2.f32))')
if op_name == 'V_DIV_FMAS_F64':
pcode = pcode.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
'D0.f64 = (exponent(S2.f64) > 1023) ? (2.0 ** 128 * fma(S0.f64, S1.f64, S2.f64)) : (2.0 ** -128 * fma(S0.f64, S1.f64, S2.f64))')
if op_name == 'V_DIV_FIXUP_F32':
pcode = pcode.replace('D0.f32 = sign_out ? -abs(S0.f32) : abs(S0.f32)',
'D0.f32 = isNAN(S0.f32) ? (sign_out ? -OVERFLOW_F32 : OVERFLOW_F32) : (sign_out ? -abs(S0.f32) : abs(S0.f32))')
if op_name == 'V_DIV_FIXUP_F64':
pcode = pcode.replace('D0.f64 = sign_out ? -abs(S0.f64) : abs(S0.f64)',
'D0.f64 = isNAN(S0.f64) ? (sign_out ? -OVERFLOW_F64 : OVERFLOW_F64) : (sign_out ? -abs(S0.f64) : abs(S0.f64))')
if 'V_DIV_SCALE' in op_name:
dt = 'f32' if 'F32' in op_name else 'f64'
exp_lim, ldexp_val = ('23', '64') if dt == 'f32' else ('52', '128')
pcode = pcode.replace(f'S2.{dt} / S1.{dt} == DENORM.{dt}', f'isDENORM(S2.{dt} / S1.{dt})')
pcode = pcode.replace(f"1.0 / 64'F(S1.{dt}) == DENORM.f64", f"isDENORM(1.0 / 64'F(S1.{dt}))")
pcode = pcode.replace(f'1.0 / S1.{dt} == DENORM.{dt}', f'isDENORM(1.0 / S1.{dt})')
pcode = pcode.replace(f'S1.{dt} == DENORM.{dt}', f'isDENORM(S1.{dt})')
pcode = pcode.replace(f'D0.{dt} = NAN.{dt}', f'VCC = 0x1LL;\nD0.{dt} = NAN.{dt}')
pcode = pcode.replace(f'elsif isDENORM(S1.{dt}) then\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})', f'elsif 1 == 0 then\nD0.{dt} = S0.{dt}')
pcode = pcode.replace(f'elsif exponent(S2.{dt}) <= {exp_lim} then\n// Numerator is tiny\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})',
f'elsif exponent(S2.{dt}) <= {exp_lim} then\nVCC = 0x1LL;\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})')
pcode = pcode.replace(f'elsif isDENORM(S2.{dt} / S1.{dt}) then\nVCC = 0x1LL;\nif S0.{dt} == S2.{dt} then\n// Only scale the numerator\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif',
f'elsif isDENORM(S2.{dt} / S1.{dt}) then\nVCC = 0x1LL;\nD0.{dt} = S0.{dt}')
pcode = pcode.replace(f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif\nelsif', f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nelse\nD0.{dt} = S0.{dt}\nendif\nelsif')
lines = pcode.rstrip().split('\n')
for i in range(len(lines) - 1, -1, -1):
if lines[i].strip() == 'endif':
lines.insert(i, f'else\nD0.{dt} = S0.{dt}')
break
pcode = '\n'.join(lines) + f';\nif isDENORM(S1.{dt}) then\nD0.{dt} = NAN.{dt}\nendif'
if op_name == 'V_TRIG_PREOP_F64':
pcode = pcode.replace("result = 64'F((1201'B(2.0 / PI)[1200 : 0] << shift.u32) & 1201'0x1fffffffffffff)", "result = trig_preop_result(shift)")
return pcode
@functools.cache
def compile_uop(op_name: str, pseudocode: str):
pseudocode = _apply_pseudocode_fixes(op_name, pseudocode)
mem_buf = LDS_BUF if op_name.startswith('DS_') else MEM_BUF
sink, output_info, input_vars, mem_stores = _compile_pseudocode(pseudocode, mem_buf)
sink, output_info, input_vars, mem_stores = _compile_pseudocode(pseudocode, mem_buf, op_name)
return _make_fn(sink, output_info, input_vars, mem_stores)