mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import functools, itertools, operator
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, cast, Optional
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -55,14 +54,11 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
||||
# NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
|
||||
full_shape = ast.full_shape
|
||||
first_upcasted = len(full_shape)-ki.upcasted
|
||||
first_output_st: ShapeTracker = ast.src[0].st_arg
|
||||
# if there's no reduce, this is first_upcasted. assumes reduces are at the end
|
||||
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.sparents if x.op is Ops.REDUCE_AXIS))
|
||||
local_loads = [x for x in ast.parents if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
|
||||
# NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
|
||||
group_for_reduces = sum([any(j!=y for j in x) for x,y in zip(
|
||||
[[l.st_arg.shape[i] for l in local_loads] for i in range(first_reduce,first_upcasted)],
|
||||
first_output_st.shape[first_reduce:first_upcasted])]) if local_loads else 0
|
||||
group_for_reduces = sum([any(l.st_arg.shape[i]!=ast.src[0].st_arg.shape[i] for l in local_loads) for i in range(first_reduce,first_upcasted)])
|
||||
global_dims = first_reduce-ki.local_dims
|
||||
|
||||
if opts.has_local:
|
||||
|
||||
Reference in New Issue
Block a user