unwrap instead of cast [pr] (#12982)

This commit is contained in:
chenyu
2025-10-28 21:29:23 -04:00
committed by GitHub
parent f55fcfecf9
commit ef16e6c68c
3 changed files with 7 additions and 9 deletions

View File

@@ -1,5 +1,4 @@
import unittest
from typing import List, cast
import numpy as np
from tinygrad.device import Buffer, Device, is_dtype_supported
from tinygrad.dtype import dtypes, ConstType
@@ -15,15 +14,15 @@ from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.codegen import full_rewrite
from tinygrad.engine.realize import lower_schedule_item
def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
def _test_uop_result(inputs:list[Tensor], stores:list[UOp], local_size=None):
for x in inputs: x.realize()
# NOTE: we only toposort the stores
uops: List[UOp] = []
def _recursive_add(uop:UOp) -> List[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop]
uops: list[UOp] = []
def _recursive_add(uop:UOp) -> list[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop]
uops = dedup(flatten(_recursive_add(st) for st in stores))
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
inbufs = [cast(UOp,x.uop).base.buffer for x in inputs]
inbufs = [x.uop.base.buffer for x in inputs]
src = Device[Device.DEFAULT].renderer.render(uops)
ei = CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test",
src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size))

View File

@@ -6,7 +6,7 @@ from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
from tinygrad.dtype import dtypes, DType, PtrDType, AddrSpace
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.helpers import flatten, get_single_element, prod
from tinygrad.helpers import flatten, get_single_element, prod, unwrap
def render_val(x, dtype):
if dtypes.is_float(dtype):
@@ -181,7 +181,7 @@ class PTXRenderer(Renderer):
def ssa(prefix:str, u:UOp|None=None, dtype:str|None=None) -> str:
nonlocal c, r
prefix += f"_{dtype if dtype is not None else self.types[cast(UOp, u).dtype.base]}_"
prefix += f"_{dtype if dtype is not None else self.types[unwrap(u).dtype.base]}_"
c[prefix] += 1
return f"%{prefix}{c[prefix]-1}"

View File

@@ -1,5 +1,4 @@
# all of symbolic lives here now
from typing import cast
import math, operator, struct, functools
from collections import defaultdict
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
@@ -131,7 +130,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([
def lt_folding(x:UOp, c:int) -> UOp|None:
p, np = partition(x.split_uop(Ops.ADD), lambda u: u.const_factor() == 1)
if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
return cast(UOp, UOp.sum(*np).divides(d))<(c//d)
return unwrap(UOp.sum(*np).divides(d))<(c//d)
return None
def canonicalize_simplex(X:UOp) -> UOp|None: