needs_load in image correct

This commit is contained in:
Comma Device
2022-07-19 19:25:47 -07:00
parent 314d70ff17
commit aa00a3948e

View File

@@ -177,6 +177,7 @@ def compile(input, output_fn):
saved_binaries = set()
kernels_to_save = set()
kernels_to_not_save = set()
import pyopencl as cl
for self, args in local_cl_cache:
for i,a in enumerate(args[2:]):
@@ -186,7 +187,9 @@ def compile(input, output_fn):
if cl.kernel_arg_access_qualifier.READ_ONLY == access_qualifer or cl.kernel_arg_type_qualifier.CONST == type_qualifer:
kernels_to_save.add(a)
else:
kernels_to_save.discard(a)
# this is written to at some point, we don't have to save it
kernels_to_not_save.add(a)
kernels_to_save -= kernels_to_not_save
gobj = 0
for self, args in local_cl_cache:
@@ -226,7 +229,7 @@ def compile(input, output_fn):
if needs_load:
data = np.empty(a.size, dtype=np.uint8)
CL.enqueue_copy(data, a, is_blocking=True)
weights.append(data)
weights.append(data.tobytes())
elif isinstance(a, cl.Image):
needs_load = a in kernels_to_save
row_pitch = (a.shape[0]*4*2 + 63)//64 * 64
@@ -249,14 +252,14 @@ def compile(input, output_fn):
# multiple of 32 isn't enough
jdat['objects'].append({
"id": ptr, "needs_load": True, "size": size, "arg_type": "image2d_t",
"id": ptr, "needs_load": needs_load, "size": size, "arg_type": "image2d_t",
"width": a.shape[0], "height": a.shape[1], "row_pitch": row_pitch,
})
if needs_load:
data = np.empty(size, dtype=np.uint8)
CL.enqueue_copy(data, buf.cl, is_blocking=True)
weights.append(data)
weights.append(data.tobytes())
else:
raise Exception("unknown object", a)
#print(jdat['objects'][-1])