Revert "less lines (#643)" (#644)

This reverts commit 30f2238994.
This commit is contained in:
George Hotz
2023-03-05 08:41:11 -08:00
committed by GitHub
parent 30f2238994
commit e8de3f5736

View File

@@ -67,7 +67,10 @@ class DeviceBuffer(RawBuffer):
# this is a quick "buffer" class for flop tracking and getting the output shape
class GenericShape:
def __init__(self, shape:Tuple[int, ...], flops:int=0): self.shape, self.flops = shape, flops
def consume_flops(self): return (self.flops, setattr(self, 'flops', 0))[0]
def consume_flops(self):
ret = self.flops
self.flops = 0
return ret
shape_fxn_for_op : Dict[Op, Callable] = {
**{op:lambda self: GenericShape(self.shape, self.consume_flops() + prod(self.shape)) for op in UnaryOps},
**{op:lambda self,y: GenericShape(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},