no RESHAPEs in the AST

This commit is contained in:
George Hotz
2022-10-28 18:25:30 -07:00
parent 294ab9e2f8
commit 71b336503f
3 changed files with 12 additions and 10 deletions

View File

@@ -153,8 +153,7 @@ class LLVMBuffer(ExplicitExecAST):
BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y),
BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y),
BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y]),
BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y), ir.FloatType()),
MovementOps.RESHAPE: lambda builder,x: x,
BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y), ir.FloatType())
}
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf=None):
super().__init__(shape, hostbuf)

View File

@@ -67,28 +67,31 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
real_srcs : Dict[LazyBuffer, Union[None, LazyOp, DeviceBuffer]] = {x:None for x in get_buffers(self.op)}
op_type : OpType = BinaryOps
psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype in [ProcessingOps,ReduceOps] and x.realized is None and len(x.children) <= 1 and len(k.children) <= 1]
intermediate_shape = self.shape
if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE and (self.device != "OPENCL" or self.shape[-1] == 4):
if psrcs[0][1].optype == ProcessingOps:
real_srcs[psrcs[0][0]] = psrcs[0][1].op
real_srcs[psrcs[0][1].op.src[0]], real_srcs[psrcs[0][1].op.src[1]] = None, None
for x in psrcs[0][1].op.src:
real_srcs[x] = x.realize(self.device)
op_type = ProcessingOps
elif psrcs[0][1].optype == ReduceOps:
src = psrcs[0][1].op.src[0]
if MERGE_ELEMENTWISE_INTO_REDUCE and src.realized is None and src.optype == BinaryOps and len(src.children) <= 1:
src = src.op
for x in (get_buffers(src) if isinstance(src, LazyOp) else [src]):
real_srcs[x] = None
real_srcs[x] = x.realize(self.device)
real_srcs[psrcs[0][0]] = LazyOp(psrcs[0][1].op.op, (src,), psrcs[0][1].op.arg)
op_type = ReduceOps
# a reshape is allowed after the ReduceOp, it's a nop in the backend
# if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs
if psrcs[0][0].shape != psrcs[0][1].shape:
real_srcs[psrcs[0][0]] = LazyOp(MovementOps.RESHAPE, (real_srcs[psrcs[0][0]],), psrcs[0][0].shape)
intermediate_shape = psrcs[0][1].shape
assert psrcs[0][0].shape == self.shape, f"shape mismatch {psrcs[0][0].shape} != {self.shape}"
# NOTE: these RESHAPEs will return self if they don't change the shape
for x in real_srcs.keys():
if real_srcs[x] is None:
real_srcs[x] = x.realize(self.device)
real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape).realize(self.device)
ret = self.dbuffer.exec_ast(realize_buffers(real_srcs, self.op))
assert ret.shape == self.shape, f"shape mismatch {ret.shape} != {self.shape}"
return ret, [x for x in real_srcs.values() if not isinstance(x, LazyOp) and x is not None], op_type
return ret.movement_op(MovementOps.RESHAPE, self.shape), [x for x in real_srcs.values() if not isinstance(x, LazyOp) and x is not None], op_type
_realize = {LoadOps:_realize_loadops, ReduceOps:_realize_reduceops, MovementOps:_realize_movementops, BinaryOps:_realize_binaryops, ProcessingOps:_realize_processingops}

View File

@@ -82,7 +82,7 @@ class GPUBuffer(ExplicitExecAST):
UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.SIGN: "sign(A)", UnaryOps.RECIPROCAL: "((float)1.0/A)",
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)",
BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)", MovementOps.RESHAPE: "(A)"
ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)"
}
start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "-INFINITY"}