mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
no RESHAPEs in the AST
This commit is contained in:
@@ -153,8 +153,7 @@ class LLVMBuffer(ExplicitExecAST):
|
||||
BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y),
|
||||
BinaryOps.DIV: lambda builder,x,y: builder.fdiv(x,y),
|
||||
BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y]),
|
||||
BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y), ir.FloatType()),
|
||||
MovementOps.RESHAPE: lambda builder,x: x,
|
||||
BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y), ir.FloatType())
|
||||
}
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf=None):
|
||||
super().__init__(shape, hostbuf)
|
||||
|
||||
@@ -67,28 +67,31 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
|
||||
real_srcs : Dict[LazyBuffer, Union[None, LazyOp, DeviceBuffer]] = {x:None for x in get_buffers(self.op)}
|
||||
op_type : OpType = BinaryOps
|
||||
psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype in [ProcessingOps,ReduceOps] and x.realized is None and len(x.children) <= 1 and len(k.children) <= 1]
|
||||
intermediate_shape = self.shape
|
||||
if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE and (self.device != "OPENCL" or self.shape[-1] == 4):
|
||||
if psrcs[0][1].optype == ProcessingOps:
|
||||
real_srcs[psrcs[0][0]] = psrcs[0][1].op
|
||||
real_srcs[psrcs[0][1].op.src[0]], real_srcs[psrcs[0][1].op.src[1]] = None, None
|
||||
for x in psrcs[0][1].op.src:
|
||||
real_srcs[x] = x.realize(self.device)
|
||||
op_type = ProcessingOps
|
||||
elif psrcs[0][1].optype == ReduceOps:
|
||||
src = psrcs[0][1].op.src[0]
|
||||
if MERGE_ELEMENTWISE_INTO_REDUCE and src.realized is None and src.optype == BinaryOps and len(src.children) <= 1:
|
||||
src = src.op
|
||||
for x in (get_buffers(src) if isinstance(src, LazyOp) else [src]):
|
||||
real_srcs[x] = None
|
||||
real_srcs[x] = x.realize(self.device)
|
||||
real_srcs[psrcs[0][0]] = LazyOp(psrcs[0][1].op.op, (src,), psrcs[0][1].op.arg)
|
||||
op_type = ReduceOps
|
||||
# a reshape is allowed after the ReduceOp, it's a nop in the backend
|
||||
# if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs
|
||||
if psrcs[0][0].shape != psrcs[0][1].shape:
|
||||
real_srcs[psrcs[0][0]] = LazyOp(MovementOps.RESHAPE, (real_srcs[psrcs[0][0]],), psrcs[0][0].shape)
|
||||
intermediate_shape = psrcs[0][1].shape
|
||||
assert psrcs[0][0].shape == self.shape, f"shape mismatch {psrcs[0][0].shape} != {self.shape}"
|
||||
# NOTE: these RESHAPEs will return self if they don't change the shape
|
||||
for x in real_srcs.keys():
|
||||
if real_srcs[x] is None:
|
||||
real_srcs[x] = x.realize(self.device)
|
||||
real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape).realize(self.device)
|
||||
ret = self.dbuffer.exec_ast(realize_buffers(real_srcs, self.op))
|
||||
assert ret.shape == self.shape, f"shape mismatch {ret.shape} != {self.shape}"
|
||||
return ret, [x for x in real_srcs.values() if not isinstance(x, LazyOp) and x is not None], op_type
|
||||
return ret.movement_op(MovementOps.RESHAPE, self.shape), [x for x in real_srcs.values() if not isinstance(x, LazyOp) and x is not None], op_type
|
||||
|
||||
_realize = {LoadOps:_realize_loadops, ReduceOps:_realize_reduceops, MovementOps:_realize_movementops, BinaryOps:_realize_binaryops, ProcessingOps:_realize_processingops}
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ class GPUBuffer(ExplicitExecAST):
|
||||
UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.SIGN: "sign(A)", UnaryOps.RECIPROCAL: "((float)1.0/A)",
|
||||
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)",
|
||||
BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
|
||||
ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)", MovementOps.RESHAPE: "(A)"
|
||||
ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)"
|
||||
}
|
||||
start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "-INFINITY"}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user