diff --git a/Source/Core/Core/PowerPC/JitArm64/Jit.h b/Source/Core/Core/PowerPC/JitArm64/Jit.h index 8d50c24b13..9936b55a2d 100644 --- a/Source/Core/Core/PowerPC/JitArm64/Jit.h +++ b/Source/Core/Core/PowerPC/JitArm64/Jit.h @@ -141,6 +141,7 @@ public: void frspx(UGeckoInstruction inst); void fctiwzx(UGeckoInstruction inst); void fresx(UGeckoInstruction inst); + void frsqrtex(UGeckoInstruction inst); // Paired void ps_maddXX(UGeckoInstruction inst); @@ -149,6 +150,7 @@ public: void ps_sel(UGeckoInstruction inst); void ps_sumX(UGeckoInstruction inst); void ps_res(UGeckoInstruction inst); + void ps_rsqrte(UGeckoInstruction inst); // Loadstore paired void psq_l(UGeckoInstruction inst); @@ -235,6 +237,7 @@ protected: void GenerateAsm(); void GenerateCommonAsm(); void GenerateFres(); + void GenerateFrsqrte(); void GenerateConvertDoubleToSingle(); void GenerateConvertSingleToDouble(); void GenerateFPRF(bool single); diff --git a/Source/Core/Core/PowerPC/JitArm64/JitArm64_FloatingPoint.cpp b/Source/Core/Core/PowerPC/JitArm64/JitArm64_FloatingPoint.cpp index 7c2d94cd4e..47cdd42de7 100644 --- a/Source/Core/Core/PowerPC/JitArm64/JitArm64_FloatingPoint.cpp +++ b/Source/Core/Core/PowerPC/JitArm64/JitArm64_FloatingPoint.cpp @@ -456,6 +456,32 @@ void JitArm64::fresx(UGeckoInstruction inst) m_float_emit.FMOV(EncodeRegToDouble(VD), ARM64Reg::X0); } +void JitArm64::frsqrtex(UGeckoInstruction inst) +{ + INSTRUCTION_START + JITDISABLE(bJITFloatingPointOff); + FALLBACK_IF(inst.Rc); + FALLBACK_IF(SConfig::GetInstance().bFPRF && js.op->wantsFPRF); + + const u32 b = inst.FB; + const u32 d = inst.FD; + + gpr.Lock(ARM64Reg::W0, ARM64Reg::W1, ARM64Reg::W2, ARM64Reg::W3, ARM64Reg::W4, ARM64Reg::W30); + fpr.Lock(ARM64Reg::Q0); + + const ARM64Reg VB = fpr.R(b, RegType::LowerPair); + m_float_emit.FMOV(ARM64Reg::X1, EncodeRegToDouble(VB)); + m_float_emit.FRSQRTE(ARM64Reg::D0, EncodeRegToDouble(VB)); + + BL(GetAsmRoutines()->frsqrte); + + gpr.Unlock(ARM64Reg::W0, ARM64Reg::W1, ARM64Reg::W2, ARM64Reg::W3, ARM64Reg::W4, ARM64Reg::W30); + fpr.Unlock(ARM64Reg::Q0); + + const ARM64Reg VD = fpr.RW(d, RegType::LowerPair); + m_float_emit.FMOV(EncodeRegToDouble(VD), ARM64Reg::X0); +} + // Since the following float conversion functions are used in non-arithmetic PPC float // instructions, they must convert floats bitexact and never flush denormals to zero or turn SNaNs // into QNaNs. This means we can't just use FCVT/FCVTL/FCVTN. diff --git a/Source/Core/Core/PowerPC/JitArm64/JitArm64_Paired.cpp b/Source/Core/Core/PowerPC/JitArm64/JitArm64_Paired.cpp index 326d01d515..ebbfe3fbd2 100644 --- a/Source/Core/Core/PowerPC/JitArm64/JitArm64_Paired.cpp +++ b/Source/Core/Core/PowerPC/JitArm64/JitArm64_Paired.cpp @@ -384,3 +384,34 @@ void JitArm64::ps_res(UGeckoInstruction inst) fpr.FixSinglePrecision(d); } + +void JitArm64::ps_rsqrte(UGeckoInstruction inst) +{ + INSTRUCTION_START + JITDISABLE(bJITPairedOff); + FALLBACK_IF(inst.Rc); + FALLBACK_IF(SConfig::GetInstance().bFPRF && js.op->wantsFPRF); + + const u32 b = inst.FB; + const u32 d = inst.FD; + + gpr.Lock(ARM64Reg::W0, ARM64Reg::W1, ARM64Reg::W2, ARM64Reg::W3, ARM64Reg::W4, ARM64Reg::W30); + fpr.Lock(ARM64Reg::Q0); + + const ARM64Reg VB = fpr.R(b, RegType::Register); + const ARM64Reg VD = fpr.RW(d, RegType::Register); + + m_float_emit.FMOV(ARM64Reg::X1, EncodeRegToDouble(VB)); + m_float_emit.FRSQRTE(64, ARM64Reg::Q0, EncodeRegToQuad(VB)); + BL(GetAsmRoutines()->frsqrte); + m_float_emit.UMOV(64, ARM64Reg::X1, EncodeRegToQuad(VB), 1); + m_float_emit.DUP(64, ARM64Reg::Q0, ARM64Reg::Q0, 1); + m_float_emit.FMOV(EncodeRegToDouble(VD), ARM64Reg::X0); + BL(GetAsmRoutines()->frsqrte); + m_float_emit.INS(64, EncodeRegToQuad(VD), 1, ARM64Reg::X0); + + gpr.Unlock(ARM64Reg::W0, ARM64Reg::W1, ARM64Reg::W2, ARM64Reg::W3, ARM64Reg::W4, ARM64Reg::W30); + fpr.Unlock(ARM64Reg::Q0); + + fpr.FixSinglePrecision(d); +} diff --git a/Source/Core/Core/PowerPC/JitArm64/JitArm64_Tables.cpp b/Source/Core/Core/PowerPC/JitArm64/JitArm64_Tables.cpp index 04dc2ca615..471a4566c8 100644 --- a/Source/Core/Core/PowerPC/JitArm64/JitArm64_Tables.cpp +++ b/Source/Core/Core/PowerPC/JitArm64/JitArm64_Tables.cpp @@ -106,23 +106,23 @@ constexpr std::array table4{{ }}; constexpr std::array table4_2{{ - {10, &JitArm64::ps_sumX}, // ps_sum0 - {11, &JitArm64::ps_sumX}, // ps_sum1 - {12, &JitArm64::ps_mulsX}, // ps_muls0 - {13, &JitArm64::ps_mulsX}, // ps_muls1 - {14, &JitArm64::ps_maddXX}, // ps_madds0 - {15, &JitArm64::ps_maddXX}, // ps_madds1 - {18, &JitArm64::fp_arith}, // ps_div - {20, &JitArm64::fp_arith}, // ps_sub - {21, &JitArm64::fp_arith}, // ps_add - {23, &JitArm64::ps_sel}, // ps_sel - {24, &JitArm64::ps_res}, // ps_res - {25, &JitArm64::fp_arith}, // ps_mul - {26, &JitArm64::FallBackToInterpreter}, // ps_rsqrte - {28, &JitArm64::ps_maddXX}, // ps_msub - {29, &JitArm64::ps_maddXX}, // ps_madd - {30, &JitArm64::ps_maddXX}, // ps_nmsub - {31, &JitArm64::ps_maddXX}, // ps_nmadd + {10, &JitArm64::ps_sumX}, // ps_sum0 + {11, &JitArm64::ps_sumX}, // ps_sum1 + {12, &JitArm64::ps_mulsX}, // ps_muls0 + {13, &JitArm64::ps_mulsX}, // ps_muls1 + {14, &JitArm64::ps_maddXX}, // ps_madds0 + {15, &JitArm64::ps_maddXX}, // ps_madds1 + {18, &JitArm64::fp_arith}, // ps_div + {20, &JitArm64::fp_arith}, // ps_sub + {21, &JitArm64::fp_arith}, // ps_add + {23, &JitArm64::ps_sel}, // ps_sel + {24, &JitArm64::ps_res}, // ps_res + {25, &JitArm64::fp_arith}, // ps_mul + {26, &JitArm64::ps_rsqrte}, // ps_rsqrte + {28, &JitArm64::ps_maddXX}, // ps_msub + {29, &JitArm64::ps_maddXX}, // ps_madd + {30, &JitArm64::ps_maddXX}, // ps_nmsub + {31, &JitArm64::ps_maddXX}, // ps_nmadd }}; constexpr std::array table4_3{{ @@ -324,16 +324,16 @@ constexpr std::array table63{{ }}; constexpr std::array table63_2{{ - {18, &JitArm64::fp_arith}, // fdivx - {20, &JitArm64::fp_arith}, // fsubx - {21, &JitArm64::fp_arith}, // faddx - {23, &JitArm64::fselx}, // fselx - {25, &JitArm64::fp_arith}, // fmulx - {26, &JitArm64::FallBackToInterpreter}, // frsqrtex - {28, &JitArm64::fp_arith}, // fmsubx - {29, &JitArm64::fp_arith}, // fmaddx - {30, &JitArm64::fp_arith}, // fnmsubx - {31, &JitArm64::fp_arith}, // fnmaddx + {18, &JitArm64::fp_arith}, // fdivx + {20, &JitArm64::fp_arith}, // fsubx + {21, &JitArm64::fp_arith}, // faddx + {23, &JitArm64::fselx}, // fselx + {25, &JitArm64::fp_arith}, // fmulx + {26, &JitArm64::frsqrtex}, // frsqrtex + {28, &JitArm64::fp_arith}, // fmsubx + {29, &JitArm64::fp_arith}, // fmaddx + {30, &JitArm64::fp_arith}, // fnmsubx + {31, &JitArm64::fp_arith}, // fnmaddx }}; constexpr std::array dynaOpTable = [] { diff --git a/Source/Core/Core/PowerPC/JitArm64/JitAsm.cpp b/Source/Core/Core/PowerPC/JitArm64/JitAsm.cpp index afbcd14596..e38a5706e2 100644 --- a/Source/Core/Core/PowerPC/JitArm64/JitAsm.cpp +++ b/Source/Core/Core/PowerPC/JitArm64/JitAsm.cpp @@ -205,6 +205,10 @@ void JitArm64::GenerateCommonAsm() GenerateFres(); JitRegister::Register(GetAsmRoutines()->fres, GetCodePtr(), "JIT_fres"); + GetAsmRoutines()->frsqrte = GetCodePtr(); + GenerateFrsqrte(); + JitRegister::Register(GetAsmRoutines()->frsqrte, GetCodePtr(), "JIT_frsqrte"); + GetAsmRoutines()->cdts = GetCodePtr(); GenerateConvertDoubleToSingle(); JitRegister::Register(GetAsmRoutines()->cdts, GetCodePtr(), "JIT_cdts"); @@ -276,6 +280,71 @@ void JitArm64::GenerateFres() RET(); } +// Input: X1 contains input, and D0 contains result of running the input through AArch64 FRSQRTE. +// Output in X0 and memory (PPCState). Clobbers X0-X4 and flags. +void JitArm64::GenerateFrsqrte() +{ + // The idea behind this implementation: AArch64's frsqrte instruction calculates the exponent and + // sign the same way as PowerPC's frsqrtex does. For the special inputs zero, negative, NaN and + // inf, even the mantissa matches. But the mantissa does not match for most other inputs, so in + // the normal case we calculate the mantissa using the table-based algorithm from the interpreter. + + TSTI2R(ARM64Reg::X1, Common::DOUBLE_EXP | Common::DOUBLE_FRAC); + m_float_emit.FMOV(ARM64Reg::X0, ARM64Reg::D0); + FixupBranch zero = B(CCFlags::CC_EQ); + ANDI2R(ARM64Reg::X2, ARM64Reg::X1, Common::DOUBLE_EXP); + MOVI2R(ARM64Reg::X3, Common::DOUBLE_EXP); + CMP(ARM64Reg::X2, ARM64Reg::X3); + FixupBranch nan_or_inf = B(CCFlags::CC_EQ); + FixupBranch negative = TBNZ(ARM64Reg::X1, 63); + ANDI2R(ARM64Reg::X3, ARM64Reg::X1, Common::DOUBLE_FRAC); + FixupBranch normal = CBNZ(ARM64Reg::X2); + + // "Normalize" denormal values + CLZ(ARM64Reg::X3, ARM64Reg::X3); + SUB(ARM64Reg::X4, ARM64Reg::X3, 11); + MOVI2R(ARM64Reg::X2, 0x00C0'0000'0000'0000); + LSLV(ARM64Reg::X4, ARM64Reg::X1, ARM64Reg::X4); + SUB(ARM64Reg::X2, ARM64Reg::X2, ARM64Reg::X3, ArithOption(ARM64Reg::X3, ShiftType::LSL, 52)); + ANDI2R(ARM64Reg::X3, ARM64Reg::X4, Common::DOUBLE_FRAC - 1); + + SetJumpTarget(normal); + LSR(ARM64Reg::X2, ARM64Reg::X2, 48); + ANDI2R(ARM64Reg::X2, ARM64Reg::X2, 0x10); + MOVP2R(ARM64Reg::X1, &Common::frsqrte_expected); + ORR(ARM64Reg::X2, ARM64Reg::X2, ARM64Reg::X3, ArithOption(ARM64Reg::X8, ShiftType::LSR, 48)); + EORI2R(ARM64Reg::X2, ARM64Reg::X2, 0x10); + ADD(ARM64Reg::X2, ARM64Reg::X1, ARM64Reg::X2, ArithOption(ARM64Reg::X2, ShiftType::LSL, 3)); + LDP(IndexType::Signed, ARM64Reg::W1, ARM64Reg::W2, ARM64Reg::X2, 0); + UBFX(ARM64Reg::X3, ARM64Reg::X3, 37, 11); + ANDI2R(ARM64Reg::X0, ARM64Reg::X0, Common::DOUBLE_SIGN | Common::DOUBLE_EXP); + MSUB(ARM64Reg::W3, ARM64Reg::W3, ARM64Reg::W2, ARM64Reg::W1); + ORR(ARM64Reg::X0, ARM64Reg::X0, ARM64Reg::X3, ArithOption(ARM64Reg::X3, ShiftType::LSL, 26)); + RET(); + + SetJumpTarget(zero); + LDR(IndexType::Unsigned, ARM64Reg::W4, PPC_REG, PPCSTATE_OFF(fpscr)); + FixupBranch skip_set_zx = TBNZ(ARM64Reg::W4, 26); + ORRI2R(ARM64Reg::W4, ARM64Reg::W4, FPSCR_FX | FPSCR_ZX, ARM64Reg::W2); + STR(IndexType::Unsigned, ARM64Reg::W4, PPC_REG, PPCSTATE_OFF(fpscr)); + SetJumpTarget(skip_set_zx); + RET(); + + SetJumpTarget(nan_or_inf); + MOVI2R(ARM64Reg::X3, Common::BitCast(-std::numeric_limits::infinity())); + CMP(ARM64Reg::X1, ARM64Reg::X3); + FixupBranch nan_or_positive_inf = B(CCFlags::CC_NEQ); + + SetJumpTarget(negative); + LDR(IndexType::Unsigned, ARM64Reg::W4, PPC_REG, PPCSTATE_OFF(fpscr)); + FixupBranch skip_set_vxsqrt = TBNZ(ARM64Reg::W4, 9); + ORRI2R(ARM64Reg::W4, ARM64Reg::W4, FPSCR_FX | FPSCR_VXSQRT, ARM64Reg::W2); + STR(IndexType::Unsigned, ARM64Reg::W4, PPC_REG, PPCSTATE_OFF(fpscr)); + SetJumpTarget(skip_set_vxsqrt); + SetJumpTarget(nan_or_positive_inf); + RET(); +} + // Input in X0, output in W1, clobbers X0-X3 and flags. void JitArm64::GenerateConvertDoubleToSingle() { diff --git a/Source/UnitTests/Core/CMakeLists.txt b/Source/UnitTests/Core/CMakeLists.txt index b442581df0..134f7da78b 100644 --- a/Source/UnitTests/Core/CMakeLists.txt +++ b/Source/UnitTests/Core/CMakeLists.txt @@ -26,6 +26,7 @@ elseif(_M_ARM_64) PowerPC/JitArm64/ConvertSingleDouble.cpp PowerPC/JitArm64/FPRF.cpp PowerPC/JitArm64/Fres.cpp + PowerPC/JitArm64/Frsqrte.cpp PowerPC/JitArm64/MovI2R.cpp ) else() diff --git a/Source/UnitTests/Core/PowerPC/JitArm64/Frsqrte.cpp b/Source/UnitTests/Core/PowerPC/JitArm64/Frsqrte.cpp new file mode 100644 index 0000000000..749b147dcb --- /dev/null +++ b/Source/UnitTests/Core/PowerPC/JitArm64/Frsqrte.cpp @@ -0,0 +1,66 @@ +// Copyright 2021 Dolphin Emulator Project +// Licensed under GPLv2+ +// Refer to the license.txt file included. + +#include + +#include "Common/Arm64Emitter.h" +#include "Common/BitUtils.h" +#include "Common/CommonTypes.h" +#include "Core/PowerPC/Interpreter/Interpreter_FPUtils.h" +#include "Core/PowerPC/JitArm64/Jit.h" +#include "Core/PowerPC/PowerPC.h" + +#include "../TestValues.h" + +#include + +namespace +{ +using namespace Arm64Gen; + +class TestFrsqrte : public JitArm64 +{ +public: + TestFrsqrte() + { + AllocCodeSpace(4096); + + const u8* raw_frsqrte = GetCodePtr(); + GenerateFrsqrte(); + + frsqrte = Common::BitCast(GetCodePtr()); + MOV(ARM64Reg::X15, ARM64Reg::X30); + MOV(ARM64Reg::X14, PPC_REG); + MOVP2R(PPC_REG, &PowerPC::ppcState); + MOV(ARM64Reg::X1, ARM64Reg::X0); + m_float_emit.FMOV(ARM64Reg::D0, ARM64Reg::X0); + m_float_emit.FRSQRTE(ARM64Reg::D0, ARM64Reg::D0); + BL(raw_frsqrte); + MOV(ARM64Reg::X30, ARM64Reg::X15); + MOV(PPC_REG, ARM64Reg::X14); + RET(); + } + + std::function frsqrte; +}; + +} // namespace + +TEST(JitArm64, Frsqrte) +{ + TestFrsqrte test; + + for (const u64 ivalue : double_test_values) + { + const double dvalue = Common::BitCast(ivalue); + + const u64 expected = Common::BitCast(Common::ApproximateReciprocalSquareRoot(dvalue)); + const u64 actual = test.frsqrte(ivalue); + + if (expected != actual) + fmt::print("{:016x} -> {:016x} == {:016x}\n", ivalue, actual, expected); + + EXPECT_EQ(expected, actual); + } +} diff --git a/Source/UnitTests/UnitTests.vcxproj b/Source/UnitTests/UnitTests.vcxproj index ef95e22c1f..3d288aab80 100644 --- a/Source/UnitTests/UnitTests.vcxproj +++ b/Source/UnitTests/UnitTests.vcxproj @@ -85,6 +85,7 @@ +