improve typing (#12611)

* improve typing and bump to 3.11

* no need for Self yet

* improve typing

* binop also
This commit is contained in:
George Hotz
2025-10-11 16:20:23 +08:00
committed by GitHub
parent 4300ebc455
commit cab034b863
5 changed files with 15 additions and 10 deletions

View File

@@ -50,4 +50,7 @@ exclude = [
"E303", "E304", "E501", "E702", "E703", "E731", "W191",
"W291", "W293", "UP039", "C416", "RET506", "RET507", "A",
"FURB110", "RUF018", "F541", "F841"
]
]
[format]
exclude = ["*"]

View File

@@ -1526,7 +1526,7 @@ class TestSchedule(unittest.TestCase):
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 4))
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(np.abs(np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum() + \
b.numpy())), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=3e-4, rtol=1e-6)
b.numpy())), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=3e-4, rtol=1e-5)
def test_shrink_pad_safe(self):
a = Tensor.ones((3, )).contiguous().realize()

View File

@@ -39,7 +39,7 @@ pm_gradient = PatternMatcher([
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)),)),
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
# there's no gradient for bitcast
(UPat(Ops.BITCAST), lambda ctx: (None,)),
(UPat(Ops.BITCAST), lambda: (None,)),
])
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:

View File

@@ -96,7 +96,7 @@ def suppress_finalizing(func):
if not getattr(sys, 'is_finalizing', lambda: True)(): raise # re-raise if not finalizing
return wrapper
def unwrap_class_type(cls_t:T): return cls_t.func if isinstance(cls_t, functools.partial) else cls_t
def unwrap_class_type(cls_t): return cls_t.func if isinstance(cls_t, functools.partial) else cls_t
def pluralize(st:str, cnt:int): return f"{cnt} {st}"+('' if cnt == 1 else 's')

View File

@@ -1,15 +1,17 @@
from typing import TypeVar
from tinygrad.uop import Ops
from tinygrad.helpers import T
from tinygrad.dtype import dtypes
from tinygrad.dtype import dtypes, ConstType
TMathTrait = TypeVar("TMathTrait", bound="MathTrait")
class MathTrait:
# required to implement
def alu(self:T, op:Ops, *src) -> T: raise NotImplementedError
def const_like(self:T, b) -> T: raise NotImplementedError
def alu(self:TMathTrait, op:Ops, *src:TMathTrait) -> TMathTrait: raise NotImplementedError
def const_like(self:TMathTrait, b:ConstType) -> TMathTrait: raise NotImplementedError
# great functions you get!
def ufix(self, x): return self.const_like(x) if not isinstance(x, MathTrait) else x
def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
def ufix(self:TMathTrait, x:ConstType|TMathTrait) -> TMathTrait: return self.const_like(x) if not isinstance(x, MathTrait) else x
def _binop(self:TMathTrait, op:Ops, x:TMathTrait|ConstType, reverse:bool) -> TMathTrait:
return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
def logical_not(self): return self.ne(True)
def neg(self):
if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")