mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
unwrap instead of cast [pr] (#12982)
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import unittest
|
||||
from typing import List, cast
|
||||
import numpy as np
|
||||
from tinygrad.device import Buffer, Device, is_dtype_supported
|
||||
from tinygrad.dtype import dtypes, ConstType
|
||||
@@ -15,15 +14,15 @@ from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.engine.realize import lower_schedule_item
|
||||
|
||||
def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
|
||||
def _test_uop_result(inputs:list[Tensor], stores:list[UOp], local_size=None):
|
||||
for x in inputs: x.realize()
|
||||
# NOTE: we only toposort the stores
|
||||
uops: List[UOp] = []
|
||||
def _recursive_add(uop:UOp) -> List[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop]
|
||||
uops: list[UOp] = []
|
||||
def _recursive_add(uop:UOp) -> list[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop]
|
||||
uops = dedup(flatten(_recursive_add(st) for st in stores))
|
||||
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \
|
||||
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
|
||||
inbufs = [cast(UOp,x.uop).base.buffer for x in inputs]
|
||||
inbufs = [x.uop.base.buffer for x in inputs]
|
||||
src = Device[Device.DEFAULT].renderer.render(uops)
|
||||
ei = CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test",
|
||||
src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size))
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, AddrSpace
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
from tinygrad.helpers import flatten, get_single_element, prod
|
||||
from tinygrad.helpers import flatten, get_single_element, prod, unwrap
|
||||
|
||||
def render_val(x, dtype):
|
||||
if dtypes.is_float(dtype):
|
||||
@@ -181,7 +181,7 @@ class PTXRenderer(Renderer):
|
||||
|
||||
def ssa(prefix:str, u:UOp|None=None, dtype:str|None=None) -> str:
|
||||
nonlocal c, r
|
||||
prefix += f"_{dtype if dtype is not None else self.types[cast(UOp, u).dtype.base]}_"
|
||||
prefix += f"_{dtype if dtype is not None else self.types[unwrap(u).dtype.base]}_"
|
||||
c[prefix] += 1
|
||||
return f"%{prefix}{c[prefix]-1}"
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# all of symbolic lives here now
|
||||
from typing import cast
|
||||
import math, operator, struct, functools
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
||||
@@ -131,7 +130,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
||||
def lt_folding(x:UOp, c:int) -> UOp|None:
|
||||
p, np = partition(x.split_uop(Ops.ADD), lambda u: u.const_factor() == 1)
|
||||
if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
|
||||
return cast(UOp, UOp.sum(*np).divides(d))<(c//d)
|
||||
return unwrap(UOp.sum(*np).divides(d))<(c//d)
|
||||
return None
|
||||
|
||||
def canonicalize_simplex(X:UOp) -> UOp|None:
|
||||
|
||||
Reference in New Issue
Block a user