This commit is contained in:
Christopher Milan
2025-12-17 23:13:28 +00:00
parent 530eb6e682
commit bf7fb2309a
2 changed files with 24 additions and 4 deletions

View File

@@ -53,10 +53,10 @@ class CLProgram:
vals:tuple[int, ...]=(), wait=False) -> float|None:
for i,(b,_) in enumerate(bufs):
if isinstance(dt:=self.buf_dtypes[i], ImageDType):
# TODO: verify this is zero copy
b = checked(
try: b = checked(
cl.clCreateImage(self.dev.context, cl.CL_MEM_READ_WRITE, cl.cl_image_format(cl.CL_RGBA, {2:cl.CL_HALF_FLOAT, 4:cl.CL_FLOAT}[dt.itemsize]),
cl.cl_image_desc(cl.CL_MEM_OBJECT_IMAGE2D, dt.shape[1], dt.shape[0], buffer=b), None, status:=ctypes.c_int32()), status)
except RuntimeError as e: raise ValueError(f"{i=} {dt=}") from e
check(cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b)))
for i,v in enumerate(vals,start=len(bufs)): check(cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v))))
if local_size is not None: global_size = cast(tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size)))

View File

@@ -3879,7 +3879,7 @@ class Tensor(OpMixin):
return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor:
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
base_image_type, dtsz = (dtypes.imageh, 2) if getenv("FLOAT16", 0) else (dtypes.imagef, 4)
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)
@@ -3892,6 +3892,19 @@ class Tensor(OpMixin):
w = w.pad_to(None, None, cin, None, None)
x = x.pad_to(None, None, cin, None, None).reshape(bs, groups*cin, iy, ix)
# hack for pitch alignment
added_width = 0
if (ix*groups*cin) % (64 // dtsz):
added_width = round_up(ix, 64 // (dtsz * math.gcd(groups * cin, 64 // dtsz))) - ix
ix = ix + added_width
x = x.pad_to(None, None, None, ix)
added_weight = 0
if (H*W*cin) % (64 // dtsz):
added_weight = round_up(H, 64 // (dtsz * math.gcd(W * cin, 64 // dtsz))) - H
H = H + added_weight
w = w.pad_to(None, None, None, H, None)
# hack for non multiples of 4 on rcout
added_output_channels = 0
if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
@@ -3911,10 +3924,12 @@ class Tensor(OpMixin):
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
x, w = x.contiguous(), w.contiguous()
if added_weight: w, H = w[:, :-added_weight, :, :, :, :], H - added_weight
# expand out
rcin_hi, rcin_lo = (cin//4, 4) if cin >= 4 else (1, 1)
group_shape, rcout_expand = (groups//4, 4) if cin == 1 else (groups, 1), (rcout//4, 4) if rcout >= 4 else (1, 1)
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
x = x.reshape(bs, iy, -1, groups, rcin_hi, rcin_lo)
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
@@ -3933,6 +3948,11 @@ class Tensor(OpMixin):
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
cout = groups * (rcout - added_output_channels)
# undo pitch alignment hack
if added_width:
ret = ret.reshape(bs, oy, ox, groups, cout)[:, :, :-added_width, :, :]
ox = ox - added_width
# NCHW output
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))