mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
test/fuzz_linearizer: add --atol/rtol and change half distribution (#4352)
This commit is contained in:
10
test/external/fuzz_linearizer.py
vendored
10
test/external/fuzz_linearizer.py
vendored
@@ -33,6 +33,8 @@ def get_fuzz_rawbufs(lin):
|
||||
data = np.random.randint(-100, 100, size=rawbuf.size, dtype=rawbuf.dtype.np)
|
||||
elif rawbuf.dtype == dtypes.bool:
|
||||
data = np.random.choice([True, False], size=rawbuf.size)
|
||||
elif rawbuf.dtype == dtypes.half:
|
||||
data = np.random.uniform(-1, 1, size=rawbuf.size).astype(dtype=rawbuf.dtype.np)
|
||||
else:
|
||||
data = np.random.uniform(-10, 10, size=rawbuf.size).astype(dtype=rawbuf.dtype.np)
|
||||
rawbuf.copyin(Tensor(data).realize().lazydata.realized.as_buffer())
|
||||
@@ -111,7 +113,7 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut
|
||||
|
||||
return ("PASS", rawbufs, var_vals, ground_truth,)
|
||||
|
||||
def fuzz_linearizer(lin: Linearizer):
|
||||
def fuzz_linearizer(lin: Linearizer, rtol=1e-2, atol=1e-2):
|
||||
SEED = getenv("SEED", 42)
|
||||
random.seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
@@ -153,7 +155,7 @@ def fuzz_linearizer(lin: Linearizer):
|
||||
|
||||
if not FUZZ_ALL_ACTIONS: print(test_lin.colored_shape())
|
||||
|
||||
(msg, rawbufs, var_vals, ground_truth) = compare_linearizer(test_lin, rawbufs, var_vals, ground_truth)
|
||||
(msg, rawbufs, var_vals, ground_truth) = compare_linearizer(test_lin, rawbufs, var_vals, ground_truth, rtol=rtol, atol=atol)
|
||||
if msg != "PASS":
|
||||
print(test_lin.ast)
|
||||
print(test_lin.applied_opts)
|
||||
@@ -178,6 +180,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--ast", type=str, default=None, help="the ast for the kernel to be optimized")
|
||||
parser.add_argument("--file", type=str, default=None, help="a file containing asts to be optimized, one per line")
|
||||
parser.add_argument("--expected-failures", type=int, default=0, help="the number of expected failed kernels")
|
||||
parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison")
|
||||
parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.ast is not None:
|
||||
@@ -206,7 +210,7 @@ if __name__ == "__main__":
|
||||
tested += 1
|
||||
lin = ast_str_to_lin(ast)
|
||||
|
||||
fuzz_failures = fuzz_linearizer(lin)
|
||||
fuzz_failures = fuzz_linearizer(lin, rtol=args.rtol, atol=args.atol)
|
||||
if fuzz_failures: failed_ids.append(i)
|
||||
for k, v in fuzz_failures.items():
|
||||
for f in v:
|
||||
|
||||
Reference in New Issue
Block a user