mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
@@ -20,9 +20,9 @@ ConvArgs = namedtuple('ConvArgs', ['H', 'W', 'groups', 'rcout', 'cin', 'oy', 'ox
|
||||
def get_conv_args(x_shape, w_shape, stride=1, groups=1, padding=0, dilation=1, out_shape=None):
|
||||
# TODO: https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#tensor-layout
|
||||
cout,cin,H,W = w_shape
|
||||
sy,sx = (stride, stride) if isinstance(stride, int) else stride
|
||||
sy,sx = make_pair(stride)
|
||||
px,px_,py,py_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
|
||||
dy,dx = (dilation, dilation) if isinstance(dilation, int) else dilation
|
||||
dy,dx = make_pair(dilation)
|
||||
bs,cin_,iy,ix = x_shape
|
||||
|
||||
# this can change px_ and py_ to make the out_shape right
|
||||
@@ -50,4 +50,4 @@ def get_available_llops():
|
||||
_buffers[name] = [cls for cname, cls in inspect.getmembers(importlib.import_module('tinygrad.llops.'+op), inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
|
||||
except ImportError as e: # NOTE: this can't be put on one line due to mypy issue
|
||||
print(op, "not available", e)
|
||||
return _buffers, DEFAULT
|
||||
return _buffers, DEFAULT
|
||||
|
||||
Reference in New Issue
Block a user