21 #include "llvm/ADT/TypeSwitch.h"
23 #define DEBUG_TYPE "arm-sme-outerproduct-fusion"
26 #define GEN_PASS_DEF_OUTERPRODUCTFUSION
27 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
36 static constexpr StringLiteral
37 kMatchFailureNoAccumulator(
"no accumulator operand");
38 static constexpr StringLiteral kMatchFailureExpectedOuterProductDefOp(
39 "defining op of accumulator must be 'arm_sme.outerproduct'");
40 static constexpr StringLiteral kMatchFailureInconsistentCombiningKind(
41 "combining kind (add or sub) of outer products must match");
42 static constexpr StringLiteral kMatchFailureInconsistentMasking(
43 "unsupported masking, either both outerproducts are masked "
45 static constexpr StringLiteral kMatchFailureOuterProductNotSingleUse(
46 "outer product(s) not single use and cannot be removed, no benefit to "
55 template <
typename LhsExtOp,
typename RhsExtOp = LhsExtOp>
57 arm_sme::OuterProductOp op,
58 VectorType resultType, VectorType inputType) {
59 if (op.getResultType() != resultType)
61 diag <<
"unsupported result type, expected " << resultType;
64 auto lhsDefOp = op.getLhs().getDefiningOp<LhsExtOp>();
65 auto rhsDefOp = op.getRhs().getDefiningOp<RhsExtOp>();
67 if (!lhsDefOp || !rhsDefOp)
69 op,
"defining op of outerproduct operands must be one of: "
70 "'arith.extf' or 'arith.extsi' or 'arith.extui'");
72 auto lhsInType = cast<VectorType>(lhsDefOp.getIn().getType());
73 auto rhsInType = cast<VectorType>(rhsDefOp.getIn().getType());
75 if (lhsInType != inputType || rhsInType != inputType)
77 diag <<
"unsupported input type, expected " << inputType;
86 auto inputType = cast<VectorType>(lhs.
getType());
87 VectorType inputTypeX2 =
89 return rewriter.
create<LLVM::experimental_vector_interleave2>(
90 loc, inputTypeX2, lhs, rhs);
116 class OuterProductFusion2Way
123 Value acc = op.getAcc();
127 arm_sme::OuterProductOp op1 = acc.
getDefiningOp<arm_sme::OuterProductOp>();
128 arm_sme::OuterProductOp op2 = op;
131 op, kMatchFailureExpectedOuterProductDefOp);
133 if (op1.getKind() != op2.getKind())
135 op, kMatchFailureInconsistentCombiningKind);
137 if (!op1->hasOneUse()) {
162 kMatchFailureOuterProductNotSingleUse);
165 if (
bool(op1.getLhsMask()) != bool(op2.getLhsMask()))
168 if (
failed(canFuseOuterProducts(rewriter, op1, op2)))
173 return createInterleave2Intrinsic(rewriter, loc, lhs, rhs);
176 auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
177 op2.getLhs().getDefiningOp()->getOperand(0));
178 auto rhs = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
179 op2.getRhs().getDefiningOp()->getOperand(0));
181 Value lhsMask, rhsMask;
182 if (op1.getLhsMask() || op2.getLhsMask()) {
183 lhsMask = packInputs(op1.getLhsMask(), op2.getLhsMask());
184 rhsMask = packInputs(op1.getRhsMask(), op2.getRhsMask());
187 auto extOp = op.getLhs().getDefiningOp();
189 arm_sme::CombiningKind kind = op.getKind();
190 if (kind == arm_sme::CombiningKind::Add) {
192 .Case<arith::ExtFOp>([&](
auto) {
194 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
197 .Case<arith::ExtSIOp>([&](
auto) {
199 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
202 .Case<arith::ExtUIOp>([&](
auto) {
204 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
207 .Default([&](
auto) { llvm_unreachable(
"unexpected extend op!"); });
208 }
else if (kind == arm_sme::CombiningKind::Sub) {
210 .Case<arith::ExtFOp>([&](
auto) {
212 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
215 .Case<arith::ExtSIOp>([&](
auto) {
217 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
220 .Case<arith::ExtUIOp>([&](
auto) {
222 op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
225 .Default([&](
auto) { llvm_unreachable(
"unexpected extend op!"); });
227 llvm_unreachable(
"unexpected arm_sme::CombiningKind!");
245 arm_sme::OuterProductOp op1,
246 arm_sme::OuterProductOp op2)
const {
259 isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4f16)) ||
261 isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32, nxv4f16))) &&
263 isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4bf16)) ||
264 failed(isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32,
267 isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
268 failed(isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32,
271 isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
273 isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i16))))
282 class OuterProductFusion4Way
290 outerProductChain.push_back(op);
292 for (
int i = 0; i < 3; ++i) {
293 auto currentOp = outerProductChain.back();
294 auto acc = currentOp.getAcc();
297 auto previousOp = acc.
getDefiningOp<arm_sme::OuterProductOp>();
300 op, kMatchFailureExpectedOuterProductDefOp);
301 if (!previousOp->hasOneUse())
303 op, kMatchFailureOuterProductNotSingleUse);
304 if (previousOp.getKind() != currentOp.getKind())
306 op, kMatchFailureInconsistentCombiningKind);
307 if (
bool(previousOp.getLhsMask()) != bool(currentOp.getLhsMask()))
309 op, kMatchFailureInconsistentCombiningKind);
310 outerProductChain.push_back(previousOp);
313 if (
failed(canFuseOuterProducts(rewriter, outerProductChain)))
316 arm_sme::OuterProductOp op1 = outerProductChain[3];
317 arm_sme::OuterProductOp op2 = outerProductChain[2];
318 arm_sme::OuterProductOp op3 = outerProductChain[1];
319 arm_sme::OuterProductOp op4 = outerProductChain[0];
323 return createInterleave2Intrinsic(rewriter, loc, lhs, rhs);
326 auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
327 op3.getLhs().getDefiningOp()->getOperand(0));
328 auto lhs1 = packInputs(op2.getLhs().getDefiningOp()->getOperand(0),
329 op4.getLhs().getDefiningOp()->getOperand(0));
330 auto lhs = packInputs(lhs0, lhs1);
332 auto rhs0 = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
333 op3.getRhs().getDefiningOp()->getOperand(0));
334 auto rhs1 = packInputs(op2.getRhs().getDefiningOp()->getOperand(0),
335 op4.getRhs().getDefiningOp()->getOperand(0));
336 auto rhs = packInputs(rhs0, rhs1);
338 Value lhsMask, rhsMask;
339 if (op1.getLhsMask() || op2.getLhsMask() || op3.getLhsMask() ||
341 auto lhs0Mask = packInputs(op1.getLhsMask(), op3.getLhsMask());
342 auto lhs1Mask = packInputs(op2.getLhsMask(), op4.getLhsMask());
343 lhsMask = packInputs(lhs0Mask, lhs1Mask);
345 auto rhs0Mask = packInputs(op1.getRhsMask(), op3.getRhsMask());
346 auto rhs1Mask = packInputs(op2.getRhsMask(), op4.getRhsMask());
347 rhsMask = packInputs(rhs0Mask, rhs1Mask);
350 auto lhsExtOp = op.getLhs().getDefiningOp();
351 auto rhsExtOp = op.getRhs().getDefiningOp();
353 arm_sme::CombiningKind kind = op.getKind();
354 if (kind == arm_sme::CombiningKind::Add) {
355 if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
358 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
359 }
else if (isa<arith::ExtUIOp>(lhsExtOp) &&
360 isa<arith::ExtUIOp>(rhsExtOp)) {
363 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
364 }
else if (isa<arith::ExtSIOp>(lhsExtOp) &&
365 isa<arith::ExtUIOp>(rhsExtOp)) {
368 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
369 }
else if (isa<arith::ExtUIOp>(lhsExtOp) &&
370 isa<arith::ExtSIOp>(rhsExtOp)) {
373 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
375 llvm_unreachable(
"unexpected extend op!");
377 }
else if (kind == arm_sme::CombiningKind::Sub) {
378 if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
381 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
382 }
else if (isa<arith::ExtUIOp>(lhsExtOp) &&
383 isa<arith::ExtUIOp>(rhsExtOp)) {
386 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
387 }
else if (isa<arith::ExtSIOp>(lhsExtOp) &&
388 isa<arith::ExtUIOp>(rhsExtOp)) {
391 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
392 }
else if (isa<arith::ExtUIOp>(lhsExtOp) &&
393 isa<arith::ExtSIOp>(rhsExtOp)) {
396 op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
398 llvm_unreachable(
"unexpected extend op!");
401 llvm_unreachable(
"unexpected arm_sme::CombiningKind!");
435 auto failedToMatch = [&](VectorType resultType, VectorType inputType,
436 auto lhsExtendOp,
auto rhsExtendOp) {
437 using LhsExtendOpTy = decltype(lhsExtendOp);
438 using RhsExtendOpTy = decltype(rhsExtendOp);
439 for (
auto op : ops) {
440 if (
failed(isCompatible<LhsExtendOpTy, RhsExtendOpTy>(
441 rewriter, op, resultType, inputType)))
447 if (failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
448 failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
449 failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
450 failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtSIOp{}) &&
451 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
452 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
453 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
454 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtSIOp{}))
472 struct SwapVectorExtractOfArithExtend
478 VectorType resultType = llvm::dyn_cast<VectorType>(extractOp.getType());
481 "extracted type is not a vector type");
483 auto numScalableDims = llvm::count(resultType.getScalableDims(),
true);
484 if (numScalableDims != 1)
486 extractOp,
"extracted type is not a 1-D scalable vector type");
488 auto *extendOp = extractOp.getVector().getDefiningOp();
489 if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
492 "extract not from extend op");
494 auto loc = extractOp.getLoc();
495 StringAttr extendOpName = extendOp->getName().getIdentifier();
496 Value extendSource = extendOp->getOperand(0);
499 Value newExtract = rewriter.
create<vector::ExtractOp>(
500 loc, extendSource, extractOp.getMixedPosition());
504 rewriter.
create(loc, extendOpName,
Value(newExtract), resultType);
506 rewriter.
replaceOp(extractOp, newExtend);
523 struct SwapVectorScalableExtractOfArithExtend
527 LogicalResult matchAndRewrite(vector::ScalableExtractOp extractOp,
529 auto *extendOp = extractOp.getSource().getDefiningOp();
530 if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
533 "extract not from extend op");
535 auto loc = extractOp.getLoc();
536 VectorType resultType = extractOp.getResultVectorType();
538 Value extendSource = extendOp->getOperand(0);
539 StringAttr extendOpName = extendOp->getName().getIdentifier();
540 VectorType extendSourceVectorType =
541 cast<VectorType>(extendSource.
getType());
544 VectorType extractResultVectorType =
545 resultType.clone(extendSourceVectorType.getElementType());
546 Value newExtract = rewriter.
create<vector::ScalableExtractOp>(
547 loc, extractResultVectorType, extendSource, extractOp.getPos());
551 rewriter.
create(loc, extendOpName,
Value(newExtract), resultType);
553 rewriter.
replaceOp(extractOp, newExtend);
559 struct OuterProductFusionPass
560 :
public arm_sme::impl::OuterProductFusionBase<OuterProductFusionPass> {
562 void runOnOperation()
override {
578 patterns.
add<SwapVectorExtractOfArithExtend,
579 SwapVectorScalableExtractOfArithExtend>(context, 1024);
580 patterns.
add<OuterProductFusion2Way, OuterProductFusion4Way>(context);
584 return std::make_unique<OuterProductFusionPass>();
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setDim(unsigned pos, int64_t val)
Set a dim in shape @pos to val.
std::unique_ptr< Pass > createOuterProductFusionPass()
Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening variants.
void populateOuterProductFusionPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...