MLIR  22.0.0git
LowerVectorContract.cpp
Go to the documentation of this file.
1 //===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.contract' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
25 
26 #define DEBUG_TYPE "vector-contract-lowering"
27 
28 using namespace mlir;
29 using namespace mlir::vector;
30 
31 //===----------------------------------------------------------------------===//
32 // Helper functions
33 //===----------------------------------------------------------------------===//
34 // Helper to find an index in an affine map.
35 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
36  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
37  int64_t idx = map.getDimPosition(i);
38  if (idx == index)
39  return i;
40  }
41  return std::nullopt;
42 }
43 
44 // Helper to construct iterator types with one index removed.
45 static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
46  int64_t index) {
47  SmallVector<Attribute> results;
48  for (const auto &it : llvm::enumerate(iteratorTypes)) {
49  int64_t idx = it.index();
50  if (idx == index)
51  continue;
52  results.push_back(it.value());
53  }
54  return results;
55 }
56 
57 // Helper to construct an affine map with one index removed.
58 static AffineMap adjustMap(AffineMap map, int64_t index,
59  PatternRewriter &rewriter) {
60  auto *ctx = rewriter.getContext();
62  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
63  int64_t idx = map.getDimPosition(i);
64  if (idx == index)
65  continue;
66  // Re-insert remaining indices, but renamed when occurring
67  // after the removed index.
68  auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
69  results.push_back(targetExpr);
70  }
71  return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
72 }
73 
74 // Helper method to possibly drop a dimension in a load.
75 // TODO
76 static Value reshapeLoad(Location loc, Value val, VectorType type,
77  int64_t index, int64_t pos,
78  PatternRewriter &rewriter) {
79  if (index == -1)
80  return val;
81 
82  // At extraction dimension?
83  if (index == 0)
84  return vector::ExtractOp::create(rewriter, loc, val, pos);
85 
86  // Unroll leading dimensions.
87  VectorType vType = VectorType::Builder(type).dropDim(0);
88  VectorType resType = VectorType::Builder(type).dropDim(index);
89  Value result = arith::ConstantOp::create(rewriter, loc, resType,
90  rewriter.getZeroAttr(resType));
91  for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
92  Value ext = vector::ExtractOp::create(rewriter, loc, val, d);
93  Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
94  result = vector::InsertOp::create(rewriter, loc, load, result, d);
95  }
96  return result;
97 }
98 
99 // Helper method to possibly drop a dimension in a store.
100 // TODO
101 static Value reshapeStore(Location loc, Value val, Value result,
102  VectorType type, int64_t index, int64_t pos,
103  PatternRewriter &rewriter) {
104  // Unmodified?
105  if (index == -1)
106  return val;
107  // At insertion dimension?
108  if (index == 0)
109  return vector::InsertOp::create(rewriter, loc, val, result, pos);
110 
111  // Unroll leading dimensions.
112  VectorType vType = VectorType::Builder(type).dropDim(0);
113  for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
114  Value ext = vector::ExtractOp::create(rewriter, loc, result, d);
115  Value ins = vector::ExtractOp::create(rewriter, loc, val, d);
116  Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
117  result = vector::InsertOp::create(rewriter, loc, sto, result, d);
118  }
119  return result;
120 }
121 
122 /// Helper to create arithmetic operation associated with a kind of contraction.
123 static std::optional<Value>
125  vector::CombiningKind kind, PatternRewriter &rewriter,
126  bool isInt, Value mask = Value()) {
127  using vector::CombiningKind;
128  Value mul;
129 
130  if (isInt) {
131  if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
132  kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
133  // Only valid for floating point types.
134  return std::nullopt;
135  mul = arith::MulIOp::create(rewriter, loc, x, y);
136  } else {
137  // Float case.
138  if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
139  kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
140  kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
141  kind == CombiningKind::XOR)
142  // Only valid for integer types.
143  return std::nullopt;
144  // Special case for fused multiply-add.
145  if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
146  Value fma = vector::FMAOp::create(rewriter, loc, x, y, acc);
147  if (mask)
148  // The fma op doesn't need explicit masking. However, fma ops used in
149  // reductions must preserve previous 'acc' values for masked-out lanes.
150  fma = selectPassthru(rewriter, mask, fma, acc);
151  return fma;
152  }
153  mul = arith::MulFOp::create(rewriter, loc, x, y);
154  }
155 
156  if (!acc)
157  return std::optional<Value>(mul);
158 
159  return makeArithReduction(rewriter, loc, kind, mul, acc,
160  /*fastmath=*/nullptr, mask);
161 }
162 
163 /// Return the positions of the reductions in the given map.
165  ArrayAttr iteratorTypes) {
166  SmallVector<int64_t> dimsIdx;
167  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
168  if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
169  dimsIdx.push_back(i);
170  }
171  return dimsIdx;
172 }
173 
174 /// Look for a given dimension in an affine map and return its position. Return
175 /// std::nullopt if the dimension is not in the map results.
176 static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
177  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
178  if (map.getDimPosition(i) == dim)
179  return i;
180  }
181  return std::nullopt;
182 }
183 
184 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
185 /// operands `x` and `y`.
186 static Value createAdd(Location loc, Value x, Value y, bool isInt,
187  PatternRewriter &rewriter) {
188  if (isInt)
189  return arith::AddIOp::create(rewriter, loc, x, y);
190  return arith::AddFOp::create(rewriter, loc, x, y);
191 }
192 
193 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
194 /// operands `x and `y`.
195 static Value createMul(Location loc, Value x, Value y, bool isInt,
196  PatternRewriter &rewriter) {
197  if (isInt)
198  return arith::MulIOp::create(rewriter, loc, x, y);
199  return arith::MulFOp::create(rewriter, loc, x, y);
200 }
201 
202 namespace {
203 
204 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
205 /// semantics to a reduction_size-unrolled sequence:
206 /// ```
207 /// %at = vector.transpose %a, [1, 0]
208 /// %bRow0 = vector.extract %b[0]
209 /// %atRow0 = vector.extract %at[0]
210 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
211 /// ...
212 /// %bRowK = vector.extract %b[K]
213 /// %atRowK = vector.extract %at[K]
214 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
215 /// ```
216 ///
217 /// This only kicks in when vectorContractLowering is set to OuterProduct and
218 /// the vector.contract op is a row-major matrix multiply.
219 class ContractionOpToOuterProductOpLowering
220  : public MaskableOpRewritePattern<vector::ContractionOp> {
221 public:
222  using MaskableOpRewritePattern::MaskableOpRewritePattern;
223 
224  using FilterConstraintType =
225  std::function<LogicalResult(vector::ContractionOp op)>;
226 
227  static LogicalResult defaultFilter(vector::ContractionOp op) {
228  return success();
229  }
230 
231  ContractionOpToOuterProductOpLowering(
232  vector::VectorContractLowering vectorContractLowering,
233  MLIRContext *context, PatternBenefit benefit = 1,
234  FilterConstraintType constraint = defaultFilter)
235  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
236  vectorContractLowering(vectorContractLowering),
237  filter(std::move(constraint)) {}
238 
239  FailureOr<Value>
240  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
241  PatternRewriter &rewriter) const override;
242 
243 private:
244  /// Options to control the vector patterns.
245  vector::VectorContractLowering vectorContractLowering;
246  FilterConstraintType filter;
247 };
248 
249 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
250 /// semantics to an output-size-unrolled sequence:
251 /// ```
252 /// %out = arith.constant ... : vector<MxNxelt_type>
253 /// %bt = vector.transpose %b, [1, 0]
254 /// %aRow0 = vector.extract %a[0]
255 /// %btRow0 = vector.extract %bt[0]
256 /// %c00 = vector.reduce %atRow0, %bRow0
257 /// %out00 = vector.insert %c00, %out[0, 0]
258 /// ...
259 /// %aRowLast = vector.extract %at[M-1]
260 /// %btRowLast = vector.extract %b[N-1]
261 /// %cLastLast = vector.reduce %atRowLast, %bRowLast
262 /// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
263 /// ```
264 ///
265 /// This only kicks in when VectorTransformsOptions is set to Dot and
266 /// the vector.contract op is a row-major matmul or matvec.
267 class ContractionOpToDotLowering
268  : public MaskableOpRewritePattern<vector::ContractionOp> {
269 public:
270  using MaskableOpRewritePattern::MaskableOpRewritePattern;
271 
272  using FilterConstraintType =
273  std::function<LogicalResult(vector::ContractionOp op)>;
274 
275  static LogicalResult defaultFilter(vector::ContractionOp op) {
276  return success();
277  }
278 
279  ContractionOpToDotLowering(
280  vector::VectorContractLowering vectorContractLowering,
281  MLIRContext *context, PatternBenefit benefit = 1,
282  const FilterConstraintType &constraint = defaultFilter)
283  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
284  vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
285 
286  FailureOr<Value>
287  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
288  PatternRewriter &rewriter) const override;
289 
290 private:
291  /// Options to control the vector patterns.
292  vector::VectorContractLowering vectorContractLowering;
293  FilterConstraintType filter;
294 };
295 
296 /// Progressive lowering of ContractionOp.
297 ///
298 /// One:
299 /// %x = vector.contract with at least one free/batch dimension
300 /// is replaced by:
301 /// %a = vector.contract with one less free/batch dimension
302 /// %b = vector.contract with one less free/batch dimension
303 /// ..
304 /// %x = combine %a %b ..
305 /// until a pure contraction is reached (no free/batch dimensions),
306 /// which is replaced by a dot-product.
307 ///
308 /// This only kicks in when either VectorTransformsOptions is set
309 /// to Dot or when other contraction patterns fail.
310 class ContractionOpLowering
311  : public MaskableOpRewritePattern<vector::ContractionOp> {
312 public:
313  using MaskableOpRewritePattern::MaskableOpRewritePattern;
314  using FilterConstraintType =
315  std::function<LogicalResult(vector::ContractionOp op)>;
316 
317  static LogicalResult defaultFilter(vector::ContractionOp op) {
318  return success();
319  }
320 
321  ContractionOpLowering(
322  vector::VectorContractLowering vectorContractLoweringOption,
323  MLIRContext *context, PatternBenefit benefit = 1,
324  FilterConstraintType constraint = defaultFilter)
325  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
326  vectorContractLoweringOption(vectorContractLoweringOption),
327  filter(std::move(constraint)) {}
328 
329  FailureOr<Value>
330  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
331  PatternRewriter &rewriter) const override;
332 
333 private:
334  /// Options to control the vector patterns.
335  vector::VectorContractLowering vectorContractLoweringOption;
336  FilterConstraintType filter;
337  // Lower one parallel dimension.
338  FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
339  vector::ContractionOp op, int64_t lhsIndex,
340  int64_t rhsIndex, Value mask) const;
341  // Lower one reduction dimension.
342  FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
343  vector::ContractionOp op, Value mask) const;
344 };
345 
346 /// Generate a vector implementation for matmat, matvec and tmatvec.
347 /// This unrolls outer-products along the reduction dimension.
348 struct UnrolledOuterProductGenerator
349  : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
350  UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
351  : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
352  kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
353  res(op.getAcc()), lhsType(op.getLhsType()) {
354  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
355  if (maskableOp.isMasked())
356  mask = maskableOp.getMaskingOp().getMask();
357  }
358 
359  Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
360  if (!v)
361  return v;
362  return vector::TransposeOp::create(rewriter, loc, v, perm);
363  }
364 
365  Value promote(Value v, Type dstElementType) {
366  Type elementType = v.getType();
367  auto vecType = dyn_cast<VectorType>(elementType);
368  if (vecType)
369  elementType = vecType.getElementType();
370  if (elementType == dstElementType)
371  return v;
372  Type promotedType = dstElementType;
373  if (vecType)
374  promotedType = vecType.clone(promotedType);
375  if (isa<FloatType>(dstElementType))
376  return arith::ExtFOp::create(rewriter, loc, promotedType, v);
377  return arith::ExtSIOp::create(rewriter, loc, promotedType, v);
378  }
379 
380  FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
381  VectorType lhsType, int reductionSize,
382  std::optional<Value> maybeMask = std::nullopt) {
383  // Incremental support for masking.
384  if (mask && !maybeMask.has_value())
385  return failure();
386 
387  Type resElementType = cast<VectorType>(res.getType()).getElementType();
388  for (int64_t k = 0; k < reductionSize; ++k) {
389  Value extractA = vector::ExtractOp::create(rewriter, loc, lhs, k);
390  Value extractB = vector::ExtractOp::create(rewriter, loc, rhs, k);
391  extractA = promote(extractA, resElementType);
392  extractB = promote(extractB, resElementType);
393  Value extractMask;
394  if (maybeMask.has_value() && maybeMask.value())
395  extractMask =
396  vector::ExtractOp::create(rewriter, loc, maybeMask.value(), k);
397 
398  Operation *outerProdOp = vector::OuterProductOp::create(
399  rewriter, loc, res.getType(), extractA, extractB, res, kind);
400  res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
401  }
402  return res;
403  }
404 
405  /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of
406  /// dimension `reductionDim`. If the dimension is a scalable dimension,
407  /// returns "nullopt".
408  std::optional<int64_t> getReductionSize(VectorType vecType,
409  int64_t reductionDim) {
410  // Cannot unroll scalable dimension.
411  if (vecType.getScalableDims()[reductionDim])
412  return std::nullopt;
413  int64_t reductionSize = vecType.getDimSize(reductionDim);
414  assert(reductionSize > 0 &&
415  "Reduction dim must be a known static size to allow unrolling");
416  return reductionSize;
417  }
418 
419  /// Two outer parallel, one inner reduction (matmat flavor).
420  FailureOr<Value> matmat() {
421  if (!iters({Par(), Par(), Red()}))
422  return failure();
423  // Set up the parallel/reduction structure in the right form.
424  AffineExpr m, n, k;
425  bindDims(rewriter.getContext(), m, n, k);
426 
427  // Classical row-major matmul: Just permute the lhs.
428  if (layout({{m, k}, {k, n}, {m, n}})) {
429  if (auto reductionSize = getReductionSize(lhsType, 1)) {
430  // Note: `t` creates new IR. It must be nested within this `if` check
431  // so that no IR is created when then pattern returns "failure".
432  Value tLhs = t(lhs);
433  Value tMask = t(mask, {2, 0, 1});
434  return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
435  }
436  }
437  // TODO: may be better to fail and use some vector<k> -> scalar reduction.
438  if (layout({{m, k}, {n, k}, {m, n}})) {
439  if (auto reductionSize = getReductionSize(lhsType, 1)) {
440  Value tLhs = t(lhs);
441  Value tRhs = t(rhs);
442  Value tMask = t(mask, {2, 0, 1});
443  return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
444  }
445  }
446  // No need to permute anything.
447  if (layout({{k, m}, {k, n}, {m, n}})) {
448  if (auto reductionSize = getReductionSize(lhsType, 0)) {
449  Value tMask = t(mask, {2, 0, 1});
450  return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
451  }
452  }
453  // Just permute the rhs.
454  if (layout({{k, m}, {n, k}, {m, n}})) {
455  if (auto reductionSize = getReductionSize(lhsType, 0)) {
456  Value tRhs = t(rhs);
457  Value tMask = t(mask, {2, 0, 1});
458  return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
459  }
460  }
461  // Transposed output: swap RHS and LHS.
462  // Classical row-major matmul: permute the lhs.
463  if (layout({{m, k}, {k, n}, {n, m}})) {
464  if (auto reductionSize = getReductionSize(lhsType, 1)) {
465  Value tLhs = t(lhs);
466  Value tMask = t(mask, {2, 0, 1});
467  return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
468  }
469  }
470  // TODO: may be better to fail and use some vector<k> -> scalar reduction.
471  if (layout({{m, k}, {n, k}, {n, m}})) {
472  if (auto reductionSize = getReductionSize(lhsType, 1)) {
473  Value tRhs = t(rhs);
474  Value tLhs = t(lhs);
475  Value tMask = t(mask, {2, 0, 1});
476  return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
477  }
478  }
479  if (layout({{k, m}, {k, n}, {n, m}})) {
480  if (auto reductionSize = getReductionSize(lhsType, 0)) {
481  Value tMask = t(mask, {2, 0, 1});
482  return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
483  }
484  }
485  if (layout({{k, m}, {n, k}, {n, m}})) {
486  if (auto reductionSize = getReductionSize(lhsType, 0)) {
487  Value tRhs = t(rhs);
488  Value tMask = t(mask, {2, 0, 1});
489  return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
490  }
491  }
492  return failure();
493  }
494 
495  //
496  // One outer parallel, one inner reduction (matvec flavor).
497  // Mask needs to be transposed everywhere to turn the reduction dimension
498  // outermost as required by outerproduct.
499  //
500  FailureOr<Value> matvec() {
501  if (!iters({Par(), Red()}))
502  return failure();
503  AffineExpr m, k;
504  bindDims(rewriter.getContext(), m, k);
505 
506  // Case mat-vec: transpose.
507  if (layout({{m, k}, {k}, {m}})) {
508  if (auto reductionSize = getReductionSize(lhsType, 1)) {
509  Value tLhs = t(lhs);
510  Value tMask = t(mask);
511  return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
512  }
513  }
514  // Case mat-trans-vec: ready to go.
515  if (layout({{k, m}, {k}, {m}})) {
516  if (auto reductionSize = getReductionSize(lhsType, 0)) {
517  Value tMask = t(mask);
518  return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
519  }
520  }
521  // Case vec-mat: swap and transpose.
522  if (layout({{k}, {m, k}, {m}})) {
523  if (auto reductionSize = getReductionSize(lhsType, 0)) {
524  Value tRhs = t(rhs);
525  Value tMask = t(mask);
526  return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
527  }
528  }
529  // Case vec-mat-trans: swap and ready to go.
530  if (layout({{k}, {k, m}, {m}})) {
531  if (auto reductionSize = getReductionSize(lhsType, 0)) {
532  Value tMask = t(mask);
533  return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
534  }
535  }
536  return failure();
537  }
538 
539  //
540  // One outer reduction, one inner parallel (tmatvec flavor).
541  // Mask already has the shape of the outer product.
542  //
543  FailureOr<Value> tmatvec() {
544  if (!iters({Red(), Par()}))
545  return failure();
546  AffineExpr k, m;
547  bindDims(rewriter.getContext(), k, m);
548 
549  // Case mat-vec: transpose.
550  if (layout({{m, k}, {k}, {m}}))
551  if (auto reductionSize = getReductionSize(lhsType, 1))
552  return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
553  // Case mat-trans-vec: ready to go.
554  if (layout({{k, m}, {k}, {m}}))
555  if (auto reductionSize = getReductionSize(lhsType, 0))
556  return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
557  // Case vec-mat: swap and transpose.
558  if (layout({{k}, {m, k}, {m}}))
559  if (auto reductionSize = getReductionSize(lhsType, 0))
560  return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
561  // Case vec-mat-trans: swap and ready to go.
562  if (layout({{k}, {k, m}, {m}}))
563  if (auto reductionSize = getReductionSize(lhsType, 0))
564  return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
565  return failure();
566  }
567 
568 private:
569  vector::CombiningKind kind;
570  Value lhs, rhs, res, mask;
571  VectorType lhsType;
572 };
573 
574 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
575 /// semantics to a reduction_size-unrolled sequence:
576 /// ```
577 /// %at = vector.transpose %a, [1, 0]
578 /// %bRow0 = vector.extract %b[0]
579 /// %atRow0 = vector.extract %at[0]
580 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
581 /// ...
582 /// %bRowK = vector.extract %b[K]
583 /// %atRowK = vector.extract %at[K]
584 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
585 /// ```
586 ///
587 /// This only kicks in when vectorContractLowering is set to OuterProduct but
588 /// otherwise supports any layout permutation of the matrix-multiply.
589 FailureOr<Value>
590 ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
591  vector::ContractionOp op, MaskingOpInterface maskOp,
592  PatternRewriter &rewriter) const {
593  if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
594  return failure();
595 
596  if (failed(filter(op)))
597  return failure();
598 
599  UnrolledOuterProductGenerator e(rewriter, op);
600  FailureOr<Value> matmatRes = e.matmat();
601  if (succeeded(matmatRes)) {
602  return matmatRes;
603  }
604  FailureOr<Value> matvecRes = e.matvec();
605  if (succeeded(matvecRes)) {
606  return matvecRes;
607  }
608 
609  FailureOr<Value> tmatvecRes = e.tmatvec();
610  return tmatvecRes;
611 }
612 
613 FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
614  vector::ContractionOp op, MaskingOpInterface maskOp,
615  PatternRewriter &rewriter) const {
616  // TODO: Support vector.mask.
617  if (maskOp)
618  return failure();
619 
620  if (failed(filter(op)))
621  return failure();
622 
623  if (vectorContractLowering != vector::VectorContractLowering::Dot)
624  return failure();
625 
626  auto iteratorTypes = op.getIteratorTypes().getValue();
627  static constexpr std::array<int64_t, 2> perm = {1, 0};
628  Location loc = op.getLoc();
629  Value lhs = op.getLhs(), rhs = op.getRhs();
630 
631  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
632  auto infer = [&](MapList m) {
633  return AffineMap::inferFromExprList(m, op.getContext());
634  };
635  AffineExpr m, n, k;
636  bindDims(rewriter.getContext(), m, n, k);
637  SmallVector<AffineMap> maps = op.getIndexingMapsArray();
638  //
639  // In the following we wish to make the reduction dimension innermost so we
640  // can load vectors and just fmul + reduce into a scalar.
641  //
642  if (isParallelIterator(iteratorTypes[0]) &&
643  isParallelIterator(iteratorTypes[1]) &&
644  isReductionIterator(iteratorTypes[2])) {
645  //
646  // Two outer parallel, one inner reduction (matmat flavor).
647  //
648  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
649  rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
650  } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
651  // No need to permute anything.
652  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
653  lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
654  rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
655  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
656  lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
657  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
658  // This is the classical row-major matmul. Just permute the lhs.
659  Value tmp = lhs;
660  lhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
661  rhs = tmp;
662  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
663  std::swap(lhs, rhs);
664  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
665  Value tmp = lhs;
666  lhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
667  rhs = vector::TransposeOp::create(rewriter, loc, tmp, perm);
668  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
669  Value tmp = rhs;
670  rhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
671  lhs = tmp;
672  } else {
673  return failure();
674  }
675  } else if (isParallelIterator(iteratorTypes[0]) &&
676  isReductionIterator(iteratorTypes[1])) {
677  //
678  // One outer parallel, one inner reduction (matvec flavor)
679  //
680  if (maps == infer({{m, n}, {n}, {m}})) {
681  // No need to permute anything.
682  } else if (maps == infer({{n, m}, {n}, {m}})) {
683  lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
684  } else if (maps == infer({{n}, {m, n}, {m}})) {
685  std::swap(lhs, rhs);
686  } else if (maps == infer({{n}, {n, m}, {m}})) {
687  std::swap(lhs, rhs);
688  lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
689  } else {
690  return failure();
691  }
692  } else {
693  return failure();
694  }
695 
696  VectorType dstType = cast<VectorType>(op.getResultType());
697  assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
698  "Expected dst type of rank 1 or 2");
699 
700  unsigned rank = dstType.getRank();
701  unsigned dstRows = dstType.getShape()[0];
702  unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
703 
704  // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
705  Value res = arith::ConstantOp::create(rewriter, loc, dstType,
706  rewriter.getZeroAttr(dstType));
707  bool isInt = isa<IntegerType>(dstType.getElementType());
708  llvm::SmallVector<Value> extractedCols;
709  extractedCols.reserve(dstColumns);
710  for (unsigned r = 0; r < dstRows; ++r) {
711  Value rowLhs = vector::ExtractOp::create(rewriter, op.getLoc(), lhs, r);
712  for (unsigned c = 0; c < dstColumns; ++c) {
713  // Extract each respective row and column of the LHS and RHS once to
714  // avoid having duplicate SSA values pointing to the same rows/columns.
715  if (r == 0) {
716  Value colRhs =
717  rank == 1
718  ? rhs
719  : vector::ExtractOp::create(rewriter, op.getLoc(), rhs, c);
720  extractedCols.push_back(colRhs);
721  }
722  Value extractedColRhs = extractedCols[c];
723  Value product =
724  createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter);
725  Value sum = vector::ReductionOp::create(
726  rewriter, op.getLoc(), vector::CombiningKind::ADD, product);
727 
729  : SmallVector<int64_t, 2>{r, c};
730  res = vector::InsertOp::create(rewriter, op.getLoc(), sum, res, pos);
731  }
732  }
733  if (auto acc = op.getAcc())
734  res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
735  return res;
736 }
737 
738 /// Lower vector.contract with all size one reduction dimensions to
739 /// elementwise ops when possible.
740 struct ContractOpToElementwise
741  : public MaskableOpRewritePattern<vector::ContractionOp> {
742  using MaskableOpRewritePattern::MaskableOpRewritePattern;
743  using FilterConstraintType =
744  std::function<LogicalResult(vector::ContractionOp op)>;
745  static LogicalResult defaultFilter(vector::ContractionOp op) {
746  return success();
747  }
748  ContractOpToElementwise(
749  vector::VectorContractLowering vectorContractLowering,
750  MLIRContext *context, PatternBenefit benefit = 1,
751  const FilterConstraintType &constraint = defaultFilter)
752  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
753  vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
754 
755  FailureOr<Value>
756  matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
757  MaskingOpInterface maskOp,
758  PatternRewriter &rewriter) const override {
759  // TODO: Support vector.mask.
760  if (maskOp)
761  return failure();
762 
763  if (failed(filter(contractOp)))
764  return failure();
765 
766  if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
767  return failure();
768 
769  ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
770  ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
771  AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
772  AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
773  SmallVector<int64_t> lhsReductionDims =
774  getReductionIndex(lhsMap, contractOp.getIteratorTypes());
775  SmallVector<int64_t> rhsReductionDims =
776  getReductionIndex(rhsMap, contractOp.getIteratorTypes());
777  // All the reduction dimensions must be a size 1.
778  for (int64_t dim : lhsReductionDims) {
779  if (lhsShape[dim] != 1)
780  return failure();
781  }
782  for (int64_t dim : rhsReductionDims) {
783  if (rhsShape[dim] != 1)
784  return failure();
785  }
786  AffineMap accMap = contractOp.getIndexingMapsArray()[2];
787  unsigned numParallelDims = accMap.getNumResults();
788  unsigned numLhsDimToBroadcast =
789  numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
790  unsigned numRhsDimToBroadcast =
791  numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
792  SmallVector<int64_t> lhsDims;
793  SmallVector<int64_t> lhsTranspose;
794  SmallVector<int64_t> rhsDims;
795  SmallVector<int64_t> rhsTranspose;
796  for (int64_t dim : lhsReductionDims)
797  lhsTranspose.push_back(numLhsDimToBroadcast + dim);
798  for (int64_t dim : rhsReductionDims)
799  rhsTranspose.push_back(numRhsDimToBroadcast + dim);
800  // Loop through the parallel dimensions to calculate the dimensions to
801  // broadcast and to permute in order to extract only parallel dimensions.
802  for (unsigned i = 0; i < numParallelDims; i++) {
803  std::optional<unsigned> lhsDim =
804  getDimPosition(lhsMap, accMap.getDimPosition(i));
805  if (lhsDim) {
806  lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
807  } else {
808  // If the parallel dimension doesn't exist we will have to broadcast it.
809  lhsDims.push_back(
810  cast<VectorType>(contractOp.getResultType()).getDimSize(i));
811  lhsTranspose.push_back(lhsDims.size() - 1);
812  }
813  std::optional<unsigned> rhsDim =
814  getDimPosition(rhsMap, accMap.getDimPosition(i));
815  if (rhsDim) {
816  rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
817  } else {
818  // If the parallel dimension doesn't exist we will have to broadcast it.
819  rhsDims.push_back(
820  cast<VectorType>(contractOp.getResultType()).getDimSize(i));
821  rhsTranspose.push_back(rhsDims.size() - 1);
822  }
823  }
824  Value newLhs = contractOp.getLhs();
825  Value newRhs = contractOp.getRhs();
826  Location loc = contractOp.getLoc();
827  if (!lhsDims.empty()) {
828  lhsDims.append(lhsShape.begin(), lhsShape.end());
829  auto expandedType =
830  VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
831  newLhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newLhs);
832  }
833  if (!rhsDims.empty()) {
834  rhsDims.append(rhsShape.begin(), rhsShape.end());
835  auto expandedType =
836  VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
837  newRhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newRhs);
838  }
839  bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
840  newLhs = vector::TransposeOp::create(rewriter, loc, newLhs, lhsTranspose);
841  newRhs = vector::TransposeOp::create(rewriter, loc, newRhs, rhsTranspose);
842  SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
843  SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
844  newLhs = vector::ExtractOp::create(rewriter, loc, newLhs, lhsOffsets);
845  newRhs = vector::ExtractOp::create(rewriter, loc, newRhs, rhsOffsets);
846  std::optional<Value> result =
847  createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
848  contractOp.getKind(), rewriter, isInt);
849  if (result)
850  return *result;
851 
852  return failure();
853  }
854 
855 private:
856  /// Options to control the vector patterns.
857  vector::VectorContractLowering vectorContractLowering;
858  FilterConstraintType filter;
859 };
860 
861 /// Progressive lowering of ContractionOp.
862 /// One:
863 /// %x = vector.contract with at least one free/batch dimension
864 /// is replaced by:
865 /// %a = vector.contract with one less free/batch dimension
866 /// %b = vector.contract with one less free/batch dimension
867 /// ..
868 /// %x = combine %a %b ..
869 /// until a pure contraction is reached (no free/batch dimensions),
870 /// which is replaced by a dot-product.
871 ///
872 /// This only kicks in when either vectorContractLoweringOption is set
873 /// to DOT or when other contraction patterns fail.
874 //
875 // TODO: break down into transpose/reshape/cast ops
876 // when they become available to avoid code dup
877 // TODO: investigate lowering order impact on performance
878 FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
879  vector::ContractionOp op, MaskingOpInterface maskOp,
880  PatternRewriter &rewriter) const {
881  if (failed(filter(op)))
882  return failure();
883 
884  // TODO: support mixed mode contract lowering.
885  if (op.getLhsType().getElementType() !=
886  getElementTypeOrSelf(op.getAccType()) ||
887  op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
888  return failure();
889 
890  // TODO: the code below assumes the default contraction, make sure it supports
891  // other kinds before enabling this lowering.
892  if (op.getKind() != vector::CombiningKind::ADD) {
893  return rewriter.notifyMatchFailure(
894  op, "contractions other than 'add' not supported");
895  }
896 
897  // TODO: implement benefits, cost models.
898  MLIRContext *ctx = op.getContext();
899 
900  ContractionOpToOuterProductOpLowering pat1(vectorContractLoweringOption, ctx);
901  FailureOr<Value> newVal1 =
902  pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
903  if (!failed(newVal1))
904  return newVal1;
905 
906  ContractionOpToDotLowering pat2(vectorContractLoweringOption, ctx);
907  FailureOr<Value> newVal2 =
908  pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
909  if (!failed(newVal2))
910  return newVal2;
911 
912  ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
913  FailureOr<Value> newVal4 =
914  pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
915  if (!failed(newVal4))
916  return newVal4;
917 
918  // Vector mask setup.
919 
920  Value mask;
921  if (maskOp)
922  mask = maskOp.getMask();
923  // Find first batch dimension in LHS/RHS, and lower when found.
924  std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
925  if (!batchDimMap.empty()) {
926  int64_t lhsIndex = batchDimMap[0].first;
927  int64_t rhsIndex = batchDimMap[0].second;
928  auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
929  if (failed(newOp))
930  return failure();
931  return newOp;
932  }
933 
934  // Collect contracting dimensions.
935  std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
936  op.getContractingDimMap();
937  DenseSet<int64_t> lhsContractingDimSet;
938  DenseSet<int64_t> rhsContractingDimSet;
939  for (auto &dimPair : contractingDimMap) {
940  lhsContractingDimSet.insert(dimPair.first);
941  rhsContractingDimSet.insert(dimPair.second);
942  }
943 
944  // Find first free dimension in LHS, and lower when found.
945  VectorType lhsType = op.getLhsType();
946  for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
947  if (lhsContractingDimSet.count(lhsIndex) == 0) {
948  auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
949  if (failed(newOp))
950  return failure();
951  return newOp;
952  }
953  }
954 
955  // Find first free dimension in RHS, and lower when found.
956  VectorType rhsType = op.getRhsType();
957  for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
958  if (rhsContractingDimSet.count(rhsIndex) == 0) {
959  auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
960  if (failed(newOp))
961  return failure();
962  return newOp;
963  }
964  }
965 
966  // Lower the first remaining reduction dimension.
967  if (!contractingDimMap.empty()) {
968  auto newOp = lowerReduction(rewriter, op, mask);
969  if (failed(newOp))
970  return failure();
971  return newOp;
972  }
973 
974  return failure();
975 }
976 
977 // Lower one parallel dimension.
978 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
979 // TODO: consider reusing existing contract unrolling
980 FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
981  vector::ContractionOp op,
982  int64_t lhsIndex,
983  int64_t rhsIndex,
984  Value mask) const {
985  VectorType lhsType = op.getLhsType();
986  VectorType rhsType = op.getRhsType();
987  VectorType resType = cast<VectorType>(op.getResultType());
988  // Find the iterator type index and result index.
989  SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
990  int64_t iterIndex = -1;
991  int64_t dimSize = -1;
992  if (lhsIndex >= 0) {
993  iterIndex = iMap[0].getDimPosition(lhsIndex);
994  if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
995  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
996  diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
997  << " to map to the same dimension";
998  });
999  if (lhsType.getScalableDims()[lhsIndex])
1000  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1001  diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex
1002  << ") is not supported yet";
1003  });
1004  dimSize = lhsType.getDimSize(lhsIndex);
1005  } else if (rhsIndex >= 0) {
1006  iterIndex = iMap[1].getDimPosition(rhsIndex);
1007  if (rhsType.getScalableDims()[rhsIndex])
1008  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1009  diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex
1010  << ") is not supported yet";
1011  });
1012  dimSize = rhsType.getDimSize(rhsIndex);
1013  }
1014  if (iterIndex < 0)
1015  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1016  diag << "expected either lhsIndex=" << lhsIndex
1017  << " or rhsIndex=" << rhsIndex << " to be nonnegative";
1018  });
1019  // value_or(-1) means that we tolerate a dimension not appearing
1020  // in the result map. That can't happen for actual parallel iterators, but
1021  // the caller ContractionOpLowering::matchAndRewrite is currently calling
1022  // lowerParallel also for the case of unit-size reduction dims appearing only
1023  // on one of LHS or RHS, not both. At the moment, such cases are created by
1024  // CastAwayContractionLeadingOneDim, so we need to either support that or
1025  // modify that pattern.
1026  int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
1027  if (resIndex == -1 && dimSize != 1)
1028  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1029  diag << "expected the dimension for iterIndex=" << iterIndex
1030  << " to either appear in the result map, or to be a unit dimension";
1031  });
1032 
1033  // Construct new iterator types and affine map array attribute.
1034  std::array<AffineMap, 3> lowIndexingMaps = {
1035  adjustMap(iMap[0], iterIndex, rewriter),
1036  adjustMap(iMap[1], iterIndex, rewriter),
1037  adjustMap(iMap[2], iterIndex, rewriter)};
1038  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1039  auto lowIter =
1040  rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1041  // Unroll into a series of lower dimensional vector.contract ops.
1042  Location loc = op.getLoc();
1043  Value result = arith::ConstantOp::create(rewriter, loc, resType,
1044  rewriter.getZeroAttr(resType));
1045 
1046  for (int64_t d = 0; d < dimSize; ++d) {
1047  auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1048  auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1049  auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1050 
1051  Value lowMask;
1052  if (mask)
1053  lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1054  iterIndex, d, rewriter);
1055 
1056  Operation *lowContract = vector::ContractionOp::create(
1057  rewriter, loc, lhs, rhs, acc, lowAffine, lowIter);
1058  lowContract = maskOperation(rewriter, lowContract, lowMask);
1059  result = reshapeStore(loc, lowContract->getResult(0), result, resType,
1060  resIndex, d, rewriter);
1061  }
1062  return result;
1063 }
1064 
1065 // Lower one reduction dimension.
1066 FailureOr<Value> ContractionOpLowering::lowerReduction(
1067  PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
1068  auto loc = op.getLoc();
1069  VectorType lhsType = op.getLhsType();
1070  VectorType rhsType = op.getRhsType();
1071  Type resType = op.getResultType();
1072  if (isa<VectorType>(resType))
1073  return rewriter.notifyMatchFailure(op,
1074  "did not expect a VectorType result");
1075  bool isInt = isa<IntegerType>(resType);
1076  // Use iterator index 0.
1077  int64_t iterIndex = 0;
1078  SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1079  std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1080  std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1081  if (!lookupLhs.has_value())
1082  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1083  diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1084  });
1085  if (!lookupRhs.has_value())
1086  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1087  diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1088  });
1089  int64_t lhsIndex = *lookupLhs;
1090  int64_t rhsIndex = *lookupRhs;
1091  int64_t dimSize = lhsType.getDimSize(lhsIndex);
1092  if (dimSize != rhsType.getDimSize(rhsIndex))
1093  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1094  diag << "expect LHS dimension " << lhsIndex
1095  << " to have the same size as RHS dimension " << rhsIndex;
1096  });
1097  // Base case.
1098  if (lhsType.getRank() == 1) {
1099  if (rhsType.getRank() != 1)
1100  return rewriter.notifyMatchFailure(
1101  op, "When LHS has rank 1, expected also RHS to have rank 1");
1102  Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1103  auto kind = vector::CombiningKind::ADD;
1104 
1105  Value acc = op.getAcc();
1106  Operation *reductionOp =
1107  acc ? vector::ReductionOp::create(rewriter, loc, kind, m, acc)
1108  : vector::ReductionOp::create(rewriter, loc, kind, m);
1109  return maskOperation(rewriter, reductionOp, mask)->getResult(0);
1110  }
1111  // Construct new iterator types and affine map array attribute.
1112  std::array<AffineMap, 3> lowIndexingMaps = {
1113  adjustMap(iMap[0], iterIndex, rewriter),
1114  adjustMap(iMap[1], iterIndex, rewriter),
1115  adjustMap(iMap[2], iterIndex, rewriter)};
1116  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1117  auto lowIter =
1118  rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1119  // Unroll into a series of lower dimensional vector.contract ops.
1120  // By feeding the initial accumulator into the first contraction,
1121  // and the result of each contraction into the next, eventually
1122  // the sum of all reductions is computed.
1123  Value result = op.getAcc();
1124  for (int64_t d = 0; d < dimSize; ++d) {
1125  auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1126  auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1127  Value newMask;
1128  if (mask)
1129  newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1130  iterIndex, d, rewriter);
1131 
1132  Operation *newContract = vector::ContractionOp::create(
1133  rewriter, loc, lhs, rhs, result, lowAffine, lowIter);
1134  result = maskOperation(rewriter, newContract, newMask)->getResult(0);
1135  }
1136  return result;
1137 }
1138 
1139 /// Progressive lowering of OuterProductOp.
1140 /// One:
1141 /// %x = vector.outerproduct %lhs, %rhs, %acc
1142 /// is replaced by:
1143 /// %z = zero-result
1144 /// %0 = vector.extract %lhs[0]
1145 /// %1 = vector.broadcast %0
1146 /// %2 = vector.extract %acc[0]
1147 /// %3 = vector.fma %1, %rhs, %2
1148 /// %4 = vector.insert %3, %z[0]
1149 /// ..
1150 /// %x = vector.insert %.., %..[N-1]
1151 ///
1152 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1153 public:
1155 
1156  LogicalResult matchAndRewrite(vector::OuterProductOp op,
1157  PatternRewriter &rewriter) const override {
1158  VectorType resType = op.getResultVectorType();
1159  if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1160  return failure();
1161 
1162  auto loc = op.getLoc();
1163 
1164  VectorType lhsType = op.getOperandVectorTypeLHS();
1165  VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1166  Type eltType = resType.getElementType();
1167  bool isInt = isa<IntegerType, IndexType>(eltType);
1168  Value acc = op.getAcc();
1169  vector::CombiningKind kind = op.getKind();
1170 
1171  // Vector mask setup.
1172  OpBuilder::InsertionGuard guard(rewriter);
1173  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1174  Operation *rootOp;
1175  Value mask;
1176  if (maskableOp.isMasked()) {
1177  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
1178  rootOp = maskableOp.getMaskingOp();
1179  mask = maskableOp.getMaskingOp().getMask();
1180  } else {
1181  rootOp = op;
1182  }
1183 
1184  if (!rhsType) {
1185  // Special case: AXPY operation.
1186  Value b =
1187  vector::BroadcastOp::create(rewriter, loc, lhsType, op.getRhs());
1188  std::optional<Value> mult = createContractArithOp(
1189  loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
1190  if (!mult.has_value())
1191  return failure();
1192  rewriter.replaceOp(rootOp, *mult);
1193  return success();
1194  }
1195 
1196  Value result = arith::ConstantOp::create(rewriter, loc, resType,
1197  rewriter.getZeroAttr(resType));
1198  for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1199  Value x = vector::ExtractOp::create(rewriter, loc, op.getLhs(), d);
1200  Value a = vector::BroadcastOp::create(rewriter, loc, rhsType, x);
1201  Value r = nullptr;
1202  if (acc)
1203  r = vector::ExtractOp::create(rewriter, loc, acc, d);
1204  Value extrMask;
1205  if (mask)
1206  extrMask = vector::ExtractOp::create(rewriter, loc, mask, d);
1207 
1208  std::optional<Value> m = createContractArithOp(
1209  loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1210  if (!m.has_value())
1211  return failure();
1212  result = vector::InsertOp::create(rewriter, loc, *m, result, d);
1213  }
1214 
1215  rewriter.replaceOp(rootOp, result);
1216  return success();
1217  }
1218 };
1219 
1220 } // namespace
1221 
1224  VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
1225  bool disableOuterProductLowering) {
1226  if (!disableOuterProductLowering)
1227  patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1228  patterns.add<ContractionOpLowering, ContractionOpToOuterProductOpLowering>(
1229  vectorContractLoweringOption, patterns.getContext(), benefit);
1230 }
1231 
1234  patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1235 }
static int64_t product(ArrayRef< int64_t > vals)
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static SmallVector< int64_t > getReductionIndex(AffineMap map, ArrayAttr iteratorTypes)
Return the positions of the reductions in the given map.
static Value createMul(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates a MulIOp if isInt is true otherwise create an MulFOp using operands x andy`.
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
static std::optional< unsigned > getDimPosition(AffineMap map, unsigned dim)
Look for a given dimension in an affine map and return its position.
static std::optional< Value > createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask=Value())
Helper to create arithmetic operation associated with a kind of contraction.
static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< Attribute > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static std::string diag(const llvm::Value &value)
#define MINUI(lhs, rhs)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:411
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:390
unsigned getNumResults() const
Definition: AffineMap.cpp:398
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:313
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:286
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:311
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:664
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:154
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:149
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:163