diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index e22b077eb4..aae187cef2 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -1,8 +1,8 @@ -from typing import List, Dict, Optional, cast, Generator, Tuple +from typing import List, Dict, Optional, cast, Generator import time from dataclasses import dataclass from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen -from tinygrad.ops import ScheduleItem, BufferOps, LoadOps, copy_ast, LazyOp +from tinygrad.ops import ScheduleItem, BufferOps, LoadOps from tinygrad.device import Runner, Device from tinygrad.device import Buffer from tinygrad.shape.symbolic import Variable, sym_infer @@ -74,20 +74,16 @@ class ExecItem: self.prg.first_run = False return et -def lower_runner(dname:str, ast:Tuple[LazyOp, ...], bufs) -> ExecItem: - runner = Device[dname].get_runner(*ast) - # TODO: globals isn't on the stupid diskrunner, remove the need for it - return ExecItem(runner, [bufs[x[0]] for x in runner.p.globals] if hasattr(runner, 'p') else bufs) - def lower_schedule_item(si:ScheduleItem) -> ExecItem: - assert len(set(x.device for x in si.bufs)) == 1 or si.ast[0].op is LoadOps.COPY - if si.ast[0].op is BufferOps.STORE: return lower_runner(si.outputs[0].device, si.ast, si.bufs) + assert len(set(x.device for x in si.bufs)) == 1 or si.ast[0].op is LoadOps.COPY or getenv("USE_COPY_KERNEL") + if si.ast[0].op is BufferOps.STORE: + runner = Device[si.outputs[0].device].get_runner(*si.ast) + return ExecItem(runner, [si.bufs[x[0]] for x in runner.p.globals]) assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput" out, ast = si.outputs[0], si.ast[0] if ast.op is LoadOps.COPY: kernel_type = BufferCopy if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]: - if getenv("USE_COPY_KERNEL"): return lower_runner(out.device, (copy_ast(ast.arg),), si.bufs) kernel_type = BufferXfer return ExecItem(kernel_type(ast.arg, out.device, si.inputs[0].device), list(si.bufs)) if ast.op is LoadOps.CUSTOM: return ExecItem(CustomOp(ast.arg), list(si.bufs)) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 73cbe62602..8ae8ae33e5 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,6 +1,6 @@ import sys, pickle, atexit from collections import defaultdict, deque -from dataclasses import dataclass +from dataclasses import dataclass, replace from typing import Tuple, List, Dict, Optional, Set, DefaultDict from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer @@ -267,6 +267,9 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe for out in ps.outputs: realized_lazybuffer(out, kernel_number) var_vals = merge_dicts([var_vals, ps.var_vals]) for out in ps.outputs: del out.srcs # can only schedule once + if getenv("USE_COPY_KERNEL") and ps.ast[0].op == LoadOps.COPY and ps.outputs[0].device.split(":")[0] == ps.inputs[0].device.split(":")[0]: + rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((ps.ast[0].arg,)))) + ps = replace(ps, ast=(LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)),)) schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0))) if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n") for x in graph[ps.outputs[0]]: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 3b1c8b06bd..5ace19289c 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -92,10 +92,6 @@ class LazyOp: const_vars = [x.arg.val.unbind()[0] for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)] return sorted(set.union(*extract_vars, set(const_vars)), key=lambda x: str(x.expr)) -def copy_ast(sz) -> LazyOp: - rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((sz,)))) - return LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)) - # **************** independent FlopCounter **************** @dataclass