diff --git a/test/test_vmap.py b/test/test_vmap.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index f23445578a..91cafbf48d 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -2,7 +2,8 @@ from __future__ import annotations import math, itertools from collections import defaultdict from typing import cast, Final -from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp, axis_letters, axis_colors +from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp +from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos from tinygrad.device import Buffer from tinygrad.dtype import dtypes, ImageDType from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten @@ -12,10 +13,6 @@ from tinygrad.renderer import Renderer remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) -# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters -axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3, - AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5} - class Scheduler: def __init__(self, ast:UOp, ren:Renderer): self.ast, self.ren = ast, ren diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index c4fd7a3a4c..883fd1452b 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -469,7 +469,7 @@ pm_add_range_tags = PatternMatcher([ ]) def split_store(ctx:list[UOp], x:UOp) -> UOp|None: - if len(x.ranges): return None + if len([r for r in x.ranges if r.arg[-1] != AxisType.OUTER]): return None # local kernel rewrite lctx = LocalAddBufferContext() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 8100c07728..600ffe1fa4 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -14,11 +14,16 @@ if TYPE_CHECKING: class AxisType(Enum): def __repr__(self): return str(self) GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702 - THREAD = auto() + THREAD = auto(); OUTER = auto() # noqa: E702 axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u", - AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"} + AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.OUTER: "O"} axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE", - AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"} + AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta", + AxisType.OUTER: "green"} + +# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters +axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3, + AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5, AxisType.OUTER: -2} range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}