functools.partial keeps mypy compiler working

This commit is contained in:
George Hotz
2023-02-08 18:04:32 -06:00
parent cfd13c083b
commit 4c2faa4140

View File

@@ -52,7 +52,7 @@ shape_fxn_for_op : Dict[Op, Callable] = {
**{op:lambda self: GenericShape(self.shape, self.flops + prod(self.shape)) for op in UnaryOps},
**{op:lambda self,y: GenericShape(self.shape, self.flops + y.flops + prod(self.shape)) for op in BinaryOps},
**{op:lambda self,new_shape: GenericShape(new_shape, self.flops + prod(self.shape)) for op in ReduceOps},
**{op:(lambda mop: (lambda self,arg: GenericShape(ShapeTracker(self.shape).movement_op(mop, arg).shape, self.flops)))(op) for op in MovementOps},
**{op:functools.partial(lambda mop,self,arg: GenericShape(ShapeTracker(self.shape).movement_op(mop, arg).shape, self.flops), op) for op in MovementOps},
# https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html
**{op:lambda self,w,C: GenericShape(C.out_shape, 2 * (C.bs * C.cout * C.oy * C.ox) * (C.cin * C.H * C.W)) for op in ProcessingOps}}