save weights

This commit is contained in:
Comma Device
2022-07-19 19:14:14 -07:00
parent 46e7dfade1
commit b8a67905e5

View File

@@ -19,7 +19,7 @@ import numpy as np
import tinygrad.ops as ops
from tinygrad.llops.ops_gpu import CL
from tinygrad.llops.ops_gpu import CL, CLProgram, CLBuffer
from extra.utils import fetch
from extra.onnx import get_run_onnx
from test.test_onnx import run_onnx_torch
@@ -176,8 +176,19 @@ def compile(input, output_fn):
saved_objs = set()
saved_binaries = set()
gobj = 0
kernels_to_save = set()
import pyopencl as cl
for self, args in local_cl_cache:
for i,a in enumerate(args[2:]):
access_qualifer = self.clprg.get_arg_info(i, cl.kernel_arg_info.ACCESS_QUALIFIER)
type_qualifer = self.clprg.get_arg_info(i, cl.kernel_arg_info.TYPE_QUALIFIER)
type_name = self.clprg.get_arg_info(i, cl.kernel_arg_info.TYPE_NAME)
if cl.kernel_arg_access_qualifier.READ_ONLY == access_qualifer or cl.kernel_arg_type_qualifier.CONST == type_qualifer:
kernels_to_save.add(a)
else:
kernels_to_save.discard(a)
gobj = 0
for self, args in local_cl_cache:
#if self.name not in jdat['programs']:
# jdat['programs'][self.name] = {"src": self.prg, "options": ' '.join(self.options)}
@@ -208,17 +219,39 @@ def compile(input, output_fn):
ptr = struct.pack("Q", a.global_id).decode("latin_1")
if ptr not in saved_objs:
if isinstance(a, cl.Buffer):
needs_load = a in kernels_to_save
jdat['objects'].append({
"id": ptr, "arg_type": "float*", "needs_load": False, "size": a.size,
"id": ptr, "arg_type": "float*", "needs_load": needs_load, "size": a.size,
})
if needs_load:
data = np.empty(a.size, dtype=np.uint8)
CL.enqueue_copy(data, a, is_blocking=True)
weights.append(data)
elif isinstance(a, cl.Image):
# multiple of 32 isn't enough
needs_load = a in kernels_to_save
row_pitch = (a.shape[0]*4*2 + 63)//64 * 64
size = row_pitch * a.shape[1]
buf = CLBuffer(size)
CLProgram("from_image_strided", """
__kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) {
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 l;
l.y = get_global_id(1);
l.x = get_global_id(0);
out[l.y*row_pitch + l.x] = read_imagef(in, smp, l);
}
""", argdtypes=(None, None, np.int32))(a.shape, None, a, buf.cl, row_pitch)
# multiple of 32 isn't enough
jdat['objects'].append({
"id": ptr, "needs_load": False, "size": size, "arg_type": "image2d_t",
"id": ptr, "needs_load": True, "size": size, "arg_type": "image2d_t",
"width": a.shape[0], "height": a.shape[1], "row_pitch": row_pitch,
})
if needs_load:
data = np.empty(size, dtype=np.uint8)
CL.enqueue_copy(data, buf.cl, is_blocking=True)
weights.append(data)
else:
raise Exception("unknown object", a)
#print(jdat['objects'][-1])