public inbox for gcc-cvs@sourceware.org
help / color / mirror / Atom feed
* [gcc/devel/rust/master] Initial support operator overloading on [lang = "add"]
@ 2022-06-08 11:50 Thomas Schwinge
  0 siblings, 0 replies; only message in thread
From: Thomas Schwinge @ 2022-06-08 11:50 UTC (permalink / raw)
  To: gcc-cvs

https://gcc.gnu.org/g:c47d5cbdee9b701fb7753b44530fcb51f80b20fa

commit c47d5cbdee9b701fb7753b44530fcb51f80b20fa
Author: Philip Herron <philip.herron@embecosm.com>
Date:   Tue Nov 16 13:54:43 2021 +0000

    Initial support operator overloading on [lang = "add"]
    
    This change incorporates a few changes.
    
    1. Create new gcc/rust/backend/rust-compile-expr.cc to split out
       implementation code
    2. Create new type check context api calls:
       - TypeCheckContext::lookup_operator_overload
       - TypeCheckContext::insert_operator_overload
    3. Update type checking for ArithmeticOrLogicalExpr to look for any
       operator overloading
    
    When we are looking for operator overloads we must look up the associated
    lang item type for this paticular operation, to resolve the operation to
    any known lang_items by looking up the specified lang_item to DefId. Then
    we must probe for the lang_item candidate for this paticular lang_item
    DefID to see if we can resolve it to a method call. Then based on the
    autoderef rules in a MethodCallExpr we must verify that we don't end up
    in a recursive operator overload by checking that the current context
    is not the same as the actual operator overload for this type. Finally
    we mark this expression as operator overload and setup everything as a
    resolved MethodCallExpr.
    
    Fixes #249

Diff:
---
 gcc/rust/Make-lang.in                              |   1 +
 gcc/rust/backend/rust-compile-expr.cc              | 316 +++++++++++++++++++++
 gcc/rust/backend/rust-compile-expr.h               |  26 +-
 gcc/rust/backend/rust-compile.cc                   | 219 +-------------
 gcc/rust/typecheck/rust-hir-type-check-expr.h      | 201 ++++++++++++-
 gcc/rust/typecheck/rust-hir-type-check.h           |  21 ++
 .../rust/execute/torture/operator_overload_1.rs    |  40 +++
 .../rust/execute/torture/operator_overload_2.rs    |  42 +++
 .../rust/execute/torture/operator_overload_3.rs    |  59 ++++
 9 files changed, 696 insertions(+), 229 deletions(-)

diff --git a/gcc/rust/Make-lang.in b/gcc/rust/Make-lang.in
index 15c0a8fa85d..f3302c2e3fb 100644
--- a/gcc/rust/Make-lang.in
+++ b/gcc/rust/Make-lang.in
@@ -90,6 +90,7 @@ GRS_OBJS = \
     rust/rust-hir-type-check-path.o \
     rust/rust-compile-intrinsic.o \
     rust/rust-base62.o \
+    rust/rust-compile-expr.o \
     $(END)
 # removed object files from here
 
