diff --git a/test/test_schedule.py b/test/test_schedule.py index 8456293e30..089a546cce 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -8,15 +8,17 @@ import numpy as np import functools from typing import List, Optional, Union, cast +from torch import wait + from tinygrad import nn, dtypes, Device, Tensor from tinygrad.device import is_dtype_supported from tinygrad.dtype import DType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View -from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites, view_supported_devices +from tinygrad.ops import GroupOp, UOp, Ops, graph_rewrite, track_rewrites, view_supported_devices from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context from tinygrad.codegen.kernel import Kernel, verify_ast -from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule, view_right, view_left, do_realize +from tinygrad.engine.schedule import BUF_LIMIT, ScheduleContext, ScheduleItem, create_schedule, full_ast_rewrite, view_right, view_left, do_realize from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule from tinygrad.engine.lazy import LazyBuffer from extra.models.llama import precompute_freqs_cis @@ -1962,5 +1964,13 @@ class TestBigGraph(unittest.TestCase): self.assertIs(big_graph, out.sink()) self.assertEqual(len(realizes), 1) + def test_early_fold_const_ops(self): + st = ShapeTracker.from_shape((2925,)).to_uop() + load = UOp(Ops.LOAD, dtypes.float, arg=None, src=(UOp.new_buffer(Device.DEFAULT, st.size, dtypes.float), st,)) + ast = UOp.store(UOp.new_buffer(Device.DEFAULT, st.size, dtypes.float), st, load*(UOp.const(dtypes.float, 11)*UOp.const(dtypes.float, 1))) + assert len([x for x in ast.toposort if x.op in GroupOp.ALU]) == 2 + ast, _ = full_ast_rewrite(ast.sink(), ScheduleContext()) + assert len([x for x in ast.toposort if x.op in GroupOp.ALU]) == 1 + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 2d0b72bde2..9e5ab73a8c 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -3,7 +3,7 @@ from collections import defaultdict, deque from dataclasses import dataclass, field from typing import FrozenSet, Set, Tuple, List, Dict, Optional, DefaultDict from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views -from tinygrad.ops import realized, identity_element, buffers +from tinygrad.ops import realized, identity_element, buffers, symbolic_flat from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG from tinygrad.dtype import ConstType, ImageDType, dtypes @@ -177,7 +177,9 @@ def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp: return x.replace(op=Ops.LOAD) check_preload = PatternMatcher([(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),]) -to_si = PatternMatcher([ +to_si = symbolic_flat+PatternMatcher([ + (UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat.cvar("x"))), lambda ctx,x:x), + (UPat(Ops.VIEW, src=(UPat.cvar("x"),)), lambda ctx,x:x), (UPat(Ops.VIEW, name="x"), _append_st_vars), (UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,b,x: x.replace(src=(b, *x.src))), # unmasked VALID is just CONST