diff --git a/source/graphics/MapReader.cpp b/source/graphics/MapReader.cpp index 919cbb556f..324592d239 100644 --- a/source/graphics/MapReader.cpp +++ b/source/graphics/MapReader.cpp @@ -1379,7 +1379,7 @@ int CMapReader::StartMapGeneration(const CStrW& scriptFile) m_GeneratorState = std::make_unique(); // The settings are stringified to pass them to the task. - m_GeneratorState->task = g_TaskManager.PushTask( + m_GeneratorState->task = {g_TaskManager, [&progress = m_GeneratorState->progress, scriptFile, settings = Script::StringifyJSON(rq, &m_ScriptSettings)](const StopToken stopToken) { @@ -1398,7 +1398,7 @@ int CMapReader::StartMapGeneration(const CStrW& scriptFile) }}; return RunMapGenerationScript(stopToken, progress, mapgenInterface, scriptPath, settings); - }); + }}; return 0; } diff --git a/source/graphics/TextureConverter.cpp b/source/graphics/TextureConverter.cpp index 9ce6b12a16..caf3899105 100644 --- a/source/graphics/TextureConverter.cpp +++ b/source/graphics/TextureConverter.cpp @@ -471,7 +471,7 @@ bool CTextureConverter::ConvertTexture(const CTexturePtr& texture, const VfsPath delete[] rgba; } - m_ResultQueue.push(g_TaskManager.PushTask([request = std::move(request)] + m_ResultQueue.push({g_TaskManager, [request = std::move(request)] { PROFILE2("compress"); // Set up the result object @@ -487,7 +487,7 @@ bool CTextureConverter::ConvertTexture(const CTexturePtr& texture, const VfsPath request->outputOptions); return result; - }, Threading::TaskPriority::LOW)); + }, Threading::TaskPriority::LOW}); return true; diff --git a/source/ps/Future.h b/source/ps/Future.h index 984ff12b95..b43a820862 100644 --- a/source/ps/Future.h +++ b/source/ps/Future.h @@ -144,93 +144,6 @@ struct SharedState } // namespace FutureSharedStateDetail -/** - * Corresponds to std::future. - * Unlike std::future, Future can request the cancellation of the task that would produce the result. - * This makes it more similar to Java's CancellableTask or C#'s Task. - * The name Future was kept over Task so it would be more familiar to C++ users, - * but this all should be revised once Concurrency TS wraps up. - * - * Future is _not_ thread-safe. Call it from a single thread or ensure synchronization externally. - * - * The callback never runs after the @p Future is destroyed. - */ -template -class Future -{ - template - friend class PackagedTask; - -public: - Future() = default; - Future(const Future& o) = delete; - Future(Future&&) = default; - Future& operator=(Future&& other) - { - CancelOrWait(); - m_Receiver = std::move(other.m_Receiver); - return *this; - } - ~Future() - { - CancelOrWait(); - } - - /** - * Make the future wait for the result of @a callback. - */ - template - PackagedTask Wrap(Callback&& callback); - - /** - * Move the result out of the future, and invalidate the future. - * If the future is not complete, calls Wait(). - * If the future is invalid, asserts. - */ - ResultType Get() - { - ENSURE(!!m_Receiver); - - Wait(); - // This mark the state invalid - can't call Get again. - return std::exchange(m_Receiver, nullptr)->GetResult(); - } - - /** - * @return true if the shared state is valid and has a result (i.e. Get can be called). - */ - bool IsDone() const - { - return !!m_Receiver && m_Receiver->IsDone(); - } - - /** - * @return true if the future has a shared state and it's not been invalidated, ie. pending, started or done. - */ - bool Valid() const - { - return !!m_Receiver; - } - - void Wait() - { - if (Valid()) - m_Receiver->Wait(); - } - - void CancelOrWait() - { - if (!Valid()) - return; - m_Receiver->RequestStop(); - m_Receiver->Wait(); - m_Receiver.reset(); - } - -protected: - std::shared_ptr> m_Receiver; -}; - /** * Corresponds somewhat to std::packaged_task. * Like packaged_task, this holds a function acting as a promise. @@ -295,18 +208,107 @@ private: std::shared_ptr> m_SharedState; }; +/** + * Corresponds to std::future. + * Unlike std::future, Future can request the cancellation of the task that would produce the result. + * This makes it more similar to Java's CancellableTask or C#'s Task. + * The name Future was kept over Task so it would be more familiar to C++ users, + * but this all should be revised once Concurrency TS wraps up. + * + * Future is _not_ thread-safe. Call it from a single thread or ensure synchronization externally. + * + * The callback never runs after the @p Future is destroyed. + */ template -template -PackagedTask Future::Wrap(Callback&& callback) +class Future { - static_assert(std::is_same_v, ResultType>, + template + friend class PackagedTask; + +public: + Future() = default; + Future(const Future& o) = delete; + Future(Future&&) = default; + Future& operator=(Future&& other) + { + CancelOrWait(); + m_Receiver = std::move(other.m_Receiver); + return *this; + } + + /** + * Make the future wait for the result of @a callback. + */ + template + Future(auto& taskManager, Callback&& callback, Args&&... args) + { + static_assert(std::is_same_v, ResultType>, "The return type of the wrapped function is not the same as the type the Future expects."); - static_assert(std::is_invocable_v || !std::is_invocable_v, - "Consider taking the `StopToken` by value"); - CancelOrWait(); - auto temp = std::make_shared>(std::move(callback)); - m_Receiver = {temp, &temp->receiver}; - return PackagedTask(std::move(temp)); -} + static_assert(std::is_invocable_v || !std::is_invocable_v, + "Consider taking the `StopToken` by value"); + + auto temp = std::make_shared>( + std::forward(callback)); + m_Receiver = {temp, &temp->receiver}; + + taskManager.PushTask(PackagedTask(std::move(temp)), std::forward(args)...); + } + + ~Future() + { + CancelOrWait(); + } + + /** + * Move the result out of the future, and invalidate the future. + * If the future is not complete, calls Wait(). + * If the future is invalid, asserts. + */ + ResultType Get() + { + ENSURE(!!m_Receiver); + + Wait(); + // This mark the state invalid - can't call Get again. + return std::exchange(m_Receiver, nullptr)->GetResult(); + } + + /** + * @return true if the shared state is valid and has a result (i.e. Get can be called). + */ + bool IsDone() const + { + return !!m_Receiver && m_Receiver->IsDone(); + } + + /** + * @return true if the future has a shared state and it's not been invalidated, ie. pending, started or done. + */ + bool Valid() const + { + return !!m_Receiver; + } + + void Wait() + { + if (Valid()) + m_Receiver->Wait(); + } + + void CancelOrWait() + { + if (!Valid()) + return; + m_Receiver->RequestStop(); + m_Receiver->Wait(); + m_Receiver.reset(); + } + +protected: + std::shared_ptr> m_Receiver; +}; + +template +Future(auto& taskManager, Callback&& callback, Args&&... args) -> Future>; #endif // INCLUDED_FUTURE diff --git a/source/ps/Profiler2.cpp b/source/ps/Profiler2.cpp index 33f71a162e..b7334bd2b3 100644 --- a/source/ps/Profiler2.cpp +++ b/source/ps/Profiler2.cpp @@ -41,6 +41,7 @@ #include #include #include +#include #include #include #include @@ -107,18 +108,11 @@ static void* MgCallback(mg_event event, struct mg_connection *conn, const struct std::string uri = request_info->uri; if (uri == "/download") - { - Threading::TaskManager::GetSingleton().PushTask([&] - { - profiler->SaveToFile(); - }).Get(); - } + Future{g_TaskManager, std::bind_front(&CProfiler2::SaveToFile, profiler)}.Get(); else if (uri == "/overview") { - Threading::TaskManager::GetSingleton().PushTask([&] - { - profiler->ConstructJSONOverview(stream); - }).Get(); + Future{g_TaskManager, + std::bind_front(&CProfiler2::ConstructJSONOverview, profiler, std::ref(stream))}.Get(); } else if (uri == "/query") { @@ -138,10 +132,10 @@ static void* MgCallback(mg_event event, struct mg_connection *conn, const struct } std::string thread(buf); - const char* err = Threading::TaskManager::GetSingleton().PushTask([&] - { - return profiler->ConstructJSONResponse(stream, thread); - }).Get(); + const char* err = Future{g_TaskManager, + std::bind_front(&CProfiler2::ConstructJSONResponse, profiler, std::ref(stream), + std::ref(thread))}.Get(); + if (err) { mg_printf(conn, "%s (%s)", header400, err); diff --git a/source/ps/TaskManager.cpp b/source/ps/TaskManager.cpp index 3567ba6f16..20277797d1 100644 --- a/source/ps/TaskManager.cpp +++ b/source/ps/TaskManager.cpp @@ -185,7 +185,7 @@ size_t TaskManager::GetNumberOfWorkers() const return m->m_Workers.size(); } -void TaskManager::DoPushTask(std::function&& task, TaskPriority priority) +void TaskManager::PushTask(std::function task, TaskPriority priority) { m->PushTask(std::move(task), priority); } diff --git a/source/ps/TaskManager.h b/source/ps/TaskManager.h index 9cf868191f..1f6c6beb1c 100644 --- a/source/ps/TaskManager.h +++ b/source/ps/TaskManager.h @@ -18,7 +18,6 @@ #ifndef INCLUDED_THREADING_TASKMANAGER #define INCLUDED_THREADING_TASKMANAGER -#include "ps/Future.h" #include "ps/Singleton.h" #include @@ -59,19 +58,11 @@ public: /** * Push a task to be executed. */ - template - Future> PushTask(T&& func, TaskPriority priority = TaskPriority::NORMAL) - { - Future> ret; - DoPushTask(ret.Wrap(std::move(func)), priority); - return ret; - } + void PushTask(std::function func, TaskPriority priority = TaskPriority::NORMAL); private: TaskManager(size_t numberOfWorkers); - void DoPushTask(std::function&& task, TaskPriority priority); - class Impl; const std::unique_ptr m; }; diff --git a/source/ps/tests/test_Future.h b/source/ps/tests/test_Future.h index 4cc8c29360..ff57778220 100644 --- a/source/ps/tests/test_Future.h +++ b/source/ps/tests/test_Future.h @@ -25,25 +25,37 @@ #include #include #include +#include class TestFuture : public CxxTest::TestSuite { public: + struct TestTaskManager + { + std::vector> tasks; + void PushTask(std::function task) + { + tasks.push_back(std::move(task)); + } + }; + void test_future_basic() { bool executed{false}; - Future noret; - auto task = noret.Wrap([&]{ executed = true; }); - task(); + TestTaskManager ttm; + Future noret{ttm, [&]{ executed = true; }}; + TS_ASSERT_EQUALS(ttm.tasks.size(), 1); + std::exchange(ttm.tasks, {})[0](); TS_ASSERT(executed); } void test_future_return() { + TestTaskManager ttm; { - Future future; - std::function task = future.Wrap([]() { return 1; }); - task(); + Future future{ttm, []{ return 1; }}; + TS_ASSERT_EQUALS(ttm.tasks.size(), 1); + std::exchange(ttm.tasks, {})[0](); TS_ASSERT_EQUALS(future.Get(), 1); } @@ -64,9 +76,9 @@ public: }; TS_ASSERT_EQUALS(destroyed, 0); { - Future future; - std::function task = future.Wrap([]() { return NonDef{1}; }); - task(); + Future future{ttm, []{ return NonDef{1}; }}; + TS_ASSERT_EQUALS(ttm.tasks.size(), 1); + std::exchange(ttm.tasks, {})[0](); TS_ASSERT_EQUALS(future.Get().value, 1); } TS_ASSERT_EQUALS(destroyed, 1); @@ -86,6 +98,7 @@ public: { Future future; std::function function; + TestTaskManager ttm; // Set things up so all temporaries passed into the futures will be reset to obviously invalid memory. std::aligned_storage_t), alignof(Future)> futureStorage; @@ -96,14 +109,14 @@ public: c = new (&functionStorage) std::function{}; *c = []() { return 7; }; - std::function task = f->Wrap(std::move(*c)); + *f = {ttm, std::move(*c)}; future = std::move(*f); function = std::move(*c); - // Let's move the packaged task while at it. - std::function task2 = std::move(task); - task2(); + TS_ASSERT_EQUALS(ttm.tasks.size(), 1); + std::exchange(ttm.tasks, {})[0](); + TS_ASSERT_EQUALS(future.Get(), 7); // Destroy and clear the memory @@ -115,8 +128,6 @@ public: void test_move_only_function() { - Future future; - class MoveOnlyType { public: @@ -128,8 +139,11 @@ public: int fn() const { return 7; } }; - auto task = future.Wrap([t = MoveOnlyType{}]{ return t.fn(); }); - task(); + TestTaskManager ttm; + + Future future{ttm, [t = MoveOnlyType{}]{ return t.fn(); }}; + TS_ASSERT_EQUALS(ttm.tasks.size(), 1); + std::exchange(ttm.tasks, {})[0](); TS_ASSERT_EQUALS(future.Get(), 7); } @@ -141,26 +155,28 @@ public: void test_exception() { - Future future; - auto packedTask = future.Wrap([]() -> int + TestTaskManager ttm; + Future future{ttm, []() -> int { throw TestException{}; - }); + }}; - packedTask(); + TS_ASSERT_EQUALS(ttm.tasks.size(), 1); + std::exchange(ttm.tasks, {})[0](); TS_ASSERT(future.IsDone()); TS_ASSERT_THROWS(future.Get(), const TestException&); } void test_voidException() { - Future future; - auto packedTask = future.Wrap([] + TestTaskManager ttm; + Future future{ttm, [] { throw TestException{}; - }); + }}; - packedTask(); + TS_ASSERT_EQUALS(ttm.tasks.size(), 1); + std::exchange(ttm.tasks, {})[0](); TS_ASSERT(future.IsDone()); TS_ASSERT_THROWS(future.Get(), const TestException&); } @@ -180,19 +196,22 @@ public: } }; - Future future; - auto packedTask = future.Wrap([] + TestTaskManager ttm; + + Future future{ttm, [] { return ThrowsOnMove{}; - }); + }}; - packedTask(); + TS_ASSERT_EQUALS(ttm.tasks.size(), 1); + std::exchange(ttm.tasks, {})[0](); TS_ASSERT(future.IsDone()); TS_ASSERT_THROWS(future.Get(), const TestException&); } void test_stop_token_overload() { + TestTaskManager ttm; { class DifferentValues { @@ -207,8 +226,9 @@ public: } }; - Future future; - future.Wrap(DifferentValues{})(); + Future future{ttm, DifferentValues{}}; + TS_ASSERT_EQUALS(ttm.tasks.size(), 1); + std::exchange(ttm.tasks, {})[0](); TS_ASSERT_EQUALS(future.Get(), true); } { @@ -223,8 +243,9 @@ public: } }; - Future future; - future.Wrap(DifferentTypes{})(); + Future future{ttm, DifferentTypes{}}; + TS_ASSERT_EQUALS(ttm.tasks.size(), 1); + std::exchange(ttm.tasks, {})[0](); } } }; diff --git a/source/ps/tests/test_TaskManager.h b/source/ps/tests/test_TaskManager.h index 61e7df2340..4b6c7c0a52 100644 --- a/source/ps/tests/test_TaskManager.h +++ b/source/ps/tests/test_TaskManager.h @@ -37,7 +37,7 @@ public: std::atomic tasks_run = 0; auto increment_run = [&tasks_run]() { tasks_run++; }; - Future future = g_TaskManager.PushTask(increment_run); + Future future{g_TaskManager, increment_run}; future.Wait(); TS_ASSERT_EQUALS(tasks_run.load(), 1); @@ -45,7 +45,7 @@ public: std::condition_variable cv; std::mutex mutex; std::atomic go = false; - future = g_TaskManager.PushTask([&]() { + future = {g_TaskManager, [&]{ std::unique_lock lock(mutex); cv.wait(lock, [&go]() -> bool { return go; }); lock.unlock(); @@ -54,7 +54,7 @@ public: go = false; lock.unlock(); cv.notify_all(); - }); + }}; TS_ASSERT_EQUALS(tasks_run.load(), 1); std::unique_lock lock(mutex); go = true; @@ -72,15 +72,15 @@ public: std::atomic tasks_run = 0; // Push general tasks auto increment_run = [&tasks_run]() { tasks_run++; }; - Future future = g_TaskManager.PushTask(increment_run); - Future futureLow = g_TaskManager.PushTask(increment_run, Threading::TaskPriority::LOW); + Future future = {g_TaskManager, increment_run}; + Future futureLow = {g_TaskManager, increment_run, Threading::TaskPriority::LOW}; future.Wait(); futureLow.Wait(); TS_ASSERT_EQUALS(tasks_run.load(), 2); // Also check with no waiting expected. - g_TaskManager.PushTask(increment_run).Wait(); + Future{g_TaskManager, increment_run}.Wait(); TS_ASSERT_EQUALS(tasks_run.load(), 3); - g_TaskManager.PushTask(increment_run, Threading::TaskPriority::LOW).Wait(); + Future{g_TaskManager, increment_run, Threading::TaskPriority::LOW}.Wait(); TS_ASSERT_EQUALS(tasks_run.load(), 4); } @@ -91,20 +91,20 @@ public: futures.resize(ITERATIONS); std::vector values(ITERATIONS); - auto f1 = g_TaskManager.PushTask([&futures]() { + Future f1{g_TaskManager, [&futures]{ for (u32 i = 0; i < ITERATIONS; i+=3) - futures[i] = g_TaskManager.PushTask([]() { return 5; }); - }); + futures[i] = {g_TaskManager, []{ return 5; }}; + }}; - auto f2 = g_TaskManager.PushTask([&futures]() { + Future f2{g_TaskManager, [&futures]{ for (u32 i = 1; i < ITERATIONS; i+=3) - futures[i] = g_TaskManager.PushTask([]() { return 5; }, Threading::TaskPriority::LOW); - }); + futures[i] = {g_TaskManager, []{ return 5; }, Threading::TaskPriority::LOW}; + }}; - auto f3 = g_TaskManager.PushTask([&futures]() { + Future f3{g_TaskManager, [&futures]{ for (u32 i = 2; i < ITERATIONS; i+=3) - futures[i] = g_TaskManager.PushTask([]() { return 5; }); - }); + futures[i] = {g_TaskManager, []{ return 5; }}; + }}; f1.Wait(); f2.Wait(); diff --git a/source/simulation2/components/CCmpPathfinder.cpp b/source/simulation2/components/CCmpPathfinder.cpp index 92169fa113..14137f0a7d 100644 --- a/source/simulation2/components/CCmpPathfinder.cpp +++ b/source/simulation2/components/CCmpPathfinder.cpp @@ -861,13 +861,13 @@ void CCmpPathfinder::StartProcessingMoves(bool useMax) ENSURE(!m_Futures[i].Valid()); // Pass the i+1th vertex pathfinder to keep the first for the main thread, // each thread get its own instance to avoid conflicts in cached data. - m_Futures[i] = g_TaskManager.PushTask( + m_Futures[i] = {g_TaskManager, [&pathfinder=*this, &vertexPfr=m_VertexPathfinders[i + 1]]() { PROFILE2("Async pathfinding"); pathfinder.m_ShortPathRequests.Compute(pathfinder, vertexPfr); pathfinder.m_LongPathRequests.Compute(pathfinder, *pathfinder.m_LongPathfinder); - }); + }}; } }