mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
improve typing (#12611)
* improve typing and bump to 3.11 * no need for Self yet * improve typing * binop also
This commit is contained in:
@@ -50,4 +50,7 @@ exclude = [
|
||||
"E303", "E304", "E501", "E702", "E703", "E731", "W191",
|
||||
"W291", "W293", "UP039", "C416", "RET506", "RET507", "A",
|
||||
"FURB110", "RUF018", "F541", "F841"
|
||||
]
|
||||
]
|
||||
|
||||
[format]
|
||||
exclude = ["*"]
|
||||
|
||||
@@ -1526,7 +1526,7 @@ class TestSchedule(unittest.TestCase):
|
||||
# run_schedule(check_schedule(out, 1))
|
||||
run_schedule(check_schedule(out, 4))
|
||||
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(np.abs(np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum() + \
|
||||
b.numpy())), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=3e-4, rtol=1e-6)
|
||||
b.numpy())), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=3e-4, rtol=1e-5)
|
||||
|
||||
def test_shrink_pad_safe(self):
|
||||
a = Tensor.ones((3, )).contiguous().realize()
|
||||
|
||||
@@ -39,7 +39,7 @@ pm_gradient = PatternMatcher([
|
||||
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)),)),
|
||||
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
||||
# there's no gradient for bitcast
|
||||
(UPat(Ops.BITCAST), lambda ctx: (None,)),
|
||||
(UPat(Ops.BITCAST), lambda: (None,)),
|
||||
])
|
||||
|
||||
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
|
||||
|
||||
@@ -96,7 +96,7 @@ def suppress_finalizing(func):
|
||||
if not getattr(sys, 'is_finalizing', lambda: True)(): raise # re-raise if not finalizing
|
||||
return wrapper
|
||||
|
||||
def unwrap_class_type(cls_t:T): return cls_t.func if isinstance(cls_t, functools.partial) else cls_t
|
||||
def unwrap_class_type(cls_t): return cls_t.func if isinstance(cls_t, functools.partial) else cls_t
|
||||
|
||||
def pluralize(st:str, cnt:int): return f"{cnt} {st}"+('' if cnt == 1 else 's')
|
||||
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
from typing import TypeVar
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.helpers import T
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import dtypes, ConstType
|
||||
|
||||
TMathTrait = TypeVar("TMathTrait", bound="MathTrait")
|
||||
class MathTrait:
|
||||
# required to implement
|
||||
def alu(self:T, op:Ops, *src) -> T: raise NotImplementedError
|
||||
def const_like(self:T, b) -> T: raise NotImplementedError
|
||||
def alu(self:TMathTrait, op:Ops, *src:TMathTrait) -> TMathTrait: raise NotImplementedError
|
||||
def const_like(self:TMathTrait, b:ConstType) -> TMathTrait: raise NotImplementedError
|
||||
|
||||
# great functions you get!
|
||||
def ufix(self, x): return self.const_like(x) if not isinstance(x, MathTrait) else x
|
||||
def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
|
||||
def ufix(self:TMathTrait, x:ConstType|TMathTrait) -> TMathTrait: return self.const_like(x) if not isinstance(x, MathTrait) else x
|
||||
def _binop(self:TMathTrait, op:Ops, x:TMathTrait|ConstType, reverse:bool) -> TMathTrait:
|
||||
return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
|
||||
def logical_not(self): return self.ne(True)
|
||||
def neg(self):
|
||||
if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
|
||||
|
||||
Reference in New Issue
Block a user