public inbox for gcc-patches@gcc.gnu.org
 help / color / mirror / Atom feed
* [PATCH] SSA MATH: Support COND_LEN_FMA for floating-point math optimization
@ 2023-07-13  5:17 juzhe.zhong
  2023-07-13  7:53 ` Richard Biener
  0 siblings, 1 reply; 3+ messages in thread
From: juzhe.zhong @ 2023-07-13  5:17 UTC (permalink / raw)
  To: gcc-patches; +Cc: richard.sandiford, rguenther, Ju-Zhe Zhong

From: Ju-Zhe Zhong <juzhe.zhong@rivai.ai>

Hi, Richard and Richi.

Previous patch we support COND_LEN_* binary operations. However, we didn't
support COND_LEN_* ternary.

Now, this patch support COND_LEN_* ternary. Consider this following case:

#define TEST_TYPE(TYPE)                                                        \
  __attribute__ ((noipa)) void ternop_##TYPE (TYPE *__restrict dst,            \
					      TYPE *__restrict a,              \
					      TYPE *__restrict b,\
                TYPE *__restrict c, int n)       \
  {                                                                            \
    for (int i = 0; i < n; i++)                                                \
      dst[i] += a[i] * b[i];                                                     \
  }

#define TEST_ALL() TEST_TYPE (double)

TEST_ALL ()

Before this patch:
...
COND_LEN_MUL
COND_LEN_ADD

Afther this patch:
...
COND_LEN_FMA

gcc/ChangeLog:

        * genmatch.cc (commutative_op): Add COND_LEN_*
        * internal-fn.cc (first_commutative_argument): Ditto.
        (CASE): Ditto.
        (get_unconditional_internal_fn): Ditto.
        (can_interpret_as_conditional_op_p): Ditto.
        (internal_fn_len_index): Ditto.
        * internal-fn.h (can_interpret_as_conditional_op_p): Ditt.
        * tree-ssa-math-opts.cc (convert_mult_to_fma_1): Ditto.
        (convert_mult_to_fma): Ditto.
        (math_opts_dom_walker::after_dom_children): Ditto.

---
 gcc/genmatch.cc           | 13 +++++++
 gcc/internal-fn.cc        | 82 +++++++++++++++++++++++++++++++++++----
 gcc/internal-fn.h         |  2 +-
 gcc/tree-ssa-math-opts.cc | 57 ++++++++++++++++++++++++---
 4 files changed, 139 insertions(+), 15 deletions(-)

diff --git a/gcc/genmatch.cc b/gcc/genmatch.cc
index 5fceeec9780..2302f2a7ff0 100644
--- a/gcc/genmatch.cc
+++ b/gcc/genmatch.cc
@@ -559,6 +559,19 @@ commutative_op (id_base *id)
       case CFN_COND_FMS:
       case CFN_COND_FNMA:
       case CFN_COND_FNMS:
+      case CFN_COND_LEN_ADD:
+      case CFN_COND_LEN_MUL:
+      case CFN_COND_LEN_MIN:
+      case CFN_COND_LEN_MAX:
+      case CFN_COND_LEN_FMIN:
+      case CFN_COND_LEN_FMAX:
+      case CFN_COND_LEN_AND:
+      case CFN_COND_LEN_IOR:
+      case CFN_COND_LEN_XOR:
+      case CFN_COND_LEN_FMA:
+      case CFN_COND_LEN_FMS:
+      case CFN_COND_LEN_FNMA:
+      case CFN_COND_LEN_FNMS:
 	return 1;
 
       default:
diff --git a/gcc/internal-fn.cc b/gcc/internal-fn.cc
index c11123a1173..e47b1377ff8 100644
--- a/gcc/internal-fn.cc
+++ b/gcc/internal-fn.cc
@@ -4191,6 +4191,19 @@ first_commutative_argument (internal_fn fn)
     case IFN_COND_FMS:
     case IFN_COND_FNMA:
     case IFN_COND_FNMS:
+    case IFN_COND_LEN_ADD:
+    case IFN_COND_LEN_MUL:
+    case IFN_COND_LEN_MIN:
+    case IFN_COND_LEN_MAX:
+    case IFN_COND_LEN_FMIN:
+    case IFN_COND_LEN_FMAX:
+    case IFN_COND_LEN_AND:
+    case IFN_COND_LEN_IOR:
+    case IFN_COND_LEN_XOR:
+    case IFN_COND_LEN_FMA:
+    case IFN_COND_LEN_FMS:
+    case IFN_COND_LEN_FNMA:
+    case IFN_COND_LEN_FNMS:
       return 1;
 
     default:
@@ -4330,11 +4343,15 @@ conditional_internal_fn_code (internal_fn ifn)
 {
   switch (ifn)
     {
-#define CASE(CODE, IFN) case IFN_COND_##IFN: return CODE;
-      FOR_EACH_CODE_MAPPING(CASE)
+#define CASE(CODE, IFN)                                                        \
+  case IFN_COND_##IFN:                                                         \
+    return CODE;                                                               \
+  case IFN_COND_LEN_##IFN:                                                     \
+    return CODE;
+      FOR_EACH_CODE_MAPPING (CASE)
 #undef CASE
-    default:
-      return ERROR_MARK;
+      default:
+	return ERROR_MARK;
     }
 }
 
@@ -4433,6 +4450,18 @@ get_unconditional_internal_fn (internal_fn ifn)
    operating elementwise if the operands are vectors.  This includes
    the case of an all-true COND, so that the operation always happens.
 
+   There is an alternative approach to interpret the STMT when the operands
+   are vectors which is the operation predicated by both conditional mask
+   and loop control length, the equivalent C code:
+
+     for (int i = 0; i < NUNTIS; i++)
+      {
+	if (i < LEN + BIAS && COND[i])
+	  LHS[i] = A[i] CODE B[i];
+	else
+	  LHS[i] = ELSE[i];
+      }
+
    When returning true, set:
 
    - *COND_OUT to the condition COND, or to NULL_TREE if the condition
@@ -4440,13 +4469,18 @@ get_unconditional_internal_fn (internal_fn ifn)
    - *CODE_OUT to the tree code
    - OPS[I] to operand I of *CODE_OUT
    - *ELSE_OUT to the fallback value ELSE, or to NULL_TREE if the
-     condition is known to be all true.  */
+     condition is known to be all true.
+   - *LEN to the len argument if it COND_LEN_* operations or to NULL_TREE.
+   - *BIAS to the bias argument if it COND_LEN_* operations or to NULL_TREE.  */
 
 bool
 can_interpret_as_conditional_op_p (gimple *stmt, tree *cond_out,
 				   tree_code *code_out,
-				   tree (&ops)[3], tree *else_out)
+				   tree (&ops)[3], tree *else_out,
+				   tree *len, tree *bias)
 {
+  *len = NULL_TREE;
+  *bias = NULL_TREE;
   if (gassign *assign = dyn_cast <gassign *> (stmt))
     {
       *cond_out = NULL_TREE;
@@ -4462,18 +4496,26 @@ can_interpret_as_conditional_op_p (gimple *stmt, tree *cond_out,
       {
 	internal_fn ifn = gimple_call_internal_fn (call);
 	tree_code code = conditional_internal_fn_code (ifn);
+	int len_index = internal_fn_len_index (ifn);
+	int cond_nargs = len_index >= 0 ? 4 : 2;
 	if (code != ERROR_MARK)
 	  {
 	    *cond_out = gimple_call_arg (call, 0);
 	    *code_out = code;
-	    unsigned int nops = gimple_call_num_args (call) - 2;
+	    unsigned int nops = gimple_call_num_args (call) - cond_nargs;
 	    for (unsigned int i = 0; i < 3; ++i)
 	      ops[i] = i < nops ? gimple_call_arg (call, i + 1) : NULL_TREE;
 	    *else_out = gimple_call_arg (call, nops + 1);
+	    if (len_index >= 0)
+	      {
+		*len = gimple_call_arg (call, len_index);
+		*bias = gimple_call_arg (call, len_index + 1);
+	      }
 	    if (integer_truep (*cond_out))
 	      {
 		*cond_out = NULL_TREE;
-		*else_out = NULL_TREE;
+		if (len_index < 0)
+		  *else_out = NULL_TREE;
 	      }
 	    return true;
 	  }
@@ -4561,8 +4603,32 @@ internal_fn_len_index (internal_fn fn)
 
     case IFN_LEN_MASK_GATHER_LOAD:
     case IFN_LEN_MASK_SCATTER_STORE:
+    case IFN_COND_LEN_FMA:
+    case IFN_COND_LEN_FMS:
+    case IFN_COND_LEN_FNMA:
+    case IFN_COND_LEN_FNMS:
       return 5;
 
+    case IFN_COND_LEN_ADD:
+    case IFN_COND_LEN_SUB:
+    case IFN_COND_LEN_MUL:
+    case IFN_COND_LEN_DIV:
+    case IFN_COND_LEN_MOD:
+    case IFN_COND_LEN_RDIV:
+    case IFN_COND_LEN_MIN:
+    case IFN_COND_LEN_MAX:
+    case IFN_COND_LEN_FMIN:
+    case IFN_COND_LEN_FMAX:
+    case IFN_COND_LEN_AND:
+    case IFN_COND_LEN_IOR:
+    case IFN_COND_LEN_XOR:
+    case IFN_COND_LEN_SHL:
+    case IFN_COND_LEN_SHR:
+      return 4;
+
+    case IFN_COND_LEN_NEG:
+      return 3;
+
     default:
       return -1;
     }
diff --git a/gcc/internal-fn.h b/gcc/internal-fn.h
index dd1bab0bddf..a5c3f4765ff 100644
--- a/gcc/internal-fn.h
+++ b/gcc/internal-fn.h
@@ -229,7 +229,7 @@ extern tree_code conditional_internal_fn_code (internal_fn);
 extern internal_fn get_unconditional_internal_fn (internal_fn);
 extern bool can_interpret_as_conditional_op_p (gimple *, tree *,
 					       tree_code *, tree (&)[3],
-					       tree *);
+					       tree *, tree *, tree *);
 
 extern bool internal_load_fn_p (internal_fn);
 extern bool internal_store_fn_p (internal_fn);
diff --git a/gcc/tree-ssa-math-opts.cc b/gcc/tree-ssa-math-opts.cc
index 68fc518b1ab..4563d1ccf7f 100644
--- a/gcc/tree-ssa-math-opts.cc
+++ b/gcc/tree-ssa-math-opts.cc
@@ -3099,10 +3099,11 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree op2)
 	  negate_p = true;
 	}
 
-      tree cond, else_value, ops[3];
+      tree cond, else_value, ops[3], len, bias;
       tree_code code;
       if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code,
-					      ops, &else_value))
+					      ops, &else_value,
+					      &len, &bias))
 	gcc_unreachable ();
       addop = ops[0] == result ? ops[1] : ops[0];
 
@@ -3122,7 +3123,23 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree op2)
       if (seq)
 	gsi_insert_seq_before (&gsi, seq, GSI_SAME_STMT);
 
-      if (cond)
+      if (len)
+	{
+	  gcc_assert (gimple_call_internal_p (use_stmt));
+	  gcc_assert (bias);
+	  if (!cond)
+	    {
+	      internal_fn ifn = gimple_call_internal_fn (use_stmt);
+	      int mask_index = internal_fn_mask_index (ifn);
+	      gcc_assert (mask_index >= 0);
+	      tree mask = gimple_call_arg (use_stmt, mask_index);
+	      cond = build_minus_one_cst (TREE_TYPE (mask));
+	    }
+	  fma_stmt
+	    = gimple_build_call_internal (IFN_COND_LEN_FMA, 7, cond, mulop1,
+					  op2, addop, else_value, len, bias);
+	}
+      else if (cond)
 	fma_stmt = gimple_build_call_internal (IFN_COND_FMA, 5, cond, mulop1,
 					       op2, addop, else_value);
       else
@@ -3420,10 +3437,10 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
 	  negate_p = seen_negate_p = true;
 	}
 
-      tree cond, else_value, ops[3];
+      tree cond, else_value, ops[3], len, bias;
       tree_code code;
       if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code, ops,
-					      &else_value))
+					      &else_value, &len, &bias))
 	return false;
 
       switch (code)
@@ -3446,7 +3463,19 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
 	{
 	  if (cond == result || else_value == result)
 	    return false;
-	  if (!direct_internal_fn_supported_p (IFN_COND_FMA, type, opt_type))
+	  if (!direct_internal_fn_supported_p (IFN_COND_FMA, type, opt_type)
+	      && !direct_internal_fn_supported_p (IFN_COND_LEN_FMA, type,
+						  opt_type))
+	    return false;
+	}
+
+      if (len)
+	{
+	  gcc_assert (bias);
+	  if (else_value == result)
+	    return false;
+	  if (!direct_internal_fn_supported_p (IFN_COND_LEN_FMA, type,
+					       opt_type))
 	    return false;
 	}
 
@@ -5632,6 +5661,22 @@ math_opts_dom_walker::after_dom_children (basic_block bb)
 		}
 	      break;
 
+	    case CFN_COND_LEN_MUL:
+	      if (convert_mult_to_fma (stmt,
+				       gimple_call_arg (stmt, 1),
+				       gimple_call_arg (stmt, 2),
+				       &fma_state,
+				       integer_truep (gimple_call_arg (stmt, 0))
+					 ? NULL_TREE
+					 : gimple_call_arg (stmt, 0)))
+
+		{
+		  gsi_remove (&gsi, true);
+		  release_defs (stmt);
+		  continue;
+		}
+	      break;
+
 	    case CFN_LAST:
 	      cancel_fma_deferring (&fma_state);
 	      break;
-- 
2.36.3


^ permalink raw reply	[flat|nested] 3+ messages in thread

* Re: [PATCH] SSA MATH: Support COND_LEN_FMA for floating-point math optimization
  2023-07-13  5:17 [PATCH] SSA MATH: Support COND_LEN_FMA for floating-point math optimization juzhe.zhong
@ 2023-07-13  7:53 ` Richard Biener
  2023-07-13  8:56   ` juzhe.zhong
  0 siblings, 1 reply; 3+ messages in thread
From: Richard Biener @ 2023-07-13  7:53 UTC (permalink / raw)
  To: Ju-Zhe Zhong; +Cc: gcc-patches, richard.sandiford

On Thu, 13 Jul 2023, juzhe.zhong@rivai.ai wrote:

> From: Ju-Zhe Zhong <juzhe.zhong@rivai.ai>
> 
> Hi, Richard and Richi.
> 
> Previous patch we support COND_LEN_* binary operations. However, we didn't
> support COND_LEN_* ternary.
> 
> Now, this patch support COND_LEN_* ternary. Consider this following case:
> 
> #define TEST_TYPE(TYPE)                                                        \
>   __attribute__ ((noipa)) void ternop_##TYPE (TYPE *__restrict dst,            \
> 					      TYPE *__restrict a,              \
> 					      TYPE *__restrict b,\
>                 TYPE *__restrict c, int n)       \
>   {                                                                            \
>     for (int i = 0; i < n; i++)                                                \
>       dst[i] += a[i] * b[i];                                                     \
>   }
> 
> #define TEST_ALL() TEST_TYPE (double)
> 
> TEST_ALL ()
> 
> Before this patch:
> ...
> COND_LEN_MUL
> COND_LEN_ADD
> 
> Afther this patch:
> ...
> COND_LEN_FMA
> 
> gcc/ChangeLog:
> 
>         * genmatch.cc (commutative_op): Add COND_LEN_*
>         * internal-fn.cc (first_commutative_argument): Ditto.
>         (CASE): Ditto.
>         (get_unconditional_internal_fn): Ditto.
>         (can_interpret_as_conditional_op_p): Ditto.
>         (internal_fn_len_index): Ditto.
>         * internal-fn.h (can_interpret_as_conditional_op_p): Ditt.
>         * tree-ssa-math-opts.cc (convert_mult_to_fma_1): Ditto.
>         (convert_mult_to_fma): Ditto.
>         (math_opts_dom_walker::after_dom_children): Ditto.
> 
> ---
>  gcc/genmatch.cc           | 13 +++++++
>  gcc/internal-fn.cc        | 82 +++++++++++++++++++++++++++++++++++----
>  gcc/internal-fn.h         |  2 +-
>  gcc/tree-ssa-math-opts.cc | 57 ++++++++++++++++++++++++---
>  4 files changed, 139 insertions(+), 15 deletions(-)
> 
> diff --git a/gcc/genmatch.cc b/gcc/genmatch.cc
> index 5fceeec9780..2302f2a7ff0 100644
> --- a/gcc/genmatch.cc
> +++ b/gcc/genmatch.cc
> @@ -559,6 +559,19 @@ commutative_op (id_base *id)
>        case CFN_COND_FMS:
>        case CFN_COND_FNMA:
>        case CFN_COND_FNMS:
> +      case CFN_COND_LEN_ADD:
> +      case CFN_COND_LEN_MUL:
> +      case CFN_COND_LEN_MIN:
> +      case CFN_COND_LEN_MAX:
> +      case CFN_COND_LEN_FMIN:
> +      case CFN_COND_LEN_FMAX:
> +      case CFN_COND_LEN_AND:
> +      case CFN_COND_LEN_IOR:
> +      case CFN_COND_LEN_XOR:
> +      case CFN_COND_LEN_FMA:
> +      case CFN_COND_LEN_FMS:
> +      case CFN_COND_LEN_FNMA:
> +      case CFN_COND_LEN_FNMS:
>  	return 1;
>  
>        default:
> diff --git a/gcc/internal-fn.cc b/gcc/internal-fn.cc
> index c11123a1173..e47b1377ff8 100644
> --- a/gcc/internal-fn.cc
> +++ b/gcc/internal-fn.cc
> @@ -4191,6 +4191,19 @@ first_commutative_argument (internal_fn fn)
>      case IFN_COND_FMS:
>      case IFN_COND_FNMA:
>      case IFN_COND_FNMS:
> +    case IFN_COND_LEN_ADD:
> +    case IFN_COND_LEN_MUL:
> +    case IFN_COND_LEN_MIN:
> +    case IFN_COND_LEN_MAX:
> +    case IFN_COND_LEN_FMIN:
> +    case IFN_COND_LEN_FMAX:
> +    case IFN_COND_LEN_AND:
> +    case IFN_COND_LEN_IOR:
> +    case IFN_COND_LEN_XOR:
> +    case IFN_COND_LEN_FMA:
> +    case IFN_COND_LEN_FMS:
> +    case IFN_COND_LEN_FNMA:
> +    case IFN_COND_LEN_FNMS:
>        return 1;
>  
>      default:
> @@ -4330,11 +4343,15 @@ conditional_internal_fn_code (internal_fn ifn)
>  {
>    switch (ifn)
>      {
> -#define CASE(CODE, IFN) case IFN_COND_##IFN: return CODE;
> -      FOR_EACH_CODE_MAPPING(CASE)
> +#define CASE(CODE, IFN)                                                        \
> +  case IFN_COND_##IFN:                                                         \
> +    return CODE;                                                               \
> +  case IFN_COND_LEN_##IFN:                                                     \
> +    return CODE;
> +      FOR_EACH_CODE_MAPPING (CASE)
>  #undef CASE
> -    default:
> -      return ERROR_MARK;
> +      default:
> +	return ERROR_MARK;

either before or after white-space seems broken.

>      }
>  }
>  
> @@ -4433,6 +4450,18 @@ get_unconditional_internal_fn (internal_fn ifn)
>     operating elementwise if the operands are vectors.  This includes
>     the case of an all-true COND, so that the operation always happens.
>  
> +   There is an alternative approach to interpret the STMT when the operands
> +   are vectors which is the operation predicated by both conditional mask
> +   and loop control length, the equivalent C code:
> +
> +     for (int i = 0; i < NUNTIS; i++)
> +      {
> +	if (i < LEN + BIAS && COND[i])
> +	  LHS[i] = A[i] CODE B[i];
> +	else
> +	  LHS[i] = ELSE[i];
> +      }
> +
>     When returning true, set:
>  
>     - *COND_OUT to the condition COND, or to NULL_TREE if the condition
> @@ -4440,13 +4469,18 @@ get_unconditional_internal_fn (internal_fn ifn)
>     - *CODE_OUT to the tree code
>     - OPS[I] to operand I of *CODE_OUT
>     - *ELSE_OUT to the fallback value ELSE, or to NULL_TREE if the
> -     condition is known to be all true.  */
> +     condition is known to be all true.
> +   - *LEN to the len argument if it COND_LEN_* operations or to NULL_TREE.
> +   - *BIAS to the bias argument if it COND_LEN_* operations or to NULL_TREE.  */
>  
>  bool
>  can_interpret_as_conditional_op_p (gimple *stmt, tree *cond_out,
>  				   tree_code *code_out,
> -				   tree (&ops)[3], tree *else_out)
> +				   tree (&ops)[3], tree *else_out,
> +				   tree *len, tree *bias)
>  {
> +  *len = NULL_TREE;
> +  *bias = NULL_TREE;
>    if (gassign *assign = dyn_cast <gassign *> (stmt))
>      {
>        *cond_out = NULL_TREE;
> @@ -4462,18 +4496,26 @@ can_interpret_as_conditional_op_p (gimple *stmt, tree *cond_out,
>        {
>  	internal_fn ifn = gimple_call_internal_fn (call);
>  	tree_code code = conditional_internal_fn_code (ifn);
> +	int len_index = internal_fn_len_index (ifn);
> +	int cond_nargs = len_index >= 0 ? 4 : 2;
>  	if (code != ERROR_MARK)
>  	  {
>  	    *cond_out = gimple_call_arg (call, 0);
>  	    *code_out = code;
> -	    unsigned int nops = gimple_call_num_args (call) - 2;
> +	    unsigned int nops = gimple_call_num_args (call) - cond_nargs;
>  	    for (unsigned int i = 0; i < 3; ++i)
>  	      ops[i] = i < nops ? gimple_call_arg (call, i + 1) : NULL_TREE;
>  	    *else_out = gimple_call_arg (call, nops + 1);
> +	    if (len_index >= 0)
> +	      {
> +		*len = gimple_call_arg (call, len_index);
> +		*bias = gimple_call_arg (call, len_index + 1);
> +	      }
>  	    if (integer_truep (*cond_out))
>  	      {
>  		*cond_out = NULL_TREE;
> -		*else_out = NULL_TREE;
> +		if (len_index < 0)
> +		  *else_out = NULL_TREE;
>  	      }
>  	    return true;
>  	  }
> @@ -4561,8 +4603,32 @@ internal_fn_len_index (internal_fn fn)
>  
>      case IFN_LEN_MASK_GATHER_LOAD:
>      case IFN_LEN_MASK_SCATTER_STORE:
> +    case IFN_COND_LEN_FMA:
> +    case IFN_COND_LEN_FMS:
> +    case IFN_COND_LEN_FNMA:
> +    case IFN_COND_LEN_FNMS:
>        return 5;
>  
> +    case IFN_COND_LEN_ADD:
> +    case IFN_COND_LEN_SUB:
> +    case IFN_COND_LEN_MUL:
> +    case IFN_COND_LEN_DIV:
> +    case IFN_COND_LEN_MOD:
> +    case IFN_COND_LEN_RDIV:
> +    case IFN_COND_LEN_MIN:
> +    case IFN_COND_LEN_MAX:
> +    case IFN_COND_LEN_FMIN:
> +    case IFN_COND_LEN_FMAX:
> +    case IFN_COND_LEN_AND:
> +    case IFN_COND_LEN_IOR:
> +    case IFN_COND_LEN_XOR:
> +    case IFN_COND_LEN_SHL:
> +    case IFN_COND_LEN_SHR:
> +      return 4;
> +
> +    case IFN_COND_LEN_NEG:
> +      return 3;
> +
>      default:
>        return -1;
>      }
> diff --git a/gcc/internal-fn.h b/gcc/internal-fn.h
> index dd1bab0bddf..a5c3f4765ff 100644
> --- a/gcc/internal-fn.h
> +++ b/gcc/internal-fn.h
> @@ -229,7 +229,7 @@ extern tree_code conditional_internal_fn_code (internal_fn);
>  extern internal_fn get_unconditional_internal_fn (internal_fn);
>  extern bool can_interpret_as_conditional_op_p (gimple *, tree *,
>  					       tree_code *, tree (&)[3],
> -					       tree *);
> +					       tree *, tree *, tree *);
>  
>  extern bool internal_load_fn_p (internal_fn);
>  extern bool internal_store_fn_p (internal_fn);
> diff --git a/gcc/tree-ssa-math-opts.cc b/gcc/tree-ssa-math-opts.cc
> index 68fc518b1ab..4563d1ccf7f 100644
> --- a/gcc/tree-ssa-math-opts.cc
> +++ b/gcc/tree-ssa-math-opts.cc
> @@ -3099,10 +3099,11 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree op2)
>  	  negate_p = true;
>  	}
>  
> -      tree cond, else_value, ops[3];
> +      tree cond, else_value, ops[3], len, bias;
>        tree_code code;
>        if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code,
> -					      ops, &else_value))
> +					      ops, &else_value,
> +					      &len, &bias))
>  	gcc_unreachable ();
>        addop = ops[0] == result ? ops[1] : ops[0];
>  
> @@ -3122,7 +3123,23 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree op2)
>        if (seq)
>  	gsi_insert_seq_before (&gsi, seq, GSI_SAME_STMT);
>  
> -      if (cond)
> +      if (len)
> +	{
> +	  gcc_assert (gimple_call_internal_p (use_stmt));
> +	  gcc_assert (bias);
> +	  if (!cond)

can len without cond happen here?

> +	    {
> +	      internal_fn ifn = gimple_call_internal_fn (use_stmt);
> +	      int mask_index = internal_fn_mask_index (ifn);
> +	      gcc_assert (mask_index >= 0);
> +	      tree mask = gimple_call_arg (use_stmt, mask_index);
> +	      cond = build_minus_one_cst (TREE_TYPE (mask));
> +	    }
> +	  fma_stmt
> +	    = gimple_build_call_internal (IFN_COND_LEN_FMA, 7, cond, mulop1,
> +					  op2, addop, else_value, len, bias);
> +	}
> +      else if (cond)
>  	fma_stmt = gimple_build_call_internal (IFN_COND_FMA, 5, cond, mulop1,
>  					       op2, addop, else_value);
>        else
> @@ -3420,10 +3437,10 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
>  	  negate_p = seen_negate_p = true;
>  	}
>  
> -      tree cond, else_value, ops[3];
> +      tree cond, else_value, ops[3], len, bias;
>        tree_code code;
>        if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code, ops,
> -					      &else_value))
> +					      &else_value, &len, &bias))
>  	return false;
>  
>        switch (code)
> @@ -3446,7 +3463,19 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
>  	{
>  	  if (cond == result || else_value == result)
>  	    return false;
> -	  if (!direct_internal_fn_supported_p (IFN_COND_FMA, type, opt_type))
> +	  if (!direct_internal_fn_supported_p (IFN_COND_FMA, type, opt_type)
> +	      && !direct_internal_fn_supported_p (IFN_COND_LEN_FMA, type,
> +						  opt_type))
> +	    return false;
> +	}
> +
> +      if (len)
> +	{
> +	  gcc_assert (bias);
> +	  if (else_value == result)
> +	    return false;
> +	  if (!direct_internal_fn_supported_p (IFN_COND_LEN_FMA, type,
> +					       opt_type))
>  	    return false;
>  	}
>  
> @@ -5632,6 +5661,22 @@ math_opts_dom_walker::after_dom_children (basic_block bb)
>  		}
>  	      break;
>  
> +	    case CFN_COND_LEN_MUL:
> +	      if (convert_mult_to_fma (stmt,
> +				       gimple_call_arg (stmt, 1),
> +				       gimple_call_arg (stmt, 2),
> +				       &fma_state,
> +				       integer_truep (gimple_call_arg (stmt, 0))
> +					 ? NULL_TREE
> +					 : gimple_call_arg (stmt, 0)))

because of this?  why not pass it through?

Otherwise looks OK.

Thanks,
Richard.

> +
> +		{
> +		  gsi_remove (&gsi, true);
> +		  release_defs (stmt);
> +		  continue;
> +		}
> +	      break;
> +
>  	    case CFN_LAST:
>  	      cancel_fma_deferring (&fma_state);
>  	      break;
> 

-- 
Richard Biener <rguenther@suse.de>
SUSE Software Solutions Germany GmbH, Frankenstrasse 146, 90461 Nuernberg,
Germany; GF: Ivo Totev, Andrew Myers, Andrew McDonald, Boudien Moerman;
HRB 36809 (AG Nuernberg)

^ permalink raw reply	[flat|nested] 3+ messages in thread

* Re: Re: [PATCH] SSA MATH: Support COND_LEN_FMA for floating-point math optimization
  2023-07-13  7:53 ` Richard Biener
