mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -82,9 +82,9 @@ sym = symbolic_simple+PatternMatcher([
|
||||
# split_reduceop
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
|
||||
# COPY(CONST) creates a new CONST on the destination device
|
||||
(UPat(Ops.COPY, name="root", src=(UPat.cvar("x"),), allow_any_len=True), lambda root,x: root.const_like(x.arg)),
|
||||
(UPat(Ops.COPY, name="root", src=(UPat.cvar("x"), UPat(Ops.DEVICE))), lambda root,x: root.const_like(x.arg)),
|
||||
# store a shrink before COPY, otherwise view after the COPY
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, name="v"),), name="copy", allow_any_len=True), lambda copy,v: v.contiguous().copy_to_device(copy.device) \
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, name="v"), UPat(Ops.DEVICE)), name="copy"), lambda copy,v: v.contiguous().copy_to_device(copy.device) \
|
||||
if prod(v.shape) < prod(v.base.shape) else v.base.copy_to_device(copy.device).view(v.st)),
|
||||
# remove cast to image when it's already a contiguous image
|
||||
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)),
|
||||
|
||||
@@ -126,13 +126,13 @@ def flip_multi(root:UOp, multi:UOp):
|
||||
|
||||
def copy_multi(multi:UOp, device:UOp):
|
||||
# if we already have a copy on the device, return that
|
||||
if multi.axis is None: return next((lb for lb in multi.real_lbs if lb.device == device.arg), multi.real_lbs[0].copy_to_device(device.arg))
|
||||
if multi.axis is None: return next((lb for lb in multi.real_lbs if lb.device == device.arg), multi.real_lbs[0].copy_to_device(device))
|
||||
# copy lbs to device, pad to final shape, and sum
|
||||
llbs:list[UOp] = []
|
||||
for lb,real,(start,end) in zip(multi.src, multi.real, multi.bounds):
|
||||
if not real: continue
|
||||
pad_arg = tuple((0,0) if a != multi.axis else (start, multi.bounds[-1][1]-end) for a in range(len(lb.shape)))
|
||||
llbs.append(lb.copy_to_device(device.arg).pad(pad_arg))
|
||||
llbs.append(lb.copy_to_device(device).pad(pad_arg))
|
||||
return functools.reduce(operator.add, llbs)
|
||||
|
||||
def assign_multi(dest:UOp, src:UOp):
|
||||
|
||||
@@ -490,9 +490,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
assert op is Ops.BIND, f"unknown op {op}"
|
||||
var, val = arg.unbind()
|
||||
return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
|
||||
def copy_to_device(self, device:str|tuple[str, ...], arg=None):
|
||||
if isinstance(device, tuple): return UOp(Ops.COPY, self.dtype, (self,)+tuple(UOp(Ops.DEVICE, arg=d) for d in device), arg)
|
||||
return UOp(Ops.COPY, self.dtype, (self, UOp(Ops.DEVICE, arg=device)), arg)
|
||||
def copy_to_device(self, device:str|tuple[str, ...]|UOp, arg=None):
|
||||
return UOp(Ops.COPY, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), arg)
|
||||
def clone(self) -> UOp: return self.copy_to_device(self.device)
|
||||
@property
|
||||
def metadata(self) -> tuple[Metadata, ...]|Metadata|None: return self.arg.metadata if self.op is Ops.KERNEL else all_metadata.get(self, None)
|
||||
@@ -536,9 +535,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def _device(self) -> Optional[str|tuple[str, ...]]:
|
||||
if self.op is Ops.DEVICE: return self.arg
|
||||
if self.op is Ops.MULTI: return tuple(cast(str, x.device) for x in self.src)
|
||||
if self.op in {Ops.COPY, Ops.BUFFER}:
|
||||
if len(self.src) > 2: return tuple(cast(str, x.device) for x in self.src[1:])
|
||||
return self.src[1].device
|
||||
if self.op in {Ops.COPY, Ops.BUFFER}: return self.src[1].device
|
||||
return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
|
||||
@property
|
||||
def buf_uop(self) -> UOp:
|
||||
|
||||
@@ -76,7 +76,7 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
|
||||
lambda root,x: root.dtype == x.dtype),
|
||||
|
||||
# COPY
|
||||
(UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE)), allow_any_len=True), lambda copy,x: copy.dtype == x.dtype),
|
||||
(UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda copy,x: copy.dtype == x.dtype),
|
||||
])
|
||||
|
||||
# ***** uop type spec *****
|
||||
|
||||
Reference in New Issue
Block a user