From d5bc27797b92e91bab862c9148c4b6ccc0fa403a Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 14 Sep 2025 14:31:57 +0800 Subject: [PATCH] fix some multitensor on rangeify (#12162) * fix some multitensor on rangeify * rangeify multi hacks * copy on const --- .github/workflows/test.yml | 2 ++ tinygrad/schedule/rangeify.py | 17 +++++++++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6d486c13b5..1a0490da87 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -526,6 +526,8 @@ jobs: -k "not test_symbolic_arange_sym_step and not test_threefry_doesnt_use_long" \ test/test_tiny.py test/test_rangeify.py test/test_ops.py test/test_tensor_variable.py \ test/test_outerworld_range.py test/test_sample.py test/test_randomness.py + - name: Test multitensor + run: RANGEIFY=1 PYTHONPATH="." python3 test/test_multitensor.py TestMultiTensor.test_matmul_shard_1_1 TestMultiTensor.test_simple_add_W - name: Test GPU=1 RANGEIFY=1 run: GPU=1 RANGEIFY=1 pytest -n auto test/test_ops.py - name: Test CPU=1 RANGEIFY=2 diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 65815e36d6..93049b738a 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -2,7 +2,7 @@ from typing import Any, cast import functools, operator from dataclasses import dataclass, field from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace -from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify +from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, graph_rewrite_map from tinygrad.uop.symbolic import sym from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup from tinygrad.schedule.multi import multi_pm @@ -438,8 +438,8 @@ pm_add_buffers = pm_mops+PatternMatcher([ (UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store), # move RESHAPEs through MSELECT/MSTACK - #(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"), - # lambda m: m.replace(src=tuple([x.src[0] for x in m.src])).reshape(m.src[0].arg)), + (UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"), + lambda m: m.replace(src=tuple([x.src[0] for x in m.src])).reshape(m.src[0].arg)), ]) # ***************** @@ -496,6 +496,10 @@ rangeify_codegen = PatternMatcher([ (UPat(Ops.STORE, name="store").f(Ops.INDEX, allow_any_len=True, name="idx").f(Ops.LOAD), lambda store,idx: idx.replace(src=(store.as_buf(),)+idx.src[1:]).load(store if idx.dtype.addrspace != AddrSpace.LOCAL else store.barrier())), + # copy on const is const + # TODO: this can be moved into codegen. this rule is probably in other places + (UPat(Ops.COPY, src=(UPat.cvar("c",), UPat())), lambda c: c), + # TODO: hack for group for reduce (UPat(Ops.IF, src=(UPat.var("gate"), UPat(Ops.LOAD, src=(UPat.var("src"), UPat.var("barrier"))),)), lambda src, barrier, gate: src.load(UOp(Ops.IF, src=(gate, barrier)))), @@ -535,7 +539,12 @@ add_tags = PatternMatcher([ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: uop_list: list[UOp] = [] tsink = graph_rewrite(sink, add_tags, ctx=uop_list, bottom_up=True, name="number the uops") - tsink = graph_rewrite(tsink, multi_pm+earliest_rewrites, name="earliest rewrites") + + # HACKS: handle multi with graph_rewrite_map in order to not have to add all the tag logic to multi + msink = graph_rewrite_map(tsink, multi_pm, name="multi") + tsink = msink[tsink].substitute({v:v.rtag(k.tag) for k,v in msink.items() if v.tag is None and k.tag is not None}) + + tsink = graph_rewrite(tsink, earliest_rewrites, name="earliest rewrites") realize_map: dict[UOp, UOp] = {} graph_rewrite(tsink, do_realize, ctx=realize_map, name="Input Graph") # NOTE: we don't use contiguous here, contiguous is a user op