mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
padding
This commit is contained in:
@@ -53,10 +53,10 @@ class CLProgram:
|
||||
vals:tuple[int, ...]=(), wait=False) -> float|None:
|
||||
for i,(b,_) in enumerate(bufs):
|
||||
if isinstance(dt:=self.buf_dtypes[i], ImageDType):
|
||||
# TODO: verify this is zero copy
|
||||
b = checked(
|
||||
try: b = checked(
|
||||
cl.clCreateImage(self.dev.context, cl.CL_MEM_READ_WRITE, cl.cl_image_format(cl.CL_RGBA, {2:cl.CL_HALF_FLOAT, 4:cl.CL_FLOAT}[dt.itemsize]),
|
||||
cl.cl_image_desc(cl.CL_MEM_OBJECT_IMAGE2D, dt.shape[1], dt.shape[0], buffer=b), None, status:=ctypes.c_int32()), status)
|
||||
except RuntimeError as e: raise ValueError(f"{i=} {dt=}") from e
|
||||
check(cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b)))
|
||||
for i,v in enumerate(vals,start=len(bufs)): check(cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v))))
|
||||
if local_size is not None: global_size = cast(tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size)))
|
||||
|
||||
@@ -3879,7 +3879,7 @@ class Tensor(OpMixin):
|
||||
return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
|
||||
|
||||
def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor:
|
||||
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
|
||||
base_image_type, dtsz = (dtypes.imageh, 2) if getenv("FLOAT16", 0) else (dtypes.imagef, 4)
|
||||
|
||||
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
|
||||
x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)
|
||||
@@ -3892,6 +3892,19 @@ class Tensor(OpMixin):
|
||||
w = w.pad_to(None, None, cin, None, None)
|
||||
x = x.pad_to(None, None, cin, None, None).reshape(bs, groups*cin, iy, ix)
|
||||
|
||||
# hack for pitch alignment
|
||||
added_width = 0
|
||||
if (ix*groups*cin) % (64 // dtsz):
|
||||
added_width = round_up(ix, 64 // (dtsz * math.gcd(groups * cin, 64 // dtsz))) - ix
|
||||
ix = ix + added_width
|
||||
x = x.pad_to(None, None, None, ix)
|
||||
|
||||
added_weight = 0
|
||||
if (H*W*cin) % (64 // dtsz):
|
||||
added_weight = round_up(H, 64 // (dtsz * math.gcd(W * cin, 64 // dtsz))) - H
|
||||
H = H + added_weight
|
||||
w = w.pad_to(None, None, None, H, None)
|
||||
|
||||
# hack for non multiples of 4 on rcout
|
||||
added_output_channels = 0
|
||||
if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
|
||||
@@ -3911,10 +3924,12 @@ class Tensor(OpMixin):
|
||||
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
|
||||
x, w = x.contiguous(), w.contiguous()
|
||||
|
||||
if added_weight: w, H = w[:, :-added_weight, :, :, :, :], H - added_weight
|
||||
|
||||
# expand out
|
||||
rcin_hi, rcin_lo = (cin//4, 4) if cin >= 4 else (1, 1)
|
||||
group_shape, rcout_expand = (groups//4, 4) if cin == 1 else (groups, 1), (rcout//4, 4) if rcout >= 4 else (1, 1)
|
||||
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
|
||||
x = x.reshape(bs, iy, -1, groups, rcin_hi, rcin_lo)
|
||||
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
|
||||
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
|
||||
|
||||
@@ -3933,6 +3948,11 @@ class Tensor(OpMixin):
|
||||
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
|
||||
cout = groups * (rcout - added_output_channels)
|
||||
|
||||
# undo pitch alignment hack
|
||||
if added_width:
|
||||
ret = ret.reshape(bs, oy, ox, groups, cout)[:, :, :-added_width, :, :]
|
||||
ox = ox - added_width
|
||||
|
||||
# NCHW output
|
||||
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
|
||||
|
||||
Reference in New Issue
Block a user