diff --git a/gcc/rust/backend/rust-compile-expr.cc b/gcc/rust/backend/rust-compile-expr.cc
new file mode 100644
index 00000000000..c7941bc2014
--- /dev/null
+++ b/gcc/rust/backend/rust-compile-expr.cc
@@ -0,0 +1,316 @@
+// Copyright (C) 2020-2021 Free Software Foundation, Inc.
+
+// This file is part of GCC.
+
+// GCC is free software; you can redistribute it and/or modify it under
+// the terms of the GNU General Public License as published by the Free
+// Software Foundation; either version 3, or (at your option) any later
+// version.
+
+// GCC is distributed in the hope that it will be useful, but WITHOUT ANY
+// WARRANTY; without even the implied warranty of MERCHANTABILITY or
+// FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+// for more details.
+
+// You should have received a copy of the GNU General Public License
+// along with GCC; see the file COPYING3.  If not see
+// <http://www.gnu.org/licenses/>.
+
+#include "rust-compile.h"
+#include "rust-compile-item.h"
+#include "rust-compile-expr.h"
+#include "rust-compile-struct-field-expr.h"
+#include "rust-hir-trait-resolve.h"
+#include "rust-hir-path-probe.h"
+#include "rust-hir-type-bounds.h"
+#include "rust-hir-dot-operator.h"
+
+namespace Rust {
+namespace Compile {
+
+void
+CompileExpr::visit (HIR::ArithmeticOrLogicalExpr &expr)
+{
+  auto op = expr.get_expr_type ();
+  auto lhs = CompileExpr::Compile (expr.get_lhs (), ctx);
+  auto rhs = CompileExpr::Compile (expr.get_rhs (), ctx);
+
+  // this might be an operator overload situation lets check
+  TyTy::FnType *fntype;
+  bool is_op_overload = ctx->get_tyctx ()->lookup_operator_overload (
+    expr.get_mappings ().get_hirid (), &fntype);
+  if (!is_op_overload)
+    {
+      translated = ctx->get_backend ()->arithmetic_or_logical_expression (
+	op, lhs, rhs, expr.get_locus ());
+      return;
+    }
+
+  // lookup the resolved name
+  NodeId resolved_node_id = UNKNOWN_NODEID;
+  if (!ctx->get_resolver ()->lookup_resolved_name (
+	expr.get_mappings ().get_nodeid (), &resolved_node_id))
+    {
+      rust_error_at (expr.get_locus (), "failed to lookup resolved MethodCall");
+      return;
+    }
+
+  // reverse lookup
+  HirId ref;
+  if (!ctx->get_mappings ()->lookup_node_to_hir (
+	expr.get_mappings ().get_crate_num (), resolved_node_id, &ref))
+    {
+      rust_fatal_error (expr.get_locus (), "reverse lookup failure");
+      return;
+    }
+
+  TyTy::BaseType *receiver = nullptr;
+  bool ok
+    = ctx->get_tyctx ()->lookup_receiver (expr.get_mappings ().get_hirid (),
+					  &receiver);
+  rust_assert (ok);
+
+  bool is_dyn_dispatch
+    = receiver->get_root ()->get_kind () == TyTy::TypeKind::DYNAMIC;
+  bool is_generic_receiver = receiver->get_kind () == TyTy::TypeKind::PARAM;
+  if (is_generic_receiver)
+    {
+      TyTy::ParamType *p = static_cast<TyTy::ParamType *> (receiver);
+      receiver = p->resolve ();
+    }
+
+  if (is_dyn_dispatch)
+    {
+      const TyTy::DynamicObjectType *dyn
+	= static_cast<const TyTy::DynamicObjectType *> (receiver->get_root ());
+
+      std::vector<HIR::Expr *> arguments;
+      arguments.push_back (expr.get_rhs ());
+
+      translated = compile_dyn_dispatch_call (dyn, receiver, fntype, lhs,
+					      arguments, expr.get_locus ());
+      return;
+    }
+
+  // lookup compiled functions since it may have already been compiled
+  HIR::PathIdentSegment segment_name ("add");
+  Bexpression *fn_expr
+    = resolve_method_address (fntype, ref, receiver, segment_name,
+			      expr.get_mappings (), expr.get_locus ());
+
+  // lookup the autoderef mappings
+  std::vector<Resolver::Adjustment> *adjustments = nullptr;
+  ok = ctx->get_tyctx ()->lookup_autoderef_mappings (
+    expr.get_mappings ().get_hirid (), &adjustments);
+  rust_assert (ok);
+
+  Bexpression *self = lhs;
+  for (auto &adjustment : *adjustments)
+    {
+      switch (adjustment.get_type ())
+	{
+	case Resolver::Adjustment::AdjustmentType::IMM_REF:
+	case Resolver::Adjustment::AdjustmentType::MUT_REF:
+	  self = ctx->get_backend ()->address_expression (
+	    self, expr.get_lhs ()->get_locus ());
+	  break;
+
+	case Resolver::Adjustment::AdjustmentType::DEREF_REF:
+	  Btype *expected_type
+	    = TyTyResolveCompile::compile (ctx, adjustment.get_expected ());
+	  self = ctx->get_backend ()->indirect_expression (
+	    expected_type, self, true, /* known_valid*/
+	    expr.get_lhs ()->get_locus ());
+	  break;
+	}
+    }
+
+  std::vector<Bexpression *> args;
+  args.push_back (self); // adjusted self
+  args.push_back (rhs);
+
+  auto fncontext = ctx->peek_fn ();
+  translated
+    = ctx->get_backend ()->call_expression (fncontext.fndecl, fn_expr, args,
+					    nullptr, expr.get_locus ());
+}
+
+Bexpression *
+CompileExpr::compile_dyn_dispatch_call (const TyTy::DynamicObjectType *dyn,
+					TyTy::BaseType *receiver,
+					TyTy::FnType *fntype,
+					Bexpression *receiver_ref,
+					std::vector<HIR::Expr *> &arguments,
+					Location expr_locus)
+{
+  size_t offs = 0;
+  const Resolver::TraitItemReference *ref = nullptr;
+  for (auto &bound : dyn->get_object_items ())
+    {
+      const Resolver::TraitItemReference *item = bound.first;
+      auto t = item->get_tyty ();
+      rust_assert (t->get_kind () == TyTy::TypeKind::FNDEF);
+      auto ft = static_cast<TyTy::FnType *> (t);
+
+      if (ft->get_id () == fntype->get_id ())
+	{
+	  ref = item;
+	  break;
+	}
+      offs++;
+    }
+
+  if (ref == nullptr)
+    return ctx->get_backend ()->error_expression ();
+
+  // get any indirection sorted out
+  if (receiver->get_kind () == TyTy::TypeKind::REF)
+    {
+      TyTy::ReferenceType *r = static_cast<TyTy::ReferenceType *> (receiver);
+      auto indirect_ty = r->get_base ();
+      Btype *indrect_compiled_tyty
+	= TyTyResolveCompile::compile (ctx, indirect_ty);
+
+      Bexpression *indirect
+	= ctx->get_backend ()->indirect_expression (indrect_compiled_tyty,
+						    receiver_ref, true,
+						    expr_locus);
+      receiver_ref = indirect;
+    }
+
+  // access the offs + 1 for the fnptr and offs=0 for the reciever obj
+  Bexpression *self_argument
+    = ctx->get_backend ()->struct_field_expression (receiver_ref, 0,
+						    expr_locus);
+
+  // access the vtable for the fn
+  Bexpression *fn_vtable_access
+    = ctx->get_backend ()->struct_field_expression (receiver_ref, offs + 1,
+						    expr_locus);
+
+  // cast it to the correct fntype
+  Btype *expected_fntype = TyTyResolveCompile::compile (ctx, fntype, true);
+  Bexpression *fn_convert_expr
+    = ctx->get_backend ()->convert_expression (expected_fntype,
+					       fn_vtable_access, expr_locus);
+
+  fncontext fnctx = ctx->peek_fn ();
+  Bblock *enclosing_scope = ctx->peek_enclosing_scope ();
+  bool is_address_taken = false;
+  Bstatement *ret_var_stmt = nullptr;
+  Bvariable *fn_convert_expr_tmp
+    = ctx->get_backend ()->temporary_variable (fnctx.fndecl, enclosing_scope,
+					       expected_fntype, fn_convert_expr,
+					       is_address_taken, expr_locus,
+					       &ret_var_stmt);
+  ctx->add_statement (ret_var_stmt);
+
+  std::vector<Bexpression *> args;
+  args.push_back (self_argument);
+  for (auto &argument : arguments)
+    {
+      Bexpression *compiled_expr = CompileExpr::Compile (argument, ctx);
+      args.push_back (compiled_expr);
+    }
+
+  Bexpression *fn_expr
+    = ctx->get_backend ()->var_expression (fn_convert_expr_tmp, expr_locus);
+
+  return ctx->get_backend ()->call_expression (fnctx.fndecl, fn_expr, args,
+					       nullptr, expr_locus);
+}
+
+Bexpression *
+CompileExpr::resolve_method_address (TyTy::FnType *fntype, HirId ref,
+				     TyTy::BaseType *receiver,
+				     HIR::PathIdentSegment &segment,
+				     Analysis::NodeMapping expr_mappings,
+				     Location expr_locus)
+{
+  // lookup compiled functions since it may have already been compiled
+  Bfunction *fn = nullptr;
+  if (ctx->lookup_function_decl (fntype->get_ty_ref (), &fn))
+    {
+      return ctx->get_backend ()->function_code_expression (fn, expr_locus);
+    }
+
+  // Now we can try and resolve the address since this might be a forward
+  // declared function, generic function which has not be compiled yet or
+  // its an not yet trait bound function
+  HIR::ImplItem *resolved_item
+    = ctx->get_mappings ()->lookup_hir_implitem (expr_mappings.get_crate_num (),
+						 ref, nullptr);
+  if (resolved_item != nullptr)
+    {
+      if (!fntype->has_subsititions_defined ())
+	return CompileInherentImplItem::Compile (receiver, resolved_item, ctx,
+						 true);
+
+      return CompileInherentImplItem::Compile (receiver, resolved_item, ctx,
+					       true, fntype);
+    }
+
+  // it might be resolved to a trait item
+  HIR::TraitItem *trait_item = ctx->get_mappings ()->lookup_hir_trait_item (
+    expr_mappings.get_crate_num (), ref);
+  HIR::Trait *trait = ctx->get_mappings ()->lookup_trait_item_mapping (
+    trait_item->get_mappings ().get_hirid ());
+
+  Resolver::TraitReference *trait_ref
+    = &Resolver::TraitReference::error_node ();
+  bool ok = ctx->get_tyctx ()->lookup_trait_reference (
+    trait->get_mappings ().get_defid (), &trait_ref);
+  rust_assert (ok);
+
+  // the type resolver can only resolve type bounds to their trait
+  // item so its up to us to figure out if this path should resolve
+  // to an trait-impl-block-item or if it can be defaulted to the
+  // trait-impl-item's definition
+
+  auto root = receiver->get_root ();
+  std::vector<Resolver::PathProbeCandidate> candidates
+    = Resolver::PathProbeType::Probe (root, segment, true, false, true);
+
+  if (candidates.size () == 0)
+    {
+      // this means we are defaulting back to the trait_item if
+      // possible
+      Resolver::TraitItemReference *trait_item_ref = nullptr;
+      bool ok = trait_ref->lookup_hir_trait_item (*trait_item, &trait_item_ref);
+      rust_assert (ok);				    // found
+      rust_assert (trait_item_ref->is_optional ()); // has definition
+
+      // FIXME Optional means it has a definition and an associated
+      // block which can be a default implementation, if it does not
+      // contain an implementation we should actually return
+      // error_mark_node
+
+      return CompileTraitItem::Compile (receiver,
+					trait_item_ref->get_hir_trait_item (),
+					ctx, fntype, true, expr_locus);
+    }
+  else
+    {
+      std::vector<Resolver::Adjustment> adjustments;
+      Resolver::PathProbeCandidate *candidate
+	= Resolver::MethodResolution::Select (candidates, root, adjustments);
+
+      // FIXME this will be a case to return error_mark_node, there is
+      // an error scenario where a Trait Foo has a method Bar, but this
+      // receiver does not implement this trait or has an incompatible
+      // implementation and we should just return error_mark_node
+      rust_assert (candidate != nullptr);
+      rust_assert (candidate->is_impl_candidate ());
+
+      HIR::ImplItem *impl_item = candidate->item.impl.impl_item;
+      if (!fntype->has_subsititions_defined ())
+	return CompileInherentImplItem::Compile (receiver, impl_item, ctx,
+						 true);
+
+      return CompileInherentImplItem::Compile (receiver, impl_item, ctx, true,
+					       fntype);
+    }
+}
+
+} // namespace Compile
+} // namespace Rust
diff --git a/gcc/rust/backend/rust-compile-expr.h b/gcc/rust/backend/rust-compile-expr.h
index f43db501a56..c9d3c304aab 100644
--- a/gcc/rust/backend/rust-compile-expr.h
+++ b/gcc/rust/backend/rust-compile-expr.h
@@ -448,17 +448,7 @@ public:
       constructor.push_back (translated_expr);
   }
 
