mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
use shard api to eval resnet fast (#3136)
* use shard api to eval resnet fast * to supports shard * test to in multitensor
This commit is contained in:
@@ -14,7 +14,8 @@ def eval_resnet():
|
||||
# Resnet50-v1.5
|
||||
from extra.models.resnet import ResNet50
|
||||
tlog("imports")
|
||||
Device.DEFAULT
|
||||
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 6))]
|
||||
for x in GPUS: Device[x]
|
||||
tlog("got devices") # NOTE: this is faster with rocm-smi running
|
||||
|
||||
class ResnetRunner:
|
||||
@@ -31,34 +32,32 @@ def eval_resnet():
|
||||
x /= self.input_std
|
||||
return self.mdl(x).argmax(axis=1).realize()
|
||||
|
||||
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 6))]
|
||||
mdljit = [TinyJit(ResnetRunner(d)) for d in GPUS]
|
||||
mdl = ResnetRunner(GPUS)
|
||||
tlog("loaded models")
|
||||
|
||||
# evaluation on the mlperf classes of the validation set from imagenet
|
||||
from examples.mlperf.dataloader import batch_load_resnet
|
||||
iterator = batch_load_resnet(getenv("BS", 128), val=getenv("VAL", 1), shuffle=False)
|
||||
def data_get(device):
|
||||
iterator = batch_load_resnet(getenv("BS", 128*6), val=getenv("VAL", 1), shuffle=False)
|
||||
def data_get():
|
||||
x,y,cookie = next(iterator)
|
||||
return x.to(device).realize(), y, cookie
|
||||
return x.shard(GPUS, axis=0).realize(), y, cookie
|
||||
n,d = 0,0
|
||||
proc = [data_get(d) for d in GPUS]
|
||||
proc = data_get()
|
||||
tlog("loaded initial data")
|
||||
st = time.perf_counter()
|
||||
while proc is not None:
|
||||
GlobalCounters.reset()
|
||||
proc = [(m(x), y, c) for m,(x,y,c) in zip(mdljit, proc)] # this frees the images
|
||||
proc = (mdl(proc[0]), proc[1], proc[2]) # this frees the images
|
||||
run = time.perf_counter()
|
||||
# load the next data here
|
||||
try: next_proc = [data_get(d) for d in GPUS]
|
||||
try: next_proc = data_get()
|
||||
except StopIteration: next_proc = None
|
||||
nd = time.perf_counter()
|
||||
proc = [t.numpy() == y for t, y, _ in proc] # this realizes the models and frees the cookies
|
||||
for match in proc:
|
||||
n += match.sum()
|
||||
d += len(match)
|
||||
proc = proc[0].numpy() == proc[1] # this realizes the models and frees the cookies
|
||||
n += proc.sum()
|
||||
d += proc.size
|
||||
et = time.perf_counter()
|
||||
tlog(f"****** {n:5d}/{d:5d} {n*100.0/d:.2f}% -- {(run-st)*1000:7.2f} ms to enqueue, {(et-run)*1000:7.2f} ms to realize ({(nd-run)*1000:7.2f} ms fetching). {(len(match)*len(proc))/(et-st):8.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-st):5.2f} TFLOPS")
|
||||
tlog(f"****** {n:5d}/{d:5d} {n*100.0/d:.2f}% -- {(run-st)*1000:7.2f} ms to enqueue, {(et-run)*1000:7.2f} ms to realize ({(nd-run)*1000:7.2f} ms fetching). {(len(proc))/(et-st):8.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-st):5.2f} TFLOPS")
|
||||
st = et
|
||||
proc, next_proc = next_proc, None
|
||||
tlog("done")
|
||||
|
||||
Reference in New Issue
Block a user