mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
use copy kernel in schedule (#4520)
* use copy kernel in schedule * imports
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
from typing import List, Dict, Optional, cast, Generator, Tuple
|
||||
from typing import List, Dict, Optional, cast, Generator
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen
|
||||
from tinygrad.ops import ScheduleItem, BufferOps, LoadOps, copy_ast, LazyOp
|
||||
from tinygrad.ops import ScheduleItem, BufferOps, LoadOps
|
||||
from tinygrad.device import Runner, Device
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
@@ -74,20 +74,16 @@ class ExecItem:
|
||||
self.prg.first_run = False
|
||||
return et
|
||||
|
||||
def lower_runner(dname:str, ast:Tuple[LazyOp, ...], bufs) -> ExecItem:
|
||||
runner = Device[dname].get_runner(*ast)
|
||||
# TODO: globals isn't on the stupid diskrunner, remove the need for it
|
||||
return ExecItem(runner, [bufs[x[0]] for x in runner.p.globals] if hasattr(runner, 'p') else bufs)
|
||||
|
||||
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
|
||||
assert len(set(x.device for x in si.bufs)) == 1 or si.ast[0].op is LoadOps.COPY
|
||||
if si.ast[0].op is BufferOps.STORE: return lower_runner(si.outputs[0].device, si.ast, si.bufs)
|
||||
assert len(set(x.device for x in si.bufs)) == 1 or si.ast[0].op is LoadOps.COPY or getenv("USE_COPY_KERNEL")
|
||||
if si.ast[0].op is BufferOps.STORE:
|
||||
runner = Device[si.outputs[0].device].get_runner(*si.ast)
|
||||
return ExecItem(runner, [si.bufs[x[0]] for x in runner.p.globals])
|
||||
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
|
||||
out, ast = si.outputs[0], si.ast[0]
|
||||
if ast.op is LoadOps.COPY:
|
||||
kernel_type = BufferCopy
|
||||
if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
|
||||
if getenv("USE_COPY_KERNEL"): return lower_runner(out.device, (copy_ast(ast.arg),), si.bufs)
|
||||
kernel_type = BufferXfer
|
||||
return ExecItem(kernel_type(ast.arg, out.device, si.inputs[0].device), list(si.bufs))
|
||||
if ast.op is LoadOps.CUSTOM: return ExecItem(CustomOp(ast.arg), list(si.bufs))
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import sys, pickle, atexit
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
|
||||
from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer
|
||||
@@ -267,6 +267,9 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
|
||||
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
|
||||
var_vals = merge_dicts([var_vals, ps.var_vals])
|
||||
for out in ps.outputs: del out.srcs # can only schedule once
|
||||
if getenv("USE_COPY_KERNEL") and ps.ast[0].op == LoadOps.COPY and ps.outputs[0].device.split(":")[0] == ps.inputs[0].device.split(":")[0]:
|
||||
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((ps.ast[0].arg,))))
|
||||
ps = replace(ps, ast=(LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)),))
|
||||
schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0)))
|
||||
if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
|
||||
for x in graph[ps.outputs[0]]:
|
||||
|
||||
@@ -92,10 +92,6 @@ class LazyOp:
|
||||
const_vars = [x.arg.val.unbind()[0] for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
|
||||
return sorted(set.union(*extract_vars, set(const_vars)), key=lambda x: str(x.expr))
|
||||
|
||||
def copy_ast(sz) -> LazyOp:
|
||||
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((sz,))))
|
||||
return LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st))
|
||||
|
||||
# **************** independent FlopCounter ****************
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user