-  void visit (HIR::ArithmeticOrLogicalExpr &expr) override
-  {
-    auto op = expr.get_expr_type ();
-    auto lhs = CompileExpr::Compile (expr.get_lhs (), ctx);
-    auto rhs = CompileExpr::Compile (expr.get_rhs (), ctx);
-    auto location = expr.get_locus ();
-
-    translated
-      = ctx->get_backend ()->arithmetic_or_logical_expression (op, lhs, rhs,
-							       location);
-  }
+  void visit (HIR::ArithmeticOrLogicalExpr &expr) override;
 
   void visit (HIR::ComparisonExpr &expr) override
   {
@@ -999,6 +989,20 @@ public:
 						  expr.get_locus ());
   }
 
+protected:
+  Bexpression *compile_dyn_dispatch_call (const TyTy::DynamicObjectType *dyn,
+					  TyTy::BaseType *receiver,
+					  TyTy::FnType *fntype,
+					  Bexpression *receiver_ref,
+					  std::vector<HIR::Expr *> &arguments,
+					  Location expr_locus);
+
+  Bexpression *resolve_method_address (TyTy::FnType *fntype, HirId ref,
+				       TyTy::BaseType *receiver,
+				       HIR::PathIdentSegment &segment,
+				       Analysis::NodeMapping expr_mappings,
+				       Location expr_locus);
+
 private:
   CompileExpr (Context *ctx)
     : HIRCompileBase (ctx), translated (nullptr), capacity_expr (nullptr)
