mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
This reverts commit 14018050c1.
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import numpy as np
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.dtype import _to_np_dtype
|
||||
from tinygrad import dtypes, Tensor
|
||||
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
|
||||
@@ -14,15 +13,12 @@ K = getenv("K", N)
|
||||
CNT = getenv("CNT", 10)
|
||||
ATOL = getenv("ATOL", 1e-4)
|
||||
RTOL = getenv("RTOL", 3e-2)
|
||||
INT_LOW = getenv("INT_LOW", 0)
|
||||
INT_HIGH = getenv("INT_HIGH", 10)
|
||||
|
||||
if __name__ == "__main__":
|
||||
def init_matrix(rows, cols):
|
||||
rng = np.random.default_rng()
|
||||
if dtype_in in dtypes.ints:
|
||||
return Tensor(rng.integers(INT_LOW, INT_HIGH, (rows, cols), dtype=_to_np_dtype(dtype_in))).realize()
|
||||
return Tensor(rng.random((rows, cols), dtype=np.float32).astype(_to_np_dtype(dtype_in))).realize()
|
||||
return Tensor.randint((rows, cols), dtype=dtype_in).realize()
|
||||
return Tensor.rand(rows, cols, dtype=dtype_in).realize()
|
||||
|
||||
a, b = init_matrix(M, K), init_matrix(K, N)
|
||||
for i in range(CNT):
|
||||
|
||||
Reference in New Issue
Block a user