diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 9811450ea7..314ff2eefa 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -530,7 +530,6 @@ add_tags = PatternMatcher([ ]) # support for using a contiguous permuted view instead of the parent view if one exists -# modified from kernelize.py to not use ShapeTracker def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp): x = src diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index ad26051b8e..0f52ac2d7c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -162,10 +162,10 @@ class Tensor(OpMixin): # data might be on a different device if isinstance(_device, str): self.uop:UOp = data if data.device == _device else data.copy_to_device(_device) - # if device is a tuple, we should have/construct a MultiLazyBuffer + # if device is a tuple, we should have/construct a multi-device UOp elif isinstance(data.device, str): self.uop = Tensor(data).shard(_device).uop else: - assert data.device == _device, f"MultiLazyBuffer device mismatch, {data.device} != {_device}" + assert data.device == _device, f"multi-device UOp device mismatch, {data.device} != {_device}" self.uop = data # add to all_tensors after construction succeeds @@ -397,7 +397,7 @@ class Tensor(OpMixin): print(t.shard((t.device, t.device), axis=1).uop) ``` """ - if not isinstance(self.device, str): raise RuntimeError("can't shard a MultiLazyBuffer") + if not isinstance(self.device, str): raise RuntimeError("can't shard a multi-device tensor") if len(devices) == 1: return self.to(devices[0]) devices = tuple(canonicalize_device(x) for x in devices) mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices) diff --git a/tinygrad/uop/divandmod.py b/tinygrad/uop/divandmod.py index 5779a1f0d1..ce917e33d6 100644 --- a/tinygrad/uop/divandmod.py +++ b/tinygrad/uop/divandmod.py @@ -3,7 +3,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp from tinygrad.dtype import dtypes from tinygrad.helpers import cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap -# NOTE: this cache is only on index UOps and matches the cache in the old ShapeTracker in spirit +# NOTE: this cache is only on index UOps @functools.cache def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None: x, y = d.src diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index e8cd8bb054..443dd5276d 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -236,7 +236,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): case Ops.RESHAPE: if self.src[0]._shape is None: return self.marg - # movement ops change the shape. this is the logic from the old ShapeTracker + # movement ops change the shape # NOTE: ssimplify is required because the shape needs to be canonical for broadcasting and same shape checking if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE_AXIS, Ops.WMMA}): ps = self.src[0]._shape @@ -465,14 +465,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op) def overflows(self, dtype:DType) -> bool: return self.vmin < dtype.min or dtype.max < self.vmax - # *** ShapeTracker helpers *** - def split_uop(self:UOp, sep:Ops): if self.op is sep: for s in self.src: yield from s.split_uop(sep) else: yield self - # *** from MultiLazyBuffer *** + # *** multi-device helpers *** def multi(self, axis:int|None): assert isinstance(self.device, tuple), f"multi device must be tuple, {self.device} isn't" @@ -514,8 +512,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return self.shrink(tuple((0,s) if i != axis else (dnum*sz,dnum*sz+sz) for i,s in enumerate(self.shape))) def shard(self, devices:tuple[str, ...], axis:int) -> UOp: return self.copy_to_device(devices)._shard(axis).multi(axis) - # *** from LazyBuffer *** - def copy_to_device(self, device:str|tuple[str, ...]|UOp, arg=None): assert arg is None or isinstance(self.device, tuple) inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self,), arg=arg)