mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
@@ -67,7 +67,10 @@ class DeviceBuffer(RawBuffer):
|
||||
# this is a quick "buffer" class for flop tracking and getting the output shape
|
||||
class GenericShape:
|
||||
def __init__(self, shape:Tuple[int, ...], flops:int=0): self.shape, self.flops = shape, flops
|
||||
def consume_flops(self): return (self.flops, setattr(self, 'flops', 0))[0]
|
||||
def consume_flops(self):
|
||||
ret = self.flops
|
||||
self.flops = 0
|
||||
return ret
|
||||
shape_fxn_for_op : Dict[Op, Callable] = {
|
||||
**{op:lambda self: GenericShape(self.shape, self.consume_flops() + prod(self.shape)) for op in UnaryOps},
|
||||
**{op:lambda self,y: GenericShape(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},
|
||||
|
||||
Reference in New Issue
Block a user