no decorators for image methods. move out RawMallocBuffer. -7 lines

This commit is contained in:
George Hotz
2023-03-12 16:28:45 -07:00
parent ed9ab6ff03
commit b512edc9ff
5 changed files with 90 additions and 94 deletions

View File

@@ -1,94 +1,88 @@
from tinygrad.helpers import IMAGE, prod
from tinygrad.helpers import prod, IMAGE
from tinygrad.lazy import get_single_root
def image_dot_decorator(normal_dot):
if IMAGE == 0: return normal_dot
def image_dot(self, w):
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
cin, cout = w.shape[-2], w.shape[-1]
out_shape_t = self.shape[0:-2] + (cout,-1)
if len(self.shape) > 1:
order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
else:
order, out_shape_t = (0,), (cout, )
worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
def image_dot(self, w):
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
cin, cout = w.shape[-2], w.shape[-1]
out_shape_t = self.shape[0:-2] + (cout,-1)
if len(self.shape) > 1:
order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
else:
order, out_shape_t = (0,), (cout, )
worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
# NOTE: with NHWC we can remove the transposes
# bs x groups*cin x H x W
cx = self.permute(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
# groups*cout x cin x H, W
cw = w.permute(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).permute(order=order)
return image_dot
# NOTE: with NHWC we can remove the transposes
# bs x groups*cin x H x W
cx = self.permute(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
# groups*cout x cin x H, W
cw = w.permute(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).permute(order=order)
def image_conv2d_decorator(normal_conv):
if IMAGE == 0: return normal_conv
def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0):
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
rcout = cout//groups
x, w = self, weight.reshape(groups, rcout, cin, H, W)
def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0):
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
rcout = cout//groups
x, w = self, weight.reshape(groups, rcout, cin, H, W)
# hack for non multiples of 4 on cin
if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
x = x.reshape(bs, groups, cin, iy, ix) # do this always?
added_input_channels = 4 - (cin % 4)
w = w.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(w.shape))))
x = x.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(x.shape))))
cin = cin + added_input_channels
x = x.reshape(bs, groups*cin, iy, ix)
# hack for non multiples of 4 on cin
if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
x = x.reshape(bs, groups, cin, iy, ix) # do this always?
added_input_channels = 4 - (cin % 4)
w = w.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(w.shape))))
x = x.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(x.shape))))
cin = cin + added_input_channels
x = x.reshape(bs, groups*cin, iy, ix)
# hack for non multiples of 4 on rcout
added_output_channels = 0
if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
added_output_channels = 4 - (rcout % 4)
rcout += added_output_channels
cout = groups * rcout
w = w.slice(tuple((0, rcout) if i == 1 else (0, w.shape[i]) for i in range(len(w.shape))))
# hack for non multiples of 4 on rcout
added_output_channels = 0
if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
added_output_channels = 4 - (rcout % 4)
rcout += added_output_channels
cout = groups * rcout
w = w.slice(tuple((0, rcout) if i == 1 else (0, w.shape[i]) for i in range(len(w.shape))))
# packed (note: flipping bs and iy would make the auto-padding work)
x = x.permute(0,2,3,1).reshape(bs * iy, ix * groups * cin//4, 4)
cin_last = iy == 1 and ix == 1
if cin == 1: w = w.reshape(cout//4,4,H*W).permute(0,2,1)
elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3).reshape(cout//4, H*cin//4*W*4, 4)
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1).reshape(cout//4, H*cin//4*W*4, 4)
# packed (note: flipping bs and iy would make the auto-padding work)
x = x.permute(0,2,3,1).reshape(bs * iy, ix * groups * cin//4, 4)
cin_last = iy == 1 and ix == 1
if cin == 1: w = w.reshape(cout//4,4,H*W).permute(0,2,1)
elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3).reshape(cout//4, H*cin//4*W*4, 4)
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1).reshape(cout//4, H*cin//4*W*4, 4)
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
x, w = x.contiguous(), w.contiguous()
if get_single_root(w.lazydata).realized: w.realize()
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
x, w = x.contiguous(), w.contiguous()
if get_single_root(w.lazydata).realized: w.realize()
# expand out
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
# expand out
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
# padding
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
# padding
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
# prepare input
x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
oy, ox = x.shape[4:6]
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, oy, ox, *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
x = x.expand(bs, oy, ox, *cout_expand, rcin_hi, rcin_lo, H, W)
# prepare input
x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
oy, ox = x.shape[4:6]
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, oy, ox, *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
x = x.expand(bs, oy, ox, *cout_expand, rcin_hi, rcin_lo, H, W)
# prepare weights
w = w.permute(0,4,2,5,1,3)
w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
# prepare weights
w = w.permute(0,4,2,5,1,3)
w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
# the conv!
ret = (x*w).sum((-4, -3, -2, -1)).reshape(bs*oy, ox*cout//4, 4)
if IMAGE >= 3: ret = ret.contiguous()
# the conv!
ret = (x*w).sum((-4, -3, -2, -1)).reshape(bs*oy, ox*cout//4, 4)
if IMAGE >= 3: ret = ret.contiguous()
# undo hack for non multiples of 4 on C.rcout
if added_output_channels != 0:
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
rcout -= added_output_channels
cout = groups * rcout
# undo hack for non multiples of 4 on C.rcout
if added_output_channels != 0:
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
rcout -= added_output_channels
cout = groups * rcout
# NCHW output
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
return image_conv2d
# NCHW output
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import functools, itertools, operator, random
import functools, itertools, operator, random, ctypes
import numpy as np
from enum import Enum, auto
from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional, Callable, Dict, TypeVar, Set, Final
@@ -59,6 +59,13 @@ class RawBufferMapped(RawBufferCopyIn):
def toCPU(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=self.dtype.np)
def copyin(self, x:np.ndarray) -> None: np.copyto(self.toCPU(), x.reshape(-1))
# this one is simple enough that i moved it out of the runtimes
class RawMallocBuffer(RawBufferMapped):
def __init__(self, size, dtype: DType):
super().__init__(size, dtype)
self._buf = ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16}[dtype] * size)()
def _buffer(self): return memoryview(self._buf)
class RawBufferCopyInOut(RawBufferCopyIn):
def copyout(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")

View File

@@ -1,14 +1,7 @@
import os, time, ctypes, hashlib, subprocess, platform
from tinygrad.helpers import dtypes, DType
from tinygrad.ops import CompiledBuffer, RawBufferMapped, Specialized
from tinygrad.ops import CompiledBuffer, Specialized, RawMallocBuffer
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
class RawMallocBuffer(RawBufferMapped):
def __init__(self, size, dtype: DType):
super().__init__(size, dtype)
self._buf = ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16}[dtype] * size)()
def _buffer(self): return memoryview(self._buf)
class ClangProgram:
def __init__(self, name:str, prg:str):
prg = "#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define half __fp16\n" + prg

View File

@@ -1,7 +1,6 @@
import time, hashlib, ctypes
from typing import ClassVar
from tinygrad.ops import CompiledBuffer, Specialized
from tinygrad.runtime.ops_clang import RawMallocBuffer
from tinygrad.ops import CompiledBuffer, Specialized, RawMallocBuffer
from tinygrad.helpers import getenv, DEBUG
from ctypes import CFUNCTYPE
from tinygrad.codegen.llvm import LLVMCodegen

View File

@@ -3,9 +3,8 @@ from __future__ import annotations
import math, functools, itertools
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG, flatten, DType, dtypes, LazyNumpyArray
from tinygrad.helpers import prod, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, LazyNumpyArray
from tinygrad.lazy import Device, LazyBuffer
from tinygrad.nn.image import image_conv2d_decorator, image_dot_decorator
# An instantiation of the Function is the Context
class Function:
@@ -324,7 +323,6 @@ class Tensor:
def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
def max_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
@image_conv2d_decorator
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
(bs,cin_,_,_), (cout,cin,H,W) = self.shape, weight.shape
assert cin*groups == cin_, f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({cin*groups} vs. {cin_})"
@@ -341,7 +339,6 @@ class Tensor:
ret = (x * weight.reshape(1, groups, rcout, 1, 1, cin, H, W)).sum((-3, -2, -1)).reshape(bs, cout, oy, ox)
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
@image_dot_decorator
def dot(self, w:Tensor) -> Tensor:
x = self.reshape(*self.shape[0:-1], 1, self.shape[-1])
w = w.reshape(*w.shape[0:-2], 1, w.shape[-2], w.shape[-1]).transpose(-1, -2)
@@ -460,3 +457,9 @@ class Tensor:
for device in Device._buffers:
setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device))
setattr(Tensor, f"{device.lower()}_", functools.partialmethod(Tensor.to_, device))
# if IMAGE>0 we install these replacement functions in Tensor (hack!)
from tinygrad.nn.image import image_conv2d, image_dot
if IMAGE:
setattr(Tensor, "conv2d", image_conv2d)
setattr(Tensor, "dot", image_dot)