Onnx slice fixups (#952)

* resolved some slice test errors and added some more debugging logs

* use same device in cumsum

* increased float priority

* onnx debug ouput match input
This commit is contained in:
Diogo
2023-06-07 22:44:30 -04:00
committed by GitHub
parent e8a23d4331
commit 666d151f8a
5 changed files with 36 additions and 32 deletions

View File

@@ -76,7 +76,7 @@ class dtypes:
def from_np(x) -> DType: return asdict(dtypes())[np.dtype(x).name]
bool: Final[DType] = DType(0, 1, "bool", bool)
float16: Final[DType] = DType(0, 2, "half", np.float16)
float32: Final[DType] = DType(1, 4, "float", np.float32)
float32: Final[DType] = DType(4, 4, "float", np.float32)
int8: Final[DType] = DType(0, 1, "char", np.int8)
int32: Final[DType] = DType(1, 4, "int", np.int32)
int64: Final[DType] = DType(2, 8, "int64", np.int64)

View File

@@ -143,7 +143,7 @@ class Tensor:
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs)
@staticmethod
def arange(stop, start=0, step=1, **kwargs): return Tensor.full(((stop-start)//step,), step).cumsum() + (start - step)
def arange(stop, start=0, step=1, **kwargs): return Tensor.full(((stop-start)//step,), step, **kwargs).cumsum() + (start - step)
@staticmethod
def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs):
@@ -470,7 +470,7 @@ class Tensor:
def cumsum(self, axis=0):
x = self.permute(*(i for i in range(self.ndim) if i != axis), axis)
return x.reshape(1, 1, -1, self.shape[axis]).conv2d(Tensor.ones(1, 1, 1, self.shape[axis]), padding=(self.shape[axis]-1, 0, 0, 0)).reshape(*x.shape).permute(*range(axis), self.ndim - 1, *range(axis, self.ndim-1))
return x.reshape(1, 1, -1, self.shape[axis]).conv2d(Tensor.ones(1, 1, 1, self.shape[axis], dtype=self.dtype, device=self.device), padding=(self.shape[axis]-1, 0, 0, 0)).reshape(*x.shape).permute(*range(axis), self.ndim - 1, *range(axis, self.ndim-1))
# ***** mlops (unary) *****