Reduce line count (#424)

* save a line, save a life

* save a line, save a life

* change order of tern
This commit is contained in:
Daniel Davis
2022-11-09 13:07:22 -05:00
committed by GitHub
parent 0994705166
commit 64ff1ddc10
2 changed files with 2 additions and 4 deletions

View File

@@ -64,8 +64,7 @@ class Max(Function):
max_is_1s = x.binary_op(BinaryOps.CMPEQ, ret.movement_op(MovementOps.EXPAND, x.shape))
# sum of locations, averaged
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape)
div = div.movement_op(MovementOps.EXPAND, x.shape)
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).movement_op(MovementOps.EXPAND, x.shape)
max_is_amount = max_is_1s.binary_op(BinaryOps.DIV, div)
grad_output_expanded = grad_output.movement_op(MovementOps.EXPAND, x.shape)

View File

@@ -19,8 +19,7 @@ class Tensor:
data = data.realize().toCPU()
if isinstance(data, np.ndarray):
if data.shape == tuple():
data = data.reshape((1,))
data = data if data.shape else data.reshape((1,))
self.lazydata = LazyBuffer.fromCPU(data.astype(np.float32), device)
elif isinstance(data, LazyBuffer):
self.lazydata = data