lil cleanups

This commit is contained in:
George Hotz
2026-01-07 23:49:53 -08:00
parent 56ba96f5cd
commit 10836a5dba
2 changed files with 8 additions and 10 deletions

View File

@@ -5,8 +5,6 @@ from dataclasses import dataclass
from tinygrad.dtype import dtypes, DType
from tinygrad.uop import Ops
from tinygrad.uop.ops import UOp
from tinygrad.helpers import DEBUG
# DType lookup table for AMD pseudocode type suffixes
from tinygrad.dtype import INVERSE_DTYPES_DICT
_QDTYPES: dict[str, DType] = {

View File

@@ -31,7 +31,6 @@ INPUT_VARS['ADDR_BASE'] = INPUT_VARS['ADDR']
MEM_BUF = UOp(Ops.DEFINE_GLOBAL, dtypes.uint8.ptr(addrspace=AddrSpace.GLOBAL), arg=0)
LDS_BUF = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(addrspace=AddrSpace.LOCAL), arg=0)
VGPR_BUF = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(addrspace=AddrSpace.GLOBAL), arg=1)
# Float bit layout: (uint_type, sign_shift, exp_shift, exp_mask, mantissa_mask, bias)
FP_INFO = {
@@ -211,6 +210,12 @@ def _fp_bits(v: UOp) -> tuple[UOp, int, int, int]:
uint_dt, _, exp_shift, exp_mask, mant_mask, _ = FP_INFO.get(v.dtype, FP_INFO[dtypes.float32])
return UOp(Ops.BITCAST, uint_dt, (v,)), exp_shift, exp_mask, mant_mask
def _minmax(args: list[UOp], is_min: bool) -> UOp:
"""Build min/max expression for 2 or 3 arguments."""
cmp = lambda x, y: UOp(Ops.CMPLT, dtypes.bool, (x, y) if is_min else (y, x))
result = UOp(Ops.WHERE, args[0].dtype, (cmp(args[0], args[1]), args[0], args[1]))
return UOp(Ops.WHERE, args[0].dtype, (cmp(result, args[2]), result, args[2])) if len(args) > 2 else result
def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
if name == 'MEM': return a[0]
if name == 'fma': return UOp(Ops.MULACC, a[2].dtype, (a[0], a[1], a[2]))
@@ -294,7 +299,7 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
return UOp(Ops.EXP2, a[0].dtype, (a[1] if a[1].dtype == a[0].dtype else UOp(Ops.CAST, a[0].dtype, (a[1],)),))
if name in MATH_OPS: return UOp(MATH_OPS[name], a[0].dtype, (a[0],))
if name == 'ldexp': return UOp(Ops.MUL, a[0].dtype, (a[0], UOp(Ops.EXP2, a[0].dtype, (UOp(Ops.CAST, a[0].dtype, (a[1],)),))))
if name in ('min', 'max'): return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, ((a[0], a[1]) if name == 'min' else (a[1], a[0]))), a[0], a[1]))
if name in ('min', 'max'): return _minmax(a, is_min=(name == 'min'))
if name in CVT_MAP:
dt, clamp = CVT_MAP[name]
v = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], UOp.const(a[0].dtype, 0.0))), UOp.const(a[0].dtype, 0.0), a[0])) if clamp else a[0]
@@ -309,12 +314,7 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
if name in ('LT_NEG_ZERO', 'GT_NEG_ZERO'):
int_dt = {dtypes.float64: dtypes.int64, dtypes.float16: dtypes.int16}.get(a[0].dtype, dtypes.int32)
return UOp(Ops.CMPLT, dtypes.bool, ((UOp(Ops.BITCAST, int_dt, (a[0],)), UOp(Ops.BITCAST, int_dt, (a[1],))) if 'LT' in name else (UOp(Ops.BITCAST, int_dt, (a[1],)), UOp(Ops.BITCAST, int_dt, (a[0],)))))
if name.startswith('v_min') or name.startswith('v_max'):
cmp = lambda x, y: UOp(Ops.CMPLT, dtypes.bool, ((x, y) if 'min' in name else (y, x)))
if '3_' in name:
m = UOp(Ops.WHERE, a[0].dtype, (cmp(a[0], a[1]), a[0], a[1]))
return UOp(Ops.WHERE, a[0].dtype, (cmp(m, a[2]), m, a[2]))
return UOp(Ops.WHERE, a[0].dtype, (cmp(a[0], a[1]), a[0], a[1]))
if name.startswith('v_min') or name.startswith('v_max'): return _minmax(a, is_min=('min' in name))
if name in ('v_sad_u8', 'v_msad_u8'):
result = a[2] if len(a) > 2 else UOp.const(dtypes.uint32, 0)
for i in range(4):