mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
lil cleanups
This commit is contained in:
@@ -5,8 +5,6 @@ from dataclasses import dataclass
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.helpers import DEBUG
|
||||
|
||||
# DType lookup table for AMD pseudocode type suffixes
|
||||
from tinygrad.dtype import INVERSE_DTYPES_DICT
|
||||
_QDTYPES: dict[str, DType] = {
|
||||
|
||||
@@ -31,7 +31,6 @@ INPUT_VARS['ADDR_BASE'] = INPUT_VARS['ADDR']
|
||||
|
||||
MEM_BUF = UOp(Ops.DEFINE_GLOBAL, dtypes.uint8.ptr(addrspace=AddrSpace.GLOBAL), arg=0)
|
||||
LDS_BUF = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(addrspace=AddrSpace.LOCAL), arg=0)
|
||||
VGPR_BUF = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(addrspace=AddrSpace.GLOBAL), arg=1)
|
||||
|
||||
# Float bit layout: (uint_type, sign_shift, exp_shift, exp_mask, mantissa_mask, bias)
|
||||
FP_INFO = {
|
||||
@@ -211,6 +210,12 @@ def _fp_bits(v: UOp) -> tuple[UOp, int, int, int]:
|
||||
uint_dt, _, exp_shift, exp_mask, mant_mask, _ = FP_INFO.get(v.dtype, FP_INFO[dtypes.float32])
|
||||
return UOp(Ops.BITCAST, uint_dt, (v,)), exp_shift, exp_mask, mant_mask
|
||||
|
||||
def _minmax(args: list[UOp], is_min: bool) -> UOp:
|
||||
"""Build min/max expression for 2 or 3 arguments."""
|
||||
cmp = lambda x, y: UOp(Ops.CMPLT, dtypes.bool, (x, y) if is_min else (y, x))
|
||||
result = UOp(Ops.WHERE, args[0].dtype, (cmp(args[0], args[1]), args[0], args[1]))
|
||||
return UOp(Ops.WHERE, args[0].dtype, (cmp(result, args[2]), result, args[2])) if len(args) > 2 else result
|
||||
|
||||
def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
|
||||
if name == 'MEM': return a[0]
|
||||
if name == 'fma': return UOp(Ops.MULACC, a[2].dtype, (a[0], a[1], a[2]))
|
||||
@@ -294,7 +299,7 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
|
||||
return UOp(Ops.EXP2, a[0].dtype, (a[1] if a[1].dtype == a[0].dtype else UOp(Ops.CAST, a[0].dtype, (a[1],)),))
|
||||
if name in MATH_OPS: return UOp(MATH_OPS[name], a[0].dtype, (a[0],))
|
||||
if name == 'ldexp': return UOp(Ops.MUL, a[0].dtype, (a[0], UOp(Ops.EXP2, a[0].dtype, (UOp(Ops.CAST, a[0].dtype, (a[1],)),))))
|
||||
if name in ('min', 'max'): return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, ((a[0], a[1]) if name == 'min' else (a[1], a[0]))), a[0], a[1]))
|
||||
if name in ('min', 'max'): return _minmax(a, is_min=(name == 'min'))
|
||||
if name in CVT_MAP:
|
||||
dt, clamp = CVT_MAP[name]
|
||||
v = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], UOp.const(a[0].dtype, 0.0))), UOp.const(a[0].dtype, 0.0), a[0])) if clamp else a[0]
|
||||
@@ -309,12 +314,7 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
|
||||
if name in ('LT_NEG_ZERO', 'GT_NEG_ZERO'):
|
||||
int_dt = {dtypes.float64: dtypes.int64, dtypes.float16: dtypes.int16}.get(a[0].dtype, dtypes.int32)
|
||||
return UOp(Ops.CMPLT, dtypes.bool, ((UOp(Ops.BITCAST, int_dt, (a[0],)), UOp(Ops.BITCAST, int_dt, (a[1],))) if 'LT' in name else (UOp(Ops.BITCAST, int_dt, (a[1],)), UOp(Ops.BITCAST, int_dt, (a[0],)))))
|
||||
if name.startswith('v_min') or name.startswith('v_max'):
|
||||
cmp = lambda x, y: UOp(Ops.CMPLT, dtypes.bool, ((x, y) if 'min' in name else (y, x)))
|
||||
if '3_' in name:
|
||||
m = UOp(Ops.WHERE, a[0].dtype, (cmp(a[0], a[1]), a[0], a[1]))
|
||||
return UOp(Ops.WHERE, a[0].dtype, (cmp(m, a[2]), m, a[2]))
|
||||
return UOp(Ops.WHERE, a[0].dtype, (cmp(a[0], a[1]), a[0], a[1]))
|
||||
if name.startswith('v_min') or name.startswith('v_max'): return _minmax(a, is_min=('min' in name))
|
||||
if name in ('v_sad_u8', 'v_msad_u8'):
|
||||
result = a[2] if len(a) > 2 else UOp.const(dtypes.uint32, 0)
|
||||
for i in range(4):
|
||||
|
||||
Reference in New Issue
Block a user