mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
76 lines
2.9 KiB
Python
76 lines
2.9 KiB
Python
from typing import List
|
|
from models.resnet import ResNet50
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.ops import LoadOps, Device, Compiled
|
|
from tinygrad.codegen.kernel import LinearizerOptions
|
|
from tinygrad.codegen.linearizer import Linearizer
|
|
from tinygrad.helpers import ansilen, DEBUG
|
|
from tinygrad.graph import print_tree
|
|
|
|
if __name__ == "__main__":
|
|
mdl = ResNet50()
|
|
seen = set()
|
|
|
|
# the device we are optimizing for
|
|
device: Compiled = Device[Device.DEFAULT]
|
|
print(f"optimizing for {Device.DEFAULT}")
|
|
|
|
# first model run to init the weights, they are saved in seen
|
|
mdl(Tensor.empty(64, 3, 224, 224)).lazydata.schedule(seen)
|
|
|
|
# run model again to get only what changes, these are the kernels of the model
|
|
x = Tensor.empty(64, 3, 224, 224)
|
|
out = mdl(x)
|
|
sched = out.lazydata.schedule(seen)
|
|
sched = [x for x in sched if x.ast.op not in LoadOps]
|
|
|
|
# work with the schedule
|
|
total_tm = 0
|
|
for i,si in enumerate(sched):
|
|
if DEBUG >= 2: print_tree(si.ast)
|
|
|
|
# enable only one kernel to focus on it
|
|
#if i != 1: continue
|
|
|
|
# "linearize" the op into uops in different ways
|
|
lins:List[Linearizer] = []
|
|
|
|
if Device.DEFAULT == "METAL" and i == 1:
|
|
# through careful work, we discovered 1,8,0
|
|
for big_chomp in [1,2]: #[1,2,4,8,16]:
|
|
for lil_chomp in [2,4,7,8,14]:
|
|
for upcasted in [0,1,2]:
|
|
lin = Linearizer(si.ast, device.linearizer_opts)
|
|
lin.reshape_and_permute(lambda x: (4096//big_chomp,big_chomp,56//lil_chomp,lil_chomp,56//lil_chomp,lil_chomp)+x[-2:], [0,2,4,1,3,5,6,7])
|
|
lin.upcasted += upcasted
|
|
lin.local_dims += 3
|
|
lins.append(lin)
|
|
else:
|
|
# try with and without tensor cores
|
|
for tc in [0,1]:
|
|
lin = Linearizer(si.ast, device.linearizer_opts)
|
|
lin.hand_coded_optimizations(use_tensor_cores=tc)
|
|
lins.append(lin)
|
|
|
|
# create output/input buffers
|
|
rawbufs = [device.buffer(si.out.st.size(), si.out.dtype)] + [device.buffer(x.st.size(), x.dtype) for x in si.inputs]
|
|
|
|
# benchmark the programs
|
|
choices = []
|
|
for lin in lins:
|
|
prg = device.to_program(lin)
|
|
|
|
# benchmark it by running 10 times
|
|
try:
|
|
tm = min([prg(rawbufs, force_wait=True) for _ in range(10)])
|
|
choices.append((tm, lin))
|
|
except AssertionError:
|
|
tm = float('inf')
|
|
|
|
# print all kernels
|
|
if DEBUG >= 1: print(f" kernel {i:2d} {lin.display_name+' '*(37-ansilen(lin.display_name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {lin.info.flops*1e-9/tm:6.0f} GFLOPS")
|
|
tm, lin = sorted(choices, key=lambda x: x[0])[0]
|
|
print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.display_name+' '*(37-ansilen(lin.display_name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {lin.info.flops*1e-9/tm:6.0f} GFLOPS")
|
|
total_tm += tm
|
|
print(f"******* total {total_tm*1000:.2f} ms")
|