mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
delete ShapeTracker to_valid_uop and substitute [pr] (#12563)
This commit is contained in:
@@ -3,14 +3,14 @@ import unittest
|
||||
import numpy as np
|
||||
from tinygrad.dtype import dtypes, Invalid
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View, views_to_valid_uop
|
||||
from tinygrad import Variable
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite
|
||||
from tinygrad.codegen.late.devectorizer import sym
|
||||
from itertools import product
|
||||
|
||||
def shapetracker_getitem(st:ShapeTracker, val:int):
|
||||
valid_idx = st.reshape((st.size,)).to_valid_uop([UOp.const(dtypes.int, val)])
|
||||
valid_idx = views_to_valid_uop(st.reshape((st.size,)).views, (UOp.const(dtypes.int, val),))
|
||||
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
|
||||
idx, valid = graph_rewrite(idx, sym), graph_rewrite(valid, sym)
|
||||
assert idx.op is Ops.CONST and valid.op is Ops.CONST
|
||||
|
||||
@@ -53,9 +53,6 @@ class ShapeTracker:
|
||||
@property
|
||||
def size(self) -> int: return self.views[-1].size()
|
||||
|
||||
def to_valid_uop(self, _idxs:list[UOp]|tuple[UOp, ...]|None=None) -> UOp:
|
||||
return views_to_valid_uop(self.views, tuple(_idxs) if _idxs is not None else None)
|
||||
|
||||
def vars(self) -> set[Variable]: return set().union(*[v.vars() for v in self.views])
|
||||
|
||||
@property
|
||||
@@ -65,7 +62,6 @@ class ShapeTracker:
|
||||
unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
|
||||
if all(len(x) == 0 for x in var_vals): return self, {}
|
||||
return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
|
||||
def substitute(self, dvars:dict[UOp, UOp]): return ShapeTracker(tuple(x.substitute(dvars) for x in self.views))
|
||||
|
||||
def real_strides(self, ignore_valid=False) -> tuple[sint|None, ...]:
|
||||
with Context(TRACK_MATCH_STATS=0): return views_to_real_strides(self.views, ignore_valid)
|
||||
|
||||
Reference in New Issue
Block a user