Allow the future request stop from the callback

In `Future` there is a notion of cancelation / stop-request. The task
callback doesn't have such a notion.
Some tasks (like the map-generation) are stopable. It did that in a
thread unsave way.

A task is canceled when the future is destroied or when `CancelOrWait`
is called on it.
This commit is contained in:
phosit 2024-10-17 19:50:11 +02:00
parent 3e4238876f
commit 778972602a
8 changed files with 154 additions and 121 deletions

View file

@ -28,6 +28,7 @@
#include "maths/MathUtil.h"
#include "ps/CLogger.h"
#include "ps/FileIo.h"
#include "ps/Future.h"
#include "ps/scripting/JSInterface_VFS.h"
#include "ps/TemplateLoader.h"
#include "scriptinterface/FunctionWrapper.h"
@ -43,23 +44,11 @@
#include <string>
#include <vector>
extern bool IsQuitRequested();
namespace
{
constexpr const char* GENERATOR_NAME{"GenerateMap"};
bool MapGenerationInterruptCallback(JSContext* UNUSED(cx))
{
// This may not use SDL_IsQuitRequested(), because it runs in a thread separate to SDL, see SDL_PumpEvents
if (IsQuitRequested())
{
LOGWARNING("Quit requested!");
return false;
}
return true;
}
bool MapGenerationInterruptCallback(JSContext* cx);
/**
* Provides callback's for the JavaScript.
@ -69,8 +58,9 @@ class CMapGenerationCallbacks
public:
// Only the constructor and the destructor are called by C++.
CMapGenerationCallbacks(std::atomic<int>& progress, ScriptInterface& scriptInterface,
Script::StructuredClone& mapData, const u16 flags) :
CMapGenerationCallbacks(const StopToken stopToken, std::atomic<int>& progress,
ScriptInterface& scriptInterface, Script::StructuredClone& mapData, const u16 flags) :
m_StopToken{stopToken},
m_Progress{progress},
m_ScriptInterface{scriptInterface},
m_MapData{mapData}
@ -128,6 +118,8 @@ public:
m_ScriptInterface.SetCallbackData(nullptr);
}
StopToken m_StopToken;
private:
// These functions are called by JS.
@ -374,10 +366,16 @@ private:
*/
CTemplateLoader m_TemplateLoader;
};
bool MapGenerationInterruptCallback(JSContext* cx)
{
return !ScriptInterface::ObjectFromCBData<CMapGenerationCallbacks>(
ScriptInterface::CmptPrivate::GetScriptInterface(cx))->m_StopToken.IsStopRequested();
}
} // anonymous namespace
Script::StructuredClone RunMapGenerationScript(std::atomic<int>& progress, ScriptInterface& scriptInterface,
const VfsPath& script, const std::string& settings, const u16 flags)
Script::StructuredClone RunMapGenerationScript(const StopToken stopToken, std::atomic<int>& progress,
ScriptInterface& scriptInterface, const VfsPath& script, const std::string& settings, const u16 flags)
{
ScriptRequest rq(scriptInterface);
@ -406,7 +404,7 @@ Script::StructuredClone RunMapGenerationScript(std::atomic<int>& progress, Scrip
scriptInterface.ReplaceNondeterministicRNG(mapGenRNG);
Script::StructuredClone mapData;
CMapGenerationCallbacks callbackData{progress, scriptInterface, mapData, flags};
CMapGenerationCallbacks callbackData{stopToken, progress, scriptInterface, mapData, flags};
// Copy settings to global variable
JS::RootedValue global(rq.cx, rq.globalValue());

View file

@ -1,4 +1,4 @@
/* Copyright (C) 2023 Wildfire Games.
/* Copyright (C) 2024 Wildfire Games.
* This file is part of 0 A.D.
*
* 0 A.D. is free software: you can redistribute it and/or modify
@ -19,6 +19,7 @@
#define INCLUDED_MAPGENERATOR
#include "lib/file/vfs/vfs_path.h"
#include "ps/Future.h"
#include "scriptinterface/ScriptTypes.h"
#include "scriptinterface/StructuredClone.h"
@ -28,6 +29,8 @@
/**
* Generate the map. This does take a long time.
*
* @param st request to fastly stop the function. The returned value is
* unspecified.
* @param progress Destination to write the function progress to. You must not
* write to it while `RunMapGenerationScript` is running.
* @param script The VFS path for the script, e.g. "maps/random/latium.js".
@ -38,7 +41,7 @@
* data, according to this format:
* https://trac.wildfiregames.com/wiki/Random_Map_Generator_Internals#Dataformat
*/
Script::StructuredClone RunMapGenerationScript(std::atomic<int>& progress,
Script::StructuredClone RunMapGenerationScript(const StopToken stopToken, std::atomic<int>& progress,
ScriptInterface& scriptInterface, const VfsPath& script, const std::string& settings,
const u16 flags = JSPROP_ENUMERATE | JSPROP_READONLY | JSPROP_PERMANENT);

View file

@ -59,6 +59,8 @@
#include <boost/algorithm/string/predicate.hpp>
extern bool IsQuitRequested();
#if defined(_MSC_VER) && _MSC_VER > 1900
#pragma warning(disable: 4456) // Declaration hides previous local declaration.
#pragma warning(disable: 4458) // Declaration hides class member.
@ -1342,7 +1344,7 @@ int CMapReader::StartMapGeneration(const CStrW& scriptFile)
// The settings are stringified to pass them to the task.
m_GeneratorState->task = Threading::TaskManager::Instance().PushTask(
[&progress = m_GeneratorState->progress, scriptFile,
settings = Script::StringifyJSON(rq, &m_ScriptSettings)]
settings = Script::StringifyJSON(rq, &m_ScriptSettings)](const StopToken stopToken)
{
PROFILE2("Map Generation");
@ -1352,7 +1354,7 @@ int CMapReader::StartMapGeneration(const CStrW& scriptFile)
MAP_GENERATION_CONTEXT_SIZE)};
ScriptInterface mapgenInterface{"Engine", "MapGenerator", mapgenContext};
return RunMapGenerationScript(progress, mapgenInterface, scriptPath, settings);
return RunMapGenerationScript(stopToken, progress, mapgenInterface, scriptPath, settings);
});
return 0;
@ -1362,13 +1364,19 @@ int CMapReader::StartMapGeneration(const CStrW& scriptFile)
{
throw PSERROR_Game_World_MapLoadFailed{
"Error generating random map.\nCheck application log for details."};
};
}
int CMapReader::PollMapGeneration()
{
ENSURE(m_GeneratorState);
if (!m_GeneratorState->task.IsReady())
if (IsQuitRequested())
{
LOGWARNING("Quit requested!");
return -1;
}
if (!m_GeneratorState->task.IsDone())
return m_GeneratorState->progress.load();
const Script::StructuredClone results{m_GeneratorState->task.Get()};

View file

@ -488,7 +488,7 @@ bool CTextureConverter::ConvertTexture(const CTexturePtr& texture, const VfsPath
bool CTextureConverter::Poll(CTexturePtr& texture, VfsPath& dest, bool& ok)
{
#if CONFIG2_NVTT
if (m_ResultQueue.empty() || !m_ResultQueue.front().IsReady())
if (m_ResultQueue.empty() || !m_ResultQueue.front().IsDone())
{
// no work to do
return false;

View file

@ -17,6 +17,7 @@
#include "graphics/MapGenerator.h"
#include "ps/Filesystem.h"
#include "ps/Future.h"
#include "simulation2/system/ComponentTest.h"
#include <atomic>
@ -56,9 +57,10 @@ public:
ScriptTestSetup(scriptInterface);
std::atomic<int> progress{1};
const Script::StructuredClone result{RunMapGenerationScript(progress, scriptInterface,
path, "{\"Seed\": 0}", JSPROP_ENUMERATE | JSPROP_PERMANENT)};
std::atomic<bool> stopRequest{false};
const Script::StructuredClone result{RunMapGenerationScript(StopToken{stopRequest},
progress, scriptInterface, path, "{\"Seed\": 0}",
JSPROP_ENUMERATE | JSPROP_PERMANENT)};
TS_ASSERT_DIFFERS(result, nullptr);

View file

@ -31,16 +31,27 @@
template<typename Callback>
class PackagedTask;
namespace FutureSharedStateDetail
class StopToken
{
enum class Status
{
PENDING,
STARTED,
DONE,
CANCELED
public:
explicit StopToken(const std::atomic<bool>& request) noexcept :
m_Request{request}
{}
bool IsStopRequested() const noexcept
{
return m_Request.load();
}
private:
const std::atomic<bool>& m_Request;
};
template<typename Callback>
using CallbackResult = typename std::conditional_t<std::is_invocable_v<Callback, StopToken>,
std::invoke_result<Callback, StopToken>, std::invoke_result<Callback>>::type;
namespace FutureSharedStateDetail
{
template<typename T>
using ResultHolder = std::conditional_t<std::is_void_v<T>, std::nullopt_t, std::optional<T>>;
@ -50,51 +61,39 @@ using ResultHolder = std::conditional_t<std::is_void_v<T>, std::nullopt_t, std::
template<typename ResultType>
class Receiver
{
static constexpr bool VoidResult = std::is_same_v<ResultType, void>;
public:
Receiver() = default;
~Receiver()
{
ENSURE(IsDoneOrCanceled());
ENSURE(IsDone());
}
Receiver(const Receiver&) = delete;
Receiver(Receiver&&) = delete;
bool IsDoneOrCanceled() const
bool IsDone() const noexcept
{
return m_Status == Status::DONE || m_Status == Status::CANCELED;
return m_Done.load();
}
void Wait()
{
// Fast path: we're already done.
if (IsDoneOrCanceled())
if (IsDone())
return;
// Slow path: we aren't done when we run the above check. Lock and wait until we are.
std::unique_lock<std::mutex> lock(m_Mutex);
m_ConditionVariable.wait(lock, [this]() -> bool { return IsDoneOrCanceled(); });
m_ConditionVariable.wait(lock, [this]{ return IsDone(); });
}
/**
* If the task is pending, cancel it: the status becomes CANCELED and if the task was completed, the result is destroyed.
* @return true if the task was indeed cancelled, false otherwise (the task is running or already done).
* Request the executing thread to stop as fast as possible. This is only
* a request the execution therad might ignore it.
* @see GetResult must not be called after a call to @p RequestStop.
*/
bool Cancel()
void RequestStop() noexcept
{
Status expected = Status::PENDING;
bool cancelled = m_Status.compare_exchange_strong(expected, Status::CANCELED);
// If we're done, invalidate, if we're pending, atomically cancel, otherwise fail.
if (cancelled || m_Status == Status::DONE)
{
if (m_Status == Status::DONE)
m_Status = Status::CANCELED;
if constexpr (!VoidResult)
std::get<ResultHolder<ResultType>>(m_Outcome).reset();
m_ConditionVariable.notify_all();
return cancelled;
}
return false;
m_StopRequest.store(true);
}
/**
@ -102,15 +101,14 @@ public:
*/
ResultType GetResult()
{
// The caller must ensure that this is only called if we have a result.
// The caller must ensure that this is only called if there is a result.
ENSURE(IsDone());
if constexpr (!std::is_void_v<ResultType>)
ENSURE(std::get<ResultHolder<ResultType>>(m_Outcome).has_value() ||
std::get<std::exception_ptr>(m_Outcome));
m_Status = Status::CANCELED;
if (std::get<std::exception_ptr>(m_Outcome))
std::rethrow_exception(std::get<std::exception_ptr>(m_Outcome));
std::rethrow_exception(std::exchange(std::get<std::exception_ptr>(m_Outcome), {}));
if constexpr (std::is_void_v<ResultType>)
return;
@ -122,7 +120,10 @@ public:
}
}
std::atomic<Status> m_Status = Status::PENDING;
// This is only set by the executing thread and read by the receiving thread. It is never reset.
std::atomic<bool> m_Done{false};
// This is only set by the receiving thread and read by the executing thread. It is never reset.
std::atomic<bool> m_StopRequest{false};
std::mutex m_Mutex;
std::condition_variable m_ConditionVariable;
@ -142,7 +143,7 @@ struct SharedState
{}
Callback callback;
Receiver<std::invoke_result_t<Callback>> receiver;
Receiver<CallbackResult<Callback>> receiver;
};
} // namespace FutureSharedStateDetail
@ -157,8 +158,6 @@ struct SharedState
* Future is _not_ thread-safe. Call it from a single thread or ensure synchronization externally.
*
* The callback never runs after the @p Future is destroyed.
* TODO:
* - Handle exceptions.
*/
template<typename ResultType>
class Future
@ -166,9 +165,6 @@ class Future
template<typename T>
friend class PackagedTask;
static constexpr bool VoidResult = std::is_same_v<ResultType, void>;
using Status = FutureSharedStateDetail::Status;
public:
Future() = default;
Future(const Future& o) = delete;
@ -193,24 +189,23 @@ public:
/**
* Move the result out of the future, and invalidate the future.
* If the future is not complete, calls Wait().
* If the future is canceled, asserts.
* If the future is invalid, asserts.
*/
ResultType Get()
{
ENSURE(!!m_Receiver);
Wait();
ENSURE(m_Receiver->m_Status != Status::CANCELED);
// This mark the state invalid - can't call Get again.
return m_Receiver->GetResult();
return std::exchange(m_Receiver, nullptr)->GetResult();
}
/**
* @return true if the shared state is valid and has a result (i.e. Get can be called).
*/
bool IsReady() const
bool IsDone() const
{
return !!m_Receiver && m_Receiver->m_Status == Status::DONE;
return !!m_Receiver && m_Receiver->IsDone();
}
/**
@ -218,7 +213,7 @@ public:
*/
bool Valid() const
{
return !!m_Receiver && m_Receiver->m_Status != Status::CANCELED;
return !!m_Receiver;
}
void Wait()
@ -227,17 +222,12 @@ public:
m_Receiver->Wait();
}
/**
* Cancels the task, waiting if the task is currently started.
* Use this function over Cancel() if you need to ensure determinism (i.e. in the simulation).
* @see Cancel.
*/
void CancelOrWait()
{
if (!Valid())
return;
if (!m_Receiver->Cancel())
m_Receiver->Wait();
m_Receiver->RequestStop();
m_Receiver->Wait();
m_Receiver.reset();
}
@ -262,26 +252,22 @@ public:
void operator()()
{
FutureSharedStateDetail::Status expected = FutureSharedStateDetail::Status::PENDING;
if (!m_SharedState->receiver.m_Status.compare_exchange_strong(expected,
FutureSharedStateDetail::Status::STARTED))
if (!m_SharedState->receiver.m_StopRequest.load())
{
return;
}
try
{
using ResultType = std::invoke_result_t<Callback>;
if constexpr (std::is_void_v<ResultType>)
m_SharedState->callback();
else
std::get<FutureSharedStateDetail::ResultHolder<ResultType>>(
m_SharedState->receiver.m_Outcome).emplace(m_SharedState->callback());
}
catch(...)
{
std::get<std::exception_ptr>(m_SharedState->receiver.m_Outcome) =
std::current_exception();
try
{
using ResultType = CallbackResult<Callback>;
if constexpr (std::is_void_v<ResultType>)
Invoke();
else
std::get<FutureSharedStateDetail::ResultHolder<ResultType>>(
m_SharedState->receiver.m_Outcome).emplace(Invoke());
}
catch(...)
{
std::get<std::exception_ptr>(m_SharedState->receiver.m_Outcome) =
std::current_exception();
}
}
// Because we might have threads waiting on us, we need to make sure that they either:
@ -290,7 +276,7 @@ public:
// This requires locking the mutex (@see Wait).
{
std::lock_guard<std::mutex> lock(m_SharedState->receiver.m_Mutex);
m_SharedState->receiver.m_Status = FutureSharedStateDetail::Status::DONE;
m_SharedState->receiver.m_Done.store(true);
}
m_SharedState->receiver.m_ConditionVariable.notify_all();
@ -299,13 +285,15 @@ public:
m_SharedState.reset();
}
void Cancel()
private:
CallbackResult<Callback> Invoke()
{
m_SharedState->Cancel();
m_SharedState.reset();
if constexpr (std::is_invocable_v<Callback, StopToken>)
return m_SharedState->callback(StopToken{m_SharedState->receiver.m_StopRequest});
else
return m_SharedState->callback();
}
private:
std::shared_ptr<FutureSharedStateDetail::SharedState<Callback>> m_SharedState;
};
@ -313,8 +301,10 @@ template<typename ResultType>
template<typename Callback>
PackagedTask<Callback> Future<ResultType>::Wrap(Callback&& callback)
{
static_assert(std::is_same_v<std::invoke_result_t<Callback>, ResultType>,
static_assert(std::is_same_v<CallbackResult<Callback>, ResultType>,
"The return type of the wrapped function is not the same as the type the Future expects.");
static_assert(std::is_invocable_v<Callback, StopToken> || !std::is_invocable_v<Callback, StopToken&>,
"Consider taking the `StopToken` by value");
CancelOrWait();
auto temp = std::make_shared<FutureSharedStateDetail::SharedState<Callback>>(std::move(callback));
m_Receiver = {temp, &temp->receiver};

View file

@ -64,9 +64,9 @@ public:
* Push a task to be executed.
*/
template<typename T>
Future<std::invoke_result_t<T>> PushTask(T&& func, TaskPriority priority = TaskPriority::NORMAL)
Future<CallbackResult<T>> PushTask(T&& func, TaskPriority priority = TaskPriority::NORMAL)
{
Future<std::invoke_result_t<T>> ret;
Future<CallbackResult<T>> ret;
DoPushTask(ret.Wrap(std::move(func)), priority);
return ret;
}

View file

@ -67,11 +67,6 @@ public:
TS_ASSERT_EQUALS(future.Get().value, 1);
}
TS_ASSERT_EQUALS(destroyed, 1);
{
Future<NonDef> future;
std::function<void()> task = future.Wrap([]() { return NonDef{1}; });
}
TS_ASSERT_EQUALS(destroyed, 1);
/**
* TODO: find a way to test this
{
@ -103,16 +98,16 @@ public:
future = std::move(*f);
function = std::move(*c);
// Let's move the packaged task while at it.
std::function<void()> task2 = std::move(task);
task2();
TS_ASSERT_EQUALS(future.Get(), 7);
// Destroy and clear the memory
f->~Future();
c->~function();
memset(&futureStorage, 0xFF, sizeof(decltype(futureStorage)));
memset(&functionStorage, 0xFF, sizeof(decltype(functionStorage)));
// Let's move the packaged task while at it.
std::function<void()> task2 = std::move(task);
task2();
TS_ASSERT_EQUALS(future.Get(), 7);
}
void test_move_only_function()
@ -150,7 +145,7 @@ public:
});
packedTask();
TS_ASSERT(future.IsReady());
TS_ASSERT(future.IsDone());
TS_ASSERT_THROWS(future.Get(), const TestException&);
}
@ -163,7 +158,7 @@ public:
});
packedTask();
TS_ASSERT(future.IsReady());
TS_ASSERT(future.IsDone());
TS_ASSERT_THROWS(future.Get(), const TestException&);
}
@ -189,7 +184,44 @@ public:
});
packedTask();
TS_ASSERT(future.IsReady());
TS_ASSERT(future.IsDone());
TS_ASSERT_THROWS(future.Get(), const TestException&);
}
void test_stop_token_overload()
{
{
class DifferentValues
{
public:
bool operator()()
{
return false;
}
bool operator()(StopToken)
{
return true;
}
};
Future<bool> future;
future.Wrap(DifferentValues{})();
TS_ASSERT_EQUALS(future.Get(), true);
}
{
class DifferentTypes
{
public:
void operator()()
{}
bool operator()(StopToken)
{
return true;
}
};
Future<bool> future;
future.Wrap(DifferentTypes{})();
}
}
};