From cab034b8633adb0d4a274f3dfdd9d557cf638c61 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 11 Oct 2025 16:20:23 +0800 Subject: [PATCH] improve typing (#12611) * improve typing and bump to 3.11 * no need for Self yet * improve typing * binop also --- ruff.toml | 5 ++++- test/test_schedule.py | 2 +- tinygrad/gradient.py | 2 +- tinygrad/helpers.py | 2 +- tinygrad/uop/mathtraits.py | 14 ++++++++------ 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/ruff.toml b/ruff.toml index 22f9bf566b..0d5b7cb8f0 100644 --- a/ruff.toml +++ b/ruff.toml @@ -50,4 +50,7 @@ exclude = [ "E303", "E304", "E501", "E702", "E703", "E731", "W191", "W291", "W293", "UP039", "C416", "RET506", "RET507", "A", "FURB110", "RUF018", "F541", "F841" -] \ No newline at end of file +] + +[format] +exclude = ["*"] diff --git a/test/test_schedule.py b/test/test_schedule.py index 4817272317..7761a0ca95 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1526,7 +1526,7 @@ class TestSchedule(unittest.TestCase): # run_schedule(check_schedule(out, 1)) run_schedule(check_schedule(out, 4)) np.testing.assert_allclose(out.numpy(), np.pad(np.log2(np.abs(np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum() + \ - b.numpy())), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=3e-4, rtol=1e-6) + b.numpy())), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=3e-4, rtol=1e-5) def test_shrink_pad_safe(self): a = Tensor.ones((3, )).contiguous().realize() diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index c538555ad2..3d68868fdb 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -39,7 +39,7 @@ pm_gradient = PatternMatcher([ (UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)),)), (UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src), # there's no gradient for bitcast - (UPat(Ops.BITCAST), lambda ctx: (None,)), + (UPat(Ops.BITCAST), lambda: (None,)), ]) def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]: diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index d9c21933d6..c2f37a8738 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -96,7 +96,7 @@ def suppress_finalizing(func): if not getattr(sys, 'is_finalizing', lambda: True)(): raise # re-raise if not finalizing return wrapper -def unwrap_class_type(cls_t:T): return cls_t.func if isinstance(cls_t, functools.partial) else cls_t +def unwrap_class_type(cls_t): return cls_t.func if isinstance(cls_t, functools.partial) else cls_t def pluralize(st:str, cnt:int): return f"{cnt} {st}"+('' if cnt == 1 else 's') diff --git a/tinygrad/uop/mathtraits.py b/tinygrad/uop/mathtraits.py index 0de976c90b..2da0ea887a 100644 --- a/tinygrad/uop/mathtraits.py +++ b/tinygrad/uop/mathtraits.py @@ -1,15 +1,17 @@ +from typing import TypeVar from tinygrad.uop import Ops -from tinygrad.helpers import T -from tinygrad.dtype import dtypes +from tinygrad.dtype import dtypes, ConstType +TMathTrait = TypeVar("TMathTrait", bound="MathTrait") class MathTrait: # required to implement - def alu(self:T, op:Ops, *src) -> T: raise NotImplementedError - def const_like(self:T, b) -> T: raise NotImplementedError + def alu(self:TMathTrait, op:Ops, *src:TMathTrait) -> TMathTrait: raise NotImplementedError + def const_like(self:TMathTrait, b:ConstType) -> TMathTrait: raise NotImplementedError # great functions you get! - def ufix(self, x): return self.const_like(x) if not isinstance(x, MathTrait) else x - def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x)) + def ufix(self:TMathTrait, x:ConstType|TMathTrait) -> TMathTrait: return self.const_like(x) if not isinstance(x, MathTrait) else x + def _binop(self:TMathTrait, op:Ops, x:TMathTrait|ConstType, reverse:bool) -> TMathTrait: + return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x)) def logical_not(self): return self.ne(True) def neg(self): if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")