mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 10:31:41 -05:00
Back off minor speed stuff for simplicity (#1053)
* passing in buffers doesn't increase speed * functools.reduce * no more get_buffers
This commit is contained in:
@@ -29,7 +29,7 @@ def _ast_reduceops(self:LazyBuffer) -> LazyOp:
|
||||
src = self.op.src[0]
|
||||
if MERGE_ELEMENTWISE_INTO_REDUCE and not src.realized and src.optype == BinaryOps and len(src.children) <= 1:
|
||||
src = src.op # type: ignore
|
||||
return LazyOp(self.op.op, (src,), self.op.arg, src.get_buffers())
|
||||
return LazyOp(self.op.op, (src,), self.op.arg)
|
||||
|
||||
# this supports late merging an upstream Reduce op and even an Elementwise op above that
|
||||
def _ast_binaryops(self:LazyBuffer) -> LazyOp:
|
||||
@@ -55,7 +55,7 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp:
|
||||
for x in real_srcs.keys():
|
||||
if not real_srcs[x]: real_srcs[x] = x.reshape_op(intermediate_shape)
|
||||
ast = self.op.map_buffers(real_srcs)
|
||||
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape, ast.buffers) if intermediate_shape != self.shape else ast
|
||||
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast
|
||||
|
||||
# **** lazy operations ****
|
||||
|
||||
@@ -135,7 +135,7 @@ class LazyBuffer:
|
||||
self.dtype = self.realized.dtype
|
||||
|
||||
# log to the graph
|
||||
if self.realized.__class__ is not RawConst or GRAPH >= 2:
|
||||
if self.realized.__class__ is not RawConst or GRAPH >= 2:
|
||||
from tinygrad.graph import log_op
|
||||
log_op(self, self.op)
|
||||
|
||||
@@ -149,7 +149,7 @@ class LazyBuffer:
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x: np.ndarray) -> LazyBuffer:
|
||||
return LazyBuffer("CPU", ShapeTracker(x.shape), LoadOps, LazyOp(LoadOps.EMPTY, (), None, ()), dtypes.from_np(x.dtype), RawNumpyBuffer.fromCPU(x))
|
||||
return LazyBuffer("CPU", ShapeTracker(x.shape), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), RawNumpyBuffer.fromCPU(x))
|
||||
|
||||
# create a constant with the shape and dtype of self
|
||||
def const_like(self, val) -> LazyBuffer:
|
||||
@@ -168,19 +168,19 @@ class LazyBuffer:
|
||||
def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y)
|
||||
def contiguous(self:LazyBuffer) -> LazyBuffer:
|
||||
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
|
||||
return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None, (self,)), self.dtype)
|
||||
|
||||
return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype)
|
||||
|
||||
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[int, ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer:
|
||||
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and len(self.children) == 0:
|
||||
return self.op.replace_with_movement_ops([(op, arg)])
|
||||
ret = create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg, (self,)), self.dtype)
|
||||
ret = create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype)
|
||||
if REMOVE_MOVEMENT_NOPS and not self.realized and not ret.realized and ret.st.contiguous:
|
||||
# MovementOps aren't stacked any more, they each have one parent, find the root
|
||||
root = get_movementroot(self)
|
||||
if root.st.contiguous and root != self and prod(ret.st.shape) == prod(root.shape):
|
||||
return root.reshape_op(ret.st.shape)
|
||||
return ret
|
||||
|
||||
|
||||
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
if self.shape == tuple(new_shape): return self
|
||||
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
|
||||
@@ -227,26 +227,27 @@ class LazyBuffer:
|
||||
return self.op.src[0].permute_op(tuple(flatten(shape_idx_groups[i] for i in arg))) \
|
||||
.reshape_op(ShapeTracker(self.st).permute(arg).shape)
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).permute(arg), MovementOps.PERMUTE, arg)
|
||||
|
||||
|
||||
def shrink_op(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
if all([b - a == s for s, (a, b) in zip(self.shape, arg)]): return self
|
||||
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink_op(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).shrink(arg), MovementOps.SHRINK, arg)
|
||||
|
||||
|
||||
def stride_op(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
local_st = ShapeTracker(self.shape).stride(arg)
|
||||
if self.shape == local_st.shape and local_st.contiguous: return self
|
||||
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride_op(tuple(map(operator.mul, arg, self.op.arg)))
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).stride(arg), MovementOps.STRIDE, arg)
|
||||
|
||||
@property
|
||||
def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,)
|
||||
def map_buffers(self, real_srcs: Dict[Any, Any]): return real_srcs.get(self, self)
|
||||
def get_buffers(self) -> Tuple[LazyBuffer, ...]: return (self,)
|
||||
def get_lazyops(self) -> List[Any]: return []
|
||||
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
|
||||
y = self
|
||||
for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg)
|
||||
return y
|
||||
|
||||
|
||||
def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
|
||||
new_srcs = []
|
||||
for x in srcs:
|
||||
@@ -333,7 +334,7 @@ def _realize_empty(buffer: LazyBuffer) -> None:
|
||||
def _realize_rand(buffer: LazyBuffer) -> None:
|
||||
rng = np.random.default_rng(buffer.op.arg)
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=buffer.shape, dtype=buffer.dtype.np), **buffer._device_extra_args()) # type: ignore
|
||||
|
||||
|
||||
def _realize_const(buffer: LazyBuffer) -> None:
|
||||
if hasattr(Device[buffer.device].codegen, 'supports_constant_folding'):
|
||||
buffer.realized = RawConst(1, buffer.dtype, float(buffer.op.arg))
|
||||
|
||||
@@ -29,16 +29,12 @@ class LazyOp:
|
||||
arg: Any
|
||||
buffers: Tuple[LazyBuffer, ...]
|
||||
|
||||
def __init__(self, op: Op, src: Tuple[Union[LazyOp, LazyBuffer], ...], arg: Any = None, buffers: Optional[Tuple[LazyBuffer, ...]] = None):
|
||||
def __init__(self, op: Op, src: Tuple[Union[LazyOp, LazyBuffer], ...], arg: Any = None):
|
||||
self.op = op
|
||||
self.src = src
|
||||
self.arg = arg
|
||||
if not buffers:
|
||||
buffers = tuple()
|
||||
for s in src:
|
||||
try: buffers += s.get_buffers()
|
||||
except AttributeError: pass
|
||||
self.buffers = buffers
|
||||
# TODO: this hasattr is required because the key function maps the buffers to ints
|
||||
self.buffers = functools.reduce(lambda x,s: (x+s.buffers) if hasattr(s, 'buffers') else x, src, tuple())
|
||||
|
||||
def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
@@ -51,15 +47,13 @@ class LazyOp:
|
||||
|
||||
# Any == Union[LazyBuffer, DeviceBuffer]
|
||||
def map_buffers(self, real_srcs: Dict[Any, Any]): return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg)
|
||||
|
||||
def get_buffers(self) -> Tuple[LazyBuffer, ...]: return self.buffers
|
||||
def get_lazyops(self) -> List['LazyOp']: return [self] + [item for x in self.src for item in x.get_lazyops()]
|
||||
|
||||
def replace_with_movement_ops(self: LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer':
|
||||
from tinygrad.lazy import elementwise_op
|
||||
assert self.op in BinaryOps or self.op in UnaryOps
|
||||
return elementwise_op(self.op, *[z.replace_with_movement_ops(ops) for z in self.src], arg=self.arg) # type: ignore
|
||||
|
||||
|
||||
@property
|
||||
def st(self): raise NotImplementedError
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user