From 389fd74a3f3e9422a965263b6961b51295c55976 Mon Sep 17 00:00:00 2001 From: Mark Wielaard Date: Sun, 1 Aug 2021 12:57:33 +0200 Subject: [PATCH] union support for hir type checking and gcc backend Treat a union as a Struct variant like a tuple struct. Add an iterator and get_identifier functions to the AST Union class. Same for the HIR Union class, plus a get_generics_params method. Add a new ADTKind enum and adt_kind field to the ADTType to select the underlying abstract data type (struct struct, tuple struct or union, with enum as possible future variant). An union constructor can have only one field. Add an union_index field to StructExprStruct which is set during type checking in the TypeCheckStructExpr HIR StructExprStructFields visitor. For the Gcc_backend class rename fill_in_struct to fill_in_fields and use it from a new union_type method. Handle union_index in constructor_expression (so only one field is initialized). --- gcc/rust/ast/rust-item.h | 11 +++ gcc/rust/backend/rust-compile-context.h | 8 +- gcc/rust/backend/rust-compile-expr.h | 3 +- gcc/rust/backend/rust-compile.cc | 2 +- gcc/rust/hir/rust-ast-lower-item.h | 51 +++++++++++ gcc/rust/hir/rust-ast-lower-stmt.h | 53 +++++++++++ gcc/rust/hir/tree/rust-hir-expr.h | 8 +- gcc/rust/hir/tree/rust-hir-item.h | 16 ++++ gcc/rust/resolve/rust-ast-resolve-item.h | 22 +++++ gcc/rust/resolve/rust-ast-resolve-stmt.h | 32 +++++++ gcc/rust/resolve/rust-ast-resolve-toplevel.h | 14 +++ gcc/rust/rust-backend.h | 5 +- gcc/rust/rust-gcc.cc | 91 ++++++++++++++----- gcc/rust/typecheck/rust-hir-type-check-stmt.h | 58 +++++++++++- .../typecheck/rust-hir-type-check-toplevel.h | 57 +++++++++++- gcc/rust/typecheck/rust-hir-type-check.cc | 59 ++++++++---- gcc/rust/typecheck/rust-tycheck-dump.h | 6 ++ gcc/rust/typecheck/rust-tyty.cc | 4 +- gcc/rust/typecheck/rust-tyty.h | 22 +++-- gcc/testsuite/rust/compile/torture/union.rs | 35 +++++++ .../rust/compile/torture/union_union.rs | 27 ++++++ 21 files changed, 529 insertions(+), 55 deletions(-) create mode 100644 gcc/testsuite/rust/compile/torture/union.rs create mode 100644 gcc/testsuite/rust/compile/torture/union_union.rs diff --git a/gcc/rust/ast/rust-item.h b/gcc/rust/ast/rust-item.h index 6d29c5b2e22..5605b0bb79c 100644 --- a/gcc/rust/ast/rust-item.h +++ b/gcc/rust/ast/rust-item.h @@ -2489,6 +2489,15 @@ public: std::vector &get_variants () { return variants; } const std::vector &get_variants () const { return variants; } + void iterate (std::function cb) + { + for (auto &variant : variants) + { + if (!cb (variant)) + return; + } + } + std::vector > &get_generic_params () { return generic_params; @@ -2505,6 +2514,8 @@ public: return where_clause; } + Identifier get_identifier () const { return union_name; } + protected: /* Use covariance to implement clone function as returning this object * rather than base */ diff --git a/gcc/rust/backend/rust-compile-context.h b/gcc/rust/backend/rust-compile-context.h index 0aaf084f04f..8007c2fa00d 100644 --- a/gcc/rust/backend/rust-compile-context.h +++ b/gcc/rust/backend/rust-compile-context.h @@ -417,9 +417,13 @@ public: fields.push_back (std::move (f)); } - Btype *struct_type_record = ctx->get_backend ()->struct_type (fields); + Btype *type_record; + if (type.is_union ()) + type_record = ctx->get_backend ()->union_type (fields); + else + type_record = ctx->get_backend ()->struct_type (fields); Btype *named_struct - = ctx->get_backend ()->named_type (type.get_name (), struct_type_record, + = ctx->get_backend ()->named_type (type.get_name (), type_record, ctx->get_mappings ()->lookup_location ( type.get_ty_ref ())); diff --git a/gcc/rust/backend/rust-compile-expr.h b/gcc/rust/backend/rust-compile-expr.h index 2a147abcf09..46582954895 100644 --- a/gcc/rust/backend/rust-compile-expr.h +++ b/gcc/rust/backend/rust-compile-expr.h @@ -80,7 +80,7 @@ public: } translated - = ctx->get_backend ()->constructor_expression (tuple_type, vals, + = ctx->get_backend ()->constructor_expression (tuple_type, vals, -1, expr.get_locus ()); } @@ -595,6 +595,7 @@ public: translated = ctx->get_backend ()->constructor_expression (type, vals, + struct_expr.union_index, struct_expr.get_locus ()); } diff --git a/gcc/rust/backend/rust-compile.cc b/gcc/rust/backend/rust-compile.cc index 5ffd11a422c..aa9aa2d90ab 100644 --- a/gcc/rust/backend/rust-compile.cc +++ b/gcc/rust/backend/rust-compile.cc @@ -79,7 +79,7 @@ CompileExpr::visit (HIR::CallExpr &expr) }); translated - = ctx->get_backend ()->constructor_expression (type, vals, + = ctx->get_backend ()->constructor_expression (type, vals, -1, expr.get_locus ()); } else diff --git a/gcc/rust/hir/rust-ast-lower-item.h b/gcc/rust/hir/rust-ast-lower-item.h index 80ca29859fb..f168c7d3e88 100644 --- a/gcc/rust/hir/rust-ast-lower-item.h +++ b/gcc/rust/hir/rust-ast-lower-item.h @@ -193,6 +193,57 @@ public: struct_decl.get_locus ()); } + void visit (AST::Union &union_decl) override + { + std::vector > generic_params; + if (union_decl.has_generics ()) + { + generic_params + = lower_generic_params (union_decl.get_generic_params ()); + } + + std::vector > where_clause_items; + HIR::WhereClause where_clause (std::move (where_clause_items)); + HIR::Visibility vis = HIR::Visibility::create_public (); + + std::vector variants; + union_decl.iterate ([&] (AST::StructField &variant) mutable -> bool { + HIR::Visibility vis = HIR::Visibility::create_public (); + HIR::Type *type + = ASTLoweringType::translate (variant.get_field_type ().get ()); + + auto crate_num = mappings->get_current_crate (); + Analysis::NodeMapping mapping (crate_num, variant.get_node_id (), + mappings->get_next_hir_id (crate_num), + mappings->get_next_localdef_id ( + crate_num)); + + HIR::StructField translated_variant (mapping, variant.get_field_name (), + std::unique_ptr (type), + vis, variant.get_locus (), + variant.get_outer_attrs ()); + variants.push_back (std::move (translated_variant)); + return true; + }); + + auto crate_num = mappings->get_current_crate (); + Analysis::NodeMapping mapping (crate_num, union_decl.get_node_id (), + mappings->get_next_hir_id (crate_num), + mappings->get_next_localdef_id (crate_num)); + + translated + = new HIR::Union (mapping, union_decl.get_identifier (), vis, + std::move (generic_params), std::move (where_clause), + std::move (variants), union_decl.get_outer_attrs (), + union_decl.get_locus ()); + + mappings->insert_defid_mapping (mapping.get_defid (), translated); + mappings->insert_hir_item (mapping.get_crate_num (), mapping.get_hirid (), + translated); + mappings->insert_location (crate_num, mapping.get_hirid (), + union_decl.get_locus ()); + } + void visit (AST::StaticItem &var) override { HIR::Visibility vis = HIR::Visibility::create_public (); diff --git a/gcc/rust/hir/rust-ast-lower-stmt.h b/gcc/rust/hir/rust-ast-lower-stmt.h index 9df6b746bb7..2e97ca63a13 100644 --- a/gcc/rust/hir/rust-ast-lower-stmt.h +++ b/gcc/rust/hir/rust-ast-lower-stmt.h @@ -215,6 +215,59 @@ public: struct_decl.get_locus ()); } + void visit (AST::Union &union_decl) override + { + std::vector > generic_params; + if (union_decl.has_generics ()) + { + generic_params + = lower_generic_params (union_decl.get_generic_params ()); + } + + std::vector > where_clause_items; + HIR::WhereClause where_clause (std::move (where_clause_items)); + HIR::Visibility vis = HIR::Visibility::create_public (); + + std::vector variants; + union_decl.iterate ([&] (AST::StructField &variant) mutable -> bool { + HIR::Visibility vis = HIR::Visibility::create_public (); + HIR::Type *type + = ASTLoweringType::translate (variant.get_field_type ().get ()); + + auto crate_num = mappings->get_current_crate (); + Analysis::NodeMapping mapping (crate_num, variant.get_node_id (), + mappings->get_next_hir_id (crate_num), + mappings->get_next_localdef_id ( + crate_num)); + + // FIXME + // AST::StructField is missing Location info + Location variant_locus; + HIR::StructField translated_variant (mapping, variant.get_field_name (), + std::unique_ptr (type), + vis, variant_locus, + variant.get_outer_attrs ()); + variants.push_back (std::move (translated_variant)); + return true; + }); + + auto crate_num = mappings->get_current_crate (); + Analysis::NodeMapping mapping (crate_num, union_decl.get_node_id (), + mappings->get_next_hir_id (crate_num), + mappings->get_next_localdef_id (crate_num)); + + translated + = new HIR::Union (mapping, union_decl.get_identifier (), vis, + std::move (generic_params), std::move (where_clause), + std::move (variants), union_decl.get_outer_attrs (), + union_decl.get_locus ()); + + mappings->insert_hir_stmt (mapping.get_crate_num (), mapping.get_hirid (), + translated); + mappings->insert_location (crate_num, mapping.get_hirid (), + union_decl.get_locus ()); + } + void visit (AST::EmptyStmt &empty) override { auto crate_num = mappings->get_current_crate (); diff --git a/gcc/rust/hir/tree/rust-hir-expr.h b/gcc/rust/hir/tree/rust-hir-expr.h index 65c40d60bdf..8d815c5adc6 100644 --- a/gcc/rust/hir/tree/rust-hir-expr.h +++ b/gcc/rust/hir/tree/rust-hir-expr.h @@ -1449,6 +1449,10 @@ public: // FIXME make unique_ptr StructBase *struct_base; + // For unions there is just one field, the index + // is set when type checking + int union_index = -1; + std::string as_string () const override; bool has_struct_base () const { return struct_base != nullptr; } @@ -1467,7 +1471,8 @@ public: // copy constructor with vector clone StructExprStructFields (StructExprStructFields const &other) - : StructExprStruct (other), struct_base (other.struct_base) + : StructExprStruct (other), struct_base (other.struct_base), + union_index (other.union_index) { fields.reserve (other.fields.size ()); for (const auto &e : other.fields) @@ -1479,6 +1484,7 @@ public: { StructExprStruct::operator= (other); struct_base = other.struct_base; + union_index = other.union_index; fields.reserve (other.fields.size ()); for (const auto &e : other.fields) diff --git a/gcc/rust/hir/tree/rust-hir-item.h b/gcc/rust/hir/tree/rust-hir-item.h index 7d976c5c991..182fe87a6b9 100644 --- a/gcc/rust/hir/tree/rust-hir-item.h +++ b/gcc/rust/hir/tree/rust-hir-item.h @@ -1989,10 +1989,26 @@ public: Union (Union &&other) = default; Union &operator= (Union &&other) = default; + std::vector > &get_generic_params () + { + return generic_params; + } + + Identifier get_identifier () const { return union_name; } + Location get_locus () const { return locus; } void accept_vis (HIRVisitor &vis) override; + void iterate (std::function cb) + { + for (auto &variant : variants) + { + if (!cb (variant)) + return; + } + } + protected: /* Use covariance to implement clone function as returning this object * rather than base */ diff --git a/gcc/rust/resolve/rust-ast-resolve-item.h b/gcc/rust/resolve/rust-ast-resolve-item.h index 539229d60fa..c67121d72f9 100644 --- a/gcc/rust/resolve/rust-ast-resolve-item.h +++ b/gcc/rust/resolve/rust-ast-resolve-item.h @@ -260,6 +260,28 @@ public: resolver->get_type_scope ().pop (); } + void visit (AST::Union &union_decl) override + { + NodeId scope_node_id = union_decl.get_node_id (); + resolver->get_type_scope ().push (scope_node_id); + + if (union_decl.has_generics ()) + { + for (auto &generic : union_decl.get_generic_params ()) + { + ResolveGenericParam::go (generic.get (), union_decl.get_node_id ()); + } + } + + union_decl.iterate ([&] (AST::StructField &field) mutable -> bool { + ResolveType::go (field.get_field_type ().get (), + union_decl.get_node_id ()); + return true; + }); + + resolver->get_type_scope ().pop (); + } + void visit (AST::StaticItem &var) override { ResolveType::go (var.get_type ().get (), var.get_node_id ()); diff --git a/gcc/rust/resolve/rust-ast-resolve-stmt.h b/gcc/rust/resolve/rust-ast-resolve-stmt.h index 210a9fc047d..b6044327b27 100644 --- a/gcc/rust/resolve/rust-ast-resolve-stmt.h +++ b/gcc/rust/resolve/rust-ast-resolve-stmt.h @@ -131,6 +131,38 @@ public: resolver->get_type_scope ().pop (); } + void visit (AST::Union &union_decl) override + { + auto path = CanonicalPath::new_seg (union_decl.get_node_id (), + union_decl.get_identifier ()); + resolver->get_type_scope ().insert ( + path, union_decl.get_node_id (), union_decl.get_locus (), false, + [&] (const CanonicalPath &, NodeId, Location locus) -> void { + RichLocation r (union_decl.get_locus ()); + r.add_range (locus); + rust_error_at (r, "redefined multiple times"); + }); + + NodeId scope_node_id = union_decl.get_node_id (); + resolver->get_type_scope ().push (scope_node_id); + + if (union_decl.has_generics ()) + { + for (auto &generic : union_decl.get_generic_params ()) + { + ResolveGenericParam::go (generic.get (), union_decl.get_node_id ()); + } + } + + union_decl.iterate ([&] (AST::StructField &field) mutable -> bool { + ResolveType::go (field.get_field_type ().get (), + union_decl.get_node_id ()); + return true; + }); + + resolver->get_type_scope ().pop (); + } + void visit (AST::Function &function) override { auto path = ResolveFunctionItemToCanonicalPath::resolve (function); diff --git a/gcc/rust/resolve/rust-ast-resolve-toplevel.h b/gcc/rust/resolve/rust-ast-resolve-toplevel.h index a042f5c3dcb..57a0534de48 100644 --- a/gcc/rust/resolve/rust-ast-resolve-toplevel.h +++ b/gcc/rust/resolve/rust-ast-resolve-toplevel.h @@ -81,6 +81,20 @@ public: }); } + void visit (AST::Union &union_decl) override + { + auto path + = prefix.append (CanonicalPath::new_seg (union_decl.get_node_id (), + union_decl.get_identifier ())); + resolver->get_type_scope ().insert ( + path, union_decl.get_node_id (), union_decl.get_locus (), false, + [&] (const CanonicalPath &, NodeId, Location locus) -> void { + RichLocation r (union_decl.get_locus ()); + r.add_range (locus); + rust_error_at (r, "redefined multiple times"); + }); + } + void visit (AST::StaticItem &var) override { auto path = prefix.append ( diff --git a/gcc/rust/rust-backend.h b/gcc/rust/rust-backend.h index be23fd3d852..4635796e953 100644 --- a/gcc/rust/rust-backend.h +++ b/gcc/rust/rust-backend.h @@ -178,6 +178,9 @@ public: // Get a struct type. virtual Btype *struct_type (const std::vector &fields) = 0; + // Get a union type. + virtual Btype *union_type (const std::vector &fields) = 0; + // Get an array type. virtual Btype *array_type (Btype *element_type, Bexpression *length) = 0; @@ -424,7 +427,7 @@ public: // corresponding fields in BTYPE. virtual Bexpression * constructor_expression (Btype *btype, const std::vector &vals, - Location) + int, Location) = 0; // Return an expression that constructs an array of BTYPE with INDEXES and diff --git a/gcc/rust/rust-gcc.cc b/gcc/rust/rust-gcc.cc index 44617a68d2a..3e47a7cba7a 100644 --- a/gcc/rust/rust-gcc.cc +++ b/gcc/rust/rust-gcc.cc @@ -265,6 +265,8 @@ public: Btype *struct_type (const std::vector &); + Btype *union_type (const std::vector &); + Btype *array_type (Btype *, Bexpression *); Btype *placeholder_pointer_type (const std::string &, Location, bool); @@ -377,7 +379,7 @@ public: Location); Bexpression *constructor_expression (Btype *, - const std::vector &, + const std::vector &, int, Location); Bexpression *array_constructor_expression (Btype *, @@ -531,7 +533,7 @@ private: Bfunction *make_function (tree t) { return new Bfunction (t); } - Btype *fill_in_struct (Btype *, const std::vector &); + Btype *fill_in_fields (Btype *, const std::vector &); Btype *fill_in_array (Btype *, Btype *, Bexpression *); @@ -1145,14 +1147,23 @@ Gcc_backend::function_ptr_type (Btype *result_type, Btype * Gcc_backend::struct_type (const std::vector &fields) { - return this->fill_in_struct (this->make_type (make_node (RECORD_TYPE)), + return this->fill_in_fields (this->make_type (make_node (RECORD_TYPE)), + fields); +} + +// Make a union type. + +Btype * +Gcc_backend::union_type (const std::vector &fields) +{ + return this->fill_in_fields (this->make_type (make_node (UNION_TYPE)), fields); } -// Fill in the fields of a struct type. +// Fill in the fields of a struct or union type. Btype * -Gcc_backend::fill_in_struct (Btype *fill, +Gcc_backend::fill_in_fields (Btype *fill, const std::vector &fields) { tree fill_tree = fill->get_tree (); @@ -1311,7 +1322,7 @@ Gcc_backend::set_placeholder_struct_type ( { tree t = placeholder->get_tree (); gcc_assert (TREE_CODE (t) == RECORD_TYPE && TYPE_FIELDS (t) == NULL_TREE); - Btype *r = this->fill_in_struct (placeholder, fields); + Btype *r = this->fill_in_fields (placeholder, fields); if (TYPE_NAME (t) != NULL_TREE) { @@ -1321,7 +1332,7 @@ Gcc_backend::set_placeholder_struct_type ( DECL_ORIGINAL_TYPE (TYPE_NAME (t)) = copy; TYPE_SIZE (copy) = NULL_TREE; Btype *bc = this->make_type (copy); - this->fill_in_struct (bc, fields); + this->fill_in_fields (bc, fields); delete bc; } @@ -1758,7 +1769,8 @@ Gcc_backend::struct_field_expression (Bexpression *bstruct, size_t index, if (struct_tree == error_mark_node || TREE_TYPE (struct_tree) == error_mark_node) return this->error_expression (); - gcc_assert (TREE_CODE (TREE_TYPE (struct_tree)) == RECORD_TYPE); + gcc_assert (TREE_CODE (TREE_TYPE (struct_tree)) == RECORD_TYPE + || TREE_CODE (TREE_TYPE (struct_tree)) == UNION_TYPE); tree field = TYPE_FIELDS (TREE_TYPE (struct_tree)); if (field == NULL_TREE) { @@ -2041,7 +2053,7 @@ Gcc_backend::lazy_boolean_expression (LazyBooleanOperator op, Bexpression *left, Bexpression * Gcc_backend::constructor_expression (Btype *btype, const std::vector &vals, - Location location) + int union_index, Location location) { tree type_tree = btype->get_tree (); if (type_tree == error_mark_node) @@ -2053,11 +2065,15 @@ Gcc_backend::constructor_expression (Btype *btype, tree sink = NULL_TREE; bool is_constant = true; tree field = TYPE_FIELDS (type_tree); - for (std::vector::const_iterator p = vals.begin (); - p != vals.end (); ++p, field = DECL_CHAIN (field)) + if (union_index != -1) { - gcc_assert (field != NULL_TREE); - tree val = (*p)->get_tree (); + gcc_assert (TREE_CODE (type_tree) == UNION_TYPE); + tree val = vals.front ()->get_tree (); + for (int i = 0; i < union_index; i++) + { + gcc_assert (field != NULL_TREE); + field = DECL_CHAIN (field); + } if (TREE_TYPE (field) == error_mark_node || val == error_mark_node || TREE_TYPE (val) == error_mark_node) return this->error_expression (); @@ -2070,17 +2086,49 @@ Gcc_backend::constructor_expression (Btype *btype, // would have been added as a map element for its // side-effects and construct an empty map. append_to_statement_list (val, &sink); - continue; } + else + { + constructor_elt empty = {NULL, NULL}; + constructor_elt *elt = init->quick_push (empty); + elt->index = field; + elt->value = this->convert_tree (TREE_TYPE (field), val, location); + if (!TREE_CONSTANT (elt->value)) + is_constant = false; + } + } + else + { + gcc_assert (TREE_CODE (type_tree) == RECORD_TYPE); + for (std::vector::const_iterator p = vals.begin (); + p != vals.end (); ++p, field = DECL_CHAIN (field)) + { + gcc_assert (field != NULL_TREE); + tree val = (*p)->get_tree (); + if (TREE_TYPE (field) == error_mark_node || val == error_mark_node + || TREE_TYPE (val) == error_mark_node) + return this->error_expression (); - constructor_elt empty = {NULL, NULL}; - constructor_elt *elt = init->quick_push (empty); - elt->index = field; - elt->value = this->convert_tree (TREE_TYPE (field), val, location); - if (!TREE_CONSTANT (elt->value)) - is_constant = false; + if (int_size_in_bytes (TREE_TYPE (field)) == 0) + { + // GIMPLE cannot represent indices of zero-sized types so + // trying to construct a map with zero-sized keys might lead + // to errors. Instead, we evaluate each expression that + // would have been added as a map element for its + // side-effects and construct an empty map. + append_to_statement_list (val, &sink); + continue; + } + + constructor_elt empty = {NULL, NULL}; + constructor_elt *elt = init->quick_push (empty); + elt->index = field; + elt->value = this->convert_tree (TREE_TYPE (field), val, location); + if (!TREE_CONSTANT (elt->value)) + is_constant = false; + } + gcc_assert (field == NULL_TREE); } - gcc_assert (field == NULL_TREE); tree ret = build_constructor (type_tree, init); if (is_constant) TREE_CONSTANT (ret) = 1; @@ -2781,6 +2829,7 @@ Gcc_backend::convert_tree (tree type_tree, tree expr_tree, Location location) || SCALAR_FLOAT_TYPE_P (type_tree) || COMPLEX_FLOAT_TYPE_P (type_tree)) return fold_convert_loc (location.gcc_location (), type_tree, expr_tree); else if (TREE_CODE (type_tree) == RECORD_TYPE + || TREE_CODE (type_tree) == UNION_TYPE || TREE_CODE (type_tree) == ARRAY_TYPE) { gcc_assert (int_size_in_bytes (type_tree) diff --git a/gcc/rust/typecheck/rust-hir-type-check-stmt.h b/gcc/rust/typecheck/rust-hir-type-check-stmt.h index 1b6f47c1595..77cbc0628ef 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-stmt.h +++ b/gcc/rust/typecheck/rust-hir-type-check-stmt.h @@ -159,7 +159,8 @@ public: TyTy::BaseType *type = new TyTy::ADTType (struct_decl.get_mappings ().get_hirid (), mappings->get_next_hir_id (), - struct_decl.get_identifier (), true, + struct_decl.get_identifier (), + TyTy::ADTType::ADTKind::TUPLE_STRUCT, std::move (fields), std::move (substitutions)); context->insert_type (struct_decl.get_mappings (), type); @@ -209,13 +210,66 @@ public: TyTy::BaseType *type = new TyTy::ADTType (struct_decl.get_mappings ().get_hirid (), mappings->get_next_hir_id (), - struct_decl.get_identifier (), false, + struct_decl.get_identifier (), + TyTy::ADTType::ADTKind::STRUCT_STRUCT, std::move (fields), std::move (substitutions)); context->insert_type (struct_decl.get_mappings (), type); infered = type; } + void visit (HIR::Union &union_decl) override + { + std::vector substitutions; + if (union_decl.has_generics ()) + { + for (auto &generic_param : union_decl.get_generic_params ()) + { + switch (generic_param.get ()->get_kind ()) + { + case HIR::GenericParam::GenericKind::LIFETIME: + // Skipping Lifetime completely until better handling. + break; + + case HIR::GenericParam::GenericKind::TYPE: { + auto param_type + = TypeResolveGenericParam::Resolve (generic_param.get ()); + context->insert_type (generic_param->get_mappings (), + param_type); + + substitutions.push_back (TyTy::SubstitutionParamMapping ( + static_cast (*generic_param), + param_type)); + } + break; + } + } + } + + std::vector variants; + union_decl.iterate ([&] (HIR::StructField &variant) mutable -> bool { + TyTy::BaseType *variant_type + = TypeCheckType::Resolve (variant.get_field_type ().get ()); + TyTy::StructFieldType *ty_variant + = new TyTy::StructFieldType (variant.get_mappings ().get_hirid (), + variant.get_field_name (), variant_type); + variants.push_back (ty_variant); + context->insert_type (variant.get_mappings (), + ty_variant->get_field_type ()); + return true; + }); + + TyTy::BaseType *type + = new TyTy::ADTType (union_decl.get_mappings ().get_hirid (), + mappings->get_next_hir_id (), + union_decl.get_identifier (), + TyTy::ADTType::ADTKind::UNION, std::move (variants), + std::move (substitutions)); + + context->insert_type (union_decl.get_mappings (), type); + infered = type; + } + void visit (HIR::Function &function) override { std::vector substitutions; diff --git a/gcc/rust/typecheck/rust-hir-type-check-toplevel.h b/gcc/rust/typecheck/rust-hir-type-check-toplevel.h index 18f3e725416..5b9757f6519 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-toplevel.h +++ b/gcc/rust/typecheck/rust-hir-type-check-toplevel.h @@ -94,7 +94,8 @@ public: TyTy::BaseType *type = new TyTy::ADTType (struct_decl.get_mappings ().get_hirid (), mappings->get_next_hir_id (), - struct_decl.get_identifier (), true, + struct_decl.get_identifier (), + TyTy::ADTType::ADTKind::TUPLE_STRUCT, std::move (fields), std::move (substitutions)); context->insert_type (struct_decl.get_mappings (), type); @@ -143,12 +144,64 @@ public: TyTy::BaseType *type = new TyTy::ADTType (struct_decl.get_mappings ().get_hirid (), mappings->get_next_hir_id (), - struct_decl.get_identifier (), false, + struct_decl.get_identifier (), + TyTy::ADTType::ADTKind::STRUCT_STRUCT, std::move (fields), std::move (substitutions)); context->insert_type (struct_decl.get_mappings (), type); } + void visit (HIR::Union &union_decl) override + { + std::vector substitutions; + if (union_decl.has_generics ()) + { + for (auto &generic_param : union_decl.get_generic_params ()) + { + switch (generic_param.get ()->get_kind ()) + { + case HIR::GenericParam::GenericKind::LIFETIME: + // Skipping Lifetime completely until better handling. + break; + + case HIR::GenericParam::GenericKind::TYPE: { + auto param_type + = TypeResolveGenericParam::Resolve (generic_param.get ()); + context->insert_type (generic_param->get_mappings (), + param_type); + + substitutions.push_back (TyTy::SubstitutionParamMapping ( + static_cast (*generic_param), + param_type)); + } + break; + } + } + } + + std::vector variants; + union_decl.iterate ([&] (HIR::StructField &variant) mutable -> bool { + TyTy::BaseType *variant_type + = TypeCheckType::Resolve (variant.get_field_type ().get ()); + TyTy::StructFieldType *ty_variant + = new TyTy::StructFieldType (variant.get_mappings ().get_hirid (), + variant.get_field_name (), variant_type); + variants.push_back (ty_variant); + context->insert_type (variant.get_mappings (), + ty_variant->get_field_type ()); + return true; + }); + + TyTy::BaseType *type + = new TyTy::ADTType (union_decl.get_mappings ().get_hirid (), + mappings->get_next_hir_id (), + union_decl.get_identifier (), + TyTy::ADTType::ADTKind::UNION, std::move (variants), + std::move (substitutions)); + + context->insert_type (union_decl.get_mappings (), type); + } + void visit (HIR::StaticItem &var) override { TyTy::BaseType *type = TypeCheckType::Resolve (var.get_type ()); diff --git a/gcc/rust/typecheck/rust-hir-type-check.cc b/gcc/rust/typecheck/rust-hir-type-check.cc index cb2896c0bb4..66adfcb5131 100644 --- a/gcc/rust/typecheck/rust-hir-type-check.cc +++ b/gcc/rust/typecheck/rust-hir-type-check.cc @@ -180,7 +180,17 @@ TypeCheckStructExpr::visit (HIR::StructExprStructFields &struct_expr) // check the arguments are all assigned and fix up the ordering if (fields_assigned.size () != struct_path_resolved->num_fields ()) { - if (!struct_expr.has_struct_base ()) + if (struct_def->is_union ()) + { + if (fields_assigned.size () != 1 || struct_expr.has_struct_base ()) + { + rust_error_at ( + struct_expr.get_locus (), + "union must have exactly one field variant assigned"); + return; + } + } + else if (!struct_expr.has_struct_base ()) { rust_error_at (struct_expr.get_locus (), "constructor is missing fields"); @@ -236,23 +246,40 @@ TypeCheckStructExpr::visit (HIR::StructExprStructFields &struct_expr) } } - // everything is ok, now we need to ensure all field values are ordered - // correctly. The GIMPLE backend uses a simple algorithm that assumes each - // assigned field in the constructor is in the same order as the field in - // the type - - std::vector > expr_fields - = struct_expr.get_fields_as_owner (); - for (auto &f : expr_fields) - f.release (); - - std::vector > ordered_fields; - for (size_t i = 0; i < adtFieldIndexToField.size (); i++) + if (struct_def->is_union ()) + { + // There is exactly one field in this constructor, we need to + // figure out the field index to make sure we initialize the + // right union field. + for (size_t i = 0; i < adtFieldIndexToField.size (); i++) + { + if (adtFieldIndexToField[i]) + { + struct_expr.union_index = i; + break; + } + } + rust_assert (struct_expr.union_index != -1); + } + else { - ordered_fields.push_back ( - std::unique_ptr (adtFieldIndexToField[i])); + // everything is ok, now we need to ensure all field values are ordered + // correctly. The GIMPLE backend uses a simple algorithm that assumes each + // assigned field in the constructor is in the same order as the field in + // the type + std::vector > expr_fields + = struct_expr.get_fields_as_owner (); + for (auto &f : expr_fields) + f.release (); + + std::vector > ordered_fields; + for (size_t i = 0; i < adtFieldIndexToField.size (); i++) + { + ordered_fields.push_back ( + std::unique_ptr (adtFieldIndexToField[i])); + } + struct_expr.set_fields_as_owner (std::move (ordered_fields)); } - struct_expr.set_fields_as_owner (std::move (ordered_fields)); resolved = struct_def; } diff --git a/gcc/rust/typecheck/rust-tycheck-dump.h b/gcc/rust/typecheck/rust-tycheck-dump.h index b80372b2a9c..cc2e3c01110 100644 --- a/gcc/rust/typecheck/rust-tycheck-dump.h +++ b/gcc/rust/typecheck/rust-tycheck-dump.h @@ -48,6 +48,12 @@ public: + "\n"; } + void visit (HIR::Union &union_decl) override + { + dump + += indent () + "union " + type_string (union_decl.get_mappings ()) + "\n"; + } + void visit (HIR::ImplBlock &impl_block) override { dump += indent () + "impl " diff --git a/gcc/rust/typecheck/rust-tyty.cc b/gcc/rust/typecheck/rust-tyty.cc index 1ca28fae061..6bac7647ec6 100644 --- a/gcc/rust/typecheck/rust-tyty.cc +++ b/gcc/rust/typecheck/rust-tyty.cc @@ -537,7 +537,7 @@ ADTType::clone () for (auto &f : fields) cloned_fields.push_back ((StructFieldType *) f->clone ()); - return new ADTType (get_ref (), get_ty_ref (), identifier, get_is_tuple (), + return new ADTType (get_ref (), get_ty_ref (), identifier, get_adt_kind (), cloned_fields, clone_substs (), used_arguments, get_combined_refs ()); } @@ -1999,7 +1999,7 @@ PlaceholderType::clone () void TypeCheckCallExpr::visit (ADTType &type) { - if (!type.get_is_tuple ()) + if (!type.is_tuple_struct ()) { rust_error_at ( call.get_locus (), diff --git a/gcc/rust/typecheck/rust-tyty.h b/gcc/rust/typecheck/rust-tyty.h index 336d42b15f9..46110e4a9a7 100644 --- a/gcc/rust/typecheck/rust-tyty.h +++ b/gcc/rust/typecheck/rust-tyty.h @@ -855,7 +855,15 @@ protected: class ADTType : public BaseType, public SubstitutionRef { public: - ADTType (HirId ref, std::string identifier, bool is_tuple, + enum ADTKind + { + STRUCT_STRUCT, + TUPLE_STRUCT, + UNION, + // ENUM ? + }; + + ADTType (HirId ref, std::string identifier, ADTKind adt_kind, std::vector fields, std::vector subst_refs, SubstitutionArgumentMappings generic_arguments @@ -863,10 +871,10 @@ public: std::set refs = std::set ()) : BaseType (ref, ref, TypeKind::ADT, refs), SubstitutionRef (std::move (subst_refs), std::move (generic_arguments)), - identifier (identifier), fields (fields), is_tuple (is_tuple) + identifier (identifier), fields (fields), adt_kind (adt_kind) {} - ADTType (HirId ref, HirId ty_ref, std::string identifier, bool is_tuple, + ADTType (HirId ref, HirId ty_ref, std::string identifier, ADTKind adt_kind, std::vector fields, std::vector subst_refs, SubstitutionArgumentMappings generic_arguments @@ -874,10 +882,12 @@ public: std::set refs = std::set ()) : BaseType (ref, ty_ref, TypeKind::ADT, refs), SubstitutionRef (std::move (subst_refs), std::move (generic_arguments)), - identifier (identifier), fields (fields), is_tuple (is_tuple) + identifier (identifier), fields (fields), adt_kind (adt_kind) {} - bool get_is_tuple () { return is_tuple; } + ADTKind get_adt_kind () { return adt_kind; } + bool is_tuple_struct () { return adt_kind == TUPLE_STRUCT; } + bool is_union () { return adt_kind == UNION; } bool is_unit () const override { return this->fields.empty (); } @@ -964,7 +974,7 @@ public: private: std::string identifier; std::vector fields; - bool is_tuple; + ADTType::ADTKind adt_kind; }; class FnType : public BaseType, public SubstitutionRef diff --git a/gcc/testsuite/rust/compile/torture/union.rs b/gcc/testsuite/rust/compile/torture/union.rs new file mode 100644 index 00000000000..393e59115a7 --- /dev/null +++ b/gcc/testsuite/rust/compile/torture/union.rs @@ -0,0 +1,35 @@ +// { dg-do compile } +// { dg-options "-w" } + +union U +{ + f1: u8 +} + +union V +{ + f1: u8, + f2: u16, + f3: i32, +} + +struct S +{ + f1: U, + f2: V +} + +fn main () +{ + let u = U { f1: 16 }; + let v = V { f2: 365 }; + let s = S { f1: u, f2: v }; + let _v125 = unsafe + { let mut uv: u64; + uv = s.f1.f1 as u64; + uv += s.f2.f1 as u64; + uv += s.f2.f2 as u64; + uv -= s.f2.f3 as u64; + uv + }; +} diff --git a/gcc/testsuite/rust/compile/torture/union_union.rs b/gcc/testsuite/rust/compile/torture/union_union.rs new file mode 100644 index 00000000000..9feb145a692 --- /dev/null +++ b/gcc/testsuite/rust/compile/torture/union_union.rs @@ -0,0 +1,27 @@ +union union +{ + union: u32, + inion: i32, + u8ion: u8, + i64on: i64, + u64on: u64 +} + +pub fn main () +{ + let union = union { union: 2 }; + let inion = union { inion: -2 }; + let mut mnion = union { inion: -16 }; + let m1 = unsafe { mnion.union }; + unsafe { mnion.union = union.union }; + let m2 = unsafe { mnion.inion }; + let u1 = unsafe { union.union }; + let i1 = unsafe { union.inion }; + let u2 = unsafe { inion.union }; + let i2 = unsafe { inion.inion }; + let _r1 = u2 - u1 - m1; + let _r2 = i1 + i2 + m2; + let _u8 = unsafe { union.u8ion }; + let _i64 = unsafe { union.i64on }; + let _u64 = unsafe { union.u64on }; +} -- 2.32.0