diff --git a/src/android/app/src/main/jni/CMakeLists.txt b/src/android/app/src/main/jni/CMakeLists.txt index a53e014e2..acad4e622 100644 --- a/src/android/app/src/main/jni/CMakeLists.txt +++ b/src/android/app/src/main/jni/CMakeLists.txt @@ -39,6 +39,7 @@ add_library(citra-android SHARED vr/layers/PassthroughLayer.cpp vr/layers/UILayer.cpp vr/utils/JniUtils.cpp + vr/utils/JniClassNames.cpp vr/utils/MessageQueue.cpp ) diff --git a/src/android/app/src/main/jni/vr/layers/UILayer.cpp b/src/android/app/src/main/jni/vr/layers/UILayer.cpp index 8bfa712e5..d135303d5 100644 --- a/src/android/app/src/main/jni/vr/layers/UILayer.cpp +++ b/src/android/app/src/main/jni/vr/layers/UILayer.cpp @@ -167,7 +167,8 @@ UILayer::UILayer(const std::string& className, const XrVector3f&& position, , mEnv(env) { const int32_t initializationStatus = Init(className, activityObject, position, session); if (initializationStatus < 0) { - FAIL("Could not initialize UILayer(%s) -- error '%d'", className.c_str(), initializationStatus); + FAIL("Could not initialize UILayer(%s) -- error '%d'", className.c_str(), + initializationStatus); } } diff --git a/src/android/app/src/main/jni/vr/utils/JniClassNames.cpp b/src/android/app/src/main/jni/vr/utils/JniClassNames.cpp new file mode 100644 index 000000000..db531cd1a --- /dev/null +++ b/src/android/app/src/main/jni/vr/utils/JniClassNames.cpp @@ -0,0 +1,47 @@ +#include "JniClassNames.h" + +#include "LogUtils.h" + +#include + +namespace VR { +namespace JniGlobalRef { +jmethodID gFindClassMethodID = nullptr; +jobject gClassLoader = nullptr; +} // namespace JniGlobalRef +} // namespace VR + +void VR::JNI::InitJNI(JNIEnv* jni, jobject activityObject) { + assert(jni != nullptr); + const jclass activityClass = jni->GetObjectClass(activityObject); + if (activityClass == nullptr) { FAIL("Failed to get activity class"); } + + // Get the getClassLoader method ID + const jmethodID getClassLoaderMethod = + jni->GetMethodID(activityClass, "getClassLoader", "()Ljava/lang/ClassLoader;"); + if (getClassLoaderMethod == nullptr) { FAIL("Failed to get getClassLoader method ID"); } + + // Call getClassLoader of the activity object to obtain the class loader + const jobject classLoaderObject = jni->CallObjectMethod(activityObject, getClassLoaderMethod); + if (classLoaderObject == nullptr) { FAIL("Failed to get class loader object"); } + + JniGlobalRef::gClassLoader = jni->NewGlobalRef(classLoaderObject); + + // Step 3: Cache the findClass method ID + jclass classLoaderClass = jni->FindClass("java/lang/ClassLoader"); + if (classLoaderClass == nullptr) { FAIL("Failed to find class loader class"); } + JniGlobalRef::gFindClassMethodID = + jni->GetMethodID(classLoaderClass, "findClass", "(Ljava/lang/String;)Ljava/lang/Class;"); + if (JniGlobalRef::gFindClassMethodID == nullptr) { FAIL("Failed to get findClass method ID"); } + + // Cleanup local references + jni->DeleteLocalRef(activityClass); + jni->DeleteLocalRef(classLoaderClass); +} + +void VR::JNI::CleanupJNI(JNIEnv* jni) { + assert(jni != nullptr); + if (JniGlobalRef::gClassLoader != nullptr) { jni->DeleteGlobalRef(JniGlobalRef::gClassLoader); } + JniGlobalRef::gClassLoader = nullptr; + JniGlobalRef::gFindClassMethodID = nullptr; +} diff --git a/src/android/app/src/main/jni/vr/utils/JniClassNames.h b/src/android/app/src/main/jni/vr/utils/JniClassNames.h new file mode 100644 index 000000000..9756b7ad3 --- /dev/null +++ b/src/android/app/src/main/jni/vr/utils/JniClassNames.h @@ -0,0 +1,18 @@ +#pragma once + +#include + +namespace VR { +namespace JniGlobalRef { +extern jmethodID gFindClassMethodID; +extern jobject gClassLoader; + +} // namespace JniGlobalRef + +namespace JNI { +// Called during JNI_OnLoad +void InitJNI(JNIEnv* env, jobject activityObject); +// Called during JNI_OnUnload +void CleanupJNI(JNIEnv* env); +} // namespace JNI +} // namespace VR diff --git a/src/android/app/src/main/jni/vr/utils/JniUtils.cpp b/src/android/app/src/main/jni/vr/utils/JniUtils.cpp index 6ac7db87e..bec2487ec 100644 --- a/src/android/app/src/main/jni/vr/utils/JniUtils.cpp +++ b/src/android/app/src/main/jni/vr/utils/JniUtils.cpp @@ -12,74 +12,39 @@ License : Licensed under GPLv3 or any later version. #include "JniUtils.h" +#include "JniClassNames.h" #include "LogUtils.h" -jclass JniUtils::GetGlobalClassReference(JNIEnv* jni, jobject activityObject, +jclass JniUtils::GetGlobalClassReference(JNIEnv* env, jobject activityObject, const std::string& className) { - // First, get the class object of the activity to get its class loader - const jclass activityClass = jni->GetObjectClass(activityObject); - if (activityClass == nullptr) { - ALOGE("Failed to get activity class"); + // Convert dot ('.') to slash ('/') in class name (Java uses dots, JNI uses slashes for class + // names) + std::string correctedClassName = className; + std::replace(correctedClassName.begin(), correctedClassName.end(), '.', '/'); + + // Convert std::string to jstring + jstring classNameJString = env->NewStringUTF(correctedClassName.c_str()); + + // Use the global class loader to find the class + jclass clazz = static_cast(env->CallObjectMethod( + VR::JniGlobalRef::gClassLoader, VR::JniGlobalRef::gFindClassMethodID, classNameJString)); + if (clazz == nullptr) { + // Class not found + ALOGE("Class not found: %s", correctedClassName.c_str()); return nullptr; } - // Get the getClassLoader method ID - const jmethodID getClassLoaderMethod = - jni->GetMethodID(activityClass, "getClassLoader", "()Ljava/lang/ClassLoader;"); - if (getClassLoaderMethod == nullptr) { - ALOGE("Failed to get getClassLoader method ID"); - return nullptr; + // Clean up the local reference to the class name jstring + env->DeleteLocalRef(classNameJString); + + // Check for exceptions and handle them. This is crucial to prevent crashes due to uncaught + // exceptions. + if (env->ExceptionCheck()) { + env->ExceptionDescribe(); + env->ExceptionClear(); + return nullptr; // Class not found or other issue } - // Call getClassLoader of the activity object to obtain the class loader - const jobject classLoaderObject = jni->CallObjectMethod(activityObject, getClassLoaderMethod); - if (classLoaderObject == nullptr) { - ALOGE("Failed to get class loader object"); - return nullptr; - } - - // Get the class loader class - const jclass classLoaderClass = jni->FindClass("java/lang/ClassLoader"); - if (classLoaderClass == nullptr) { - ALOGE("Failed to get class loader class"); - return nullptr; - } - - // Get the findClass method ID from the class loader class - const jmethodID findClassMethod = - jni->GetMethodID(classLoaderClass, "findClass", "(Ljava/lang/String;)Ljava/lang/Class;"); - if (findClassMethod == nullptr) { - ALOGE("Failed to get findClass method ID"); - return nullptr; - } - - // Convert the class name string to a jstring - const jstring javaClassName = jni->NewStringUTF(className.c_str()); - if (javaClassName == nullptr) { - ALOGE("Failed to convert class name to jstring"); - return nullptr; - } - - // Call findClass on the class loader object with the class name - const jclass classToFind = static_cast( - jni->CallObjectMethod(classLoaderObject, findClassMethod, javaClassName)); - - // Clean up local references - jni->DeleteLocalRef(activityClass); - jni->DeleteLocalRef(classLoaderObject); - jni->DeleteLocalRef(classLoaderClass); - jni->DeleteLocalRef(javaClassName); - - if (classToFind == nullptr) { - // Handle error (Class not found) - return nullptr; - } - - // Create a global reference to the class - const jclass globalClassRef = reinterpret_cast(jni->NewGlobalRef(classToFind)); - - // Clean up the local reference of the class - jni->DeleteLocalRef(classToFind); - - return globalClassRef; + // Return a global reference to the class + return static_cast(env->NewGlobalRef(clazz)); } diff --git a/src/android/app/src/main/jni/vr/vr_main.cpp b/src/android/app/src/main/jni/vr/vr_main.cpp index cabd4d0da..d81d40533 100644 --- a/src/android/app/src/main/jni/vr/vr_main.cpp +++ b/src/android/app/src/main/jni/vr/vr_main.cpp @@ -21,6 +21,7 @@ License : Licensed under GPLv3 or any later version. #include "vr_settings.h" #include "utils/Common.h" +#include "utils/JniClassNames.h" #include "utils/MessageQueue.h" #include "utils/SyspropUtils.h" #include "utils/XrMath.h" @@ -1126,6 +1127,7 @@ Java_org_citra_citra_1emu_vr_VrActivity_nativeOnCreate(JNIEnv* env, jobject thiz // time to first frame. gOnCreateStartTime = std::chrono::steady_clock::now(); + VR::JNI::InitJNI(env, thiz); JavaVM* jvm; env->GetJavaVM(&jvm); auto ret = VRAppHandle(new VRAppThread(jvm, env, thiz)).l; @@ -1138,6 +1140,7 @@ Java_org_citra_citra_1emu_vr_VrActivity_nativeOnDestroy(JNIEnv* env, jobject thi ALOGI("nativeOnDestroy {}", static_cast(handle)); if (handle != 0) { delete VRAppHandle(handle).p; } + VR::JNI::CleanupJNI(env); } extern "C" JNIEXPORT jint JNICALL