diff --git a/extra/assembly/amd/pcode_parse.py b/extra/assembly/amd/pcode_parse.py index fafa75d8cb..9191a1e8fe 100644 --- a/extra/assembly/amd/pcode_parse.py +++ b/extra/assembly/amd/pcode_parse.py @@ -5,8 +5,6 @@ from dataclasses import dataclass from tinygrad.dtype import dtypes, DType from tinygrad.uop import Ops from tinygrad.uop.ops import UOp -from tinygrad.helpers import DEBUG - # DType lookup table for AMD pseudocode type suffixes from tinygrad.dtype import INVERSE_DTYPES_DICT _QDTYPES: dict[str, DType] = { diff --git a/extra/assembly/amd/ucode.py b/extra/assembly/amd/ucode.py index f37498396e..e8f75ca8df 100644 --- a/extra/assembly/amd/ucode.py +++ b/extra/assembly/amd/ucode.py @@ -31,7 +31,6 @@ INPUT_VARS['ADDR_BASE'] = INPUT_VARS['ADDR'] MEM_BUF = UOp(Ops.DEFINE_GLOBAL, dtypes.uint8.ptr(addrspace=AddrSpace.GLOBAL), arg=0) LDS_BUF = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(addrspace=AddrSpace.LOCAL), arg=0) -VGPR_BUF = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(addrspace=AddrSpace.GLOBAL), arg=1) # Float bit layout: (uint_type, sign_shift, exp_shift, exp_mask, mantissa_mask, bias) FP_INFO = { @@ -211,6 +210,12 @@ def _fp_bits(v: UOp) -> tuple[UOp, int, int, int]: uint_dt, _, exp_shift, exp_mask, mant_mask, _ = FP_INFO.get(v.dtype, FP_INFO[dtypes.float32]) return UOp(Ops.BITCAST, uint_dt, (v,)), exp_shift, exp_mask, mant_mask +def _minmax(args: list[UOp], is_min: bool) -> UOp: + """Build min/max expression for 2 or 3 arguments.""" + cmp = lambda x, y: UOp(Ops.CMPLT, dtypes.bool, (x, y) if is_min else (y, x)) + result = UOp(Ops.WHERE, args[0].dtype, (cmp(args[0], args[1]), args[0], args[1])) + return UOp(Ops.WHERE, args[0].dtype, (cmp(result, args[2]), result, args[2])) if len(args) > 2 else result + def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp: if name == 'MEM': return a[0] if name == 'fma': return UOp(Ops.MULACC, a[2].dtype, (a[0], a[1], a[2])) @@ -294,7 +299,7 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp: return UOp(Ops.EXP2, a[0].dtype, (a[1] if a[1].dtype == a[0].dtype else UOp(Ops.CAST, a[0].dtype, (a[1],)),)) if name in MATH_OPS: return UOp(MATH_OPS[name], a[0].dtype, (a[0],)) if name == 'ldexp': return UOp(Ops.MUL, a[0].dtype, (a[0], UOp(Ops.EXP2, a[0].dtype, (UOp(Ops.CAST, a[0].dtype, (a[1],)),)))) - if name in ('min', 'max'): return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, ((a[0], a[1]) if name == 'min' else (a[1], a[0]))), a[0], a[1])) + if name in ('min', 'max'): return _minmax(a, is_min=(name == 'min')) if name in CVT_MAP: dt, clamp = CVT_MAP[name] v = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], UOp.const(a[0].dtype, 0.0))), UOp.const(a[0].dtype, 0.0), a[0])) if clamp else a[0] @@ -309,12 +314,7 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp: if name in ('LT_NEG_ZERO', 'GT_NEG_ZERO'): int_dt = {dtypes.float64: dtypes.int64, dtypes.float16: dtypes.int16}.get(a[0].dtype, dtypes.int32) return UOp(Ops.CMPLT, dtypes.bool, ((UOp(Ops.BITCAST, int_dt, (a[0],)), UOp(Ops.BITCAST, int_dt, (a[1],))) if 'LT' in name else (UOp(Ops.BITCAST, int_dt, (a[1],)), UOp(Ops.BITCAST, int_dt, (a[0],))))) - if name.startswith('v_min') or name.startswith('v_max'): - cmp = lambda x, y: UOp(Ops.CMPLT, dtypes.bool, ((x, y) if 'min' in name else (y, x))) - if '3_' in name: - m = UOp(Ops.WHERE, a[0].dtype, (cmp(a[0], a[1]), a[0], a[1])) - return UOp(Ops.WHERE, a[0].dtype, (cmp(m, a[2]), m, a[2])) - return UOp(Ops.WHERE, a[0].dtype, (cmp(a[0], a[1]), a[0], a[1])) + if name.startswith('v_min') or name.startswith('v_max'): return _minmax(a, is_min=('min' in name)) if name in ('v_sad_u8', 'v_msad_u8'): result = a[2] if len(a) > 2 else UOp.const(dtypes.uint32, 0) for i in range(4):