diff --git a/extra/assembly/amd/pcode_transform.py b/extra/assembly/amd/pcode_transform.py index ca99923b47..cefc2e687e 100644 --- a/extra/assembly/amd/pcode_transform.py +++ b/extra/assembly/amd/pcode_transform.py @@ -335,7 +335,46 @@ def _transform_stmt(stmt, ctx: dict): case UOp(): return _transform_uop(stmt, ctx) case _: return stmt -def parse_transform(pcode: str) -> tuple: +def _apply_pseudocode_fixes(op_name: str, pcode: str) -> str: + """Apply known fixes for PDF pseudocode bugs.""" + if op_name == 'V_DIV_FMAS_F32': + pcode = pcode.replace('D0.f32 = 2.0F ** 32 * fma(S0.f32, S1.f32, S2.f32)', + 'D0.f32 = (exponent(S2.f32) > 127) ? (2.0F ** 64 * fma(S0.f32, S1.f32, S2.f32)) : (2.0F ** -64 * fma(S0.f32, S1.f32, S2.f32))') + if op_name == 'V_DIV_FMAS_F64': + pcode = pcode.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)', + 'D0.f64 = (exponent(S2.f64) > 1023) ? (2.0 ** 128 * fma(S0.f64, S1.f64, S2.f64)) : (2.0 ** -128 * fma(S0.f64, S1.f64, S2.f64))') + if op_name == 'V_DIV_FIXUP_F32': + pcode = pcode.replace('D0.f32 = sign_out ? -abs(S0.f32) : abs(S0.f32)', + 'D0.f32 = isNAN(S0.f32) ? (sign_out ? -OVERFLOW_F32 : OVERFLOW_F32) : (sign_out ? -abs(S0.f32) : abs(S0.f32))') + if op_name == 'V_DIV_FIXUP_F64': + pcode = pcode.replace('D0.f64 = sign_out ? -abs(S0.f64) : abs(S0.f64)', + 'D0.f64 = isNAN(S0.f64) ? (sign_out ? -OVERFLOW_F64 : OVERFLOW_F64) : (sign_out ? -abs(S0.f64) : abs(S0.f64))') + if 'V_DIV_SCALE' in op_name: + dt = 'f32' if 'F32' in op_name else 'f64' + exp_lim, ldexp_val = ('23', '64') if dt == 'f32' else ('52', '128') + pcode = pcode.replace(f'S2.{dt} / S1.{dt} == DENORM.{dt}', f'isDENORM(S2.{dt} / S1.{dt})') + pcode = pcode.replace(f"1.0 / 64'F(S1.{dt}) == DENORM.f64", f"isDENORM(1.0 / 64'F(S1.{dt}))") + pcode = pcode.replace(f'1.0 / S1.{dt} == DENORM.{dt}', f'isDENORM(1.0 / S1.{dt})') + pcode = pcode.replace(f'S1.{dt} == DENORM.{dt}', f'isDENORM(S1.{dt})') + pcode = pcode.replace(f'D0.{dt} = NAN.{dt}', f'VCC = 0x1LL;\nD0.{dt} = NAN.{dt}') + pcode = pcode.replace(f'elsif isDENORM(S1.{dt}) then\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})', f'elsif 1 == 0 then\nD0.{dt} = S0.{dt}') + pcode = pcode.replace(f'elsif exponent(S2.{dt}) <= {exp_lim} then\n// Numerator is tiny\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})', + f'elsif exponent(S2.{dt}) <= {exp_lim} then\nVCC = 0x1LL;\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})') + pcode = pcode.replace(f'elsif isDENORM(S2.{dt} / S1.{dt}) then\nVCC = 0x1LL;\nif S0.{dt} == S2.{dt} then\n// Only scale the numerator\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif', + f'elsif isDENORM(S2.{dt} / S1.{dt}) then\nVCC = 0x1LL;\nD0.{dt} = S0.{dt}') + pcode = pcode.replace(f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif\nelsif', f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nelse\nD0.{dt} = S0.{dt}\nendif\nelsif') + lines = pcode.rstrip().split('\n') + for i in range(len(lines) - 1, -1, -1): + if lines[i].strip() == 'endif': + lines.insert(i, f'else\nD0.{dt} = S0.{dt}') + break + pcode = '\n'.join(lines) + f';\nif isDENORM(S1.{dt}) then\nD0.{dt} = NAN.{dt}\nendif' + if op_name == 'V_TRIG_PREOP_F64': + pcode = pcode.replace("result = 64'F((1201'B(2.0 / PI)[1200 : 0] << shift.u32) & 1201'0x1fffffffffffff)", "result = trig_preop_result(shift)") + return pcode + +def parse_transform(pcode: str, op_name: str | None = None) -> tuple: + if op_name is not None: pcode = _apply_pseudocode_fixes(op_name, pcode) ctx: dict[str, DType] = {'SCC': dtypes.bool, 'VCC': dtypes.uint64, 'EXEC': dtypes.uint64, 'VDATA': dtypes.uint64, 'SDATA': dtypes.uint64, 'ADDR': dtypes.uint64, 'VDST': dtypes.uint32, 'ROUND_MODE': dtypes.uint32, 'ROUND_TOWARD_ZERO': dtypes.uint32, 'HW_REGISTERS': dtypes.uint32, diff --git a/extra/assembly/amd/ucode.py b/extra/assembly/amd/ucode.py index a6c619c5df..cd69d133ee 100644 --- a/extra/assembly/amd/ucode.py +++ b/extra/assembly/amd/ucode.py @@ -523,10 +523,10 @@ _DTYPE_ACCESSOR = {dtypes.uint8: 'u8', dtypes.int8: 'i8', dtypes.uint16: 'u16', dtypes.uint32: 'u32', dtypes.int32: 'i32', dtypes.uint64: 'u64', dtypes.int64: 'i64', dtypes.float32: 'u32', dtypes.float64: 'u64'} -def _compile_pseudocode(pseudocode: str, mem_buf: UOp = MEM_BUF) -> tuple[UOp, list[tuple[str, DType]], dict[str, UOp], list[UOp]]: +def _compile_pseudocode(pseudocode: str, mem_buf: UOp = MEM_BUF, op_name: str | None = None) -> tuple[UOp, list[tuple[str, DType]], dict[str, UOp], list[UOp]]: ctx = Ctx(mem_buf=mem_buf) try: - stmts = parse_transform(pseudocode) + stmts = parse_transform(pseudocode, op_name) except AssertionError as e: print("issue parsing") print(pseudocode) @@ -606,51 +606,8 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s return _extract_results(s1_sub) return fn -# ═══════════════════════════════════════════════════════════════════════════════ -# PSEUDOCODE FIXES -# ═══════════════════════════════════════════════════════════════════════════════ - -def _apply_pseudocode_fixes(op_name: str, pcode: str) -> str: - """Apply known fixes for PDF pseudocode bugs.""" - if op_name == 'V_DIV_FMAS_F32': - pcode = pcode.replace('D0.f32 = 2.0F ** 32 * fma(S0.f32, S1.f32, S2.f32)', - 'D0.f32 = (exponent(S2.f32) > 127) ? (2.0F ** 64 * fma(S0.f32, S1.f32, S2.f32)) : (2.0F ** -64 * fma(S0.f32, S1.f32, S2.f32))') - if op_name == 'V_DIV_FMAS_F64': - pcode = pcode.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)', - 'D0.f64 = (exponent(S2.f64) > 1023) ? (2.0 ** 128 * fma(S0.f64, S1.f64, S2.f64)) : (2.0 ** -128 * fma(S0.f64, S1.f64, S2.f64))') - if op_name == 'V_DIV_FIXUP_F32': - pcode = pcode.replace('D0.f32 = sign_out ? -abs(S0.f32) : abs(S0.f32)', - 'D0.f32 = isNAN(S0.f32) ? (sign_out ? -OVERFLOW_F32 : OVERFLOW_F32) : (sign_out ? -abs(S0.f32) : abs(S0.f32))') - if op_name == 'V_DIV_FIXUP_F64': - pcode = pcode.replace('D0.f64 = sign_out ? -abs(S0.f64) : abs(S0.f64)', - 'D0.f64 = isNAN(S0.f64) ? (sign_out ? -OVERFLOW_F64 : OVERFLOW_F64) : (sign_out ? -abs(S0.f64) : abs(S0.f64))') - if 'V_DIV_SCALE' in op_name: - dt = 'f32' if 'F32' in op_name else 'f64' - exp_lim, ldexp_val = ('23', '64') if dt == 'f32' else ('52', '128') - pcode = pcode.replace(f'S2.{dt} / S1.{dt} == DENORM.{dt}', f'isDENORM(S2.{dt} / S1.{dt})') - pcode = pcode.replace(f"1.0 / 64'F(S1.{dt}) == DENORM.f64", f"isDENORM(1.0 / 64'F(S1.{dt}))") - pcode = pcode.replace(f'1.0 / S1.{dt} == DENORM.{dt}', f'isDENORM(1.0 / S1.{dt})') - pcode = pcode.replace(f'S1.{dt} == DENORM.{dt}', f'isDENORM(S1.{dt})') - pcode = pcode.replace(f'D0.{dt} = NAN.{dt}', f'VCC = 0x1LL;\nD0.{dt} = NAN.{dt}') - pcode = pcode.replace(f'elsif isDENORM(S1.{dt}) then\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})', f'elsif 1 == 0 then\nD0.{dt} = S0.{dt}') - pcode = pcode.replace(f'elsif exponent(S2.{dt}) <= {exp_lim} then\n// Numerator is tiny\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})', - f'elsif exponent(S2.{dt}) <= {exp_lim} then\nVCC = 0x1LL;\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})') - pcode = pcode.replace(f'elsif isDENORM(S2.{dt} / S1.{dt}) then\nVCC = 0x1LL;\nif S0.{dt} == S2.{dt} then\n// Only scale the numerator\nD0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif', - f'elsif isDENORM(S2.{dt} / S1.{dt}) then\nVCC = 0x1LL;\nD0.{dt} = S0.{dt}') - pcode = pcode.replace(f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nendif\nelsif', f'D0.{dt} = ldexp(S0.{dt}, {ldexp_val})\nelse\nD0.{dt} = S0.{dt}\nendif\nelsif') - lines = pcode.rstrip().split('\n') - for i in range(len(lines) - 1, -1, -1): - if lines[i].strip() == 'endif': - lines.insert(i, f'else\nD0.{dt} = S0.{dt}') - break - pcode = '\n'.join(lines) + f';\nif isDENORM(S1.{dt}) then\nD0.{dt} = NAN.{dt}\nendif' - if op_name == 'V_TRIG_PREOP_F64': - pcode = pcode.replace("result = 64'F((1201'B(2.0 / PI)[1200 : 0] << shift.u32) & 1201'0x1fffffffffffff)", "result = trig_preop_result(shift)") - return pcode - @functools.cache def compile_uop(op_name: str, pseudocode: str): - pseudocode = _apply_pseudocode_fixes(op_name, pseudocode) mem_buf = LDS_BUF if op_name.startswith('DS_') else MEM_BUF - sink, output_info, input_vars, mem_stores = _compile_pseudocode(pseudocode, mem_buf) + sink, output_info, input_vars, mem_stores = _compile_pseudocode(pseudocode, mem_buf, op_name) return _make_fn(sink, output_info, input_vars, mem_stores)