fix mypy, add TODOs

This commit is contained in:
George Hotz
2022-07-08 09:57:22 -07:00
parent 8557ed88df
commit 7e17f2ae8d

View File

@@ -24,6 +24,7 @@ DEBUG = int(os.getenv("DEBUG", "0"))
GRAPH = int(os.getenv("GRAPH", "0"))
OPT = int(os.getenv("OPT", "1"))
# TODO: movement ops that only change shape are really nops. treat them as such
MERGE_MOVEMENT_OPS, REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS = OPT>=1, OPT>=1, OPT>=1
MERGE_ELEMENTWISE_OPS, MERGE_ONE_CONV_INTO_ELEMENTWISE, MERGE_ELEMENTWISE_INTO_REDUCE = OPT>=2, OPT>=2, OPT>=2
SHUFFLE_MOVEMENT_OPS = OPT>=3
@@ -104,6 +105,7 @@ def _realize_loadops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer],
return Device._buffers[self.device].fromCPU(self.op.arg), [], LoadOps
def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
# TODO: this can also corealize a binary op after the reduce, not just before
src = self.op.src[0]
if MERGE_ELEMENTWISE_INTO_REDUCE and getattr(self.dbuffer, "start_for_op", None) and src.realized is None and src.optype == BinaryOps and len(src.children) <= 1:
# TODO: this code is (somewhat) repeated in _realize_binaryops
@@ -111,7 +113,7 @@ def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
buf_names : Dict[LazyBuffer, str] = {x:f"arg_{i}" for i,x in enumerate(real_srcs.keys())}
return self.dbuffer(self.shape)._processing_op([(buf_names[lb], db) for lb,db in real_srcs.items()], \
_ast(LazyOp(self.op.op, [src.op], self.op.arg), buf_names, self.dbuffer.code_for_op), start=self.dbuffer.start_for_op[self.op.op]), \
_ast(LazyOp(self.op.op, (src.op,), self.op.arg), buf_names, self.dbuffer.code_for_op), start=self.dbuffer.start_for_op[self.op.op]), \
list(real_srcs.values()), ReduceOps
else:
real_src = src.realize(self.device)