26static bool validateFMAOperands(
Value op) {
27 if (
auto cvt = op.
getDefiningOp<x86vector::CvtPackedEvenIndexedToF32Op>())
28 return cvt.getResult().hasOneUse();
30 if (
auto bcst = op.
getDefiningOp<x86vector::BcstToPackedF32Op>())
31 return bcst.getResult().hasOneUse();
41static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
45 if (!isa<x86vector::CvtPackedEvenIndexedToF32Op>(
lhs.getDefiningOp()) &&
46 !isa<x86vector::CvtPackedEvenIndexedToF32Op>(
rhs.getDefiningOp()))
49 if (!validateFMAOperands(
lhs) || !validateFMAOperands(
rhs))
52 if (
lhs.getDefiningOp()->getBlock() !=
rhs.getDefiningOp()->getBlock())
55 if (
lhs.getDefiningOp()->getBlock() != fmaOp->getBlock())
58 if (!fmaOp.getResult().hasOneUse())
62 if (consumer->
getBlock() != fmaOp->getBlock())
75 if (
auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(consumer)) {
76 if (shapeCastOp.getResult().hasOneUse()) {
78 if (nxtConsumer->
getBlock() == fmaOp->getBlock()) {
80 rewriter.
moveOpBefore(fmaOp.getLhs().getDefiningOp(), consumer);
81 rewriter.
moveOpBefore(fmaOp.getRhs().getDefiningOp(), consumer);
83 rewriter.
moveOpBefore(shapeCastOp.getOperation(), consumer);
89 rewriter.
moveOpBefore(fmaOp.getLhs().getDefiningOp(), consumer);
90 rewriter.
moveOpBefore(fmaOp.getRhs().getDefiningOp(), consumer);
134 using OpRewritePattern<vector::FMAOp>::OpRewritePattern;
136 LogicalResult matchAndRewrite(vector::FMAOp fmaOp,
137 PatternRewriter &rewriter)
const override {
139 if (!validateVectorFMAOp(fmaOp))
142 llvm::SmallVector<vector::FMAOp> fmaOps;
143 Operation *nextOp = fmaOp;
144 bool stopAtNextDependentFMA =
true;
148 while ((nextOp = nextOp->getNextNode())) {
149 auto fma = dyn_cast<vector::FMAOp>(nextOp);
153 bool hasX86CvtOperand = isa<x86vector::CvtPackedEvenIndexedToF32Op>(
154 fma.getLhs().getDefiningOp()) ||
155 isa<x86vector::CvtPackedEvenIndexedToF32Op>(
156 fma.getRhs().getDefiningOp());
158 if (hasX86CvtOperand && stopAtNextDependentFMA)
161 if (validateVectorFMAOp(fma))
162 fmaOps.push_back(fma);
164 stopAtNextDependentFMA =
false;
169 fmaOp,
"No eligible FMA operations were found: the operation may "
170 "already be shuffled, there may be no following FMAs, or the "
171 "following FMAs do not satisfy the shuffle conditions.");
173 fmaOps.push_back(fmaOp);
174 for (
auto fmaOp : fmaOps)
175 moveFMA(rewriter, fmaOp);
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...