From 71c5901fc19cfc0b4f2cabea894d6c64dc7e55f7 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 13 Aug 2024 02:51:00 +0800 Subject: [PATCH] refactor ast arg and op [compare_schedule] (#6052) --- tinygrad/engine/schedule.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 40fd24b3e2..266e839a17 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,8 +1,8 @@ import sys, pickle, atexit, importlib, contextlib from collections import defaultdict, deque from dataclasses import dataclass, field -from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args -from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st +from typing import Tuple, List, Dict, Optional, Set, DefaultDict, get_args +from tinygrad.ops import MetaOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \ GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata @@ -55,7 +55,7 @@ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, """recursively create a lazyop""" if buf is not buf.base: st, buf = buf.st+st, buf.base if (buf, st) in cache: return cache[(buf, st)] - arg = buf.arg + assert buf.op is not None, "base must be a base itself" # buffer ops define ShapeTracker if buf.realized is not None or (buf in realizes and buf not in outputs): @@ -63,11 +63,11 @@ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, var_vals.update(st_var_vals) # if it's a const, we generate it if buf.op is MetaOps.CONST: - if isinstance(arg, Variable): - arg, var_val = arg.unbind() - var_vals[arg] = var_val - else: assert isinstance(arg, get_args(ConstType)), f"cannot create ConstBuffer with value {arg}" - return LazyOp(BufferOps.CONST, (), ConstBuffer(arg, buf.dtype, unbound_st)) + if isinstance(val:=buf.arg, Variable): + val, var_val = val.unbind() + var_vals[val] = var_val + else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}" + return LazyOp(BufferOps.CONST, (), ConstBuffer(val, buf.dtype, unbound_st)) # otherwise, it's a load and we add it to the inputs if buf in assign_targets: # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine @@ -80,18 +80,20 @@ def _recursive_lazyop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, # reduce ops change ShapeTracker if buf.op in ReduceOps: + rinfo = reduce_info.get((buf, st)) + rsrc = _recursive_lazyop(buf.srcs[0], st:=(rinfo[0] if rinfo else st), outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) # if we are merging the reduce, skip it - if (buf, st) not in reduce_info: - assert buf.srcs[0].base.op is buf.op, f"can't merge reduceop {buf.op} with {buf.srcs[0].base.op}\n{st}" - return _recursive_lazyop(buf.srcs[0], st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) - st, arg = reduce_info[(buf, st)] + if rinfo is None: + assert rsrc.op is buf.op, f"can't merge reduceop {buf.op} with {rsrc.op}\n{st}" + return rsrc + return cache.setdefault((buf, st), LazyOp(buf.op, (rsrc,), rinfo[1])) # elementwise ops pass shapetracker in_ops = tuple(_recursive_lazyop(x, st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) for x in buf.srcs) if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}: assert buf in outputs, f"{buf.op} must be writable" return in_ops[0] - return cache.setdefault((buf, st), LazyOp(cast(Op, buf.op), in_ops, arg)) + return cache.setdefault((buf, st), LazyOp(buf.op, in_ops, buf.arg)) def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]: permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis) + axis