From 1e3703fe6408979830a93fb2f600cd05866f39d8 Mon Sep 17 00:00:00 2001 From: reuk Date: Wed, 12 Jun 2024 19:31:17 +0100 Subject: [PATCH] URLConnectionState: Simplify and improve thread safety --- modules/juce_core/native/juce_Network_mac.mm | 584 +++++++++---------- 1 file changed, 276 insertions(+), 308 deletions(-) diff --git a/modules/juce_core/native/juce_Network_mac.mm b/modules/juce_core/native/juce_Network_mac.mm index 6d338c8a1e..b3a8b3faaa 100644 --- a/modules/juce_core/native/juce_Network_mac.mm +++ b/modules/juce_core/native/juce_Network_mac.mm @@ -119,322 +119,286 @@ bool JUCE_CALLTYPE Process::openEmailWithAttachments ([[maybe_unused]] const Str } //============================================================================== -class URLConnectionStateBase : public Thread +class URLConnectionState final { public: - explicit URLConnectionStateBase (NSURLRequest* req, int maxRedirects) - : Thread ("http connection"), - request ([req retain]), - data ([[NSMutableData data] retain]), - numRedirectsToFollow (maxRedirects) + URLConnectionState (NSUniquePtr req, const int maxRedirects) + : request (std::move (req)), + numRedirects (maxRedirects) { + DelegateClass::setState (delegate.get(), this); } - virtual ~URLConnectionStateBase() = default; - - virtual void cancel() = 0; - virtual bool start (WebInputStream&, WebInputStream::Listener*) = 0; - virtual int read (char* dest, int numBytes) = 0; - - int64 getContentLength() const noexcept { return contentLength; } - NSDictionary* getHeaders() const noexcept { return headers; } - int getStatusCode() const noexcept { return statusCode; } - NSInteger getErrorCode() const noexcept { return nsUrlErrorCode; } - -protected: - CriticalSection dataLock, createConnectionLock; - id delegate = nil; - NSDictionary* headers = nil; - NSURLRequest* request = nil; - NSMutableData* data = nil; - int64 contentLength = -1; - int statusCode = 0; - NSInteger nsUrlErrorCode = 0; - - std::atomic initialised { false }, hasFailed { false }, hasFinished { false }; - const int numRedirectsToFollow; - int numRedirects = 0; - int64 latestTotalBytes = 0; - bool hasBeenCancelled = false; - -private: - - JUCE_DECLARE_NON_COPYABLE_WITH_LEAK_DETECTOR (URLConnectionStateBase) -}; - -//============================================================================== -class API_AVAILABLE (macos (10.9)) URLConnectionState final : public URLConnectionStateBase -{ -public: - URLConnectionState (NSURLRequest* req, const int maxRedirects) - : URLConnectionStateBase (req, maxRedirects) + ~URLConnectionState() { - static DelegateClass cls; - delegate = [cls.createInstance() init]; - DelegateClass::setState (delegate, this); + cancel(); + + std::unique_lock lock { mutex }; + [session.get() finishTasksAndInvalidate]; + condvar.wait (lock, [&] { return state == State::invalidated; }); } - ~URLConnectionState() override + void cancel() { - signalThreadShouldExit(); + const std::scoped_lock lock { mutex }; + // When a task completes, URLSession:task:didCompleteWithError: will be called on the + // delegate, even if the task is cancelled. + + if (auto* toCancel = task.get()) + [toCancel cancel]; + } + + int64 getContentLength() const noexcept + { + const std::scoped_lock lock { mutex }; + return contentLength; + } + + NSUniquePtr getHeaders() const noexcept + { + const std::scoped_lock lock { mutex }; + return NSUniquePtr { [headers.get() copy] }; + } + + int getStatusCode() const noexcept + { + const std::scoped_lock lock { mutex }; + return statusCode; + } + + bool start (WebInputStream& inputStream, WebInputStream::Listener* listener) + { + std::unique_lock lock { mutex }; + [task.get() resume]; + + while (! condvar.wait_for (lock, + std::chrono::milliseconds { 1 }, + [&] { return state != State::beforeStart; })) { - const ScopedLock sl (dataLock); - isBeingDeleted = true; - [task cancel]; - DelegateClass::setState (delegate, nullptr); - } - - stopThread (10000); - [task release]; - [request release]; - [headers release]; - - [session finishTasksAndInvalidate]; - [session release]; - - const ScopedLock sl (dataLock); - [delegate release]; - [data release]; - } - - void cancel() override - { - { - const ScopedLock lock (createConnectionLock); - hasBeenCancelled = true; - } - - signalThreadShouldExit(); - stopThread (10000); - } - - bool start (WebInputStream& inputStream, WebInputStream::Listener* listener) override - { - { - const ScopedLock lock (createConnectionLock); - - if (hasBeenCancelled) + if (listener != nullptr + && ! listener->postDataSendProgress (inputStream, + (int) latestTotalBytes, + (int) [[request.get() HTTPBody] length])) + { return false; - } - - startThread(); - - while (isThreadRunning() && ! initialised) - { - if (listener != nullptr) - if (! listener->postDataSendProgress (inputStream, (int) latestTotalBytes, (int) [[request HTTPBody] length])) - return false; - - Thread::sleep (1); + } } return true; } - int read (char* dest, int numBytes) override + int read (char* dest, int numBytes) { int numDone = 0; while (numBytes > 0) { - const ScopedLock sl (dataLock); - auto available = jmin (numBytes, (int) [data length]); + std::unique_lock lock { mutex }; - if (available > 0) - { - [data getBytes: dest length: (NSUInteger) available]; - [data replaceBytesInRange: NSMakeRange (0, (NSUInteger) available) withBytes: nil length: 0]; + const auto getNumAvailable = [&] { return jmin (numBytes, (int) [data.get() length]); }; + condvar.wait (lock, [&] { return getNumAvailable() > 0 || state == State::requestFinished; }); - numDone += available; - numBytes -= available; - dest += available; - } - else - { - if (hasFailed || hasFinished) - break; + const auto available = getNumAvailable(); - const ScopedUnlock ul (dataLock); - Thread::sleep (1); - } + if (available <= 0) + break; + + [data.get() getBytes: dest length: (NSUInteger) available]; + [data.get() replaceBytesInRange: NSMakeRange (0, (NSUInteger) available) withBytes: nil length: 0]; + + numDone += available; + numBytes -= available; + dest += available; } return numDone; } - void didReceiveResponse (NSURLResponse* response, id completionHandler) +private: + void didReceiveResponse (NSURLResponse* response, + void (^completionHandler) (NSURLSessionResponseDisposition)) { { - const ScopedLock sl (dataLock); - if (isBeingDeleted) - return; + const std::scoped_lock lock { mutex }; - [data setLength: 0]; + contentLength = [response expectedContentLength]; + + if ([response isKindOfClass: [NSHTTPURLResponse class]]) + { + auto httpResponse = (NSHTTPURLResponse*) response; + headers.reset ([[httpResponse allHeaderFields] retain]); + statusCode = (int) [httpResponse statusCode]; + } + + if (state == State::beforeStart) + state = State::started; } - contentLength = [response expectedContentLength]; - - [headers release]; - headers = nil; - - if ([response isKindOfClass: [NSHTTPURLResponse class]]) - { - auto httpResponse = (NSHTTPURLResponse*) response; - headers = [[httpResponse allHeaderFields] retain]; - statusCode = (int) [httpResponse statusCode]; - } - - initialised = true; - - if (completionHandler != nil) - { - // Need to wrangle this parameter back into an obj-C block, - // and call it to allow the transfer to continue.. - void (^callbackBlock)(NSURLSessionResponseDisposition) = completionHandler; - callbackBlock (NSURLSessionResponseAllow); - } + condvar.notify_one(); + completionHandler (NSURLSessionResponseAllow); } - void didComplete (NSError* error) + void didComplete ([[maybe_unused]] NSError* error) { - const ScopedLock sl (dataLock); + { + const std::scoped_lock lock { mutex }; - if (isBeingDeleted) - return; + if (state != State::invalidated) + state = State::requestFinished; + } + + condvar.notify_one(); #if JUCE_DEBUG if (error != nullptr) DBG (nsStringToJuce ([error description])); #endif + } - hasFailed = (error != nullptr); - initialised = true; - signalThreadShouldExit(); + void didBecomeInvalid ([[maybe_unused]] NSError* error) + { + { + const std::scoped_lock lock { mutex }; + state = State::invalidated; + } + + condvar.notify_one(); + + #if JUCE_DEBUG + if (error != nullptr) + DBG (nsStringToJuce ([error description])); + #endif } void didReceiveData (NSData* newData) { - const ScopedLock sl (dataLock); + { + const std::scoped_lock lock { mutex }; + [data.get() appendData: newData]; - if (isBeingDeleted) - return; + if (state == State::beforeStart) + state = State::started; + } - [data appendData: newData]; - initialised = true; + condvar.notify_one(); } void didSendBodyData (int64_t totalBytesWritten) { - latestTotalBytes = static_cast (totalBytesWritten); + const std::scoped_lock lock { mutex }; + latestTotalBytes = totalBytesWritten; } - void willPerformHTTPRedirection (NSURLRequest* urlRequest, void (^completionHandler)(NSURLRequest *)) + void willPerformHTTPRedirection (NSURLRequest* urlRequest, void (^completionHandler) (NSURLRequest *)) { - { - const ScopedLock sl (dataLock); - - if (isBeingDeleted) - return; - } - - completionHandler (numRedirects++ < numRedirectsToFollow ? urlRequest : nil); + // No lock required here because numRedirects is only accessed from the session's work queue + // after the task has started. + completionHandler (--numRedirects >= 0 ? urlRequest : nil); } - void run() override - { - jassert (task == nil && session == nil); - - session = [[NSURLSession sessionWithConfiguration: [NSURLSessionConfiguration defaultSessionConfiguration] - delegate: delegate - delegateQueue: [NSOperationQueue currentQueue]] retain]; - - { - const ScopedLock lock (createConnectionLock); - - if (! hasBeenCancelled) - task = [session dataTaskWithRequest: request]; - } - - if (task == nil) - return; - - [task retain]; - [task resume]; - - while (! threadShouldExit()) - wait (5); - - hasFinished = true; - initialised = true; - } - -private: //============================================================================== - struct DelegateClass final : public ObjCClass + struct DelegateClass final : public ObjCClass> { - DelegateClass() : ObjCClass ("JUCE_URLDelegate_") + DelegateClass() + : ObjCClass ("JUCE_URLDelegate_") { addIvar ("state"); addMethod (@selector (URLSession:dataTask:didReceiveResponse:completionHandler:), - didReceiveResponse); - addMethod (@selector (URLSession:didBecomeInvalidWithError:), didBecomeInvalidWithError); - addMethod (@selector (URLSession:dataTask:didReceiveData:), didReceiveData); + [] (id self, + SEL, + NSURLSession*, + NSURLSessionDataTask*, + NSURLResponse* response, + void (^completionHandler) (NSURLSessionResponseDisposition)) + { + getState (self)->didReceiveResponse (response, completionHandler); + }); + + addMethod (@selector (URLSession:didBecomeInvalidWithError:), + [] (id self, SEL, NSURLSession*, NSError* error) + { + getState (self)->didBecomeInvalid (error); + }); + + addMethod (@selector (URLSession:dataTask:didReceiveData:), + [] (id self, SEL, NSURLSession*, NSURLSessionDataTask*, NSData* newData) + { + getState (self)->didReceiveData (newData); + }); + addMethod (@selector (URLSession:task:didSendBodyData:totalBytesSent:totalBytesExpectedToSend:), - didSendBodyData); + [] (id self, + SEL, + NSURLSession*, + NSURLSessionTask*, + int64_t, + int64_t totalBytesWritten, + int64_t) + { + getState (self)->didSendBodyData (totalBytesWritten); + }); + addMethod (@selector (URLSession:task:willPerformHTTPRedirection:newRequest:completionHandler:), - willPerformHTTPRedirection); - addMethod (@selector (URLSession:task:didCompleteWithError:), didCompleteWithError); + [] (id self, + SEL, + NSURLSession*, + NSURLSessionTask*, + NSHTTPURLResponse*, + NSURLRequest* req, + void (^completionHandler) (NSURLRequest *)) + { + getState (self)->willPerformHTTPRedirection (req, completionHandler); + }); + + addMethod (@selector (URLSession:task:didCompleteWithError:), + [] (id self, SEL, NSURLConnection*, NSURLSessionTask*, NSError* error) + { + getState (self)->didComplete (error); + }); registerClass(); } - static void setState (id self, URLConnectionState* state) { object_setInstanceVariable (self, "state", state); } - static URLConnectionState* getState (id self) { return getIvar (self, "state"); } - - private: - static void didReceiveResponse (id self, SEL, NSURLSession*, NSURLSessionDataTask*, NSURLResponse* response, id completionHandler) - { - if (auto state = getState (self)) - state->didReceiveResponse (response, completionHandler); - } - - static void didBecomeInvalidWithError (id self, SEL, NSURLSession*, NSError* error) - { - if (auto state = getState (self)) - state->didComplete (error); - } - - static void didReceiveData (id self, SEL, NSURLSession*, NSURLSessionDataTask*, NSData* newData) - { - if (auto state = getState (self)) - state->didReceiveData (newData); - } - - static void didSendBodyData (id self, SEL, NSURLSession*, NSURLSessionTask*, int64_t, int64_t totalBytesWritten, int64_t) - { - if (auto state = getState (self)) - state->didSendBodyData (totalBytesWritten); - } - - static void willPerformHTTPRedirection (id self, SEL, NSURLSession*, NSURLSessionTask*, NSHTTPURLResponse*, - NSURLRequest* request, void (^completionHandler)(NSURLRequest *)) - { - if (auto state = getState (self)) - state->willPerformHTTPRedirection (request, completionHandler); - } - - static void didCompleteWithError (id self, SEL, NSURLConnection*, NSURLSessionTask*, NSError* error) - { - if (auto state = getState (self)) - state->didComplete (error); - } + static void setState (NSObject* self, URLConnectionState* state) { object_setInstanceVariable (self, "state", state); } + static URLConnectionState* getState (NSObject* self) { return getIvar (self, "state"); } }; - NSURLSession* session = nil; - NSURLSessionTask* task = nil; - bool isBeingDeleted = false; + static DelegateClass& getDelegateClass() + { + static DelegateClass cls; + return cls; + } + + enum class State + { + beforeStart, + started, + requestFinished, + invalidated, + }; + + mutable std::mutex mutex; + std::condition_variable condvar; + + NSUniquePtr headers; + NSUniquePtr request; + NSUniquePtr> delegate { [getDelegateClass().createInstance() init] }; + NSUniquePtr data { [[NSMutableData data] retain] }; + NSUniquePtr session + { + [[NSURLSession sessionWithConfiguration: [NSURLSessionConfiguration defaultSessionConfiguration] + delegate: delegate.get() + delegateQueue: nil] retain] + }; + NSUniquePtr task { [[session.get() dataTaskWithRequest: request.get()] retain] }; + + int64 latestTotalBytes = 0; + int64 contentLength = -1; + int statusCode = 0; + int numRedirects = 0; + State state = State::beforeStart; JUCE_DECLARE_NON_COPYABLE_WITH_LEAK_DETECTOR (URLConnectionState) }; @@ -713,11 +677,11 @@ class WebInputStream::Pimpl { public: Pimpl (WebInputStream& pimplOwner, const URL& urlToUse, bool addParametersToBody) - : owner (pimplOwner), - url (urlToUse), - addParametersToRequestBody (addParametersToBody), - hasBodyDataToSend (addParametersToRequestBody || url.hasBodyDataToSend()), - httpRequestCmd (hasBodyDataToSend ? "POST" : "GET") + : owner (pimplOwner), + url (urlToUse), + addParametersToRequestBody (addParametersToBody), + hasBodyDataToSend (addParametersToRequestBody || url.hasBodyDataToSend()), + httpRequestCmd (hasBodyDataToSend ? "POST" : "GET") { } @@ -737,7 +701,7 @@ public: createConnection(); } - if (connection == nullptr) + if (! connection.has_value()) return false; if (! connection->start (owner, webInputListener)) @@ -746,32 +710,30 @@ public: return false; } - if (auto* connectionHeaders = connection->getHeaders()) - { - statusCode = connection->getStatusCode(); + const auto connectionHeaders = connection->getHeaders(); - NSEnumerator* enumerator = [connectionHeaders keyEnumerator]; + if (connectionHeaders == nullptr) + return false; - while (NSString* key = [enumerator nextObject]) - responseHeaders.set (nsStringToJuce (key), - nsStringToJuce ((NSString*) [connectionHeaders objectForKey: key])); + statusCode = connection->getStatusCode(); - return true; - } + NSEnumerator* enumerator = [connectionHeaders.get() keyEnumerator]; - return false; + while (NSString* key = [enumerator nextObject]) + responseHeaders.set (nsStringToJuce (key), + nsStringToJuce ((NSString*) [connectionHeaders.get() objectForKey: key])); + + return true; } void cancel() { - { - const ScopedLock lock (createConnectionLock); + const ScopedLock lock (createConnectionLock); - if (connection != nullptr) - connection->cancel(); + if (connection.has_value()) + connection->cancel(); - hasBeenCancelled = true; - } + hasBeenCancelled = true; } //============================================================================== @@ -795,8 +757,8 @@ public: int getStatusCode() const { return statusCode; } //============================================================================== - bool isError() const { return (connection == nullptr || connection->getHeaders() == nullptr); } - int64 getTotalLength() { return connection == nullptr ? -1 : connection->getContentLength(); } + bool isError() const { return (! connection.has_value() || connection->getHeaders() == nullptr); } + int64 getTotalLength() { return ! connection.has_value() ? -1 : connection->getContentLength(); } bool isExhausted() { return finished; } int64 getPosition() { return position; } @@ -844,7 +806,7 @@ public: private: WebInputStream& owner; URL url; - std::unique_ptr connection; + std::optional connection; String headers; MemoryBlock postData; int64 position = 0; @@ -859,58 +821,64 @@ private: void createConnection() { - jassert (connection == nullptr); + jassert (! connection.has_value()); - if (NSURL* nsURL = [NSURL URLWithString: juceStringToNS (url.toString (! addParametersToRequestBody))]) + NSUniquePtr nsURL { [[NSURL URLWithString: juceStringToNS (url.toString (! addParametersToRequestBody))] retain] }; + + if (nsURL == nullptr) + return; + + const auto timeOutSeconds = [this] { - const auto timeOutSeconds = [this] - { - if (timeOutMs > 0) - return timeOutMs / 1000.0; + if (timeOutMs > 0) + return timeOutMs / 1000.0; - return timeOutMs < 0 ? std::numeric_limits::infinity() : 60.0; - }(); + return timeOutMs < 0 ? std::numeric_limits::infinity() : 60.0; + }(); - if (NSMutableURLRequest* req = [NSMutableURLRequest requestWithURL: nsURL - cachePolicy: NSURLRequestReloadIgnoringLocalCacheData - timeoutInterval: timeOutSeconds]) - { - if (NSString* httpMethod = [NSString stringWithUTF8String: httpRequestCmd.toRawUTF8()]) - { - [req setHTTPMethod: httpMethod]; + NSUniquePtr req { [[NSMutableURLRequest requestWithURL: nsURL.get() + cachePolicy: NSURLRequestReloadIgnoringLocalCacheData + timeoutInterval: timeOutSeconds] retain] }; - if (hasBodyDataToSend) - { - WebInputStream::createHeadersAndPostData (url, - headers, - postData, - addParametersToRequestBody); + if (req == nullptr) + return; - if (! postData.isEmpty()) - [req setHTTPBody: [NSData dataWithBytes: postData.getData() - length: postData.getSize()]]; - } + NSUniquePtr httpMethod { [[NSString stringWithUTF8String: httpRequestCmd.toRawUTF8()] retain] }; - StringArray headerLines; - headerLines.addLines (headers); - headerLines.removeEmptyStrings (true); + if (httpMethod == nullptr) + return; - for (int i = 0; i < headerLines.size(); ++i) - { - auto key = headerLines[i].upToFirstOccurrenceOf (":", false, false).trim(); - auto value = headerLines[i].fromFirstOccurrenceOf (":", false, false).trim(); + [req.get() setHTTPMethod: httpMethod.get()]; - if (key.isNotEmpty() && value.isNotEmpty()) - [req addValue: juceStringToNS (value) forHTTPHeaderField: juceStringToNS (key)]; - } + if (hasBodyDataToSend) + { + WebInputStream::createHeadersAndPostData (url, + headers, + postData, + addParametersToRequestBody); - // Workaround for an Apple bug. See https://github.com/AFNetworking/AFNetworking/issues/2334 - [req HTTPBody]; - - connection = std::make_unique (req, numRedirectsToFollow); - } - } + if (! postData.isEmpty()) + [req.get() setHTTPBody: [NSData dataWithBytes: postData.getData() + length: postData.getSize()]]; } + + StringArray headerLines; + headerLines.addLines (headers); + headerLines.removeEmptyStrings (true); + + for (int i = 0; i < headerLines.size(); ++i) + { + auto key = headerLines[i].upToFirstOccurrenceOf (":", false, false).trim(); + auto value = headerLines[i].fromFirstOccurrenceOf (":", false, false).trim(); + + if (key.isNotEmpty() && value.isNotEmpty()) + [req.get() addValue: juceStringToNS (value) forHTTPHeaderField: juceStringToNS (key)]; + } + + // Workaround for an Apple bug. See https://github.com/AFNetworking/AFNetworking/issues/2334 + [req.get() HTTPBody]; + + connection.emplace (std::move (req), numRedirectsToFollow); } JUCE_DECLARE_NON_COPYABLE_WITH_LEAK_DETECTOR (Pimpl)