diff --git a/Compiler/ml.py b/Compiler/ml.py index 9101a57c..5ac85e03 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -725,350 +725,6 @@ class DenseBase(Layer): N = len(batch) tmp = Matrix(self.d_in, self.d_out, unreduced_sfix) - A = sfix.Matrix(N, self.d_out, address=f_schur_Y.address) - B = sfix.Matrix(self.N, self.d_in, address=self.X.address) - - @multithread(self.n_threads, self.d_in) - def _(base, size): - mp = B.direct_trans_mul(A, reduce=False, - indices=(regint.inc(size, base), - batch.get_vector(), - regint.inc(N), - regint.inc(self.d_out))) - tmp.assign_part_vector(mp, base) - - progress('nabla W (matmul)') - - @multithread(self.n_threads, self.d_in * self.d_out, - max_size=get_program().budget) - def _(base, size): - self.nabla_W.assign_vector( - tmp.get_vector(base, size).reduce_after_mul(), base=base) - - if self.print_random_update: - print_ln('backward %s', self) - i = regint.get_random(64) % self.d_in - j = regint.get_random(64) % self.d_out - print_ln('%s at (%s, %s): before=%s after=%s A=%s B=%s', - str(self.nabla_W), i, j, tmp[i][j].v.reveal(), - self.nabla_W[i][j].reveal(), - A.get_column(j).reveal(), - B.get_column_by_row_indices( - batch.get_vector(), i).reveal()) - print_ln('batch=%s B=%s', batch, - [self.X[bi][0][i].reveal() for bi in batch]) - - progress('nabla W') - - self.nabla_b.assign_vector(sum(sum(f_schur_Y[k][j].get_vector() - for k in range(N)) - for j in range(self.d))) - - progress('nabla b') - - if self.debug_output: - print_ln('dense nabla Y %s', self.nabla_Y.reveal_nested()) - print_ln('dense W %s', self.W.reveal_nested()) - print_ln('dense nabla X %s', self.nabla_X.reveal_nested()) - if self.debug: - limit = N * self.debug - @for_range_opt(self.d_in) - def _(i): - @for_range_opt(self.d_out) - def _(j): - to_check = self.nabla_W[i][j].reveal() - check = sum(to_check > limit) + sum(to_check < -limit) - @if_(check) - def _(): - print_ln('nabla W %s %s %s: %s', i, j, self.W.sizes, to_check) - print_ln('Y %s', [f_schur_Y[k][0][j].reveal() - for k in range(N)]) - print_ln('X %s', [self.X[k][0][i].reveal() - for k in range(N)]) - @for_range_opt(self.d_out) - def _(j): - to_check = self.nabla_b[j].reveal() - check = sum(to_check > limit) + sum(to_check < -limit) - @if_(check) - def _(): - print_ln('nabla b %s %s: %s', j, len(self.b), to_check) - print_ln('Y %s', [f_schur_Y[k][0][j].reveal() - for k in range(N)]) - @for_range_opt(len(batch)) - def _(i): - to_check = self.nabla_X[i].get_vector().reveal() - check = sum(to_check > limit) + sum(to_check < -limit) - @if_(check) - def _(): - print_ln('X %s %s', i, self.X[i].reveal_nested()) - print_ln('Y %s %s', i, f_schur_Y[i].reveal_nested()) - -class Dense(DenseBase): - """ Fixed-point dense (matrix multiplication) layer. - - :param N: number of examples - :param d_in: input dimension - :param d_out: output dimension - """ - def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False): - if activation == 'id': - self.activation_layer = None - elif activation == 'relu': - self.activation_layer = Relu([N, d, d_out]) - elif activation == 'square': - self.activation_layer = Square([N, d, d_out]) - else: - raise CompilerError('activation not supported: %s', activation) - - self.N = N - self.d_in = d_in - self.d_out = d_out - self.d = d - self.activation = activation - - self.X = Tensor([N, d, d_in], sfix) - self.Y = Tensor([N, d, d_out], sfix) - self.W = Tensor([d_in, d_out], sfix) - self.b = sfix.Array(d_out) - - back_N = min(N, self.back_batch_size) - self.nabla_Y = Tensor([back_N, d, d_out], sfix) - self.nabla_X = Tensor([back_N, d, d_in], sfix) - self.nabla_W = Tensor([d_in, d_out], sfix) - self.nabla_b = sfix.Array(d_out) - - self.debug = debug - - l = self.activation_layer - if l: - self.f_input = l.X - l.Y = self.Y - l.nabla_Y = self.nabla_Y - else: - self.f_input = self.Y - - def __repr__(self): - return '%s(%s, %s, %s, activation=%s)' % \ - (type(self).__name__, self.N, self.d_in, - self.d_out, repr(self.activation)) - - def reset(self): - d_in = self.d_in - d_out = self.d_out - r = math.sqrt(6.0 / (d_in + d_out)) - print('Initializing dense weights in [%f,%f]' % (-r, r)) - self.W.randomize(-r, r, n_threads=self.n_threads) - self.b.assign_all(0) - - def input_from(self, player, **kwargs): - self.W.input_from(player, **kwargs) - if self.input_bias: - self.b.input_from(player, **kwargs) - - def compute_f_input(self, batch): - N = len(batch) - assert self.d == 1 - if self.input_bias: - prod = MultiArray([N, self.d, self.d_out], sfix) - else: - prod = self.f_input - max_size = get_program().budget - @multithread(self.n_threads, N, max_size) - def _(base, size): - X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address) - prod.assign_part_vector( - X_sub.direct_mul(self.W, indices=( - batch.get_vector(base, size), regint.inc(self.d_in), - regint.inc(self.d_in), regint.inc(self.d_out))), base) - - if self.input_bias: - if self.d_out == 1: - @multithread(self.n_threads, N) - def _(base, size): - v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size) - self.f_input.assign_vector(v, base) - else: - @for_range_multithread(self.n_threads, 100, N) - def _(i): - v = prod[i].get_vector() + self.b.get_vector() - self.f_input[i].assign_vector(v) - progress('f input') - - def _forward(self, batch=None): - if batch is None: - batch = regint.Array(self.N) - batch.assign(regint.inc(self.N)) - self.compute_f_input(batch=batch) - if self.activation_layer: - self.activation_layer.forward(batch, training=True) - if self.debug_output: - print_ln('dense X %s', self.X.reveal_nested()) - print_ln('dense W %s', self.W.reveal_nested()) - print_ln('dense b %s', self.b.reveal_nested()) - print_ln('dense Y %s', self.Y.reveal_nested()) - if self.debug: - limit = self.debug - @for_range_opt(len(batch)) - def _(i): - @for_range_opt(self.d_out) - def _(j): - to_check = self.Y[i][0][j].reveal() - check = to_check > limit - @if_(check) - def _(): - print_ln('dense Y %s %s %s %s', i, j, self.W.sizes, to_check) - print_ln('X %s', self.X[i].reveal_nested()) - print_ln('W %s', - [self.W[k][j].reveal() for k in range(self.d_in)]) - - def backward(self, compute_nabla_X=True, batch=None): - N = len(batch) - d = self.d - d_out = self.d_out - X = self.X - Y = self.Y - W = self.W - b = self.b - nabla_X = self.nabla_X - nabla_Y = self.nabla_Y - nabla_W = self.nabla_W - nabla_b = self.nabla_b - - if self.activation_layer: - self.activation_layer.backward(batch) - f_schur_Y = self.activation_layer.nabla_X - else: - f_schur_Y = nabla_Y - - if compute_nabla_X: - nabla_X.alloc() - @multithread(self.n_threads, N) - def _(base, size): - B = sfix.Matrix(N, d_out, address=f_schur_Y.address) - nabla_X.assign_part_vector( - B.direct_mul_trans(W, indices=(regint.inc(size, base), - regint.inc(self.d_out), - regint.inc(self.d_out), - regint.inc(self.d_in))), - base) - - if self.print_random_update: - print_ln('backward %s', self) - index = regint.get_random(64) % self.nabla_X.total_size() - print_ln('%s nabla_X at %s: %s', str(self.nabla_X), - index, self.nabla_X.to_array()[index].reveal()) - - progress('nabla X') - - self.backward_params(f_schur_Y, batch=batch) - - -class FlexDense(Dense): - """ Fixed-point dense (matrix multiplication) layer with flexible number of dimensions. - Behaves like torch's nn.Linear which loops over the additional dimensions. - - :param N: number of examples - :param d_in: input dimension - :param d_out: output dimension - """ - - def compute_f_input(self, batch): - N = len(batch) - prod = MultiArray([N, self.d, self.d_out], sfix) - - # flattened_array version - result_matrix = sfix.Matrix(N * self.d, self.d_out, address=prod.address) - max_size = get_program().budget - - # X is stored at full dataset indices, batch specifies which samples to use - X_sub = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address) - - # Precompute batch_d_indices for all N*d elements - # For each sample in batch, expand to d consecutive indices - batch_d_indices = regint.Array(N * self.d) - @for_range(N) - def _(i): - actual_sample = batch[i] - batch_d_indices.assign(regint.inc(self.d, actual_sample * self.d), i * self.d) - # @for_range(self.d) - # def _(d_idx): - # batch_d_indices[i * self.d + d_idx] = actual_sample * self.d + d_idx - - @multithread(self.n_threads, N * self.d, max_size) - def _(base, size): - result_matrix.assign_part_vector( - X_sub.direct_mul(self.W, indices=( - batch_d_indices.get_vector(base, size), regint.inc(self.d_in), - regint.inc(self.d_in), regint.inc(self.d_out))), base) - - if self.input_bias: - if self.d_out == 1: - @multithread(self.n_threads, N) - def _(base, size): - v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size) - self.f_input.assign_vector(v, base) - else: - @for_range_multithread(self.n_threads, 100, [N, self.d]) - def _(i, j): - v = prod[i][j].get_vector() + self.b.get_vector() - # print_ln("running bias %s %s %s", i, j, v.reveal()) - self.f_input[i][j].assign_vector(v) - - # print_ln("FlexDense f_inpu full %s", self.f_input.reveal()) - - progress('f input') - - def backward(self, compute_nabla_X=True, batch=None): - N = len(batch) - d = self.d - d_out = self.d_out - X = self.X - Y = self.Y - W = self.W - b = self.b - nabla_X = self.nabla_X - nabla_Y = self.nabla_Y - nabla_W = self.nabla_W - nabla_b = self.nabla_b - - if self.activation_layer: - self.activation_layer.backward(batch) - f_schur_Y = self.activation_layer.nabla_X - else: - f_schur_Y = nabla_Y - - if compute_nabla_X: - nabla_X.alloc() - - # flattened matrix version - result_matrix = sfix.Matrix(N * self.d, self.d_in, address=nabla_X.address) - # Note: f_schur_Y is stored at indices [0, 1, ..., N-1] not at actual batch indices - @multithread(self.n_threads, N * self.d) - def _(base, size): - X_sub = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address) - offset = regint.inc(size, base=base) - - result_matrix.assign_part_vector( - X_sub.direct_mul_trans(self.W, indices=( - offset, regint.inc(self.d_out), - regint.inc(self.d_out), regint.inc(self.d_in))), base) - - if self.print_random_update: - print_ln('backward %s', self) - index = regint.get_random(64) % self.nabla_X.total_size() - print_ln('%s nabla_X at %s: %s', str(self.nabla_X), - index, self.nabla_X.to_array()[index].reveal()) - - progress('nabla X') - - self.backward_params(f_schur_Y, batch=batch) - - def backward_params(self, f_schur_Y, batch): - print("backward params flexdense") - N = len(batch) - tmp = Matrix(self.d_in, self.d_out, unreduced_sfix) - # tmp.assign_all(0) - # A (f_schur_Y/nabla_Y) is stored at sequential batch indices [0, 1, ..., N-1] A = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address) # B (X) is stored at the full dataset size, not just the batch @@ -1082,7 +738,7 @@ class FlexDense(Dense): @for_range(N) def _(i): # batch[i] gives the actual sample index in the full dataset - # For FlexDense with d>1, we need to map to flattened indices + # For Dense with d>1, we need to map to flattened indices actual_sample_idx = batch[i] @for_range(self.d) def _(d_idx): @@ -1162,6 +818,182 @@ class FlexDense(Dense): print_ln('X %s %s', i, self.X[i].reveal_nested()) print_ln('Y %s %s', i, f_schur_Y[i].reveal_nested()) +class Dense(DenseBase): + """ Fixed-point dense (matrix multiplication) layer. + Supports inputs of size [N, d, d_in] to map to [N, d, d_out]. If d > 1, the layer + behaves like torch's nn.Linear which loops over the additional d + + :param N: number of examples + :param d_in: input dimension + :param d_out: output dimension + :param d: (optional) extra dimension + """ + def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False): + if activation == 'id': + self.activation_layer = None + elif activation == 'relu': + self.activation_layer = Relu([N, d, d_out]) + elif activation == 'square': + self.activation_layer = Square([N, d, d_out]) + else: + raise CompilerError('activation not supported: %s', activation) + + self.N = N + self.d_in = d_in + self.d_out = d_out + self.d = d + self.activation = activation + + self.X = Tensor([N, d, d_in], sfix) + self.Y = Tensor([N, d, d_out], sfix) + self.W = Tensor([d_in, d_out], sfix) + self.b = sfix.Array(d_out) + + back_N = min(N, self.back_batch_size) + self.nabla_Y = Tensor([back_N, d, d_out], sfix) + self.nabla_X = Tensor([back_N, d, d_in], sfix) + self.nabla_W = Tensor([d_in, d_out], sfix) + self.nabla_b = sfix.Array(d_out) + + self.debug = debug + + l = self.activation_layer + if l: + self.f_input = l.X + l.Y = self.Y + l.nabla_Y = self.nabla_Y + else: + self.f_input = self.Y + + def __repr__(self): + return '%s(%s, %s, %s, activation=%s)' % \ + (type(self).__name__, self.N, self.d_in, + self.d_out, repr(self.activation)) + + def reset(self): + d_in = self.d_in + d_out = self.d_out + r = math.sqrt(6.0 / (d_in + d_out)) + print('Initializing dense weights in [%f,%f]' % (-r, r)) + self.W.randomize(-r, r, n_threads=self.n_threads) + self.b.assign_all(0) + + def input_from(self, player, **kwargs): + self.W.input_from(player, **kwargs) + if self.input_bias: + self.b.input_from(player, **kwargs) + + def compute_f_input(self, batch): + N = len(batch) + prod = MultiArray([N, self.d, self.d_out], sfix) + + # flattened_array version + result_matrix = sfix.Matrix(N * self.d, self.d_out, address=prod.address) + max_size = get_program().budget + + # X is stored at full dataset indices, batch specifies which samples to use + X_sub = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address) + + # Precompute batch_d_indices for all N*d elements + # For each sample in batch, expand to d consecutive indices + batch_d_indices = regint.Array(N * self.d) + @for_range(N) + def _(i): + actual_sample = batch[i] + batch_d_indices.assign(regint.inc(self.d, actual_sample * self.d), i * self.d) + + @multithread(self.n_threads, N * self.d, max_size) + def _(base, size): + result_matrix.assign_part_vector( + X_sub.direct_mul(self.W, indices=( + batch_d_indices.get_vector(base, size), regint.inc(self.d_in), + regint.inc(self.d_in), regint.inc(self.d_out))), base) + + if self.input_bias: + if self.d_out == 1: + @multithread(self.n_threads, N) + def _(base, size): + v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size) + self.f_input.assign_vector(v, base) + else: + @for_range_multithread(self.n_threads, 100, [N, self.d]) + def _(i, j): + v = prod[i][j].get_vector() + self.b.get_vector() + self.f_input[i][j].assign_vector(v) + + progress('f input') + + def _forward(self, batch=None): + if batch is None: + batch = regint.Array(self.N) + batch.assign(regint.inc(self.N)) + self.compute_f_input(batch=batch) + if self.activation_layer: + self.activation_layer.forward(batch, training=True) + if self.debug_output: + print_ln('dense X %s', self.X.reveal_nested()) + print_ln('dense W %s', self.W.reveal_nested()) + print_ln('dense b %s', self.b.reveal_nested()) + print_ln('dense Y %s', self.Y.reveal_nested()) + if self.debug: + limit = self.debug + @for_range_opt(len(batch)) + def _(i): + @for_range_opt(self.d_out) + def _(j): + to_check = self.Y[i][0][j].reveal() + check = to_check > limit + @if_(check) + def _(): + print_ln('dense Y %s %s %s %s', i, j, self.W.sizes, to_check) + print_ln('X %s', self.X[i].reveal_nested()) + print_ln('W %s', + [self.W[k][j].reveal() for k in range(self.d_in)]) + + def backward(self, compute_nabla_X=True, batch=None): + N = len(batch) + d = self.d + d_out = self.d_out + X = self.X + Y = self.Y + W = self.W + b = self.b + nabla_X = self.nabla_X + nabla_Y = self.nabla_Y + nabla_W = self.nabla_W + nabla_b = self.nabla_b + + if self.activation_layer: + self.activation_layer.backward(batch) + f_schur_Y = self.activation_layer.nabla_X + else: + f_schur_Y = nabla_Y + + if compute_nabla_X: + nabla_X.alloc() + + # flattened matrix version + result_matrix = sfix.Matrix(N * self.d, self.d_in, address=nabla_X.address) + # Note: f_schur_Y is stored at indices [0, 1, ..., N-1] not at actual batch indices + @multithread(self.n_threads, N * self.d) + def _(base, size): + X_sub = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address) + offset = regint.inc(size, base=base) + + result_matrix.assign_part_vector( + X_sub.direct_mul_trans(self.W, indices=( + offset, regint.inc(self.d_out), + regint.inc(self.d_out), regint.inc(self.d_in))), base) + + if self.print_random_update: + print_ln('backward %s', self) + index = regint.get_random(64) % self.nabla_X.total_size() + print_ln('%s nabla_X at %s: %s', str(self.nabla_X), + index, self.nabla_X.to_array()[index].reveal()) + + progress('nabla X') + + self.backward_params(f_schur_Y, batch=batch) class QuantizedDense(DenseBase): @@ -3108,7 +2940,7 @@ class BertIntermediate(BertBase): input_shape = [n_examples, seq_len, hidden_size] output_shape = [n_examples, seq_len, intermediate_size] super(BertIntermediate, self).__init__(input_shape, output_shape) - self.dense = FlexDense(n_examples, hidden_size, intermediate_size, seq_len) + self.dense = Dense(n_examples, hidden_size, intermediate_size, seq_len) self.activation = Gelu([n_examples, seq_len, intermediate_size]) @@ -3150,7 +2982,7 @@ class BertOutput(BertBase): self.input_shape = input_shape print("INSTANTIATING BERTOUTPUT with ", input_shape, output_shape, intermediate_size, hidden_size, rsqrt_approx) super(BertOutput, self).__init__(input_shape, output_shape) - self.dense = FlexDense(n_examples, intermediate_size, hidden_size, seq_len) + self.dense = Dense(n_examples, intermediate_size, hidden_size, seq_len) self.layer_norm = LayerNorm(output_shape, layernorm_eps=layernorm_eps, approx=rsqrt_approx) self.dropout = FlexDropout([n_examples, seq_len, hidden_size], alpha=dropout) @@ -3241,9 +3073,9 @@ class MultiHeadAttention(BertBase): self.seq_len = seq_len self.hidden_size = hidden_size - self.wq = FlexDense(n_examples, hidden_size, self.all_head_size, self.seq_len) - self.wk = FlexDense(n_examples, hidden_size, self.all_head_size, self.seq_len) - self.wv = FlexDense(n_examples, hidden_size, self.all_head_size, self.seq_len) + self.wq = Dense(n_examples, hidden_size, self.all_head_size, self.seq_len) + self.wk = Dense(n_examples, hidden_size, self.all_head_size, self.seq_len) + self.wv = Dense(n_examples, hidden_size, self.all_head_size, self.seq_len) self.dropout = FlexDropout([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], alpha=dropout) # I think? # TODO: DROPOUT? self.output = BertOutput(internal_shape, hidden_size, hidden_size, seq_len, dropout, layernorm_eps, rsqrt_approx) @@ -4621,7 +4453,7 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None, if name == 'Linear': assert mul(input_shape[1:]) == item.in_features assert item.bias is not None - layers.append(FlexDense(input_shape[0], item.in_features, + layers.append(Dense(input_shape[0], item.in_features, item.out_features)) if input_via is not None: shapes = [x.shape for x in (layers[-1].W, layers[-1].b)]