From 64ff1ddc1074cbe7dc2f4b8f620eb25fbe14ef0c Mon Sep 17 00:00:00 2001 From: Daniel Davis Date: Wed, 9 Nov 2022 13:07:22 -0500 Subject: [PATCH] Reduce line count (#424) * save a line, save a life * save a line, save a life * change order of tern --- tinygrad/mlops.py | 3 +-- tinygrad/tensor.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 00ad2a8531..d82bff4922 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -64,8 +64,7 @@ class Max(Function): max_is_1s = x.binary_op(BinaryOps.CMPEQ, ret.movement_op(MovementOps.EXPAND, x.shape)) # sum of locations, averaged - div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape) - div = div.movement_op(MovementOps.EXPAND, x.shape) + div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).movement_op(MovementOps.EXPAND, x.shape) max_is_amount = max_is_1s.binary_op(BinaryOps.DIV, div) grad_output_expanded = grad_output.movement_op(MovementOps.EXPAND, x.shape) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b3dff6a02f..bb58b7313c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -19,8 +19,7 @@ class Tensor: data = data.realize().toCPU() if isinstance(data, np.ndarray): - if data.shape == tuple(): - data = data.reshape((1,)) + data = data if data.shape else data.reshape((1,)) self.lazydata = LazyBuffer.fromCPU(data.astype(np.float32), device) elif isinstance(data, LazyBuffer): self.lazydata = data