mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-06 12:44:58 -05:00
debug: add optional detailed BEAM_LOG logging (#3883)
* debug: add optional detailed BEAM_LOG logging show uop count, compile and run times for each candidate in search also add --timing to verify_kernel.py to make it easier to explore hand-crafted applied opts * fix linter
This commit is contained in:
8
test/external/verify_kernel.py
vendored
8
test/external/verify_kernel.py
vendored
@@ -4,6 +4,7 @@ from extra.optimization.helpers import kern_str_to_lin
|
||||
from test.external.fuzz_linearizer import compare_linearizer
|
||||
from tinygrad.helpers import colored
|
||||
from tinygrad.features.graph import print_tree
|
||||
from tinygrad.features.search import time_linearizer
|
||||
|
||||
# Use this with the LOGKERN options to verify that all executed kernels are valid and evaluate to the same ground truth results
|
||||
|
||||
@@ -17,6 +18,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--file", type=str, default=None, help="a file containing a tuple of ast and applied_opts, one per line")
|
||||
parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison")
|
||||
parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison")
|
||||
parser.add_argument("--timing", action='store_true', help="show final timing for the kernel")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.kernel is not None:
|
||||
@@ -38,9 +40,13 @@ if __name__ == "__main__":
|
||||
test_lin = kern_str_to_lin(kern_str)
|
||||
for op in test_lin.ast: print_tree(op)
|
||||
print(test_lin.colored_shape())
|
||||
if (msg:=compare_linearizer(test_lin, None, None, None, rtol=args.rtol, atol=args.atol)[0]) != "PASS":
|
||||
(msg,rb,vv,gt) = compare_linearizer(test_lin, None, None, None, rtol=args.rtol, atol=args.atol)
|
||||
if msg != "PASS":
|
||||
failed_ids.append(i)
|
||||
failures[msg].append((test_lin.ast, test_lin.applied_opts))
|
||||
if args.timing:
|
||||
tm = time_linearizer(test_lin, rb, allow_test_size=False, cnt=10)
|
||||
print(f"final time {tm*1e6:9.0f} us")
|
||||
|
||||
for msg, errors in failures.items():
|
||||
for i, (ast, opts) in enumerate(errors):
|
||||
|
||||
Reference in New Issue
Block a user