AUI Framework  develop
Cross-platform base for C++ UI apps
Loading...
Searching...
No Matches
AThreadPool.h
    1/*
    2 * AUI Framework - Declarative UI toolkit for modern C++20
    3 * Copyright (C) 2020-2024 Alex2772 and Contributors
    4 *
    5 * SPDX-License-Identifier: MPL-2.0
    6 *
    7 * This Source Code Form is subject to the terms of the Mozilla Public
    8 * License, v. 2.0. If a copy of the MPL was not distributed with this
    9 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
   10 */
   11
   12#pragma once
   13
   14#include <AUI/Core.h>
   15#include <cassert>
   16#include <atomic>
   17
   18#include <AUI/Common/AVector.h>
   19#include <AUI/Common/AQueue.h>
   20#include <AUI/Common/AException.h>
   21#include <AUI/Thread/AThread.h>
   22#include <glm/glm.hpp>
   23#include <utility>
   24#include "AUI/Traits/concepts.h"
   25
   26template <typename T>
   27class AFuture;
   28
   33class API_AUI_CORE AThreadPool {
   34public:
   35    class API_AUI_CORE Worker : public AThread {
   36    private:
   37        bool mEnabled = true;
   38        bool processQueue(std::unique_lock<std::mutex>& mutex, AQueue<std::function<void()>>& queue);
   39        AThreadPool& mTP;
   40
   41        void iteration(std::unique_lock<std::mutex>& tpLock);
   42        void wait(std::unique_lock<std::mutex>& tpLock);
   43
   44    public:
   45        Worker(AThreadPool& tp, size_t index);
   46        ~Worker();
   47        void aboutToDelete();
   48
   49        template <aui::predicate ShouldContinue>
   50        void loop(ShouldContinue&& shouldContinue) {
   51            std::unique_lock lock(mTP.mQueueLock);
   52            while (shouldContinue()) {
   53                iteration(lock);
   54                if (!shouldContinue()) {
   55                    return;
   56                }
   57                wait(lock);
   58            }
   59        }
   60
   61        AThreadPool& threadPool() noexcept { return mTP; }
   62    };
   63
   64    enum Priority {
   65        PRIORITY_HIGHEST,
   66        PRIORITY_MEDIUM,
   67        PRIORITY_LOWEST,
   68    };
   69
   70protected:
   71    typedef std::function<void()> task;
   72    AVector<_<Worker>> mWorkers;
   73    AQueue<task> mQueueHighest;
   74    AQueue<task> mQueueMedium;
   75    AQueue<task> mQueueLowest;
   76    AQueue<task> mQueueTryLater;
   77    std::mutex mQueueLock;
   78    std::condition_variable mCV;
   79    size_t mIdleWorkers = 0;
   80
   81public:
   86    AThreadPool(size_t size);
   87
   93    ~AThreadPool();
   94    size_t getPendingTaskCount();
   95    size_t getTotalTaskCount() {
   96        return getPendingTaskCount() + getTotalWorkerCount() - getIdleWorkerCount();
   97    }
   98    void run(const std::function<void()>& fun, Priority priority = PRIORITY_MEDIUM);
   99    void clear();
  100    void runLaterTasks();
  101    static void enqueue(const std::function<void()>& fun, Priority priority = PRIORITY_MEDIUM);
  102
  103    void setWorkersCount(std::size_t workersCount);
  104
  105    void wakeUpAll() {
  106        std::unique_lock lck(mQueueLock);
  107        mCV.notify_all();
  108    }
  109
  113    static AThreadPool& global();
  114
  115    [[nodiscard]]
  116    const AVector<_<Worker>>& workers() const {
  117        return mWorkers;
  118    }
  119
  120    size_t getTotalWorkerCount() const { return mWorkers.size(); }
  121    size_t getIdleWorkerCount() const { return mIdleWorkers; }
  122
  142    template <typename Iterator, typename Functor>
  143    auto parallel(Iterator begin, Iterator end, Functor&& functor);
  144
  145    template <aui::invocable Callable>
  146    [[nodiscard]] inline auto operator*(Callable fun) {
  147        using Value = std::invoke_result_t<Callable>;
  148        AFuture<Value> future(std::move(fun));
  149        run(
  150            [innerWeak = future.inner().weak()]() {
  151                /*
  152                 * Avoid holding a strong reference - we need to keep future cancellation on reference count exceeding
  153                 * even while actual future execution.
  154                 */
  155                if (auto lock = innerWeak.lock()) {
  156                    auto innerUnsafePointer = lock->ptr().get();   // using .get() here in order to bypass
  157                                                                   // null check in operator->
  158
  159                    lock = nullptr;   // destroy strong ref
  160
  161                    innerUnsafePointer->tryExecute(innerWeak);   // there's a check inside tryExecute to check its
  162                                                                 // validity
  163                }
  164            },
  165            AThreadPool::PRIORITY_LOWEST);
  166        return future;
  167    }
  168
  169    class TryLaterException {};
  170};
  171
  172#include <AUI/Thread/AFuture.h>
  173
  184template <typename T = void>
  185class AFutureSet : public AVector<AFuture<T>> {
  186private:
  187    using super = AVector<AFuture<T>>;
  188
  189public:
  190    using AVector<AFuture<T>>::AVector;
  191
  196    void waitForAll() {
  197        // wait from the end to avoid idling (see AFuture::wait for details)
  198        for (const AFuture<T>& v : aui::reverse_iterator_wrap(*this)) {
  199            v.operator*();
  200        }
  201    }
  202
  207    void checkForExceptions() const {
  208        for (const AFuture<T>& v : *this) {
  209            if (v.hasResult()) {
  210                v.operator*();   // TODO bad design
  211            }
  212        }
  213    }
  214
  225    template <aui::invocable OnComplete>
  226    void onAllComplete(OnComplete&& onComplete) {
  227        // check if all futures is already complete.
  228        for (const AFuture<T>& v : *this) {
  229            if (!v.hasResult()) {
  230                goto setupTheHell;
  231            }
  232        }
  233        onComplete();
  234        return;
  235
  236    setupTheHell:
  237        struct Temporary {
  238            OnComplete onComplete;
  239            AFutureSet myCopy;
  240            std::atomic_bool canBeCalled = true;
  241        };
  242        auto temporary = _new<Temporary>(std::forward<OnComplete>(onComplete), *this);
  243        for (const AFuture<T>& v : *this) {
  244            v.onSuccess([temporary](const auto& v) {
  245                for (const AFuture<T>& v : temporary->myCopy) {
  246                    if (!v.hasResult()) {
  247                        return;
  248                    }
  249                }
  250                // yay! all tasks are completed. the last thing to check if the callback is already called
  251                if (temporary->canBeCalled.exchange(false)) {
  252                    temporary->onComplete();
  253                }
  254            });
  255        }
  256    }
  257};
  258
  259template <typename Iterator, typename Functor>
  260auto AThreadPool::parallel(Iterator begin, Iterator end, Functor&& functor) {
  261    using ResultType = decltype(std::declval<Functor>()(std::declval<Iterator>(), std::declval<Iterator>()));
  262    AFutureSet<ResultType> futureSet;
  263
  264    size_t itemCount = end - begin;
  265    size_t affinity = (glm::min) (AThreadPool::global().getTotalWorkerCount(), itemCount);
  266    if (affinity == 0)
  267        return futureSet;
  268    size_t itemsPerThread = itemCount / affinity;
  269
  270    for (size_t threadIndex = 0; threadIndex < affinity; ++threadIndex) {
  271        auto forThreadBegin = begin;
  272        begin += itemsPerThread;
  273        auto forThreadEnd = threadIndex + 1 == affinity ? end : begin;
  274        futureSet.push_back(
  275            *this * [functor = std::forward<Functor>(functor), forThreadBegin, forThreadEnd]() -> decltype(auto) {
  276                return functor(forThreadBegin, forThreadEnd);
  277            });
  278    }
  279
  280    return futureSet;
  281}
  282
  283#include <AUI/Reflect/AReflect.h>
Manages multiple futures.
Definition AThreadPool.h:185
void waitForAll()
Wait for the result of every AFuture.
Definition AThreadPool.h:196
void checkForExceptions() const
Find AFutures that encountered an exception. If such AFuture is found, AInvocationTargetException is ...
Definition AThreadPool.h:207
void onAllComplete(OnComplete &&onComplete)
Specifies a callback which will be called when all futures in future set would have the result.
Definition AThreadPool.h:226
Represents a value that will be available at some point in the future.
Definition AFuture.h:621
const AFuture & onSuccess(Callback &&callback) const noexcept
Add onSuccess callback to the future.
Definition AFuture.h:709
A std::queue with AUI extensions.
Definition AQueue.h:24
Definition AThreadPool.h:169
static AThreadPool & global()
Global thread pool created with the default constructor.
AThreadPool(size_t size)
Initializes the thread pool with size of threads.
auto parallel(Iterator begin, Iterator end, Functor &&functor)
Definition AThreadPool.h:260
AThreadPool()
Initializes the thread pool with max(std::thread::hardware_concurrency() - 1, 2) of threads or –aui-t...
A std::vector with AUI extensions.
Definition AVector.h:39
bool hasResult() const noexcept
Definition AFuture.h:395
Definition iterators.h:34