variable_to_uop -> sint_to_uop [pr] (#7847)

and added type to it
This commit is contained in:
chenyu
2024-11-22 10:54:59 -05:00
committed by GitHub
parent 40d7535eeb
commit f6d1201c48
2 changed files with 7 additions and 9 deletions

View File

@@ -4,7 +4,7 @@ import functools, itertools, operator
from dataclasses import dataclass
from typing import List, Tuple, cast, Optional
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import variable_to_uop
from tinygrad.shape.view import sint_to_uop
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element
from tinygrad.renderer import Renderer
@@ -76,11 +76,10 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
else:
# all loops are RANGES
idxs = [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False))
for i,g in enumerate(full_shape[:first_reduce])]
idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), (i, False)) for i,g in enumerate(full_shape[:first_reduce])]
# reduce loops
idxs += [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True))
idxs += [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), (i, True))
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
# upcast loops
@@ -91,7 +90,7 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
# late indexes (group for reduce)
ridxs = idxs[:]
for a in range(first_reduce, first_reduce+group_for_reduces):
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True))
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(full_shape[a])), (1000+a, True))
return IndexContext(idxs, ridxs)

View File

@@ -74,15 +74,14 @@ def _reshape_mask(_mask:Optional[Tuple[Tuple[sint, sint], ...]], old_shape:Tuple
return tuple(reversed(new_mask))
def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
strides = strides_for_shape(shape)
result = []
for stride in strides:
for stride in strides_for_shape(shape):
here = offs // stride if stride != 0 else 0
result.append(here)
offs -= here * stride
return result
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x
def sint_to_uop(x:sint) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x
@dataclass(frozen=True)
class View:
@@ -100,7 +99,7 @@ class View:
def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
iexpr = variable_to_uop(self.offset)
iexpr = sint_to_uop(self.offset)
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
if m is not None: