From f511ad9103e3f8948d510459f374eaca32649b33 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 19 Oct 2024 13:48:59 -0400 Subject: [PATCH] No pyint again (#7156) * Revert "bring back pyint (#7150)" This reverts commit 37e83ca6fceac48714061263fb71b4f5f57219e6. * remove truncate in const folding * truncate_output=False --- test/helpers.py | 1 - test/test_dtype.py | 6 +++--- test/test_linearizer_failures.py | 2 +- test/test_uops.py | 3 +++ tinygrad/codegen/lowerer.py | 10 +++++----- tinygrad/codegen/uopgraph.py | 6 ------ tinygrad/dtype.py | 3 +-- tinygrad/ops.py | 10 ++++------ tinygrad/shape/view.py | 4 ++-- 9 files changed, 19 insertions(+), 26 deletions(-) diff --git a/test/helpers.py b/test/helpers.py index 0c593aa4a3..ca75dbdee5 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -29,7 +29,6 @@ def assert_jit_cache_len(fxn, expected_len): assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT): - if dtype == dtypes.pyint and device != "PYTHON": return False if dtype == dtypes.bfloat16: # NOTE: this requires bf16 buffer support return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX")) diff --git a/test/test_dtype.py b/test/test_dtype.py index a88f2555eb..afd4a99e8e 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -14,7 +14,7 @@ pytestmark = pytest.mark.filterwarnings("ignore") settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False)) settings.load_profile("my_profile") -core_dtypes = list([v for k,v in DTYPES_DICT.items() if k != 'pyint']) +core_dtypes = list(DTYPES_DICT.values()) if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)] dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)] @@ -22,7 +22,7 @@ dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_sup def get_available_cast_dtypes(dtype: DType) -> List[DType]: if not is_dtype_supported(dtype): return [] # dont cast internal dtypes - return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_") and k != 'pyint'] + return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")] def _test_to_np(a:Tensor, np_dtype, target): if DEBUG >= 2: print(a) @@ -806,7 +806,7 @@ class TestTensorMethod(unittest.TestCase): class TestDtypeUsage(unittest.TestCase): def test_max_w_alu(self): - for d in dtype_ints: + for d in dtypes.ints: t = Tensor([[1, 2], [3, 4]], dtype=d) (t*t).max().item() diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index b818ab404b..ece7d43787 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -1315,7 +1315,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=7), Opt(op=OptOps.UPCAST, axis=1, amt=2)] - helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["HIP", "AMD", "METAL"]) + helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["HIP", "AMD"]) if __name__ == '__main__': unittest.main() diff --git a/test/test_uops.py b/test/test_uops.py index 037121bcaa..9cf3053290 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -230,6 +230,9 @@ class TestExecALU(TestUOps): self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (1, 1)), 2) self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-128, 0)), -128) + # test no truncate + self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250), truncate_output=False), 500) + class TestConstantFolding(unittest.TestCase): def test_cast_const(self): t = Tensor(1, dtype=dtypes.float).cast(dtypes.int) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index eb2c5cd51c..941e19d270 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -33,7 +33,7 @@ def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]): def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse=False) -> List[UOp]: if reverse: dims = dims[::-1] limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims - ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.pyint, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)] + ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)] if limited != dims: ret = [] # cast for mypy, get_contraction won't be None @@ -75,22 +75,22 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext: get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max) else: # all loops are RANGES - idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, False)) + idxs = [UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False)) for i,g in enumerate(full_shape[:first_reduce])] # reduce loops - idxs += [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, True)) + idxs += [UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True)) for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)] # upcast loops for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted): assert isinstance(g, int), "needs to be int to upcast/unroll" - idxs.append(UOp(UOps.EXPAND, dtypes.pyint, (UOp.const(dtypes.pyint.vec(g), tuple(range(g))),), ((i,g),))) + idxs.append(UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),))) # late indexes (group for reduce) ridxs = idxs[:] for a in range(first_reduce, first_reduce+group_for_reduces): - ridxs[a] = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(full_shape[a])), (1000+a, True)) + ridxs[a] = UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True)) return IndexContext(idxs, ridxs) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 671022cb0e..ec84df4935 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -545,9 +545,6 @@ reducer = PatternMatcher([ (UPat(UOps.LOAD, name="load"), simplify_buffer_load), ]) -no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.VCONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE, UOps.DEFINE_VAR), - name="x"), lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count), x.src, x.arg) if x.dtype.scalar() == dtypes.pyint else None)]) - # *** uop graph *** linearize_cnt = 0 @@ -559,9 +556,6 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: acc_number = 0 sink = graph_rewrite(sink, sym) - # rewrite pyint to int32 - sink = graph_rewrite(sink, no_pyint) - # expand linearize_cnt += 1 if linearize_cnt != (de:=getenv("DEBUG_EXPAND", 0)) and de != -1: diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index e39767d22c..19f314226e 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -84,7 +84,6 @@ class dtypes: @staticmethod def fields() -> Dict[str, DType]: return DTYPES_DICT void: Final[DType] = DType(-1, 0, "void", None, 1) - pyint: Final[DType] = DType(-1, 8, "pyint", None, 1) # arbitrary precision integer, same itemsize to int64 so min/max works bool: Final[DType] = DType(0, 1, "bool", '?', 1) int8: Final[DType] = DType(1, 1, "char", 'b', 1) uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1) @@ -116,7 +115,7 @@ class dtypes: floats = (float16, bfloat16, float32, float64) uints = (uint8, uint16, uint32, uint64) - sints = (int8, int16, int32, int64, pyint) + sints = (int8, int16, int32, int64) ints = uints + sints if (env_default_float := getenv("DEFAULT_FLOAT", "")): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index f1a342d5c1..be439bf9e6 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -411,10 +411,11 @@ python_alu: Dict[Op, Callable] = { BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf, TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z} -def exec_alu(op:Op, dtype:DType, operands): +def exec_alu(op:Op, dtype:DType, operands, truncate_output=True): if dtype.count > 1: return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)]) - return truncate.get(dtype, lambda x: x)(python_alu[op](*operands)) + alu = python_alu[op](*operands) + return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu # ***** uop helpers ***** @@ -691,9 +692,6 @@ spec = PatternMatcher([ (UPat(UOps.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype), (UPat(UOps.SPECIAL, src=()), lambda: True), - # no pyint allowed here! - (UPat(UOps.ALU, dtype=dtypes.pyint), lambda: False), - # TODO: confirm the args of both of these are shapetrackers (UPat(UOps.VIEW, src=()), lambda: True), (UPat(UOps.VIEW, src=(UPat(),)), lambda: True), @@ -906,7 +904,7 @@ symbolic = PatternMatcher([ (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), # ** constant folding ** (UPat(UOps.ALU, name="root", src=UPat((UOps.VCONST, UOps.CONST))), - lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))), + lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], truncate_output=False))), # ALU min==max -> CONST (slow!) (UPat(UOps.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), # max folding diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index fc85ecfd68..2542b99c06 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -82,7 +82,7 @@ def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]: offs -= here * stride return result -def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x +def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x @dataclass(frozen=True) class View: @@ -93,7 +93,7 @@ class View: contiguous:bool def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]: - idxs = [UOp.range(dtypes.pyint, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs + idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs iexpr = variable_to_uop(self.offset) for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)): if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st