apply multi before rangeify (#12298)

* it doesn't realize it when i reshape

* cleaner graph

* map out

* REDUCE_AXIS also gives the wrong answer

* maybe

* work

* back here

* try

* more

* refactor tests

* check MultiBuffer

* or copy

* fine with this

* don't need graph_rewrite_map in rangeify
This commit is contained in:
qazal
2025-09-29 14:16:31 +03:00
committed by GitHub
parent b899392f30
commit 9513f025c5
4 changed files with 21 additions and 8 deletions

View File

@@ -54,6 +54,17 @@ class TestMultiTensor(unittest.TestCase):
assert lb.shape == (128,)
(X + X).realize()
def _test_shard_op(self, op, out, n=4):
t = Tensor.ones(n).contiguous().realize().shard(devices_2, 0)
r = op(t).realize()
assert t.uop.is_realized, "shard didn't realize"
self.assertEqual(r.tolist(), out)
def test_shard_reshape(self): self._test_shard_op(lambda t:t.reshape(2, 2), [[1.,1.],[1.,1.]])
def test_shard_elementwise(self): self._test_shard_op(lambda t:(t+t).reshape(2, 2), [[2.,2.],[2.,2.]])
def test_shard_reduce(self):
self._test_shard_op(lambda t:t.reshape(2, 3).sum(axis=1), [3.,3.], n=6)
self._test_shard_op(lambda t:t.reshape(2, 3).sum(axis=0), [2.,2.,2.], n=6)
def test_shard_not_multiple(self):
X = Tensor.ones(256).contiguous().realize()
with self.assertRaises(RuntimeError):

View File

@@ -1,7 +1,7 @@
from typing import cast, TypeVar
import functools, itertools, operator
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv, unwrap
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, resolve
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, resolve, track_rewrites, graph_rewrite_map
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Device
@@ -239,3 +239,6 @@ multi_pm = PatternMatcher([
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE),
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
])+replace_allreduce
@track_rewrites()
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]: return graph_rewrite_map(big_sink, multi_pm)

View File

@@ -2,11 +2,9 @@ from typing import Any, cast
import functools, operator
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, graph_rewrite_map
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify
from tinygrad.uop.symbolic import sym, symbolic_simple
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup
from tinygrad.schedule.multi import multi_pm
from tinygrad.schedule.kernelize import Kernel
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
@@ -594,10 +592,6 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
uop_list: list[UOp] = []
tsink = graph_rewrite(sink, add_tags, ctx=uop_list, bottom_up=True, name="number the uops")
# HACKS: handle multi with graph_rewrite_map in order to not have to add all the tag logic to multi
msink = graph_rewrite_map(tsink, multi_pm, name="multi")
tsink = msink[tsink].substitute({v:v.rtag(k.tag) for k,v in msink.items() if v.tag is None and k.tag is not None})
tsink = graph_rewrite(tsink, earliest_rewrites, name="earliest rewrites")
realize_map: dict[UOp, UOp] = {}
graph_rewrite(tsink, do_realize, ctx=realize_map, name="Input Graph")

View File

@@ -16,6 +16,7 @@ from tinygrad.engine.realize import run_schedule
from tinygrad.engine.memory import memory_planner
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.schedule.multi import get_multi_map
from tinygrad.schedule.kernelize import get_kernelize_map
# *** all in scope Tensors are here. this gets relevant UOps ***
@@ -241,6 +242,10 @@ class Tensor(MathTrait):
# verify Tensors match the spec
if __debug__: type_verify(list(big_sink.toposort()), tensor_uop_spec)
if RANGEIFY and any(isinstance(x._device, tuple) for x in big_sink.toposort()):
_apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map")
big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst]))
becomes_map = get_rangeify_map(big_sink) if RANGEIFY else get_kernelize_map(big_sink)
_apply_map_to_tensors(becomes_map, name="Apply Kernelize Map")
return self