[TESTING] better cudagraph-based benchmarking (#2394)

This commit is contained in:
Philippe Tillet
2023-09-25 21:41:26 -07:00
committed by GitHub
parent 80adbbb87b
commit eea0718445
2 changed files with 62 additions and 63 deletions

View File

@@ -29,22 +29,22 @@ matmul_data = {
# NOTE:
'a100': {
# square
(512, 512, 512): {'float16': 0.061, 'float32': 0.097, 'int8': 0.05},
(1024, 1024, 1024): {'float16': 0.283, 'float32': 0.313, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.618, 'float32': 0.532, 'int8': 0.34},
(8192, 8192, 8192): {'float16': 0.786, 'float32': 0.754, 'int8': 0.51},
(512, 512, 512): {'float16': 0.108, 'float32': 0.097, 'int8': 0.05},
(1024, 1024, 1024): {'float16': 0.355, 'float32': 0.313, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.653, 'float32': 0.532, 'int8': 0.34},
(8192, 8192, 8192): {'float16': 0.839, 'float32': 0.754, 'int8': 0.51},
# tall-skinny
(16, 1024, 1024): {'float16': 0.006, 'float32': 0.009, 'int8': 0.005},
(16, 4096, 4096): {'float16': 0.057, 'float32': 0.051, 'int8': 0.026},
(16, 8192, 8192): {'float16': 0.077, 'float32': 0.077, 'int8': 0.043},
(64, 1024, 1024): {'float16': 0.018, 'float32': 0.023, 'int8': 0.017},
(64, 4096, 4096): {'float16': 0.150, 'float32': 0.000, 'int8': 0.097},
(64, 8192, 8192): {'float16': 0.214, 'float32': 0.000, 'int8': 0.174},
(1024, 64, 1024): {'float16': 0.029, 'float32': 0.046, 'int8': 0.017},
(4096, 64, 4096): {'float16': 0.136, 'float32': 0.214, 'int8': 0.102},
(8192, 64, 8192): {'float16': 0.278, 'float32': 0.000, 'int8': 0.177},
(16, 1024, 1024): {'float16': 0.015, 'float32': 0.009, 'int8': 0.005},
(16, 4096, 4096): {'float16': 0.080, 'float32': 0.051, 'int8': 0.026},
(16, 8192, 8192): {'float16': 0.083, 'float32': 0.077, 'int8': 0.043},
(64, 1024, 1024): {'float16': 0.045, 'float32': 0.023, 'int8': 0.017},
(64, 4096, 4096): {'float16': 0.170, 'float32': 0.000, 'int8': 0.097},
(64, 8192, 8192): {'float16': 0.227, 'float32': 0.000, 'int8': 0.174},
(1024, 64, 1024): {'float16': 0.040, 'float32': 0.046, 'int8': 0.017},
(4096, 64, 4096): {'float16': 0.160, 'float32': 0.214, 'int8': 0.102},
(8192, 64, 8192): {'float16': 0.272, 'float32': 0.000, 'int8': 0.177},
# test EVEN_K==False
(8192, 8192, 8176): {'float16': 0.786, 'float32': 0.743, 'int8': 0.51},
(8192, 8192, 8176): {'float16': 0.828, 'float32': 0.743, 'int8': 0.51},
}
}
@@ -100,15 +100,15 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements,
elementwise_data = {
'a100': {
1024 * 16: {'float16': 0.003, 'float32': 0.007},
1024 * 64: {'float16': 0.013, 'float32': 0.026},
1024 * 256: {'float16': 0.053, 'float32': 0.105},
1024 * 1024: {'float16': 0.212, 'float32': 0.420},
1024 * 16384: {'float16': 0.762, 'float32': 0.812},
1024 * 65536: {'float16': 0.846, 'float32': 0.869},
1024 * 16: {'float16': 0.031, 'float32': 0.060},
1024 * 64: {'float16': 0.120, 'float32': 0.224},
1024 * 256: {'float16': 0.394, 'float32': 0.691},
1024 * 1024: {'float16': 1.06, 'float32': 1.453},
1024 * 16384: {'float16': 0.832, 'float32': 0.862},
1024 * 65536: {'float16': 0.873, 'float32': 0.882},
# Non pow 2
1020 * 100: {'float16': 0.020, 'float32': 0.041},
10003 * 7007: {'float16': 0.513, 'float32': 0.861},
1020 * 100: {'float16': 0.173, 'float32': 0.327},
10003 * 7007: {'float16': 0.522, 'float32': 0.873},
}
}
@@ -143,30 +143,30 @@ def test_elementwise(N, dtype_str):
flash_attention_data = {
"a100": {
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.532,
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.542,
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.471,
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.150,
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.204,
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.155,
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.203,
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.202,
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.089,
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.298,
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.263,
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.095,
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.136,
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.108,
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.306,
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.266,
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.098,
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.134,
(4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135,
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.052,
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.525,
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.066,
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.541,
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.471,
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.150,
(4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.265,
(4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.257,
(4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.128,
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.297,
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.263,
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.095,
(4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.263,
(4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.255,
(4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.144,
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.306,
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.266,
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.098,
(4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159,
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.138,
(4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.076,
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.136,
(4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.088,
}
}
@@ -238,8 +238,8 @@ def _sum(x_ptr, y_ptr, output_ptr, n_elements,
reduction_data = {
'a100': {
1024 * 16384: {'float16': 0.016, 'float32': 0.031, 'int16': 0.015, 'int32': 0.031},
1024 * 65536: {'float16': 0.016, 'float32': 0.032, 'int16': 0.015, 'int32': 0.032},
1024 * 16384: {'float16': 0.016, 'float32': 0.031, 'int16': 0.022, 'int32': 0.048},
1024 * 65536: {'float16': 0.016, 'float32': 0.032, 'int16': 0.022, 'int32': 0.049},
}
}