diff --git a/libstdc++-v3/include/bits/hashtable.h b/libstdc++-v3/include/bits/hashtable.h index 8fac385570b..9e721aad8cc 100644 --- a/libstdc++-v3/include/bits/hashtable.h +++ b/libstdc++-v3/include/bits/hashtable.h @@ -337,6 +337,13 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION bool _Constant_iteratorsa> friend struct __detail::_Insert; + template + friend struct __detail::_Equality; + public: using size_type = typename __hashtable_base::size_type; using difference_type = typename __hashtable_base::difference_type; diff --git a/libstdc++-v3/include/bits/hashtable_policy.h b/libstdc++-v3/include/bits/hashtable_policy.h index 7bbfdfd375b..4024e6c37fa 100644 --- a/libstdc++-v3/include/bits/hashtable_policy.h +++ b/libstdc++-v3/include/bits/hashtable_policy.h @@ -34,6 +34,7 @@ #include // for std::tuple, std::forward_as_tuple #include // for std::numeric_limits #include // for std::min. +#include // for std::is_permutation. namespace std _GLIBCXX_VISIBILITY(default) { @@ -1815,65 +1816,6 @@ namespace __detail _M_eq() const { return _EqualEBO::_M_cget(); } }; - /** - * struct _Equality_base. - * - * Common types and functions for class _Equality. - */ - struct _Equality_base - { - protected: - template - static bool - _S_is_permutation(_Uiterator, _Uiterator, _Uiterator); - }; - - // See std::is_permutation in N3068. - template - bool - _Equality_base:: - _S_is_permutation(_Uiterator __first1, _Uiterator __last1, - _Uiterator __first2) - { - for (; __first1 != __last1; ++__first1, ++__first2) - if (!(*__first1 == *__first2)) - break; - - if (__first1 == __last1) - return true; - - _Uiterator __last2 = __first2; - std::advance(__last2, std::distance(__first1, __last1)); - - for (_Uiterator __it1 = __first1; __it1 != __last1; ++__it1) - { - _Uiterator __tmp = __first1; - while (__tmp != __it1 && !bool(*__tmp == *__it1)) - ++__tmp; - - // We've seen this one before. - if (__tmp != __it1) - continue; - - std::ptrdiff_t __n2 = 0; - for (__tmp = __first2; __tmp != __last2; ++__tmp) - if (*__tmp == *__it1) - ++__n2; - - if (!__n2) - return false; - - std::ptrdiff_t __n1 = 0; - for (__tmp = __it1; __tmp != __last1; ++__tmp) - if (*__tmp == *__it1) - ++__n1; - - if (__n1 != __n2) - return false; - } - return true; - } - /** * Primary class template _Equality. * @@ -1889,7 +1831,7 @@ namespace __detail bool _Unique_keys = _Traits::__unique_keys::value> struct _Equality; - /// Specialization. + /// unordered_map and unordered_set specializations. template:: _M_equal(const __hashtable& __other) const { + using __node_base = typename __hashtable::__node_base; + using __node_type = typename __hashtable::__node_type; const __hashtable* __this = static_cast(this); - if (__this->size() != __other.size()) return false; for (auto __itx = __this->begin(); __itx != __this->end(); ++__itx) { - const auto __ity = __other.find(_ExtractKey()(*__itx)); - if (__ity == __other.end() || !bool(*__ity == *__itx)) + std::size_t __ybkt = __other._M_bucket_index(__itx._M_cur); + __node_base* __prev_n = __other._M_buckets[__ybkt]; + if (!__prev_n) return false; + + for (__node_type* __n = static_cast<__node_type*>(__prev_n->_M_nxt);; + __n = __n->_M_next()) + { + if (__n->_M_v() == *__itx) + break; + + if (!__n->_M_nxt + || __other._M_bucket_index(__n->_M_next()) != __ybkt) + return false; + } } + return true; } - /// Specialization. + /// unordered_multiset and unordered_multimap specializations. template struct _Equality<_Key, _Value, _Alloc, _ExtractKey, _Equal, _H1, _H2, _Hash, _RehashPolicy, _Traits, false> - : public _Equality_base { using __hashtable = _Hashtable<_Key, _Value, _Alloc, _ExtractKey, _Equal, _H1, _H2, _Hash, _RehashPolicy, _Traits>; @@ -1952,25 +1907,51 @@ namespace __detail _H1, _H2, _Hash, _RehashPolicy, _Traits, false>:: _M_equal(const __hashtable& __other) const { + using __node_base = typename __hashtable::__node_base; + using __node_type = typename __hashtable::__node_type; const __hashtable* __this = static_cast(this); - if (__this->size() != __other.size()) return false; for (auto __itx = __this->begin(); __itx != __this->end();) { - const auto __xrange = __this->equal_range(_ExtractKey()(*__itx)); - const auto __yrange = __other.equal_range(_ExtractKey()(*__itx)); + std::size_t __x_count = 1; + auto __itx_end = __itx; + for (++__itx_end; __itx_end != __this->end() + && __this->key_eq()(_ExtractKey()(*__itx), + _ExtractKey()(*__itx_end)); + ++__itx_end) + ++__x_count; + + std::size_t __ybkt = __other._M_bucket_index(__itx._M_cur); + __node_base* __y_prev_n = __other._M_buckets[__ybkt]; + if (!__y_prev_n) + return false; + + __node_type* __y_n = static_cast<__node_type*>(__y_prev_n->_M_nxt); + for (;; __y_n = __y_n->_M_next()) + { + if (__this->key_eq()(_ExtractKey()(__y_n->_M_v()), + _ExtractKey()(*__itx))) + break; + + if (!__y_n->_M_nxt + || __other._M_bucket_index(__y_n->_M_next()) != __ybkt) + return false; + } + + typename __hashtable::const_iterator __ity(__y_n); + for (auto __ity_end = __ity; __ity_end != __other.end(); ++__ity_end) + if (--__x_count == 0) + break; - if (std::distance(__xrange.first, __xrange.second) - != std::distance(__yrange.first, __yrange.second)) + if (__x_count != 0) return false; - if (!_S_is_permutation(__xrange.first, __xrange.second, - __yrange.first)) + if (!std::is_permutation(__itx, __itx_end, __ity)) return false; - __itx = __xrange.second; + __itx = __itx_end; } return true; } diff --git a/libstdc++-v3/testsuite/23_containers/unordered_multiset/operators/1.cc b/libstdc++-v3/testsuite/23_containers/unordered_multiset/operators/1.cc index 4b87f62b74a..7252cad29c2 100644 --- a/libstdc++-v3/testsuite/23_containers/unordered_multiset/operators/1.cc +++ b/libstdc++-v3/testsuite/23_containers/unordered_multiset/operators/1.cc @@ -99,8 +99,64 @@ void test01() VERIFY( !(ums1 != cums2) ); } +void test02() +{ + std::unordered_multiset us1 + { 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9 }; + std::unordered_multiset us2 + { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + + VERIFY( us1 == us2 ); +} + +struct Hash +{ + std::size_t + operator()(const std::pair& p) const + { return p.first; } +}; + +struct Equal +{ + bool + operator()(const std::pair& lhs, const std::pair& rhs) const + { return lhs.first == rhs.first; } +}; + +void test03() +{ + std::unordered_multiset, Hash, Equal> us1 + { + { 0, 0 }, { 1, 0 }, { 2, 0 }, { 3, 0 }, { 4, 0 }, + { 0, 1 }, { 1, 1 }, { 2, 1 }, { 3, 1 }, { 4, 1 }, + { 5, 0 }, { 6, 0 }, { 7, 0 }, { 8, 0 }, { 9, 0 }, + { 5, 1 }, { 6, 1 }, { 7, 1 }, { 8, 1 }, { 9, 1 } + }; + std::unordered_multiset, Hash, Equal> us2 + { + { 5, 1 }, { 6, 1 }, { 7, 1 }, { 8, 1 }, { 9, 1 }, + { 0, 1 }, { 1, 1 }, { 2, 1 }, { 3, 1 }, { 4, 1 }, + { 5, 0 }, { 6, 0 }, { 7, 0 }, { 8, 0 }, { 9, 0 }, + { 0, 0 }, { 1, 0 }, { 2, 0 }, { 3, 0 }, { 4, 0 } + }; + + VERIFY( us1 == us2 ); + + std::unordered_multiset, Hash, Equal> us3 + { + { 5, 1 }, { 6, 1 }, { 7, 1 }, { 8, 1 }, { 9, 1 }, + { 0, 1 }, { 1, 1 }, { 2, 1 }, { 3, 1 }, { 4, 1 }, + { 5, 0 }, { 6, 0 }, { 7, 1 }, { 8, 0 }, { 9, 0 }, + { 0, 0 }, { 1, 0 }, { 2, 0 }, { 3, 0 }, { 4, 0 } + }; + + VERIFY( us1 != us3 ); +} + int main() { test01(); + test02(); + test03(); return 0; } diff --git a/libstdc++-v3/testsuite/23_containers/unordered_set/operators/1.cc b/libstdc++-v3/testsuite/23_containers/unordered_set/operators/1.cc index d841246e2c1..36a45dfa099 100644 --- a/libstdc++-v3/testsuite/23_containers/unordered_set/operators/1.cc +++ b/libstdc++-v3/testsuite/23_containers/unordered_set/operators/1.cc @@ -99,8 +99,56 @@ void test01() VERIFY( !(us1 != cus2) ); } +void test02() +{ + std::unordered_set us1 { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + std::unordered_set us2 { 0, 2, 4, 6, 8, 1, 3, 5, 7, 9 }; + + VERIFY( us1 == us2 ); +} + +struct Hash +{ + std::size_t + operator()(const std::pair& p) const + { return p.first; } +}; + +struct Equal +{ + bool + operator()(const std::pair& lhs, const std::pair& rhs) const + { return lhs.first == rhs.first; } +}; + +void test03() +{ + std::unordered_set, Hash, Equal> us1 + { + { 0, 0 }, { 1, 1 }, { 2, 2 }, { 3, 3 }, { 4, 4 }, + { 5, 5 }, { 6, 6 }, { 7, 7 }, { 8, 8 }, { 9, 9 } + }; + std::unordered_set, Hash, Equal> us2 + { + { 5, 5 }, { 6, 6 }, { 7, 7 }, { 8, 8 }, { 9, 9 }, + { 0, 0 }, { 1, 1 }, { 2, 2 }, { 3, 3 }, { 4, 4 } + }; + + VERIFY( us1 == us2 ); + + std::unordered_set, Hash, Equal> us3 + { + { 5, -5 }, { 6, 6 }, { 7, 7 }, { 8, 8 }, { 9, 9 }, + { 0, 0 }, { 1, 1 }, { 2, 2 }, { 3, 3 }, { 4, 4 } + }; + + VERIFY( us1 != us3 ); +} + int main() { test01(); + test02(); + test03(); return 0; }