diff --git a/libstdc++-v3/include/std/variant b/libstdc++-v3/include/std/variant index dd8847cf829..56de78407c4 100644 --- a/libstdc++-v3/include/std/variant +++ b/libstdc++-v3/include/std/variant @@ -182,7 +182,7 @@ namespace __variant // used for raw visitation with indices passed in struct __variant_idx_cookie { using type = __variant_idx_cookie; }; // Used to enable deduction (and same-type checking) for std::visit: - template struct __deduce_visit_result { }; + template struct __deduce_visit_result { using type = _Tp; }; // Visit variants that might be valueless. template @@ -1017,7 +1017,22 @@ namespace __variant static constexpr auto _S_apply() - { return _Array_type{&__visit_invoke}; } + { + constexpr bool __visit_ret_type_mismatch = + _Array_type::__result_is_deduced::value + && !is_same_v(), + std::declval<_Variants>()...))>; + if constexpr (__visit_ret_type_mismatch) + { + static_assert(!__visit_ret_type_mismatch, + "std::visit requires the visitor to have the same " + "return type for all alternatives of a variant"); + return __nonesuch{}; + } + else + return _Array_type{&__visit_invoke}; + } }; template @@ -1692,6 +1707,27 @@ namespace __variant std::forward<_Variants>(__variants)...); } + template + struct __same_types : public std::bool_constant< + std::__and_...>::value> {}; + + template + decltype(auto) __check_visitor_result(_Visitor&& __vis, + _Variant&& __variant) + { + return std::forward<_Visitor>(__vis)( + std::get<_Idx>(std::forward<_Variant>(__variant))); + } + + template + constexpr bool __check_visitor_results(std::index_sequence<_Idxs...>) + { + return __same_types( + std::declval<_Visitor>(), + std::declval<_Variant>()))...>::value; + } + + template constexpr decltype(auto) visit(_Visitor&& __visitor, _Variants&&... __variants) @@ -1704,8 +1740,28 @@ namespace __variant using _Tag = __detail::__variant::__deduce_visit_result<_Result_type>; - return std::__do_visit<_Tag>(std::forward<_Visitor>(__visitor), - std::forward<_Variants>(__variants)...); + if constexpr (sizeof...(_Variants) == 1) + { + constexpr bool __visit_rettypes_match = + __check_visitor_results<_Visitor, _Variants...>( + std::make_index_sequence< + std::variant_size...>::value>()); + if constexpr (!__visit_rettypes_match) + { + static_assert(__visit_rettypes_match, + "std::visit requires the visitor to have the same " + "return type for all alternatives of a variant"); + return; + } + else + return std::__do_visit<_Tag>( + std::forward<_Visitor>(__visitor), + std::forward<_Variants>(__variants)...); + } + else + return std::__do_visit<_Tag>( + std::forward<_Visitor>(__visitor), + std::forward<_Variants>(__variants)...); } #if __cplusplus > 201703L