mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
bring back pyint (#7150)
fixed test_failure_52 and resnet. need to understand this better
This commit is contained in:
@@ -29,6 +29,7 @@ def assert_jit_cache_len(fxn, expected_len):
|
||||
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
|
||||
|
||||
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
|
||||
if dtype == dtypes.pyint and device != "PYTHON": return False
|
||||
if dtype == dtypes.bfloat16:
|
||||
# NOTE: this requires bf16 buffer support
|
||||
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
|
||||
|
||||
@@ -14,7 +14,7 @@ pytestmark = pytest.mark.filterwarnings("ignore")
|
||||
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
settings.load_profile("my_profile")
|
||||
|
||||
core_dtypes = list(DTYPES_DICT.values())
|
||||
core_dtypes = list([v for k,v in DTYPES_DICT.items() if k != 'pyint'])
|
||||
if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
|
||||
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
|
||||
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
|
||||
@@ -22,7 +22,7 @@ dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_sup
|
||||
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
|
||||
if not is_dtype_supported(dtype): return []
|
||||
# dont cast internal dtypes
|
||||
return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")]
|
||||
return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_") and k != 'pyint']
|
||||
|
||||
def _test_to_np(a:Tensor, np_dtype, target):
|
||||
if DEBUG >= 2: print(a)
|
||||
@@ -806,7 +806,7 @@ class TestTensorMethod(unittest.TestCase):
|
||||
|
||||
class TestDtypeUsage(unittest.TestCase):
|
||||
def test_max_w_alu(self):
|
||||
for d in dtypes.ints:
|
||||
for d in dtype_ints:
|
||||
t = Tensor([[1, 2], [3, 4]], dtype=d)
|
||||
(t*t).max().item()
|
||||
|
||||
|
||||
@@ -1258,7 +1258,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.TC, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16)]
|
||||
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["CUDA", "NV", "METAL"])
|
||||
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
|
||||
|
||||
def test_failure_53(self):
|
||||
# COMPILE_ERROR, val scope issue
|
||||
|
||||
@@ -33,7 +33,7 @@ def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
|
||||
def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse=False) -> List[UOp]:
|
||||
if reverse: dims = dims[::-1]
|
||||
limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
|
||||
ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
|
||||
ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.pyint, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
|
||||
if limited != dims:
|
||||
ret = []
|
||||
# cast for mypy, get_contraction won't be None
|
||||
@@ -75,22 +75,22 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
||||
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
|
||||
else:
|
||||
# all loops are RANGES
|
||||
idxs = [UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False))
|
||||
idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, False))
|
||||
for i,g in enumerate(full_shape[:first_reduce])]
|
||||
|
||||
# reduce loops
|
||||
idxs += [UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True))
|
||||
idxs += [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, True))
|
||||
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
|
||||
|
||||
# upcast loops
|
||||
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
|
||||
assert isinstance(g, int), "needs to be int to upcast/unroll"
|
||||
idxs.append(UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
|
||||
idxs.append(UOp(UOps.EXPAND, dtypes.pyint, (UOp.const(dtypes.pyint.vec(g), tuple(range(g))),), ((i,g),)))
|
||||
|
||||
# late indexes (group for reduce)
|
||||
ridxs = idxs[:]
|
||||
for a in range(first_reduce, first_reduce+group_for_reduces):
|
||||
ridxs[a] = UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True))
|
||||
ridxs[a] = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(full_shape[a])), (1000+a, True))
|
||||
|
||||
return IndexContext(idxs, ridxs)
|
||||
|
||||
|
||||
@@ -545,6 +545,9 @@ reducer = PatternMatcher([
|
||||
(UPat(UOps.LOAD, name="load"), simplify_buffer_load),
|
||||
])
|
||||
|
||||
no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.VCONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE, UOps.DEFINE_VAR),
|
||||
name="x"), lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count), x.src, x.arg) if x.dtype.scalar() == dtypes.pyint else None)])
|
||||
|
||||
# *** uop graph ***
|
||||
|
||||
linearize_cnt = 0
|
||||
@@ -556,6 +559,9 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
acc_number = 0
|
||||
sink = graph_rewrite(sink, sym)
|
||||
|
||||
# rewrite pyint to int32
|
||||
sink = graph_rewrite(sink, no_pyint)
|
||||
|
||||
# expand
|
||||
linearize_cnt += 1
|
||||
if linearize_cnt != (de:=getenv("DEBUG_EXPAND", 0)) and de != -1:
|
||||
|
||||
@@ -84,6 +84,7 @@ class dtypes:
|
||||
@staticmethod
|
||||
def fields() -> Dict[str, DType]: return DTYPES_DICT
|
||||
void: Final[DType] = DType(-1, 0, "void", None, 1)
|
||||
pyint: Final[DType] = DType(-1, 8, "pyint", None, 1) # arbitrary precision integer, same itemsize to int64 so min/max works
|
||||
bool: Final[DType] = DType(0, 1, "bool", '?', 1)
|
||||
int8: Final[DType] = DType(1, 1, "char", 'b', 1)
|
||||
uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1)
|
||||
@@ -115,7 +116,7 @@ class dtypes:
|
||||
|
||||
floats = (float16, bfloat16, float32, float64)
|
||||
uints = (uint8, uint16, uint32, uint64)
|
||||
sints = (int8, int16, int32, int64)
|
||||
sints = (int8, int16, int32, int64, pyint)
|
||||
ints = uints + sints
|
||||
|
||||
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
|
||||
|
||||
@@ -691,6 +691,9 @@ spec = PatternMatcher([
|
||||
(UPat(UOps.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
|
||||
(UPat(UOps.SPECIAL, src=()), lambda: True),
|
||||
|
||||
# no pyint allowed here!
|
||||
(UPat(UOps.ALU, dtype=dtypes.pyint), lambda: False),
|
||||
|
||||
# TODO: confirm the args of both of these are shapetrackers
|
||||
(UPat(UOps.VIEW, src=()), lambda: True),
|
||||
(UPat(UOps.VIEW, src=(UPat(),)), lambda: True),
|
||||
|
||||
@@ -82,7 +82,7 @@ def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
|
||||
offs -= here * stride
|
||||
return result
|
||||
|
||||
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x
|
||||
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class View:
|
||||
@@ -93,7 +93,7 @@ class View:
|
||||
contiguous:bool
|
||||
|
||||
def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
|
||||
idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
|
||||
idxs = [UOp.range(dtypes.pyint, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
|
||||
iexpr = variable_to_uop(self.offset)
|
||||
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
|
||||
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
|
||||
|
||||
Reference in New Issue
Block a user