more const folding

This commit is contained in:
qazal
2024-12-07 13:04:28 +02:00
parent 25f5b336ef
commit 2a7d258a2b
2 changed files with 16 additions and 4 deletions

View File

@@ -8,15 +8,17 @@ import numpy as np
import functools
from typing import List, Optional, Union, cast
from torch import wait
from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites, view_supported_devices
from tinygrad.ops import GroupOp, UOp, Ops, graph_rewrite, track_rewrites, view_supported_devices
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule, view_right, view_left, do_realize
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleContext, ScheduleItem, create_schedule, full_ast_rewrite, view_right, view_left, do_realize
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
from tinygrad.engine.lazy import LazyBuffer
from extra.models.llama import precompute_freqs_cis
@@ -1962,5 +1964,13 @@ class TestBigGraph(unittest.TestCase):
self.assertIs(big_graph, out.sink())
self.assertEqual(len(realizes), 1)
def test_early_fold_const_ops(self):
st = ShapeTracker.from_shape((2925,)).to_uop()
load = UOp(Ops.LOAD, dtypes.float, arg=None, src=(UOp.new_buffer(Device.DEFAULT, st.size, dtypes.float), st,))
ast = UOp.store(UOp.new_buffer(Device.DEFAULT, st.size, dtypes.float), st, load*(UOp.const(dtypes.float, 11)*UOp.const(dtypes.float, 1)))
assert len([x for x in ast.toposort if x.op in GroupOp.ALU]) == 2
ast, _ = full_ast_rewrite(ast.sink(), ScheduleContext())
assert len([x for x in ast.toposort if x.op in GroupOp.ALU]) == 1
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -3,7 +3,7 @@ from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import FrozenSet, Set, Tuple, List, Dict, Optional, DefaultDict
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
from tinygrad.ops import realized, identity_element, buffers
from tinygrad.ops import realized, identity_element, buffers, symbolic_flat
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG
from tinygrad.dtype import ConstType, ImageDType, dtypes
@@ -177,7 +177,9 @@ def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
return x.replace(op=Ops.LOAD)
check_preload = PatternMatcher([(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),])
to_si = PatternMatcher([
to_si = symbolic_flat+PatternMatcher([
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat.cvar("x"))), lambda ctx,x:x),
(UPat(Ops.VIEW, src=(UPat.cvar("x"),)), lambda ctx,x:x),
(UPat(Ops.VIEW, name="x"), _append_st_vars),
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,b,x: x.replace(src=(b, *x.src))),
# unmasked VALID is just CONST