simplify group_for_reduces in get_index [pr] (#7851)

what was that
This commit is contained in:
chenyu
2024-11-22 11:53:21 -05:00
committed by GitHub
parent af5d77f684
commit a352a6938f

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import functools, itertools, operator
from dataclasses import dataclass
from typing import List, Tuple, cast, Optional
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
from tinygrad.renderer import Renderer
@@ -55,14 +54,11 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
# NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
full_shape = ast.full_shape
first_upcasted = len(full_shape)-ki.upcasted
first_output_st: ShapeTracker = ast.src[0].st_arg
# if there's no reduce, this is first_upcasted. assumes reduces are at the end
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.sparents if x.op is Ops.REDUCE_AXIS))
local_loads = [x for x in ast.parents if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
# NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
group_for_reduces = sum([any(j!=y for j in x) for x,y in zip(
[[l.st_arg.shape[i] for l in local_loads] for i in range(first_reduce,first_upcasted)],
first_output_st.shape[first_reduce:first_upcasted])]) if local_loads else 0
group_for_reduces = sum([any(l.st_arg.shape[i]!=ast.src[0].st_arg.shape[i] for l in local_loads) for i in range(first_reduce,first_upcasted)])
global_dims = first_reduce-ki.local_dims
if opts.has_local: