mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
cleanup long lines [pr] (#12623)
* cleanup long lines * more * a few more * all noqa fixed * fix amd + cuda * clean that up
This commit is contained in:
@@ -165,15 +165,22 @@ def beam_search(lin:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=Tr
|
||||
if isinstance(e, RuntimeError): continue
|
||||
raise
|
||||
timed_lins.append((acted_lins[i], min(tms)))
|
||||
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {time_to_str(compile_et, w=12)} compile/{time_to_str(timed_lins[-1][1], w=12)} run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
|
||||
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {time_to_str(timed_lins[-1][1], w=12)} {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
||||
if BEAM_DEBUG > 1:
|
||||
print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops",
|
||||
f"{time_to_str(compile_et, w=12)} compile/{time_to_str(timed_lins[-1][1], w=12)} run",
|
||||
f" {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}")
|
||||
elif DEBUG >= 2:
|
||||
print(f"\r{time.perf_counter() - st:7.2f}s: {time_to_str(timed_lins[-1][1], w=12)}",
|
||||
f" {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="")
|
||||
|
||||
# done
|
||||
opts = sorted(timed_lins, key=lambda x: x[1])
|
||||
exiting = len(opts) == 0 or (opts[0][1] < min_progress) or (len(beam) > 0 and ((beam[0][1]-opts[0][1]) < min_progress))
|
||||
if not exiting: beam = opts[:amt]
|
||||
elif len(opts) > 0 and opts[0][1] < beam[0][1]: beam = opts[:1]
|
||||
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(time_to_str(beam[0][1], w=12), "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
|
||||
if DEBUG >= 2:
|
||||
print(f"\r{time.perf_counter() - st:7.2f}s:", colored(time_to_str(beam[0][1], w=12), "green" if exiting else None),
|
||||
f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape())
|
||||
except KeyboardInterrupt as e:
|
||||
if beam_pool is not None: beam_pool.terminate()
|
||||
raise e
|
||||
|
||||
@@ -23,10 +23,13 @@ def argfix(*x):
|
||||
if len(x) != 1: raise ValueError(f"bad arg {x}")
|
||||
return tuple(x[0])
|
||||
return x
|
||||
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
||||
# https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
||||
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__))
|
||||
def all_same(items:tuple[T, ...]|list[T]): return all(x == items[0] for x in items)
|
||||
def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstance(s, int) for s in t)
|
||||
def colored(st, color:str|None, background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
|
||||
def colored(st, color:str|None, background=False): # replace the termcolor library
|
||||
colors = ['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white']
|
||||
return f"\u001b[{10*background+60*(color.upper() == color)+30+colors.index(color.lower())}m{st}\u001b[0m" if color is not None else st
|
||||
def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
|
||||
def time_to_str(t:float, w=8) -> str: return next((f"{t * d:{w}.2f}{pr}" for d,pr in [(1, "s "),(1e3, "ms")] if t > 10/d), f"{t * 1e6:{w}.2f}us")
|
||||
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
|
||||
@@ -218,11 +221,12 @@ class TracingKey:
|
||||
class ProfileEvent: pass
|
||||
|
||||
@dataclass
|
||||
class ProfileRangeEvent(ProfileEvent): device:str; name:str|TracingKey; st:decimal.Decimal; en:decimal.Decimal|None=None; is_copy:bool=False # noqa: E702
|
||||
class ProfileRangeEvent(ProfileEvent):
|
||||
device:str; name:str|TracingKey; st:decimal.Decimal; en:decimal.Decimal|None=None; is_copy:bool=False # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProfilePointEvent(ProfileEvent): device:str; name:str; key:Any; arg:dict=field(default_factory=dict); \
|
||||
ts:decimal.Decimal=field(default_factory=perf_counter_us) # noqa: E702
|
||||
class ProfilePointEvent(ProfileEvent):
|
||||
device:str; name:str; key:Any; arg:dict=field(default_factory=dict); ts:decimal.Decimal=field(default_factory=perf_counter_us) # noqa: E702
|
||||
|
||||
cpu_events:list[ProfileEvent] = []
|
||||
@contextlib.contextmanager
|
||||
@@ -281,7 +285,8 @@ def diskcache_put(table:str, key:dict|str|int, val:Any, prepickled=False):
|
||||
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
|
||||
cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
|
||||
_db_tables.add(table)
|
||||
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key))}, ?)", tuple(key.values()) + (val if prepickled else pickle.dumps(val), )) # noqa: E501
|
||||
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key))}, ?)",
|
||||
tuple(key.values()) + (val if prepickled else pickle.dumps(val),))
|
||||
conn.commit()
|
||||
cur.close()
|
||||
return val
|
||||
|
||||
@@ -108,7 +108,9 @@ class CStyleLanguage(Renderer):
|
||||
extra_matcher = extra_pm
|
||||
|
||||
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
|
||||
tmp = ""
|
||||
if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs):
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
|
||||
buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else
|
||||
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
|
||||
local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]
|
||||
@@ -229,10 +231,12 @@ class ClangRenderer(CStyleLanguage):
|
||||
# 'static' in C roughly means that function symbol isn't exported. LLVM puts those symbols at the end of object file which allows Clang JIT
|
||||
# to just jump at the start of a shellcode without having to deal with symbols or trampolines at all. This is better than having to inline
|
||||
# wmma function every time it is called or wasting complexity on a symbol parsing and a memory page on trampoline.
|
||||
prefix += [f"""static {(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
|
||||
out, dt1, dt2 = self.render_dtype(dtype_in.vec(N*N)), self.render_dtype(dtype_in.vec(N)), self.render_dtype(dtype_in.vec(M))
|
||||
prefix += [f"""static {out} __{name}({dt1} data1, {dt2} data2, {out} data0){{
|
||||
AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
|
||||
AMX(0, (int *)(&data2), 0ull<<62); AMX(1, (int *)(&data1), 0ull<<62); AMX(12, 0, 0ull);
|
||||
for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}\n AMX_SET(1);\n return data0;\n}}"""] # noqa: E501
|
||||
for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
|
||||
AMX_SET(1);\n return data0;\n}}"""]
|
||||
return prefix
|
||||
def _render_body(self, function_name, kernel, bufs, uops, pref=None) -> str: return super().render_kernel(function_name, kernel, bufs, uops, pref)
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str: return ""
|
||||
|
||||
@@ -23,10 +23,12 @@ class CLCompiler(Compiler):
|
||||
build_status: int = cl.clBuildProgram(program, 1, self.dev.device_id, None, cl.clBuildProgram.argtypes[4](), None)
|
||||
if build_status != 0:
|
||||
cl.clGetProgramBuildInfo(program, self.dev.device_id, cl.CL_PROGRAM_BUILD_LOG, 0, None, log_size := ctypes.c_size_t())
|
||||
cl.clGetProgramBuildInfo(program, self.dev.device_id, cl.CL_PROGRAM_BUILD_LOG, log_size.value, mstr := ctypes.create_string_buffer(log_size.value), None) # noqa: E501
|
||||
cl.clGetProgramBuildInfo(program, self.dev.device_id, cl.CL_PROGRAM_BUILD_LOG,
|
||||
log_size.value, mstr := ctypes.create_string_buffer(log_size.value), None)
|
||||
raise CompileError(f"OpenCL Compile Error\n\n{mstr.value.decode()}")
|
||||
check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARY_SIZES, ctypes.sizeof(ctypes.c_size_t), binary_sizes := (ctypes.c_size_t * 1)(), None))
|
||||
check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARIES, ctypes.sizeof(ctypes.c_void_p), (ctypes.c_void_p * 1)(ctypes.addressof(binary := ctypes.create_string_buffer(binary_sizes[0]))), None)) # noqa: E501
|
||||
check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARIES, ctypes.sizeof(ctypes.c_void_p),
|
||||
(ctypes.c_void_p * 1)(ctypes.addressof(binary := ctypes.create_string_buffer(binary_sizes[0]))), None))
|
||||
check(cl.clReleaseProgram(program))
|
||||
return bytes(binary)
|
||||
|
||||
@@ -97,16 +99,22 @@ class CLDevice(Compiled):
|
||||
err = cl.clGetDeviceIDs(platform_ids[0], device_type, 0, None, num_devices := ctypes.c_uint32())
|
||||
if err == 0 and num_devices.value != 0: break
|
||||
if DEBUG >= 1: print(f"CLDevice: got {num_platforms.value} platforms and {num_devices.value} devices")
|
||||
CLDevice.device_ids = init_c_var((cl.cl_device_id * num_devices.value)(), lambda x: check(cl.clGetDeviceIDs(platform_ids[0], device_type, num_devices, x, None))) # noqa: E501
|
||||
CLDevice.device_ids = init_c_var((cl.cl_device_id * num_devices.value)(),
|
||||
lambda x: check(cl.clGetDeviceIDs(platform_ids[0], device_type, num_devices, x, None)))
|
||||
|
||||
self.device_id = CLDevice.device_ids[0 if ":" not in device else int(device.split(":")[1])]
|
||||
self.device_name = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_NAME, 256, buf := ctypes.create_string_buffer(256), None), buf.value.decode())[1] # noqa: E501
|
||||
self.driver_version = (cl.clGetDeviceInfo(self.device_id, cl.CL_DRIVER_VERSION, 256, buf := ctypes.create_string_buffer(256), None), buf.value.decode())[1] # noqa: E501
|
||||
self.device_name = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_NAME, 256,
|
||||
buf:=ctypes.create_string_buffer(256), None), buf.value.decode())[1]
|
||||
self.driver_version = (cl.clGetDeviceInfo(self.device_id, cl.CL_DRIVER_VERSION, 256,
|
||||
buf:=ctypes.create_string_buffer(256), None), buf.value.decode())[1]
|
||||
if DEBUG >= 1: print(f"CLDevice: opening {self.device_name} with version {self.driver_version}")
|
||||
self.context = checked(cl.clCreateContext(None, 1, self.device_id, cl.clCreateContext.argtypes[3](), None, status := ctypes.c_int32()), status)
|
||||
self.queue = checked(cl.clCreateCommandQueue(self.context, self.device_id, cl.CL_QUEUE_PROFILING_ENABLE, status), status)
|
||||
self.pending_copyin: list[memoryview] = []
|
||||
self.device_exts = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_EXTENSIONS, 4096, ctypes.byref(buf := ctypes.create_string_buffer(4096)), ctypes.byref(total := ctypes.c_size_t())), ctypes.string_at(buf, size=total.value).decode())[1] # noqa: E501
|
||||
self.device_exts = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_EXTENSIONS, 4096,
|
||||
ctypes.byref(buf := ctypes.create_string_buffer(4096)),
|
||||
ctypes.byref(total := ctypes.c_size_t())),
|
||||
ctypes.string_at(buf, size=total.value).decode())[1]
|
||||
|
||||
compilers = [(IntelRenderer if "cl_intel_subgroup_matrix_multiply_accumulate" in self.device_exts else OpenCLRenderer,
|
||||
functools.partial(CLCompiler, self, f"compile_cl_{hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()}"))]
|
||||
|
||||
@@ -10,7 +10,9 @@ if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401 # pylint:
|
||||
if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.cuda import cuda # type: ignore # pylint: disable=reimported
|
||||
|
||||
def check(status):
|
||||
if status != 0: raise RuntimeError(f"CUDA Error {status}, {ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()}") # noqa: E501
|
||||
if status != 0:
|
||||
error = ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()
|
||||
raise RuntimeError(f"CUDA Error {status}, {error}")
|
||||
|
||||
def encode_args(args, vals) -> tuple[ctypes.Structure, ctypes.Array]:
|
||||
c_args = init_c_struct_t(tuple([(f'f{i}', cuda.CUdeviceptr_v2) for i in range(len(args))] +
|
||||
|
||||
@@ -60,7 +60,11 @@ def compile_hip(prg:str, arch="gfx1100", asm=False) -> bytes:
|
||||
check(comgr.amd_comgr_set_data_name(data_src, b"<null>"))
|
||||
check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
|
||||
# -include hiprtc_runtime.h was removed
|
||||
check(set_options(action_info, f"-O3 -mcumode --hip-version=6.0.32830 -DHIP_VERSION_MAJOR=6 -DHIP_VERSION_MINOR=0 -DHIP_VERSION_PATCH=32830 -D__HIPCC_RTC__ -std=c++14 -nogpuinc -Wno-gnu-line-marker -Wno-missing-prototypes --offload-arch={arch} -I/opt/rocm/include -Xclang -disable-llvm-passes -Xclang -aux-triple -Xclang x86_64-unknown-linux-gnu".encode())) # noqa: E501
|
||||
options = [
|
||||
"-O3", "-mcumode", "--hip-version=6.0.32830", "-DHIP_VERSION_MAJOR=6", "-DHIP_VERSION_MINOR=0", "-DHIP_VERSION_PATCH=32830",
|
||||
"-D__HIPCC_RTC__", "-std=c++14", "-nogpuinc", "-Wno-gnu-line-marker", "-Wno-missing-prototypes", f"--offload-arch={arch}",
|
||||
"-I/opt/rocm/include", "-Xclang -disable-llvm-passes", "-Xclang -aux-triple", "-Xclang x86_64-unknown-linux-gnu"]
|
||||
check(set_options(action_info, ' '.join(options).encode()))
|
||||
status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, action_info, data_set_src, data_set_bc)
|
||||
if status != 0:
|
||||
print(_get_comgr_data(data_set_bc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())
|
||||
|
||||
@@ -22,10 +22,12 @@ def jitlink_check(status, ctx=None):
|
||||
|
||||
def pretty_ptx(s):
|
||||
# all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
|
||||
s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers # noqa: E501
|
||||
s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])',
|
||||
lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers
|
||||
s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types
|
||||
s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions
|
||||
s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers # noqa: E501
|
||||
s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])',
|
||||
lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers
|
||||
s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
|
||||
s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
|
||||
return s
|
||||
|
||||
@@ -2484,17 +2484,20 @@ class Tensor(MathTrait):
|
||||
if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, dtype)
|
||||
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
||||
padding_ = self._resolve_pool_pads(padding, len(HW))
|
||||
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
|
||||
assert groups*cin == cin_ and len(self.shape) == len(weight.shape),\
|
||||
f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
|
||||
|
||||
# conv2d is a pooling op (with padding)
|
||||
x = self.pad(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
|
||||
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
|
||||
if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not WINO:
|
||||
# normal conv
|
||||
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501
|
||||
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW)\
|
||||
.permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))])
|
||||
|
||||
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
|
||||
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, dtype=dtype).reshape(bs, cout, *oyx) # noqa: E501
|
||||
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW))\
|
||||
.sum([-1-i for i in range(1+len(oyx))], keepdim=True, dtype=dtype).reshape(bs, cout, *oyx)
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
|
||||
|
||||
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
|
||||
@@ -2505,7 +2508,8 @@ class Tensor(MathTrait):
|
||||
# TODO: stride == dilation
|
||||
# use padding to round up to 4x4 output tiles
|
||||
# (bs, cin_, tyx, HWI)
|
||||
d = self.pad(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
|
||||
pads = [[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])]
|
||||
d = self.pad(sum(pads, []))._pool(HWI, HWO)
|
||||
# move HW to the front: # (HWI, bs, cin_, tyx)
|
||||
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW)))
|
||||
tyx = d.shape[-len(HWI):] # dim of tiling
|
||||
|
||||
Reference in New Issue
Block a user