fix some multitensor on rangeify (#12162)

* fix some multitensor on rangeify

* rangeify multi hacks

* copy on const
This commit is contained in:
George Hotz
2025-09-14 14:31:57 +08:00
committed by GitHub
parent 4b7904eca9
commit d5bc27797b
2 changed files with 15 additions and 4 deletions

View File

@@ -526,6 +526,8 @@ jobs:
-k "not test_symbolic_arange_sym_step and not test_threefry_doesnt_use_long" \
test/test_tiny.py test/test_rangeify.py test/test_ops.py test/test_tensor_variable.py \
test/test_outerworld_range.py test/test_sample.py test/test_randomness.py
- name: Test multitensor
run: RANGEIFY=1 PYTHONPATH="." python3 test/test_multitensor.py TestMultiTensor.test_matmul_shard_1_1 TestMultiTensor.test_simple_add_W
- name: Test GPU=1 RANGEIFY=1
run: GPU=1 RANGEIFY=1 pytest -n auto test/test_ops.py
- name: Test CPU=1 RANGEIFY=2

View File

@@ -2,7 +2,7 @@ from typing import Any, cast
import functools, operator
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, graph_rewrite_map
from tinygrad.uop.symbolic import sym
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup
from tinygrad.schedule.multi import multi_pm
@@ -438,8 +438,8 @@ pm_add_buffers = pm_mops+PatternMatcher([
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
# move RESHAPEs through MSELECT/MSTACK
#(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
# lambda m: m.replace(src=tuple([x.src[0] for x in m.src])).reshape(m.src[0].arg)),
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
lambda m: m.replace(src=tuple([x.src[0] for x in m.src])).reshape(m.src[0].arg)),
])
# *****************
@@ -496,6 +496,10 @@ rangeify_codegen = PatternMatcher([
(UPat(Ops.STORE, name="store").f(Ops.INDEX, allow_any_len=True, name="idx").f(Ops.LOAD),
lambda store,idx: idx.replace(src=(store.as_buf(),)+idx.src[1:]).load(store if idx.dtype.addrspace != AddrSpace.LOCAL else store.barrier())),
# copy on const is const
# TODO: this can be moved into codegen. this rule is probably in other places
(UPat(Ops.COPY, src=(UPat.cvar("c",), UPat())), lambda c: c),
# TODO: hack for group for reduce
(UPat(Ops.IF, src=(UPat.var("gate"), UPat(Ops.LOAD, src=(UPat.var("src"), UPat.var("barrier"))),)),
lambda src, barrier, gate: src.load(UOp(Ops.IF, src=(gate, barrier)))),
@@ -535,7 +539,12 @@ add_tags = PatternMatcher([
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
uop_list: list[UOp] = []
tsink = graph_rewrite(sink, add_tags, ctx=uop_list, bottom_up=True, name="number the uops")
tsink = graph_rewrite(tsink, multi_pm+earliest_rewrites, name="earliest rewrites")
# HACKS: handle multi with graph_rewrite_map in order to not have to add all the tag logic to multi
msink = graph_rewrite_map(tsink, multi_pm, name="multi")
tsink = msink[tsink].substitute({v:v.rtag(k.tag) for k,v in msink.items() if v.tag is None and k.tag is not None})
tsink = graph_rewrite(tsink, earliest_rewrites, name="earliest rewrites")
realize_map: dict[UOp, UOp] = {}
graph_rewrite(tsink, do_realize, ctx=realize_map, name="Input Graph")
# NOTE: we don't use contiguous here, contiguous is a user op