This commit is contained in:
George Hotz
2022-08-06 15:21:26 +00:00
parent 27f209b80b
commit f300caa486
2 changed files with 25 additions and 3 deletions

View File

@@ -4,22 +4,44 @@ Generic folded reduce may not work.
GPUs:
AMD:
RDNA2: https://developer.amd.com/wp-content/resources/RDNA2_Shader_ISA_November2020.pdf
We have RX6900XT with 80 CU, 40 WGP, and 20 "processors"
@ 1.825 GHz, there's 18,688 FP32 GFLOPS of compute. 10240 FLOPS/cycle, 128 per CU
We have RX6900XT with 80 CU, 40 WGP, and 1 "processor"
@ 1.825 GHz, there's 18,688 FP32 GFLOPS of compute. 10240 FLOPS/cycle, 128 per CU (32 FMAs per vALU, 2 per compute unit)
286 GFLOP for ENET=2 BS=64. At theoretical max, (286/18688)*1000 = 15.3 ms.
We observe about 10x factor off with pytorch.
We will focus on speed for AMD, since we have complete docs for that GPU.
Each "processor" has an "ultra threaded dispatch processor"
M1:
On M1 GPU, theoretical is 2.275 TFLOPS. https://www.notebookcheck.net/Apple-M1-GPU-Benchmarks-and-Specs.503610.0.html
We observe 2000ms for BS=8 (37 GFLOP). 37/2275 = 11.9 ms. tinygrad is over a factor of 100x off (similar on AMD GPU)
NOTE: the timer in the M1 OpenCL doesn't seem to be anywhere close to wall time.
Adreno:
TBD, no comma three here. Image > Buffer because the L1 cache is used. Would UBWC help on weights?
We have a good bit of work on this in hyperthneed. Let's get the disassembler out and make this fast.
TPUs:
These use really big systolic arrays and have a lot less flexibility.
IIRC, their vector math unit is similar to the GPU.

View File

@@ -58,7 +58,7 @@ class CLProgram:
if DEBUG >= 1:
CL.time_sum += 0 if DEBUG <= 1 or CL.CACHE is not None else (e.profile.end - e.profile.start)
CL.ops_sum += op_estimate
print(f"**CL** {CL.kernel_count:6d} {self.name:20s} args {len(args[2:]):5d} size {prod(args[0]):8d} kernels {str(args[0]):18s} {str(args[1]):12s} OPs {op_estimate/1e6:5.1f}M/{CL.ops_sum/1e9:7.2f}G " + \
print(f"**CL** {CL.kernel_count:6d} {self.name:20s} args {len(args[2:]):5d} kernels {str(args[0]):18s} {str(args[1]):12s} OPs {op_estimate/1e6:6.1f}M/{CL.ops_sum/1e9:7.2f}G " + \
("" if DEBUG <= 1 or CL.CACHE is not None else f"tm {(e.profile.end - e.profile.start)/1e3:9.2f}us/{CL.time_sum/1e6:9.2f}ms ({op_estimate/(e.profile.end - e.profile.start):8.2f} GFLOPS)"))
if DEBUG >= 4: print(self.prg)