mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
fix cuda sync (#3888)
This commit is contained in:
@@ -84,8 +84,9 @@ class LazyBuffer:
|
||||
def is_unrealized_contiguous_const(self): return self.base == self and not self.base.realized and self.op is LoadOps.CONST
|
||||
|
||||
def _copy(self, device:str) -> LazyBuffer:
|
||||
if self.device.startswith("EXT") or self.device.startswith("DISK"):
|
||||
if (dstart:=self.device.split(":")[0]) in {"EXT", "DISK"} or (dstart in {"HSA", "CUDA"} and device.split(":")[0] == dstart):
|
||||
# DISK/EXT don't sync
|
||||
# copies in HSA/CUDA to other HSA/CUDA don't sync either
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self,), enable_cache=False)
|
||||
sync = LazyBuffer.loadop(LoadOps.SYNC, (0,), dtypes.uint32, self.device, src=(self,), enable_cache=True)
|
||||
wait = LazyBuffer.loadop(LoadOps.WAIT, (0,), dtypes.uint32, device, src=(sync,), enable_cache=True)
|
||||
|
||||
@@ -2,7 +2,7 @@ import sys
|
||||
from collections import defaultdict, deque
|
||||
from typing import Deque, List, Dict, Optional, cast, Set, DefaultDict
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps
|
||||
from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, BufferRead, JITRunner, update_stats, Compiled, BufferOptions
|
||||
from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, BufferRead, JITRunner, update_stats
|
||||
from tinygrad.features.graph import realized_lazybuffer, log_lazybuffer
|
||||
from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG, prod, dedup, all_int
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
@@ -27,18 +27,16 @@ class SyncOp(JITRunner):
|
||||
update_stats(colored("synchronize", "RED"), 0, 0, {}, et, 1, device=self.dname)
|
||||
|
||||
def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
|
||||
if si.ast[0].op not in {LoadOps.COPY, LoadOps.WAIT}: assert len(set(x.device for x in si.outputs+si.inputs)) == 1
|
||||
assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op in {LoadOps.COPY, LoadOps.WAIT}
|
||||
if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast)
|
||||
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
|
||||
out, ast = si.outputs[0], si.ast[0]
|
||||
if ast.op is LoadOps.COPY:
|
||||
if hasattr(Device[out.device].allocator, 'transfer') and type(Device[out.device]) is type(Device[si.inputs[0].device]): return BufferXfer()
|
||||
if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]: return BufferXfer()
|
||||
if si.inputs[0].device.startswith("DISK"): return BufferRead()
|
||||
return BufferCopy()
|
||||
if ast.op is LoadOps.CUSTOM: return CustomOp(ast.arg)
|
||||
if ast.op is LoadOps.SYNC and out.device.startswith("CUDA") and si.inputs[0].device.startswith("CUDA"): return None
|
||||
if ast.op is LoadOps.SYNC and out.device.startswith("HSA") and si.inputs[0].device.startswith("HSA"): return None
|
||||
if ast.op is LoadOps.SYNC: return SyncOp(out.device) if isinstance(Device[out.device], Compiled) else None
|
||||
if ast.op is LoadOps.SYNC: return SyncOp(out.device)
|
||||
return None
|
||||
|
||||
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
|
||||
@@ -53,13 +51,12 @@ def run_schedule(schedule:List[ScheduleItem]):
|
||||
for out in si.outputs:
|
||||
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
|
||||
if out.size > 0:
|
||||
options = BufferOptions(host=True, signal=True) if si.ast[0].op is LoadOps.SYNC else None
|
||||
if out.op is LoadOps.ASSIGN and out.srcs[1].base.realized is not None:
|
||||
# if the buffer isn't realized, it might be a const or something. this is fine
|
||||
out.realized = out.srcs[1].base.realized
|
||||
else:
|
||||
out.realized = out.output_buffer if out.output_buffer is not None else \
|
||||
Buffer(out.device, out.size, out.dtype, "PLACEHOLDER" if getattr(prg, "skip_allocation", False) else None, options=options)
|
||||
Buffer(out.device, out.size, out.dtype, "PLACEHOLDER" if getattr(prg, "skip_allocation", False) else None)
|
||||
del out.srcs
|
||||
|
||||
# run the function (put it in JIT)
|
||||
|
||||
Reference in New Issue
Block a user