mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
use ast engine for merged reduceop
This commit is contained in:
@@ -111,8 +111,7 @@ def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
|
||||
buf_names : Dict[LazyBuffer, str] = {x:f"arg_{i}" for i,x in enumerate(real_srcs.keys())}
|
||||
|
||||
return self.dbuffer(self.shape)._processing_op([(buf_names[lb], db) for lb,db in real_srcs.items()], \
|
||||
self.dbuffer.code_for_op[self.op.op].replace("A", _ast(src.op, buf_names, self.dbuffer.code_for_op)), \
|
||||
start=self.dbuffer.start_for_op[self.op.op]), \
|
||||
_ast(LazyOp(self.op.op, [src.op], self.op.arg), buf_names, self.dbuffer.code_for_op), start=self.dbuffer.start_for_op[self.op.op]), \
|
||||
list(real_srcs.values()), ReduceOps
|
||||
else:
|
||||
real_src = src.realize(self.device)
|
||||
@@ -136,6 +135,7 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
|
||||
conv_args : Optional[ConvArgs] = None
|
||||
psrcs = [x for x in real_srcs.keys() if x.optype == ProcessingOps and x.realized is None and len(x.children) <= 1]
|
||||
if len(psrcs) == 1 and MERGE_ONE_CONV_INTO_ELEMENTWISE:
|
||||
# TODO: do something similar to what i did with reduceop to use the ast engine?
|
||||
conv_args = psrcs[0].op.arg
|
||||
del real_srcs[psrcs[0]]
|
||||
real_srcs[psrcs[0].op.src[0]], real_srcs[psrcs[0].op.src[1]] = None, None
|
||||
|
||||
Reference in New Issue
Block a user