summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--core/object/worker_thread_pool.cpp28
-rw-r--r--core/object/worker_thread_pool.h7
-rw-r--r--core/templates/command_queue_mt.h50
3 files changed, 66 insertions, 19 deletions
diff --git a/core/object/worker_thread_pool.cpp b/core/object/worker_thread_pool.cpp
index b93c327004..881c825cf1 100644
--- a/core/object/worker_thread_pool.cpp
+++ b/core/object/worker_thread_pool.cpp
@@ -33,6 +33,7 @@
#include "core/object/script_language.h"
#include "core/os/os.h"
#include "core/os/thread_safe.h"
+#include "core/templates/command_queue_mt.h"
void WorkerThreadPool::Task::free_template_userdata() {
ERR_FAIL_NULL(template_userdata);
@@ -43,6 +44,8 @@ void WorkerThreadPool::Task::free_template_userdata() {
WorkerThreadPool *WorkerThreadPool::singleton = nullptr;
+thread_local CommandQueueMT *WorkerThreadPool::flushing_cmd_queue = nullptr;
+
void WorkerThreadPool::_process_task(Task *p_task) {
int pool_thread_index = thread_ids[Thread::get_caller_id()];
ThreadData &curr_thread = threads[pool_thread_index];
@@ -428,7 +431,15 @@ Error WorkerThreadPool::wait_for_task_completion(TaskID p_task_id) {
if (!task_to_process) {
caller_pool_thread->awaited_task = task;
+
+ if (flushing_cmd_queue) {
+ flushing_cmd_queue->unlock();
+ }
caller_pool_thread->cond_var.wait(lock);
+ if (flushing_cmd_queue) {
+ flushing_cmd_queue->lock();
+ }
+
DEV_ASSERT(exit_threads || caller_pool_thread->signaled || task->completed);
caller_pool_thread->awaited_task = nullptr;
}
@@ -540,7 +551,14 @@ void WorkerThreadPool::wait_for_group_task_completion(GroupID p_group) {
{
Group *group = *groupp;
+
+ if (flushing_cmd_queue) {
+ flushing_cmd_queue->unlock();
+ }
group->done_semaphore.wait();
+ if (flushing_cmd_queue) {
+ flushing_cmd_queue->lock();
+ }
uint32_t max_users = group->tasks_used + 1; // Add 1 because the thread waiting for it is also user. Read before to avoid another thread freeing task after increment.
uint32_t finished_users = group->finished.increment(); // fetch happens before inc, so increment later.
@@ -563,6 +581,16 @@ int WorkerThreadPool::get_thread_index() {
return singleton->thread_ids.has(tid) ? singleton->thread_ids[tid] : -1;
}
+void WorkerThreadPool::thread_enter_command_queue_mt_flush(CommandQueueMT *p_queue) {
+ ERR_FAIL_COND(flushing_cmd_queue != nullptr);
+ flushing_cmd_queue = p_queue;
+}
+
+void WorkerThreadPool::thread_exit_command_queue_mt_flush() {
+ ERR_FAIL_NULL(flushing_cmd_queue);
+ flushing_cmd_queue = nullptr;
+}
+
void WorkerThreadPool::init(int p_thread_count, float p_low_priority_task_ratio) {
ERR_FAIL_COND(threads.size() > 0);
if (p_thread_count < 0) {
diff --git a/core/object/worker_thread_pool.h b/core/object/worker_thread_pool.h
index 1c4758fb48..3ec4fd732f 100644
--- a/core/object/worker_thread_pool.h
+++ b/core/object/worker_thread_pool.h
@@ -41,6 +41,8 @@
#include "core/templates/rid.h"
#include "core/templates/safe_refcount.h"
+class CommandQueueMT;
+
class WorkerThreadPool : public Object {
GDCLASS(WorkerThreadPool, Object)
public:
@@ -135,6 +137,8 @@ private:
static WorkerThreadPool *singleton;
+ static thread_local CommandQueueMT *flushing_cmd_queue;
+
TaskID _add_task(const Callable &p_callable, void (*p_func)(void *), void *p_userdata, BaseTemplateUserdata *p_template_userdata, bool p_high_priority, const String &p_description);
GroupID _add_group_task(const Callable &p_callable, void (*p_func)(void *, uint32_t), void *p_userdata, BaseTemplateUserdata *p_template_userdata, int p_elements, int p_tasks, bool p_high_priority, const String &p_description);
@@ -197,6 +201,9 @@ public:
static WorkerThreadPool *get_singleton() { return singleton; }
static int get_thread_index();
+ static void thread_enter_command_queue_mt_flush(CommandQueueMT *p_queue);
+ static void thread_exit_command_queue_mt_flush();
+
void init(int p_thread_count = -1, float p_low_priority_task_ratio = 0.3);
void finish();
WorkerThreadPool();
diff --git a/core/templates/command_queue_mt.h b/core/templates/command_queue_mt.h
index 7e480653ac..b1010f7f43 100644
--- a/core/templates/command_queue_mt.h
+++ b/core/templates/command_queue_mt.h
@@ -31,6 +31,7 @@
#ifndef COMMAND_QUEUE_MT_H
#define COMMAND_QUEUE_MT_H
+#include "core/object/worker_thread_pool.h"
#include "core/os/memory.h"
#include "core/os/mutex.h"
#include "core/os/semaphore.h"
@@ -306,15 +307,15 @@ class CommandQueueMT {
struct CommandBase {
virtual void call() = 0;
- virtual void post() {}
- virtual ~CommandBase() {}
+ virtual SyncSemaphore *get_sync_semaphore() { return nullptr; }
+ virtual ~CommandBase() = default; // Won't be called.
};
struct SyncCommand : public CommandBase {
SyncSemaphore *sync_sem = nullptr;
- virtual void post() override {
- sync_sem->sem.post();
+ virtual SyncSemaphore *get_sync_semaphore() override {
+ return sync_sem;
}
};
@@ -340,6 +341,7 @@ class CommandQueueMT {
SyncSemaphore sync_sems[SYNC_SEMAPHORES];
Mutex mutex;
Semaphore *sync = nullptr;
+ uint64_t flush_read_ptr = 0;
template <class T>
T *allocate() {
@@ -362,31 +364,41 @@ class CommandQueueMT {
void _flush() {
lock();
- uint64_t read_ptr = 0;
- uint64_t limit = command_mem.size();
-
- while (read_ptr < limit) {
- uint64_t size = *(uint64_t *)&command_mem[read_ptr];
- read_ptr += 8;
- CommandBase *cmd = reinterpret_cast<CommandBase *>(&command_mem[read_ptr]);
-
- cmd->call(); //execute the function
- cmd->post(); //release in case it needs sync/ret
- cmd->~CommandBase(); //should be done, so erase the command
-
- read_ptr += size;
+ WorkerThreadPool::thread_enter_command_queue_mt_flush(this);
+ while (flush_read_ptr < command_mem.size()) {
+ uint64_t size = *(uint64_t *)&command_mem[flush_read_ptr];
+ flush_read_ptr += 8;
+ CommandBase *cmd = reinterpret_cast<CommandBase *>(&command_mem[flush_read_ptr]);
+
+ SyncSemaphore *sync_sem = cmd->get_sync_semaphore();
+ cmd->call();
+ if (sync_sem) {
+ sync_sem->sem.post(); // Release in case it needs sync/ret.
+ }
+
+ if (unlikely(flush_read_ptr == 0)) {
+ // A reentrant call flushed.
+ DEV_ASSERT(command_mem.is_empty());
+ unlock();
+ return;
+ }
+
+ flush_read_ptr += size;
}
+ WorkerThreadPool::thread_exit_command_queue_mt_flush();
command_mem.clear();
+ flush_read_ptr = 0;
unlock();
}
- void lock();
- void unlock();
void wait_for_flush();
SyncSemaphore *_alloc_sync_sem();
public:
+ void lock();
+ void unlock();
+
/* NORMAL PUSH COMMANDS */
DECL_PUSH(0)
SPACE_SEP_LIST(DECL_PUSH, 15)