less lines (#643)

This commit is contained in:
Peter McDevitt
2023-03-05 09:37:14 -07:00
committed by GitHub
parent 3da56ab41d
commit 30f2238994

View File

@@ -67,10 +67,7 @@ class DeviceBuffer(RawBuffer):
# this is a quick "buffer" class for flop tracking and getting the output shape
class GenericShape:
def __init__(self, shape:Tuple[int, ...], flops:int=0): self.shape, self.flops = shape, flops
def consume_flops(self):
ret = self.flops
self.flops = 0
return ret
def consume_flops(self): return (self.flops, setattr(self, 'flops', 0))[0]
shape_fxn_for_op : Dict[Op, Callable] = {
**{op:lambda self: GenericShape(self.shape, self.consume_flops() + prod(self.shape)) for op in UnaryOps},
**{op:lambda self,y: GenericShape(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},