mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Disktensors! (#819)
* make empty a real thing * start ops_disk * disk tensor works * interpreted cleanup * slice write to disk * preprocess imagenet * fix custom function
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -24,3 +24,5 @@ datasets/cifar-10-python.tar.gz
|
||||
datasets/librispeech/
|
||||
datasets/imagenet/
|
||||
datasets/squad/
|
||||
datasets/img_align_celeba*
|
||||
datasets/open-images-v6-mlperf
|
||||
|
||||
22
datasets/preprocess_imagenet.py
Normal file
22
datasets/preprocess_imagenet.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
from datasets.imagenet import iterate, get_val_files
|
||||
|
||||
if __name__ == "__main__":
|
||||
#sz = len(get_val_files())
|
||||
sz = 32*100
|
||||
X,Y = None, None
|
||||
|
||||
idx = 0
|
||||
for x,y in iterate(shuffle=False):
|
||||
print(x.shape, y.shape)
|
||||
assert x.shape[0] == y.shape[0]
|
||||
bs = x.shape[0]
|
||||
if X is None:
|
||||
# TODO: need uint8 support
|
||||
X = Tensor.empty(sz, *x.shape[1:], device="disk:/tmp/imagenet_x")
|
||||
Y = Tensor.empty(sz, *y.shape[1:], device="disk:/tmp/imagenet_y")
|
||||
print(X.shape, Y.shape)
|
||||
X[idx:idx+bs].assign(x)
|
||||
Y[idx:idx+bs].assign(y)
|
||||
idx += bs
|
||||
if idx >= sz: break
|
||||
@@ -24,7 +24,7 @@ def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
|
||||
return ret.realized
|
||||
|
||||
def atan2_cpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
|
||||
return Device[ret.device].buffer(np.arctan2(a.realized._buf, b.realized._buf))
|
||||
return Device[ret.device].from_underlying(np.arctan2(a.realized._buf, b.realized._buf))
|
||||
|
||||
# *** second, we write the ATan2 mlop ***
|
||||
# NOTE: The derivative of atan2 doesn't need a custom op! https://www.liquisearch.com/atan2/derivative
|
||||
|
||||
@@ -56,6 +56,8 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
|
||||
class TestOps(unittest.TestCase):
|
||||
def test_zeros(self):
|
||||
helper_test_op([], lambda: torch.zeros(45,65), lambda: Tensor.zeros(45,65), forward_only=True)
|
||||
def test_empty_0(self):
|
||||
helper_test_op([], lambda: torch.empty(45,65)*0/0, lambda: Tensor.empty(45,65)*0/0, forward_only=True)
|
||||
def test_ones(self):
|
||||
helper_test_op([], lambda: torch.ones(45,65), lambda: Tensor.ones(45,65), forward_only=True)
|
||||
def test_eye(self):
|
||||
|
||||
56
test/unit/test_disk_tensor.py
Normal file
56
test/unit/test_disk_tensor.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import pathlib
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class TestDiskTensor(unittest.TestCase):
|
||||
def test_empty(self):
|
||||
pathlib.Path("/tmp/dt1").unlink(missing_ok=True)
|
||||
Tensor.empty(100, 100, device="disk:/tmp/dt1")
|
||||
|
||||
def test_write_ones(self):
|
||||
pathlib.Path("/tmp/dt2").unlink(missing_ok=True)
|
||||
|
||||
out = Tensor.ones(10, 10, device="CPU")
|
||||
outdisk = out.to("disk:/tmp/dt2")
|
||||
print(outdisk)
|
||||
outdisk.realize()
|
||||
del out, outdisk
|
||||
|
||||
# test file
|
||||
with open("/tmp/dt2", "rb") as f:
|
||||
assert f.read() == b"\x00\x00\x80\x3F" * 100
|
||||
|
||||
# test load alt
|
||||
reloaded = Tensor.empty(10, 10, device="disk:/tmp/dt2")
|
||||
out = reloaded.numpy()
|
||||
assert np.all(out == 1.)
|
||||
|
||||
def test_slice(self):
|
||||
pathlib.Path("/tmp/dt3").unlink(missing_ok=True)
|
||||
Tensor.arange(10, device="disk:/tmp/dt3").realize()
|
||||
|
||||
slice_me = Tensor.empty(10, device="disk:/tmp/dt3")
|
||||
print(slice_me)
|
||||
is_3 = slice_me[3:4].cpu()
|
||||
assert is_3.numpy()[0] == 3
|
||||
|
||||
def test_slice_2d(self):
|
||||
pathlib.Path("/tmp/dt5").unlink(missing_ok=True)
|
||||
Tensor.arange(100, device="CPU").to("disk:/tmp/dt5").realize()
|
||||
slice_me = Tensor.empty(10, 10, device="disk:/tmp/dt5")
|
||||
tst = slice_me[1].numpy()
|
||||
print(tst)
|
||||
np.testing.assert_allclose(tst, np.arange(10, 20))
|
||||
|
||||
def test_assign_slice(self):
|
||||
pathlib.Path("/tmp/dt4").unlink(missing_ok=True)
|
||||
cc = Tensor.arange(10, device="CPU").to("disk:/tmp/dt4").realize()
|
||||
#cc.assign(np.ones(10)).realize()
|
||||
print(cc[3:5].numpy())
|
||||
cc[3:5].assign([13, 12]).realize()
|
||||
print(cc.numpy())
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -8,6 +8,7 @@ def multidevice_test(fxn):
|
||||
exclude_devices = getenv("EXCLUDE_DEVICES", "").split(",")
|
||||
def ret(self):
|
||||
for device in Device._buffers:
|
||||
if device == "DISK": continue
|
||||
print(device)
|
||||
if device in exclude_devices:
|
||||
print(f"WARNING: {device} test is excluded")
|
||||
|
||||
@@ -102,7 +102,7 @@ class LazyBuffer:
|
||||
if GRAPH >= 3: log_op(self, self.op, phantom=True)
|
||||
|
||||
def __repr__(self): return f"<LB {self.shape} {self.dtype} op:{self.op.op if self.realized is None else self.realized} st:{self.st}>"
|
||||
def _device_extra_args(self) -> Dict[str, int]: return {"device": int(self.device.split(":")[1])} if ":" in self.device else {}
|
||||
def _device_extra_args(self) -> Dict[str, str]: return {"device": self.device.split(":")[1]} if ":" in self.device else {}
|
||||
|
||||
def realize(self:LazyBuffer) -> LazyBuffer:
|
||||
if self.realized is None:
|
||||
@@ -113,6 +113,8 @@ class LazyBuffer:
|
||||
else:
|
||||
if DEBUG >= 4: print(f"copying {self.op.arg.shape}:{dtypes.from_np(self.op.arg.dtype)} -> {self.device}")
|
||||
self.realized = Device[self.device].buffer.fromCPU(self.op.arg(), **self._device_extra_args())
|
||||
elif self.op.op == LoadOps.EMPTY:
|
||||
self.realized = Device[self.device].buffer(prod(self.shape), self.dtype, **self._device_extra_args())
|
||||
elif self.op.op == LoadOps.CONTIGUOUS:
|
||||
realized = self.op.src[0].realize().realized
|
||||
if self.op.src[0].st.contiguous and not isinstance(realized, RawConst) and realized.size == prod(self.shape):
|
||||
@@ -161,6 +163,10 @@ class LazyBuffer:
|
||||
def fromCPU(x:LazyNumpyArray, device) -> LazyBuffer:
|
||||
return create_lazybuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x), dtypes.from_np(x.dtype))
|
||||
|
||||
@staticmethod
|
||||
def empty(shape, dtype, device) -> LazyBuffer:
|
||||
return create_lazybuffer(device, shape, LoadOps, LazyOp(LoadOps.EMPTY, tuple()), dtype)
|
||||
|
||||
# create a constant with the shape and dtype of self
|
||||
def const_like(self, val) -> LazyBuffer:
|
||||
return create_lazybuffer(self.device, (1,), LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), LazyNumpyArray([val], (1,), self.dtype.np)), self.dtype) \
|
||||
|
||||
@@ -12,7 +12,7 @@ class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); CAST = auto();
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class FusedOps(Enum): MULACC = auto() # noqa: E702
|
||||
class LoadOps(Enum): FROMCPU = auto(); CONTIGUOUS = auto(); TOCPU = auto(); CUSTOM = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); FROMCPU = auto(); CONTIGUOUS = auto(); TOCPU = auto(); CUSTOM = auto() # noqa: E702
|
||||
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps]
|
||||
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[FusedOps]]
|
||||
@@ -34,24 +34,25 @@ def map_buffers(real_srcs:Dict[Any, Any], x:Any) -> LazyOp:
|
||||
# **************** for Interpreted Buffers ****************
|
||||
|
||||
class Interpreted:
|
||||
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_lazybuffer=lambda x: x.realized, to_underlying=lambda x: x._buf):
|
||||
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_lazybuffer=lambda x: x.realized, to_underlying=lambda x: x._buf, from_underlying=None):
|
||||
self.buffer = buffer
|
||||
self.fxn_for_op = fxn_for_op
|
||||
self.from_lazybuffer = from_lazybuffer
|
||||
self.from_underlying = buffer if from_underlying is None else from_underlying
|
||||
self.to_underlying = to_underlying
|
||||
self.synchronize = lambda: None
|
||||
self.codegen = None
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output=None, context=None):
|
||||
def exec_ast(self, ast:LazyOp, output=None, context=None, **kwargs):
|
||||
if FusedOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
||||
ast = LazyOp(FusedOps.MULACC, ast.src[0].src, ast.arg)
|
||||
created_context = context is None
|
||||
if context is None: context = dict()
|
||||
if not created_context and ast in context: return context[ast]
|
||||
srcs = [self.exec_ast(x, context=context) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src]
|
||||
srcs = [self.exec_ast(x, context=context, **kwargs) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src]
|
||||
if DEBUG >= 3: st = time.perf_counter()
|
||||
ret = self.buffer(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
|
||||
if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape):30s} in({len(srcs)}):", list(set(x._buf.shape for x in srcs)), ast.arg if ast.arg is not None else "")
|
||||
ret = self.from_underlying(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
|
||||
if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape) if hasattr(ret._buf, 'shape') else str(len(ret._buf)):30s} in({len(srcs)}):", list(set(x._buf.shape if hasattr(x._buf, 'shape') else len(x._buf) for x in srcs)), ast.arg if ast.arg is not None else "")
|
||||
if not created_context: context[ast] = ret
|
||||
if output is not None and output.output_buffer is not None:
|
||||
assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype
|
||||
@@ -151,7 +152,7 @@ class Compiled:
|
||||
k = self.codegen(ast, output)
|
||||
|
||||
# this is the default now
|
||||
if getenv("ENABLE_METHOD_CACHE", 1):
|
||||
if hasattr(k, 'key') and getenv("ENABLE_METHOD_CACHE", 1):
|
||||
if k.key not in self.method_cache: self.method_cache[k.key] = k.codegen().build(self.runtime)
|
||||
elif DEBUG >= 5: print(f"method cache hit : {k.key}")
|
||||
prg = self.method_cache[k.key]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import operator
|
||||
from typing import Callable, Dict, Tuple
|
||||
from tinygrad.helpers import dtypes
|
||||
from typing import Callable, Dict, Tuple, Optional
|
||||
from tinygrad.helpers import dtypes, DType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, FusedOps, Op, Interpreted
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
|
||||
@@ -35,8 +35,8 @@ numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
}}
|
||||
|
||||
class RawNumpyBuffer(RawBuffer):
|
||||
def __init__(self, buf:np.ndarray): super().__init__(buf.size, dtypes.from_np(buf.dtype), buf)
|
||||
def __init__(self, size:int, dtype:DType, buf:Optional[np.ndarray]=None): super().__init__(size, dtype, buf if buf is not None else np.empty([size], dtype.np))
|
||||
@classmethod
|
||||
def fromCPU(cls, x): return cls(x)
|
||||
def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x)
|
||||
def toCPU(self): return self._buf
|
||||
CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op)
|
||||
CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op, from_underlying=RawNumpyBuffer.fromCPU)
|
||||
|
||||
25
tinygrad/runtime/ops_disk.py
Normal file
25
tinygrad/runtime/ops_disk.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import os, mmap
|
||||
from typing import Optional
|
||||
from typing import Callable, Dict
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.runtime.lib import RawBufferMapped
|
||||
from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps
|
||||
|
||||
class RawDiskBuffer(RawBufferMapped):
|
||||
def __init__(self, size, dtype, device:Optional[str]=None, buf=None, shape=None):
|
||||
self.shape = (size, ) if shape is None else shape
|
||||
assert device is not None or buf is not None, "disk tensor needs a path or a buf"
|
||||
if device is not None:
|
||||
with open(device, "a+b") as f:
|
||||
if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize)
|
||||
buf = memoryview(mmap.mmap(f.fileno(), size * dtype.itemsize))
|
||||
super().__init__(size, dtype, buf)
|
||||
def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buffer(), shape=arg)
|
||||
def shrink(self, arg):
|
||||
assert arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}"
|
||||
return RawDiskBuffer(arg[0][1]-arg[0][0], self.dtype, buf=self._buffer()[arg[0][0]*prod(self.shape[1:])*self.dtype.itemsize:arg[0][1]*prod(self.shape[1:])*self.dtype.itemsize])
|
||||
def _buffer(self): return self._buf
|
||||
|
||||
disk_fxn_for_op: Dict[Op, Callable] = { UnaryOps.NOOP: lambda x: x, MovementOps.RESHAPE: RawDiskBuffer.reshape, MovementOps.SHRINK: RawDiskBuffer.shrink }
|
||||
|
||||
DiskBuffer = Interpreted(RawDiskBuffer, disk_fxn_for_op, to_underlying=lambda x:x, from_underlying=lambda x:x)
|
||||
@@ -31,7 +31,7 @@ CL = _CL()
|
||||
|
||||
# TODO: merge CLImage in here
|
||||
class CLBuffer(RawBufferCopyInOut):
|
||||
def __init__(self, size, dtype, device=0):
|
||||
def __init__(self, size, dtype, device='0'):
|
||||
if isinstance(dtype, ImageDType):
|
||||
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
|
||||
buf = cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
|
||||
@@ -39,7 +39,7 @@ class CLBuffer(RawBufferCopyInOut):
|
||||
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
|
||||
else:
|
||||
buf = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, size * dtype.itemsize)
|
||||
setattr(buf, 'device', device) # device is tracked on the underlying buffer
|
||||
setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer
|
||||
super().__init__(size, dtype, buf)
|
||||
def _copyin(self, x:np.ndarray):
|
||||
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import torch
|
||||
from typing import Dict, Callable
|
||||
from typing import Dict, Callable, Optional
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, FusedOps, Op, Interpreted
|
||||
from tinygrad.helpers import getenv, dtypes, prod
|
||||
from tinygrad.helpers import getenv, dtypes, prod, DType
|
||||
from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int32: dtypes.int32, torch.int64: dtypes.int64}
|
||||
inverse_type_map = {v:k for k,v in type_map.items()}
|
||||
|
||||
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin,
|
||||
@@ -18,8 +19,10 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
}}
|
||||
|
||||
class RawTorchBuffer(RawBuffer):
|
||||
def __init__(self, buf:torch.Tensor): super().__init__(prod(buf.shape), type_map[buf.dtype], buf)
|
||||
def __init__(self, size:int, dtype:DType, buf:Optional[torch.Tensor]=None): super().__init__(size, dtype, buf if buf is not None else torch.empty([size], dtype=inverse_type_map[dtype]))
|
||||
@classmethod
|
||||
def fromCPU(cls, x): return cls(torch.from_numpy(x).requires_grad_(False).to(device))
|
||||
def fromCPU(cls, x):
|
||||
buf = torch.from_numpy(x).requires_grad_(False).to(device)
|
||||
return cls(prod(x.shape), type_map[buf.dtype], buf)
|
||||
def toCPU(self): return self._buf.cpu().numpy()
|
||||
TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op)
|
||||
TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op, from_underlying=lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x))
|
||||
|
||||
@@ -34,7 +34,7 @@ class Tensor:
|
||||
default_type: ClassVar[DType] = dtypes.float32
|
||||
|
||||
def __init__(self, data:Union[list, LazyBuffer, LazyNumpyArray, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
device = device.upper().replace(":0", "") # canonicalize device
|
||||
device = (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") # canonicalize device
|
||||
if isinstance(data, list):
|
||||
data = np.array(data, dtype=(dtype if dtype is not None else Tensor.default_type).np)
|
||||
elif isinstance(data, LazyBuffer) and data.device != device:
|
||||
@@ -90,8 +90,13 @@ class Tensor:
|
||||
return self
|
||||
|
||||
def assign(self, x) -> Tensor:
|
||||
if not isinstance(x, Tensor): x = Tensor(x)
|
||||
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
||||
# TODO: this is a hack for writing to DISK
|
||||
if self.device.startswith("DISK"):
|
||||
if not isinstance(x, Tensor): x = Tensor(x, device="CPU", dtype=self.dtype)
|
||||
self.lazydata.realize().realized._copyin(x.numpy()) # type: ignore
|
||||
return self
|
||||
if not isinstance(x, Tensor): x = Tensor(x, device=self.device, dtype=self.dtype)
|
||||
assert self.shape == x.shape and self.device == x.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}"
|
||||
assert not x.requires_grad # self requires_grad is okay?
|
||||
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
|
||||
if self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized
|
||||
@@ -126,7 +131,9 @@ class Tensor:
|
||||
def zeros_like(tensor, **kwargs): return Tensor.zeros(*tensor.shape, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def empty(*shape, **kwargs): return Tensor.zeros(*shape, **kwargs)
|
||||
def empty(*shape, device=Device.DEFAULT, dtype:Optional[DType]=None, **kwargs):
|
||||
# NOTE: we do the reshape to fix interpreted buffers
|
||||
return Tensor(LazyBuffer.empty([prod(shape)], Tensor.default_type if dtype is None else dtype, device), dtype=dtype, device=device, **kwargs).reshape(*shape)
|
||||
|
||||
@staticmethod
|
||||
def eye(dim, **kwargs): return Tensor([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim)
|
||||
@@ -457,7 +464,7 @@ class Tensor:
|
||||
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if isinstance(x, Tensor) or x != 0.0 or reverse else self
|
||||
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if isinstance(x, Tensor) or x != 1.0 else self
|
||||
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Pow, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self
|
||||
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if isinstance(x, Tensor) or reverse else self.mul(1/x)
|
||||
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if isinstance(x, Tensor) or reverse or x == 0.0 else self.mul(1/x)
|
||||
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
|
||||
|
||||
def maximum(self, x:Union[Tensor, float]) -> Tensor: return self._broadcasted(mlops.Maximum, x)
|
||||
|
||||
Reference in New Issue
Block a user