mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fixup run thneed
This commit is contained in:
@@ -46,7 +46,6 @@ def load_thneed_model(fn="model.thneed", float32=False, replace=None):
|
||||
ptr = nptr
|
||||
|
||||
if o['arg_type'] == "image2d_t" or o['arg_type'] == "image1d_t":
|
||||
assert not o['needs_load']
|
||||
if o['arg_type'] == "image2d_t":
|
||||
if 'buffer_id' in o and o['height'] == 1 and not bufs_loaded[o['buffer_id']]:
|
||||
# hack: use a image1d since we can back that with a buffer
|
||||
@@ -63,6 +62,10 @@ def load_thneed_model(fn="model.thneed", float32=False, replace=None):
|
||||
buf = cl.Image(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR,
|
||||
cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT),
|
||||
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=arr)
|
||||
elif o['needs_load']:
|
||||
buf = cl.Image(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR,
|
||||
cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT),
|
||||
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=o['data'])
|
||||
else:
|
||||
buf = cl.Image(ctx, mf.READ_WRITE, image_fmt, shape=(o['width'], o['height']))
|
||||
if o['arg_type'] == "image1d_t":
|
||||
@@ -86,6 +89,12 @@ def load_thneed_model(fn="model.thneed", float32=False, replace=None):
|
||||
ptr = nptr
|
||||
|
||||
inputs, vnum, vision, outputs = [], [], [], []
|
||||
|
||||
for k in jdat['inputs'] if 'inputs' in jdat else []:
|
||||
inputs.append(bufs[k['buffer_id']])
|
||||
for k in jdat['outputs'] if 'outputs' in jdat else []:
|
||||
outputs.append(bufs[k['buffer_id']])
|
||||
|
||||
for i,k in enumerate(jdat['kernels']):
|
||||
if k['name'] == 'zero_pad_image_float':
|
||||
inputs.append(bufs[k['args'][1]])
|
||||
@@ -121,8 +130,8 @@ def load_thneed_model(fn="model.thneed", float32=False, replace=None):
|
||||
for a,b in zip(real_inputs, inp):
|
||||
if debug:
|
||||
print(a.size, b.size*b.itemsize)
|
||||
assert a.size == (b.size * b.itemsize) or float32
|
||||
cl.enqueue_copy(q, a, np.array(b, dtype=np.float32 if a != vision[0] else np.float16))
|
||||
#assert a.size == (b.size * b.itemsize) or float32
|
||||
cl.enqueue_copy(q, a, np.array(b, dtype=np.float16 if len(vision) > 0 and a == vision[0] else np.float32))
|
||||
|
||||
#jdat['kernels'] = jdat['kernels'][0:8]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user