From 11d0cfec770e59fff706e29ef4dc8a9169b67578 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 5 Jun 2022 14:13:08 -0700 Subject: [PATCH] more readable and faster --- tinygrad/ops/ops_gpu.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tinygrad/ops/ops_gpu.py b/tinygrad/ops/ops_gpu.py index 8facf90937..bb64b6ff4b 100644 --- a/tinygrad/ops/ops_gpu.py +++ b/tinygrad/ops/ops_gpu.py @@ -14,19 +14,19 @@ class UnaryOp(Function): def backward(ctx, grad_output): input, = ctx.saved_tensors - return binary_op(ctx, ctx.bop, grad_output, input) + return binary_op(ctx, ctx.bop, input, grad_output) class ReLU(UnaryOp): fop = 'max(a, (float)0.)' - bop = 'a * (b >= 0)' + bop = 'b * (a >= 0)' class Log(UnaryOp): fop = 'log(a)' - bop = 'a / b' + bop = 'b / a' class Exp(UnaryOp): fop = 'exp(a)' - bop = 'a * exp(b)' + bop = 'b * exp(a)' # ************* reduce ops ************* @@ -37,7 +37,8 @@ class Sum(Function): def backward(ctx, grad_output): shape_input, = ctx.saved_tensors - return binary_op(ctx, 'a+b', grad_output, buffer_new(ctx, shape_input, zero=True)) + # NOTE: the b buffer_new isn't used, since this is just for broadcast + return binary_op(ctx, 'a', grad_output, buffer_new(ctx, shape_input)) class Max(Function): def forward(ctx, input, axis=None):