single device copy [pr] (#10221)

* single device copy [pr]

* simpler
This commit is contained in:
George Hotz
2025-05-08 15:23:22 -07:00
committed by GitHub
parent 1d0f239df7
commit 0b7e3e86d0
4 changed files with 8 additions and 11 deletions

View File

@@ -82,9 +82,9 @@ sym = symbolic_simple+PatternMatcher([
# split_reduceop
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
# COPY(CONST) creates a new CONST on the destination device
(UPat(Ops.COPY, name="root", src=(UPat.cvar("x"),), allow_any_len=True), lambda root,x: root.const_like(x.arg)),
(UPat(Ops.COPY, name="root", src=(UPat.cvar("x"), UPat(Ops.DEVICE))), lambda root,x: root.const_like(x.arg)),
# store a shrink before COPY, otherwise view after the COPY
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, name="v"),), name="copy", allow_any_len=True), lambda copy,v: v.contiguous().copy_to_device(copy.device) \
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, name="v"), UPat(Ops.DEVICE)), name="copy"), lambda copy,v: v.contiguous().copy_to_device(copy.device) \
if prod(v.shape) < prod(v.base.shape) else v.base.copy_to_device(copy.device).view(v.st)),
# remove cast to image when it's already a contiguous image
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)),

View File

@@ -126,13 +126,13 @@ def flip_multi(root:UOp, multi:UOp):
def copy_multi(multi:UOp, device:UOp):
# if we already have a copy on the device, return that
if multi.axis is None: return next((lb for lb in multi.real_lbs if lb.device == device.arg), multi.real_lbs[0].copy_to_device(device.arg))
if multi.axis is None: return next((lb for lb in multi.real_lbs if lb.device == device.arg), multi.real_lbs[0].copy_to_device(device))
# copy lbs to device, pad to final shape, and sum
llbs:list[UOp] = []
for lb,real,(start,end) in zip(multi.src, multi.real, multi.bounds):
if not real: continue
pad_arg = tuple((0,0) if a != multi.axis else (start, multi.bounds[-1][1]-end) for a in range(len(lb.shape)))
llbs.append(lb.copy_to_device(device.arg).pad(pad_arg))
llbs.append(lb.copy_to_device(device).pad(pad_arg))
return functools.reduce(operator.add, llbs)
def assign_multi(dest:UOp, src:UOp):

View File

@@ -490,9 +490,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
assert op is Ops.BIND, f"unknown op {op}"
var, val = arg.unbind()
return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
def copy_to_device(self, device:str|tuple[str, ...], arg=None):
if isinstance(device, tuple): return UOp(Ops.COPY, self.dtype, (self,)+tuple(UOp(Ops.DEVICE, arg=d) for d in device), arg)
return UOp(Ops.COPY, self.dtype, (self, UOp(Ops.DEVICE, arg=device)), arg)
def copy_to_device(self, device:str|tuple[str, ...]|UOp, arg=None):
return UOp(Ops.COPY, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), arg)
def clone(self) -> UOp: return self.copy_to_device(self.device)
@property
def metadata(self) -> tuple[Metadata, ...]|Metadata|None: return self.arg.metadata if self.op is Ops.KERNEL else all_metadata.get(self, None)
@@ -536,9 +535,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def _device(self) -> Optional[str|tuple[str, ...]]:
if self.op is Ops.DEVICE: return self.arg
if self.op is Ops.MULTI: return tuple(cast(str, x.device) for x in self.src)
if self.op in {Ops.COPY, Ops.BUFFER}:
if len(self.src) > 2: return tuple(cast(str, x.device) for x in self.src[1:])
return self.src[1].device
if self.op in {Ops.COPY, Ops.BUFFER}: return self.src[1].device
return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
@property
def buf_uop(self) -> UOp:

View File

@@ -76,7 +76,7 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
lambda root,x: root.dtype == x.dtype),
# COPY
(UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE)), allow_any_len=True), lambda copy,x: copy.dtype == x.dtype),
(UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda copy,x: copy.dtype == x.dtype),
])
# ***** uop type spec *****