diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index a9ded30e29..d0e5db725a 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -4,7 +4,7 @@ import functools, itertools, operator from dataclasses import dataclass from typing import List, Tuple, cast, Optional from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.shape.view import variable_to_uop +from tinygrad.shape.view import sint_to_uop from tinygrad.dtype import dtypes, PtrDType from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element from tinygrad.renderer import Renderer @@ -76,11 +76,10 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext: get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max) else: # all loops are RANGES - idxs = [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False)) - for i,g in enumerate(full_shape[:first_reduce])] + idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), (i, False)) for i,g in enumerate(full_shape[:first_reduce])] # reduce loops - idxs += [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True)) + idxs += [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), (i, True)) for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)] # upcast loops @@ -91,7 +90,7 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext: # late indexes (group for reduce) ridxs = idxs[:] for a in range(first_reduce, first_reduce+group_for_reduces): - ridxs[a] = UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True)) + ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(full_shape[a])), (1000+a, True)) return IndexContext(idxs, ridxs) diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index d6c6489e53..1034147548 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -74,15 +74,14 @@ def _reshape_mask(_mask:Optional[Tuple[Tuple[sint, sint], ...]], old_shape:Tuple return tuple(reversed(new_mask)) def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]: - strides = strides_for_shape(shape) result = [] - for stride in strides: + for stride in strides_for_shape(shape): here = offs // stride if stride != 0 else 0 result.append(here) offs -= here * stride return result -def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x +def sint_to_uop(x:sint) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x @dataclass(frozen=True) class View: @@ -100,7 +99,7 @@ class View: def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]: idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs - iexpr = variable_to_uop(self.offset) + iexpr = sint_to_uop(self.offset) for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)): if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st if m is not None: