diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 4287844a..67410678 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -17,6 +17,7 @@ from Compiler import instructions_base import Compiler.GC.instructions as inst import operator import math +import itertools from functools import reduce class bits(Tape.Register, _structure, _bit): @@ -1182,9 +1183,20 @@ class sbitintvec(sbitvec, _number, _bitint, _sbitintbase): return self.from_vec(other * x for x in self.v) elif isinstance(other, sbitfixvec): return NotImplemented + other_bits = util.bit_decompose(other) + m = float('inf') + for x in itertools.chain(self.v, other_bits): + try: + m = min(m, x.n) + except: + pass + if m == 1: + op = operator.mul + else: + op = operator.and_ matrix = [] - for i, b in enumerate(util.bit_decompose(other)): - matrix.append([x & b for x in self.v[:len(self.v)-i]]) + for i, b in enumerate(other_bits): + matrix.append([op(x, b) for x in self.v[:len(self.v)-i]]) v = sbitint.wallace_tree_from_matrix(matrix) return self.from_vec(v[:len(self.v)]) __rmul__ = __mul__ diff --git a/Compiler/types.py b/Compiler/types.py index 10b2c424..1531c49d 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5263,7 +5263,8 @@ class Array(_vectorizable): # length can be None for single-element arrays length = 0 base = self.address + index * self.value_type.mem_size() - if size is not None and isinstance(base, _register): + if size is not None and isinstance(base, _register) \ + and not issubclass(self.value_type, _vec): base = regint._expand_address(base, size) self.address_cache[program.curr_block, key] = \ util.untuplify([base + i * length \ @@ -6063,6 +6064,7 @@ class SubMultiArray(_vectorizable): assert n_threads is None if max(res_matrix.sizes) > 1000: raise AttributeError() + self.value_type.matrix_mul A = self.get_vector() B = other.get_vector() res_matrix.assign_vector( diff --git a/GC/TinySecret.h b/GC/TinySecret.h index 9cdde3dc..85098d18 100644 --- a/GC/TinySecret.h +++ b/GC/TinySecret.h @@ -146,7 +146,7 @@ public: if (this != &res) res.get_regs().assign(this->get_regs().begin(), this->get_regs().begin() - + max(size_t(n_bits), this->get_regs().size())); + + min(size_t(n_bits), this->get_regs().size())); res.resize_regs(n_bits); } diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 5b0589b6..1d7c883d 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -666,6 +666,21 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const return r[1] + size; else return 0; + case TRANS: + if (reg_type == SBIT) + { + int n_outputs = n; + auto& args = start; + int n_inputs = args.size() - n_outputs; + long long res = 0; + for (int i = 0; i < n_outputs; i++) + res = max(res, args[i] + DIV_CEIL(n_inputs, 64)); + for (int j = 0; j < n_inputs; j++) + res = max(res, args[n_outputs] + DIV_CEIL(n_outputs, 64)); + return res; + } + else + return 0; default: if (get_reg_type() != reg_type) return 0; @@ -731,7 +746,6 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const case ANDM: case NOTS: case NOTCB: - case TRANS: size = DIV_CEIL(n, 64); break; case CONVCBIT2S: