diff --git a/gcc/doc/md.texi b/gcc/doc/md.texi index b8cc90e1a75e402abbf8a8cf2efefc1a333f8b3a..6d5a98c4946d3ff4c2b8abea5c29caa6863fd3f7 100644 --- a/gcc/doc/md.texi +++ b/gcc/doc/md.texi @@ -6202,6 +6202,51 @@ The operation is only supported for vector modes @var{m}. This pattern is not allowed to @code{FAIL}. +@cindex @code{cmla@var{m}4} instruction pattern +@item @samp{cmla@var{m}4} +Perform a vector multiply and accumulate that is semantically the same as +a multiply and accumulate of complex numbers. + +@smallexample + complex TYPE c[N]; + complex TYPE a[N]; + complex TYPE b[N]; + for (int i = 0; i < N; i += 1) + @{ + c[i] += a[i] * b[i]; + @} +@end smallexample + +In GCC lane ordering the real part of the number must be in the even lanes with +the imaginary part in the odd lanes. + +The operation is only supported for vector modes @var{m}. + +This pattern is not allowed to @code{FAIL}. + +@cindex @code{cmla_conj@var{m}4} instruction pattern +@item @samp{cmla_conj@var{m}4} +Perform a vector multiply by conjugate and accumulate that is semantically +the same as a multiply and accumulate of complex numbers where the second +multiply arguments is conjugated. + +@smallexample + complex TYPE c[N]; + complex TYPE a[N]; + complex TYPE b[N]; + for (int i = 0; i < N; i += 1) + @{ + c[i] += a[i] * conj (b[i]); + @} +@end smallexample + +In GCC lane ordering the real part of the number must be in the even lanes with +the imaginary part in the odd lanes. + +The operation is only supported for vector modes @var{m}. + +This pattern is not allowed to @code{FAIL}. + @cindex @code{cmul@var{m}4} instruction pattern @item @samp{cmul@var{m}4} Perform a vector multiply that is semantically the same as multiply of diff --git a/gcc/internal-fn.def b/gcc/internal-fn.def index 5a0bbe3fe5dee591d54130e60f6996b28164ae38..305450e026d4b94ab62ceb9ca719ec5570ff43eb 100644 --- a/gcc/internal-fn.def +++ b/gcc/internal-fn.def @@ -288,6 +288,8 @@ DEF_INTERNAL_FLT_FN (LDEXP, ECF_CONST, ldexp, binary) /* Ternary math functions. */ DEF_INTERNAL_FLT_FLOATN_FN (FMA, ECF_CONST, fma, ternary) +DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA, ECF_CONST, cmla, ternary) +DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA_CONJ, ECF_CONST, cmla_conj, ternary) /* Unary integer ops. */ DEF_INTERNAL_INT_FN (CLRSB, ECF_CONST | ECF_NOTHROW, clrsb, unary) diff --git a/gcc/optabs.def b/gcc/optabs.def index e82396bae1117c6de91304761a560b7fbcb69ce1..8e2758d685ed85e02df10dac571eb40d45a294ed 100644 --- a/gcc/optabs.def +++ b/gcc/optabs.def @@ -294,6 +294,8 @@ OPTAB_D (cadd90_optab, "cadd90$a3") OPTAB_D (cadd270_optab, "cadd270$a3") OPTAB_D (cmul_optab, "cmul$a3") OPTAB_D (cmul_conj_optab, "cmul_conj$a3") +OPTAB_D (cmla_optab, "cmla$a4") +OPTAB_D (cmla_conj_optab, "cmla_conj$a4") OPTAB_D (cos_optab, "cos$a2") OPTAB_D (cosh_optab, "cosh$a2") OPTAB_D (exp10_optab, "exp10$a2") diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c index 82721acbab8cf81c4d6f9954c98fb913a7bb6282..3625a80c08e3d70fd362fc52e17e65b3b2c7da83 100644 --- a/gcc/tree-vect-slp-patterns.c +++ b/gcc/tree-vect-slp-patterns.c @@ -325,6 +325,24 @@ vect_match_expression_p (slp_tree node, tree_code code) return true; } +/* Checks to see if the expression represented by NODE is a call to the internal + function FN. */ + +static inline bool +vect_match_call_p (slp_tree node, internal_fn fn) +{ + if (!node + || !SLP_TREE_REPRESENTATIVE (node)) + return false; + + gimple* expr = STMT_VINFO_STMT (SLP_TREE_REPRESENTATIVE (node)); + if (!expr + || !gimple_call_internal_p (expr, fn)) + return false; + + return true; +} + /* Check if the given lane permute in PERMUTES matches an alternating sequence of {even odd even odd ...}. This to account for unrolled loops. Further mode there resulting permute must be linear. */ @@ -1081,6 +1099,161 @@ complex_mul_pattern::build (vec_info *vinfo) complex_pattern::build (vinfo); } +/******************************************************************************* + * complex_fma_pattern class + ******************************************************************************/ + +class complex_fma_pattern : public complex_pattern +{ + protected: + complex_fma_pattern (slp_tree *node, vec *m_ops, internal_fn ifn) + : complex_pattern (node, m_ops, ifn) + { + this->m_num_args = 3; + } + + public: + void build (vec_info *); + static internal_fn + matches (complex_operation_t op, slp_tree_to_load_perm_map_t *, slp_tree *, + vec *); + + static vect_pattern* + recognize (slp_tree_to_load_perm_map_t *, slp_tree *); + + static vect_pattern* + mkInstance (slp_tree *node, vec *m_ops, internal_fn ifn) + { + return new complex_fma_pattern (node, m_ops, ifn); + } +}; + +/* Helper function to "reset" a previously matched node and undo the changes + made enough so that the node is treated as an irrelevant node. */ + +static inline void +vect_slp_reset_pattern (slp_tree node) +{ + stmt_vec_info stmt_info = vect_orig_stmt (SLP_TREE_REPRESENTATIVE (node)); + STMT_VINFO_IN_PATTERN_P (stmt_info) = false; + STMT_SLP_TYPE (stmt_info) = pure_slp; + SLP_TREE_REPRESENTATIVE (node) = stmt_info; +} + +/* Pattern matcher for trying to match complex multiply and accumulate + and multiply and subtract patterns in SLP tree. + If the operation matches then IFN is set to the operation it matched and + the arguments to the two replacement statements are put in m_ops. + + If no match is found then IFN is set to IFN_LAST and m_ops is unchanged. + + This function matches the patterns shaped as: + + double ax = (b[i+1] * a[i]) + (b[i] * a[i]); + double bx = (a[i+1] * b[i]) - (a[i+1] * b[i+1]); + + c[i] = c[i] - ax; + c[i+1] = c[i+1] + bx; + + If a match occurred then TRUE is returned, else FALSE. The match is + performed after COMPLEX_MUL which would have done the majority of the work. + This function merely matches an ADD with a COMPLEX_MUL IFN. The initial + match is expected to be in OP1 and the initial match operands in args0. */ + +internal_fn +complex_fma_pattern::matches (complex_operation_t op, + slp_tree_to_load_perm_map_t * /* perm_cache */, + slp_tree *ref_node, vec *ops) +{ + internal_fn ifn = IFN_LAST; + + /* Find the two components. We match Complex MUL first which reduces the + amount of work this pattern has to do. After that we just match the + head node and we're done.: + + * FMA: + +. + + We need to ignore the two_operands nodes that may also match. + For that we can check if they have any scalar statements and also + check that it's not a permute node as we're looking for a normal + PLUS_EXPR operation. */ + if (op != CMPLX_NONE) + return IFN_LAST; + + /* Find the two components. We match Complex MUL first which reduces the + amount of work this pattern has to do. After that we just match the + head node and we're done.: + + * FMA: + + on a non-two_operands node. */ + slp_tree vnode = *ref_node; + if (SLP_TREE_LANE_PERMUTATION (vnode).exists () + /* Need to exclude the plus two-operands node. These are not marked + so we have to infer it based on conditions. */ + || !SLP_TREE_SCALAR_STMTS (vnode).exists () + || !vect_match_expression_p (vnode, PLUS_EXPR)) + return IFN_LAST; + + slp_tree node = SLP_TREE_CHILDREN (vnode)[1]; + + if (vect_match_call_p (node, IFN_COMPLEX_MUL)) + ifn = IFN_COMPLEX_FMA; + else if (vect_match_call_p (node, IFN_COMPLEX_MUL_CONJ)) + ifn = IFN_COMPLEX_FMA_CONJ; + else + return IFN_LAST; + + if (!vect_pattern_validate_optab (ifn, vnode)) + return IFN_LAST; + + vect_slp_reset_pattern (node); + ops->truncate (0); + ops->create (3); + + if (ifn == IFN_COMPLEX_FMA) + { + ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]); + ops->quick_push (SLP_TREE_CHILDREN (node)[1]); + ops->quick_push (SLP_TREE_CHILDREN (node)[0]); + } + else + { + ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]); + ops->quick_push (SLP_TREE_CHILDREN (node)[0]); + ops->quick_push (SLP_TREE_CHILDREN (node)[1]); + } + + return ifn; +} + +/* Attempt to recognize a complex mul pattern. */ + +vect_pattern* +complex_fma_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache, + slp_tree *node) +{ + auto_vec ops; + complex_operation_t op + = vect_detect_pair_op (*node, true, &ops); + internal_fn ifn + = complex_fma_pattern::matches (op, perm_cache, node, &ops); + if (ifn == IFN_LAST) + return NULL; + + return new complex_fma_pattern (node, &ops, ifn); +} + +/* Perform a replacement of the detected complex mul pattern with the new + instruction sequences. */ + +void +complex_fma_pattern::build (vec_info *vinfo) +{ + SLP_TREE_CHILDREN (*this->m_node).truncate (0); + SLP_TREE_CHILDREN (*this->m_node).safe_splice (this->m_ops); + + complex_pattern::build (vinfo); +} + /******************************************************************************* * Pattern matching definitions ******************************************************************************/