MLIR  20.0.0git
OuterProductFusion.cpp
Go to the documentation of this file.
1 //===- OuterProductFusion.cpp - Fuse 'arm_sme.outerproduct' ops -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrites that fuse 'arm_sme.outerproduct' operations
10 // into the 2-way or 4-way widening outerproduct operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
18 #include "mlir/IR/PatternMatch.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 
22 #define DEBUG_TYPE "arm-sme-outerproduct-fusion"
23 
24 namespace mlir::arm_sme {
25 #define GEN_PASS_DEF_OUTERPRODUCTFUSION
26 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
27 } // namespace mlir::arm_sme
28 
29 using namespace mlir;
30 using namespace mlir::arm_sme;
31 
32 namespace {
33 
34 // Common match failure reasons.
35 static constexpr StringLiteral
36  kMatchFailureNoAccumulator("no accumulator operand");
37 static constexpr StringLiteral kMatchFailureExpectedOuterProductDefOp(
38  "defining op of accumulator must be 'arm_sme.outerproduct'");
39 static constexpr StringLiteral kMatchFailureInconsistentCombiningKind(
40  "combining kind (add or sub) of outer products must match");
41 static constexpr StringLiteral kMatchFailureInconsistentMasking(
42  "unsupported masking, either both outerproducts are masked "
43  "or neither");
44 static constexpr StringLiteral kMatchFailureOuterProductNotSingleUse(
45  "outer product(s) not single use and cannot be removed, no benefit to "
46  "fusing");
47 
48 // An outer product is compatible if all of the following are true:
49 // - the result type matches `resultType`.
50 // - the defining operation of LHS is of the type `LhsExtOp`.
51 // - the defining operation of RHS is of the type `RhsExtOp`.
52 // - the input types of the defining operations are identical and match
53 // `inputType`.
54 template <typename LhsExtOp, typename RhsExtOp = LhsExtOp>
55 static LogicalResult isCompatible(PatternRewriter &rewriter,
56  arm_sme::OuterProductOp op,
57  VectorType resultType, VectorType inputType) {
58  if (op.getResultType() != resultType)
59  return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
60  diag << "unsupported result type, expected " << resultType;
61  });
62 
63  auto lhsDefOp = op.getLhs().getDefiningOp<LhsExtOp>();
64  auto rhsDefOp = op.getRhs().getDefiningOp<RhsExtOp>();
65 
66  if (!lhsDefOp || !rhsDefOp)
67  return rewriter.notifyMatchFailure(
68  op, "defining op of outerproduct operands must be one of: "
69  "'arith.extf' or 'arith.extsi' or 'arith.extui'");
70 
71  auto lhsInType = cast<VectorType>(lhsDefOp.getIn().getType());
72  auto rhsInType = cast<VectorType>(rhsDefOp.getIn().getType());
73 
74  if (lhsInType != inputType || rhsInType != inputType)
75  return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
76  diag << "unsupported input type, expected " << inputType;
77  });
78 
79  return success();
80 }
81 
82 // Fuse two 'arm_sme.outerproduct' operations that are chained via the
83 // accumulator into 2-way outer product operation.
84 //
85 // For example:
86 //
87 // %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
88 // %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
89 // %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>,
90 // vector<[4]xf32>
91 //
92 // %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
93 // %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
94 // %1 = arm_sme.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>,
95 // vector<[4]xf32>
96 //
97 // Becomes:
98 //
99 // %a_packed = vector.interleave %a0, %a1 : vector<[4]xf16> -> vector<[8]xf16>
100 // %b_packed = vector.interleave %b0, %b1 : vector<[4]xf16> -> vector<[8]xf16>
101 // %0 = arm_sme.fmopa_2way %a_packed, %b_packed
102 // : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
103 class OuterProductFusion2Way
104  : public OpRewritePattern<arm_sme::OuterProductOp> {
105 public:
107 
108  LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
109  PatternRewriter &rewriter) const override {
110  Value acc = op.getAcc();
111  if (!acc)
112  return rewriter.notifyMatchFailure(op, kMatchFailureNoAccumulator);
113 
114  arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
115  arm_sme::OuterProductOp op2 = op;
116  if (!op1)
117  return rewriter.notifyMatchFailure(
118  op, kMatchFailureExpectedOuterProductDefOp);
119 
120  if (op1.getKind() != op2.getKind())
121  return rewriter.notifyMatchFailure(
122  op, kMatchFailureInconsistentCombiningKind);
123 
124  if (!op1->hasOneUse()) {
125  // If the first outer product has uses other than as the input to another
126  // outer product, it can't be erased after fusion.
127  return rewriter.notifyMatchFailure(op,
128  kMatchFailureOuterProductNotSingleUse);
129  }
130 
131  if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()))
132  return rewriter.notifyMatchFailure(op, kMatchFailureInconsistentMasking);
133 
134  if (failed(canFuseOuterProducts(rewriter, op1, op2)))
135  return failure();
136 
137  auto loc = op.getLoc();
138  auto packInputs = [&](Value lhs, Value rhs) {
139  return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs);
140  };
141 
142  auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
143  op2.getLhs().getDefiningOp()->getOperand(0));
144  auto rhs = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
145  op2.getRhs().getDefiningOp()->getOperand(0));
146 
147  Value lhsMask, rhsMask;
148  if (op1.getLhsMask() || op2.getLhsMask()) {
149  lhsMask = packInputs(op1.getLhsMask(), op2.getLhsMask());
150  rhsMask = packInputs(op1.getRhsMask(), op2.getRhsMask());
151  }
152 
153  auto extOp = op.getLhs().getDefiningOp();
154 
155  arm_sme::CombiningKind kind = op.getKind();
156  if (kind == arm_sme::CombiningKind::Add) {
158  .Case<arith::ExtFOp>([&](auto) {
159  rewriter.replaceOpWithNewOp<arm_sme::FMopa2WayOp>(
160  op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
161  op1.getAcc());
162  })
163  .Case<arith::ExtSIOp>([&](auto) {
164  rewriter.replaceOpWithNewOp<arm_sme::SMopa2WayOp>(
165  op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
166  op1.getAcc());
167  })
168  .Case<arith::ExtUIOp>([&](auto) {
169  rewriter.replaceOpWithNewOp<arm_sme::UMopa2WayOp>(
170  op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
171  op1.getAcc());
172  })
173  .Default([&](auto) { llvm_unreachable("unexpected extend op!"); });
174  } else if (kind == arm_sme::CombiningKind::Sub) {
176  .Case<arith::ExtFOp>([&](auto) {
177  rewriter.replaceOpWithNewOp<arm_sme::FMops2WayOp>(
178  op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
179  op1.getAcc());
180  })
181  .Case<arith::ExtSIOp>([&](auto) {
182  rewriter.replaceOpWithNewOp<arm_sme::SMops2WayOp>(
183  op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
184  op1.getAcc());
185  })
186  .Case<arith::ExtUIOp>([&](auto) {
187  rewriter.replaceOpWithNewOp<arm_sme::UMops2WayOp>(
188  op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
189  op1.getAcc());
190  })
191  .Default([&](auto) { llvm_unreachable("unexpected extend op!"); });
192  } else {
193  llvm_unreachable("unexpected arm_sme::CombiningKind!");
194  }
195 
196  return success();
197  }
198 
199 private:
200  // A pair of outer product can be fused if all of the following are true:
201  // - input and result types match.
202  // - the defining operations of the inputs are identical extensions,
203  // specifically either:
204  // - a signed or unsigned extension for integer types.
205  // - a floating-point extension for floating-point types.
206  // - the types and extension are supported, i.e. there's a 2-way operation
207  // they can be fused into.
208  LogicalResult canFuseOuterProducts(PatternRewriter &rewriter,
209  arm_sme::OuterProductOp op1,
210  arm_sme::OuterProductOp op2) const {
211  // Supported result types.
212  auto nxnxv4i32 =
213  VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
214  auto nxnxv4f32 =
215  VectorType::get({4, 4}, rewriter.getF32Type(), {true, true});
216  // Supported input types.
217  // Note: this is before packing so these have half the number of elements
218  // of the input vector types of the 2-way operations.
219  auto nxv4i16 = VectorType::get({4}, rewriter.getI16Type(), true);
220  auto nxv4f16 = VectorType::get({4}, rewriter.getF16Type(), true);
221  auto nxv4bf16 = VectorType::get({4}, rewriter.getBF16Type(), true);
222  if ((failed(
223  isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4f16)) ||
224  failed(
225  isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32, nxv4f16))) &&
226  (failed(
227  isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4bf16)) ||
228  failed(isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32,
229  nxv4bf16))) &&
230  (failed(
231  isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
232  failed(isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32,
233  nxv4i16))) &&
234  (failed(
235  isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
236  failed(
237  isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i16))))
238  return failure();
239 
240  return success();
241  }
242 };
243 
244 // Fuse four 'arm_sme.outerproduct' operations that are chained via the
245 // accumulator into 4-way outer product operation.
246 class OuterProductFusion4Way
247  : public OpRewritePattern<arm_sme::OuterProductOp> {
248 public:
250 
251  LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
252  PatternRewriter &rewriter) const override {
253  SmallVector<arm_sme::OuterProductOp, 4> outerProductChain;
254  outerProductChain.push_back(op);
255 
256  for (int i = 0; i < 3; ++i) {
257  auto currentOp = outerProductChain.back();
258  auto acc = currentOp.getAcc();
259  if (!acc)
260  return rewriter.notifyMatchFailure(op, kMatchFailureNoAccumulator);
261  auto previousOp = acc.getDefiningOp<arm_sme::OuterProductOp>();
262  if (!previousOp)
263  return rewriter.notifyMatchFailure(
264  op, kMatchFailureExpectedOuterProductDefOp);
265  if (!previousOp->hasOneUse())
266  return rewriter.notifyMatchFailure(
267  op, kMatchFailureOuterProductNotSingleUse);
268  if (previousOp.getKind() != currentOp.getKind())
269  return rewriter.notifyMatchFailure(
270  op, kMatchFailureInconsistentCombiningKind);
271  if (bool(previousOp.getLhsMask()) != bool(currentOp.getLhsMask()))
272  return rewriter.notifyMatchFailure(
273  op, kMatchFailureInconsistentCombiningKind);
274  outerProductChain.push_back(previousOp);
275  }
276 
277  if (failed(canFuseOuterProducts(rewriter, outerProductChain)))
278  return failure();
279 
280  arm_sme::OuterProductOp op1 = outerProductChain[3];
281  arm_sme::OuterProductOp op2 = outerProductChain[2];
282  arm_sme::OuterProductOp op3 = outerProductChain[1];
283  arm_sme::OuterProductOp op4 = outerProductChain[0];
284 
285  auto loc = op.getLoc();
286  auto packInputs = [&](Value lhs, Value rhs) {
287  return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs);
288  };
289 
290  auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
291  op3.getLhs().getDefiningOp()->getOperand(0));
292  auto lhs1 = packInputs(op2.getLhs().getDefiningOp()->getOperand(0),
293  op4.getLhs().getDefiningOp()->getOperand(0));
294  auto lhs = packInputs(lhs0, lhs1);
295 
296  auto rhs0 = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
297  op3.getRhs().getDefiningOp()->getOperand(0));
298  auto rhs1 = packInputs(op2.getRhs().getDefiningOp()->getOperand(0),
299  op4.getRhs().getDefiningOp()->getOperand(0));
300  auto rhs = packInputs(rhs0, rhs1);
301 
302  Value lhsMask, rhsMask;
303  if (op1.getLhsMask() || op2.getLhsMask() || op3.getLhsMask() ||
304  op4.getLhsMask()) {
305  auto lhs0Mask = packInputs(op1.getLhsMask(), op3.getLhsMask());
306  auto lhs1Mask = packInputs(op2.getLhsMask(), op4.getLhsMask());
307  lhsMask = packInputs(lhs0Mask, lhs1Mask);
308 
309  auto rhs0Mask = packInputs(op1.getRhsMask(), op3.getRhsMask());
310  auto rhs1Mask = packInputs(op2.getRhsMask(), op4.getRhsMask());
311  rhsMask = packInputs(rhs0Mask, rhs1Mask);
312  }
313 
314  auto lhsExtOp = op.getLhs().getDefiningOp();
315  auto rhsExtOp = op.getRhs().getDefiningOp();
316 
317  arm_sme::CombiningKind kind = op.getKind();
318  if (kind == arm_sme::CombiningKind::Add) {
319  if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
320  // signed
321  rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>(
322  op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
323  } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
324  isa<arith::ExtUIOp>(rhsExtOp)) {
325  // unsigned
326  rewriter.replaceOpWithNewOp<arm_sme::UMopa4WayOp>(
327  op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
328  } else if (isa<arith::ExtSIOp>(lhsExtOp) &&
329  isa<arith::ExtUIOp>(rhsExtOp)) {
330  // signed by unsigned
331  rewriter.replaceOpWithNewOp<arm_sme::SuMopa4WayOp>(
332  op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
333  } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
334  isa<arith::ExtSIOp>(rhsExtOp)) {
335  // unsigned by signed
336  rewriter.replaceOpWithNewOp<arm_sme::UsMopa4WayOp>(
337  op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
338  } else {
339  llvm_unreachable("unexpected extend op!");
340  }
341  } else if (kind == arm_sme::CombiningKind::Sub) {
342  if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
343  // signed
344  rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>(
345  op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
346  } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
347  isa<arith::ExtUIOp>(rhsExtOp)) {
348  // unsigned
349  rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>(
350  op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
351  } else if (isa<arith::ExtSIOp>(lhsExtOp) &&
352  isa<arith::ExtUIOp>(rhsExtOp)) {
353  // signed by unsigned
354  rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>(
355  op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
356  } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
357  isa<arith::ExtSIOp>(rhsExtOp)) {
358  // unsigned by signed
359  rewriter.replaceOpWithNewOp<arm_sme::UsMops4WayOp>(
360  op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
361  } else {
362  llvm_unreachable("unexpected extend op!");
363  }
364  } else {
365  llvm_unreachable("unexpected arm_sme::CombiningKind!");
366  }
367 
368  return success();
369  }
370 
371 private:
372  // Four outer products can be fused if all of the following are true:
373  // - input and result types match.
374  // - the defining operations of the inputs are identical extensions,
375  // specifically either:
376  // - a signed or unsigned extension for integer types.
377  // - a floating-point extension for floating-point types.
378  // - the types and extension are supported, i.e. there's a 4-way operation
379  // they can be fused into.
380  LogicalResult
381  canFuseOuterProducts(PatternRewriter &rewriter,
383  // Supported result types.
384  auto nxnxv4i32 =
385  VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
386  auto nxnxv2i64 =
387  VectorType::get({2, 2}, rewriter.getI64Type(), {true, true});
388 
389  // Supported input types.
390  // Note: this is before packing so these have 1/4 the number of elements
391  // of the input vector types of the 4-way operations.
392  auto nxv4i8 = VectorType::get({4}, rewriter.getI8Type(), true);
393  auto nxv2i16 = VectorType::get({2}, rewriter.getI16Type(), true);
394 
395  auto failedToMatch = [&](VectorType resultType, VectorType inputType,
396  auto lhsExtendOp, auto rhsExtendOp) {
397  using LhsExtendOpTy = decltype(lhsExtendOp);
398  using RhsExtendOpTy = decltype(rhsExtendOp);
399  for (auto op : ops) {
400  if (failed(isCompatible<LhsExtendOpTy, RhsExtendOpTy>(
401  rewriter, op, resultType, inputType)))
402  return true;
403  }
404  return false;
405  };
406 
407  if (failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
408  failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
409  failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
410  failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtSIOp{}) &&
411  failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
412  failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
413  failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
414  failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtSIOp{}))
415  return failure();
416 
417  return success();
418  }
419 };
420 
421 // Rewrites: vector.extract(arith.extend) -> arith.extend(vector.extract).
422 //
423 // This transforms IR like:
424 // %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
425 // %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
426 // Into:
427 // %0 = vector.extract %src[0] : vector<[8]xi8> from vector<4x[8]xi8>
428 // %1 = arith.extsi %0 : vector<[8]xi8> to vector<[8]xi32>
429 //
430 // This enables outer product fusion in the `-arm-sme-outer-product-fusion`
431 // pass when the result is the input to an outer product.
432 struct SwapVectorExtractOfArithExtend
433  : public OpRewritePattern<vector::ExtractOp> {
435 
436  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
437  PatternRewriter &rewriter) const override {
438  VectorType resultType = llvm::dyn_cast<VectorType>(extractOp.getType());
439  if (!resultType)
440  return rewriter.notifyMatchFailure(extractOp,
441  "extracted type is not a vector type");
442 
443  auto numScalableDims = resultType.getNumScalableDims();
444  if (numScalableDims != 1)
445  return rewriter.notifyMatchFailure(
446  extractOp, "extracted type is not a 1-D scalable vector type");
447 
448  auto *extendOp = extractOp.getVector().getDefiningOp();
449  if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
450  extendOp))
451  return rewriter.notifyMatchFailure(extractOp,
452  "extract not from extend op");
453 
454  auto loc = extractOp.getLoc();
455  StringAttr extendOpName = extendOp->getName().getIdentifier();
456  Value extendSource = extendOp->getOperand(0);
457 
458  // Create new extract from source of extend.
459  Value newExtract = rewriter.create<vector::ExtractOp>(
460  loc, extendSource, extractOp.getMixedPosition());
461 
462  // Extend new extract to original result type.
463  Operation *newExtend =
464  rewriter.create(loc, extendOpName, Value(newExtract), resultType);
465 
466  rewriter.replaceOp(extractOp, newExtend);
467 
468  return success();
469  }
470 };
471 
472 // Same as above, but for vector.scalable.extract.
473 //
474 // This transforms IR like:
475 // %0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32>
476 // %1 = vector.scalable.extract %0[0] : vector<[4]xi32> from vector<[8]xi32>
477 // Into:
478 // %0 = vector.scalable.extract %src[0] : vector<[4]xi8> from vector<[8]xi8>
479 // %1 = arith.extsi %0 : vector<[4]xi8> to vector<[4]xi32>
480 //
481 // This enables outer product fusion in the `-arm-sme-outer-product-fusion`
482 // pass when the result is the input to an outer product.
483 struct SwapVectorScalableExtractOfArithExtend
484  : public OpRewritePattern<vector::ScalableExtractOp> {
486 
487  LogicalResult matchAndRewrite(vector::ScalableExtractOp extractOp,
488  PatternRewriter &rewriter) const override {
489  auto *extendOp = extractOp.getSource().getDefiningOp();
490  if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
491  extendOp))
492  return rewriter.notifyMatchFailure(extractOp,
493  "extract not from extend op");
494 
495  auto loc = extractOp.getLoc();
496  VectorType resultType = extractOp.getResultVectorType();
497 
498  Value extendSource = extendOp->getOperand(0);
499  StringAttr extendOpName = extendOp->getName().getIdentifier();
500  VectorType extendSourceVectorType =
501  cast<VectorType>(extendSource.getType());
502 
503  // Create new extract from source of extend.
504  VectorType extractResultVectorType =
505  resultType.clone(extendSourceVectorType.getElementType());
506  Value newExtract = rewriter.create<vector::ScalableExtractOp>(
507  loc, extractResultVectorType, extendSource, extractOp.getPos());
508 
509  // Extend new extract to original result type.
510  Operation *newExtend =
511  rewriter.create(loc, extendOpName, Value(newExtract), resultType);
512 
513  rewriter.replaceOp(extractOp, newExtend);
514 
515  return success();
516  }
517 };
518 
519 struct OuterProductFusionPass
520  : public arm_sme::impl::OuterProductFusionBase<OuterProductFusionPass> {
521 
522  void runOnOperation() override {
525 
526  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
527  signalPassFailure();
528  }
529 };
530 
531 } // namespace
532 
535  MLIRContext *context = patterns.getContext();
536  // Note: High benefit to ensure extract(extend) are swapped first.
537  patterns.add<SwapVectorExtractOfArithExtend,
538  SwapVectorScalableExtractOfArithExtend>(context, 1024);
539  patterns.add<OuterProductFusion2Way, OuterProductFusion4Way>(context);
540 }
541 
543  return std::make_unique<OuterProductFusionPass>();
544 }
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
IntegerType getI16Type()
Definition: Builders.cpp:105
FloatType getF32Type()
Definition: Builders.cpp:87
IntegerType getI64Type()
Definition: Builders.cpp:109
IntegerType getI32Type()
Definition: Builders.cpp:107
FloatType getF16Type()
Definition: Builders.cpp:83
FloatType getBF16Type()
Definition: Builders.cpp:81
IntegerType getI8Type()
Definition: Builders.cpp:103
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
std::unique_ptr< Pass > createOuterProductFusionPass()
Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening variants.
void populateOuterProductFusionPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362