Files
tinygrad/examples/handcode_resnet50_opt.py
2023-10-07 13:22:20 -07:00

76 lines
2.9 KiB
Python

from typing import List
from models.resnet import ResNet50
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps, Device, Compiled
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.helpers import ansilen, DEBUG
from tinygrad.graph import print_tree
if __name__ == "__main__":
mdl = ResNet50()
seen = set()
# the device we are optimizing for
device: Compiled = Device[Device.DEFAULT]
print(f"optimizing for {Device.DEFAULT}")
# first model run to init the weights, they are saved in seen
mdl(Tensor.empty(64, 3, 224, 224)).lazydata.schedule(seen)
# run model again to get only what changes, these are the kernels of the model
x = Tensor.empty(64, 3, 224, 224)
out = mdl(x)
sched = out.lazydata.schedule(seen)
sched = [x for x in sched if x.ast.op not in LoadOps]
# work with the schedule
total_tm = 0
for i,si in enumerate(sched):
if DEBUG >= 2: print_tree(si.ast)
# enable only one kernel to focus on it
#if i != 1: continue
# "linearize" the op into uops in different ways
lins:List[Linearizer] = []
if Device.DEFAULT == "METAL" and i == 1:
# through careful work, we discovered 1,8,0
for big_chomp in [1,2]: #[1,2,4,8,16]:
for lil_chomp in [2,4,7,8,14]:
for upcasted in [0,1,2]:
lin = Linearizer(si.ast, device.linearizer_opts)
lin.reshape_and_permute(lambda x: (4096//big_chomp,big_chomp,56//lil_chomp,lil_chomp,56//lil_chomp,lil_chomp)+x[-2:], [0,2,4,1,3,5,6,7])
lin.upcasted += upcasted
lin.local_dims += 3
lins.append(lin)
else:
# try with and without tensor cores
for tc in [0,1]:
lin = Linearizer(si.ast, device.linearizer_opts)
lin.hand_coded_optimizations(use_tensor_cores=tc)
lins.append(lin)
# create output/input buffers
rawbufs = [device.buffer(si.out.st.size(), si.out.dtype)] + [device.buffer(x.st.size(), x.dtype) for x in si.inputs]
# benchmark the programs
choices = []
for lin in lins:
prg = device.to_program(lin)
# benchmark it by running 10 times
try:
tm = min([prg(rawbufs, force_wait=True) for _ in range(10)])
choices.append((tm, lin))
except AssertionError:
tm = float('inf')
# print all kernels
if DEBUG >= 1: print(f" kernel {i:2d} {lin.display_name+' '*(37-ansilen(lin.display_name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {lin.info.flops*1e-9/tm:6.0f} GFLOPS")
tm, lin = sorted(choices, key=lambda x: x[0])[0]
print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.display_name+' '*(37-ansilen(lin.display_name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {lin.info.flops*1e-9/tm:6.0f} GFLOPS")
total_tm += tm
print(f"******* total {total_tm*1000:.2f} ms")