diff --git a/gcc/rust/backend/rust-compile.cc b/gcc/rust/backend/rust-compile.cc
index e9aca2c34f1..e53993a8cbe 100644
--- a/gcc/rust/backend/rust-compile.cc
+++ b/gcc/rust/backend/rust-compile.cc
@@ -242,220 +242,21 @@ CompileExpr::visit (HIR::MethodCallExpr &expr)
       const TyTy::DynamicObjectType *dyn
 	= static_cast<const TyTy::DynamicObjectType *> (receiver->get_root ());
 
-      size_t offs = 0;
-      const Resolver::TraitItemReference *ref = nullptr;
-      for (auto &bound : dyn->get_object_items ())
-	{
-	  const Resolver::TraitItemReference *item = bound.first;
-	  auto t = item->get_tyty ();
-	  rust_assert (t->get_kind () == TyTy::TypeKind::FNDEF);
-	  auto ft = static_cast<TyTy::FnType *> (t);
-
-	  if (ft->get_id () == fntype->get_id ())
-	    {
-	      ref = item;
-	      break;
-	    }
-	  offs++;
-	}
-
-      if (ref == nullptr)
-	{
-	  translated = ctx->get_backend ()->error_expression ();
-	  return;
-	}
-
-      // get any indirection sorted out
-      auto receiver_ref = self;
-      if (receiver->get_kind () == TyTy::TypeKind::REF)
-	{
-	  TyTy::ReferenceType *r
-	    = static_cast<TyTy::ReferenceType *> (receiver);
-	  auto indirect_ty = r->get_base ();
-	  Btype *indrect_compiled_tyty
-	    = TyTyResolveCompile::compile (ctx, indirect_ty);
-
-	  Bexpression *indirect
-	    = ctx->get_backend ()->indirect_expression (indrect_compiled_tyty,
-							receiver_ref, true,
-							expr.get_locus ());
-	  receiver_ref = indirect;
-	}
+      std::vector<HIR::Expr *> arguments;
+      for (auto &arg : expr.get_arguments ())
+	arguments.push_back (arg.get ());
 
-      // access the offs + 1 for the fnptr and offs=0 for the reciever obj
-      Bexpression *self_argument
-	= ctx->get_backend ()->struct_field_expression (receiver_ref, 0,
-							expr.get_locus ());
-
-      // access the vtable for the fn
-      Bexpression *fn_vtable_access
-	= ctx->get_backend ()->struct_field_expression (receiver_ref, offs + 1,
-							expr.get_locus ());
-
-      // cast it to the correct fntype
-      Btype *expected_fntype = TyTyResolveCompile::compile (ctx, fntype, true);
-      Bexpression *fn_convert_expr
-	= ctx->get_backend ()->convert_expression (expected_fntype,
-						   fn_vtable_access,
-						   expr.get_locus ());
-
-      fncontext fnctx = ctx->peek_fn ();
-      Bblock *enclosing_scope = ctx->peek_enclosing_scope ();
-      bool is_address_taken = false;
-      Bstatement *ret_var_stmt = nullptr;
-
-      Bvariable *fn_convert_expr_tmp = ctx->get_backend ()->temporary_variable (
-	fnctx.fndecl, enclosing_scope, expected_fntype, fn_convert_expr,
-	is_address_taken, expr.get_locus (), &ret_var_stmt);
-      ctx->add_statement (ret_var_stmt);
-
-      std::vector<Bexpression *> args;
-      args.push_back (self_argument);
-      for (auto &argument : expr.get_arguments ())
-	{
-	  Bexpression *compiled_expr
-	    = CompileExpr::Compile (argument.get (), ctx);
-	  args.push_back (compiled_expr);
-	}
-
-      Bexpression *fn_expr
-	= ctx->get_backend ()->var_expression (fn_convert_expr_tmp,
-					       expr.get_locus ());
-
-      translated
-	= ctx->get_backend ()->call_expression (fnctx.fndecl, fn_expr, args,
-						nullptr, expr.get_locus ());
+      translated = compile_dyn_dispatch_call (dyn, receiver, fntype, self,
+					      arguments, expr.get_locus ());
       return;
     }
 
