diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 3ff2c889..05b6dbcc 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -588,6 +588,7 @@ class Merger: if instr.indices_values is not None and instr.first_factor_base_addresses is not None and instr.second_factor_base_addresses is not None: # Determine which values get accessed by the MATMULSM instruction and only add the according dependencies. for matmul_idx in range(len(instr.first_factor_base_addresses)): + start_time = time.time() first_base = instr.first_factor_base_addresses[matmul_idx] second_base = instr.second_factor_base_addresses[matmul_idx] @@ -599,18 +600,35 @@ class Merger: first_factor_row_length = instr.args[12 * matmul_idx + 10] second_factor_row_length = instr.args[12 * matmul_idx + 11] + # Add dependencies to the first factor. for i in range(instr.args[12 * matmul_idx + 3]): - for j in range(instr.args[12 * matmul_idx + 5]): - for k in range(instr.args[12 * matmul_idx + 4]): - first_factor_addr = first_base + \ - first_factor_row_length * first_factor_row_indices[i] + \ - first_factor_column_indices[k] - handle_mem_access(first_factor_addr, 's', last_mem_read_of, last_mem_write_of) + if (time.time() - start_time) > 10: + # Abort building the dependencies if that takes too much time. + if block.warn_about_mem and not block.parent.warned_about_mem: + print('WARNING: Order of memory instructions not preserved due to long vector, errors possible') + block.parent.warned_about_mem = True + break - second_factor_addr = second_base + \ - second_factor_row_length * second_factor_row_indices[k] + \ - second_factor_column_indices[j] - handle_mem_access(second_factor_addr, 's', last_mem_read_of, last_mem_write_of) + for k in range(instr.args[12 * matmul_idx + 4]): + first_factor_addr = first_base + \ + first_factor_row_length * first_factor_row_indices[i] + \ + first_factor_column_indices[k] + handle_mem_access(first_factor_addr, 's', last_mem_read_of, last_mem_write_of) + + # Add dependencies to the second factor. + for k in range(instr.args[12 * matmul_idx + 4]): + if (time.time() - start_time) > 10: + # Abort building the dependencies if that takes too much time. + if block.warn_about_mem and not block.parent.warned_about_mem: + print('WARNING: Order of memory instructions not preserved due to long vector, errors possible') + block.parent.warned_about_mem = True + break + + for j in range(instr.args[12 * matmul_idx + 5]): + second_factor_addr = second_base + \ + second_factor_row_length * second_factor_row_indices[k] + \ + second_factor_column_indices[j] + handle_mem_access(second_factor_addr, 's', last_mem_read_of, last_mem_write_of) else: # If the accessed values cannot be determined, be cautious I guess. for i in last_mem_write_of.values(): diff --git a/Programs/Source/test_dot.mpc b/Programs/Source/test_dot.mpc index 148caa0d..92f0bad0 100644 --- a/Programs/Source/test_dot.mpc +++ b/Programs/Source/test_dot.mpc @@ -158,3 +158,8 @@ M = sint.Matrix(10, 10) M.direct_mul(M, indices=[regint(0), regint.inc(10), regint.inc(10), regint(0)]) stop_timer(9) + + +start_timer(10) +sint.Matrix(1000, 1000) * sint.Matrix(1000, 1000) +stop_timer(10)