fix: use proper dtype for bias during convolution

This commit is contained in:
Umut
2022-11-18 13:48:46 +01:00
parent 0ea77bde46
commit 58689d5806

View File

@@ -391,9 +391,19 @@ def _trace_or_eval(
if isinstance(x, Tracer):
return _trace_conv(x, weight, bias, pads, strides, dilations, group, conv_func)
bias = np.zeros(weight.shape[0]) if bias is None else bias
assert isinstance(x, np.ndarray)
assert isinstance(weight, np.ndarray)
dtype = (
np.float64
if np.issubdtype(x.dtype, np.floating) or np.issubdtype(weight.dtype, np.floating)
else np.int64
)
bias = np.zeros(weight.shape[0], dtype=dtype) if bias is None else bias
assert isinstance(bias, np.ndarray)
return _evaluate_conv(x, weight, bias, pads, strides, dilations, group, conv_func)