mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
fix handcode_resnet50_opt.py (#2558)
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
from typing import List
|
||||
from extra.models.resnet import ResNet50
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps, Device, Compiled
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.device import Device, Compiled
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.features.search import time_linearizer, beam_search
|
||||
from tinygrad.features.search import time_linearizer, beam_search, bufs_from_lin
|
||||
from tinygrad.helpers import ansilen, DEBUG, getenv
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
from tinygrad.shape.symbolic import sym_infer
|
||||
@@ -33,9 +34,7 @@ if __name__ == "__main__":
|
||||
total_tm = 0
|
||||
running_gflops = 0
|
||||
for i,si in enumerate(sched):
|
||||
# create output/input buffers (NOTE: bufs_from_lin is slower, so we don't use it. TODO: fix)
|
||||
rawbufs = [device.buffer(si.out.st.size(), si.out.dtype)] + [device.buffer(x.st.size(), x.dtype) for x in si.inputs]
|
||||
#rawbufs = bufs_from_lin(lin)
|
||||
rawbufs = bufs_from_lin(Linearizer(si.ast))
|
||||
|
||||
# "linearize" the op into uops in different ways
|
||||
lins:List[Linearizer] = []
|
||||
|
||||
Reference in New Issue
Block a user