mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
optionally use a copy kernel instead of SDMA (#4116)
* optionally use a copy kernel * lazyops in copied kernels * add sync * no sdma at all * work * copy_ast
This commit is contained in:
@@ -112,7 +112,7 @@ class Kernel:
|
||||
ret = type(self).__new__(type(self))
|
||||
|
||||
# base linearizer params
|
||||
ret.opts, ret.ast = self.opts, self.ast
|
||||
ret.opts, ret.ast, ret.lazyops = self.opts, self.ast, self.lazyops
|
||||
|
||||
# things downstream of the AST
|
||||
ret.reduceop, ret.outbufs, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \
|
||||
|
||||
@@ -211,8 +211,9 @@ class Linearizer(Kernel):
|
||||
self.buf_uops.append(self.uops.add(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size)))
|
||||
|
||||
# kernel name (before late upcast)
|
||||
self.name = ("r" if self.reduceop else "E") + (f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \
|
||||
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
||||
self.name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \
|
||||
(f"{len(self.outbufs)}_" if len(self.outbufs) > 1 else "_") + \
|
||||
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
||||
|
||||
# name the function something unique
|
||||
Linearizer.kernel_cnt[(function_name := to_function_name(self.name))] += 1
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import List, Dict, Optional, cast, Generator
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import colored
|
||||
from tinygrad.ops import ScheduleItem, BufferOps, LoadOps
|
||||
from tinygrad.helpers import colored, getenv
|
||||
from tinygrad.ops import ScheduleItem, BufferOps, LoadOps, copy_ast
|
||||
from tinygrad.device import Runner, Device, BufferCopy, BufferXfer, update_stats
|
||||
from tinygrad.buffer import Buffer
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
@@ -29,7 +29,8 @@ def lower_schedule_item(si:ScheduleItem) -> Runner:
|
||||
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
|
||||
out, ast = si.outputs[0], si.ast[0]
|
||||
if ast.op is LoadOps.COPY:
|
||||
if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]: return BufferXfer()
|
||||
if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
|
||||
return Device[si.outputs[0].device].get_runner(copy_ast(ast.arg)) if getenv("USE_COPY_KERNEL") else BufferXfer()
|
||||
return BufferCopy()
|
||||
if ast.op is LoadOps.CUSTOM: return CustomOp(ast.arg)
|
||||
if ast.op is LoadOps.EMPTY: return EmptyOp()
|
||||
|
||||
@@ -97,7 +97,7 @@ class LazyBuffer:
|
||||
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
|
||||
|
||||
def _copy(self, device:str) -> LazyBuffer:
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self,), enable_cache=False)
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, self.buffer.nbytes, (self,), enable_cache=False)
|
||||
|
||||
def copy_to_device(self, device:str, force: bool = False) -> LazyBuffer:
|
||||
# no COPY
|
||||
|
||||
@@ -75,6 +75,10 @@ class LazyOp:
|
||||
def vars(self) -> List[Variable]:
|
||||
return sorted(set.union(*[x.arg.st.vars() for x in self.lazyops if x.op in BufferOps], set()), key=lambda x: str(x.expr))
|
||||
|
||||
def copy_ast(sz) -> LazyOp:
|
||||
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((sz,))))
|
||||
return LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st))
|
||||
|
||||
# **************** independent FlopCounter ****************
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user