From mboxrd@z Thu Jan 1 00:00:00 1970 Return-Path: Received: from mx2.suse.de (mx2.suse.de [195.135.220.15]) by sourceware.org (Postfix) with ESMTPS id 936DB3950C4D for ; Fri, 8 Jan 2021 09:37:18 +0000 (GMT) DMARC-Filter: OpenDMARC Filter v1.3.2 sourceware.org 936DB3950C4D Authentication-Results: sourceware.org; dmarc=none (p=none dis=none) header.from=suse.de Authentication-Results: sourceware.org; spf=pass smtp.mailfrom=rguenther@suse.de X-Virus-Scanned: by amavisd-new at test-mx.suse.de Received: from relay2.suse.de (unknown [195.135.221.27]) by mx2.suse.de (Postfix) with ESMTP id 4D67DAFA7; Fri, 8 Jan 2021 09:37:17 +0000 (UTC) Date: Fri, 8 Jan 2021 10:37:17 +0100 (CET) From: Richard Biener To: Tamar Christina cc: gcc-patches@gcc.gnu.org, nd@arm.com, ook@ucw.cz Subject: Re: [PATCH 5/8 v9]middle-end slp: support complex multiply and complex multiply conjugate In-Reply-To: <20201228133722.GA27314@arm.com> Message-ID: References: <20201228133722.GA27314@arm.com> User-Agent: Alpine 2.21 (LSU 202 2017-01-01) MIME-Version: 1.0 Content-Type: text/plain; charset=US-ASCII X-Spam-Status: No, score=-11.1 required=5.0 tests=BAYES_00, GIT_PATCH_0, KAM_DMARC_STATUS, KAM_LOTSOFHASH, RCVD_IN_MSPIKE_H3, RCVD_IN_MSPIKE_WL, SPF_HELO_NONE, SPF_PASS, TXREP autolearn=ham autolearn_force=no version=3.4.2 X-Spam-Checker-Version: SpamAssassin 3.4.2 (2018-09-13) on server2.sourceware.org X-BeenThere: gcc-patches@gcc.gnu.org X-Mailman-Version: 2.1.29 Precedence: list List-Id: Gcc-patches mailing list List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , X-List-Received-Date: Fri, 08 Jan 2021 09:37:21 -0000 On Mon, 28 Dec 2020, Tamar Christina wrote: > Hi All, > > This adds support for complex multiply and complex multiply and accumulate to > the vect pattern detector. > > Bootstrapped Regtested on aarch64-none-linux-gnu, x86_64-pc-linux-gnu > and no issues. > > Ok for master? > > Thanks, > Tamar > > gcc/ChangeLog: > > * internal-fn.def (COMPLEX_MUL, COMPLEX_MUL_CONJ): New. > * optabs.def (cmul_optab, cmul_conj_optab): New. > * doc/md.texi: Document them. > * tree-vect-slp-patterns.c (vect_match_call_complex_mla, > vect_normalize_conj_loc, is_eq_or_top, vect_validate_multiplication, > vect_build_combine_node, class complex_mul_pattern, > complex_mul_pattern::matches, complex_mul_pattern::recognize, > complex_mul_pattern::build): New. > > --- inline copy of patch -- > diff --git a/gcc/doc/md.texi b/gcc/doc/md.texi > index ec6ec180b91fcf9f481b6754c044483787fd923c..b8cc90e1a75e402abbf8a8cf2efefc1a333f8b3a 100644 > --- a/gcc/doc/md.texi > +++ b/gcc/doc/md.texi > @@ -6202,6 +6202,50 @@ The operation is only supported for vector modes @var{m}. > > This pattern is not allowed to @code{FAIL}. > > +@cindex @code{cmul@var{m}4} instruction pattern > +@item @samp{cmul@var{m}4} > +Perform a vector multiply that is semantically the same as multiply of > +complex numbers. > + > +@smallexample > + complex TYPE c[N]; > + complex TYPE a[N]; > + complex TYPE b[N]; > + for (int i = 0; i < N; i += 1) > + @{ > + c[i] = a[i] * b[i]; > + @} > +@end smallexample > + > +In GCC lane ordering the real part of the number must be in the even lanes with > +the imaginary part in the odd lanes. > + > +The operation is only supported for vector modes @var{m}. > + > +This pattern is not allowed to @code{FAIL}. > + > +@cindex @code{cmul_conj@var{m}4} instruction pattern > +@item @samp{cmul_conj@var{m}4} > +Perform a vector multiply by conjugate that is semantically the same as a > +multiply of complex numbers where the second multiply arguments is conjugated. > + > +@smallexample > + complex TYPE c[N]; > + complex TYPE a[N]; > + complex TYPE b[N]; > + for (int i = 0; i < N; i += 1) > + @{ > + c[i] = a[i] * conj (b[i]); > + @} > +@end smallexample > + > +In GCC lane ordering the real part of the number must be in the even lanes with > +the imaginary part in the odd lanes. > + > +The operation is only supported for vector modes @var{m}. > + > +This pattern is not allowed to @code{FAIL}. > + > @cindex @code{ffs@var{m}2} instruction pattern > @item @samp{ffs@var{m}2} > Store into operand 0 one plus the index of the least significant 1-bit > diff --git a/gcc/internal-fn.def b/gcc/internal-fn.def > index 511fe70162b5d9db3a61a5285d31c008f6835487..5a0bbe3fe5dee591d54130e60f6996b28164ae38 100644 > --- a/gcc/internal-fn.def > +++ b/gcc/internal-fn.def > @@ -279,6 +279,8 @@ DEF_INTERNAL_FLT_FLOATN_FN (FMAX, ECF_CONST, fmax, binary) > DEF_INTERNAL_OPTAB_FN (XORSIGN, ECF_CONST, xorsign, binary) > DEF_INTERNAL_OPTAB_FN (COMPLEX_ADD_ROT90, ECF_CONST, cadd90, binary) > DEF_INTERNAL_OPTAB_FN (COMPLEX_ADD_ROT270, ECF_CONST, cadd270, binary) > +DEF_INTERNAL_OPTAB_FN (COMPLEX_MUL, ECF_CONST, cmul, binary) > +DEF_INTERNAL_OPTAB_FN (COMPLEX_MUL_CONJ, ECF_CONST, cmul_conj, binary) > > > /* FP scales. */ > diff --git a/gcc/optabs.def b/gcc/optabs.def > index e9727def4dbf941bb9ac8b56f83f8ea0f52b262c..e82396bae1117c6de91304761a560b7fbcb69ce1 100644 > --- a/gcc/optabs.def > +++ b/gcc/optabs.def > @@ -292,6 +292,8 @@ OPTAB_D (copysign_optab, "copysign$F$a3") > OPTAB_D (xorsign_optab, "xorsign$F$a3") > OPTAB_D (cadd90_optab, "cadd90$a3") > OPTAB_D (cadd270_optab, "cadd270$a3") > +OPTAB_D (cmul_optab, "cmul$a3") > +OPTAB_D (cmul_conj_optab, "cmul_conj$a3") > OPTAB_D (cos_optab, "cos$a2") > OPTAB_D (cosh_optab, "cosh$a2") > OPTAB_D (exp10_optab, "exp10$a2") > diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c > index dbc58f7c53868ed431fc67de1f0162eb0d3b2c24..82721acbab8cf81c4d6f9954c98fb913a7bb6282 100644 > --- a/gcc/tree-vect-slp-patterns.c > +++ b/gcc/tree-vect-slp-patterns.c > @@ -719,6 +719,368 @@ complex_add_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache, > return new complex_add_pattern (node, &ops, ifn); > } > > +/******************************************************************************* > + * complex_mul_pattern > + ******************************************************************************/ > + > +/* Helper function of that looks for a match in the CHILDth child of NODE. The > + child used is stored in RES. > + > + If the match is successful then ARGS will contain the operands matched > + and the complex_operation_t type is returned. If match is not successful > + then CMPLX_NONE is returned and ARGS is left unmodified. */ > + > +static inline complex_operation_t > +vect_match_call_complex_mla (slp_tree node, unsigned child, > + vec *args = NULL, slp_tree *res = NULL) > +{ > + gcc_assert (child < SLP_TREE_CHILDREN (node).length ()); > + > + slp_tree data = SLP_TREE_CHILDREN (node)[child]; > + > + if (res) > + *res = data; > + > + return vect_detect_pair_op (data, false, args); > +} > + > +/* Check to see if either of the trees in ARGS are a NEGATE_EXPR. If the first > + child (args[0]) is a NEGATE_EXPR then NEG_FIRST_P is set to TRUE. > + > + If a negate is found then the values in ARGS are reordered such that the > + negate node is always the second one and the entry is replaced by the child > + of the negate node. */ > + > +static inline bool > +vect_normalize_conj_loc (vec args, bool *neg_first_p = NULL) > +{ > + gcc_assert (args.length () == 2); > + bool neg_found = false; > + > + if (vect_match_expression_p (args[0], NEGATE_EXPR)) > + { > + std::swap (args[0], args[1]); > + neg_found = true; > + if (neg_first_p) > + *neg_first_p = true; > + } > + else if (vect_match_expression_p (args[1], NEGATE_EXPR)) > + { > + neg_found = true; > + if (neg_first_p) > + *neg_first_p = false; > + } > + > + if (neg_found) > + args[1] = SLP_TREE_CHILDREN (args[1])[0]; > + > + return neg_found; > +} > + > +/* Helper function to check if PERM is KIND or PERM_TOP. */ > + > +static inline bool > +is_eq_or_top (complex_load_perm_t perm, complex_perm_kinds_t kind) > +{ > + return perm.first == kind || perm.first == PERM_TOP; > +} > + > +/* Helper function that checks to see if LEFT_OP and RIGHT_OP are both MULT_EXPR > + nodes but also that they represent an operation that is either a complex > + multiplication or a complex multiplication by conjugated value. > + > + Of the negation is expected to be in the first half of the tree (As required > + by an FMS pattern) then NEG_FIRST is true. If the operation is a conjugate > + operation then CONJ_FIRST_OPERAND is set to indicate whether the first or > + second operand contains the conjugate operation. */ > + > +static inline bool > +vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache, > + vec left_op, vec right_op, > + bool neg_first, bool *conj_first_operand, > + bool fms) > +{ > + /* The presence of a negation indicates that we have either a conjugate or a > + rotation. We need to distinguish which one. */ > + *conj_first_operand = false; > + complex_perm_kinds_t kind; > + > + /* Complex conjugates have the negation on the imaginary part of the > + number where rotations affect the real component. So check if the > + negation is on a dup of lane 1. */ > + if (fms) > + { > + /* Canonicalization for fms is not consistent. So have to test both > + variants to be sure. This needs to be fixed in the mid-end so > + this part can be simpler. */ > + kind = linear_loads_p (perm_cache, right_op[0]).first; > + if (!((kind == PERM_ODDODD > + && is_eq_or_top (linear_loads_p (perm_cache, right_op[1]), > + PERM_ODDEVEN)) > + || (kind == PERM_ODDEVEN > + && is_eq_or_top (linear_loads_p (perm_cache, right_op[1]), > + PERM_ODDODD)))) > + return false; > + } > + else > + { > + if (linear_loads_p (perm_cache, right_op[1]).first != PERM_ODDODD > + && !is_eq_or_top (linear_loads_p (perm_cache, right_op[0]), > + PERM_ODDEVEN)) > + return false; > + } > + > + /* Deal with differences in indexes. */ > + int index1 = fms ? 1 : 0; > + int index2 = fms ? 0 : 1; > + > + /* Check if the conjugate is on the second first or second operand. The > + order of the node with the conjugate value determines this, and the dup > + node must be one of lane 0 of the same DR as the neg node. */ > + kind = linear_loads_p (perm_cache, left_op[index1]).first; > + if (kind == PERM_TOP) > + { > + if (linear_loads_p (perm_cache, left_op[index2]).first == PERM_EVENODD) > + return true; > + } > + else if (kind == PERM_EVENODD) > + { > + if ((kind = linear_loads_p (perm_cache, left_op[index2]).first) == PERM_EVENODD) > + return false; > + } > + else if (!neg_first) > + *conj_first_operand = true; > + else > + return false; > + > + if (kind != PERM_EVENEVEN) > + return false; > + > + return true; > +} > + > +/* Helper function to help distinguish between a conjugate and a rotation in a > + complex multiplication. The operations have similar shapes but the order of > + the load permutes are different. This function returns TRUE when the order > + is consistent with a multiplication or multiplication by conjugated > + operand but returns FALSE if it's a multiplication by rotated operand. */ > + > +static inline bool > +vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache, > + vec op, complex_perm_kinds_t permKind) > +{ > + /* The left node is the more common case, test it first. */ > + if (!is_eq_or_top (linear_loads_p (perm_cache, op[0]), permKind)) > + { > + if (!is_eq_or_top (linear_loads_p (perm_cache, op[1]), permKind)) > + return false; > + } > + return true; > +} > + > +/* This function combines two nodes containing only even and only odd lanes > + together into a single node which contains the nodes in even/odd order > + by using a lane permute. */ > + > +static slp_tree > +vect_build_combine_node (slp_tree even, slp_tree odd, slp_tree rep) > +{ > + auto_vec nodes; > + nodes.create (2); > + vec > perm; > + perm.create (SLP_TREE_LANES (rep)); > + > + for (unsigned x = 0; x < SLP_TREE_LANES (rep); x+=2) > + { > + perm.quick_push (std::make_pair (0, x)); > + perm.quick_push (std::make_pair (1, x)); > + } That looks wrong, it creates {0,0}, {1, 0}, {0, 2}, {1, 2} but you want {0, 0}, {1, 0}, {0, 1}, {1, 1} AFAICS. At least I assume SLP_TREE_LANES (odd/even) == SLP_TREE_LANES (rep) / 2? 'rep' isn't documented, I assume it's supoosed to be a "representative" for the result? > + > + nodes.quick_push (even); > + nodes.quick_push (odd); No need for this intermediate nodes array, just push to ... > + SLP_TREE_REF_COUNT (even)++; > + SLP_TREE_REF_COUNT (odd)++; > + > + slp_tree vnode = vect_create_new_slp_node (2, SLP_TREE_CODE (even)); > + SLP_TREE_CODE (vnode) = VEC_PERM_EXPR; > + SLP_TREE_LANE_PERMUTATION (vnode) = perm; > + SLP_TREE_CHILDREN (vnode).safe_splice (nodes); ... the children array directly (even with quick_push, we've already allocated 2 elements for the children). > + SLP_TREE_REF_COUNT (vnode) = 1; > + SLP_TREE_LANES (vnode) = SLP_TREE_LANES (rep); > + gcc_assert (perm.length () == SLP_TREE_LANES (vnode)); > + /* Representation is set to that of the current node as the vectorizer > + can't deal with VEC_PERMs with no representation, as would be the > + case with invariants. */ Yeah, I need to fix this ... > + SLP_TREE_REPRESENTATIVE (vnode) = SLP_TREE_REPRESENTATIVE (rep); > + SLP_TREE_VECTYPE (vnode) = SLP_TREE_VECTYPE (rep); > + return vnode; > +} > + > +class complex_mul_pattern : public complex_pattern > +{ > + protected: > + complex_mul_pattern (slp_tree *node, vec *m_ops, internal_fn ifn) > + : complex_pattern (node, m_ops, ifn) > + { > + this->m_num_args = 2; > + } > + > + public: > + void build (vec_info *); > + static internal_fn > + matches (complex_operation_t op, slp_tree_to_load_perm_map_t *, slp_tree *, > + vec *); > + > + static vect_pattern* > + recognize (slp_tree_to_load_perm_map_t *, slp_tree *); > + > + static vect_pattern* > + mkInstance (slp_tree *node, vec *m_ops, internal_fn ifn) > + { > + return new complex_mul_pattern (node, m_ops, ifn); > + } > + > +}; > + > +/* Pattern matcher for trying to match complex multiply pattern in SLP tree > + If the operation matches then IFN is set to the operation it matched > + and the arguments to the two replacement statements are put in m_ops. > + > + If no match is found then IFN is set to IFN_LAST and m_ops is unchanged. > + > + This function matches the patterns shaped as: > + > + double ax = (b[i+1] * a[i]); > + double bx = (a[i+1] * b[i]); > + > + c[i] = c[i] - ax; > + c[i+1] = c[i+1] + bx; > + > + If a match occurred then TRUE is returned, else FALSE. The initial match is > + expected to be in OP1 and the initial match operands in args0. */ > + > +internal_fn > +complex_mul_pattern::matches (complex_operation_t op, > + slp_tree_to_load_perm_map_t *perm_cache, > + slp_tree *node, vec *ops) > +{ > + internal_fn ifn = IFN_LAST; > + > + if (op != MINUS_PLUS) > + return IFN_LAST; > + > + slp_tree root = *node; > + /* First two nodes must be a multiply. */ > + auto_vec muls; > + if (vect_match_call_complex_mla (root, 0) != MULT_MULT > + || vect_match_call_complex_mla (root, 1, &muls) != MULT_MULT) > + return IFN_LAST; > + > + /* Now operand2+4 may lead to another expression. */ > + auto_vec left_op, right_op; > + left_op.safe_splice (SLP_TREE_CHILDREN (muls[0])); > + right_op.safe_splice (SLP_TREE_CHILDREN (muls[1])); > + > + if (linear_loads_p (perm_cache, left_op[1]).first == PERM_ODDEVEN) > + return IFN_LAST; > + > + bool neg_first; > + bool is_neg = vect_normalize_conj_loc (right_op, &neg_first); > + > + if (!is_neg) > + { > + /* A multiplication needs to multiply agains the real pair, otherwise > + the pattern matches that of FMS. */ > + if (!vect_validate_multiplication (perm_cache, left_op, PERM_EVENEVEN) > + || vect_normalize_conj_loc (left_op)) > + return IFN_LAST; > + ifn = IFN_COMPLEX_MUL; > + } > + else if (is_neg) > + { > + bool conj_first_operand; > + if (!vect_validate_multiplication (perm_cache, left_op, right_op, > + neg_first, &conj_first_operand, > + false)) > + return IFN_LAST; > + > + ifn = IFN_COMPLEX_MUL_CONJ; > + } > + > + if (!vect_pattern_validate_optab (ifn, *node)) > + return IFN_LAST; > + > + ops->truncate (0); > + ops->create (3); > + > + complex_perm_kinds_t kind = linear_loads_p (perm_cache, left_op[0]).first; > + if (kind == PERM_EVENODD) > + { > + ops->quick_push (left_op[1]); > + ops->quick_push (right_op[1]); > + ops->quick_push (left_op[0]); > + } > + else if (kind == PERM_TOP) > + { > + ops->quick_push (left_op[1]); > + ops->quick_push (right_op[1]); > + ops->quick_push (left_op[0]); > + } > + else > + { > + ops->quick_push (left_op[0]); > + ops->quick_push (right_op[0]); > + ops->quick_push (left_op[1]); > + } > + > + return ifn; > +} > + > +/* Attempt to recognize a complex mul pattern. */ > + > +vect_pattern* > +complex_mul_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache, > + slp_tree *node) > +{ > + auto_vec ops; > + complex_operation_t op > + = vect_detect_pair_op (*node, true, &ops); > + internal_fn ifn > + = complex_mul_pattern::matches (op, perm_cache, node, &ops); > + if (ifn == IFN_LAST) > + return NULL; > + > + return new complex_mul_pattern (node, &ops, ifn); > +} > + > +/* Perform a replacement of the detected complex mul pattern with the new > + instruction sequences. */ > + > +void > +complex_mul_pattern::build (vec_info *vinfo) > +{ > + auto_vec nodes; > + > + /* First re-arrange the children. */ > + nodes.create (2); > + > + nodes.quick_push (this->m_ops[2]); > + nodes.quick_push ( > + vect_build_combine_node (this->m_ops[0], this->m_ops[1], *this->m_node)); > + SLP_TREE_REF_COUNT (this->m_ops[2])++; > + > + slp_tree node; > + unsigned i; > + FOR_EACH_VEC_ELT (SLP_TREE_CHILDREN (*this->m_node), i, node) > + vect_free_slp_tree (node); > + > + SLP_TREE_CHILDREN (*this->m_node).truncate (0); > + SLP_TREE_CHILDREN (*this->m_node).safe_splice (nodes); please elide the nodes array. *this->m_node now has a "wrong" representative but I guess > + complex_pattern::build (vinfo); will fix that up? I still find the structure of the pattern matching & transform hard to follow. But well - I've settled with the idea of refactoring it for next stage1 after the fact ;) Thanks, Richard.