From fcb7e0fc20fd836323004c1c678fc98ff8e687ed Mon Sep 17 00:00:00 2001 From: attila Date: Tue, 27 Sep 2022 20:14:21 +0200 Subject: [PATCH] WinRT midi: Ensure object lifetimes in WinRT async callbacks While the affected callbacks are cancelled before the referenced state is deleted, we have had user reports that they can still be accessed by the cancelled callbacks causing crashes. After only finding warnings that WinRT AsyncCallback cancellation is not a guaranteed thing, we saw it best to wrap the pointers. --- .../native/juce_win32_Midi.cpp | 140 +++++++++++------- .../juce_core/system/juce_StandardHeader.h | 1 + 2 files changed, 88 insertions(+), 53 deletions(-) diff --git a/modules/juce_audio_devices/native/juce_win32_Midi.cpp b/modules/juce_audio_devices/native/juce_win32_Midi.cpp index 803f5b34f6..c2b13f7098 100644 --- a/modules/juce_audio_devices/native/juce_win32_Midi.cpp +++ b/modules/juce_audio_devices/native/juce_win32_Midi.cpp @@ -29,6 +29,39 @@ namespace juce { +template +class CheckedReference +{ +public: + template + friend auto createCheckedReference (Ptr*); + + void clear() + { + std::lock_guard lock { mutex }; + ptr = nullptr; + } + + template + void access (Callback&& callback) + { + std::lock_guard lock { mutex }; + callback (ptr); + } + +private: + explicit CheckedReference (T* ptrIn) : ptr (ptrIn) {} + + T* ptr; + std::mutex mutex; +}; + +template +auto createCheckedReference (Ptr* ptrIn) +{ + return std::shared_ptr> { new CheckedReference (ptrIn) }; +} + class MidiInput::Pimpl { public: @@ -1417,59 +1450,56 @@ private: }; //============================================================================== - template - struct OpenMidiPortThread : public Thread + template + static void openMidiPortThread (String threadName, + String midiDeviceID, + ComSmartPtr& comFactory, + ComSmartPtr& comPort) { - OpenMidiPortThread (String threadName, String midiDeviceID, - ComSmartPtr& comFactory, - ComSmartPtr& comPort) - : Thread (threadName), - deviceID (midiDeviceID), - factory (comFactory), - port (comPort) + std::thread { [&] { - } + Thread::setCurrentThreadName (threadName); - ~OpenMidiPortThread() - { - stopThread (2000); - } - - void run() override - { - WinRTWrapper::ScopedHString hDeviceId (deviceID); + const WinRTWrapper::ScopedHString hDeviceId { midiDeviceID }; ComSmartPtr> asyncOp; - auto hr = factory->FromIdAsync (hDeviceId.get(), asyncOp.resetAndGetPointerAddress()); + const auto hr = comFactory->FromIdAsync (hDeviceId.get(), asyncOp.resetAndGetPointerAddress()); if (FAILED (hr)) return; - hr = asyncOp->put_Completed (Callback> ( - [this] (IAsyncOperation* asyncOpPtr, AsyncStatus) - { - if (asyncOpPtr == nullptr) - return E_ABORT; + std::promise> promise; + auto future = promise.get_future(); - auto hr = asyncOpPtr->GetResults (port.resetAndGetPointerAddress()); + auto callback = [p = std::move (promise)] (IAsyncOperation* asyncOpPtr, AsyncStatus) mutable + { + if (asyncOpPtr == nullptr) + { + p.set_value (nullptr); + return E_ABORT; + } - if (FAILED (hr)) - return hr; + ComSmartPtr result; + const auto hr = asyncOpPtr->GetResults (result.resetAndGetPointerAddress()); - portOpened.signal(); - return S_OK; - } - ).Get()); + if (FAILED (hr)) + { + p.set_value (nullptr); + return hr; + } - // We need to use a timeout here, rather than waiting indefinitely, as the - // WinRT API can occasionally hang! - portOpened.wait (2000); - } + p.set_value (std::move (result)); + return S_OK; + }; - const String deviceID; - ComSmartPtr& factory; - ComSmartPtr& port; - WaitableEvent portOpened { true }; - }; + const auto ir = asyncOp->put_Completed (Callback> (std::move (callback)).Get()); + + if (FAILED (ir)) + return; + + if (future.wait_for (std::chrono::milliseconds (2000)) == std::future_status::ready) + comPort = future.get(); + } }.join(); + } //============================================================================== template @@ -1565,12 +1595,7 @@ private: inputDevice (input), callback (cb) { - OpenMidiPortThread portThread ("Open WinRT MIDI input port", - deviceInfo.deviceID, - service.midiInFactory, - midiPort); - portThread.startThread(); - portThread.waitForThreadToExit (-1); + openMidiPortThread ("Open WinRT MIDI input port", deviceInfo.deviceID, service.midiInFactory, midiPort); if (midiPort == nullptr) { @@ -1582,7 +1607,18 @@ private: auto hr = midiPort->add_MessageReceived ( Callback> ( - [this] (IMidiInPort*, IMidiMessageReceivedEventArgs* args) { return midiInMessageReceived (args); } + [self = checkedReference] (IMidiInPort*, IMidiMessageReceivedEventArgs* args) + { + HRESULT hr = S_OK; + + self->access ([&hr, args] (auto* ptr) + { + if (ptr != nullptr) + hr = ptr->midiInMessageReceived (args); + }); + + return hr; + } ).Get(), &midiInMessageToken); @@ -1595,6 +1631,7 @@ private: ~WinRTInputWrapper() { + checkedReference->clear(); disconnect(); } @@ -1706,6 +1743,8 @@ private: double startTime = 0; bool isStarted = false; + std::shared_ptr> checkedReference = createCheckedReference (this); + JUCE_DECLARE_NON_COPYABLE_WITH_LEAK_DETECTOR (WinRTInputWrapper); }; @@ -1716,12 +1755,7 @@ private: WinRTOutputWrapper (WinRTMidiService& service, const String& deviceIdentifier) : WinRTIOWrapper (*service.bleDeviceWatcher, *service.outputDeviceWatcher, deviceIdentifier) { - OpenMidiPortThread portThread ("Open WinRT MIDI output port", - deviceInfo.deviceID, - service.midiOutFactory, - midiPort); - portThread.startThread(); - portThread.waitForThreadToExit (-1); + openMidiPortThread ("Open WinRT MIDI output port", deviceInfo.deviceID, service.midiOutFactory, midiPort); if (midiPort == nullptr) throw std::runtime_error ("Timed out waiting for midi output port creation"); diff --git a/modules/juce_core/system/juce_StandardHeader.h b/modules/juce_core/system/juce_StandardHeader.h index 05c3fbb953..5e3a497bde 100644 --- a/modules/juce_core/system/juce_StandardHeader.h +++ b/modules/juce_core/system/juce_StandardHeader.h @@ -55,6 +55,7 @@ #include #include #include +#include #include #include #include