MLIR  21.0.0git
LowerVectorContract.cpp
Go to the documentation of this file.
1 //===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.contract' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
34 
35 #define DEBUG_TYPE "vector-contract-lowering"
36 
37 using namespace mlir;
38 using namespace mlir::vector;
39 
40 //===----------------------------------------------------------------------===//
41 // Helper functions
42 //===----------------------------------------------------------------------===//
43 // Helper to find an index in an affine map.
44 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
45  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
46  int64_t idx = map.getDimPosition(i);
47  if (idx == index)
48  return i;
49  }
50  return std::nullopt;
51 }
52 
53 // Helper to construct iterator types with one index removed.
54 static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
55  int64_t index) {
56  SmallVector<Attribute> results;
57  for (const auto &it : llvm::enumerate(iteratorTypes)) {
58  int64_t idx = it.index();
59  if (idx == index)
60  continue;
61  results.push_back(it.value());
62  }
63  return results;
64 }
65 
66 // Helper to construct an affine map with one index removed.
67 static AffineMap adjustMap(AffineMap map, int64_t index,
68  PatternRewriter &rewriter) {
69  auto *ctx = rewriter.getContext();
71  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
72  int64_t idx = map.getDimPosition(i);
73  if (idx == index)
74  continue;
75  // Re-insert remaining indices, but renamed when occurring
76  // after the removed index.
77  auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
78  results.push_back(targetExpr);
79  }
80  return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
81 }
82 
83 // Helper method to possibly drop a dimension in a load.
84 // TODO
85 static Value reshapeLoad(Location loc, Value val, VectorType type,
86  int64_t index, int64_t pos,
87  PatternRewriter &rewriter) {
88  if (index == -1)
89  return val;
90 
91  // At extraction dimension?
92  if (index == 0)
93  return rewriter.create<vector::ExtractOp>(loc, val, pos);
94 
95  // Unroll leading dimensions.
96  VectorType vType = VectorType::Builder(type).dropDim(0);
97  VectorType resType = VectorType::Builder(type).dropDim(index);
98  Value result = rewriter.create<arith::ConstantOp>(
99  loc, resType, rewriter.getZeroAttr(resType));
100  for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
101  Value ext = rewriter.create<vector::ExtractOp>(loc, val, d);
102  Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
103  result = rewriter.create<vector::InsertOp>(loc, load, result, d);
104  }
105  return result;
106 }
107 
108 // Helper method to possibly drop a dimension in a store.
109 // TODO
110 static Value reshapeStore(Location loc, Value val, Value result,
111  VectorType type, int64_t index, int64_t pos,
112  PatternRewriter &rewriter) {
113  // Unmodified?
114  if (index == -1)
115  return val;
116  // At insertion dimension?
117  if (index == 0)
118  return rewriter.create<vector::InsertOp>(loc, val, result, pos);
119 
120  // Unroll leading dimensions.
121  VectorType vType = VectorType::Builder(type).dropDim(0);
122  for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
123  Value ext = rewriter.create<vector::ExtractOp>(loc, result, d);
124  Value ins = rewriter.create<vector::ExtractOp>(loc, val, d);
125  Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
126  result = rewriter.create<vector::InsertOp>(loc, sto, result, d);
127  }
128  return result;
129 }
130 
131 /// Helper to create arithmetic operation associated with a kind of contraction.
132 static std::optional<Value>
134  vector::CombiningKind kind, PatternRewriter &rewriter,
135  bool isInt, Value mask = Value()) {
136  using vector::CombiningKind;
137  Value mul;
138 
139  if (isInt) {
140  if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
141  kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
142  // Only valid for floating point types.
143  return std::nullopt;
144  mul = rewriter.create<arith::MulIOp>(loc, x, y);
145  } else {
146  // Float case.
147  if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
148  kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
149  kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
150  kind == CombiningKind::XOR)
151  // Only valid for integer types.
152  return std::nullopt;
153  // Special case for fused multiply-add.
154  if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
155  Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
156  if (mask)
157  // The fma op doesn't need explicit masking. However, fma ops used in
158  // reductions must preserve previous 'acc' values for masked-out lanes.
159  fma = selectPassthru(rewriter, mask, fma, acc);
160  return fma;
161  }
162  mul = rewriter.create<arith::MulFOp>(loc, x, y);
163  }
164 
165  if (!acc)
166  return std::optional<Value>(mul);
167 
168  return makeArithReduction(rewriter, loc, kind, mul, acc,
169  /*fastmath=*/nullptr, mask);
170 }
171 
172 /// Return the positions of the reductions in the given map.
174  ArrayAttr iteratorTypes) {
175  SmallVector<int64_t> dimsIdx;
176  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
177  if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
178  dimsIdx.push_back(i);
179  }
180  return dimsIdx;
181 }
182 
183 /// Look for a given dimension in an affine map and return its position. Return
184 /// std::nullopt if the dimension is not in the map results.
185 static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
186  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
187  if (map.getDimPosition(i) == dim)
188  return i;
189  }
190  return std::nullopt;
191 }
192 
193 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
194 /// operands `x` and `y`.
195 static Value createAdd(Location loc, Value x, Value y, bool isInt,
196  PatternRewriter &rewriter) {
197  if (isInt)
198  return rewriter.create<arith::AddIOp>(loc, x, y);
199  return rewriter.create<arith::AddFOp>(loc, x, y);
200 }
201 
202 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
203 /// operands `x and `y`.
204 static Value createMul(Location loc, Value x, Value y, bool isInt,
205  PatternRewriter &rewriter) {
206  if (isInt)
207  return rewriter.create<arith::MulIOp>(loc, x, y);
208  return rewriter.create<arith::MulFOp>(loc, x, y);
209 }
210 
211 namespace {
212 
213 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
214 /// semantics to:
215 /// ```
216 /// %flattened_a = vector.shape_cast %a
217 /// %flattened_b = vector.shape_cast %b
218 /// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
219 /// %d = vector.shape_cast %%flattened_d
220 /// %e = add %c, %d
221 /// ```
222 /// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
223 //
224 /// This only kicks in when vectorContractLowering is set to Matmul and
225 /// the vector.contract op is a row-major matrix multiply.
226 class ContractionOpToMatmulOpLowering
227  : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
228 public:
229  using MaskableOpRewritePattern::MaskableOpRewritePattern;
230 
231  using FilterConstraintType =
232  std::function<LogicalResult(vector::ContractionOp op)>;
233 
234  static LogicalResult defaultFilter(vector::ContractionOp op) {
235  return success();
236  }
237 
238  ContractionOpToMatmulOpLowering(
239  vector::VectorContractLowering vectorContractLowering,
240  MLIRContext *context, PatternBenefit benefit = 1,
241  FilterConstraintType constraint = defaultFilter)
242  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
243  vectorContractLowering(vectorContractLowering),
244  filter(std::move(constraint)) {}
245 
246  FailureOr<Value>
247  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
248  PatternRewriter &rewriter) const override;
249 
250 private:
251  /// Options to control the vector patterns.
252  vector::VectorContractLowering vectorContractLowering;
253  FilterConstraintType filter;
254 };
255 
256 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
257 /// semantics to a reduction_size-unrolled sequence:
258 /// ```
259 /// %at = vector.transpose %a, [1, 0]
260 /// %bRow0 = vector.extract %b[0]
261 /// %atRow0 = vector.extract %at[0]
262 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
263 /// ...
264 /// %bRowK = vector.extract %b[K]
265 /// %atRowK = vector.extract %at[K]
266 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
267 /// ```
268 ///
269 /// This only kicks in when vectorContractLowering is set to OuterProduct and
270 /// the vector.contract op is a row-major matrix multiply.
271 class ContractionOpToOuterProductOpLowering
272  : public MaskableOpRewritePattern<vector::ContractionOp> {
273 public:
274  using MaskableOpRewritePattern::MaskableOpRewritePattern;
275 
276  using FilterConstraintType =
277  std::function<LogicalResult(vector::ContractionOp op)>;
278 
279  static LogicalResult defaultFilter(vector::ContractionOp op) {
280  return success();
281  }
282 
283  ContractionOpToOuterProductOpLowering(
284  vector::VectorContractLowering vectorContractLowering,
285  MLIRContext *context, PatternBenefit benefit = 1,
286  FilterConstraintType constraint = defaultFilter)
287  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
288  vectorContractLowering(vectorContractLowering),
289  filter(std::move(constraint)) {}
290 
291  FailureOr<Value>
292  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
293  PatternRewriter &rewriter) const override;
294 
295 private:
296  /// Options to control the vector patterns.
297  vector::VectorContractLowering vectorContractLowering;
298  FilterConstraintType filter;
299 };
300 
301 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
302 /// semantics to an output-size-unrolled sequence:
303 /// ```
304 /// %out = arith.constant ... : vector<MxNxelt_type>
305 /// %bt = vector.transpose %b, [1, 0]
306 /// %aRow0 = vector.extract %a[0]
307 /// %btRow0 = vector.extract %bt[0]
308 /// %c00 = vector.reduce %atRow0, %bRow0
309 /// %out00 = vector.insert %c00, %out[0, 0]
310 /// ...
311 /// %aRowLast = vector.extract %at[M-1]
312 /// %btRowLast = vector.extract %b[N-1]
313 /// %cLastLast = vector.reduce %atRowLast, %bRowLast
314 /// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
315 /// ```
316 ///
317 /// This only kicks in when VectorTransformsOptions is set to Dot and
318 /// the vector.contract op is a row-major matmul or matvec.
319 class ContractionOpToDotLowering
320  : public MaskableOpRewritePattern<vector::ContractionOp> {
321 public:
322  using MaskableOpRewritePattern::MaskableOpRewritePattern;
323 
324  using FilterConstraintType =
325  std::function<LogicalResult(vector::ContractionOp op)>;
326 
327  static LogicalResult defaultFilter(vector::ContractionOp op) {
328  return success();
329  }
330 
331  ContractionOpToDotLowering(
332  vector::VectorContractLowering vectorContractLowering,
333  MLIRContext *context, PatternBenefit benefit = 1,
334  const FilterConstraintType &constraint = defaultFilter)
335  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
336  vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
337 
338  FailureOr<Value>
339  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
340  PatternRewriter &rewriter) const override;
341 
342 private:
343  /// Options to control the vector patterns.
344  vector::VectorContractLowering vectorContractLowering;
345  FilterConstraintType filter;
346 };
347 
348 /// Progressive lowering of ContractionOp.
349 ///
350 /// One:
351 /// %x = vector.contract with at least one free/batch dimension
352 /// is replaced by:
353 /// %a = vector.contract with one less free/batch dimension
354 /// %b = vector.contract with one less free/batch dimension
355 /// ..
356 /// %x = combine %a %b ..
357 /// until a pure contraction is reached (no free/batch dimensions),
358 /// which is replaced by a dot-product.
359 ///
360 /// This only kicks in when either VectorTransformsOptions is set
361 /// to Dot or when other contraction patterns fail.
362 class ContractionOpLowering
363  : public MaskableOpRewritePattern<vector::ContractionOp> {
364 public:
365  using MaskableOpRewritePattern::MaskableOpRewritePattern;
366  using FilterConstraintType =
367  std::function<LogicalResult(vector::ContractionOp op)>;
368 
369  static LogicalResult defaultFilter(vector::ContractionOp op) {
370  return success();
371  }
372 
373  ContractionOpLowering(
374  vector::VectorContractLowering vectorContractLoweringOption,
375  MLIRContext *context, PatternBenefit benefit = 1,
376  FilterConstraintType constraint = defaultFilter)
377  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
378  vectorContractLoweringOption(vectorContractLoweringOption),
379  filter(std::move(constraint)) {}
380 
381  FailureOr<Value>
382  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
383  PatternRewriter &rewriter) const override;
384 
385 private:
386  /// Options to control the vector patterns.
387  vector::VectorContractLowering vectorContractLoweringOption;
388  FilterConstraintType filter;
389  // Lower one parallel dimension.
390  FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
391  vector::ContractionOp op, int64_t lhsIndex,
392  int64_t rhsIndex, Value mask) const;
393  // Lower one reduction dimension.
394  FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
395  vector::ContractionOp op, Value mask) const;
396 };
397 
398 /// Generate a vector implementation for matmat, matvec and tmatvec.
399 /// This unrolls outer-products along the reduction dimension.
400 struct UnrolledOuterProductGenerator
401  : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
402  UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
403  : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
404  kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
405  res(op.getAcc()), lhsType(op.getLhsType()) {
406  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
407  if (maskableOp.isMasked())
408  mask = maskableOp.getMaskingOp().getMask();
409  }
410 
411  Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
412  if (!v)
413  return v;
414  return rewriter.create<vector::TransposeOp>(loc, v, perm);
415  }
416 
417  Value promote(Value v, Type dstElementType) {
418  Type elementType = v.getType();
419  auto vecType = dyn_cast<VectorType>(elementType);
420  if (vecType)
421  elementType = vecType.getElementType();
422  if (elementType == dstElementType)
423  return v;
424  Type promotedType = dstElementType;
425  if (vecType)
426  promotedType = vecType.clone(promotedType);
427  if (isa<FloatType>(dstElementType))
428  return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
429  return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
430  }
431 
432  FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
433  VectorType lhsType, int reductionSize,
434  std::optional<Value> maybeMask = std::nullopt) {
435  // Incremental support for masking.
436  if (mask && !maybeMask.has_value())
437  return failure();
438 
439  Type resElementType = cast<VectorType>(res.getType()).getElementType();
440  for (int64_t k = 0; k < reductionSize; ++k) {
441  Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
442  Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
443  extractA = promote(extractA, resElementType);
444  extractB = promote(extractB, resElementType);
445  Value extractMask;
446  if (maybeMask.has_value() && maybeMask.value())
447  extractMask =
448  rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
449 
450  Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
451  loc, res.getType(), extractA, extractB, res, kind);
452  res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
453  }
454  return res;
455  }
456 
457  /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of
458  /// dimension `reductionDim`. If the dimension is a scalable dimension,
459  /// returns "nullopt".
460  std::optional<int64_t> getReductionSize(VectorType vecType,
461  int64_t reductionDim) {
462  // Cannot unroll scalable dimension.
463  if (vecType.getScalableDims()[reductionDim])
464  return std::nullopt;
465  int64_t reductionSize = vecType.getDimSize(reductionDim);
466  assert(reductionSize > 0 &&
467  "Reduction dim must be a known static size to allow unrolling");
468  return reductionSize;
469  }
470 
471  /// Two outer parallel, one inner reduction (matmat flavor).
472  FailureOr<Value> matmat() {
473  if (!iters({Par(), Par(), Red()}))
474  return failure();
475  // Set up the parallel/reduction structure in the right form.
476  AffineExpr m, n, k;
477  bindDims(rewriter.getContext(), m, n, k);
478 
479  // Classical row-major matmul: Just permute the lhs.
480  if (layout({{m, k}, {k, n}, {m, n}})) {
481  if (auto reductionSize = getReductionSize(lhsType, 1)) {
482  // Note: `t` creates new IR. It must be nested within this `if` check
483  // so that no IR is created when then pattern returns "failure".
484  Value tLhs = t(lhs);
485  Value tMask = t(mask, {2, 0, 1});
486  return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
487  }
488  }
489  // TODO: may be better to fail and use some vector<k> -> scalar reduction.
490  if (layout({{m, k}, {n, k}, {m, n}})) {
491  if (auto reductionSize = getReductionSize(lhsType, 1)) {
492  Value tLhs = t(lhs);
493  Value tRhs = t(rhs);
494  Value tMask = t(mask, {2, 0, 1});
495  return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
496  }
497  }
498  // No need to permute anything.
499  if (layout({{k, m}, {k, n}, {m, n}})) {
500  if (auto reductionSize = getReductionSize(lhsType, 0)) {
501  Value tMask = t(mask, {2, 0, 1});
502  return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
503  }
504  }
505  // Just permute the rhs.
506  if (layout({{k, m}, {n, k}, {m, n}})) {
507  if (auto reductionSize = getReductionSize(lhsType, 0)) {
508  Value tRhs = t(rhs);
509  Value tMask = t(mask, {2, 0, 1});
510  return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
511  }
512  }
513  // Transposed output: swap RHS and LHS.
514  // Classical row-major matmul: permute the lhs.
515  if (layout({{m, k}, {k, n}, {n, m}})) {
516  if (auto reductionSize = getReductionSize(lhsType, 1)) {
517  Value tLhs = t(lhs);
518  Value tMask = t(mask, {2, 0, 1});
519  return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
520  }
521  }
522  // TODO: may be better to fail and use some vector<k> -> scalar reduction.
523  if (layout({{m, k}, {n, k}, {n, m}})) {
524  if (auto reductionSize = getReductionSize(lhsType, 1)) {
525  Value tRhs = t(rhs);
526  Value tLhs = t(lhs);
527  Value tMask = t(mask, {2, 0, 1});
528  return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
529  }
530  }
531  if (layout({{k, m}, {k, n}, {n, m}})) {
532  if (auto reductionSize = getReductionSize(lhsType, 0)) {
533  Value tMask = t(mask, {2, 0, 1});
534  return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
535  }
536  }
537  if (layout({{k, m}, {n, k}, {n, m}})) {
538  if (auto reductionSize = getReductionSize(lhsType, 0)) {
539  Value tRhs = t(rhs);
540  Value tMask = t(mask, {2, 0, 1});
541  return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
542  }
543  }
544  return failure();
545  }
546 
547  //
548  // One outer parallel, one inner reduction (matvec flavor).
549  // Mask needs to be transposed everywhere to turn the reduction dimension
550  // outermost as required by outerproduct.
551  //
552  FailureOr<Value> matvec() {
553  if (!iters({Par(), Red()}))
554  return failure();
555  AffineExpr m, k;
556  bindDims(rewriter.getContext(), m, k);
557 
558  // Case mat-vec: transpose.
559  if (layout({{m, k}, {k}, {m}})) {
560  if (auto reductionSize = getReductionSize(lhsType, 1)) {
561  Value tLhs = t(lhs);
562  Value tMask = t(mask);
563  return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
564  }
565  }
566  // Case mat-trans-vec: ready to go.
567  if (layout({{k, m}, {k}, {m}})) {
568  if (auto reductionSize = getReductionSize(lhsType, 0)) {
569  Value tMask = t(mask);
570  return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
571  }
572  }
573  // Case vec-mat: swap and transpose.
574  if (layout({{k}, {m, k}, {m}})) {
575  if (auto reductionSize = getReductionSize(lhsType, 0)) {
576  Value tRhs = t(rhs);
577  Value tMask = t(mask);
578  return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
579  }
580  }
581  // Case vec-mat-trans: swap and ready to go.
582  if (layout({{k}, {k, m}, {m}})) {
583  if (auto reductionSize = getReductionSize(lhsType, 0)) {
584  Value tMask = t(mask);
585  return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
586  }
587  }
588  return failure();
589  }
590 
591  //
592  // One outer reduction, one inner parallel (tmatvec flavor).
593  // Mask already has the shape of the outer product.
594  //
595  FailureOr<Value> tmatvec() {
596  if (!iters({Red(), Par()}))
597  return failure();
598  AffineExpr k, m;
599  bindDims(rewriter.getContext(), k, m);
600 
601  // Case mat-vec: transpose.
602  if (layout({{m, k}, {k}, {m}}))
603  if (auto reductionSize = getReductionSize(lhsType, 1))
604  return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
605  // Case mat-trans-vec: ready to go.
606  if (layout({{k, m}, {k}, {m}}))
607  if (auto reductionSize = getReductionSize(lhsType, 0))
608  return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
609  // Case vec-mat: swap and transpose.
610  if (layout({{k}, {m, k}, {m}}))
611  if (auto reductionSize = getReductionSize(lhsType, 0))
612  return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
613  // Case vec-mat-trans: swap and ready to go.
614  if (layout({{k}, {k, m}, {m}}))
615  if (auto reductionSize = getReductionSize(lhsType, 0))
616  return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
617  return failure();
618  }
619 
620 private:
621  vector::CombiningKind kind;
622  Value lhs, rhs, res, mask;
623  VectorType lhsType;
624 };
625 
626 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
627 /// semantics to a reduction_size-unrolled sequence:
628 /// ```
629 /// %at = vector.transpose %a, [1, 0]
630 /// %bRow0 = vector.extract %b[0]
631 /// %atRow0 = vector.extract %at[0]
632 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
633 /// ...
634 /// %bRowK = vector.extract %b[K]
635 /// %atRowK = vector.extract %at[K]
636 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
637 /// ```
638 ///
639 /// This only kicks in when vectorContractLowering is set to OuterProduct but
640 /// otherwise supports any layout permutation of the matrix-multiply.
641 FailureOr<Value>
642 ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
643  vector::ContractionOp op, MaskingOpInterface maskOp,
644  PatternRewriter &rewriter) const {
645  if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
646  return failure();
647 
648  if (failed(filter(op)))
649  return failure();
650 
651  UnrolledOuterProductGenerator e(rewriter, op);
652  FailureOr<Value> matmatRes = e.matmat();
653  if (succeeded(matmatRes)) {
654  return matmatRes;
655  }
656  FailureOr<Value> matvecRes = e.matvec();
657  if (succeeded(matvecRes)) {
658  return matvecRes;
659  }
660 
661  FailureOr<Value> tmatvecRes = e.tmatvec();
662  return tmatvecRes;
663 }
664 
665 FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
666  vector::ContractionOp op, MaskingOpInterface maskOp,
667  PatternRewriter &rewriter) const {
668  // TODO: Support vector.mask.
669  if (maskOp)
670  return failure();
671 
672  if (failed(filter(op)))
673  return failure();
674 
675  if (vectorContractLowering != vector::VectorContractLowering::Dot)
676  return failure();
677 
678  auto iteratorTypes = op.getIteratorTypes().getValue();
679  static constexpr std::array<int64_t, 2> perm = {1, 0};
680  Location loc = op.getLoc();
681  Value lhs = op.getLhs(), rhs = op.getRhs();
682 
683  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
684  auto infer = [&](MapList m) {
685  return AffineMap::inferFromExprList(m, op.getContext());
686  };
687  AffineExpr m, n, k;
688  bindDims(rewriter.getContext(), m, n, k);
689  SmallVector<AffineMap> maps = op.getIndexingMapsArray();
690  //
691  // In the following we wish to make the reduction dimension innermost so we
692  // can load vectors and just fmul + reduce into a scalar.
693  //
694  if (isParallelIterator(iteratorTypes[0]) &&
695  isParallelIterator(iteratorTypes[1]) &&
696  isReductionIterator(iteratorTypes[2])) {
697  //
698  // Two outer parallel, one inner reduction (matmat flavor).
699  //
700  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
701  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
702  } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
703  // No need to permute anything.
704  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
705  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
706  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
707  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
708  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
709  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
710  // This is the classical row-major matmul. Just permute the lhs.
711  Value tmp = lhs;
712  lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
713  rhs = tmp;
714  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
715  std::swap(lhs, rhs);
716  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
717  Value tmp = lhs;
718  lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
719  rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
720  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
721  Value tmp = rhs;
722  rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
723  lhs = tmp;
724  } else {
725  return failure();
726  }
727  } else if (isParallelIterator(iteratorTypes[0]) &&
728  isReductionIterator(iteratorTypes[1])) {
729  //
730  // One outer parallel, one inner reduction (matvec flavor)
731  //
732  if (maps == infer({{m, n}, {n}, {m}})) {
733  // No need to permute anything.
734  } else if (maps == infer({{n, m}, {n}, {m}})) {
735  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
736  } else if (maps == infer({{n}, {m, n}, {m}})) {
737  std::swap(lhs, rhs);
738  } else if (maps == infer({{n}, {n, m}, {m}})) {
739  std::swap(lhs, rhs);
740  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
741  } else {
742  return failure();
743  }
744  } else {
745  return failure();
746  }
747 
748  VectorType dstType = cast<VectorType>(op.getResultType());
749  assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
750  "Expected dst type of rank 1 or 2");
751 
752  unsigned rank = dstType.getRank();
753  unsigned dstRows = dstType.getShape()[0];
754  unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
755 
756  // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
757  Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
758  rewriter.getZeroAttr(dstType));
759  bool isInt = isa<IntegerType>(dstType.getElementType());
760  llvm::SmallVector<Value> extractedCols;
761  extractedCols.reserve(dstColumns);
762  for (unsigned r = 0; r < dstRows; ++r) {
763  Value rowLhs = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
764  for (unsigned c = 0; c < dstColumns; ++c) {
765  // Extract each respective row and column of the LHS and RHS once to
766  // avoid having duplicate SSA values pointing to the same rows/columns.
767  if (r == 0) {
768  Value colRhs =
769  rank == 1 ? rhs
770  : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
771  extractedCols.push_back(colRhs);
772  }
773  Value extractedColRhs = extractedCols[c];
774  Value product =
775  createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter);
776  Value sum = rewriter.create<vector::ReductionOp>(
777  op.getLoc(), vector::CombiningKind::ADD, product);
778 
780  : SmallVector<int64_t, 2>{r, c};
781  res = rewriter.create<vector::InsertOp>(op.getLoc(), sum, res, pos);
782  }
783  }
784  if (auto acc = op.getAcc())
785  res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
786  return res;
787 }
788 
789 /// Lower vector.contract with all size one reduction dimensions to
790 /// elementwise ops when possible.
791 struct ContractOpToElementwise
792  : public MaskableOpRewritePattern<vector::ContractionOp> {
793  using MaskableOpRewritePattern::MaskableOpRewritePattern;
794  using FilterConstraintType =
795  std::function<LogicalResult(vector::ContractionOp op)>;
796  static LogicalResult defaultFilter(vector::ContractionOp op) {
797  return success();
798  }
799  ContractOpToElementwise(
800  vector::VectorContractLowering vectorContractLowering,
801  MLIRContext *context, PatternBenefit benefit = 1,
802  const FilterConstraintType &constraint = defaultFilter)
803  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
804  vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
805 
806  FailureOr<Value>
807  matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
808  MaskingOpInterface maskOp,
809  PatternRewriter &rewriter) const override {
810  // TODO: Support vector.mask.
811  if (maskOp)
812  return failure();
813 
814  if (failed(filter(contractOp)))
815  return failure();
816 
817  if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
818  return failure();
819 
820  ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
821  ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
822  AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
823  AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
824  SmallVector<int64_t> lhsReductionDims =
825  getReductionIndex(lhsMap, contractOp.getIteratorTypes());
826  SmallVector<int64_t> rhsReductionDims =
827  getReductionIndex(rhsMap, contractOp.getIteratorTypes());
828  // All the reduction dimensions must be a size 1.
829  for (int64_t dim : lhsReductionDims) {
830  if (lhsShape[dim] != 1)
831  return failure();
832  }
833  for (int64_t dim : rhsReductionDims) {
834  if (rhsShape[dim] != 1)
835  return failure();
836  }
837  AffineMap accMap = contractOp.getIndexingMapsArray()[2];
838  unsigned numParallelDims = accMap.getNumResults();
839  unsigned numLhsDimToBroadcast =
840  numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
841  unsigned numRhsDimToBroadcast =
842  numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
843  SmallVector<int64_t> lhsDims;
844  SmallVector<int64_t> lhsTranspose;
845  SmallVector<int64_t> rhsDims;
846  SmallVector<int64_t> rhsTranspose;
847  for (int64_t dim : lhsReductionDims)
848  lhsTranspose.push_back(numLhsDimToBroadcast + dim);
849  for (int64_t dim : rhsReductionDims)
850  rhsTranspose.push_back(numRhsDimToBroadcast + dim);
851  // Loop through the parallel dimensions to calculate the dimensions to
852  // broadcast and to permute in order to extract only parallel dimensions.
853  for (unsigned i = 0; i < numParallelDims; i++) {
854  std::optional<unsigned> lhsDim =
855  getDimPosition(lhsMap, accMap.getDimPosition(i));
856  if (lhsDim) {
857  lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
858  } else {
859  // If the parallel dimension doesn't exist we will have to broadcast it.
860  lhsDims.push_back(
861  cast<VectorType>(contractOp.getResultType()).getDimSize(i));
862  lhsTranspose.push_back(lhsDims.size() - 1);
863  }
864  std::optional<unsigned> rhsDim =
865  getDimPosition(rhsMap, accMap.getDimPosition(i));
866  if (rhsDim) {
867  rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
868  } else {
869  // If the parallel dimension doesn't exist we will have to broadcast it.
870  rhsDims.push_back(
871  cast<VectorType>(contractOp.getResultType()).getDimSize(i));
872  rhsTranspose.push_back(rhsDims.size() - 1);
873  }
874  }
875  Value newLhs = contractOp.getLhs();
876  Value newRhs = contractOp.getRhs();
877  Location loc = contractOp.getLoc();
878  if (!lhsDims.empty()) {
879  lhsDims.append(lhsShape.begin(), lhsShape.end());
880  auto expandedType =
881  VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
882  newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
883  }
884  if (!rhsDims.empty()) {
885  rhsDims.append(rhsShape.begin(), rhsShape.end());
886  auto expandedType =
887  VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
888  newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
889  }
890  bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
891  newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
892  newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
893  SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
894  SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
895  newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
896  newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
897  std::optional<Value> result =
898  createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
899  contractOp.getKind(), rewriter, isInt);
900  if (result)
901  return *result;
902 
903  return failure();
904  }
905 
906 private:
907  /// Options to control the vector patterns.
908  vector::VectorContractLowering vectorContractLowering;
909  FilterConstraintType filter;
910 };
911 
912 /// Progressive lowering of ContractionOp.
913 /// One:
914 /// %x = vector.contract with at least one free/batch dimension
915 /// is replaced by:
916 /// %a = vector.contract with one less free/batch dimension
917 /// %b = vector.contract with one less free/batch dimension
918 /// ..
919 /// %x = combine %a %b ..
920 /// until a pure contraction is reached (no free/batch dimensions),
921 /// which is replaced by a dot-product.
922 ///
923 /// This only kicks in when either vectorContractLoweringOption is set
924 /// to DOT or when other contraction patterns fail.
925 //
926 // TODO: break down into transpose/reshape/cast ops
927 // when they become available to avoid code dup
928 // TODO: investigate lowering order impact on performance
929 FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
930  vector::ContractionOp op, MaskingOpInterface maskOp,
931  PatternRewriter &rewriter) const {
932  if (failed(filter(op)))
933  return failure();
934 
935  // TODO: support mixed mode contract lowering.
936  if (op.getLhsType().getElementType() !=
937  getElementTypeOrSelf(op.getAccType()) ||
938  op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
939  return failure();
940 
941  // TODO: the code below assumes the default contraction, make sure it supports
942  // other kinds before enabling this lowering.
943  if (op.getKind() != vector::CombiningKind::ADD) {
944  return rewriter.notifyMatchFailure(
945  op, "contractions other than 'add' not supported");
946  }
947 
948  // TODO: implement benefits, cost models.
949  MLIRContext *ctx = op.getContext();
950 
951  ContractionOpToMatmulOpLowering pat1(vectorContractLoweringOption, ctx);
952  FailureOr<Value> newVal1 =
953  pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
954  if (!failed(newVal1))
955  return newVal1;
956 
957  ContractionOpToOuterProductOpLowering pat2(vectorContractLoweringOption, ctx);
958  FailureOr<Value> newVal2 =
959  pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
960  if (!failed(newVal2))
961  return newVal2;
962 
963  ContractionOpToDotLowering pat3(vectorContractLoweringOption, ctx);
964  FailureOr<Value> newVal3 =
965  pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
966  if (!failed(newVal3))
967  return newVal3;
968 
969  ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
970  FailureOr<Value> newVal4 =
971  pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
972  if (!failed(newVal4))
973  return newVal4;
974 
975  // Vector mask setup.
976 
977  Value mask;
978  if (maskOp)
979  mask = maskOp.getMask();
980  // Find first batch dimension in LHS/RHS, and lower when found.
981  std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
982  if (!batchDimMap.empty()) {
983  int64_t lhsIndex = batchDimMap[0].first;
984  int64_t rhsIndex = batchDimMap[0].second;
985  auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
986  if (failed(newOp))
987  return failure();
988  return newOp;
989  }
990 
991  // Collect contracting dimensions.
992  std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
993  op.getContractingDimMap();
994  DenseSet<int64_t> lhsContractingDimSet;
995  DenseSet<int64_t> rhsContractingDimSet;
996  for (auto &dimPair : contractingDimMap) {
997  lhsContractingDimSet.insert(dimPair.first);
998  rhsContractingDimSet.insert(dimPair.second);
999  }
1000 
1001  // Find first free dimension in LHS, and lower when found.
1002  VectorType lhsType = op.getLhsType();
1003  for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1004  if (lhsContractingDimSet.count(lhsIndex) == 0) {
1005  auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
1006  if (failed(newOp))
1007  return failure();
1008  return newOp;
1009  }
1010  }
1011 
1012  // Find first free dimension in RHS, and lower when found.
1013  VectorType rhsType = op.getRhsType();
1014  for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1015  if (rhsContractingDimSet.count(rhsIndex) == 0) {
1016  auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
1017  if (failed(newOp))
1018  return failure();
1019  return newOp;
1020  }
1021  }
1022 
1023  // Lower the first remaining reduction dimension.
1024  if (!contractingDimMap.empty()) {
1025  auto newOp = lowerReduction(rewriter, op, mask);
1026  if (failed(newOp))
1027  return failure();
1028  return newOp;
1029  }
1030 
1031  return failure();
1032 }
1033 
1034 // Lower one parallel dimension.
1035 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
1036 // TODO: consider reusing existing contract unrolling
1037 FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
1038  vector::ContractionOp op,
1039  int64_t lhsIndex,
1040  int64_t rhsIndex,
1041  Value mask) const {
1042  VectorType lhsType = op.getLhsType();
1043  VectorType rhsType = op.getRhsType();
1044  VectorType resType = cast<VectorType>(op.getResultType());
1045  // Find the iterator type index and result index.
1046  SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1047  int64_t iterIndex = -1;
1048  int64_t dimSize = -1;
1049  if (lhsIndex >= 0) {
1050  iterIndex = iMap[0].getDimPosition(lhsIndex);
1051  if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
1052  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1053  diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
1054  << " to map to the same dimension";
1055  });
1056  if (lhsType.getScalableDims()[lhsIndex])
1057  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1058  diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex
1059  << ") is not supported yet";
1060  });
1061  dimSize = lhsType.getDimSize(lhsIndex);
1062  } else if (rhsIndex >= 0) {
1063  iterIndex = iMap[1].getDimPosition(rhsIndex);
1064  if (rhsType.getScalableDims()[rhsIndex])
1065  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1066  diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex
1067  << ") is not supported yet";
1068  });
1069  dimSize = rhsType.getDimSize(rhsIndex);
1070  }
1071  if (iterIndex < 0)
1072  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1073  diag << "expected either lhsIndex=" << lhsIndex
1074  << " or rhsIndex=" << rhsIndex << " to be nonnegative";
1075  });
1076  // value_or(-1) means that we tolerate a dimension not appearing
1077  // in the result map. That can't happen for actual parallel iterators, but
1078  // the caller ContractionOpLowering::matchAndRewrite is currently calling
1079  // lowerParallel also for the case of unit-size reduction dims appearing only
1080  // on one of LHS or RHS, not both. At the moment, such cases are created by
1081  // CastAwayContractionLeadingOneDim, so we need to either support that or
1082  // modify that pattern.
1083  int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
1084  if (resIndex == -1 && dimSize != 1)
1085  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1086  diag << "expected the dimension for iterIndex=" << iterIndex
1087  << " to either appear in the result map, or to be a unit dimension";
1088  });
1089 
1090  // Construct new iterator types and affine map array attribute.
1091  std::array<AffineMap, 3> lowIndexingMaps = {
1092  adjustMap(iMap[0], iterIndex, rewriter),
1093  adjustMap(iMap[1], iterIndex, rewriter),
1094  adjustMap(iMap[2], iterIndex, rewriter)};
1095  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1096  auto lowIter =
1097  rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1098  // Unroll into a series of lower dimensional vector.contract ops.
1099  Location loc = op.getLoc();
1100  Value result = rewriter.create<arith::ConstantOp>(
1101  loc, resType, rewriter.getZeroAttr(resType));
1102 
1103  for (int64_t d = 0; d < dimSize; ++d) {
1104  auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1105  auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1106  auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1107 
1108  Value lowMask;
1109  if (mask)
1110  lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1111  iterIndex, d, rewriter);
1112 
1113  Operation *lowContract = rewriter.create<vector::ContractionOp>(
1114  loc, lhs, rhs, acc, lowAffine, lowIter);
1115  lowContract = maskOperation(rewriter, lowContract, lowMask);
1116  result = reshapeStore(loc, lowContract->getResult(0), result, resType,
1117  resIndex, d, rewriter);
1118  }
1119  return result;
1120 }
1121 
1122 // Lower one reduction dimension.
1123 FailureOr<Value> ContractionOpLowering::lowerReduction(
1124  PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
1125  auto loc = op.getLoc();
1126  VectorType lhsType = op.getLhsType();
1127  VectorType rhsType = op.getRhsType();
1128  Type resType = op.getResultType();
1129  if (isa<VectorType>(resType))
1130  return rewriter.notifyMatchFailure(op,
1131  "did not expect a VectorType result");
1132  bool isInt = isa<IntegerType>(resType);
1133  // Use iterator index 0.
1134  int64_t iterIndex = 0;
1135  SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1136  std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1137  std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1138  if (!lookupLhs.has_value())
1139  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1140  diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1141  });
1142  if (!lookupRhs.has_value())
1143  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1144  diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1145  });
1146  int64_t lhsIndex = *lookupLhs;
1147  int64_t rhsIndex = *lookupRhs;
1148  int64_t dimSize = lhsType.getDimSize(lhsIndex);
1149  if (dimSize != rhsType.getDimSize(rhsIndex))
1150  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1151  diag << "expect LHS dimension " << lhsIndex
1152  << " to have the same size as RHS dimension " << rhsIndex;
1153  });
1154  // Base case.
1155  if (lhsType.getRank() == 1) {
1156  if (rhsType.getRank() != 1)
1157  return rewriter.notifyMatchFailure(
1158  op, "When LHS has rank 1, expected also RHS to have rank 1");
1159  Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1160  auto kind = vector::CombiningKind::ADD;
1161 
1162  Value acc = op.getAcc();
1163  Operation *reductionOp =
1164  acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
1165  : rewriter.create<vector::ReductionOp>(loc, kind, m);
1166  return maskOperation(rewriter, reductionOp, mask)->getResult(0);
1167  }
1168  // Construct new iterator types and affine map array attribute.
1169  std::array<AffineMap, 3> lowIndexingMaps = {
1170  adjustMap(iMap[0], iterIndex, rewriter),
1171  adjustMap(iMap[1], iterIndex, rewriter),
1172  adjustMap(iMap[2], iterIndex, rewriter)};
1173  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1174  auto lowIter =
1175  rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1176  // Unroll into a series of lower dimensional vector.contract ops.
1177  // By feeding the initial accumulator into the first contraction,
1178  // and the result of each contraction into the next, eventually
1179  // the sum of all reductions is computed.
1180  Value result = op.getAcc();
1181  for (int64_t d = 0; d < dimSize; ++d) {
1182  auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1183  auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1184  Value newMask;
1185  if (mask)
1186  newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1187  iterIndex, d, rewriter);
1188 
1189  Operation *newContract = rewriter.create<vector::ContractionOp>(
1190  loc, lhs, rhs, result, lowAffine, lowIter);
1191  result = maskOperation(rewriter, newContract, newMask)->getResult(0);
1192  }
1193  return result;
1194 }
1195 
1196 /// Progressive lowering of OuterProductOp.
1197 /// One:
1198 /// %x = vector.outerproduct %lhs, %rhs, %acc
1199 /// is replaced by:
1200 /// %z = zero-result
1201 /// %0 = vector.extract %lhs[0]
1202 /// %1 = vector.broadcast %0
1203 /// %2 = vector.extract %acc[0]
1204 /// %3 = vector.fma %1, %rhs, %2
1205 /// %4 = vector.insert %3, %z[0]
1206 /// ..
1207 /// %x = vector.insert %.., %..[N-1]
1208 ///
1209 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1210 public:
1212 
1213  LogicalResult matchAndRewrite(vector::OuterProductOp op,
1214  PatternRewriter &rewriter) const override {
1215  VectorType resType = op.getResultVectorType();
1216  if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1217  return failure();
1218 
1219  auto loc = op.getLoc();
1220 
1221  VectorType lhsType = op.getOperandVectorTypeLHS();
1222  VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1223  Type eltType = resType.getElementType();
1224  bool isInt = isa<IntegerType, IndexType>(eltType);
1225  Value acc = op.getAcc();
1226  vector::CombiningKind kind = op.getKind();
1227 
1228  // Vector mask setup.
1229  OpBuilder::InsertionGuard guard(rewriter);
1230  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1231  Operation *rootOp;
1232  Value mask;
1233  if (maskableOp.isMasked()) {
1234  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
1235  rootOp = maskableOp.getMaskingOp();
1236  mask = maskableOp.getMaskingOp().getMask();
1237  } else {
1238  rootOp = op;
1239  }
1240 
1241  if (!rhsType) {
1242  // Special case: AXPY operation.
1243  Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
1244  std::optional<Value> mult = createContractArithOp(
1245  loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
1246  if (!mult.has_value())
1247  return failure();
1248  rewriter.replaceOp(rootOp, *mult);
1249  return success();
1250  }
1251 
1252  Value result = rewriter.create<arith::ConstantOp>(
1253  loc, resType, rewriter.getZeroAttr(resType));
1254  for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1255  Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d);
1256  Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
1257  Value r = nullptr;
1258  if (acc)
1259  r = rewriter.create<vector::ExtractOp>(loc, acc, d);
1260  Value extrMask;
1261  if (mask)
1262  extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d);
1263 
1264  std::optional<Value> m = createContractArithOp(
1265  loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1266  if (!m.has_value())
1267  return failure();
1268  result = rewriter.create<vector::InsertOp>(loc, *m, result, d);
1269  }
1270 
1271  rewriter.replaceOp(rootOp, result);
1272  return success();
1273  }
1274 };
1275 
1276 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1277 /// semantics to:
1278 /// ```
1279 /// %mta = maybe_transpose
1280 /// %mtb = maybe_transpose
1281 /// %flattened_a = vector.shape_cast %mta
1282 /// %flattened_b = vector.shape_cast %mtb
1283 /// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
1284 /// %mtd = vector.shape_cast %flattened_d
1285 /// %d = maybe_untranspose %mtd
1286 /// %e = add %c, %d
1287 /// ```
1288 /// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
1289 //
1290 /// This only kicks in when vectorContractLowering is set to `Matmul`.
1291 /// vector.transpose operations are inserted if the vector.contract op is not a
1292 /// row-major matrix multiply.
1293 ///
1294 /// Scalable vectors are not supported.
1295 FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
1296  vector::ContractionOp op, MaskingOpInterface maskOp,
1297  PatternRewriter &rew) const {
1298  // TODO: Support vector.mask.
1299  if (maskOp)
1300  return failure();
1301 
1302  if (vectorContractLowering != vector::VectorContractLowering::Matmul)
1303  return failure();
1304  if (failed(filter(op)))
1305  return failure();
1306 
1307  auto iteratorTypes = op.getIteratorTypes().getValue();
1308  if (!isParallelIterator(iteratorTypes[0]) ||
1309  !isParallelIterator(iteratorTypes[1]) ||
1310  !isReductionIterator(iteratorTypes[2]))
1311  return failure();
1312 
1313  Type opResType = op.getType();
1314  VectorType vecType = dyn_cast<VectorType>(opResType);
1315  if (vecType && vecType.isScalable()) {
1316  // Note - this is sufficient to reject all cases with scalable vectors.
1317  return failure();
1318  }
1319 
1320  Type elementType = op.getLhsType().getElementType();
1321  if (!elementType.isIntOrFloat())
1322  return failure();
1323 
1324  Type dstElementType = vecType ? vecType.getElementType() : opResType;
1325  if (elementType != dstElementType)
1326  return failure();
1327 
1328  // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1329  // Bail out if the contraction cannot be put in this form.
1330  MLIRContext *ctx = op.getContext();
1331  Location loc = op.getLoc();
1332  AffineExpr m, n, k;
1333  bindDims(rew.getContext(), m, n, k);
1334  // LHS must be A(m, k) or A(k, m).
1335  Value lhs = op.getLhs();
1336  auto lhsMap = op.getIndexingMapsArray()[0];
1337  if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1338  lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1339  else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1340  return failure();
1341 
1342  // RHS must be B(k, n) or B(n, k).
1343  Value rhs = op.getRhs();
1344  auto rhsMap = op.getIndexingMapsArray()[1];
1345  if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1346  rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1347  else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1348  return failure();
1349 
1350  // At this point lhs and rhs are in row-major.
1351  VectorType lhsType = cast<VectorType>(lhs.getType());
1352  VectorType rhsType = cast<VectorType>(rhs.getType());
1353  int64_t lhsRows = lhsType.getDimSize(0);
1354  int64_t lhsColumns = lhsType.getDimSize(1);
1355  int64_t rhsColumns = rhsType.getDimSize(1);
1356 
1357  Type flattenedLHSType =
1358  VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1359  lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1360 
1361  Type flattenedRHSType =
1362  VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1363  rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1364 
1365  Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1366  rhsColumns);
1367  mul = rew.create<vector::ShapeCastOp>(
1368  loc,
1369  VectorType::get({lhsRows, rhsColumns},
1370  getElementTypeOrSelf(op.getAcc().getType())),
1371  mul);
1372 
1373  // ACC must be C(m, n) or C(n, m).
1374  auto accMap = op.getIndexingMapsArray()[2];
1375  if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1376  mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1377  else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1378  llvm_unreachable("invalid contraction semantics");
1379 
1380  Value res =
1381  isa<IntegerType>(elementType)
1382  ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
1383  : static_cast<Value>(
1384  rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
1385 
1386  return res;
1387 }
1388 } // namespace
1389 
1392  VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
1393  bool disableOuterProductLowering) {
1394  if (!disableOuterProductLowering)
1395  patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1396  patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
1397  ContractionOpToOuterProductOpLowering>(
1398  vectorContractLoweringOption, patterns.getContext(), benefit);
1399 }
1400 
1403  patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1404 }
static int64_t product(ArrayRef< int64_t > vals)
union mlir::linalg::@1183::ArityGroupAndKind::Kind kind
static SmallVector< int64_t > getReductionIndex(AffineMap map, ArrayAttr iteratorTypes)
Return the positions of the reductions in the given map.
static Value createMul(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates a MulIOp if isInt is true otherwise create an MulFOp using operands x andy`.
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
static std::optional< unsigned > getDimPosition(AffineMap map, unsigned dim)
Look for a given dimension in an affine map and return its position.
static std::optional< Value > createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask=Value())
Helper to create arithmetic operation associated with a kind of contraction.
static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< Attribute > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static std::string diag(const llvm::Value &value)
#define MINUI(lhs, rhs)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:394
unsigned getNumResults() const
Definition: AffineMap.cpp:402
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:314
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:412
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:270
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:295
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:649
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:152
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:147
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:368
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:157