mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
regen
This commit is contained in:
@@ -1224,13 +1224,7 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str:
|
||||
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
|
||||
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})"
|
||||
|
||||
# v_fmaak/v_fmamk/v_madak/v_madmk literal handling - need literal= keyword for VOP2
|
||||
# fmamk/madmk: dst = src0 * K + vsrc1, fmaak/madak: dst = src0 * vsrc1 + K
|
||||
lit_s = ""
|
||||
_ak_ops = ('v_fmaak_f32', 'v_fmaak_f16', 'v_madak_f32', 'v_madak_f16')
|
||||
_mk_ops = ('v_fmamk_f32', 'v_fmamk_f16', 'v_madmk_f32', 'v_madmk_f16')
|
||||
if mn in _ak_ops and len(args) == 4: lit_s, args = f", literal={args[3].strip()}", args[:3]
|
||||
elif mn in _mk_ops and len(args) == 4: lit_s, args = f", literal={args[2].strip()}", [args[0], args[1], args[3]]
|
||||
# v_fmaak/v_fmamk/v_madak/v_madmk - autogen functions take K positionally, syntax matches signature
|
||||
|
||||
# VCC ops cleanup
|
||||
vcc_ops = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'}
|
||||
@@ -1274,7 +1268,6 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str:
|
||||
if inline_abs: neg_hi = inline_abs
|
||||
|
||||
all_kw = list(kw)
|
||||
if lit_s: all_kw.append(lit_s.lstrip(', '))
|
||||
if opsel is not None: all_kw.append(f'opsel={opsel}')
|
||||
if opsel_hi is not None:
|
||||
all_kw.append(f'opsel_hi={opsel_hi & 3}')
|
||||
|
||||
@@ -974,8 +974,8 @@ v_sub_f16_e32 = functools.partial(VOP2, VOP2Op.V_SUB_F16)
|
||||
v_subrev_f16_e32 = functools.partial(VOP2, VOP2Op.V_SUBREV_F16)
|
||||
v_mul_f16_e32 = functools.partial(VOP2, VOP2Op.V_MUL_F16)
|
||||
v_mac_f16_e32 = functools.partial(VOP2, VOP2Op.V_MAC_F16)
|
||||
v_madmk_f16_e32 = functools.partial(VOP2, VOP2Op.V_MADMK_F16)
|
||||
v_madak_f16_e32 = functools.partial(VOP2, VOP2Op.V_MADAK_F16)
|
||||
def v_madmk_f16_e32(vdst, src0, K, vsrc1): return VOP2(VOP2Op.V_MADMK_F16, vdst, src0, vsrc1, literal=K)
|
||||
def v_madak_f16_e32(vdst, src0, vsrc1, K): return VOP2(VOP2Op.V_MADAK_F16, vdst, src0, vsrc1, literal=K)
|
||||
v_add_u16_e32 = functools.partial(VOP2, VOP2Op.V_ADD_U16)
|
||||
v_sub_u16_e32 = functools.partial(VOP2, VOP2Op.V_SUB_U16)
|
||||
v_subrev_u16_e32 = functools.partial(VOP2, VOP2Op.V_SUBREV_U16)
|
||||
|
||||
@@ -352,9 +352,9 @@ def _generate_ins_py(formats, enums, src_enum, doc_name) -> str:
|
||||
tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt, f"{fmt}, {cls_name}")
|
||||
if fmt in formats or fmt in ("GLOBAL", "SCRATCH"):
|
||||
suffix = "_e32" if fmt in ("VOP1", "VOP2", "VOPC") else "_e64" if fmt == "VOP3" and op_val < 512 else ""
|
||||
if name in ('V_FMAMK_F32', 'V_FMAMK_F16'):
|
||||
if name in ('V_FMAMK_F32', 'V_FMAMK_F16', 'V_MADMK_F32', 'V_MADMK_F16'):
|
||||
lines.append(f"def {name.lower()}{suffix}(vdst, src0, K, vsrc1): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
|
||||
elif name in ('V_FMAAK_F32', 'V_FMAAK_F16'):
|
||||
elif name in ('V_FMAAK_F32', 'V_FMAAK_F16', 'V_MADAK_F32', 'V_MADAK_F16'):
|
||||
lines.append(f"def {name.lower()}{suffix}(vdst, src0, vsrc1, K): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
|
||||
else: lines.append(f"{name.lower()}{suffix} = functools.partial({tgt}.{name}{seg})")
|
||||
src_names = {name for _, name in src_enum.items()}
|
||||
|
||||
Reference in New Issue
Block a user