20#include "llvm/ADT/TypeSwitch.h"
22#define DEBUG_TYPE "arm-sme-outerproduct-fusion"
25#define GEN_PASS_DEF_OUTERPRODUCTFUSION
26#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
35static constexpr StringLiteral
36 kMatchFailureNoAccumulator(
"no accumulator operand");
37static constexpr StringLiteral kMatchFailureExpectedOuterProductDefOp(
38 "defining op of accumulator must be 'arm_sme.outerproduct'");
39static constexpr StringLiteral kMatchFailureInconsistentCombiningKind(
40 "combining kind (add or sub) of outer products must match");
41static constexpr StringLiteral kMatchFailureInconsistentMasking(
42 "unsupported masking, either both outerproducts are masked "
44static constexpr StringLiteral kMatchFailureOuterProductNotSingleUse(
45 "outer product(s) not single use and cannot be removed, no benefit to "
54template <
typename LhsExtOp,
typename RhsExtOp = LhsExtOp>
56 arm_sme::OuterProductOp op,
57 VectorType resultType, VectorType inputType) {
58 if (op.getResultType() != resultType)
60 diag <<
"unsupported result type, expected " << resultType;
63 auto lhsDefOp = op.getLhs().getDefiningOp<LhsExtOp>();
64 auto rhsDefOp = op.getRhs().getDefiningOp<RhsExtOp>();
66 if (!lhsDefOp || !rhsDefOp)
68 op,
"defining op of outerproduct operands must be one of: "
69 "'arith.extf' or 'arith.extsi' or 'arith.extui'");
71 auto lhsInType = cast<VectorType>(lhsDefOp.getIn().getType());
72 auto rhsInType = cast<VectorType>(rhsDefOp.getIn().getType());
74 if (lhsInType != inputType || rhsInType != inputType)
76 diag <<
"unsupported input type, expected " << inputType;
103class OuterProductFusion2Way
108 LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
109 PatternRewriter &rewriter)
const override {
110 Value acc = op.getAcc();
114 arm_sme::OuterProductOp op1 = acc.
getDefiningOp<arm_sme::OuterProductOp>();
115 arm_sme::OuterProductOp op2 = op;
118 op, kMatchFailureExpectedOuterProductDefOp);
120 if (op1.getKind() != op2.getKind())
122 op, kMatchFailureInconsistentCombiningKind);
124 if (!op1->hasOneUse()) {
128 kMatchFailureOuterProductNotSingleUse);
131 if (
bool(op1.getLhsMask()) !=
bool(op2.getLhsMask()))
134 if (
failed(canFuseOuterProducts(rewriter, op1, op2)))
137 auto loc = op.getLoc();
138 auto packInputs = [&](Value
lhs, Value
rhs) {
139 return vector::InterleaveOp::create(rewriter, loc,
lhs,
rhs);
142 auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
143 op2.getLhs().getDefiningOp()->getOperand(0));
144 auto rhs = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
145 op2.getRhs().getDefiningOp()->getOperand(0));
148 if (op1.getLhsMask() || op2.getLhsMask()) {
149 lhsMask = packInputs(op1.getLhsMask(), op2.getLhsMask());
150 rhsMask = packInputs(op1.getRhsMask(), op2.getRhsMask());
155 arm_sme::CombiningKind kind = op.getKind();
156 if (kind == arm_sme::CombiningKind::Add) {
158 .Case<arith::ExtFOp>([&](
auto) {
160 op2, op.getResultType(),
lhs,
rhs, lhsMask, rhsMask,
163 .Case<arith::ExtSIOp>([&](
auto) {
165 op2, op.getResultType(),
lhs,
rhs, lhsMask, rhsMask,
168 .Case<arith::ExtUIOp>([&](
auto) {
170 op2, op.getResultType(),
lhs,
rhs, lhsMask, rhsMask,
173 .DefaultUnreachable(
"unexpected extend op!");
174 }
else if (kind == arm_sme::CombiningKind::Sub) {
176 .Case<arith::ExtFOp>([&](
auto) {
178 op2, op.getResultType(),
lhs,
rhs, lhsMask, rhsMask,
181 .Case<arith::ExtSIOp>([&](
auto) {
183 op2, op.getResultType(),
lhs,
rhs, lhsMask, rhsMask,
186 .Case<arith::ExtUIOp>([&](
auto) {
188 op2, op.getResultType(),
lhs,
rhs, lhsMask, rhsMask,
191 .DefaultUnreachable(
"unexpected extend op!");
193 llvm_unreachable(
"unexpected arm_sme::CombiningKind!");
208 LogicalResult canFuseOuterProducts(PatternRewriter &rewriter,
209 arm_sme::OuterProductOp op1,
210 arm_sme::OuterProductOp op2)
const {
213 VectorType::get({4, 4}, rewriter.
getI32Type(), {
true,
true});
215 VectorType::get({4, 4}, rewriter.
getF32Type(), {
true,
true});
219 auto nxv4i16 = VectorType::get({4}, rewriter.
getI16Type(),
true);
220 auto nxv4f16 = VectorType::get({4}, rewriter.
getF16Type(),
true);
221 auto nxv4bf16 = VectorType::get({4}, rewriter.
getBF16Type(),
true);
223 isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4f16)) ||
225 isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32, nxv4f16))) &&
227 isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4bf16)) ||
228 failed(isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32,
231 isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
232 failed(isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32,
235 isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
237 isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i16))))
246class OuterProductFusion4Way
251 LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
252 PatternRewriter &rewriter)
const override {
253 SmallVector<arm_sme::OuterProductOp, 4> outerProductChain;
254 outerProductChain.push_back(op);
256 for (
int i = 0; i < 3; ++i) {
257 auto currentOp = outerProductChain.back();
258 auto acc = currentOp.getAcc();
261 auto previousOp = acc.
getDefiningOp<arm_sme::OuterProductOp>();
264 op, kMatchFailureExpectedOuterProductDefOp);
265 if (!previousOp->hasOneUse())
267 op, kMatchFailureOuterProductNotSingleUse);
268 if (previousOp.getKind() != currentOp.getKind())
270 op, kMatchFailureInconsistentCombiningKind);
271 if (
bool(previousOp.getLhsMask()) !=
bool(currentOp.getLhsMask()))
273 op, kMatchFailureInconsistentCombiningKind);
274 outerProductChain.push_back(previousOp);
277 if (
failed(canFuseOuterProducts(rewriter, outerProductChain)))
280 arm_sme::OuterProductOp op1 = outerProductChain[3];
281 arm_sme::OuterProductOp op2 = outerProductChain[2];
282 arm_sme::OuterProductOp op3 = outerProductChain[1];
283 arm_sme::OuterProductOp op4 = outerProductChain[0];
285 auto loc = op.getLoc();
286 auto packInputs = [&](Value
lhs, Value
rhs) {
287 return vector::InterleaveOp::create(rewriter, loc,
lhs,
rhs);
290 auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
291 op3.getLhs().getDefiningOp()->getOperand(0));
292 auto lhs1 = packInputs(op2.getLhs().getDefiningOp()->getOperand(0),
293 op4.getLhs().getDefiningOp()->getOperand(0));
294 auto lhs = packInputs(lhs0, lhs1);
296 auto rhs0 = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
297 op3.getRhs().getDefiningOp()->getOperand(0));
298 auto rhs1 = packInputs(op2.getRhs().getDefiningOp()->getOperand(0),
299 op4.getRhs().getDefiningOp()->getOperand(0));
300 auto rhs = packInputs(rhs0, rhs1);
302 Value lhsMask, rhsMask;
303 if (op1.getLhsMask() || op2.getLhsMask() || op3.getLhsMask() ||
305 auto lhs0Mask = packInputs(op1.getLhsMask(), op3.getLhsMask());
306 auto lhs1Mask = packInputs(op2.getLhsMask(), op4.getLhsMask());
307 lhsMask = packInputs(lhs0Mask, lhs1Mask);
309 auto rhs0Mask = packInputs(op1.getRhsMask(), op3.getRhsMask());
310 auto rhs1Mask = packInputs(op2.getRhsMask(), op4.getRhsMask());
311 rhsMask = packInputs(rhs0Mask, rhs1Mask);
315 auto rhsExtOp = op.getRhs().getDefiningOp();
317 arm_sme::CombiningKind kind = op.getKind();
318 if (kind == arm_sme::CombiningKind::Add) {
319 if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
322 op4, op.getResultType(),
lhs,
rhs, lhsMask, rhsMask, op1.getAcc());
323 }
else if (isa<arith::ExtUIOp>(lhsExtOp) &&
324 isa<arith::ExtUIOp>(rhsExtOp)) {
327 op4, op.getResultType(),
lhs,
rhs, lhsMask, rhsMask, op1.getAcc());
328 }
else if (isa<arith::ExtSIOp>(lhsExtOp) &&
329 isa<arith::ExtUIOp>(rhsExtOp)) {
332 op4, op.getResultType(),
lhs,
rhs, lhsMask, rhsMask, op1.getAcc());
333 }
else if (isa<arith::ExtUIOp>(lhsExtOp) &&
334 isa<arith::ExtSIOp>(rhsExtOp)) {
337 op4, op.getResultType(),
lhs,
rhs, lhsMask, rhsMask, op1.getAcc());
339 llvm_unreachable(
"unexpected extend op!");
341 }
else if (kind == arm_sme::CombiningKind::Sub) {
342 if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
345 op4, op.getResultType(),
lhs,
rhs, lhsMask, rhsMask, op1.getAcc());
346 }
else if (isa<arith::ExtUIOp>(lhsExtOp) &&
347 isa<arith::ExtUIOp>(rhsExtOp)) {
350 op4, op.getResultType(),
lhs,
rhs, lhsMask, rhsMask, op1.getAcc());
351 }
else if (isa<arith::ExtSIOp>(lhsExtOp) &&
352 isa<arith::ExtUIOp>(rhsExtOp)) {
355 op4, op.getResultType(),
lhs,
rhs, lhsMask, rhsMask, op1.getAcc());
356 }
else if (isa<arith::ExtUIOp>(lhsExtOp) &&
357 isa<arith::ExtSIOp>(rhsExtOp)) {
360 op4, op.getResultType(),
lhs,
rhs, lhsMask, rhsMask, op1.getAcc());
362 llvm_unreachable(
"unexpected extend op!");
365 llvm_unreachable(
"unexpected arm_sme::CombiningKind!");
381 canFuseOuterProducts(PatternRewriter &rewriter,
382 ArrayRef<arm_sme::OuterProductOp> ops)
const {
385 VectorType::get({4, 4}, rewriter.
getI32Type(), {
true,
true});
387 VectorType::get({2, 2}, rewriter.
getI64Type(), {
true,
true});
392 auto nxv4i8 = VectorType::get({4}, rewriter.
getI8Type(),
true);
393 auto nxv2i16 = VectorType::get({2}, rewriter.
getI16Type(),
true);
395 auto failedToMatch = [&](VectorType resultType, VectorType inputType,
396 auto lhsExtendOp,
auto rhsExtendOp) {
397 using LhsExtendOpTy =
decltype(lhsExtendOp);
398 using RhsExtendOpTy =
decltype(rhsExtendOp);
399 for (
auto op : ops) {
400 if (
failed(isCompatible<LhsExtendOpTy, RhsExtendOpTy>(
401 rewriter, op, resultType, inputType)))
407 if (failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
408 failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
409 failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
410 failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtSIOp{}) &&
411 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
412 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
413 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
414 failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtSIOp{}))
432struct SwapVectorExtractOfArithExtend
436 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
437 PatternRewriter &rewriter)
const override {
438 VectorType resultType = llvm::dyn_cast<VectorType>(extractOp.getType());
441 "extracted type is not a vector type");
443 auto numScalableDims = resultType.getNumScalableDims();
444 if (numScalableDims != 1)
446 extractOp,
"extracted type is not a 1-D scalable vector type");
448 auto *extendOp = extractOp.getSource().getDefiningOp();
449 if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
452 "extract not from extend op");
454 auto loc = extractOp.getLoc();
455 StringAttr extendOpName = extendOp->getName().getIdentifier();
456 Value extendSource = extendOp->getOperand(0);
459 Value newExtract = vector::ExtractOp::create(rewriter, loc, extendSource,
460 extractOp.getMixedPosition());
463 Operation *newExtend =
464 rewriter.
create(loc, extendOpName, Value(newExtract), resultType);
466 rewriter.
replaceOp(extractOp, newExtend);
483struct SwapVectorScalableExtractOfArithExtend
487 LogicalResult matchAndRewrite(vector::ScalableExtractOp extractOp,
488 PatternRewriter &rewriter)
const override {
489 auto *extendOp = extractOp.getSource().getDefiningOp();
490 if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
493 "extract not from extend op");
495 auto loc = extractOp.getLoc();
496 VectorType resultType = extractOp.getResultVectorType();
498 Value extendSource = extendOp->getOperand(0);
499 StringAttr extendOpName = extendOp->getName().getIdentifier();
500 VectorType extendSourceVectorType =
501 cast<VectorType>(extendSource.
getType());
504 VectorType extractResultVectorType =
505 resultType.clone(extendSourceVectorType.getElementType());
506 Value newExtract = vector::ScalableExtractOp::create(
507 rewriter, loc, extractResultVectorType, extendSource,
511 Operation *newExtend =
512 rewriter.
create(loc, extendOpName, Value(newExtract), resultType);
514 rewriter.
replaceOp(extractOp, newExtend);
520struct OuterProductFusionPass
523 void runOnOperation()
override {
538 patterns.add<SwapVectorExtractOfArithExtend,
539 SwapVectorScalableExtractOfArithExtend>(context, 1024);
540 patterns.add<OuterProductFusion2Way, OuterProductFusion4Way>(context);
544 return std::make_unique<OuterProductFusionPass>();
static std::string diag(const llvm::Value &value)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
std::unique_ptr< Pass > createOuterProductFusionPass()
Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening variants.
void populateOuterProductFusionPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
llvm::TypeSwitch< T, ResultT > TypeSwitch
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...