fixup run thneed

This commit is contained in:
Comma Device
2022-08-18 08:22:53 -07:00
parent b132de677d
commit 1f23517d92

View File

@@ -46,7 +46,6 @@ def load_thneed_model(fn="model.thneed", float32=False, replace=None):
ptr = nptr
if o['arg_type'] == "image2d_t" or o['arg_type'] == "image1d_t":
assert not o['needs_load']
if o['arg_type'] == "image2d_t":
if 'buffer_id' in o and o['height'] == 1 and not bufs_loaded[o['buffer_id']]:
# hack: use a image1d since we can back that with a buffer
@@ -63,6 +62,10 @@ def load_thneed_model(fn="model.thneed", float32=False, replace=None):
buf = cl.Image(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR,
cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT),
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=arr)
elif o['needs_load']:
buf = cl.Image(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR,
cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT),
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=o['data'])
else:
buf = cl.Image(ctx, mf.READ_WRITE, image_fmt, shape=(o['width'], o['height']))
if o['arg_type'] == "image1d_t":
@@ -86,6 +89,12 @@ def load_thneed_model(fn="model.thneed", float32=False, replace=None):
ptr = nptr
inputs, vnum, vision, outputs = [], [], [], []
for k in jdat['inputs'] if 'inputs' in jdat else []:
inputs.append(bufs[k['buffer_id']])
for k in jdat['outputs'] if 'outputs' in jdat else []:
outputs.append(bufs[k['buffer_id']])
for i,k in enumerate(jdat['kernels']):
if k['name'] == 'zero_pad_image_float':
inputs.append(bufs[k['args'][1]])
@@ -121,8 +130,8 @@ def load_thneed_model(fn="model.thneed", float32=False, replace=None):
for a,b in zip(real_inputs, inp):
if debug:
print(a.size, b.size*b.itemsize)
assert a.size == (b.size * b.itemsize) or float32
cl.enqueue_copy(q, a, np.array(b, dtype=np.float32 if a != vision[0] else np.float16))
#assert a.size == (b.size * b.itemsize) or float32
cl.enqueue_copy(q, a, np.array(b, dtype=np.float16 if len(vision) > 0 and a == vision[0] else np.float32))
#jdat['kernels'] = jdat['kernels'][0:8]