MLIR 23.0.0git
ShuffleVectorFMAOps.cpp
Go to the documentation of this file.
1//===- ShuffleVectorFMAOps.cpp --------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12
14
15#include "mlir/Pass/Pass.h"
17
18using namespace mlir;
19using namespace mlir::vector;
20using namespace mlir::x86vector;
21
22namespace {
23
24// Validates whether the given operation is an x86vector operation and has only
25// one consumer.
26static bool validateFMAOperands(Value op) {
27 if (auto cvt = op.getDefiningOp<x86vector::CvtPackedEvenIndexedToF32Op>())
28 return cvt.getResult().hasOneUse();
29
30 if (auto bcst = op.getDefiningOp<x86vector::BcstToPackedF32Op>())
31 return bcst.getResult().hasOneUse();
32
33 return false;
34}
35
36// Validates the vector.fma operation on the following conditions:
37// (i) one of the lhs or rhs defining operation should be
38// CvtPackedEvenIndexedToF32Op, (ii) the lhs or rhs defining operation should be
39// an x86vector operation and has only one consumer, (iii) all operations
40// are in the same block, and (iv) ths FMA has only one user.
41static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
42 Value lhs = fmaOp.getLhs();
43 Value rhs = fmaOp.getRhs();
44
45 if (!isa<x86vector::CvtPackedEvenIndexedToF32Op>(lhs.getDefiningOp()) &&
46 !isa<x86vector::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
47 return false;
48
49 if (!validateFMAOperands(lhs) || !validateFMAOperands(rhs))
50 return false;
51
52 if (lhs.getDefiningOp()->getBlock() != rhs.getDefiningOp()->getBlock())
53 return false;
54
55 if (lhs.getDefiningOp()->getBlock() != fmaOp->getBlock())
56 return false;
57
58 if (!fmaOp.getResult().hasOneUse())
59 return false;
60
61 Operation *consumer = *fmaOp.getResult().getUsers().begin();
62 if (consumer->getBlock() != fmaOp->getBlock())
63 return false;
64
65 return true;
66}
67
68// Moves vector.fma along with the lhs and rhs defining operation before its
69// consumer. If the consumer is vector.ShapeCastOp and has only one user then
70// move before the consumer of vector.ShapeCastOp.
71// TODO: Move before first consumer, if there are multiple.
72static void moveFMA(PatternRewriter &rewriter, vector::FMAOp fmaOp) {
73 Operation *consumer = *fmaOp.getResult().getUsers().begin();
74
75 if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(consumer)) {
76 if (shapeCastOp.getResult().hasOneUse()) {
77 Operation *nxtConsumer = *shapeCastOp.getResult().getUsers().begin();
78 if (nxtConsumer->getBlock() == fmaOp->getBlock()) {
79 consumer = *shapeCastOp.getResult().getUsers().begin();
80 rewriter.moveOpBefore(fmaOp.getLhs().getDefiningOp(), consumer);
81 rewriter.moveOpBefore(fmaOp.getRhs().getDefiningOp(), consumer);
82 rewriter.moveOpBefore(fmaOp.getOperation(), consumer);
83 rewriter.moveOpBefore(shapeCastOp.getOperation(), consumer);
84 return;
85 }
86 }
87 }
88
89 rewriter.moveOpBefore(fmaOp.getLhs().getDefiningOp(), consumer);
90 rewriter.moveOpBefore(fmaOp.getRhs().getDefiningOp(), consumer);
91 rewriter.moveOpBefore(fmaOp.getOperation(), consumer);
92
93 return;
94}
95
96// Shuffle FMAs with x86vector operations as operands such that
97// FMAs are grouped with respect to odd/even packed index.
98//
99// For example:
100// ```
101// %1 = x86vector.avx.bcst_to_f32.packed
102// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32
103// %3 = vector.fma %1, %2, %arg1
104// %4 = x86vector.avx.bcst_to_f32.packed
105// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32
106// %6 = vector.fma %4, %5, %3
107// %7 = x86vector.avx.bcst_to_f32.packed
108// %8 = x86vector.avx.cvt.packed.odd.indexed_to_f32
109// %9 = vector.fma %7, %8, %arg2
110// %10 = x86vector.avx.bcst_to_f32.packed
111// %11 = x86vector.avx.cvt.packed.even.indexed_to_f32
112// %12 = vector.fma %10, %11, %9
113// yield %6, %12
114// ```
115// to
116// ```
117// %1 = x86vector.avx.bcst_to_f32.packed
118// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32
119// %3 = vector.fma %1, %2, %arg1
120// %7 = x86vector.avx.bcst_to_f32.packed
121// %8 = x86vector.avx.cvt.packed.odd.indexed_to_f32
122// %9 = vector.fma %7, %8, %arg2
123// %4 = x86vector.avx.bcst_to_f32.packed
124// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32
125// %6 = vector.fma %4, %5, %3
126// %10 = x86vector.avx.bcst_to_f32.packed
127// %11 = x86vector.avx.cvt.packed.even.indexed_to_f32
128// %12 = vector.fma %10, %11, %9
129// yield %9, %12
130// ```
131// TODO: Shuffling supported only if the FMA, lhs/rhs defining operations
132// have only one consumer. Have to extend this pass for multiple consumers.
133struct ShuffleVectorFMAOps : public OpRewritePattern<vector::FMAOp> {
134 using OpRewritePattern<vector::FMAOp>::OpRewritePattern;
135
136 LogicalResult matchAndRewrite(vector::FMAOp fmaOp,
137 PatternRewriter &rewriter) const override {
138
139 if (!validateVectorFMAOp(fmaOp))
140 return failure();
141
142 llvm::SmallVector<vector::FMAOp> fmaOps;
143 Operation *nextOp = fmaOp;
144 bool stopAtNextDependentFMA = true;
145
146 // Break the loop and return failure if the immediate next FMA op
147 // have CvtPackedEvenIndexedToF32Op in it's lhs/rhs defining ops.
148 while ((nextOp = nextOp->getNextNode())) {
149 auto fma = dyn_cast<vector::FMAOp>(nextOp);
150 if (!fma)
151 continue;
152
153 bool hasX86CvtOperand = isa<x86vector::CvtPackedEvenIndexedToF32Op>(
154 fma.getLhs().getDefiningOp()) ||
155 isa<x86vector::CvtPackedEvenIndexedToF32Op>(
156 fma.getRhs().getDefiningOp());
157
158 if (hasX86CvtOperand && stopAtNextDependentFMA)
159 break;
160
161 if (validateVectorFMAOp(fma))
162 fmaOps.push_back(fma);
163
164 stopAtNextDependentFMA = false;
165 }
166
167 if (fmaOps.empty())
168 return rewriter.notifyMatchFailure(
169 fmaOp, "No eligible FMA operations were found: the operation may "
170 "already be shuffled, there may be no following FMAs, or the "
171 "following FMAs do not satisfy the shuffle conditions.");
172
173 fmaOps.push_back(fmaOp);
174 for (auto fmaOp : fmaOps)
175 moveFMA(rewriter, fmaOp);
176
177 return success();
178 }
179};
180
181} // namespace
182
185 patterns.add<ShuffleVectorFMAOps>(patterns.getContext());
186}
return success()
lhs
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...