use shard api to eval resnet fast (#3136)

* use shard api to eval resnet fast

* to supports shard

* test to in multitensor
This commit is contained in:
George Hotz
2024-01-15 16:49:38 -08:00
committed by GitHub
parent ca0beeef38
commit cec0a7bc37
3 changed files with 29 additions and 21 deletions

View File

@@ -14,7 +14,8 @@ def eval_resnet():
# Resnet50-v1.5
from extra.models.resnet import ResNet50
tlog("imports")
Device.DEFAULT
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 6))]
for x in GPUS: Device[x]
tlog("got devices") # NOTE: this is faster with rocm-smi running
class ResnetRunner:
@@ -31,34 +32,32 @@ def eval_resnet():
x /= self.input_std
return self.mdl(x).argmax(axis=1).realize()
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 6))]
mdljit = [TinyJit(ResnetRunner(d)) for d in GPUS]
mdl = ResnetRunner(GPUS)
tlog("loaded models")
# evaluation on the mlperf classes of the validation set from imagenet
from examples.mlperf.dataloader import batch_load_resnet
iterator = batch_load_resnet(getenv("BS", 128), val=getenv("VAL", 1), shuffle=False)
def data_get(device):
iterator = batch_load_resnet(getenv("BS", 128*6), val=getenv("VAL", 1), shuffle=False)
def data_get():
x,y,cookie = next(iterator)
return x.to(device).realize(), y, cookie
return x.shard(GPUS, axis=0).realize(), y, cookie
n,d = 0,0
proc = [data_get(d) for d in GPUS]
proc = data_get()
tlog("loaded initial data")
st = time.perf_counter()
while proc is not None:
GlobalCounters.reset()
proc = [(m(x), y, c) for m,(x,y,c) in zip(mdljit, proc)] # this frees the images
proc = (mdl(proc[0]), proc[1], proc[2]) # this frees the images
run = time.perf_counter()
# load the next data here
try: next_proc = [data_get(d) for d in GPUS]
try: next_proc = data_get()
except StopIteration: next_proc = None
nd = time.perf_counter()
proc = [t.numpy() == y for t, y, _ in proc] # this realizes the models and frees the cookies
for match in proc:
n += match.sum()
d += len(match)
proc = proc[0].numpy() == proc[1] # this realizes the models and frees the cookies
n += proc.sum()
d += proc.size
et = time.perf_counter()
tlog(f"****** {n:5d}/{d:5d} {n*100.0/d:.2f}% -- {(run-st)*1000:7.2f} ms to enqueue, {(et-run)*1000:7.2f} ms to realize ({(nd-run)*1000:7.2f} ms fetching). {(len(match)*len(proc))/(et-st):8.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-st):5.2f} TFLOPS")
tlog(f"****** {n:5d}/{d:5d} {n*100.0/d:.2f}% -- {(run-st)*1000:7.2f} ms to enqueue, {(et-run)*1000:7.2f} ms to realize ({(nd-run)*1000:7.2f} ms fetching). {(len(proc))/(et-st):8.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-st):5.2f} TFLOPS")
st = et
proc, next_proc = next_proc, None
tlog("done")