-  // address of compiled function
-  Bexpression *fn_expr = ctx->get_backend ()->error_expression ();
-
   // lookup compiled functions since it may have already been compiled
-  Bfunction *fn = nullptr;
-  if (ctx->lookup_function_decl (fntype->get_ty_ref (), &fn))
-    {
-      fn_expr
-	= ctx->get_backend ()->function_code_expression (fn, expr.get_locus ());
-    }
-  else
-    {
-      // Now we can try and resolve the address since this might be a forward
-      // declared function, generic function which has not be compiled yet or
-      // its an not yet trait bound function
-      HIR::ImplItem *resolved_item = ctx->get_mappings ()->lookup_hir_implitem (
-	expr.get_mappings ().get_crate_num (), ref, nullptr);
-      if (resolved_item == nullptr)
-	{
-	  // it might be resolved to a trait item
-	  HIR::TraitItem *trait_item
-	    = ctx->get_mappings ()->lookup_hir_trait_item (
-	      expr.get_mappings ().get_crate_num (), ref);
-	  HIR::Trait *trait = ctx->get_mappings ()->lookup_trait_item_mapping (
-	    trait_item->get_mappings ().get_hirid ());
-
-	  Resolver::TraitReference *trait_ref
-	    = &Resolver::TraitReference::error_node ();
-	  bool ok = ctx->get_tyctx ()->lookup_trait_reference (
-	    trait->get_mappings ().get_defid (), &trait_ref);
-	  rust_assert (ok);
-
-	  // the type resolver can only resolve type bounds to their trait
-	  // item so its up to us to figure out if this path should resolve
-	  // to an trait-impl-block-item or if it can be defaulted to the
-	  // trait-impl-item's definition
-
-	  auto root = receiver->get_root ();
-	  std::vector<Resolver::PathProbeCandidate> candidates
-	    = Resolver::PathProbeType::Probe (
-	      root, expr.get_method_name ().get_segment (), true, false, true);
-
-	  if (candidates.size () == 0)
-	    {
-	      // this means we are defaulting back to the trait_item if
-	      // possible
-	      Resolver::TraitItemReference *trait_item_ref = nullptr;
-	      bool ok = trait_ref->lookup_hir_trait_item (*trait_item,
-							  &trait_item_ref);
-	      rust_assert (ok);				    // found
-	      rust_assert (trait_item_ref->is_optional ()); // has definition
-
-	      // FIXME Optional means it has a definition and an associated
-	      // block which can be a default implementation, if it does not
-	      // contain an implementation we should actually return
-	      // error_mark_node
-
-	      TyTy::BaseType *self_type = nullptr;
-	      if (!ctx->get_tyctx ()->lookup_type (
-		    expr.get_receiver ()->get_mappings ().get_hirid (),
-		    &self_type))
-		{
-		  rust_error_at (expr.get_locus (),
-				 "failed to resolve type for self param");
-		  return;
-		}
-
-	      fn_expr = CompileTraitItem::Compile (
-		self_type, trait_item_ref->get_hir_trait_item (), ctx, fntype,
-		true, expr.get_locus ());
-	    }
-	  else
-	    {
-	      std::vector<Resolver::Adjustment> adjustments;
-	      Resolver::PathProbeCandidate *candidate
-		= Resolver::MethodResolution::Select (candidates, root,
-						      adjustments);
-
-	      // FIXME this will be a case to return error_mark_node, there is
-	      // an error scenario where a Trait Foo has a method Bar, but this
-	      // receiver does not implement this trait or has an incompatible
-	      // implementation and we should just return error_mark_node
-	      rust_assert (candidate != nullptr);
-	      rust_assert (candidate->is_impl_candidate ());
-
-	      HIR::ImplItem *impl_item = candidate->item.impl.impl_item;
-
-	      TyTy::BaseType *self_type = nullptr;
-	      if (!ctx->get_tyctx ()->lookup_type (
-		    expr.get_receiver ()->get_mappings ().get_hirid (),
-		    &self_type))
-		{
-		  rust_error_at (expr.get_locus (),
-				 "failed to resolve type for self param");
-		  return;
-		}
-
-	      if (!fntype->has_subsititions_defined ())
-		fn_expr
-		  = CompileInherentImplItem::Compile (self_type, impl_item, ctx,
-						      true);
-	      else
-		fn_expr
-		  = CompileInherentImplItem::Compile (self_type, impl_item, ctx,
-						      true, fntype);
-	    }
-	}
-      else
-	{
-	  TyTy::BaseType *self_type = nullptr;
-	  if (!ctx->get_tyctx ()->lookup_type (
-		expr.get_receiver ()->get_mappings ().get_hirid (), &self_type))
-	    {
-	      rust_error_at (expr.get_locus (),
-			     "failed to resolve type for self param");
-	      return;
-	    }
-
-	  if (!fntype->has_subsititions_defined ())
-	    fn_expr
-	      = CompileInherentImplItem::Compile (self_type, resolved_item, ctx,
-						  true);
-	  else
-	    fn_expr
-	      = CompileInherentImplItem::Compile (self_type, resolved_item, ctx,
-						  true, fntype);
-	}
-    }
+  HIR::PathExprSegment method_name = expr.get_method_name ();
+  HIR::PathIdentSegment segment_name = method_name.get_segment ();
+  Bexpression *fn_expr
+    = resolve_method_address (fntype, ref, receiver, segment_name,
+			      expr.get_mappings (), expr.get_locus ());
 
   // lookup the autoderef mappings
   std::vector<Resolver::Adjustment> *adjustments = nullptr;
