diff --git a/tinygrad/codegen/assembly_arm64.py b/tinygrad/codegen/assembly_arm64.py index 900c18f5ec..9faa21cfdf 100644 --- a/tinygrad/codegen/assembly_arm64.py +++ b/tinygrad/codegen/assembly_arm64.py @@ -20,8 +20,8 @@ def specialize_to_arm64(fn_nm, asm): var_size = 16 prev_uop:Optional[UOps] = None ins = [] - x_regs = ['x' + str(i) for i in reversed(range(29)) if i not in (10,11,12,13,14,15,16,17,18,19,20)] - s_regs = ['s' + str(i) for i in reversed(range(3,30))] + x_regs = ['x' + str(i) for i in reversed(range(12))] + s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16] type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'} alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max", BinaryOps.MOD: "", BinaryOps.CMPLT: "subs", @@ -58,12 +58,12 @@ def specialize_to_arm64(fn_nm, asm): if len(available_regs) == 0: # ARM needs the stack 16-byte aligned var_size += 16 - available_regs.append('s0' if dtypes.is_float(out[1]) else 'x11') + available_regs.append('s0' if dtypes.is_float(out[1]) else 'x12') mem_vars[v.nm] = var_size rtor[v.nm] = available_regs.pop() temp_floats = ['s0', 's1', 's2'] - temp_ints = ['x11', 'x12', 'x13'] + temp_ints = ['x12', 'x13', 'x16'] for i, (uop, out, vin, arg) in enumerate(asm): # Clear regs out of interval for var, reg in list(rtor.items()): @@ -83,7 +83,7 @@ def specialize_to_arm64(fn_nm, asm): if arg.startswith('data'): # data 8 to n into the stack if int(arg[4:]) >= 8: - ins.append(f"ldr x15, [x19, #{(int(arg[4:]) - 8) * 8}]") + ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]") ins.append(f"mov {rtor[out.nm]}, x15") else: ins.append(f"mov {rtor[out.nm]}, #0") @@ -161,7 +161,7 @@ def specialize_to_arm64(fn_nm, asm): if out is not None and out.nm in mem_vars: ins.append(f"mov x15, {mem_vars[out.nm]}") ins.append(f"str {rtor[out.nm]}, [sp, x15]") - return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x19, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"]) + return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x17, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"]) def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]: lang = ARM64Language()