@ 2023-07-13  8:56   ` juzhe.zhong
  0 siblings, 0 replies; 3+ messages in thread
From: juzhe.zhong @ 2023-07-13  8:56 UTC (permalink / raw)
  To: rguenther; +Cc: gcc-patches, richard.sandiford

[-- Attachment #1: Type: text/plain, Size: 12596 bytes --]

Hi, Richard.

>> either before or after white-space seems broken.
I use clang-format with the format in gcc/contrib/format. 
I manually adjust it, could you take a look to see whether the format issue is still there?

I have address all your comments with V2 patch:
https://gcc.gnu.org/pipermail/gcc-patches/2023-July/624395.html 

Does it look more reasonable ?

Thanks.


juzhe.zhong@rivai.ai
 
From: Richard Biener
Date: 2023-07-13 15:53
To: Ju-Zhe Zhong
CC: gcc-patches; richard.sandiford
Subject: Re: [PATCH] SSA MATH: Support COND_LEN_FMA for floating-point math optimization
On Thu, 13 Jul 2023, juzhe.zhong@rivai.ai wrote:
 
> From: Ju-Zhe Zhong <juzhe.zhong@rivai.ai>
> 
> Hi, Richard and Richi.
> 
> Previous patch we support COND_LEN_* binary operations. However, we didn't
> support COND_LEN_* ternary.
> 
> Now, this patch support COND_LEN_* ternary. Consider this following case:
> 
> #define TEST_TYPE(TYPE)                                                        \
>   __attribute__ ((noipa)) void ternop_##TYPE (TYPE *__restrict dst,            \
>       TYPE *__restrict a,              \
>       TYPE *__restrict b,\
>                 TYPE *__restrict c, int n)       \
>   {                                                                            \
>     for (int i = 0; i < n; i++)                                                \
>       dst[i] += a[i] * b[i];                                                     \
>   }
> 
> #define TEST_ALL() TEST_TYPE (double)
> 
> TEST_ALL ()
> 
> Before this patch:
> ...
> COND_LEN_MUL
> COND_LEN_ADD
> 
> Afther this patch:
> ...
> COND_LEN_FMA
> 
> gcc/ChangeLog:
> 
>         * genmatch.cc (commutative_op): Add COND_LEN_*
>         * internal-fn.cc (first_commutative_argument): Ditto.
>         (CASE): Ditto.
>         (get_unconditional_internal_fn): Ditto.
>         (can_interpret_as_conditional_op_p): Ditto.
>         (internal_fn_len_index): Ditto.
>         * internal-fn.h (can_interpret_as_conditional_op_p): Ditt.
>         * tree-ssa-math-opts.cc (convert_mult_to_fma_1): Ditto.
>         (convert_mult_to_fma): Ditto.
>         (math_opts_dom_walker::after_dom_children): Ditto.
> 
> ---
>  gcc/genmatch.cc           | 13 +++++++
>  gcc/internal-fn.cc        | 82 +++++++++++++++++++++++++++++++++++----
>  gcc/internal-fn.h         |  2 +-
>  gcc/tree-ssa-math-opts.cc | 57 ++++++++++++++++++++++++---
>  4 files changed, 139 insertions(+), 15 deletions(-)
> 
> diff --git a/gcc/genmatch.cc b/gcc/genmatch.cc
> index 5fceeec9780..2302f2a7ff0 100644
> --- a/gcc/genmatch.cc
> +++ b/gcc/genmatch.cc
> @@ -559,6 +559,19 @@ commutative_op (id_base *id)
>        case CFN_COND_FMS:
>        case CFN_COND_FNMA:
>        case CFN_COND_FNMS:
> +      case CFN_COND_LEN_ADD:
> +      case CFN_COND_LEN_MUL:
> +      case CFN_COND_LEN_MIN:
> +      case CFN_COND_LEN_MAX:
> +      case CFN_COND_LEN_FMIN:
> +      case CFN_COND_LEN_FMAX:
> +      case CFN_COND_LEN_AND:
> +      case CFN_COND_LEN_IOR:
> +      case CFN_COND_LEN_XOR:
> +      case CFN_COND_LEN_FMA:
> +      case CFN_COND_LEN_FMS:
> +      case CFN_COND_LEN_FNMA:
> +      case CFN_COND_LEN_FNMS:
>  return 1;
>  
>        default:
> diff --git a/gcc/internal-fn.cc b/gcc/internal-fn.cc
> index c11123a1173..e47b1377ff8 100644
> --- a/gcc/internal-fn.cc
> +++ b/gcc/internal-fn.cc
> @@ -4191,6 +4191,19 @@ first_commutative_argument (internal_fn fn)
>      case IFN_COND_FMS:
>      case IFN_COND_FNMA:
>      case IFN_COND_FNMS:
> +    case IFN_COND_LEN_ADD:
> +    case IFN_COND_LEN_MUL:
> +    case IFN_COND_LEN_MIN:
> +    case IFN_COND_LEN_MAX:
> +    case IFN_COND_LEN_FMIN:
> +    case IFN_COND_LEN_FMAX:
> +    case IFN_COND_LEN_AND:
> +    case IFN_COND_LEN_IOR:
> +    case IFN_COND_LEN_XOR:
> +    case IFN_COND_LEN_FMA:
> +    case IFN_COND_LEN_FMS:
> +    case IFN_COND_LEN_FNMA:
> +    case IFN_COND_LEN_FNMS:
>        return 1;
>  
>      default:
> @@ -4330,11 +4343,15 @@ conditional_internal_fn_code (internal_fn ifn)
>  {
>    switch (ifn)
>      {
> -#define CASE(CODE, IFN) case IFN_COND_##IFN: return CODE;
> -      FOR_EACH_CODE_MAPPING(CASE)
> +#define CASE(CODE, IFN)                                                        \
> +  case IFN_COND_##IFN:                                                         \
> +    return CODE;                                                               \
> +  case IFN_COND_LEN_##IFN:                                                     \
> +    return CODE;
> +      FOR_EACH_CODE_MAPPING (CASE)
>  #undef CASE
> -    default:
> -      return ERROR_MARK;
> +      default:
> + return ERROR_MARK;
 
either before or after white-space seems broken.
 
>      }
>  }
>  
> @@ -4433,6 +4450,18 @@ get_unconditional_internal_fn (internal_fn ifn)
>     operating elementwise if the operands are vectors.  This includes
>     the case of an all-true COND, so that the operation always happens.
>  
> +   There is an alternative approach to interpret the STMT when the operands
> +   are vectors which is the operation predicated by both conditional mask
> +   and loop control length, the equivalent C code:
> +
> +     for (int i = 0; i < NUNTIS; i++)
> +      {
> + if (i < LEN + BIAS && COND[i])
> +   LHS[i] = A[i] CODE B[i];
> + else
> +   LHS[i] = ELSE[i];
> +      }
> +
>     When returning true, set:
>  
>     - *COND_OUT to the condition COND, or to NULL_TREE if the condition
> @@ -4440,13 +4469,18 @@ get_unconditional_internal_fn (internal_fn ifn)
>     - *CODE_OUT to the tree code
>     - OPS[I] to operand I of *CODE_OUT
>     - *ELSE_OUT to the fallback value ELSE, or to NULL_TREE if the
> -     condition is known to be all true.  */
> +     condition is known to be all true.
> +   - *LEN to the len argument if it COND_LEN_* operations or to NULL_TREE.
> +   - *BIAS to the bias argument if it COND_LEN_* operations or to NULL_TREE.  */
>  
>  bool
>  can_interpret_as_conditional_op_p (gimple *stmt, tree *cond_out,
>     tree_code *code_out,
> -    tree (&ops)[3], tree *else_out)
> +    tree (&ops)[3], tree *else_out,
> +    tree *len, tree *bias)
>  {
> +  *len = NULL_TREE;
> +  *bias = NULL_TREE;
>    if (gassign *assign = dyn_cast <gassign *> (stmt))
>      {
>        *cond_out = NULL_TREE;
> @@ -4462,18 +4496,26 @@ can_interpret_as_conditional_op_p (gimple *stmt, tree *cond_out,
>        {
>  internal_fn ifn = gimple_call_internal_fn (call);
>  tree_code code = conditional_internal_fn_code (ifn);
> + int len_index = internal_fn_len_index (ifn);
> + int cond_nargs = len_index >= 0 ? 4 : 2;
>  if (code != ERROR_MARK)
>    {
>      *cond_out = gimple_call_arg (call, 0);
>      *code_out = code;
> -     unsigned int nops = gimple_call_num_args (call) - 2;
> +     unsigned int nops = gimple_call_num_args (call) - cond_nargs;
>      for (unsigned int i = 0; i < 3; ++i)
>        ops[i] = i < nops ? gimple_call_arg (call, i + 1) : NULL_TREE;
>      *else_out = gimple_call_arg (call, nops + 1);
> +     if (len_index >= 0)
> +       {
> + *len = gimple_call_arg (call, len_index);
> + *bias = gimple_call_arg (call, len_index + 1);
> +       }
>      if (integer_truep (*cond_out))
>        {
>  *cond_out = NULL_TREE;
> - *else_out = NULL_TREE;
> + if (len_index < 0)
> +   *else_out = NULL_TREE;
>        }
>      return true;
>    }
> @@ -4561,8 +4603,32 @@ internal_fn_len_index (internal_fn fn)
>  
>      case IFN_LEN_MASK_GATHER_LOAD:
>      case IFN_LEN_MASK_SCATTER_STORE:
> +    case IFN_COND_LEN_FMA:
> +    case IFN_COND_LEN_FMS:
> +    case IFN_COND_LEN_FNMA:
> +    case IFN_COND_LEN_FNMS:
>        return 5;
>  
> +    case IFN_COND_LEN_ADD:
> +    case IFN_COND_LEN_SUB:
> +    case IFN_COND_LEN_MUL:
> +    case IFN_COND_LEN_DIV:
> +    case IFN_COND_LEN_MOD:
> +    case IFN_COND_LEN_RDIV:
> +    case IFN_COND_LEN_MIN:
> +    case IFN_COND_LEN_MAX:
> +    case IFN_COND_LEN_FMIN:
> +    case IFN_COND_LEN_FMAX:
> +    case IFN_COND_LEN_AND:
> +    case IFN_COND_LEN_IOR:
> +    case IFN_COND_LEN_XOR:
> +    case IFN_COND_LEN_SHL:
> +    case IFN_COND_LEN_SHR:
> +      return 4;
> +
> +    case IFN_COND_LEN_NEG:
> +      return 3;
> +
>      default:
>        return -1;
>      }
> diff --git a/gcc/internal-fn.h b/gcc/internal-fn.h
> index dd1bab0bddf..a5c3f4765ff 100644
> --- a/gcc/internal-fn.h
> +++ b/gcc/internal-fn.h
> @@ -229,7 +229,7 @@ extern tree_code conditional_internal_fn_code (internal_fn);
>  extern internal_fn get_unconditional_internal_fn (internal_fn);
>  extern bool can_interpret_as_conditional_op_p (gimple *, tree *,
>         tree_code *, tree (&)[3],
> -        tree *);
> +        tree *, tree *, tree *);
>  
>  extern bool internal_load_fn_p (internal_fn);
>  extern bool internal_store_fn_p (internal_fn);
> diff --git a/gcc/tree-ssa-math-opts.cc b/gcc/tree-ssa-math-opts.cc
> index 68fc518b1ab..4563d1ccf7f 100644
> --- a/gcc/tree-ssa-math-opts.cc
> +++ b/gcc/tree-ssa-math-opts.cc
> @@ -3099,10 +3099,11 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree op2)
>    negate_p = true;
>  }
>  
> -      tree cond, else_value, ops[3];
> +      tree cond, else_value, ops[3], len, bias;
>        tree_code code;
>        if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code,
> -       ops, &else_value))
> +       ops, &else_value,
> +       &len, &bias))
>  gcc_unreachable ();
>        addop = ops[0] == result ? ops[1] : ops[0];
>  
> @@ -3122,7 +3123,23 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree op2)
>        if (seq)
>  gsi_insert_seq_before (&gsi, seq, GSI_SAME_STMT);
>  
> -      if (cond)
> +      if (len)
> + {
> +   gcc_assert (gimple_call_internal_p (use_stmt));
> +   gcc_assert (bias);
> +   if (!cond)
 
can len without cond happen here?
 
> +     {
> +       internal_fn ifn = gimple_call_internal_fn (use_stmt);
> +       int mask_index = internal_fn_mask_index (ifn);
> +       gcc_assert (mask_index >= 0);
> +       tree mask = gimple_call_arg (use_stmt, mask_index);
> +       cond = build_minus_one_cst (TREE_TYPE (mask));
> +     }
> +   fma_stmt
> +     = gimple_build_call_internal (IFN_COND_LEN_FMA, 7, cond, mulop1,
> +   op2, addop, else_value, len, bias);
> + }
> +      else if (cond)
>  fma_stmt = gimple_build_call_internal (IFN_COND_FMA, 5, cond, mulop1,
>         op2, addop, else_value);
>        else
> @@ -3420,10 +3437,10 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
>    negate_p = seen_negate_p = true;
>  }
>  
> -      tree cond, else_value, ops[3];
> +      tree cond, else_value, ops[3], len, bias;
>        tree_code code;
>        if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code, ops,
> -       &else_value))
> +       &else_value, &len, &bias))
>  return false;
>  
>        switch (code)
> @@ -3446,7 +3463,19 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
>  {
>    if (cond == result || else_value == result)
>      return false;
> -   if (!direct_internal_fn_supported_p (IFN_COND_FMA, type, opt_type))
> +   if (!direct_internal_fn_supported_p (IFN_COND_FMA, type, opt_type)
> +       && !direct_internal_fn_supported_p (IFN_COND_LEN_FMA, type,
> +   opt_type))
> +     return false;
> + }
> +
> +      if (len)
> + {
> +   gcc_assert (bias);
> +   if (else_value == result)
> +     return false;
> +   if (!direct_internal_fn_supported_p (IFN_COND_LEN_FMA, type,
> +        opt_type))
>      return false;
>  }
>  
> @@ -5632,6 +5661,22 @@ math_opts_dom_walker::after_dom_children (basic_block bb)
>  }
>        break;
>  
> +     case CFN_COND_LEN_MUL:
> +       if (convert_mult_to_fma (stmt,
> +        gimple_call_arg (stmt, 1),
> +        gimple_call_arg (stmt, 2),
> +        &fma_state,
> +        integer_truep (gimple_call_arg (stmt, 0))
> + ? NULL_TREE
> + : gimple_call_arg (stmt, 0)))
 
because of this?  why not pass it through?
 
Otherwise looks OK.
 
Thanks,
Richard.
 
> +
> + {
> +   gsi_remove (&gsi, true);
> +   release_defs (stmt);
> +   continue;
> + }
> +       break;
> +
>      case CFN_LAST:
>        cancel_fma_deferring (&fma_state);
>        break;
> 
 
-- 
Richard Biener <rguenther@suse.de>
SUSE Software Solutions Germany GmbH, Frankenstrasse 146, 90461 Nuernberg,
Germany; GF: Ivo Totev, Andrew Myers, Andrew McDonald, Boudien Moerman;
HRB 36809 (AG Nuernberg)
 

^ permalink raw reply	[flat|nested] 3+ messages in thread

end of thread, other threads:[~2023-07-13  8:56 UTC | newest]

Thread overview: 3+ messages (download: mbox.gz / follow: Atom feed)
-- links below jump to the message on this page --
2023-07-13  5:17 [PATCH] SSA MATH: Support COND_LEN_FMA for floating-point math optimization juzhe.zhong
2023-07-13  7:53 ` Richard Biener
2023-07-13  8:56   ` juzhe.zhong

This is a public inbox, see mirroring instructions
for how to clone and mirror all data and code used for this inbox;
as well as URLs for read-only IMAP folder(s) and NNTP newsgroup(s).