diff --git a/gcc/rust/typecheck/rust-hir-type-check-expr.h b/gcc/rust/typecheck/rust-hir-type-check-expr.h
index 2a6bae9ca7d..b332a684ad4 100644
--- a/gcc/rust/typecheck/rust-hir-type-check-expr.h
+++ b/gcc/rust/typecheck/rust-hir-type-check-expr.h
@@ -699,20 +699,203 @@ public:
     auto lhs = TypeCheckExpr::Resolve (expr.get_lhs (), false);
     auto rhs = TypeCheckExpr::Resolve (expr.get_rhs (), false);
 
-    bool valid_lhs = validate_arithmetic_type (lhs, expr.get_expr_type ());
-    bool valid_rhs = validate_arithmetic_type (rhs, expr.get_expr_type ());
-    bool valid = valid_lhs && valid_rhs;
-    if (!valid)
+    // in order to probe of the correct type paths we need the root type, which
+    // strips any references
+    const TyTy::BaseType *root = lhs->get_root ();
+
+    // look up lang item for arithmetic type
+    std::vector<PathProbeCandidate> candidates;
+    auto lang_item_type
+      = Analysis::RustLangItem::OperatorToLangItem (expr.get_expr_type ());
+    std::string associated_item_name
+      = Analysis::RustLangItem::ToString (lang_item_type);
+    DefId respective_lang_item_id = UNKNOWN_DEFID;
+    bool lang_item_defined
+      = mappings->lookup_lang_item (lang_item_type, &respective_lang_item_id);
+
+    // handle the case where we are within the impl block for this lang_item
+    // otherwise we end up with a recursive operator overload such as the i32
+    // operator overload trait
+    if (lang_item_defined)
+      {
+	TypeCheckContextItem &fn_context = context->peek_context ();
+	if (fn_context.get_type () == TypeCheckContextItem::ItemType::IMPL_ITEM)
+	  {
+	    auto &impl_item = fn_context.get_impl_item ();
+	    HIR::ImplBlock *parent = impl_item.first;
+	    HIR::Function *fn = impl_item.second;
+
+	    if (parent->has_trait_ref ()
+		&& fn->get_function_name ().compare (associated_item_name) == 0)
+	      {
+		TraitReference *trait_reference
+		  = TraitResolver::Lookup (*parent->get_trait_ref ().get ());
+		if (!trait_reference->is_error ())
+		  {
+		    TyTy::BaseType *lookup = nullptr;
+		    bool ok
+		      = context->lookup_type (fn->get_mappings ().get_hirid (),
+					      &lookup);
+		    rust_assert (ok);
+		    rust_assert (lookup->get_kind () == TyTy::TypeKind::FNDEF);
+
+		    TyTy::FnType *fntype = static_cast<TyTy::FnType *> (lookup);
+		    rust_assert (fntype->is_method ());
+
+		    bool is_lang_item_impl
+		      = trait_reference->get_mappings ().get_defid ()
+			== respective_lang_item_id;
+		    bool self_is_lang_item_self
+		      = fntype->get_self_type ()->is_equal (*lhs);
+
+		    bool recursive_operator_overload
+		      = is_lang_item_impl && self_is_lang_item_self;
+		    lang_item_defined = !recursive_operator_overload;
+		  }
+	      }
+	  }
+      }
+
+    // probe for the lang-item
+    if (lang_item_defined)
+      {
+	bool receiver_is_type_param
+	  = root->get_kind () == TyTy::TypeKind::PARAM;
+	bool receiver_is_dyn = root->get_kind () == TyTy::TypeKind::DYNAMIC;
+
+	bool receiver_is_generic = receiver_is_type_param || receiver_is_dyn;
+	bool probe_bounds = true;
+	bool probe_impls = !receiver_is_generic;
+	bool ignore_mandatory_trait_items = !receiver_is_generic;
+
+	candidates = PathProbeType::Probe (
+	  root, HIR::PathIdentSegment (associated_item_name), probe_impls,
+	  probe_bounds, ignore_mandatory_trait_items, respective_lang_item_id);
+      }
+
+    bool have_implementation_for_lang_item = candidates.size () > 0;
+    if (!lang_item_defined || !have_implementation_for_lang_item)
+      {
+	bool valid_lhs = validate_arithmetic_type (lhs, expr.get_expr_type ());
+	bool valid_rhs = validate_arithmetic_type (rhs, expr.get_expr_type ());
+	bool valid = valid_lhs && valid_rhs;
+	if (!valid)
+	  {
+	    rust_error_at (expr.get_locus (),
+			   "cannot apply this operator to types %s and %s",
+			   lhs->as_string ().c_str (),
+			   rhs->as_string ().c_str ());
+	    return;
+	  }
+
+	infered = lhs->unify (rhs);
+	return;
+      }
+
+    // now its just like a method-call-expr
+    context->insert_receiver (expr.get_mappings ().get_hirid (), lhs);
+
+    // autoderef
+    std::vector<Adjustment> adjustments;
+    PathProbeCandidate *resolved_candidate
+      = MethodResolution::Select (candidates, lhs, adjustments);
+    rust_assert (resolved_candidate != nullptr);
+
+    // store the adjustments for code-generation to know what to do
+    context->insert_autoderef_mappings (expr.get_mappings ().get_hirid (),
+					std::move (adjustments));
+
+    TyTy::BaseType *lookup_tyty = resolved_candidate->ty;
+    NodeId resolved_node_id
+      = resolved_candidate->is_impl_candidate ()
+	  ? resolved_candidate->item.impl.impl_item->get_impl_mappings ()
+	      .get_nodeid ()
+	  : resolved_candidate->item.trait.item_ref->get_mappings ()
+	      .get_nodeid ();
+
+    rust_assert (lookup_tyty->get_kind () == TyTy::TypeKind::FNDEF);
+    TyTy::BaseType *lookup = lookup_tyty;
+    TyTy::FnType *fn = static_cast<TyTy::FnType *> (lookup);
+    rust_assert (fn->is_method ());
+
+    if (root->get_kind () == TyTy::TypeKind::ADT)
+      {
+	const TyTy::ADTType *adt = static_cast<const TyTy::ADTType *> (root);
+	if (adt->has_substitutions () && fn->needs_substitution ())
+	  {
+	    // consider the case where we have:
+	    //
+	    // struct Foo<X,Y>(X,Y);
+	    //
+	    // impl<T> Foo<T, i32> {
+	    //   fn test<X>(self, a:X) -> (T,X) { (self.0, a) }
+	    // }
+	    //
+	    // In this case we end up with an fn type of:
+	    //
+	    // fn <T,X> test(self:Foo<T,i32>, a:X) -> (T,X)
+	    //
+	    // This means the instance or self we are calling this method for
+	    // will be substituted such that we can get the inherited type
+	    // arguments but then need to use the turbo fish if available or
+	    // infer the remaining arguments. Luckily rust does not allow for
+	    // default types GenericParams on impl blocks since these must
+	    // always be at the end of the list
+
+	    auto s = fn->get_self_type ()->get_root ();
+	    rust_assert (s->can_eq (adt, false, false));
+	    rust_assert (s->get_kind () == TyTy::TypeKind::ADT);
+	    const TyTy::ADTType *self_adt
+	      = static_cast<const TyTy::ADTType *> (s);
+
+	    // we need to grab the Self substitutions as the inherit type
+	    // parameters for this
+	    if (self_adt->needs_substitution ())
+	      {
+		rust_assert (adt->was_substituted ());
+
+		TyTy::SubstitutionArgumentMappings used_args_in_prev_segment
+		  = GetUsedSubstArgs::From (adt);
+
+		TyTy::SubstitutionArgumentMappings inherit_type_args
+		  = self_adt->solve_mappings_from_receiver_for_self (
+		    used_args_in_prev_segment);
+
+		// there may or may not be inherited type arguments
+		if (!inherit_type_args.is_error ())
+		  {
+		    // need to apply the inherited type arguments to the
+		    // function
+		    lookup = fn->handle_substitions (inherit_type_args);
+		  }
+	      }
+	  }
+      }
+
+    // type check the arguments
+    TyTy::FnType *type = static_cast<TyTy::FnType *> (lookup);
+    rust_assert (type->num_params () == 2);
+    auto fnparam = type->param_at (1);
+    auto resolved_argument_type = fnparam.second->unify (rhs);
+    if (resolved_argument_type->get_kind () == TyTy::TypeKind::ERROR)
       {
 	rust_error_at (expr.get_locus (),
-		       "cannot apply this operator to types %s and %s",
-		       lhs->as_string ().c_str (), rhs->as_string ().c_str ());
+		       "Type Resolution failure on parameter");
 	return;
       }
 
