diff --git a/Kernel/tests/stl/coroutines.cpp b/Kernel/tests/stl/coroutines.cpp index d052f80b..59605500 100644 --- a/Kernel/tests/stl/coroutines.cpp +++ b/Kernel/tests/stl/coroutines.cpp @@ -24,6 +24,7 @@ #include "../kernel.h" #include +#include /* https://gist.github.com/Qix-/caa277fbf1a4e6ca55a27f2242df3b9a */ @@ -58,7 +59,7 @@ struct resumable::promise_type auto initial_suspend() { return std::suspend_always(); } auto final_suspend() noexcept { return std::suspend_always(); } void return_void() {} - void unhandled_exception() { assert(!"std::terminate();"); } + void unhandled_exception() { std::terminate(); } }; resumable foo() @@ -70,6 +71,122 @@ resumable foo() /* ===================================================================== */ +struct Generator +{ + struct promise_type + { + int current_value; + + Generator get_return_object() + { + return Generator{std::coroutine_handle::from_promise(*this)}; + } + + std::suspend_always initial_suspend() + { + return {}; + } + + std::suspend_always final_suspend() noexcept + { + return {}; + } + + void return_void() + { + } + + std::suspend_always yield_value(int value) + { + current_value = value; + return {}; + } + + void unhandled_exception() + { + std::terminate(); + } + }; + + std::coroutine_handle handle; + + Generator(std::coroutine_handle h) : handle(h) {} + + ~Generator() + { + if (handle) + handle.destroy(); + } + + bool next() + { + if (!handle || handle.done()) + return false; + + handle.resume(); + return true; + } + + int value() const + { + int ret = handle.promise().current_value; + return ret; + } +}; + +Generator CountToThree() +{ + debug("1"); + co_yield 1; + debug("2"); + co_yield 2; + debug("3"); + co_yield 3; + debug("end"); +} + +/* ===================================================================== */ + +struct Task +{ + struct promise_type + { + Task get_return_object() { return Task{std::coroutine_handle::from_promise(*this)}; } + std::suspend_never initial_suspend() { return {}; } + std::suspend_never final_suspend() noexcept { return {}; } + void return_void() {} + void unhandled_exception() { std::terminate(); } + }; + + std::coroutine_handle handle; + Task(std::coroutine_handle h) : handle(h) {} + ~Task() + { + if (handle) + handle.destroy(); + } +}; + +struct Awaiter +{ + bool await_ready() { return false; } + void await_suspend(std::coroutine_handle<> h) + { + std::this_thread::sleep_for(std::chrono::seconds(1)); + h.resume(); + } + void await_resume() {} +}; + +Task AsyncFunc() +{ + debug("waiting"); + co_await Awaiter{}; + debug("done"); +} + +/* ===================================================================== */ + class SyscallAwaitable { public: @@ -120,7 +237,7 @@ public: void unhandled_exception() { - assert("std::terminate();"); + std::terminate(); } }; @@ -146,6 +263,17 @@ void coroutineTest() auto task = perform_syscall(); task.handle.resume(); + /* async task */ + AsyncFunc(); + + /* generator */ + auto gen = CountToThree(); + while (gen.next()) + { + auto a = gen.value(); + debug("%d", a); + } + /* Example of coroutine */ auto p = foo(); while (p.resume())