mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 22:38:16 -05:00
use default_float.np to construct test data in test_ops (#3701)
first step of #2797
This commit is contained in:
@@ -61,10 +61,12 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
|
||||
(shps, torch_fp*1000, tinygrad_fp*1000, torch_fbp*1000, tinygrad_fbp*1000), end="")
|
||||
|
||||
def prepare_test_op(low, high, shps, vals, forward_only=False):
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
if shps is None: ts = [torch.tensor(x, requires_grad=(not forward_only)) for x in vals]
|
||||
else: ts = [torch.tensor(np.random.uniform(low=low, high=high, size=x), requires_grad=(not forward_only), dtype=torch.float32) for x in shps]
|
||||
if shps is None:
|
||||
ts = [torch.tensor(x, requires_grad=(not forward_only)) for x in vals]
|
||||
else:
|
||||
np.random.seed(0)
|
||||
np_data = [np.random.uniform(low=low, high=high, size=size).astype(dtypes.default_float.np) for size in shps]
|
||||
ts = [torch.tensor(data, requires_grad=(not forward_only)) for data in np_data]
|
||||
tst = [Tensor(x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts]
|
||||
return ts, tst
|
||||
|
||||
|
||||
Reference in New Issue
Block a user