mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
3
.github/workflows/test.yml
vendored
3
.github/workflows/test.yml
vendored
@@ -228,9 +228,6 @@ jobs:
|
||||
- if: ${{ matrix.task == 'onnx' }}
|
||||
name: Test MLPerf optimizers
|
||||
run: GPU=1 python -m pytest -n=auto test/external/external_test_optim.py --durations=20
|
||||
- if: ${{ matrix.task == 'onnx' }}
|
||||
name: Test THREEFRY
|
||||
run: PYTHONPATH=. THREEFRY=1 GPU=1 python3 -m pytest test/test_randomness.py test/test_jit.py
|
||||
|
||||
#testwebgpu:
|
||||
# name: WebGPU Tests
|
||||
|
||||
@@ -6,10 +6,6 @@ from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
|
||||
def rand(*shape):
|
||||
if CI: return Tensor(np.random.rand(*shape))
|
||||
return Tensor.rand(*shape)
|
||||
|
||||
class TestBeamSearch(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.old_beam = BEAM.value
|
||||
@@ -18,44 +14,44 @@ class TestBeamSearch(unittest.TestCase):
|
||||
BEAM.value = self.old_beam
|
||||
|
||||
def test_variable_ast_beam(self):
|
||||
a = rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3))
|
||||
a = Tensor.rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3))
|
||||
a = (a+1).realize()
|
||||
|
||||
def test_big_prime_number(self):
|
||||
a = rand(367, 367)
|
||||
b = rand(367, 367)
|
||||
def test_big_prime_number_matmul(self):
|
||||
a = Tensor.rand(367, 367)
|
||||
b = Tensor.rand(367, 367)
|
||||
c = (a@b).realize()
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_big_prime_number_max(self):
|
||||
a = -rand(367, 367)
|
||||
b = rand(367, 367)
|
||||
a = -Tensor.rand(367, 367)
|
||||
b = Tensor.rand(367, 367)
|
||||
# if incorrectly padded 0, the max would be 0 instead of a negative number
|
||||
c = (a*b).max(1)
|
||||
np.testing.assert_allclose(c.numpy(), (a.numpy() * b.numpy()).max(1), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_big_prime_number_sum(self):
|
||||
a = rand(367, 367)
|
||||
b = rand(367, 367)
|
||||
a = Tensor.rand(367, 367)
|
||||
b = Tensor.rand(367, 367)
|
||||
# if incorrectly padded 0, the sum would be inf
|
||||
c = (a/b).sum(1).realize()
|
||||
np.testing.assert_allclose(c.numpy(), (a.numpy() / b.numpy()).sum(1), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_variable_big_prime_number(self):
|
||||
v = Variable("v", 1, 400).bind(367)
|
||||
a = rand(367, 367)
|
||||
b = rand(367, 367)
|
||||
a = Tensor.rand(367, 367)
|
||||
b = Tensor.rand(367, 367)
|
||||
c = (a.reshape(367, v) @ b.reshape(v, 367)).realize()
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_variable_shrink_prime_number(self):
|
||||
v = Variable("v", 1, 400).bind(367)
|
||||
a = rand(400, 367)
|
||||
a = Tensor.rand(400, 367)
|
||||
b = (a.shrink(((0,v), None))+1).reshape(367,367).realize()
|
||||
np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_no_mutate_rawbuffers(self):
|
||||
a = rand(3, 3).realize()
|
||||
a = Tensor.rand(3, 3).realize()
|
||||
desired = a.numpy() + 1
|
||||
a.assign(a+1)
|
||||
actual = a.numpy()
|
||||
@@ -64,7 +60,7 @@ class TestBeamSearch(unittest.TestCase):
|
||||
@unittest.skipIf(CI, "flaky. CL_OUT_OF_RESOURCES")
|
||||
def test_conv_beam(self):
|
||||
c = Conv2d(3, 16, (3,3))
|
||||
x = rand(1,3,32,32)
|
||||
x = Tensor.rand(1,3,32,32)
|
||||
with Timing():
|
||||
c(x).realize()
|
||||
|
||||
|
||||
@@ -18,8 +18,6 @@ def assert_jit_cache_len(fxn, expected_len):
|
||||
assert len(fxn.jit_cache) == expected_len
|
||||
else:
|
||||
assert len(fxn.jit_cache) == 1
|
||||
# until we have a better way of typing the prg in JitItem
|
||||
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
|
||||
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
|
||||
|
||||
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
|
||||
|
||||
@@ -10,28 +10,28 @@ def tensors_allocated():
|
||||
class TestGC(unittest.TestCase):
|
||||
|
||||
def test_gc(self):
|
||||
a = Tensor.rand(4, 4, requires_grad=True)
|
||||
a = Tensor.zeros(4, 4, requires_grad=True)
|
||||
b = Tensor.zeros(4, 4, requires_grad=True)
|
||||
(a*b).mean().backward()
|
||||
assert(tensors_allocated() > 0)
|
||||
del a,b
|
||||
assert(tensors_allocated() == 1) # one for Tensor._rng_counter
|
||||
assert(tensors_allocated() == 0)
|
||||
|
||||
def test_gc_complex(self):
|
||||
a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
|
||||
b = Tensor.rand(4, 4, requires_grad=True)
|
||||
assert(tensors_allocated() == 3)
|
||||
b = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
|
||||
assert(tensors_allocated() == 2)
|
||||
(a*b).mean().backward()
|
||||
assert(tensors_allocated() == 5)
|
||||
assert(tensors_allocated() == 4)
|
||||
del b
|
||||
assert(tensors_allocated() == 3)
|
||||
assert(tensors_allocated() == 2)
|
||||
b = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
|
||||
print(tensors_allocated())
|
||||
(a*b).mean().backward()
|
||||
print(tensors_allocated())
|
||||
assert(tensors_allocated() == 5)
|
||||
assert(tensors_allocated() == 4)
|
||||
del b
|
||||
assert(tensors_allocated() == 3)
|
||||
assert(tensors_allocated() == 2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -182,7 +182,7 @@ class TestJit(unittest.TestCase):
|
||||
a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed
|
||||
b = Tensor.randn(10, 10).realize()
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
Tensor._seed = 1234
|
||||
jf = TinyJit(f)
|
||||
res = set()
|
||||
for _ in range(5):
|
||||
@@ -190,7 +190,7 @@ class TestJit(unittest.TestCase):
|
||||
res.add(o1.numpy()[0][0])
|
||||
assert len(res) == 5, "All values should be different, rand works in jit."
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
Tensor._seed = 1234
|
||||
jf2 = TinyJit(f)
|
||||
res2 = set()
|
||||
for _ in range(5):
|
||||
@@ -199,7 +199,7 @@ class TestJit(unittest.TestCase):
|
||||
assert len(res2) == 5, "All values should be different, rand works in jit."
|
||||
assert res == res2, "Jit rand is not reproducible with the same seed"
|
||||
|
||||
Tensor.manual_seed(3421)
|
||||
Tensor._seed = 3421
|
||||
jf3 = TinyJit(f)
|
||||
res3 = set()
|
||||
for _ in range(5):
|
||||
|
||||
@@ -199,7 +199,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
np.testing.assert_allclose(np_c, r.numpy(), atol=tc_atol, rtol=tc_rtol)
|
||||
|
||||
def test_limit_dims_to_max_5d_global(self):
|
||||
t = Tensor.empty(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1
|
||||
t = Tensor.rand(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
|
||||
assert len(sched) == 1
|
||||
lin = Linearizer(*sched[0].ast)
|
||||
@@ -740,7 +740,7 @@ class TestLinearizerOpts(unittest.TestCase):
|
||||
|
||||
def test_padto_where(self):
|
||||
N = 17 * 17
|
||||
a = (Tensor.empty(N, N).max(axis=0, keepdim=True) > 1).where(1, 0)
|
||||
a = (Tensor.rand(N, N).max(axis=0, keepdim=True) > 1).where(1, 0)
|
||||
helper_linearizer_opt(a.max(0), [
|
||||
[Opt(OptOps.PADTO, 0, 32)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
|
||||
|
||||
@@ -5,7 +5,6 @@ from functools import partial
|
||||
import numpy as np
|
||||
import torch
|
||||
from tinygrad import nn, dtypes, Tensor
|
||||
from tinygrad.helpers import THREEFRY
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
# https://gist.github.com/devries/11405101
|
||||
@@ -42,7 +41,7 @@ def kstest(l1, l2):
|
||||
prob = ksprob((nesq + 0.12 + 0.11 / nesq) * d)
|
||||
return prob
|
||||
|
||||
def equal_distribution(tiny_func, torch_func=None, numpy_func=None, shape=(20, 23), alpha=0.04):
|
||||
def equal_distribution(tiny_func, torch_func=None, numpy_func=None, shape=(20, 23), alpha=0.05):
|
||||
Tensor.manual_seed(1337)
|
||||
torch.manual_seed(1337)
|
||||
np.random.seed(1337)
|
||||
@@ -61,7 +60,6 @@ class TestRandomness(unittest.TestCase):
|
||||
self.assertFalse(normal_test(Tensor.rand))
|
||||
self.assertTrue(equal_distribution(Tensor.rand, torch.rand, lambda x: np.random.rand(*x)))
|
||||
|
||||
@unittest.skipIf(THREEFRY.value, "broken with threefry")
|
||||
def test_rand_half(self):
|
||||
N = 128
|
||||
x = Tensor.rand((2, N, N), dtype=dtypes.half)
|
||||
@@ -73,16 +71,6 @@ class TestRandomness(unittest.TestCase):
|
||||
self.assertTrue(zeros.size > 0)
|
||||
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.float16), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
|
||||
|
||||
@unittest.skipIf(not THREEFRY.value, "not using threefry")
|
||||
def test_threefly_against_reference(self):
|
||||
Tensor.manual_seed(1337)
|
||||
# generated using
|
||||
# (jax.extend.random.threefry_2x32((np.uint32(1337), np.uint32(0x0)), np.arange(20, dtype=np.uint32)) >> 8).astype(float) / np.float32(2**24)
|
||||
jr = np.array([0.30984968, 0.42723763, 0.92448753, 0.27268296, 0.48820806, 0.29587173, 0.3213513, 0.05805135, 0.4954177, 0.23303074,
|
||||
0.62478125, 0.51861334, 0.24712527, 0.12718695, 0.5236074, 0.50704265, 0.9166272, 0.6918763, 0.6530086, 0.34640658])
|
||||
r = Tensor.rand(20).numpy()
|
||||
np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "need bfloat16 support")
|
||||
def test_rand_bfloat16(self):
|
||||
N = 128
|
||||
@@ -127,10 +115,16 @@ class TestRandomness(unittest.TestCase):
|
||||
lambda x: np.random.uniform(-1, 1, size=x) * math.sqrt(6 / (x[0] + math.prod(x[1:])))))
|
||||
|
||||
def test_kaiming_uniform(self):
|
||||
Tensor.manual_seed(1337)
|
||||
torch.manual_seed(1337)
|
||||
np.random.seed(1337)
|
||||
for shape in [(128, 64, 3, 3), (20, 24)]:
|
||||
self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), shape=shape))
|
||||
|
||||
def test_kaiming_normal(self):
|
||||
Tensor.manual_seed(1337)
|
||||
torch.manual_seed(1337)
|
||||
np.random.seed(1337)
|
||||
for shape in [(128, 64, 3, 3), (20, 24)]:
|
||||
self.assertTrue(equal_distribution(Tensor.kaiming_normal, lambda x: torch.nn.init.kaiming_normal_(torch.empty(x)), shape=shape))
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ class TestSchedule(unittest.TestCase):
|
||||
# run
|
||||
img = Tensor.rand(2,3,64,64)
|
||||
out = c1(img).elu()
|
||||
check_schedule(out, 1, [c1.weight, c1.bias, img])
|
||||
check_schedule(out, 1, [c1.weight, c1.bias])
|
||||
|
||||
def test_two_sum(self):
|
||||
img = Tensor.empty(64,64)
|
||||
@@ -336,7 +336,7 @@ class TestSchedule(unittest.TestCase):
|
||||
out = bn1(conv1(x)).relu()
|
||||
out = bn2(conv2(out))
|
||||
out = (out + x).relu()
|
||||
check_schedule(out, 2, [conv1.weight, conv2.weight])
|
||||
check_schedule(out, 4)
|
||||
|
||||
def test_contiguous_while_contiguous(self):
|
||||
x = Tensor.empty(1, 64, 32, 32)
|
||||
|
||||
@@ -378,17 +378,17 @@ class TestMoveTensor(unittest.TestCase):
|
||||
|
||||
class TestZeroShapeTensor(unittest.TestCase):
|
||||
def test_shape_stride(self):
|
||||
t = Tensor.empty(3, 2, 0)
|
||||
t = Tensor.rand(3, 2, 0)
|
||||
assert t.shape == (3, 2, 0)
|
||||
# numpy has stride 0, 0, 0; torch has stride 2, 1, 1
|
||||
assert t.lazydata.st.real_strides() == (0, 0, 1)
|
||||
|
||||
t = Tensor.empty(3, 0, 2)
|
||||
t = Tensor.rand(3, 0, 2)
|
||||
assert t.shape == (3, 0, 2)
|
||||
# numpy has stride 0, 0, 0; torch has stride 2, 2, 1
|
||||
assert t.lazydata.st.real_strides() == (0, 2, 1)
|
||||
|
||||
t = Tensor.empty(0, 0, 0)
|
||||
t = Tensor.rand(0, 0, 0)
|
||||
assert t.shape == (0, 0, 0)
|
||||
# numpy has stride 0, 0, 0; torch has stride 1, 1, 1
|
||||
assert t.lazydata.st.real_strides() == (0, 0, 1)
|
||||
|
||||
@@ -120,7 +120,7 @@ class TestSafetensors(unittest.TestCase):
|
||||
for dtype in dtypes.fields().values():
|
||||
if dtype in [dtypes.bfloat16]: continue # not supported in numpy
|
||||
path = temp(f"ones.{dtype}.safetensors")
|
||||
ones = Tensor(np.random.rand(10,10), dtype=dtype)
|
||||
ones = Tensor.rand((10,10), dtype=dtype)
|
||||
safe_save(get_state_dict(ones), path)
|
||||
np.testing.assert_equal(ones.numpy(), list(safe_load(path).values())[0].numpy())
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ def _time_program(variables:List[Variable], rdev:Compiled, lib:bytes, global_siz
|
||||
tms = []
|
||||
for _ in range(cnt):
|
||||
if clear_l2:
|
||||
with Context(DEBUG=0, BEAM=0): Tensor.ones(1024,1024).contiguous().realize()
|
||||
with Context(DEBUG=0): Tensor.rand(1024,1024).realize()
|
||||
tms.append(cast(float, car(rawbufs, var_vals, wait=True, do_update_stats=False))*factor)
|
||||
if early_stop is not None and early_stop < tms[-1]: break
|
||||
return tms
|
||||
|
||||
@@ -94,8 +94,7 @@ class ContextVar:
|
||||
def __gt__(self, x): return self.value > x
|
||||
def __lt__(self, x): return self.value < x
|
||||
|
||||
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
|
||||
WINO, THREEFRY = ContextVar("WINO", 0), ContextVar("THREEFRY", 0)
|
||||
DEBUG, IMAGE, WINO, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("WINO", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
|
||||
GRAPH, GRAPHPATH = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
|
||||
|
||||
class Timing(contextlib.ContextDecorator):
|
||||
|
||||
@@ -7,12 +7,11 @@ from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype, cast_scalar
|
||||
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
|
||||
from tinygrad.helpers import argfix, make_pair, IMAGE, DEBUG, WINO, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.features.multi import MultiLazyBuffer
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.realize import run_schedule, create_schedule
|
||||
|
||||
@@ -213,7 +212,7 @@ class Tensor:
|
||||
# ***** creation llop entrypoint *****
|
||||
|
||||
@staticmethod
|
||||
def _loadop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
|
||||
def _loadop(op, shape, device:Optional[Union[Tuple[str], str]]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
|
||||
if isinstance(device, tuple):
|
||||
return Tensor(MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \
|
||||
for d in device], None), device, dtype, **kwargs)
|
||||
@@ -223,34 +222,16 @@ class Tensor:
|
||||
def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs)
|
||||
|
||||
_seed: int = int(time.time())
|
||||
_rng_counter: Optional[Tensor] = None
|
||||
@staticmethod
|
||||
def manual_seed(seed=0): Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
||||
def manual_seed(seed=0): Tensor._seed = seed
|
||||
|
||||
@staticmethod
|
||||
def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, **kwargs):
|
||||
if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
||||
if not THREEFRY.value:
|
||||
if dtype == dtypes.bfloat16:
|
||||
return Tensor.rand(*shape, **kwargs, device=device, dtype=dtypes.float).cast(dtypes.bfloat16)
|
||||
return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, device=device, dtype=dtype, **kwargs)
|
||||
|
||||
# threefry
|
||||
if (num := prod((shape:=argfix(*shape)))) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
|
||||
counts = (Tensor.arange(num, device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)).realize().pad(((0,num%2),))
|
||||
Tensor._rng_counter.assign(Tensor._rng_counter + num).realize()
|
||||
|
||||
rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
|
||||
ks = [0x0, Tensor._seed ^ 0x0 ^ 0x1BD11BDA, Tensor._seed]
|
||||
|
||||
x = [(c := counts.chunk(2))[0] + ks[-1], c[1] + ks[0]]
|
||||
for i in range(5):
|
||||
for r in rotations[i % 2]: x[0], x[1] = (x0 := x[0] + x[1]), x0 ^ ((x[1] * (2 ** r)) + (x[1].div(2 ** (32 - r), upcast=False)))
|
||||
x = [(x[0] + ks[i % 3]), (x[1] + ks[(i + 1) % 3] + i + 1)]
|
||||
out = x[0].cat(x[1])[:num].div(2 ** 8, upcast=False).cast(dtypes.float32).div(2 ** 24)
|
||||
out = out.reshape(shape).cast(dtypes.default_float if dtype is None else dtype)
|
||||
out.requires_grad = kwargs.get("requires_grad")
|
||||
return out.contiguous()
|
||||
def rand(*shape, **kwargs):
|
||||
if kwargs.get("dtype") == dtypes.bfloat16:
|
||||
# TODO: remove this once we use threefry for rand.
|
||||
kwargs.pop("dtype")
|
||||
return Tensor.rand(*shape, **kwargs, dtype=dtypes.float).cast(dtypes.bfloat16)
|
||||
return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, **kwargs)
|
||||
|
||||
# ***** creation helper functions *****
|
||||
|
||||
@@ -267,7 +248,6 @@ class Tensor:
|
||||
@staticmethod
|
||||
def arange(start, stop=None, step=1, **kwargs):
|
||||
if stop is None: stop, start = start, 0
|
||||
assert all(isinstance(s, (int, float)) for s in (start, stop, step)), "symbolic arange not supported"
|
||||
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
||||
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
|
||||
|
||||
@@ -886,10 +866,10 @@ class Tensor:
|
||||
if not isinstance(x, Tensor) and x == 0.0: return mlops.Zero.apply(self)
|
||||
if not isinstance(x, Tensor) and x == -1.0: return -self
|
||||
return mlops.Mul.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x != 1.0 else self
|
||||
def div(self, x:Union[Tensor, Scalar], reverse=False, upcast=True) -> Tensor:
|
||||
def div(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
|
||||
x = self._to_const_val(x)
|
||||
if not isinstance(x, Tensor) and not reverse and x != 0 and upcast: return self.mul(1/x)
|
||||
if (isinstance(x, Tensor) and dtypes.is_float(x.dtype)) or not upcast: return mlops.Div.apply(*self._broadcasted(x, reverse))
|
||||
if not isinstance(x, Tensor) and not reverse and x != 0: return self.mul(1/x)
|
||||
if isinstance(x, Tensor) and dtypes.is_float(x.dtype): return mlops.Div.apply(*self._broadcasted(x, reverse))
|
||||
return mlops.Div.apply(*self.cast(least_upper_float(self.dtype))._broadcasted(x, reverse))
|
||||
def xor(self, x:Tensor, reverse=False) -> Tensor: return mlops.Xor.apply(*self._broadcasted(x, reverse))
|
||||
|
||||
@@ -1042,7 +1022,7 @@ if IMAGE:
|
||||
setattr(Tensor, "conv2d", image_conv2d)
|
||||
setattr(Tensor, "dot", image_dot)
|
||||
|
||||
# TODO: eventually remove this
|
||||
# TODO: remove the custom op and replace with threefry
|
||||
def custom_random(out:Buffer):
|
||||
Tensor._seed += 1
|
||||
if DEBUG >= 2: print(f"*** {out.device} rand seed {Tensor._seed} size {out.size:<15d} dtype {out.dtype}")
|
||||
|
||||
Reference in New Issue
Block a user