From 8de160d08ee386faef2e48a63d178259d93b16d8 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 2 Jan 2024 12:52:20 -0800 Subject: [PATCH] hotfix: remove dead code, save lines --- tinygrad/mlops.py | 46 +++++++++++++++------------------------------- 1 file changed, 15 insertions(+), 31 deletions(-) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index efe8639516..05dd925192 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -20,8 +20,7 @@ class Cast(Function): self.input_dtype, self.bitcast = x.dtype, bitcast return x.cast(dtype, bitcast) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.cast(self.input_dtype, self.bitcast) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.cast(self.input_dtype, self.bitcast) # ************* unary ops ************* @@ -55,16 +54,14 @@ class Log(Function): self.x = x return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2))) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.e(BinaryOps.DIV, self.x) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.DIV, self.x) class Exp(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2) return self.ret - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return self.ret.e(BinaryOps.MUL, grad_output) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.e(BinaryOps.MUL, grad_output) class Sqrt(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: @@ -88,29 +85,23 @@ class Sigmoid(Function): # ************* binary ops ************* class Less(Function): - def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - return x.e(BinaryOps.CMPLT, y) + def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y) class Eq(Function): - def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - return x.e(BinaryOps.CMPEQ, y) + def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPEQ, y) class Xor(Function): - def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - return x.e(BinaryOps.XOR, y) + def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.XOR, y) class Add(Function): - def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - return x.e(BinaryOps.ADD, y) + def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.ADD, y) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return grad_output if self.needs_input_grad[0] else None, \ grad_output if self.needs_input_grad[1] else None class Sub(Function): - def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - self.x_dtype, self.y_dtype = x.dtype, y.dtype - return x.e(BinaryOps.SUB, y) + def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.SUB, y) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return grad_output if self.needs_input_grad[0] else None, \ @@ -153,8 +144,7 @@ class Sum(Function): self.input_shape = x.shape return x.r(ReduceOps.SUM, new_shape) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.expand(self.input_shape) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape) class Max(Function): def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer: @@ -175,45 +165,39 @@ class Expand(Function): self.input_shape = x.shape return x.expand(shape) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.r(ReduceOps.SUM, self.input_shape) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.r(ReduceOps.SUM, self.input_shape) class Reshape(Function): def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer: self.input_shape = x.shape return x.reshape(shape) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.reshape(self.input_shape) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.reshape(self.input_shape) class Permute(Function): def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer: self.input_order = order return x.permute(order) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.permute(argsort(self.input_order)) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.permute(argsort(self.input_order)) class Pad(Function): def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)]) return x.pad(arg) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.shrink(self.narg) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.shrink(self.narg) class Shrink(Function): def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)]) return x.shrink(arg) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.pad(self.narg) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.pad(self.narg) class Flip(Function): def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))]) return x.stride(self.arg) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.stride(self.arg) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.stride(self.arg)