PTX assembly support (#977)

* ptx assembly

* all ops tests pass

* fix tests
This commit is contained in:
George Hotz
2023-06-13 12:31:42 -07:00
committed by GitHub
parent 727416201f
commit ba4eadb04c
9 changed files with 280 additions and 26 deletions

View File

@@ -471,7 +471,7 @@ class Tensor:
def cumsum(self, axis=0):
x = self.permute(*(i for i in range(self.ndim) if i != axis), axis)
return x.reshape(1, 1, -1, self.shape[axis]).conv2d(Tensor.ones(1, 1, 1, self.shape[axis], dtype=self.dtype, device=self.device), padding=(self.shape[axis]-1, 0, 0, 0)).reshape(*x.shape).permute(*range(axis), self.ndim - 1, *range(axis, self.ndim-1))
# ***** mlops (unary) *****
def contiguous(self): return mlops.Contiguous.apply(self)
@@ -481,12 +481,12 @@ class Tensor:
def sin(self): return mlops.Sin.apply(self)
def cos(self): return ((math.pi/2)-self).sin()
def tan(self): return self.sin() / self.cos()
@staticmethod
def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(c-k, start=-k, **kwargs).unsqueeze(0).expand(r,c)
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype).where(self, Tensor.zeros_like(self))
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype).where(Tensor.zeros_like(self), self)
# ***** math functions (unary) *****
def __neg__(self): return 0.0-self
@@ -527,7 +527,12 @@ class Tensor:
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Add, x, reverse) if isinstance(x, Tensor) or x != 0.0 else self
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if isinstance(x, Tensor) or x != 0.0 or reverse else self
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if isinstance(x, Tensor) or x != 1.0 else self
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Pow, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor:
if not isinstance(x, Tensor) and not reverse:
# simple pow identities
if x == 2.0: return self*self
if x == -1.0: return 1/self
return self._broadcasted(mlops.Pow, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if isinstance(x, Tensor) or reverse or x == 0.0 else self.mul(1/x)
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)