mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
more in pcode_transform
This commit is contained in:
@@ -19,13 +19,12 @@ _SPECIAL_CASTS = {'f32_to_u32', 'f64_to_u32', 'f16_to_u32', 'f32_to_u64', 'f64_t
|
||||
_CAST_MAP = {f'{s}_to_{d}': _DT_SUFFIX[d] for s in _DT_SUFFIX for d in _DT_SUFFIX if s != d and f'{s}_to_{d}' not in _SPECIAL_CASTS}
|
||||
_CAST_MAP.update({f'v_cvt_{d}_{s}': _DT_SUFFIX[d] for s in _DT_SUFFIX for d in _DT_SUFFIX if s != d and f'{s}_to_{d}' not in _SPECIAL_CASTS})
|
||||
|
||||
# CUSTOM op return types
|
||||
# CUSTOM op return types (ops that stay as CUSTOM after transformation)
|
||||
_CUSTOM_TYPES = {
|
||||
'isDENORM': dtypes.bool, 'isQuietNAN': dtypes.bool, 'isSignalNAN': dtypes.bool, 'isEven': dtypes.bool,
|
||||
'LT_NEG_ZERO': dtypes.bool, 'GT_NEG_ZERO': dtypes.bool,
|
||||
'sign': dtypes.uint32, 'exponent': dtypes.uint32, 'ABSDIFF': dtypes.uint32, 'SAT8': dtypes.uint32,
|
||||
'BYTE_PERMUTE': dtypes.uint32, 'count_ones': dtypes.uint32, 'countbits': dtypes.uint32, 'reverse_bits': dtypes.uint32,
|
||||
'u8_to_u32': dtypes.uint32, 'u4_to_u32': dtypes.uint32, 's_ff1_i32_b32': dtypes.uint32, 's_ff1_i32_b64': dtypes.uint32,
|
||||
'isDENORM': dtypes.bool, 'isQuietNAN': dtypes.bool, 'isSignalNAN': dtypes.bool,
|
||||
'sign': dtypes.uint32, 'exponent': dtypes.uint32, 'BYTE_PERMUTE': dtypes.uint32,
|
||||
'count_ones': dtypes.uint32, 'countbits': dtypes.uint32, 'reverse_bits': dtypes.uint32,
|
||||
's_ff1_i32_b32': dtypes.uint32, 's_ff1_i32_b64': dtypes.uint32,
|
||||
'v_sad_u8': dtypes.uint32, 'v_msad_u8': dtypes.uint32, 'ConvertFromFormat': dtypes.uint32, 'nop': dtypes.uint32,
|
||||
'f32_to_u32': dtypes.uint32, 'f64_to_u32': dtypes.uint32, 'signext_from_bit': dtypes.int64,
|
||||
'f16_to_snorm': dtypes.int16, 'f16_to_unorm': dtypes.uint16, 'f32_to_snorm': dtypes.int16, 'f32_to_unorm': dtypes.uint16,
|
||||
@@ -62,8 +61,9 @@ def _vn(u: UOp) -> str|None: # var name
|
||||
if u.op == Ops.DEFINE_VAR: return u.arg[0] if isinstance(u.arg, tuple) else u.arg
|
||||
return _vn(u.src[0]) if u.op == Ops.CUSTOMI and u.src[0].op == Ops.DEFINE_VAR else None
|
||||
|
||||
def _typed_const(src: UOp, val) -> UOp: # const matching src type
|
||||
return UOp.const(src.dtype, val) if src.dtype != dtypes.void else UOp(Ops.CONST, dtypes.void, (src,), val)
|
||||
|
||||
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PATTERN HANDLERS
|
||||
@@ -160,8 +160,8 @@ pcode_pm = PatternMatcher([
|
||||
(UPat(Ops.CUSTOM, arg='sin', src=(_fpat,)), lambda x: UOp(Ops.SIN, x.dtype, (x,))),
|
||||
(UPat(Ops.CUSTOM, arg='rcp', src=(_fpat,)), lambda x: UOp(Ops.RECIPROCAL, x.dtype, (x,))),
|
||||
(UPat(Ops.CUSTOM, arg='fma', src=(_fpat, UPat.var('b'), UPat.var('c'))), lambda x, b, c: UOp(Ops.MULACC, x.dtype, (x, b, c))),
|
||||
(UPat(Ops.CUSTOM, arg='abs', src=(_fpat,)), lambda x: UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, _typed_const(x, 0))), UOp(Ops.NEG, x.dtype, (x,)), x))),
|
||||
(UPat(Ops.CUSTOM, arg='cos', src=(_fpat,)), lambda x: UOp(Ops.SIN, x.dtype, (UOp(Ops.ADD, x.dtype, (x, _typed_const(x, 1.5707963267948966))),))),
|
||||
(UPat(Ops.CUSTOM, arg='abs', src=(_fpat,)), lambda x: UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, _tc(x, 0))), UOp(Ops.NEG, x.dtype, (x,)), x))),
|
||||
(UPat(Ops.CUSTOM, arg='cos', src=(_fpat,)), lambda x: UOp(Ops.SIN, x.dtype, (UOp(Ops.ADD, x.dtype, (x, _tc(x, 1.5707963267948966))),))),
|
||||
(UPat(Ops.CUSTOM, arg='floor', src=(_fpat,)), lambda x: _floor(x)),
|
||||
(UPat(Ops.CUSTOM, arg='fract', src=(_fpat,)), lambda x: UOp(Ops.SUB, x.dtype, (x, _floor(x)))),
|
||||
(UPat(Ops.CUSTOM, arg='rsqrt', src=(_fpat,)), lambda x: UOp(Ops.RECIPROCAL, x.dtype, (UOp(Ops.SQRT, x.dtype, (x,)),))),
|
||||
@@ -186,7 +186,24 @@ pcode_pm = PatternMatcher([
|
||||
# Boolean functions
|
||||
(UPat(Ops.CUSTOM, arg='isNAN', src=(UPat.var('x'),)), lambda x: UOp(Ops.CMPNE, dtypes.bool, (x, x))),
|
||||
(UPat(Ops.CUSTOM, arg='isINF', src=(UPat.var('x'),)), lambda x: UOp(Ops.OR, dtypes.bool, (
|
||||
UOp(Ops.CMPEQ, dtypes.bool, (x, _typed_const(x, float('inf')))), UOp(Ops.CMPEQ, dtypes.bool, (x, _typed_const(x, float('-inf'))))))),
|
||||
UOp(Ops.CMPEQ, dtypes.bool, (x, _tc(x, float('inf')))), UOp(Ops.CMPEQ, dtypes.bool, (x, _tc(x, float('-inf'))))))),
|
||||
# ABSDIFF: |a - b| for unsigned
|
||||
(UPat(Ops.CUSTOM, arg='ABSDIFF', src=(UPat.var('a'), UPat.var('b'))),
|
||||
lambda a, b: UOp(Ops.WHERE, a.dtype, (UOp(Ops.CMPLT, dtypes.bool, (b, a)),
|
||||
UOp(Ops.SUB, a.dtype, (a, b)), UOp(Ops.SUB, a.dtype, (b, a))))),
|
||||
# SAT8: clamp to signed 8-bit range [-128, 127]
|
||||
(UPat(Ops.CUSTOM, arg='SAT8', src=(UPat.var('x'),)),
|
||||
lambda x: UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, _tc(x, -128))), _tc(x, -128),
|
||||
UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (_tc(x, 127), x)), _tc(x, 127), x))))),
|
||||
# cvtToQuietNAN: passthrough (just returns the arg)
|
||||
(UPat(Ops.CUSTOM, arg='cvtToQuietNAN', src=(UPat.var('x'),)), lambda x: x),
|
||||
# LT/GT_NEG_ZERO: signed comparison via bitcast to int (handles f16 fallback to f32 int type)
|
||||
(UPat(Ops.CUSTOM, arg='LT_NEG_ZERO', src=(UPat.var('a', dtype=dtypes.floats), UPat.var('b', dtype=dtypes.floats))),
|
||||
lambda a, b: (idt := {dtypes.float64: dtypes.int64, dtypes.float16: dtypes.int16}.get(a.dtype, dtypes.int32)) and
|
||||
UOp(Ops.CMPLT, dtypes.bool, (UOp(Ops.BITCAST, idt, (a,)), UOp(Ops.BITCAST, idt, (b,))))),
|
||||
(UPat(Ops.CUSTOM, arg='GT_NEG_ZERO', src=(UPat.var('a', dtype=dtypes.floats), UPat.var('b', dtype=dtypes.floats))),
|
||||
lambda a, b: (idt := {dtypes.float64: dtypes.int64, dtypes.float16: dtypes.int16}.get(a.dtype, dtypes.int32)) and
|
||||
UOp(Ops.CMPLT, dtypes.bool, (UOp(Ops.BITCAST, idt, (b,)), UOp(Ops.BITCAST, idt, (a,))))),
|
||||
# min/max
|
||||
(UPat(Ops.CUSTOM, arg='min', src=(UPat.var('a'), UPat.var('b'))), lambda a, b: _minmax(a, b, True)),
|
||||
(UPat(Ops.CUSTOM, arg='max', src=(UPat.var('a'), UPat.var('b'))), lambda a, b: _minmax(a, b, False)),
|
||||
|
||||
@@ -232,18 +232,10 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
|
||||
result = UOp(Ops.BITCAST, a[0].dtype, (UOp(Ops.OR, uint_dt, (UOp(Ops.AND, uint_dt, (bits, UOp.const(uint_dt, (1 << sign_shift) | mant_mask))),
|
||||
UOp.const(uint_dt, (bias - 1) << exp_shift))),))
|
||||
return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPEQ, dtypes.bool, (a[0], UOp.const(a[0].dtype, 0.0))), a[0], result))
|
||||
if name == 'cvtToQuietNAN': return a[0]
|
||||
if name == 'signext_from_bit':
|
||||
sign = UOp(Ops.SHL, a[0].dtype, (UOp.const(a[0].dtype, 1), UOp(Ops.SUB, a[0].dtype, (_cast(a[1], a[0].dtype), UOp.const(a[0].dtype, 1)))))
|
||||
result = UOp(Ops.SUB, a[0].dtype, (UOp(Ops.XOR, a[0].dtype, (a[0], sign)), sign))
|
||||
return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPEQ, dtypes.bool, (a[1], UOp.const(a[1].dtype, 0))), UOp.const(a[0].dtype, 0), result))
|
||||
if name == 'ABSDIFF':
|
||||
gt = UOp(Ops.CMPLT, dtypes.bool, (a[1], a[0]))
|
||||
return UOp(Ops.SUB, dtypes.uint32, (UOp(Ops.WHERE, dtypes.uint32, (gt, _cast(a[0], dtypes.uint32), _cast(a[1], dtypes.uint32))),
|
||||
UOp(Ops.WHERE, dtypes.uint32, (gt, _cast(a[1], dtypes.uint32), _cast(a[0], dtypes.uint32)))))
|
||||
if name == 'SAT8':
|
||||
c = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], UOp.const(a[0].dtype, -128))), UOp.const(a[0].dtype, -128), a[0]))
|
||||
return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(a[0].dtype, 127), c)), UOp.const(a[0].dtype, 127), c))
|
||||
if name == 'BYTE_PERMUTE':
|
||||
src64, sel = _cast(a[0], dtypes.uint64), UOp(Ops.AND, dtypes.uint32, (_cast(a[1], dtypes.uint32), UOp.const(dtypes.uint32, 0xff)))
|
||||
sel_idx, sel_nib = UOp(Ops.AND, dtypes.uint32, (sel, UOp.const(dtypes.uint32, 7))), UOp(Ops.AND, dtypes.uint32, (sel, UOp.const(dtypes.uint32, 0xf)))
|
||||
@@ -268,9 +260,6 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
|
||||
c = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], UOp.const(a[0].dtype, lo))), UOp.const(a[0].dtype, lo), a[0]))
|
||||
c = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(a[0].dtype, 1.0), c)), UOp.const(a[0].dtype, 1.0), c))
|
||||
return UOp(Ops.CAST, out, (UOp(Ops.MUL, a[0].dtype, (c, UOp.const(a[0].dtype, scale))),))
|
||||
if name in ('LT_NEG_ZERO', 'GT_NEG_ZERO'):
|
||||
int_dt = {dtypes.float64: dtypes.int64, dtypes.float16: dtypes.int16}.get(a[0].dtype, dtypes.int32)
|
||||
return UOp(Ops.CMPLT, dtypes.bool, ((UOp(Ops.BITCAST, int_dt, (a[0],)), UOp(Ops.BITCAST, int_dt, (a[1],))) if 'LT' in name else (UOp(Ops.BITCAST, int_dt, (a[1],)), UOp(Ops.BITCAST, int_dt, (a[0],)))))
|
||||
if name in ('v_sad_u8', 'v_msad_u8'):
|
||||
result = a[2] if len(a) > 2 else UOp.const(dtypes.uint32, 0)
|
||||
for i in range(4):
|
||||
|
||||
Reference in New Issue
Block a user