mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
@@ -4,7 +4,7 @@ import functools, itertools, operator
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, cast, Optional
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import variable_to_uop
|
||||
from tinygrad.shape.view import sint_to_uop
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -76,11 +76,10 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
||||
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
|
||||
else:
|
||||
# all loops are RANGES
|
||||
idxs = [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False))
|
||||
for i,g in enumerate(full_shape[:first_reduce])]
|
||||
idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), (i, False)) for i,g in enumerate(full_shape[:first_reduce])]
|
||||
|
||||
# reduce loops
|
||||
idxs += [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True))
|
||||
idxs += [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), (i, True))
|
||||
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
|
||||
|
||||
# upcast loops
|
||||
@@ -91,7 +90,7 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
||||
# late indexes (group for reduce)
|
||||
ridxs = idxs[:]
|
||||
for a in range(first_reduce, first_reduce+group_for_reduces):
|
||||
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True))
|
||||
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(full_shape[a])), (1000+a, True))
|
||||
|
||||
return IndexContext(idxs, ridxs)
|
||||
|
||||
|
||||
@@ -74,15 +74,14 @@ def _reshape_mask(_mask:Optional[Tuple[Tuple[sint, sint], ...]], old_shape:Tuple
|
||||
return tuple(reversed(new_mask))
|
||||
|
||||
def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
|
||||
strides = strides_for_shape(shape)
|
||||
result = []
|
||||
for stride in strides:
|
||||
for stride in strides_for_shape(shape):
|
||||
here = offs // stride if stride != 0 else 0
|
||||
result.append(here)
|
||||
offs -= here * stride
|
||||
return result
|
||||
|
||||
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x
|
||||
def sint_to_uop(x:sint) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class View:
|
||||
@@ -100,7 +99,7 @@ class View:
|
||||
|
||||
def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
|
||||
idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
|
||||
iexpr = variable_to_uop(self.offset)
|
||||
iexpr = sint_to_uop(self.offset)
|
||||
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
|
||||
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
|
||||
if m is not None:
|
||||
|
||||
Reference in New Issue
Block a user