// // Copyright (c) 2022 Klemens Morgenstern (klemens.morgenstern@gmx.net) // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #ifndef BOOST_COBALT_DETAIL_THREAD_HPP #define BOOST_COBALT_DETAIL_THREAD_HPP #include <boost/cobalt/config.hpp> #include <boost/cobalt/detail/forward_cancellation.hpp> #include <boost/cobalt/detail/handler.hpp> #include <boost/cobalt/concepts.hpp> #include <boost/cobalt/this_coro.hpp> #include <boost/asio/cancellation_signal.hpp> #include <thread> namespace boost::cobalt { struct as_tuple_tag; struct as_result_tag; namespace detail { struct thread_promise; } struct thread; namespace detail { struct signal_helper_2 { asio::cancellation_signal signal; }; struct thread_state { asio::io_context ctx{1u}; asio::cancellation_signal signal; std::mutex mtx; std::optional<completion_handler<std::exception_ptr>> waitor; std::atomic<bool> done = false; }; struct thread_promise : signal_helper_2, promise_cancellation_base<asio::cancellation_slot, asio::enable_total_cancellation>, promise_throw_if_cancelled_base, enable_awaitables<thread_promise>, enable_await_allocator<thread_promise>, enable_await_executor<thread_promise> { BOOST_COBALT_DECL thread_promise(); struct initial_awaitable { bool await_ready() const {return false;} void await_suspend(std::coroutine_handle<thread_promise> h) { h.promise().mtx.unlock(); } void await_resume() {} }; auto initial_suspend() noexcept { return initial_awaitable{}; } std::suspend_never final_suspend() noexcept { wexec_.reset(); return {}; } void unhandled_exception() { throw; } void return_void() { } using executor_type = typename cobalt::executor; const executor_type & get_executor() const {return *exec_;} #if !defined(BOOST_COBALT_NO_PMR) using allocator_type = pmr::polymorphic_allocator<void>; using resource_type = pmr::unsynchronized_pool_resource; resource_type * resource; allocator_type get_allocator() const { return allocator_type(resource); } #endif using promise_cancellation_base<asio::cancellation_slot, asio::enable_total_cancellation>::await_transform; using promise_throw_if_cancelled_base::await_transform; using enable_awaitables<thread_promise>::await_transform; using enable_await_allocator<thread_promise>::await_transform; using enable_await_executor<thread_promise>::await_transform; BOOST_COBALT_DECL boost::cobalt::thread get_return_object(); void set_executor(asio::io_context::executor_type exec) { wexec_.emplace(exec); exec_.emplace(exec); } std::mutex mtx; private: std::optional<asio::executor_work_guard<asio::io_context::executor_type>> wexec_; std::optional<cobalt::executor> exec_; }; struct thread_awaitable { asio::cancellation_slot cl; std::optional<std::tuple<std::exception_ptr>> res; bool await_ready(const boost::source_location & loc = BOOST_CURRENT_LOCATION) const { if (state_ == nullptr) boost::throw_exception(std::invalid_argument("Thread expired"), loc); std::lock_guard<std::mutex> lock{state_->mtx}; return state_->done; } template<typename Promise> bool await_suspend(std::coroutine_handle<Promise> h) { BOOST_ASSERT(state_); std::lock_guard<std::mutex> lock{state_->mtx}; if (state_->done) return false; if constexpr (requires (Promise p) {p.get_cancellation_slot();}) if ((cl = h.promise().get_cancellation_slot()).is_connected()) { cl.assign( [st = state_](asio::cancellation_type type) { std::lock_guard<std::mutex> lock{st->mtx}; asio::post(st->ctx, [st, type] { BOOST_ASIO_HANDLER_LOCATION((__FILE__, __LINE__, __func__)); st->signal.emit(type); }); }); } state_->waitor.emplace(h, res); return true; } void await_resume() { if (cl.is_connected()) cl.clear(); if (thread_) thread_->join(); if (!res) // await_ready return; if (auto ee = std::get<0>(*res)) std::rethrow_exception(ee); } system::result<void, std::exception_ptr> await_resume(const as_result_tag &) { if (cl.is_connected()) cl.clear(); if (thread_) thread_->join(); if (!res) // await_ready return {system::in_place_value}; if (auto ee = std::get<0>(*res)) return {system::in_place_error, std::move(ee)}; return {system::in_place_value}; } std::tuple<std::exception_ptr> await_resume(const as_tuple_tag &) { if (cl.is_connected()) cl.clear(); if (thread_) thread_->join(); return std::get<0>(*res); } explicit thread_awaitable(std::shared_ptr<detail::thread_state> state) : state_(std::move(state)) {} explicit thread_awaitable(std::thread thread, std::shared_ptr<detail::thread_state> state) : thread_(std::move(thread)), state_(std::move(state)) {} private: std::optional<std::thread> thread_; std::shared_ptr<detail::thread_state> state_; }; } } #endif //BOOST_COBALT_DETAIL_THREAD_HPP