cvt functions

This commit is contained in:
George Hotz
2026-01-08 04:57:04 -08:00
parent 0dfdad0e76
commit 544a877960
3 changed files with 36 additions and 19 deletions

View File

@@ -48,6 +48,15 @@ def _floor(x):
trunc = UOp(Ops.TRUNC, x.dtype, (x,))
return UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, trunc)), UOp(Ops.SUB, x.dtype, (trunc, _typed_const(x, 1))), trunc))
def _cvt(src_dt: DType, dst_dt: DType):
"""Create a conversion function that asserts input type and casts to output type."""
def convert(x: UOp) -> UOp:
# Allow: exact match, void (unresolved), or uint32 (unresolved array access/slice)
# TODO: should only allow exact match
assert x.dtype == src_dt or x.dtype == dtypes.void or x.dtype == dtypes.uint32, f"Expected {src_dt}, got {x.dtype}"
return UOp(Ops.CAST, dst_dt, (x,))
return convert
# Function expansions: name -> lambda(*srcs) -> UOp
_FN_EXPAND: dict[str, callable] = {
'trunc': lambda x: UOp(Ops.TRUNC, x.dtype, (x,)),
@@ -69,6 +78,17 @@ _FN_EXPAND: dict[str, callable] = {
'max': lambda a, b: UOp(Ops.WHERE, a.dtype, (UOp(Ops.CMPLT, dtypes.bool, (b, a)), a, b)),
'clamp': lambda x, lo, hi: (c := UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, lo)), lo, x)),
UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (hi, c)), hi, c)))[1],
# Conversions (type-checked casts)
'f32_to_i32': _cvt(dtypes.float32, dtypes.int32), 'f32_to_f16': _cvt(dtypes.float32, dtypes.float16),
'f32_to_f64': _cvt(dtypes.float32, dtypes.float64), 'f32_to_i8': _cvt(dtypes.float32, dtypes.int8),
'f32_to_u8': _cvt(dtypes.float32, dtypes.uint8), 'f32_to_i16': _cvt(dtypes.float32, dtypes.int16),
'f32_to_u16': _cvt(dtypes.float32, dtypes.uint16), 'f64_to_i32': _cvt(dtypes.float64, dtypes.int32),
'f64_to_f32': _cvt(dtypes.float64, dtypes.float32), 'f16_to_f32': _cvt(dtypes.float16, dtypes.float32),
'f16_to_i16': _cvt(dtypes.float16, dtypes.int16), 'f16_to_u16': _cvt(dtypes.float16, dtypes.uint16),
'i32_to_f32': _cvt(dtypes.int32, dtypes.float32), 'i32_to_f64': _cvt(dtypes.int32, dtypes.float64),
'u32_to_f32': _cvt(dtypes.uint32, dtypes.float32), 'u32_to_f64': _cvt(dtypes.uint32, dtypes.float64),
'i16_to_f16': _cvt(dtypes.int16, dtypes.float16), 'u16_to_f16': _cvt(dtypes.uint16, dtypes.float16),
'v_cvt_u16_f32': _cvt(dtypes.float32, dtypes.uint16), 'v_cvt_i16_f32': _cvt(dtypes.float32, dtypes.int16),
}
# Function return type inference for CUSTOM ops
@@ -80,18 +100,12 @@ _U32_FNS = {'sign', 'exponent', 'ABSDIFF', 'SAT8', 'BYTE_PERMUTE', 'count_ones',
'u8_to_u32', 'u4_to_u32', 'u32_to_u16', 's_ff1_i32_b32', 's_ff1_i32_b64', 'v_sad_u8', 'v_msad_u8',
'v_min_u16', 'v_min_u32', 'v_max_u16', 'v_max_u32', 'v_min3_u16', 'v_min3_u32', 'v_max3_u16', 'v_max3_u32'}
_I32_FNS = {'v_min_i16', 'v_min_i32', 'v_max_i16', 'v_max_i32', 'v_min3_i16', 'v_min3_i32', 'v_max3_i16', 'v_max3_i32'}
_CVT_FNS = { # conversion functions: name -> output dtype
'f32_to_i32': dtypes.int32, 'f32_to_u32': dtypes.uint32, 'f32_to_f16': dtypes.float16, 'f32_to_f64': dtypes.float64,
'f32_to_i8': dtypes.int8, 'f32_to_u8': dtypes.uint8, 'f32_to_i16': dtypes.int16, 'f32_to_u16': dtypes.uint16,
'f64_to_i32': dtypes.int32, 'f64_to_u32': dtypes.uint32, 'f64_to_f32': dtypes.float32,
'f16_to_f32': dtypes.float32, 'f16_to_i16': dtypes.int16, 'f16_to_u16': dtypes.uint16,
'i32_to_f32': dtypes.float32, 'i32_to_f64': dtypes.float64, 'i32_to_i16': dtypes.int16,
'u32_to_f32': dtypes.float32, 'u32_to_f64': dtypes.float64,
'i16_to_f16': dtypes.float16, 'u16_to_f16': dtypes.float16,
'bf16_to_f32': dtypes.float32, 'f32_to_bf16': dtypes.bfloat16,
'v_cvt_u16_f32': dtypes.uint16, 'v_cvt_i16_f32': dtypes.int16,
'f16_to_snorm': dtypes.int16, 'f16_to_unorm': dtypes.uint16, 'f32_to_snorm': dtypes.int16, 'f32_to_unorm': dtypes.uint16,
'signext': dtypes.int64, 'signext_from_bit': dtypes.int64,
_CVT_FNS = { # conversion functions: name -> output dtype (only those not in _FN_EXPAND)
'f32_to_u32': dtypes.uint32, 'f64_to_u32': dtypes.uint32, # need clamping
'i32_to_i16': dtypes.int16, 'u32_to_u16': dtypes.uint32, # need masking
'bf16_to_f32': dtypes.float32, 'f32_to_bf16': dtypes.bfloat16, # bit manipulation
'f16_to_snorm': dtypes.int16, 'f16_to_unorm': dtypes.uint16, 'f32_to_snorm': dtypes.int16, 'f32_to_unorm': dtypes.uint16, # scaling
'signext': dtypes.int64, 'signext_from_bit': dtypes.int64, # special handling
}
def _infer_fn_dtype(name: str, srcs: tuple[UOp, ...]) -> DType:

View File

@@ -158,6 +158,14 @@ def _norm(s, keep_structure=False):
s = re.sub(r'//[^\n]*', '', s)
s = re.sub(r'0x[0-9a-fA-F]+', lambda m: str(int(m[0], 16)), s) # convert hex before stripping whitespace
s = re.sub(r"(\d+)U(?!LL)", r"\1", s) # strip U suffix early before whitespace removal
# Normalize conversion functions to typed cast syntax (f32_to_f64 -> 64'F, etc.)
cvt_map = {'f32_to_i32': "32'I", 'f32_to_f16': "16'F", 'f32_to_f64': "64'F", 'f32_to_i8': "8'I",
'f32_to_u8': "8'U", 'f32_to_i16': "16'I", 'f32_to_u16': "16'U", 'f64_to_i32': "32'I",
'f64_to_f32': "32'F", 'f16_to_f32': "32'F", 'f16_to_i16': "16'I", 'f16_to_u16': "16'U",
'i32_to_f32': "32'F", 'i32_to_f64': "64'F", 'u32_to_f32': "32'F", 'u32_to_f64': "64'F",
'i16_to_f16': "16'F", 'u16_to_f16': "16'F"}
for fn, cast in cvt_map.items():
s = re.sub(rf'\b{fn}\b', cast, s)
if keep_structure:
s = re.sub(r';', '', s)
s = re.sub(r'\n\s*\n', '\n', s)

View File

@@ -203,13 +203,8 @@ def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp:
# FUNCTION CALLS
# ═══════════════════════════════════════════════════════════════════════════════
CVT_MAP = {'u32_to_f32': (dtypes.float32, False), 'i32_to_f32': (dtypes.float32, False), 'f32_to_u32': (dtypes.uint32, True),
'f32_to_i32': (dtypes.int32, False), 'f16_to_f32': (dtypes.float32, False), 'f32_to_f16': (dtypes.float16, False),
'f32_to_u8': (dtypes.uint8, False), 'f32_to_i8': (dtypes.int8, False), 'f32_to_u16': (dtypes.uint16, False),
'f32_to_i16': (dtypes.int16, False), 'v_cvt_u16_f32': (dtypes.uint16, False), 'v_cvt_i16_f32': (dtypes.int16, False),
'f64_to_i32': (dtypes.int32, False), 'f64_to_u32': (dtypes.uint32, True), 'i32_to_f64': (dtypes.float64, False),
'u32_to_f64': (dtypes.float64, False), 'f64_to_f32': (dtypes.float32, False), 'f32_to_f64': (dtypes.float64, False),
'u16_to_f16': (dtypes.float16, False), 'i16_to_f16': (dtypes.float16, False), 'f16_to_u16': (dtypes.uint16, False), 'f16_to_i16': (dtypes.int16, False)}
# Conversions that need special handling (clamping negative to 0 before cast)
CVT_MAP = {'f32_to_u32': (dtypes.uint32, True), 'f64_to_u32': (dtypes.uint32, True)}
def _fp_bits(v: UOp) -> tuple[UOp, int, int, int]:
"""Get float as bits with its layout info. Unwraps CAST to check original float type."""