diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index 8092914be7..9e0559d44f 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -1,5 +1,4 @@ import unittest -from typing import List, cast import numpy as np from tinygrad.device import Buffer, Device, is_dtype_supported from tinygrad.dtype import dtypes, ConstType @@ -15,15 +14,15 @@ from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.codegen import full_rewrite from tinygrad.engine.realize import lower_schedule_item -def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None): +def _test_uop_result(inputs:list[Tensor], stores:list[UOp], local_size=None): for x in inputs: x.realize() # NOTE: we only toposort the stores - uops: List[UOp] = [] - def _recursive_add(uop:UOp) -> List[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop] + uops: list[UOp] = [] + def _recursive_add(uop:UOp) -> list[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop] uops = dedup(flatten(_recursive_add(st) for st in stores)) outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \ initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE] - inbufs = [cast(UOp,x.uop).base.buffer for x in inputs] + inbufs = [x.uop.base.buffer for x in inputs] src = Device[Device.DEFAULT].renderer.render(uops) ei = CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size)) diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index b94ff21d61..e61fc3eda1 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -6,7 +6,7 @@ from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp from tinygrad.dtype import dtypes, DType, PtrDType, AddrSpace from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer -from tinygrad.helpers import flatten, get_single_element, prod +from tinygrad.helpers import flatten, get_single_element, prod, unwrap def render_val(x, dtype): if dtypes.is_float(dtype): @@ -181,7 +181,7 @@ class PTXRenderer(Renderer): def ssa(prefix:str, u:UOp|None=None, dtype:str|None=None) -> str: nonlocal c, r - prefix += f"_{dtype if dtype is not None else self.types[cast(UOp, u).dtype.base]}_" + prefix += f"_{dtype if dtype is not None else self.types[unwrap(u).dtype.base]}_" c[prefix] += 1 return f"%{prefix}{c[prefix]-1}" diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index f22ac72c70..b22590e3a3 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -1,5 +1,4 @@ # all of symbolic lives here now -from typing import cast import math, operator, struct, functools from collections import defaultdict from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu @@ -131,7 +130,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([ def lt_folding(x:UOp, c:int) -> UOp|None: p, np = partition(x.split_uop(Ops.ADD), lambda u: u.const_factor() == 1) if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d: - return cast(UOp, UOp.sum(*np).divides(d))<(c//d) + return unwrap(UOp.sum(*np).divides(d))<(c//d) return None def canonicalize_simplex(X:UOp) -> UOp|None: