mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
fix some multitensor on rangeify (#12162)
* fix some multitensor on rangeify * rangeify multi hacks * copy on const
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -526,6 +526,8 @@ jobs:
|
||||
-k "not test_symbolic_arange_sym_step and not test_threefry_doesnt_use_long" \
|
||||
test/test_tiny.py test/test_rangeify.py test/test_ops.py test/test_tensor_variable.py \
|
||||
test/test_outerworld_range.py test/test_sample.py test/test_randomness.py
|
||||
- name: Test multitensor
|
||||
run: RANGEIFY=1 PYTHONPATH="." python3 test/test_multitensor.py TestMultiTensor.test_matmul_shard_1_1 TestMultiTensor.test_simple_add_W
|
||||
- name: Test GPU=1 RANGEIFY=1
|
||||
run: GPU=1 RANGEIFY=1 pytest -n auto test/test_ops.py
|
||||
- name: Test CPU=1 RANGEIFY=2
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Any, cast
|
||||
import functools, operator
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, graph_rewrite_map
|
||||
from tinygrad.uop.symbolic import sym
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
@@ -438,8 +438,8 @@ pm_add_buffers = pm_mops+PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
|
||||
|
||||
# move RESHAPEs through MSELECT/MSTACK
|
||||
#(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
|
||||
# lambda m: m.replace(src=tuple([x.src[0] for x in m.src])).reshape(m.src[0].arg)),
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
|
||||
lambda m: m.replace(src=tuple([x.src[0] for x in m.src])).reshape(m.src[0].arg)),
|
||||
])
|
||||
|
||||
# *****************
|
||||
@@ -496,6 +496,10 @@ rangeify_codegen = PatternMatcher([
|
||||
(UPat(Ops.STORE, name="store").f(Ops.INDEX, allow_any_len=True, name="idx").f(Ops.LOAD),
|
||||
lambda store,idx: idx.replace(src=(store.as_buf(),)+idx.src[1:]).load(store if idx.dtype.addrspace != AddrSpace.LOCAL else store.barrier())),
|
||||
|
||||
# copy on const is const
|
||||
# TODO: this can be moved into codegen. this rule is probably in other places
|
||||
(UPat(Ops.COPY, src=(UPat.cvar("c",), UPat())), lambda c: c),
|
||||
|
||||
# TODO: hack for group for reduce
|
||||
(UPat(Ops.IF, src=(UPat.var("gate"), UPat(Ops.LOAD, src=(UPat.var("src"), UPat.var("barrier"))),)),
|
||||
lambda src, barrier, gate: src.load(UOp(Ops.IF, src=(gate, barrier)))),
|
||||
@@ -535,7 +539,12 @@ add_tags = PatternMatcher([
|
||||
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
uop_list: list[UOp] = []
|
||||
tsink = graph_rewrite(sink, add_tags, ctx=uop_list, bottom_up=True, name="number the uops")
|
||||
tsink = graph_rewrite(tsink, multi_pm+earliest_rewrites, name="earliest rewrites")
|
||||
|
||||
# HACKS: handle multi with graph_rewrite_map in order to not have to add all the tag logic to multi
|
||||
msink = graph_rewrite_map(tsink, multi_pm, name="multi")
|
||||
tsink = msink[tsink].substitute({v:v.rtag(k.tag) for k,v in msink.items() if v.tag is None and k.tag is not None})
|
||||
|
||||
tsink = graph_rewrite(tsink, earliest_rewrites, name="earliest rewrites")
|
||||
realize_map: dict[UOp, UOp] = {}
|
||||
graph_rewrite(tsink, do_realize, ctx=realize_map, name="Input Graph")
|
||||
# NOTE: we don't use contiguous here, contiguous is a user op
|
||||
|
||||
Reference in New Issue
Block a user