mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
more const folding
This commit is contained in:
@@ -8,15 +8,17 @@ import numpy as np
|
||||
import functools
|
||||
from typing import List, Optional, Union, cast
|
||||
|
||||
from torch import wait
|
||||
|
||||
from tinygrad import nn, dtypes, Device, Tensor
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites, view_supported_devices
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, graph_rewrite, track_rewrites, view_supported_devices
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule, view_right, view_left, do_realize
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleContext, ScheduleItem, create_schedule, full_ast_rewrite, view_right, view_left, do_realize
|
||||
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
@@ -1962,5 +1964,13 @@ class TestBigGraph(unittest.TestCase):
|
||||
self.assertIs(big_graph, out.sink())
|
||||
self.assertEqual(len(realizes), 1)
|
||||
|
||||
def test_early_fold_const_ops(self):
|
||||
st = ShapeTracker.from_shape((2925,)).to_uop()
|
||||
load = UOp(Ops.LOAD, dtypes.float, arg=None, src=(UOp.new_buffer(Device.DEFAULT, st.size, dtypes.float), st,))
|
||||
ast = UOp.store(UOp.new_buffer(Device.DEFAULT, st.size, dtypes.float), st, load*(UOp.const(dtypes.float, 11)*UOp.const(dtypes.float, 1)))
|
||||
assert len([x for x in ast.toposort if x.op in GroupOp.ALU]) == 2
|
||||
ast, _ = full_ast_rewrite(ast.sink(), ScheduleContext())
|
||||
assert len([x for x in ast.toposort if x.op in GroupOp.ALU]) == 1
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import FrozenSet, Set, Tuple, List, Dict, Optional, DefaultDict
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
|
||||
from tinygrad.ops import realized, identity_element, buffers
|
||||
from tinygrad.ops import realized, identity_element, buffers, symbolic_flat
|
||||
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes
|
||||
@@ -177,7 +177,9 @@ def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
|
||||
return x.replace(op=Ops.LOAD)
|
||||
check_preload = PatternMatcher([(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),])
|
||||
|
||||
to_si = PatternMatcher([
|
||||
to_si = symbolic_flat+PatternMatcher([
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat.cvar("x"))), lambda ctx,x:x),
|
||||
(UPat(Ops.VIEW, src=(UPat.cvar("x"),)), lambda ctx,x:x),
|
||||
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
||||
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,b,x: x.replace(src=(b, *x.src))),
|
||||
# unmasked VALID is just CONST
|
||||
|
||||
Reference in New Issue
Block a user