mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
enforce UOps.CONST arg has python type based on dtype (#3952)
added an assert in uops, remove the cast in renderer
This commit is contained in:
@@ -85,8 +85,7 @@ class PythonProgram:
|
||||
elif arg[1][0] == 'l':
|
||||
ul[i] = [x[2-arg[0]] for x in warp]
|
||||
elif uop is UOps.CONST:
|
||||
casted_arg = int(arg) if dtypes.is_int(dtype) else float(arg)
|
||||
ul[i] = [[casted_arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [casted_arg] * warp_size
|
||||
ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
|
||||
elif uop is UOps.LOOP:
|
||||
@@ -105,7 +104,7 @@ class PythonProgram:
|
||||
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
|
||||
if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
|
||||
else:
|
||||
casted = [float(x) if dtypes.is_float(dtype) else int(x) if dtypes.is_int(dtype) else x for x in inp[0]]
|
||||
casted = [dtypes.as_const(x, dtype) for x in inp[0]]
|
||||
overflow_adjust = 2**(dtype.itemsize*8 - 1) if (dtypes.is_int(dtype) and not dtypes.is_unsigned(dtype)) else 0
|
||||
overflow_fixed = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) if dtypes.is_int(dtype) else x for x in casted]
|
||||
ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *overflow_fixed)))
|
||||
|
||||
Reference in New Issue
Block a user