From 9513f025c5e1b581a6b10aa2b7b88822ce8ddbc3 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 29 Sep 2025 14:16:31 +0300 Subject: [PATCH] apply multi before rangeify (#12298) * it doesn't realize it when i reshape * cleaner graph * map out * REDUCE_AXIS also gives the wrong answer * maybe * work * back here * try * more * refactor tests * check MultiBuffer * or copy * fine with this * don't need graph_rewrite_map in rangeify --- test/test_multitensor.py | 11 +++++++++++ tinygrad/schedule/multi.py | 5 ++++- tinygrad/schedule/rangeify.py | 8 +------- tinygrad/tensor.py | 5 +++++ 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 86159ba375..b688edde61 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -54,6 +54,17 @@ class TestMultiTensor(unittest.TestCase): assert lb.shape == (128,) (X + X).realize() + def _test_shard_op(self, op, out, n=4): + t = Tensor.ones(n).contiguous().realize().shard(devices_2, 0) + r = op(t).realize() + assert t.uop.is_realized, "shard didn't realize" + self.assertEqual(r.tolist(), out) + def test_shard_reshape(self): self._test_shard_op(lambda t:t.reshape(2, 2), [[1.,1.],[1.,1.]]) + def test_shard_elementwise(self): self._test_shard_op(lambda t:(t+t).reshape(2, 2), [[2.,2.],[2.,2.]]) + def test_shard_reduce(self): + self._test_shard_op(lambda t:t.reshape(2, 3).sum(axis=1), [3.,3.], n=6) + self._test_shard_op(lambda t:t.reshape(2, 3).sum(axis=0), [2.,2.,2.], n=6) + def test_shard_not_multiple(self): X = Tensor.ones(256).contiguous().realize() with self.assertRaises(RuntimeError): diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 7c695db824..a1cfe18c1a 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -1,7 +1,7 @@ from typing import cast, TypeVar import functools, itertools, operator from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv, unwrap -from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, resolve +from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, resolve, track_rewrites, graph_rewrite_map from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.device import Device @@ -239,3 +239,6 @@ multi_pm = PatternMatcher([ (UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi), ])+replace_allreduce + +@track_rewrites() +def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]: return graph_rewrite_map(big_sink, multi_pm) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index d283bddd46..4581981850 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -2,11 +2,9 @@ from typing import Any, cast import functools, operator from dataclasses import dataclass, field from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace -from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, graph_rewrite_map +from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify from tinygrad.uop.symbolic import sym, symbolic_simple from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup -from tinygrad.schedule.multi import multi_pm - from tinygrad.schedule.kernelize import Kernel from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType @@ -594,10 +592,6 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: uop_list: list[UOp] = [] tsink = graph_rewrite(sink, add_tags, ctx=uop_list, bottom_up=True, name="number the uops") - # HACKS: handle multi with graph_rewrite_map in order to not have to add all the tag logic to multi - msink = graph_rewrite_map(tsink, multi_pm, name="multi") - tsink = msink[tsink].substitute({v:v.rtag(k.tag) for k,v in msink.items() if v.tag is None and k.tag is not None}) - tsink = graph_rewrite(tsink, earliest_rewrites, name="earliest rewrites") realize_map: dict[UOp, UOp] = {} graph_rewrite(tsink, do_realize, ctx=realize_map, name="Input Graph") diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index fe15acb2ec..4ceb9576f3 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -16,6 +16,7 @@ from tinygrad.engine.realize import run_schedule from tinygrad.engine.memory import memory_planner from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars from tinygrad.schedule.rangeify import get_rangeify_map +from tinygrad.schedule.multi import get_multi_map from tinygrad.schedule.kernelize import get_kernelize_map # *** all in scope Tensors are here. this gets relevant UOps *** @@ -241,6 +242,10 @@ class Tensor(MathTrait): # verify Tensors match the spec if __debug__: type_verify(list(big_sink.toposort()), tensor_uop_spec) + if RANGEIFY and any(isinstance(x._device, tuple) for x in big_sink.toposort()): + _apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map") + big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst])) + becomes_map = get_rangeify_map(big_sink) if RANGEIFY else get_kernelize_map(big_sink) _apply_map_to_tensors(becomes_map, name="Apply Kernelize Map") return self