diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7b7cf87b3d..8373f33a82 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -529,7 +529,9 @@ jobs: - name: Test const folding run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded and not TestReduceOpsConstFolding and not TestMultiConstFolding" - name: Test multitensor - run: CPU=1 RANGEIFY=1 python3 test/test_multitensor.py TestMultiTensor.test_matmul_shard_1_1 TestMultiTensor.test_simple_add_W + run: | + CPU=1 RANGEIFY=1 python3 test/test_multitensor.py TestMultiTensor.test_matmul_shard_1_1 TestMultiTensor.test_simple_add_W + CPU=1 RANGEIFY=1 python3 -m pytest test/test_multitensor.py::TestMultiAssign -k 'not (multi_assign_piece_noncontig or multi_assign_var_offset)' - name: Test CPU=1 RANGEIFY=2 run: CPU=1 CPU_LLVM=0 RANGEIFY=2 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20 # slow (and still wrong on beautiful_mnist) diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 5e9950ed0f..cd655cafd9 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -211,7 +211,7 @@ def assign_multi(dest:UOp, src:UOp): return dest.src[0].assign(src.src[0]).multi(src.axis) def passthrough_multi(root:UOp, multi:UOp): - return root.replace(src=(multi.src[0],)).multi(multi.axis) + return UOp(root.op, root.dtype, (multi.src[0],), root.arg).multi(multi.axis) # NOTE: this is the same pattern as Ops.UNROLL multi_pm = PatternMatcher([ diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 995c5477f2..af496cfdb8 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -582,7 +582,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: # rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph # if it's not tagged by here, it's out - tsink = UOp.sink(*[x for x in tsink.parents if (x.op is Ops.BUFFERIZE or x.base.op in {Ops.CONST}) and x.tag is not None]) + tsink = UOp.sink(*[x for x in tsink.parents if x.base.op in {Ops.BUFFERIZE, Ops.CONST} and x.tag is not None]) if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")