lazy cleanup

This commit is contained in:
George Hotz
2023-02-07 07:39:53 -06:00
parent d93563f39f
commit 02d8cb0959

View File

@@ -35,11 +35,13 @@ def realize_buffers(real_srcs, x:LazyOp) -> LazyOp:
# **** realize functions ****
# TODO: make all _realize functions return an AST, perhaps unrealized
# NOTE: loadops and movementops aren't valid ASTs and won't become kernels
def _realize_loadops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], Optional[OpType]]:
if self.op.op == LoadOps.FROMCPU:
return Device._buffers[self.device].fromCPU(self.op.arg), [], LoadOps
elif self.op.op == LoadOps.CONTIGUOUS:
# under the hood, this is an AST or a no op. rename to MetaOps?
real_src = self.op.src[0].realize(self.device)
ret = real_src.contiguous()
return ret, [real_src], LoadOps if id(ret) != id(real_src) else None
@@ -50,6 +52,7 @@ def _realize_movementops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuff
src = self.op.src[0]
# fuse RESHAPE and ReduceOps
# TODO: add MetaOps.TOIMAGE instead?
if src.realized is None and src.optype == ReduceOps and self.op.op == MovementOps.RESHAPE and len(src.children) <= 1:
return _realize_reduceops_w_shape(src, output_shape = self.op.arg)
@@ -57,8 +60,8 @@ def _realize_movementops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuff
return real_src.movement_op(self.op.op, self.op.arg), [real_src], MovementOps
def _realize_processingops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
real_src_x, real_src_w = [x.realize(self.device) for x in self.op.src]
return real_src_x.processing_op(self.op.op, real_src_w, self.op.arg), [real_src_x, real_src_w], ProcessingOps
ast = LazyOp(self.op.op, tuple(x.realize(self.device) for x in self.op.src), self.op.arg)
return self.dbuffer.exec_ast(ast), get_buffers(ast), ProcessingOps
# this supports late merging an upstream Elementwise op
def _realize_reduceops_w_shape(self:LazyBuffer, output_shape=None) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
@@ -69,11 +72,9 @@ def _realize_reduceops_w_shape(self:LazyBuffer, output_shape=None) -> Tuple[Devi
real_srcs : Dict[LazyBuffer, DeviceBuffer] = {x:x.realize(self.device) for x in get_buffers(src.op)}
ast = LazyOp(self.op.op, (realize_buffers(real_srcs, src.op),), self.op.arg)
else:
real_src = src.realize(self.device)
real_srcs = {src:real_src}
ast = LazyOp(self.op.op, (real_src,), self.op.arg)
ast = LazyOp(self.op.op, (src.realize(self.device),), self.op.arg)
if output_shape is not None: ast = LazyOp(MovementOps.RESHAPE, (ast, ), output_shape)
return self.dbuffer.exec_ast(ast), list(real_srcs.values()), ReduceOps
return self.dbuffer.exec_ast(ast), get_buffers(ast), ReduceOps
def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]: return _realize_reduceops_w_shape(self)
# this supports late merging an upstream Reduce op and even an Elementwise op above that
@@ -115,8 +116,7 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
if real_srcs[x] is None:
real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape).realize(self.device)
ast = LazyOp(MovementOps.RESHAPE, (realize_buffers(real_srcs, self.op), ), self.shape)
ret = self.dbuffer.exec_ast(ast)
return ret, [x for x in real_srcs.values() if not isinstance(x, LazyOp) and x is not None], op_type
return self.dbuffer.exec_ast(ast), get_buffers(ast), op_type
_realize = {LoadOps:_realize_loadops, ReduceOps:_realize_reduceops, MovementOps:_realize_movementops, BinaryOps:_realize_binaryops, ProcessingOps:_realize_processingops}