fix cuda sync (#3888)

This commit is contained in:
George Hotz
2024-03-22 19:02:30 -07:00
committed by GitHub
parent 2d3ce53348
commit f0c4e06ffd
2 changed files with 7 additions and 9 deletions

View File

@@ -84,8 +84,9 @@ class LazyBuffer:
def is_unrealized_contiguous_const(self): return self.base == self and not self.base.realized and self.op is LoadOps.CONST
def _copy(self, device:str) -> LazyBuffer:
if self.device.startswith("EXT") or self.device.startswith("DISK"):
if (dstart:=self.device.split(":")[0]) in {"EXT", "DISK"} or (dstart in {"HSA", "CUDA"} and device.split(":")[0] == dstart):
# DISK/EXT don't sync
# copies in HSA/CUDA to other HSA/CUDA don't sync either
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self,), enable_cache=False)
sync = LazyBuffer.loadop(LoadOps.SYNC, (0,), dtypes.uint32, self.device, src=(self,), enable_cache=True)
wait = LazyBuffer.loadop(LoadOps.WAIT, (0,), dtypes.uint32, device, src=(sync,), enable_cache=True)

View File

@@ -2,7 +2,7 @@ import sys
from collections import defaultdict, deque
from typing import Deque, List, Dict, Optional, cast, Set, DefaultDict
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps
from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, BufferRead, JITRunner, update_stats, Compiled, BufferOptions
from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, BufferRead, JITRunner, update_stats
from tinygrad.features.graph import realized_lazybuffer, log_lazybuffer
from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG, prod, dedup, all_int
from tinygrad.shape.symbolic import Variable
@@ -27,18 +27,16 @@ class SyncOp(JITRunner):
update_stats(colored("synchronize", "RED"), 0, 0, {}, et, 1, device=self.dname)
def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
if si.ast[0].op not in {LoadOps.COPY, LoadOps.WAIT}: assert len(set(x.device for x in si.outputs+si.inputs)) == 1
assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op in {LoadOps.COPY, LoadOps.WAIT}
if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast)
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
out, ast = si.outputs[0], si.ast[0]
if ast.op is LoadOps.COPY:
if hasattr(Device[out.device].allocator, 'transfer') and type(Device[out.device]) is type(Device[si.inputs[0].device]): return BufferXfer()
if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]: return BufferXfer()
if si.inputs[0].device.startswith("DISK"): return BufferRead()
return BufferCopy()
if ast.op is LoadOps.CUSTOM: return CustomOp(ast.arg)
if ast.op is LoadOps.SYNC and out.device.startswith("CUDA") and si.inputs[0].device.startswith("CUDA"): return None
if ast.op is LoadOps.SYNC and out.device.startswith("HSA") and si.inputs[0].device.startswith("HSA"): return None
if ast.op is LoadOps.SYNC: return SyncOp(out.device) if isinstance(Device[out.device], Compiled) else None
if ast.op is LoadOps.SYNC: return SyncOp(out.device)
return None
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
@@ -53,13 +51,12 @@ def run_schedule(schedule:List[ScheduleItem]):
for out in si.outputs:
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
if out.size > 0:
options = BufferOptions(host=True, signal=True) if si.ast[0].op is LoadOps.SYNC else None
if out.op is LoadOps.ASSIGN and out.srcs[1].base.realized is not None:
# if the buffer isn't realized, it might be a const or something. this is fine
out.realized = out.srcs[1].base.realized
else:
out.realized = out.output_buffer if out.output_buffer is not None else \
Buffer(out.device, out.size, out.dtype, "PLACEHOLDER" if getattr(prg, "skip_allocation", False) else None, options=options)
Buffer(out.device, out.size, out.dtype, "PLACEHOLDER" if getattr(prg, "skip_allocation", False) else None)
del out.srcs
# run the function (put it in JIT)