diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 5cadbe969d..291c1882c6 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -112,7 +112,7 @@ class Kernel: ret = type(self).__new__(type(self)) # base linearizer params - ret.opts, ret.ast = self.opts, self.ast + ret.opts, ret.ast, ret.lazyops = self.opts, self.ast, self.lazyops # things downstream of the AST ret.reduceop, ret.outbufs, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \ diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 2fff91c094..140b6d0498 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -211,8 +211,9 @@ class Linearizer(Kernel): self.buf_uops.append(self.uops.add(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size))) # kernel name (before late upcast) - self.name = ("r" if self.reduceop else "E") + (f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \ - colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())]) + self.name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \ + (f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \ + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())]) # name the function something unique Linearizer.kernel_cnt[(function_name := to_function_name(self.name))] += 1 diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index dcbc1855fb..6e7d997ec9 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -1,7 +1,7 @@ from typing import List, Dict, Optional, cast, Generator from dataclasses import dataclass -from tinygrad.helpers import colored -from tinygrad.ops import ScheduleItem, BufferOps, LoadOps +from tinygrad.helpers import colored, getenv +from tinygrad.ops import ScheduleItem, BufferOps, LoadOps, copy_ast from tinygrad.device import Runner, Device, BufferCopy, BufferXfer, update_stats from tinygrad.buffer import Buffer from tinygrad.shape.symbolic import Variable @@ -29,7 +29,8 @@ def lower_schedule_item(si:ScheduleItem) -> Runner: assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput" out, ast = si.outputs[0], si.ast[0] if ast.op is LoadOps.COPY: - if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]: return BufferXfer() + if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]: + return Device[si.outputs[0].device].get_runner(copy_ast(ast.arg)) if getenv("USE_COPY_KERNEL") else BufferXfer() return BufferCopy() if ast.op is LoadOps.CUSTOM: return CustomOp(ast.arg) if ast.op is LoadOps.EMPTY: return EmptyOp() diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 265751a337..123f87e54f 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -97,7 +97,7 @@ class LazyBuffer: def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views) def _copy(self, device:str) -> LazyBuffer: - return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self,), enable_cache=False) + return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, self.buffer.nbytes, (self,), enable_cache=False) def copy_to_device(self, device:str, force: bool = False) -> LazyBuffer: # no COPY diff --git a/tinygrad/ops.py b/tinygrad/ops.py index b7a24ea3f0..f2c379026d 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -75,6 +75,10 @@ class LazyOp: def vars(self) -> List[Variable]: return sorted(set.union(*[x.arg.st.vars() for x in self.lazyops if x.op in BufferOps], set()), key=lambda x: str(x.expr)) +def copy_ast(sz) -> LazyOp: + rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((sz,)))) + return LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)) + # **************** independent FlopCounter **************** @dataclass