diff --git a/Compiler/ml.py b/Compiler/ml.py index a1f1c4dc..30059d5b 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -885,7 +885,10 @@ class Dense(DenseBase): def compute_f_input(self, batch): N = len(batch) - prod = MultiArray([N, self.d, self.d_out], sfix) + if self.input_bias: + prod = MultiArray([N, self.d, self.d_out], sfix) + else: + prod = self.f_input # flattened_array version result_matrix = sfix.Matrix(N * self.d, self.d_out, address=prod.address)