mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-11 07:58:08 -05:00
fix GPU import error and old python Tuple
This commit is contained in:
@@ -3,17 +3,19 @@
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
from tinygrad.helpers import prod
|
||||
|
||||
# *** first, we implement the atan2 op at the lowest level ***
|
||||
# `atan2_op` can handle both GPUBuffers and CPUBuffers
|
||||
# `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers
|
||||
|
||||
from tinygrad.ops import ASTRunner
|
||||
from tinygrad.runtime.ops_gpu import GPUBuffer
|
||||
from tinygrad.ops import ASTRunner, CompiledBuffer
|
||||
from tinygrad.runtime.ops_cpu import CPUBuffer
|
||||
|
||||
def atan2_gpu(a:GPUBuffer, b:GPUBuffer) -> GPUBuffer:
|
||||
# we don't always have GPU support, so the type signature is the abstract CompiledBuffer instead of GPUBuffer
|
||||
def atan2_gpu(a:CompiledBuffer, b:CompiledBuffer) -> CompiledBuffer:
|
||||
from tinygrad.runtime.ops_gpu import GPUBuffer
|
||||
assert type(a) == GPUBuffer and type(b) == GPUBuffer, "gpu function requires GPUBuffers"
|
||||
ret = GPUBuffer(a.shape)
|
||||
ASTRunner("atan2", """
|
||||
__kernel void atan2(global float *c, global float *a, global float *b) {
|
||||
@@ -25,12 +27,6 @@ def atan2_gpu(a:GPUBuffer, b:GPUBuffer) -> GPUBuffer:
|
||||
def atan2_cpu(a:CPUBuffer, b:CPUBuffer) -> CPUBuffer:
|
||||
return CPUBuffer(np.arctan2(a._buf, b._buf))
|
||||
|
||||
def atan2_dispatch(a, b):
|
||||
assert prod(a.shape) == prod(b.shape) and type(a) == type(b), "shape or type mismatch"
|
||||
if isinstance(a, GPUBuffer): return atan2_gpu(a, b)
|
||||
elif isinstance(a, CPUBuffer): return atan2_cpu(a, b)
|
||||
else: raise NotImplementedError(f"no atan2 implemented for {type(a)}")
|
||||
|
||||
# *** second, we write the ATan2 mlop ***
|
||||
# NOTE: The derivative of atan2 doesn't need a custom op! https://www.liquisearch.com/atan2/derivative
|
||||
# In general, it is also optional to write a backward function, just your backward pass won't work without it
|
||||
@@ -41,10 +37,11 @@ from tinygrad.tensor import Function
|
||||
|
||||
class ATan2(Function):
|
||||
def forward(self, a:LazyBuffer, b:LazyBuffer) -> LazyBuffer:
|
||||
assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch"
|
||||
self.a, self.b = a, b
|
||||
ast = LazyOp(LoadOps.CUSTOM, (a, b), atan2_dispatch)
|
||||
ast = LazyOp(LoadOps.CUSTOM, (a, b), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device])
|
||||
return LazyBuffer(a.device, a.shape, LoadOps, ast)
|
||||
def backward(self, grad_output:LazyBuffer) -> tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
denom = (self.a.binary_op(BinaryOps.MUL, self.a)).binary_op(BinaryOps.ADD, self.b.binary_op(BinaryOps.MUL, self.b))
|
||||
return grad_output.binary_op(BinaryOps.MUL, self.b.binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \
|
||||
grad_output.binary_op(BinaryOps.MUL, self.a.unary_op(UnaryOps.NEG).binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[1] else None
|
||||
|
||||
Reference in New Issue
Block a user