MLIR  19.0.0git
OuterProductFusion.cpp
Go to the documentation of this file.
1 //===- OuterProductFusion.cpp - Fuse 'arm_sme.outerproduct' ops -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrites that fuse 'arm_sme.outerproduct' operations
10 // into the 2-way or 4-way widening outerproduct operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
19 #include "mlir/IR/PatternMatch.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 #define DEBUG_TYPE "arm-sme-outerproduct-fusion"
24 
25 namespace mlir::arm_sme {
26 #define GEN_PASS_DEF_OUTERPRODUCTFUSION
27 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
28 } // namespace mlir::arm_sme
29 
30 using namespace mlir;
31 using namespace mlir::arm_sme;
32 
33 namespace {
34 
35 // Common match failure reasons.
36 static constexpr StringLiteral
37  kMatchFailureNoAccumulator("no accumulator operand");
38 static constexpr StringLiteral kMatchFailureExpectedOuterProductDefOp(
39  "defining op of accumulator must be 'arm_sme.outerproduct'");
40 static constexpr StringLiteral kMatchFailureInconsistentCombiningKind(
41  "combining kind (add or sub) of outer products must match");
42 static constexpr StringLiteral kMatchFailureInconsistentMasking(
43  "unsupported masking, either both outerproducts are masked "
44  "or neither");
45 static constexpr StringLiteral kMatchFailureOuterProductNotSingleUse(
46  "outer product(s) not single use and cannot be removed, no benefit to "
47  "fusing");
48 
49 // An outer product is compatible if all of the following are true:
50 // - the result type matches `resultType`.
51 // - the defining operation of LHS is of the type `LhsExtOp`.
52 // - the defining operation of RHS is of the type `RhsExtOp`.
53 // - the input types of the defining operations are identical and match
54 // `inputType`.
55 template <typename LhsExtOp, typename RhsExtOp = LhsExtOp>
56 static LogicalResult isCompatible(PatternRewriter &rewriter,
57  arm_sme::OuterProductOp op,
58  VectorType resultType, VectorType inputType) {
59  if (op.getResultType() != resultType)
60  return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
61  diag << "unsupported result type, expected " << resultType;
62  });
63 
64  auto lhsDefOp = op.getLhs().getDefiningOp<LhsExtOp>();
65  auto rhsDefOp = op.getRhs().getDefiningOp<RhsExtOp>();
66 
67  if (!lhsDefOp || !rhsDefOp)
68  return rewriter.notifyMatchFailure(
69  op, "defining op of outerproduct operands must be one of: "
70  "'arith.extf' or 'arith.extsi' or 'arith.extui'");
71 
72  auto lhsInType = cast<VectorType>(lhsDefOp.getIn().getType());
73  auto rhsInType = cast<VectorType>(rhsDefOp.getIn().getType());
74 
75  if (lhsInType != inputType || rhsInType != inputType)
76  return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
77  diag << "unsupported input type, expected " << inputType;
78  });
79 
80  return success();
81 }
82 
83 // Create 'llvm.experimental.vector.interleave2' intrinsic from `lhs` and `rhs`.
84 static Value createInterleave2Intrinsic(RewriterBase &rewriter, Location loc,
85  Value lhs, Value rhs) {
86  auto inputType = cast<VectorType>(lhs.getType());
87  VectorType inputTypeX2 =
88  VectorType::Builder(inputType).setDim(0, inputType.getShape()[0] * 2);
89  return rewriter.create<LLVM::experimental_vector_interleave2>(
90  loc, inputTypeX2, lhs, rhs);
91 }
92 
93 // Fuse two 'arm_sme.outerproduct' operations that are chained via the
94 // accumulator into 2-way outer product operation.
95 //
96 // For example:
97 //
98 // %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
99 // %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
100 // %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>,
101 // vector<[4]xf32>
102 //
103 // %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
104 // %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
105 // %1 = arm_sme.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>,
106 // vector<[4]xf32>
107 //
108 // Becomes:
109 //
110 // %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1)
111 // : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
112 // %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1)
113 // : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
114 // %0 = arm_sme.fmopa_2way %a_packed, %b_packed
115 // : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
116 class OuterProductFusion2Way
117  : public OpRewritePattern<arm_sme::OuterProductOp> {
118 public:
120 
121  LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
122  PatternRewriter &rewriter) const override {
123  Value acc = op.getAcc();
124  if (!acc)
125  return rewriter.notifyMatchFailure(op, kMatchFailureNoAccumulator);
126 
127  arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
128  arm_sme::OuterProductOp op2 = op;
129  if (!op1)
130  return rewriter.notifyMatchFailure(
131  op, kMatchFailureExpectedOuterProductDefOp);
132 
133  if (op1.getKind() != op2.getKind())
134  return rewriter.notifyMatchFailure(
135  op, kMatchFailureInconsistentCombiningKind);
136 
137  if (!op1->hasOneUse()) {
138  // If the first outer product has uses other than as the input to another
139  // outer product, it can't be erased after fusion. This is a problem when
140  // it also has an accumulator as this will be used as the root for tile
141  // allocation and since the widening outer product uses the same
142  // accumulator it will get assigned the same tile ID, resulting in 3
143  // outer products accumulating to the same tile and incorrect results.
144  //
145  // Example:
146  //
147  // %acc = arith.constant dense<0.0> ; root for tile allocation
148  // %0 = arm_sme.outerproduct %a0, %b0 acc(%acc)
149  // vector.print %0 ; intermediary use, can't erase %0
150  // %1 = arm_sme.outerproduct %a1, %b1 acc(%0)
151  //
152  // After fusion and tile allocation
153  //
154  // %0 = arm_sme.zero {tile_id = 0 : i32}
155  // %1 = arm_sme.outerproduct %a0, %b0 acc(%0) {tile_id = 0 : i32}
156  // vector.print %1
157  // %2 = arm_sme.fmopa_2way %a, %b acc(%0) {tile_id = 0 : i32}
158  //
159  // No accumulator would be ok, but it's simpler to prevent this
160  // altogether, since it has no benefit.
161  return rewriter.notifyMatchFailure(op,
162  kMatchFailureOuterProductNotSingleUse);
163  }
164 
165  if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()))
166  return rewriter.notifyMatchFailure(op, kMatchFailureInconsistentMasking);
167 
168  if (failed(canFuseOuterProducts(rewriter, op1, op2)))
169  return failure();
170 
171  auto loc = op.getLoc();
172  auto packInputs = [&](Value lhs, Value rhs) {
173  return createInterleave2Intrinsic(rewriter, loc, lhs, rhs);
174  };
175 
176  auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
177  op2.getLhs().getDefiningOp()->getOperand(0));
178  auto rhs = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
179  op2.getRhs().getDefiningOp()->getOperand(0));
180 
181  Value lhsMask, rhsMask;
182  if (op1.getLhsMask() || op2.getLhsMask()) {
183  lhsMask = packInputs(op1.getLhsMask(), op2.getLhsMask());
184  rhsMask = packInputs(op1.getRhsMask(), op2.getRhsMask());
185  }
186 
187  auto extOp = op.getLhs().getDefiningOp();
188 
189  arm_sme::CombiningKind kind = op.getKind();
190  if (kind == arm_sme::CombiningKind::Add) {
192  .Case<arith::ExtFOp>([&](auto) {
193  rewriter.replaceOpWithNewOp<arm_sme::FMopa2WayOp>(
194  op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
195  op1.getAcc());
196  })
197  .Case<arith::ExtSIOp>([&](auto) {
198  rewriter.replaceOpWithNewOp<arm_sme::SMopa2WayOp>(
199  op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
200  op1.getAcc());
201  })
202  .Case<arith::ExtUIOp>([&](auto) {
203  rewriter.replaceOpWithNewOp<arm_sme::UMopa2WayOp>(
204  op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
205  op1.getAcc());
206  })
207  .Default([&](auto) { llvm_unreachable("unexpected extend op!"); });
208  } else if (kind == arm_sme::CombiningKind::Sub) {
210  .Case<arith::ExtFOp>([&](auto) {
211  rewriter.replaceOpWithNewOp<arm_sme::FMops2WayOp>(
212  op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
213  op1.getAcc());
214  })
215  .Case<arith::ExtSIOp>([&](auto) {
216  rewriter.replaceOpWithNewOp<arm_sme::SMops2WayOp>(
217  op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
218  op1.getAcc());
219  })
220  .Case<arith::ExtUIOp>([&](auto) {
221  rewriter.replaceOpWithNewOp<arm_sme::UMops2WayOp>(
222  op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
223  op1.getAcc());
224  })
225  .Default([&](auto) { llvm_unreachable("unexpected extend op!"); });
226  } else {
227  llvm_unreachable("unexpected arm_sme::CombiningKind!");
228  }
229 
230  rewriter.eraseOp(op1);
231 
232  return success();
233  }
234 
235 private:
236  // A pair of outer product can be fused if all of the following are true:
237  // - input and result types match.
238  // - the defining operations of the inputs are identical extensions,
239  // specifically either:
240  // - a signed or unsigned extension for integer types.
241  // - a floating-point extension for floating-point types.
242  // - the types and extension are supported, i.e. there's a 2-way operation
243  // they can be fused into.
244  LogicalResult canFuseOuterProducts(PatternRewriter &rewriter,
245  arm_sme::OuterProductOp op1,
246  arm_sme::OuterProductOp op2) const {
247  // Supported result types.
248  auto nxnxv4i32 =
249  VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
250  auto nxnxv4f32 =
251  VectorType::get({4, 4}, rewriter.getF32Type(), {true, true});
252  // Supported input types.
253  // Note: this is before packing so these have half the number of elements
254  // of the input vector types of the 2-way operations.
255  auto nxv4i16 = VectorType::get({4}, rewriter.getI16Type(), true);
256  auto nxv4f16 = VectorType::get({4}, rewriter.getF16Type(), true);
257  auto nxv4bf16 = VectorType::get({4}, rewriter.getBF16Type(), true);
258  if ((failed(
259  isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4f16)) ||
260  failed(
261  isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32, nxv4f16))) &&
262  (failed(
263  isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4bf16)) ||
264  failed(isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32,
265  nxv4bf16))) &&
266  (failed(
267  isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
268  failed(isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32,
269  nxv4i16))) &&
270  (failed(
271  isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
272  failed(
273  isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i16))))
274  return failure();
275 
276  return success();
277  }
278 };
279 
280 // Fuse four 'arm_sme.outerproduct' operations that are chained via the
281 // accumulator into 4-way outer product operation.
282 class OuterProductFusion4Way
283  : public OpRewritePattern<arm_sme::OuterProductOp> {
284 public:
286 
287  LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
288  PatternRewriter &rewriter) const override {
289  SmallVector<arm_sme::OuterProductOp, 4> outerProductChain;
290  outerProductChain.push_back(op);
291 
292  for (int i = 0; i < 3; ++i) {
293  auto currentOp = outerProductChain.back();
294  auto acc = currentOp.getAcc();
295  if (!acc)
296  return rewriter.notifyMatchFailure(op, kMatchFailureNoAccumulator);
297  auto previousOp = acc.getDefiningOp<arm_sme::OuterProductOp>();
298  if (!previousOp)
299  return rewriter.notifyMatchFailure(
300  op, kMatchFailureExpectedOuterProductDefOp);
301  if (!previousOp->hasOneUse())
302  return rewriter.notifyMatchFailure(
303  op, kMatchFailureOuterProductNotSingleUse);
304  if (previousOp.getKind() != currentOp.getKind())
305  return rewriter.notifyMatchFailure(
306  op, kMatchFailureInconsistentCombiningKind);
307  if (bool(previousOp.getLhsMask()) != bool(currentOp.getLhsMask()))
308  return rewriter.notifyMatchFailure(
309  op, kMatchFailureInconsistentCombiningKind);
310  outerProductChain.push_back(previousOp);
311  }
312 
313  if (failed(canFuseOuterProducts(rewriter, outerProductChain)))
314  return failure();
315 
316  arm_sme::OuterProductOp op1 = outerProductChain[3];
317  arm_sme::OuterProductOp op2 = outerProductChain[2];
318  arm_sme::OuterProductOp op3 = outerProductChain[1];
319  arm_sme::OuterProductOp op4 = outerProductChain[0];
320 
321  auto loc = op.getLoc();
322  auto packInputs = [&](Value lhs, Value rhs) {
323  return createInterleave2Intrinsic(rewriter, loc, lhs, rhs);
324  };
325 
326  auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
327  op3.getLhs().getDefiningOp()->getOperand(0));
328  auto lhs1 = packInputs(op2.getLhs().getDefiningOp()->getOperand(0),
329  op4.getLhs().getDefiningOp()->getOperand(0));
330  auto lhs = packInputs(lhs0, lhs1);
331 
332  auto rhs0 = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
333  op3.getRhs().getDefiningOp()->getOperand(0));
334  auto rhs1 = packInputs(op2.getRhs().getDefiningOp()->getOperand(0),
335  op4.getRhs().getDefiningOp()->getOperand(0));
336  auto rhs = packInputs(rhs0, rhs1);
337 
338  Value lhsMask, rhsMask;
339  if (op1.getLhsMask() || op2.getLhsMask() || op3.getLhsMask() ||
340  op4.getLhsMask()) {
341  auto lhs0Mask = packInputs(op1.getLhsMask(), op3.getLhsMask());
342  auto lhs1Mask = packInputs(op2.getLhsMask(), op4.getLhsMask());
343  lhsMask = packInputs(lhs0Mask, lhs1Mask);
344 
345  auto rhs0Mask = packInputs(op1.getRhsMask(), op3.getRhsMask());
346  auto rhs1Mask = packInputs(op2.getRhsMask(), op4.getRhsMask());
347  rhsMask = packInputs(rhs0Mask, rhs1Mask);
348  }
349 
350  auto lhsExtOp = op.getLhs().getDefiningOp();
351  auto rhsExtOp = op.getRhs().getDefiningOp();
352 
353  arm_sme::CombiningKind kind = op.getKind();
354  if (kind == arm_sme::CombiningKind::Add) {
355  if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
356  // signed
357  rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>(
358  op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
359  } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
360  isa<arith::ExtUIOp>(rhsExtOp)) {
361  // unsigned
362  rewriter.replaceOpWithNewOp<arm_sme::UMopa4WayOp>(
363  op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
364  } else if (isa<arith::ExtSIOp>(lhsExtOp) &&
365  isa<arith::ExtUIOp>(rhsExtOp)) {
366  // signed by unsigned
367  rewriter.replaceOpWithNewOp<arm_sme::SuMopa4WayOp>(
368  op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
369  } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
370  isa<arith::ExtSIOp>(rhsExtOp)) {
371  // unsigned by signed
372  rewriter.replaceOpWithNewOp<arm_sme::UsMopa4WayOp>(
373  op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
374  } else {
375  llvm_unreachable("unexpected extend op!");
376  }
377  } else if (kind == arm_sme::CombiningKind::Sub) {
378  if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
379  // signed
380  rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>(
381  op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
382  } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
383  isa<arith::ExtUIOp>(rhsExtOp)) {
384  // unsigned
385  rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>(
386  op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
387  } else if (isa<arith::ExtSIOp>(lhsExtOp) &&
388  isa<arith::ExtUIOp>(rhsExtOp)) {
389  // signed by unsigned
390  rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>(
391  op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
392  } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
393  isa<arith::ExtSIOp>(rhsExtOp)) {
394  // unsigned by signed
395  rewriter.replaceOpWithNewOp<arm_sme::UsMops4WayOp>(
396  op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
397  } else {
398  llvm_unreachable("unexpected extend op!");
399  }
400  } else {
401  llvm_unreachable("unexpected arm_sme::CombiningKind!");
402  }
403 
404  rewriter.eraseOp(op3);
405  rewriter.eraseOp(op2);
406  rewriter.eraseOp(op1);
407 
408  return success();
409  }
410 
411 private:
412  // Four outer products can be fused if all of the following are true:
413  // - input and result types match.
414  // - the defining operations of the inputs are identical extensions,
415  // specifically either:
416  // - a signed or unsigned extension for integer types.
417  // - a floating-point extension for floating-point types.
418  // - the types and extension are supported, i.e. there's a 4-way operation
419  // they can be fused into.
421  canFuseOuterProducts(PatternRewriter &rewriter,
423  // Supported result types.
424  auto nxnxv4i32 =
425  VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
426  auto nxnxv2i64 =
427  VectorType::get({2, 2}, rewriter.getI64Type(), {true, true});
428 
429  // Supported input types.
430  // Note: this is before packing so these have 1/4 the number of elements
431  // of the input vector types of the 4-way operations.
432  auto nxv4i8 = VectorType::get({4}, rewriter.getI8Type(), true);
433  auto nxv2i16 = VectorType::get({2}, rewriter.getI16Type(), true);
434 
435  auto failedToMatch = [&](VectorType resultType, VectorType inputType,
436  auto lhsExtendOp, auto rhsExtendOp) {
437  using LhsExtendOpTy = decltype(lhsExtendOp);
438  using RhsExtendOpTy = decltype(rhsExtendOp);
439  for (auto op : ops) {
440  if (failed(isCompatible<LhsExtendOpTy, RhsExtendOpTy>(
441  rewriter, op, resultType, inputType)))
442  return true;
443  }
444  return false;
445  };
446 
447  if (failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
448  failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
449  failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
450  failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtSIOp{}) &&
451  failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
452  failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
453  failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
454  failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtSIOp{}))
455  return failure();
456 
457  return success();
458  }
459 };
460 
461 // Rewrites: vector.extract(arith.extend) -> arith.extend(vector.extract).
462 //
463 // This transforms IR like:
464 // %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
465 // %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
466 // Into:
467 // %0 = vector.extract %src[0] : vector<[8]xi8> from vector<4x[8]xi8>
468 // %1 = arith.extsi %0 : vector<[8]xi8> to vector<[8]xi32>
469 //
470 // This enables outer product fusion in the `-arm-sme-outer-product-fusion`
471 // pass when the result is the input to an outer product.
472 struct SwapVectorExtractOfArithExtend
473  : public OpRewritePattern<vector::ExtractOp> {
475 
476  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
477  PatternRewriter &rewriter) const override {
478  VectorType resultType = llvm::dyn_cast<VectorType>(extractOp.getType());
479  if (!resultType)
480  return rewriter.notifyMatchFailure(extractOp,
481  "extracted type is not a vector type");
482 
483  auto numScalableDims = llvm::count(resultType.getScalableDims(), true);
484  if (numScalableDims != 1)
485  return rewriter.notifyMatchFailure(
486  extractOp, "extracted type is not a 1-D scalable vector type");
487 
488  auto *extendOp = extractOp.getVector().getDefiningOp();
489  if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
490  extendOp))
491  return rewriter.notifyMatchFailure(extractOp,
492  "extract not from extend op");
493 
494  auto loc = extractOp.getLoc();
495  StringAttr extendOpName = extendOp->getName().getIdentifier();
496  Value extendSource = extendOp->getOperand(0);
497 
498  // Create new extract from source of extend.
499  Value newExtract = rewriter.create<vector::ExtractOp>(
500  loc, extendSource, extractOp.getMixedPosition());
501 
502  // Extend new extract to original result type.
503  Operation *newExtend =
504  rewriter.create(loc, extendOpName, Value(newExtract), resultType);
505 
506  rewriter.replaceOp(extractOp, newExtend);
507 
508  return success();
509  }
510 };
511 
512 // Same as above, but for vector.scalable.extract.
513 //
514 // This transforms IR like:
515 // %0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32>
516 // %1 = vector.scalable.extract %0[0] : vector<[4]xi32> from vector<[8]xi32>
517 // Into:
518 // %0 = vector.scalable.extract %src[0] : vector<[4]xi8> from vector<[8]xi8>
519 // %1 = arith.extsi %0 : vector<[4]xi8> to vector<[4]xi32>
520 //
521 // This enables outer product fusion in the `-arm-sme-outer-product-fusion`
522 // pass when the result is the input to an outer product.
523 struct SwapVectorScalableExtractOfArithExtend
524  : public OpRewritePattern<vector::ScalableExtractOp> {
526 
527  LogicalResult matchAndRewrite(vector::ScalableExtractOp extractOp,
528  PatternRewriter &rewriter) const override {
529  auto *extendOp = extractOp.getSource().getDefiningOp();
530  if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
531  extendOp))
532  return rewriter.notifyMatchFailure(extractOp,
533  "extract not from extend op");
534 
535  auto loc = extractOp.getLoc();
536  VectorType resultType = extractOp.getResultVectorType();
537 
538  Value extendSource = extendOp->getOperand(0);
539  StringAttr extendOpName = extendOp->getName().getIdentifier();
540  VectorType extendSourceVectorType =
541  cast<VectorType>(extendSource.getType());
542 
543  // Create new extract from source of extend.
544  VectorType extractResultVectorType =
545  resultType.clone(extendSourceVectorType.getElementType());
546  Value newExtract = rewriter.create<vector::ScalableExtractOp>(
547  loc, extractResultVectorType, extendSource, extractOp.getPos());
548 
549  // Extend new extract to original result type.
550  Operation *newExtend =
551  rewriter.create(loc, extendOpName, Value(newExtract), resultType);
552 
553  rewriter.replaceOp(extractOp, newExtend);
554 
555  return success();
556  }
557 };
558 
559 struct OuterProductFusionPass
560  : public arm_sme::impl::OuterProductFusionBase<OuterProductFusionPass> {
561 
562  void runOnOperation() override {
563  RewritePatternSet patterns(&getContext());
565 
566  if (failed(
567  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
568  signalPassFailure();
569  }
570 };
571 
572 } // namespace
573 
575  RewritePatternSet &patterns) {
576  MLIRContext *context = patterns.getContext();
577  // Note: High benefit to ensure extract(extend) are swapped first.
578  patterns.add<SwapVectorExtractOfArithExtend,
579  SwapVectorScalableExtractOfArithExtend>(context, 1024);
580  patterns.add<OuterProductFusion2Way, OuterProductFusion4Way>(context);
581 }
582 
584  return std::make_unique<OuterProductFusionPass>();
585 }
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
IntegerType getI16Type()
Definition: Builders.cpp:81
FloatType getF32Type()
Definition: Builders.cpp:63
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getI32Type()
Definition: Builders.cpp:83
FloatType getF16Type()
Definition: Builders.cpp:59
FloatType getBF16Type()
Definition: Builders.cpp:57
IntegerType getI8Type()
Definition: Builders.cpp:79
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:305
Builder & setDim(unsigned pos, int64_t val)
Set a dim in shape @pos to val.
Definition: BuiltinTypes.h:339
std::unique_ptr< Pass > createOuterProductFusionPass()
Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening variants.
void populateOuterProductFusionPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362