diff --git a/test/test_dtype.py b/test/test_dtype.py index d451252104..82d6e3ddbd 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -48,6 +48,7 @@ def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype:DType): _assert_eq(Tensor([1,2,3,4], dtype=a_dtype)+Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [2,4,6,8]) _assert_eq(Tensor([1,2,3,4], dtype=a_dtype)*Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [1,4,9,16]) _assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]]) + _assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy()) class TestBFloat16DType(unittest.TestCase): def test_bf16_to_float(self): diff --git a/tinygrad/codegen/assembly_arm64.py b/tinygrad/codegen/assembly_arm64.py index 5e4df3aea2..900c18f5ec 100644 --- a/tinygrad/codegen/assembly_arm64.py +++ b/tinygrad/codegen/assembly_arm64.py @@ -139,12 +139,11 @@ def specialize_to_arm64(fn_nm, asm): ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]") if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}") elif uop == UOps.STORE: - shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2", dtypes.bool: "#2"} #NOTE: if need casting load var in s/h0 or x/w12 temp regs reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm]) if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}") ins.append(f"mov x15, #{arg[0]}") - ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]") + ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]") elif uop == UOps.COND_BRANCH: #TODO: this is a hack it shouldn't always be a cmp before a cond branch? if prev_uop == UOps.LOAD: