mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
apply multi before rangeify (#12298)
* it doesn't realize it when i reshape * cleaner graph * map out * REDUCE_AXIS also gives the wrong answer * maybe * work * back here * try * more * refactor tests * check MultiBuffer * or copy * fine with this * don't need graph_rewrite_map in rangeify
This commit is contained in:
@@ -54,6 +54,17 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert lb.shape == (128,)
|
||||
(X + X).realize()
|
||||
|
||||
def _test_shard_op(self, op, out, n=4):
|
||||
t = Tensor.ones(n).contiguous().realize().shard(devices_2, 0)
|
||||
r = op(t).realize()
|
||||
assert t.uop.is_realized, "shard didn't realize"
|
||||
self.assertEqual(r.tolist(), out)
|
||||
def test_shard_reshape(self): self._test_shard_op(lambda t:t.reshape(2, 2), [[1.,1.],[1.,1.]])
|
||||
def test_shard_elementwise(self): self._test_shard_op(lambda t:(t+t).reshape(2, 2), [[2.,2.],[2.,2.]])
|
||||
def test_shard_reduce(self):
|
||||
self._test_shard_op(lambda t:t.reshape(2, 3).sum(axis=1), [3.,3.], n=6)
|
||||
self._test_shard_op(lambda t:t.reshape(2, 3).sum(axis=0), [2.,2.,2.], n=6)
|
||||
|
||||
def test_shard_not_multiple(self):
|
||||
X = Tensor.ones(256).contiguous().realize()
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import cast, TypeVar
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv, unwrap
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, resolve
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, resolve, track_rewrites, graph_rewrite_map
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Device
|
||||
|
||||
@@ -239,3 +239,6 @@ multi_pm = PatternMatcher([
|
||||
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE),
|
||||
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
||||
])+replace_allreduce
|
||||
|
||||
@track_rewrites()
|
||||
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]: return graph_rewrite_map(big_sink, multi_pm)
|
||||
|
||||
@@ -2,11 +2,9 @@ from typing import Any, cast
|
||||
import functools, operator
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, graph_rewrite_map
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
|
||||
from tinygrad.schedule.kernelize import Kernel
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
|
||||
|
||||
@@ -594,10 +592,6 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
uop_list: list[UOp] = []
|
||||
tsink = graph_rewrite(sink, add_tags, ctx=uop_list, bottom_up=True, name="number the uops")
|
||||
|
||||
# HACKS: handle multi with graph_rewrite_map in order to not have to add all the tag logic to multi
|
||||
msink = graph_rewrite_map(tsink, multi_pm, name="multi")
|
||||
tsink = msink[tsink].substitute({v:v.rtag(k.tag) for k,v in msink.items() if v.tag is None and k.tag is not None})
|
||||
|
||||
tsink = graph_rewrite(tsink, earliest_rewrites, name="earliest rewrites")
|
||||
realize_map: dict[UOp, UOp] = {}
|
||||
graph_rewrite(tsink, do_realize, ctx=realize_map, name="Input Graph")
|
||||
|
||||
@@ -16,6 +16,7 @@ from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
from tinygrad.schedule.multi import get_multi_map
|
||||
from tinygrad.schedule.kernelize import get_kernelize_map
|
||||
|
||||
# *** all in scope Tensors are here. this gets relevant UOps ***
|
||||
@@ -241,6 +242,10 @@ class Tensor(MathTrait):
|
||||
# verify Tensors match the spec
|
||||
if __debug__: type_verify(list(big_sink.toposort()), tensor_uop_spec)
|
||||
|
||||
if RANGEIFY and any(isinstance(x._device, tuple) for x in big_sink.toposort()):
|
||||
_apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map")
|
||||
big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst]))
|
||||
|
||||
becomes_map = get_rangeify_map(big_sink) if RANGEIFY else get_kernelize_map(big_sink)
|
||||
_apply_map_to_tensors(becomes_map, name="Apply Kernelize Map")
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user