-    infered = lhs->unify (rhs);
-    infered->append_reference (lhs->get_ref ());
-    infered->append_reference (rhs->get_ref ());
+    // get the return type
+    TyTy::BaseType *function_ret_tyty = fn->get_return_type ()->clone ();
+
+    // store the expected fntype
+    context->insert_operator_overload (expr.get_mappings ().get_hirid (), type);
+
+    // set up the resolved name on the path
+    resolver->insert_resolved_name (expr.get_mappings ().get_nodeid (),
+				    resolved_node_id);
+
+    // return the result of the function back
+    infered = function_ret_tyty;
   }
 
   void visit (HIR::ComparisonExpr &expr) override
diff --git a/gcc/rust/typecheck/rust-hir-type-check.h b/gcc/rust/typecheck/rust-hir-type-check.h
index b165f9cb041..1add4faa59a 100644
--- a/gcc/rust/typecheck/rust-hir-type-check.h
+++ b/gcc/rust/typecheck/rust-hir-type-check.h
@@ -296,6 +296,24 @@ public:
     return true;
   }
 
+  void insert_operator_overload (HirId id, TyTy::FnType *call_site)
+  {
+    auto it = operator_overloads.find (id);
+    rust_assert (it == operator_overloads.end ());
+
+    operator_overloads[id] = call_site;
+  }
+
+  bool lookup_operator_overload (HirId id, TyTy::FnType **call)
+  {
+    auto it = operator_overloads.find (id);
+    if (it == operator_overloads.end ())
+      return false;
+
+    *call = it->second;
+    return true;
+  }
+
 private:
   TypeCheckContext ();
 
@@ -318,6 +336,9 @@ private:
   // adjustment mappings
   std::map<HirId, std::vector<Adjustment>> autoderef_mappings;
 
