diff --git a/.runsettings b/.runsettings new file mode 100644 index 000000000..5f10e44f8 --- /dev/null +++ b/.runsettings @@ -0,0 +1,24 @@ + + + + .\TestResults + 60000 + true + + + + + + + Single + (?i:Test) + true + on + true + Verbose + AdditionalInfo + ShortInfo + , + 20000 + + \ No newline at end of file diff --git a/README.md b/README.md index 1749fea0c..19e1493dc 100644 --- a/README.md +++ b/README.md @@ -34,3 +34,8 @@ a dev command prompt at the root of the repo _after_ following the above build i * Run `build_prior_projection.cmd` in the dev command prompt as well * Run `prepare_versionless_diffs.cmd` which removes version stamps on both current and prior projection * Use a directory-level differencing tool to compare `_build\$(arch)\$(flavor)\winrt` and `_reference\$(arch)\$(flavor)\winrt` + +## Testing +This repository uses the [Catch2](https://github.com/catchorg/Catch2) testing framework. +- From a Visual Studio command line, you should run `build_tests_all.cmd` to build and run the tests. To Debug the tests, you can debug the associated `_build\$(arch)\$(flavor)\.exe` under the debugger of your choice. +- Optionally, you can install the [Catch2Adapter](https://marketplace.visualstudio.com/items?itemName=JohnnyHendriks.ext01) to run the tests from Visual Studio. \ No newline at end of file diff --git a/cppwinrt/code_writers.h b/cppwinrt/code_writers.h index ffc0ee948..35e9935d0 100644 --- a/cppwinrt/code_writers.h +++ b/cppwinrt/code_writers.h @@ -1873,7 +1873,37 @@ namespace cppwinrt } } - static void write_produce_method(writer& w, MethodDef const& method) + static void write_produce_upcall_TryLookup(writer& w, std::string_view const& upcall, method_signature const& method_signature) + { + auto name = method_signature.return_param_name(); + + w.write("auto out_param_val = %(%, trylookup_from_abi);", + upcall, + bind(method_signature)); + w.write(R"( + if (out_param_val.has_value()) + { + *% = detach_from<%>(std::move(*out_param_val)); + } + else + { + return impl::error_out_of_bounds; + } +)", + name, method_signature.return_signature()); + + for (auto&& [param, param_signature] : method_signature.params()) + { + if (param.Flags().Out() && !param_signature->Type().is_szarray() && is_object(param_signature->Type())) + { + auto param_name = param.Name(); + + w.write("\n if (%) *% = detach_abi(winrt_impl_%);", param_name, param_name, param_name); + } + } + } + + static void write_produce_method(writer& w, MethodDef const& method, TypeDef const& type) { std::string_view format; @@ -1902,13 +1932,45 @@ namespace cppwinrt method_signature signature{ method }; auto async_types_guard = w.push_async_types(signature.is_async()); std::string upcall = "this->shim()."; - upcall += get_name(method); + auto name = get_name(method); + upcall += name; - w.write(format, - get_abi_name(method), - bind(signature), - bind(signature), - bind(upcall, signature)); + auto typeName = type.TypeName(); + if (((typeName == "IMapView`2") || (typeName == "IMap`2")) + && (name == "Lookup")) + { + // Special-case IMap*::Lookup to look for a TryLookup here, to avoid extranous throw/originates + std::string tryLookupUpCall = "this->shim().TryLookup"; + format = R"( int32_t __stdcall %(%) noexcept final try + { +% typename D::abi_guard guard(this->shim()); + if constexpr (has_TryLookup_v) + { + % + } + else + { + % + } + return 0; + } + catch (...) { return to_hresult(); } +)"; + w.write(format, + get_abi_name(method), + bind(signature), + bind(signature), // clear_abi + bind(tryLookupUpCall, signature), + bind(upcall, signature)); + } + else + { + w.write(format, + get_abi_name(method), + bind(signature), + bind(signature), + bind(upcall, signature)); + } } static void write_fast_produce_methods(writer& w, TypeDef const& default_interface) @@ -1951,7 +2013,7 @@ namespace cppwinrt break; } - w.write_each(info.type.MethodList()); + w.write_each(info.type.MethodList(), info.type); } } @@ -1973,7 +2035,7 @@ namespace cppwinrt bind(generics), type, type, - bind_each(type.MethodList()), + bind_each(type.MethodList(), type), bind(type)); } diff --git a/strings/base_collections_base.h b/strings/base_collections_base.h index f3ede9ed4..fec3de660 100644 --- a/strings/base_collections_base.h +++ b/strings/base_collections_base.h @@ -506,6 +506,20 @@ WINRT_EXPORT namespace winrt template struct map_view_base : iterable_base, Version> { + // specialization of Lookup that avoids throwing the hresult + std::optional TryLookup(K const& key, trylookup_from_abi_t) const + { + [[maybe_unused]] auto guard = static_cast(*this).acquire_shared(); + auto pair = static_cast(*this).get_container().find(static_cast(*this).wrap_value(key)); + + if (pair == static_cast(*this).get_container().end()) + { + return std::nullopt; + } + + return static_cast(*this).unwrap_value(pair->second); + } + V Lookup(K const& key) const { [[maybe_unused]] auto guard = static_cast(*this).acquire_shared(); @@ -536,6 +550,7 @@ WINRT_EXPORT namespace winrt first = nullptr; second = nullptr; } + }; template diff --git a/strings/base_coroutine_foundation.h b/strings/base_coroutine_foundation.h index 670a65403..87aaed24e 100644 --- a/strings/base_coroutine_foundation.h +++ b/strings/base_coroutine_foundation.h @@ -351,6 +351,10 @@ namespace winrt::impl return m_promise->enable_cancellation_propagation(value); } + bool originate_on_cancel(bool value = true) const noexcept + { + return m_promise->originate_on_cancel(value); + } private: Promise* m_promise; @@ -484,7 +488,14 @@ namespace winrt::impl if (m_status.load(std::memory_order_relaxed) == AsyncStatus::Started) { m_status.store(AsyncStatus::Canceled, std::memory_order_relaxed); - m_exception = std::make_exception_ptr(hresult_canceled()); + if (cancellable_promise::originate_on_cancel()) + { + m_exception = std::make_exception_ptr(hresult_canceled()); + } + else + { + m_exception = std::make_exception_ptr(hresult_canceled(hresult_error::no_originate)); + } cancel = std::move(m_cancel); } } @@ -628,7 +639,14 @@ namespace winrt::impl { if (Status() == AsyncStatus::Canceled) { - throw winrt::hresult_canceled(); + if (cancellable_promise::originate_on_cancel()) + { + throw winrt::hresult_canceled(); + } + else + { + throw winrt::hresult_canceled(hresult_error::no_originate); + } } return std::forward(expression); diff --git a/strings/base_coroutine_threadpool.h b/strings/base_coroutine_threadpool.h index 0faaa1acd..7fa8789c7 100644 --- a/strings/base_coroutine_threadpool.h +++ b/strings/base_coroutine_threadpool.h @@ -180,12 +180,23 @@ WINRT_EXPORT namespace winrt return m_propagate_cancellation; } + bool originate_on_cancel(bool value = true) noexcept + { + return std::exchange(m_originate_on_cancel, value); + } + + bool should_originate_on_cancel() const noexcept + { + return m_originate_on_cancel; + } + private: static inline auto const cancelling_ptr = reinterpret_cast(1); std::atomic m_canceller{ nullptr }; void* m_context{ nullptr }; bool m_propagate_cancellation{ false }; + bool m_originate_on_cancel{ true }; // By default, will call RoOriginateError before throwing a cancel error code. }; template diff --git a/strings/base_error.h b/strings/base_error.h index 85de70f63..41929336b 100644 --- a/strings/base_error.h +++ b/strings/base_error.h @@ -84,6 +84,9 @@ WINRT_EXPORT namespace winrt { struct hresult_error { + struct no_originate_t {}; + static constexpr no_originate_t no_originate{}; + using from_abi_t = take_ownership_from_abi_t; static constexpr auto from_abi{ take_ownership_from_abi }; @@ -109,6 +112,10 @@ WINRT_EXPORT namespace winrt originate(code, nullptr, sourceInformation); } + explicit hresult_error(hresult const code, no_originate_t) noexcept : m_code(verify_error(code)) + { + } + hresult_error(hresult const code, param::hstring const& message, winrt::impl::slim_source_location const& sourceInformation = winrt::impl::slim_source_location::current()) noexcept : m_code(verify_error(code)) { originate(code, get_abi(message), sourceInformation); @@ -325,6 +332,7 @@ WINRT_EXPORT namespace winrt struct hresult_canceled : hresult_error { hresult_canceled(winrt::impl::slim_source_location const& sourceInformation = winrt::impl::slim_source_location::current()) noexcept : hresult_error(impl::error_canceled, sourceInformation) {} + hresult_canceled(hresult_error::no_originate_t) noexcept : hresult_error(impl::error_canceled, hresult_error::no_originate) {} hresult_canceled(param::hstring const& message, winrt::impl::slim_source_location const& sourceInformation = winrt::impl::slim_source_location::current()) noexcept : hresult_error(impl::error_canceled, message, sourceInformation) {} hresult_canceled(take_ownership_from_abi_t, winrt::impl::slim_source_location const& sourceInformation = winrt::impl::slim_source_location::current()) noexcept : hresult_error(impl::error_canceled, take_ownership_from_abi, sourceInformation) {} }; diff --git a/strings/base_meta.h b/strings/base_meta.h index 25deb42ec..7dbb4c386 100644 --- a/strings/base_meta.h +++ b/strings/base_meta.h @@ -10,6 +10,10 @@ WINRT_EXPORT namespace winrt struct take_ownership_from_abi_t {}; inline constexpr take_ownership_from_abi_t take_ownership_from_abi{}; + // Map implementations can implement TryLookup with trylookup_from_abi_t as an optimization + struct trylookup_from_abi_t {}; + inline constexpr trylookup_from_abi_t trylookup_from_abi{}; + template struct com_ptr; @@ -298,4 +302,16 @@ namespace winrt::impl return (func(Types{}) || ...); } }; + + template + struct has_TryLookup + { + template ().TryLookup(std::declval(), trylookup_from_abi))> static constexpr bool get_value(int) { return true; } + template static constexpr bool get_value(...) { return false; } + public: + static constexpr bool value = get_value(0); + }; + + template + inline constexpr bool has_TryLookup_v = has_TryLookup::value; } diff --git a/test/old_tests/UnitTests/Errors.cpp b/test/old_tests/UnitTests/Errors.cpp index 472d633f1..f8656a508 100644 --- a/test/old_tests/UnitTests/Errors.cpp +++ b/test/old_tests/UnitTests/Errors.cpp @@ -233,6 +233,7 @@ TEST_CASE("Errors") // Make sure trimming works. hresult_error e(E_FAIL, L":) is \u263A \n \t "); + auto x = e.message(); REQUIRE(e.message() == L":) is \u263A"); // Make sure delegates propagate correctly. diff --git a/test/old_tests/UnitTests/TryLookup.cpp b/test/old_tests/UnitTests/TryLookup.cpp index 3c6c54580..350fa9006 100644 --- a/test/old_tests/UnitTests/TryLookup.cpp +++ b/test/old_tests/UnitTests/TryLookup.cpp @@ -143,4 +143,129 @@ TEST_CASE("TryLookup TryRemove error") REQUIRE(!map.TryLookup(123)); REQUIRE(!map.TryRemove(123)); -} \ No newline at end of file +} + +TEST_CASE("trylookup_from_abi specialization") +{ + // A map that throws a specific error, used to verify various edge cases. + // and implements tryLookup, to take advantage of an optimization to avoid a throw. + struct map_with_try_lookup : implements> + { + hresult codeToThrow{ S_OK }; + bool shouldThrowOnTryLookup{ false }; + std::optional TryLookup(int, trylookup_from_abi_t) + { + if (shouldThrowOnTryLookup) + { + throw_hresult(codeToThrow); + } + else + { + return { std::nullopt }; + } + } + int Lookup(int) { throw_hresult(E_UNEXPECTED); } // shouldn't be called by the test + int32_t Size() { throw_hresult(E_UNEXPECTED); } // shouldn't be called by the test + bool HasKey(int) { throw_hresult(E_UNEXPECTED); } // shouldn't be called by the test + void Split(IMapView&, IMapView&) { throw_hresult(E_UNEXPECTED); } // shouldn't be called by the test + }; + + auto self = make_self(); + IMapView map = *self; + + // Make sure that we use the TryLookup specialization, and don't throw an unexpected exception. + self->shouldThrowOnTryLookup = false; + REQUIRE(!map.TryLookup(123)); + // make sure regular lookup stll throws bounds + REQUIRE_THROWS_AS(map.Lookup(123), hresult_out_of_bounds); + + // Simulate a non-agile map that is being accessed from the wrong thread. + // "Try" operations should throw rather than erroneously report "not found". + // Because they didn't even try. The operation never got off the ground. + self->shouldThrowOnTryLookup = true; + self->codeToThrow = RPC_E_WRONG_THREAD; + REQUIRE_THROWS_AS(map.TryLookup(123), hresult_wrong_thread); + // regular lookup should throw the same error + REQUIRE_THROWS_AS(map.Lookup(123), hresult_wrong_thread); +} + +TEST_CASE("trylookup_from_abi NOT opt-in, no special tag") +{ + // Makes sure that an existing TryLookup method is not called without the trylookup_from_abi_t tag. + struct map_without_try_lookup : implements> + { + hresult codeToThrow{ S_OK }; + std::optional TryLookup(int) // notice no trylookup_from_abi_t, so no opt-in + { + // throw an unexpectd hresult, this should not be called. + throw_hresult(RPC_E_WRONG_THREAD); + } + int Lookup(int) { return 42; } // Behave as if the item was found + + int32_t Size() { throw_hresult(E_UNEXPECTED); } // shouldn't be called by the test + bool HasKey(int) { throw_hresult(E_UNEXPECTED); } // shouldn't be called by the test + void Split(IMapView&, IMapView&) { throw_hresult(E_UNEXPECTED); } // shouldn't be called by the test + }; + + auto self = make_self(); + IMapView map = *self; + + // Make sure that we don't use the TryLookup specialization, we use the Successful Lookup + REQUIRE(map.TryLookup(123).value() == 42); + REQUIRE(map.Lookup(123) == 42); +} + +TEST_CASE("trylookup_from_abi specialization with IInspectable") +{ + // A map that throws a specific error, used to verify various edge cases. + // and implements tryLookup, to take advantage of an optimization to avoid a throw. + struct map_with_try_lookup : implements> + { + hresult codeToThrow{ S_OK }; + bool shouldThrowOnTryLookup{ false }; + bool returnNullptr{ false }; + std::optional TryLookup(int, trylookup_from_abi_t) + { + if (returnNullptr) + { + return { nullptr }; + } + else if (shouldThrowOnTryLookup) + { + throw_hresult(codeToThrow); + } + else + { + return { std::nullopt }; + } + } + IInspectable Lookup(int) { throw_hresult(E_UNEXPECTED); } // shouldn't be called by the test + int32_t Size() { throw_hresult(E_UNEXPECTED); } // shouldn't be called by the test + bool HasKey(int) { throw_hresult(E_UNEXPECTED); } // shouldn't be called by the test + void Split(IMapView&, IMapView&) { throw_hresult(E_UNEXPECTED); } // shouldn't be called by the test + }; + + auto self = make_self(); + IMapView map = *self; + + // Ensure that we return a value on nullptr, a nullptr is a valid IInspectable in the Map + self->returnNullptr = true; + REQUIRE(map.TryLookup(123) == IInspectable{nullptr}); + REQUIRE(map.Lookup(123) == IInspectable{nullptr}); + + // Make sure that we use the TryLookup specialization, and don't throw an unexpected exception. + self->shouldThrowOnTryLookup = false; + self->returnNullptr = false; + REQUIRE(map.TryLookup(123) == IInspectable{nullptr}); + // make sure regular lookup stll throws bounds + REQUIRE_THROWS_AS(map.Lookup(123), hresult_out_of_bounds); + + // Simulate a non-agile map that is being accessed from the wrong thread. + // "Try" operations should throw rather than erroneously report "not found". + // Because they didn't even try. The operation never got off the ground. + self->shouldThrowOnTryLookup = true; + self->codeToThrow = RPC_E_WRONG_THREAD; + REQUIRE_THROWS_AS(map.TryLookup(123), hresult_wrong_thread); + // regular lookup should throw the same error + REQUIRE_THROWS_AS(map.Lookup(123), hresult_wrong_thread); +} diff --git a/test/test/async_check_cancel.cpp b/test/test/async_check_cancel.cpp index 7547609f6..d88fdcaf0 100644 --- a/test/test/async_check_cancel.cpp +++ b/test/test/async_check_cancel.cpp @@ -11,6 +11,28 @@ namespace using std::experimental::suspend_never; #endif + static bool s_exceptionLoggerCalled = false; + + static struct { + uint32_t lineNumber; + char const* fileName; + char const* functionName; + void* returnAddress; + winrt::hresult result; + } s_exceptionLoggerArgs{}; + + void __stdcall exceptionLogger(uint32_t lineNumber, char const* fileName, char const* functionName, void* returnAddress, winrt::hresult const result) noexcept + { + s_exceptionLoggerArgs = { + /*.lineNumber =*/ lineNumber, + /*.fileName =*/ fileName, + /*.functionName =*/ functionName, + /*.returnAddress =*/ returnAddress, + /*.result =*/ result, + }; + s_exceptionLoggerCalled = true; + } + // // Checks that manual cancellation checks work. // @@ -61,6 +83,7 @@ namespace co_return 1; } + IAsyncOperationWithProgress OperationWithProgress(HANDLE event, bool& canceled) { co_await resume_on_signal(event); @@ -77,6 +100,59 @@ namespace co_return 1; } + IAsyncAction OperationCancelLogged(HANDLE event, bool& canceled) + { + REQUIRE(!s_exceptionLoggerCalled); + REQUIRE(!winrt_throw_hresult_handler); + winrt_throw_hresult_handler = exceptionLogger; + + co_await resume_on_signal(event); + auto cancel = co_await get_cancellation_token(); + + if (cancel()) + { + REQUIRE(!canceled); + canceled = true; + REQUIRE(s_exceptionLoggerCalled); + REQUIRE(s_exceptionLoggerArgs.result == HRESULT_FROM_WIN32(ERROR_CANCELLED)); + } + + winrt_throw_hresult_handler = nullptr; + s_exceptionLoggerCalled = false; + + co_await suspend_never(); + + REQUIRE(false); + co_return; + } + + IAsyncAction OperationAvoidLoggingCancel(HANDLE event, bool& canceled) + { + REQUIRE(!s_exceptionLoggerCalled); + REQUIRE(!winrt_throw_hresult_handler); + winrt_throw_hresult_handler = exceptionLogger; + + auto cancel = co_await get_cancellation_token(); + cancel.originate_on_cancel(false); + + co_await resume_on_signal(event); + + if (cancel()) + { + REQUIRE(!canceled); + canceled = true; + REQUIRE(!s_exceptionLoggerCalled); + } + + winrt_throw_hresult_handler = nullptr; + s_exceptionLoggerCalled = false; + + co_await suspend_never(); + + REQUIRE(false); + co_return; + } + template void Check(F make) { @@ -96,7 +172,7 @@ namespace async.Cancel(); SetEvent(start.get()); - REQUIRE(WaitForSingleObject(completed.get(), 1000) == WAIT_OBJECT_0); + REQUIRE(WaitForSingleObject(completed.get(), IsDebuggerPresent() ? INFINITE : 1000) == WAIT_OBJECT_0); REQUIRE(async.Status() == AsyncStatus::Canceled); REQUIRE(async.ErrorCode() == HRESULT_FROM_WIN32(ERROR_CANCELLED)); @@ -115,4 +191,6 @@ TEST_CASE("async_check_cancel") Check(ActionWithProgress); Check(Operation); Check(OperationWithProgress); + Check(OperationCancelLogged); + Check(OperationAvoidLoggingCancel); } diff --git a/test/test_cpp20/custom_error.cpp b/test/test_cpp20/custom_error.cpp index d7b055e47..b925628d9 100644 --- a/test/test_cpp20/custom_error.cpp +++ b/test/test_cpp20/custom_error.cpp @@ -37,9 +37,9 @@ namespace #if defined(_LIBCPP_VERSION) && _LIBCPP_VERSION < 170000 // not available in libc++ before LLVM 16 -TEST_CASE("custom_error_logger", "[!shouldfail]") +TEST_CASE("custom_error_logger_on_throw", "[!shouldfail]") #else -TEST_CASE("custom_error_logger") +TEST_CASE("custom_error_logger_on_throw") #endif { // Set up global handler @@ -72,3 +72,62 @@ TEST_CASE("custom_error_logger") winrt_throw_hresult_handler = nullptr; s_loggerCalled = false; } +template +void HresultOnLine80(Args... args) +{ + // Validate that handler translated on creating an HRESULT +#line 80 // Force next line to be reported as line number 80 + winrt::hresult_canceled(std::forward(args)...); +} + +#if defined(_LIBCPP_VERSION) && _LIBCPP_VERSION < 170000 +// not available in libc++ before LLVM 16 +TEST_CASE("custom_error_logger_on_originate", "[!shouldfail]") +#else +TEST_CASE("custom_error_logger_on_originate") +#endif +{ + // Set up global handler + REQUIRE(!s_loggerCalled); + REQUIRE(!winrt_throw_hresult_handler); + winrt_throw_hresult_handler = logger; + + HresultOnLine80(); + REQUIRE(s_loggerCalled); + // In C++20 these fields should be filled in by std::source_location + REQUIRE(s_loggerArgs.lineNumber == 80); + const auto fileNameSv = std::string_view(s_loggerArgs.fileName); + REQUIRE(!fileNameSv.empty()); + REQUIRE(fileNameSv.find("custom_error.cpp") != std::string::npos); +#ifdef _DEBUG + const auto functionNameSv = std::string_view(s_loggerArgs.functionName); + REQUIRE(!functionNameSv.empty()); + // Every compiler has a slightly different naming approach for this function, and even the same + // compiler can change its mind over time. Instead of matching the entire function name just + // match against the part we care about. + REQUIRE((functionNameSv.find("HresultOnLine80") != std::string_view::npos)); +#else + REQUIRE(s_loggerArgs.functionName == nullptr); +#endif // _DEBUG + + REQUIRE(s_loggerArgs.returnAddress); + REQUIRE(s_loggerArgs.result == HRESULT_FROM_WIN32(ERROR_CANCELLED)); // E_ILLEGAL_DELEGATE_ASSIGNMENT) + + s_loggerCalled = false; + s_loggerArgs.lineNumber = 0; + // verify HRESULT with a custom message + HresultOnLine80(L"with custom message"); + REQUIRE(s_loggerCalled); + REQUIRE(s_loggerArgs.lineNumber == 80); + + s_loggerCalled = false; + s_loggerArgs.lineNumber = 0; + // verify that no_originate does _not_ call the logger. + HresultOnLine80(winrt::hresult_error::no_originate); + REQUIRE(!s_loggerCalled); + REQUIRE(s_loggerArgs.lineNumber == 0); + + // Remove global handler + winrt_throw_hresult_handler = nullptr; + s_loggerCalled = false; +}