+  // operator overloads
+  std::map<HirId, TyTy::FnType *> operator_overloads;
+
   // variants
   std::map<HirId, HirId> variants;
 };
diff --git a/gcc/testsuite/rust/execute/torture/operator_overload_1.rs b/gcc/testsuite/rust/execute/torture/operator_overload_1.rs
new file mode 100644
index 00000000000..e52b3947980
--- /dev/null
+++ b/gcc/testsuite/rust/execute/torture/operator_overload_1.rs
@@ -0,0 +1,40 @@
+/* { dg-output "3\n" } */
+extern "C" {
+    fn printf(s: *const i8, ...);
+}
+
+#[lang = "add"]
+pub trait Add<Rhs = Self> {
+    type Output;
+    // { dg-warning "unused name" "" { target *-*-* } .-1 }
+
+    fn add(self, rhs: Rhs) -> Self::Output;
+    // { dg-warning "unused name .self." "" { target *-*-* } .-1 }
+    // { dg-warning "unused name .rhs." "" { target *-*-* } .-2 }
+    // { dg-warning "unused name .Add::add." "" { target *-*-* } .-3 }
+}
+
+impl Add for i32 {
+    type Output = i32;
+
+    fn add(self, other: i32) -> i32 {
+        let res = self + other;
+
+        unsafe {
+            let a = "%i\n\0";
+            let b = a as *const str;
+            let c = b as *const i8;
+
+            printf(c, res);
+        }
+
+        res
+    }
+}
+
+fn main() -> i32 {
+    let a;
+    a = 1 + 2;
+
+    0
+}
diff --git a/gcc/testsuite/rust/execute/torture/operator_overload_2.rs b/gcc/testsuite/rust/execute/torture/operator_overload_2.rs
new file mode 100644
index 00000000000..9d5615d1381
--- /dev/null
+++ b/gcc/testsuite/rust/execute/torture/operator_overload_2.rs
@@ -0,0 +1,42 @@
+/* { dg-output "3\n" } */
+extern "C" {
+    fn printf(s: *const i8, ...);
+}
+
+#[lang = "add"]
+pub trait Add<Rhs = Self> {
+    type Output;
+    // { dg-warning "unused name" "" { target *-*-* } .-1 }
+
+    fn add(self, rhs: Rhs) -> Self::Output;
+    // { dg-warning "unused name .self." "" { target *-*-* } .-1 }
+    // { dg-warning "unused name .rhs." "" { target *-*-* } .-2 }
+    // { dg-warning "unused name .Add::add." "" { target *-*-* } .-3 }
+}
+
+struct Foo(i32);
+
+impl Add for Foo {
+    type Output = Foo;
+
+    fn add(self, other: Foo) -> Foo {
+        let res = Foo(self.0 + other.0);
+
+        unsafe {
+            let a = "%i\n\0";
+            let b = a as *const str;
+            let c = b as *const i8;
+
+            printf(c, res.0);
+        }
+
+        res
+    }
+}
+
+fn main() -> i32 {
+    let a;
+    a = Foo(1) + Foo(2);
+
+    0
+}
diff --git a/gcc/testsuite/rust/execute/torture/operator_overload_3.rs b/gcc/testsuite/rust/execute/torture/operator_overload_3.rs
new file mode 100644
index 00000000000..bd99b50a4fd
--- /dev/null
+++ b/gcc/testsuite/rust/execute/torture/operator_overload_3.rs
@@ -0,0 +1,59 @@
+/* { dg-output "3\n3\n" } */
+extern "C" {
+    fn printf(s: *const i8, ...);
+}
+
+#[lang = "add"]
+pub trait Add<Rhs = Self> {
+    type Output;
+    // { dg-warning "unused name" "" { target *-*-* } .-1 }
+
+    fn add(self, rhs: Rhs) -> Self::Output;
+    // { dg-warning "unused name .self." "" { target *-*-* } .-1 }
+    // { dg-warning "unused name .rhs." "" { target *-*-* } .-2 }
+    // { dg-warning "unused name .Add::add." "" { target *-*-* } .-3 }
+}
+
+impl Add for i32 {
+    type Output = i32;
+
+    fn add(self, other: i32) -> i32 {
+        let res = self + other;
+
+        unsafe {
+            let a = "%i\n\0";
+            let b = a as *const str;
+            let c = b as *const i8;
+
+            printf(c, res);
+        }
+
+        res
+    }
+}
+
+struct Foo(i32);
+impl Add for Foo {
+    type Output = Foo;
+
+    fn add(self, other: Foo) -> Foo {
+        let res = Foo(self.0 + other.0);
+
+        unsafe {
+            let a = "%i\n\0";
+            let b = a as *const str;
+            let c = b as *const i8;
+
+            printf(c, res.0);
+        }
+
+        res
+    }
+}
+
+fn main() -> i32 {
+    let a;
+    a = Foo(1) + Foo(2);
+
+    0
+}


^ permalink raw reply	[flat|nested] only message in thread

only message in thread, other threads:[~2022-06-08 11:50 UTC | newest]

Thread overview: (only message) (download: mbox.gz / follow: Atom feed)
-- links below jump to the message on this page --
2022-06-08 11:50 [gcc/devel/rust/master] Initial support operator overloading on [lang = "add"] Thomas Schwinge

This is a public inbox, see mirroring instructions
for how to clone and mirror all data and code used for this inbox;
as well as URLs for read-only IMAP folder(s) and NNTP newsgroup(s).