MLIR  19.0.0git
LowerVectorContract.cpp
Go to the documentation of this file.
1 //===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.contract' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
35 
36 #define DEBUG_TYPE "vector-contract-lowering"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 //===----------------------------------------------------------------------===//
42 // Helper functions
43 //===----------------------------------------------------------------------===//
44 
45 // Helper to find an index in an affine map.
46 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
47  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
48  int64_t idx = map.getDimPosition(i);
49  if (idx == index)
50  return i;
51  }
52  return std::nullopt;
53 }
54 
55 // Helper to construct iterator types with one index removed.
56 static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
57  int64_t index) {
58  SmallVector<Attribute> results;
59  for (const auto &it : llvm::enumerate(iteratorTypes)) {
60  int64_t idx = it.index();
61  if (idx == index)
62  continue;
63  results.push_back(it.value());
64  }
65  return results;
66 }
67 
68 // Helper to construct an affine map with one index removed.
69 static AffineMap adjustMap(AffineMap map, int64_t index,
70  PatternRewriter &rewriter) {
71  auto *ctx = rewriter.getContext();
73  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
74  int64_t idx = map.getDimPosition(i);
75  if (idx == index)
76  continue;
77  // Re-insert remaining indices, but renamed when occurring
78  // after the removed index.
79  auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
80  results.push_back(targetExpr);
81  }
82  return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
83 }
84 
85 // Helper method to possibly drop a dimension in a load.
86 // TODO
87 static Value reshapeLoad(Location loc, Value val, VectorType type,
88  int64_t index, int64_t pos,
89  PatternRewriter &rewriter) {
90  if (index == -1)
91  return val;
92 
93  // At extraction dimension?
94  if (index == 0)
95  return rewriter.create<vector::ExtractOp>(loc, val, pos);
96 
97  // Unroll leading dimensions.
98  VectorType vType = VectorType::Builder(type).dropDim(0);
99  VectorType resType = VectorType::Builder(type).dropDim(index);
100  Value result = rewriter.create<arith::ConstantOp>(
101  loc, resType, rewriter.getZeroAttr(resType));
102  for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
103  Value ext = rewriter.create<vector::ExtractOp>(loc, val, d);
104  Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
105  result = rewriter.create<vector::InsertOp>(loc, load, result, d);
106  }
107  return result;
108 }
109 
110 // Helper method to possibly drop a dimension in a store.
111 // TODO
112 static Value reshapeStore(Location loc, Value val, Value result,
113  VectorType type, int64_t index, int64_t pos,
114  PatternRewriter &rewriter) {
115  // Unmodified?
116  if (index == -1)
117  return val;
118  // At insertion dimension?
119  if (index == 0)
120  return rewriter.create<vector::InsertOp>(loc, val, result, pos);
121 
122  // Unroll leading dimensions.
123  VectorType vType = VectorType::Builder(type).dropDim(0);
124  for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
125  Value ext = rewriter.create<vector::ExtractOp>(loc, result, d);
126  Value ins = rewriter.create<vector::ExtractOp>(loc, val, d);
127  Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
128  result = rewriter.create<vector::InsertOp>(loc, sto, result, d);
129  }
130  return result;
131 }
132 
133 /// Helper to create arithmetic operation associated with a kind of contraction.
134 static std::optional<Value>
136  vector::CombiningKind kind, PatternRewriter &rewriter,
137  bool isInt, Value mask = Value()) {
138  using vector::CombiningKind;
139  Value mul;
140 
141  if (isInt) {
142  if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
143  kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
144  // Only valid for floating point types.
145  return std::nullopt;
146  mul = rewriter.create<arith::MulIOp>(loc, x, y);
147  } else {
148  // Float case.
149  if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
150  kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
151  kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
152  kind == CombiningKind::XOR)
153  // Only valid for integer types.
154  return std::nullopt;
155  // Special case for fused multiply-add.
156  if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
157  Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
158  if (mask)
159  // The fma op doesn't need explicit masking. However, fma ops used in
160  // reductions must preserve previous 'acc' values for masked-out lanes.
161  fma = selectPassthru(rewriter, mask, fma, acc);
162  return fma;
163  }
164  mul = rewriter.create<arith::MulFOp>(loc, x, y);
165  }
166 
167  if (!acc)
168  return std::optional<Value>(mul);
169 
170  return makeArithReduction(rewriter, loc, kind, mul, acc,
171  /*fastmath=*/nullptr, mask);
172 }
173 
174 /// Return the positions of the reductions in the given map.
176  ArrayAttr iteratorTypes) {
177  SmallVector<int64_t> dimsIdx;
178  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
179  if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
180  dimsIdx.push_back(i);
181  }
182  return dimsIdx;
183 }
184 
185 /// Look for a given dimension in an affine map and return its position. Return
186 /// std::nullopt if the dimension is not in the map results.
187 static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
188  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
189  if (map.getDimPosition(i) == dim)
190  return i;
191  }
192  return std::nullopt;
193 }
194 
195 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
196 /// operands `x` and `y`.
197 static Value createAdd(Location loc, Value x, Value y, bool isInt,
198  PatternRewriter &rewriter) {
199  if (isInt)
200  return rewriter.create<arith::AddIOp>(loc, x, y);
201  return rewriter.create<arith::AddFOp>(loc, x, y);
202 }
203 
204 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
205 /// operands `x and `y`.
206 static Value createMul(Location loc, Value x, Value y, bool isInt,
207  PatternRewriter &rewriter) {
208  if (isInt)
209  return rewriter.create<arith::MulIOp>(loc, x, y);
210  return rewriter.create<arith::MulFOp>(loc, x, y);
211 }
212 
213 namespace {
214 
215 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
216 /// semantics to:
217 /// ```
218 /// %flattened_a = vector.shape_cast %a
219 /// %flattened_b = vector.shape_cast %b
220 /// %flattened_d = vector.matmul %flattened_a, %flattened_b
221 /// %d = vector.shape_cast %%flattened_d
222 /// %e = add %c, %d
223 /// ```
224 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
225 //
226 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
227 /// the vector.contract op is a row-major matrix multiply.
228 class ContractionOpToMatmulOpLowering
229  : public OpRewritePattern<vector::ContractionOp> {
230 public:
232 
233  using FilterConstraintType =
234  std::function<LogicalResult(vector::ContractionOp op)>;
235 
236  static LogicalResult defaultFilter(vector::ContractionOp op) {
237  return success();
238  }
239 
240  ContractionOpToMatmulOpLowering(
241  vector::VectorTransformsOptions vectorTransformOptions,
242  MLIRContext *context, PatternBenefit benefit = 1,
243  FilterConstraintType constraint = defaultFilter)
244  : OpRewritePattern<vector::ContractionOp>(context, benefit),
245  vectorTransformOptions(vectorTransformOptions),
246  filter(std::move(constraint)) {}
247 
248  LogicalResult matchAndRewrite(vector::ContractionOp op,
249  PatternRewriter &rewriter) const override;
250 
251 private:
252  /// Options to control the vector patterns.
253  vector::VectorTransformsOptions vectorTransformOptions;
254  FilterConstraintType filter;
255 };
256 
257 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
258 /// semantics to a reduction_size-unrolled sequence:
259 /// ```
260 /// %at = vector.transpose %a, [1, 0]
261 /// %bRow0 = vector.extract %b[0]
262 /// %atRow0 = vector.extract %at[0]
263 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
264 /// ...
265 /// %bRowK = vector.extract %b[K]
266 /// %atRowK = vector.extract %at[K]
267 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
268 /// ```
269 ///
270 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
271 /// the vector.contract op is a row-major matrix multiply.
272 class ContractionOpToOuterProductOpLowering
273  : public OpRewritePattern<vector::ContractionOp> {
274 public:
276 
277  using FilterConstraintType =
278  std::function<LogicalResult(vector::ContractionOp op)>;
279 
280  static LogicalResult defaultFilter(vector::ContractionOp op) {
281  return success();
282  }
283 
284  ContractionOpToOuterProductOpLowering(
285  vector::VectorTransformsOptions vectorTransformOptions,
286  MLIRContext *context, PatternBenefit benefit = 1,
287  FilterConstraintType constraint = defaultFilter)
288  : OpRewritePattern<vector::ContractionOp>(context, benefit),
289  vectorTransformOptions(vectorTransformOptions),
290  filter(std::move(constraint)) {}
291 
292  LogicalResult matchAndRewrite(vector::ContractionOp op,
293  PatternRewriter &rewriter) const override;
294 
295 private:
296  /// Options to control the vector patterns.
297  vector::VectorTransformsOptions vectorTransformOptions;
298  FilterConstraintType filter;
299 };
300 
301 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
302 /// semantics to an output-size-unrolled sequence:
303 /// ```
304 /// %out = arith.constant ... : vector<MxNxelt_type>
305 /// %bt = vector.transpose %b, [1, 0]
306 /// %aRow0 = vector.extract %a[0]
307 /// %btRow0 = vector.extract %bt[0]
308 /// %c00 = vector.reduce %atRow0, %bRow0
309 /// %out00 = vector.insert %c00, %out[0, 0]
310 /// ...
311 /// %aRowLast = vector.extract %at[M-1]
312 /// %btRowLast = vector.extract %b[N-1]
313 /// %cLastLast = vector.reduce %atRowLast, %bRowLast
314 /// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
315 /// ```
316 ///
317 /// This only kicks in when VectorTransformsOptions is set to Dot and
318 /// the vector.contract op is a row-major matmul or matvec.
319 class ContractionOpToDotLowering
320  : public OpRewritePattern<vector::ContractionOp> {
321 public:
323 
324  using FilterConstraintType =
325  std::function<LogicalResult(vector::ContractionOp op)>;
326 
327  static LogicalResult defaultFilter(vector::ContractionOp op) {
328  return success();
329  }
330 
331  ContractionOpToDotLowering(
332  vector::VectorTransformsOptions vectorTransformOptions,
333  MLIRContext *context, PatternBenefit benefit = 1,
334  const FilterConstraintType &constraint = defaultFilter)
335  : OpRewritePattern<vector::ContractionOp>(context, benefit),
336  vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
337 
338  LogicalResult matchAndRewrite(vector::ContractionOp op,
339  PatternRewriter &rewriter) const override;
340 
341 private:
342  /// Options to control the vector patterns.
343  vector::VectorTransformsOptions vectorTransformOptions;
344  FilterConstraintType filter;
345 };
346 
347 /// Progressive lowering of ContractionOp.
348 ///
349 /// One:
350 /// %x = vector.contract with at least one free/batch dimension
351 /// is replaced by:
352 /// %a = vector.contract with one less free/batch dimension
353 /// %b = vector.contract with one less free/batch dimension
354 /// ..
355 /// %x = combine %a %b ..
356 /// until a pure contraction is reached (no free/batch dimensions),
357 /// which is replaced by a dot-product.
358 ///
359 /// This only kicks in when either VectorTransformsOptions is set
360 /// to Dot or when other contraction patterns fail.
361 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
362 public:
364  using FilterConstraintType =
365  std::function<LogicalResult(vector::ContractionOp op)>;
366 
367  static LogicalResult defaultFilter(vector::ContractionOp op) {
368  return success();
369  }
370 
371  ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
372  MLIRContext *context, PatternBenefit benefit = 1,
373  FilterConstraintType constraint = defaultFilter)
374  : OpRewritePattern<vector::ContractionOp>(context, benefit),
375  vectorTransformOptions(vectorTransformOptions),
376  filter(std::move(constraint)) {}
377 
378  LogicalResult matchAndRewrite(vector::ContractionOp op,
379  PatternRewriter &rewriter) const override;
380 
381 private:
382  /// Options to control the vector patterns.
383  vector::VectorTransformsOptions vectorTransformOptions;
384  FilterConstraintType filter;
385  // Lower one parallel dimension.
386  FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
387  vector::ContractionOp op, int64_t lhsIndex,
388  int64_t rhsIndex, Value mask) const;
389  // Lower one reduction dimension.
390  FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
391  vector::ContractionOp op, Value mask) const;
392 };
393 
394 /// Generate a vector implementation for matmat, matvec and tmatvec.
395 /// This unrolls outer-products along the reduction dimension.
396 struct UnrolledOuterProductGenerator
397  : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
398  UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
399  : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
400  kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
401  res(op.getAcc()), lhsType(op.getLhsType()) {
402  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
403  if (maskableOp.isMasked())
404  mask = maskableOp.getMaskingOp().getMask();
405  }
406 
407  Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
408  if (!v)
409  return v;
410  return rewriter.create<vector::TransposeOp>(loc, v, perm);
411  }
412 
413  Value promote(Value v, Type dstElementType) {
414  Type elementType = v.getType();
415  auto vecType = dyn_cast<VectorType>(elementType);
416  if (vecType)
417  elementType = vecType.getElementType();
418  if (elementType == dstElementType)
419  return v;
420  Type promotedType = dstElementType;
421  if (vecType)
422  promotedType = vecType.clone(promotedType);
423  if (isa<FloatType>(dstElementType))
424  return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
425  return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
426  }
427 
428  FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
429  VectorType lhsType, int reductionSize,
430  std::optional<Value> maybeMask = std::nullopt) {
431  // Incremental support for masking.
432  if (mask && !maybeMask.has_value())
433  return failure();
434 
435  Type resElementType = cast<VectorType>(res.getType()).getElementType();
436  for (int64_t k = 0; k < reductionSize; ++k) {
437  Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
438  Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
439  extractA = promote(extractA, resElementType);
440  extractB = promote(extractB, resElementType);
441  Value extractMask;
442  if (maybeMask.has_value() && maybeMask.value())
443  extractMask =
444  rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
445 
446  Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
447  loc, res.getType(), extractA, extractB, res, kind);
448  res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
449  }
450  return res;
451  }
452 
453  /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of
454  /// dimension `reductionDim`. If the dimension is a scalable dimension,
455  /// returns "nullopt".
456  std::optional<int64_t> getReductionSize(VectorType vecType,
457  int64_t reductionDim) {
458  // Cannot unroll scalable dimension.
459  if (vecType.getScalableDims()[reductionDim])
460  return std::nullopt;
461  int64_t reductionSize = vecType.getDimSize(reductionDim);
462  assert(reductionSize > 0 &&
463  "Reduction dim must be a known static size to allow unrolling");
464  return reductionSize;
465  }
466 
467  /// Two outer parallel, one inner reduction (matmat flavor).
468  FailureOr<Value> matmat() {
469  if (!iters({Par(), Par(), Red()}))
470  return failure();
471  // Set up the parallel/reduction structure in the right form.
472  AffineExpr m, n, k;
473  bindDims(rewriter.getContext(), m, n, k);
474 
475  // Classical row-major matmul: Just permute the lhs.
476  if (layout({{m, k}, {k, n}, {m, n}})) {
477  if (auto reductionSize = getReductionSize(lhsType, 1)) {
478  // Note: `t` creates new IR. It must be nested within this `if` check
479  // so that no IR is created when then pattern returns "failure".
480  Value tLhs = t(lhs);
481  Value tMask = t(mask, {2, 0, 1});
482  return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
483  }
484  }
485  // TODO: may be better to fail and use some vector<k> -> scalar reduction.
486  if (layout({{m, k}, {n, k}, {m, n}})) {
487  if (auto reductionSize = getReductionSize(lhsType, 1)) {
488  Value tLhs = t(lhs);
489  Value tRhs = t(rhs);
490  Value tMask = t(mask, {2, 0, 1});
491  return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
492  }
493  }
494  // No need to permute anything.
495  if (layout({{k, m}, {k, n}, {m, n}})) {
496  if (auto reductionSize = getReductionSize(lhsType, 0)) {
497  Value tMask = t(mask, {2, 0, 1});
498  return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
499  }
500  }
501  // Just permute the rhs.
502  if (layout({{k, m}, {n, k}, {m, n}})) {
503  if (auto reductionSize = getReductionSize(lhsType, 0)) {
504  Value tRhs = t(rhs);
505  Value tMask = t(mask, {2, 0, 1});
506  return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
507  }
508  }
509  // Transposed output: swap RHS and LHS.
510  // Classical row-major matmul: permute the lhs.
511  if (layout({{m, k}, {k, n}, {n, m}})) {
512  if (auto reductionSize = getReductionSize(lhsType, 1)) {
513  Value tLhs = t(lhs);
514  Value tMask = t(mask, {2, 0, 1});
515  return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
516  }
517  }
518  // TODO: may be better to fail and use some vector<k> -> scalar reduction.
519  if (layout({{m, k}, {n, k}, {n, m}})) {
520  if (auto reductionSize = getReductionSize(lhsType, 1)) {
521  Value tRhs = t(rhs);
522  Value tLhs = t(lhs);
523  Value tMask = t(mask, {2, 0, 1});
524  return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
525  }
526  }
527  if (layout({{k, m}, {k, n}, {n, m}})) {
528  if (auto reductionSize = getReductionSize(lhsType, 0)) {
529  Value tMask = t(mask, {2, 0, 1});
530  return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
531  }
532  }
533  if (layout({{k, m}, {n, k}, {n, m}})) {
534  if (auto reductionSize = getReductionSize(lhsType, 0)) {
535  Value tRhs = t(rhs);
536  Value tMask = t(mask, {2, 0, 1});
537  return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
538  }
539  }
540  return failure();
541  }
542 
543  //
544  // One outer parallel, one inner reduction (matvec flavor).
545  // Mask needs to be transposed everywhere to turn the reduction dimension
546  // outermost as required by outerproduct.
547  //
548  FailureOr<Value> matvec() {
549  if (!iters({Par(), Red()}))
550  return failure();
551  AffineExpr m, k;
552  bindDims(rewriter.getContext(), m, k);
553 
554  // Case mat-vec: transpose.
555  if (layout({{m, k}, {k}, {m}})) {
556  if (auto reductionSize = getReductionSize(lhsType, 1)) {
557  Value tLhs = t(lhs);
558  Value tMask = t(mask);
559  return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
560  }
561  }
562  // Case mat-trans-vec: ready to go.
563  if (layout({{k, m}, {k}, {m}})) {
564  if (auto reductionSize = getReductionSize(lhsType, 0)) {
565  Value tMask = t(mask);
566  return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
567  }
568  }
569  // Case vec-mat: swap and transpose.
570  if (layout({{k}, {m, k}, {m}})) {
571  if (auto reductionSize = getReductionSize(lhsType, 0)) {
572  Value tRhs = t(rhs);
573  Value tMask = t(mask);
574  return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
575  }
576  }
577  // Case vec-mat-trans: swap and ready to go.
578  if (layout({{k}, {k, m}, {m}})) {
579  if (auto reductionSize = getReductionSize(lhsType, 0)) {
580  Value tMask = t(mask);
581  return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
582  }
583  }
584  return failure();
585  }
586 
587  //
588  // One outer reduction, one inner parallel (tmatvec flavor).
589  // Mask already has the shape of the outer product.
590  //
591  FailureOr<Value> tmatvec() {
592  if (!iters({Red(), Par()}))
593  return failure();
594  AffineExpr k, m;
595  bindDims(rewriter.getContext(), k, m);
596 
597  // Case mat-vec: transpose.
598  if (layout({{m, k}, {k}, {m}}))
599  if (auto reductionSize = getReductionSize(lhsType, 1))
600  return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
601  // Case mat-trans-vec: ready to go.
602  if (layout({{k, m}, {k}, {m}}))
603  if (auto reductionSize = getReductionSize(lhsType, 0))
604  return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
605  // Case vec-mat: swap and transpose.
606  if (layout({{k}, {m, k}, {m}}))
607  if (auto reductionSize = getReductionSize(lhsType, 0))
608  return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
609  // Case vec-mat-trans: swap and ready to go.
610  if (layout({{k}, {k, m}, {m}}))
611  if (auto reductionSize = getReductionSize(lhsType, 0))
612  return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
613  return failure();
614  }
615 
616 private:
617  vector::CombiningKind kind;
618  Value lhs, rhs, res, mask;
619  VectorType lhsType;
620 };
621 
622 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
623 /// semantics to a reduction_size-unrolled sequence:
624 /// ```
625 /// %at = vector.transpose %a, [1, 0]
626 /// %bRow0 = vector.extract %b[0]
627 /// %atRow0 = vector.extract %at[0]
628 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
629 /// ...
630 /// %bRowK = vector.extract %b[K]
631 /// %atRowK = vector.extract %at[K]
632 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
633 /// ```
634 ///
635 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
636 /// otherwise supports any layout permutation of the matrix-multiply.
637 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
638  vector::ContractionOp op, PatternRewriter &rewriter) const {
639  if (vectorTransformOptions.vectorContractLowering !=
640  vector::VectorContractLowering::OuterProduct)
641  return failure();
642 
643  if (failed(filter(op)))
644  return failure();
645 
646  // Vector mask setup.
647  OpBuilder::InsertionGuard guard(rewriter);
648  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
649  Operation *rootOp;
650  if (maskableOp.isMasked()) {
651  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
652  rootOp = maskableOp.getMaskingOp();
653  } else {
654  rootOp = op;
655  }
656 
657  UnrolledOuterProductGenerator e(rewriter, op);
658  FailureOr<Value> matmatRes = e.matmat();
659  if (succeeded(matmatRes)) {
660  rewriter.replaceOp(rootOp, *matmatRes);
661  return success();
662  }
663  FailureOr<Value> matvecRes = e.matvec();
664  if (succeeded(matvecRes)) {
665  rewriter.replaceOp(rootOp, *matvecRes);
666  return success();
667  }
668  FailureOr<Value> tmatvecRes = e.tmatvec();
669  if (succeeded(tmatvecRes)) {
670  rewriter.replaceOp(rootOp, *tmatvecRes);
671  return success();
672  }
673 
674  return failure();
675 }
676 
678 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
679  PatternRewriter &rewriter) const {
680  // TODO: Support vector.mask.
681  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
682  if (maskableOp.isMasked())
683  return failure();
684 
685  if (failed(filter(op)))
686  return failure();
687 
688  if (vectorTransformOptions.vectorContractLowering !=
689  vector::VectorContractLowering::Dot)
690  return failure();
691 
692  auto iteratorTypes = op.getIteratorTypes().getValue();
693  static constexpr std::array<int64_t, 2> perm = {1, 0};
694  Location loc = op.getLoc();
695  Value lhs = op.getLhs(), rhs = op.getRhs();
696 
697  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
698  auto infer = [&](MapList m) {
700  };
701  AffineExpr m, n, k;
702  bindDims(rewriter.getContext(), m, n, k);
703  SmallVector<AffineMap> maps = op.getIndexingMapsArray();
704  //
705  // In the following we wish to make the reduction dimension innermost so we
706  // can load vectors and just fmul + reduce into a scalar.
707  //
708  if (isParallelIterator(iteratorTypes[0]) &&
709  isParallelIterator(iteratorTypes[1]) &&
710  isReductionIterator(iteratorTypes[2])) {
711  //
712  // Two outer parallel, one inner reduction (matmat flavor).
713  //
714  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
715  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
716  } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
717  // No need to permute anything.
718  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
719  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
720  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
721  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
722  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
723  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
724  // This is the classical row-major matmul. Just permute the lhs.
725  Value tmp = lhs;
726  lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
727  rhs = tmp;
728  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
729  std::swap(lhs, rhs);
730  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
731  Value tmp = lhs;
732  lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
733  rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
734  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
735  Value tmp = rhs;
736  rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
737  lhs = tmp;
738  } else {
739  return failure();
740  }
741  } else if (isParallelIterator(iteratorTypes[0]) &&
742  isReductionIterator(iteratorTypes[1])) {
743  //
744  // One outer parallel, one inner reduction (matvec flavor)
745  //
746  if (maps == infer({{m, n}, {n}, {m}})) {
747  // No need to permute anything.
748  } else if (maps == infer({{n, m}, {n}, {m}})) {
749  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
750  } else if (maps == infer({{n}, {m, n}, {m}})) {
751  std::swap(lhs, rhs);
752  } else if (maps == infer({{n}, {n, m}, {m}})) {
753  std::swap(lhs, rhs);
754  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
755  } else {
756  return failure();
757  }
758  } else {
759  return failure();
760  }
761 
762  VectorType dstType = cast<VectorType>(op.getResultType());
763  assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
764  "Expected dst type of rank 1 or 2");
765 
766  unsigned rank = dstType.getRank();
767  unsigned dstRows = dstType.getShape()[0];
768  unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
769 
770  // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
771  Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
772  rewriter.getZeroAttr(dstType));
773  bool isInt = isa<IntegerType>(dstType.getElementType());
774  for (unsigned r = 0; r < dstRows; ++r) {
775  Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
776  for (unsigned c = 0; c < dstColumns; ++c) {
777  Value b = rank == 1
778  ? rhs
779  : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
780  Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
781  Value reduced = rewriter.create<vector::ReductionOp>(
782  op.getLoc(), vector::CombiningKind::ADD, m);
783 
785  : SmallVector<int64_t, 2>{r, c};
786  res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
787  }
788  }
789  if (auto acc = op.getAcc())
790  res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
791  rewriter.replaceOp(op, res);
792  return success();
793 }
794 
795 /// Lower vector.contract with all size one reduction dimensions to
796 /// elementwise ops when possible.
797 struct ContractOpToElementwise
798  : public OpRewritePattern<vector::ContractionOp> {
800  using FilterConstraintType =
801  std::function<LogicalResult(vector::ContractionOp op)>;
802  static LogicalResult defaultFilter(vector::ContractionOp op) {
803  return success();
804  }
805  ContractOpToElementwise(
806  vector::VectorTransformsOptions vectorTransformOptions,
807  MLIRContext *context, PatternBenefit benefit = 1,
808  const FilterConstraintType &constraint = defaultFilter)
809  : OpRewritePattern<vector::ContractionOp>(context, benefit),
810  vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
811 
812  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
813  PatternRewriter &rewriter) const override {
814  // TODO: Support vector.mask.
815  auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation());
816  if (maskableOp.isMasked())
817  return failure();
818 
819  if (failed(filter(contractOp)))
820  return failure();
821 
822  if (vectorTransformOptions.vectorContractLowering !=
823  vector::VectorContractLowering::ParallelArith)
824  return failure();
825 
826  ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
827  ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
828  AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
829  AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
830  SmallVector<int64_t> lhsReductionDims =
831  getReductionIndex(lhsMap, contractOp.getIteratorTypes());
832  SmallVector<int64_t> rhsReductionDims =
833  getReductionIndex(rhsMap, contractOp.getIteratorTypes());
834  // All the reduction dimensions must be a size 1.
835  for (int64_t dim : lhsReductionDims) {
836  if (lhsShape[dim] != 1)
837  return failure();
838  }
839  for (int64_t dim : rhsReductionDims) {
840  if (rhsShape[dim] != 1)
841  return failure();
842  }
843  AffineMap accMap = contractOp.getIndexingMapsArray()[2];
844  unsigned numParallelDims = accMap.getNumResults();
845  unsigned numLhsDimToBroadcast =
846  numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
847  unsigned numRhsDimToBroadcast =
848  numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
849  SmallVector<int64_t> lhsDims;
850  SmallVector<int64_t> lhsTranspose;
851  SmallVector<int64_t> rhsDims;
852  SmallVector<int64_t> rhsTranspose;
853  for (int64_t dim : lhsReductionDims)
854  lhsTranspose.push_back(numLhsDimToBroadcast + dim);
855  for (int64_t dim : rhsReductionDims)
856  rhsTranspose.push_back(numRhsDimToBroadcast + dim);
857  // Loop through the parallel dimensions to calculate the dimensions to
858  // broadcast and to permute in order to extract only parallel dimensions.
859  for (unsigned i = 0; i < numParallelDims; i++) {
860  std::optional<unsigned> lhsDim =
861  getDimPosition(lhsMap, accMap.getDimPosition(i));
862  if (lhsDim) {
863  lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
864  } else {
865  // If the parallel dimension doesn't exist we will have to broadcast it.
866  lhsDims.push_back(
867  cast<VectorType>(contractOp.getResultType()).getDimSize(i));
868  lhsTranspose.push_back(lhsDims.size() - 1);
869  }
870  std::optional<unsigned> rhsDim =
871  getDimPosition(rhsMap, accMap.getDimPosition(i));
872  if (rhsDim) {
873  rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
874  } else {
875  // If the parallel dimension doesn't exist we will have to broadcast it.
876  rhsDims.push_back(
877  cast<VectorType>(contractOp.getResultType()).getDimSize(i));
878  rhsTranspose.push_back(rhsDims.size() - 1);
879  }
880  }
881  Value newLhs = contractOp.getLhs();
882  Value newRhs = contractOp.getRhs();
883  Location loc = contractOp.getLoc();
884  if (!lhsDims.empty()) {
885  lhsDims.append(lhsShape.begin(), lhsShape.end());
886  auto expandedType =
887  VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
888  newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
889  }
890  if (!rhsDims.empty()) {
891  rhsDims.append(rhsShape.begin(), rhsShape.end());
892  auto expandedType =
893  VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
894  newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
895  }
896  bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
897  newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
898  newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
899  SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
900  SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
901  newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
902  newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
903  std::optional<Value> result =
904  createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
905  contractOp.getKind(), rewriter, isInt);
906  rewriter.replaceOp(contractOp, {*result});
907  return success();
908  }
909 
910 private:
911  /// Options to control the vector patterns.
912  vector::VectorTransformsOptions vectorTransformOptions;
913  FilterConstraintType filter;
914 };
915 
916 /// Progressive lowering of ContractionOp.
917 /// One:
918 /// %x = vector.contract with at least one free/batch dimension
919 /// is replaced by:
920 /// %a = vector.contract with one less free/batch dimension
921 /// %b = vector.contract with one less free/batch dimension
922 /// ..
923 /// %x = combine %a %b ..
924 /// until a pure contraction is reached (no free/batch dimensions),
925 /// which is replaced by a dot-product.
926 ///
927 /// This only kicks in when either VectorTransformsOptions is set
928 /// to DOT or when other contraction patterns fail.
929 //
930 // TODO: break down into transpose/reshape/cast ops
931 // when they become available to avoid code dup
932 // TODO: investigate lowering order impact on performance
934 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
935  PatternRewriter &rewriter) const {
936  if (failed(filter(op)))
937  return failure();
938 
939  // TODO: support mixed mode contract lowering.
940  if (op.getLhsType().getElementType() !=
941  getElementTypeOrSelf(op.getAccType()) ||
942  op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
943  return failure();
944 
945  // TODO: the code below assumes the default contraction, make sure it supports
946  // other kinds before enabling this lowering.
947  if (op.getKind() != vector::CombiningKind::ADD) {
948  return rewriter.notifyMatchFailure(
949  op, "contractions other than 'add' not supported");
950  }
951 
952  // TODO: implement benefits, cost models.
953  MLIRContext *ctx = op.getContext();
954  ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
955  if (succeeded(pat1.matchAndRewrite(op, rewriter)))
956  return success();
957  ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
958  if (succeeded(pat2.matchAndRewrite(op, rewriter)))
959  return success();
960  ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
961  if (succeeded(pat3.matchAndRewrite(op, rewriter)))
962  return success();
963  ContractOpToElementwise pat4(vectorTransformOptions, ctx);
964  if (succeeded(pat4.matchAndRewrite(op, rewriter)))
965  return success();
966 
967  // Vector mask setup.
968  OpBuilder::InsertionGuard guard(rewriter);
969  Operation *rootOp = op;
970  Value mask;
971  if (op.isMasked()) {
972  rewriter.setInsertionPoint(op.getMaskingOp());
973  rootOp = op.getMaskingOp();
974  mask = op.getMaskingOp().getMask();
975  }
976 
977  // Find first batch dimension in LHS/RHS, and lower when found.
978  std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
979  if (!batchDimMap.empty()) {
980  int64_t lhsIndex = batchDimMap[0].first;
981  int64_t rhsIndex = batchDimMap[0].second;
982  auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
983  if (failed(newOp))
984  return failure();
985  rewriter.replaceOp(rootOp, *newOp);
986  return success();
987  }
988 
989  // Collect contracting dimensions.
990  std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
991  op.getContractingDimMap();
992  DenseSet<int64_t> lhsContractingDimSet;
993  DenseSet<int64_t> rhsContractingDimSet;
994  for (auto &dimPair : contractingDimMap) {
995  lhsContractingDimSet.insert(dimPair.first);
996  rhsContractingDimSet.insert(dimPair.second);
997  }
998 
999  // Find first free dimension in LHS, and lower when found.
1000  VectorType lhsType = op.getLhsType();
1001  for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1002  if (lhsContractingDimSet.count(lhsIndex) == 0) {
1003  auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
1004  if (failed(newOp))
1005  return failure();
1006  rewriter.replaceOp(rootOp, *newOp);
1007  return success();
1008  }
1009  }
1010 
1011  // Find first free dimension in RHS, and lower when found.
1012  VectorType rhsType = op.getRhsType();
1013  for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1014  if (rhsContractingDimSet.count(rhsIndex) == 0) {
1015  auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
1016  if (failed(newOp))
1017  return failure();
1018  rewriter.replaceOp(rootOp, *newOp);
1019  return success();
1020  }
1021  }
1022 
1023  // Lower the first remaining reduction dimension.
1024  if (!contractingDimMap.empty()) {
1025  auto newOp = lowerReduction(rewriter, op, mask);
1026  if (failed(newOp))
1027  return failure();
1028  rewriter.replaceOp(rootOp, *newOp);
1029  return success();
1030  }
1031 
1032  return failure();
1033 }
1034 
1035 // Lower one parallel dimension.
1036 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
1037 // TODO: consider reusing existing contract unrolling
1038 FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
1039  vector::ContractionOp op,
1040  int64_t lhsIndex,
1041  int64_t rhsIndex,
1042  Value mask) const {
1043  VectorType lhsType = op.getLhsType();
1044  VectorType rhsType = op.getRhsType();
1045  VectorType resType = cast<VectorType>(op.getResultType());
1046  // Find the iterator type index and result index.
1047  SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1048  int64_t iterIndex = -1;
1049  int64_t dimSize = -1;
1050  if (lhsIndex >= 0) {
1051  iterIndex = iMap[0].getDimPosition(lhsIndex);
1052  if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
1053  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1054  diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
1055  << " to map to the same dimension";
1056  });
1057  if (lhsType.getScalableDims()[lhsIndex])
1058  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1059  diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex
1060  << ") is not supported yet";
1061  });
1062  dimSize = lhsType.getDimSize(lhsIndex);
1063  } else if (rhsIndex >= 0) {
1064  iterIndex = iMap[1].getDimPosition(rhsIndex);
1065  if (rhsType.getScalableDims()[rhsIndex])
1066  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1067  diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex
1068  << ") is not supported yet";
1069  });
1070  dimSize = rhsType.getDimSize(rhsIndex);
1071  }
1072  if (iterIndex < 0)
1073  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1074  diag << "expected either lhsIndex=" << lhsIndex
1075  << " or rhsIndex=" << rhsIndex << " to be nonnegative";
1076  });
1077  // value_or(-1) means that we tolerate a dimension not appearing
1078  // in the result map. That can't happen for actual parallel iterators, but
1079  // the caller ContractionOpLowering::matchAndRewrite is currently calling
1080  // lowerParallel also for the case of unit-size reduction dims appearing only
1081  // on one of LHS or RHS, not both. At the moment, such cases are created by
1082  // CastAwayContractionLeadingOneDim, so we need to either support that or
1083  // modify that pattern.
1084  int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
1085  if (resIndex == -1 && dimSize != 1)
1086  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1087  diag << "expected the dimension for iterIndex=" << iterIndex
1088  << " to either appear in the result map, or to be a unit dimension";
1089  });
1090 
1091  // Construct new iterator types and affine map array attribute.
1092  std::array<AffineMap, 3> lowIndexingMaps = {
1093  adjustMap(iMap[0], iterIndex, rewriter),
1094  adjustMap(iMap[1], iterIndex, rewriter),
1095  adjustMap(iMap[2], iterIndex, rewriter)};
1096  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1097  auto lowIter =
1098  rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1099  // Unroll into a series of lower dimensional vector.contract ops.
1100  Location loc = op.getLoc();
1101  Value result = rewriter.create<arith::ConstantOp>(
1102  loc, resType, rewriter.getZeroAttr(resType));
1103 
1104  for (int64_t d = 0; d < dimSize; ++d) {
1105  auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1106  auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1107  auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1108 
1109  Value lowMask;
1110  if (mask)
1111  lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1112  iterIndex, d, rewriter);
1113 
1114  Operation *lowContract = rewriter.create<vector::ContractionOp>(
1115  loc, lhs, rhs, acc, lowAffine, lowIter);
1116  lowContract = maskOperation(rewriter, lowContract, lowMask);
1117  result = reshapeStore(loc, lowContract->getResult(0), result, resType,
1118  resIndex, d, rewriter);
1119  }
1120  return result;
1121 }
1122 
1123 // Lower one reduction dimension.
1124 FailureOr<Value> ContractionOpLowering::lowerReduction(
1125  PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
1126  auto loc = op.getLoc();
1127  VectorType lhsType = op.getLhsType();
1128  VectorType rhsType = op.getRhsType();
1129  Type resType = op.getResultType();
1130  if (isa<VectorType>(resType))
1131  return rewriter.notifyMatchFailure(op,
1132  "did not expect a VectorType result");
1133  bool isInt = isa<IntegerType>(resType);
1134  // Use iterator index 0.
1135  int64_t iterIndex = 0;
1136  SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1137  std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1138  std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1139  if (!lookupLhs.has_value())
1140  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1141  diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1142  });
1143  if (!lookupRhs.has_value())
1144  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1145  diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1146  });
1147  int64_t lhsIndex = *lookupLhs;
1148  int64_t rhsIndex = *lookupRhs;
1149  int64_t dimSize = lhsType.getDimSize(lhsIndex);
1150  if (dimSize != rhsType.getDimSize(rhsIndex))
1151  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1152  diag << "expect LHS dimension " << lhsIndex
1153  << " to have the same size as RHS dimension " << rhsIndex;
1154  });
1155  // Base case.
1156  if (lhsType.getRank() == 1) {
1157  if (rhsType.getRank() != 1)
1158  return rewriter.notifyMatchFailure(
1159  op, "When LHS has rank 1, expected also RHS to have rank 1");
1160  Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1161  auto kind = vector::CombiningKind::ADD;
1162 
1163  Value acc = op.getAcc();
1164  Operation *reductionOp =
1165  acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
1166  : rewriter.create<vector::ReductionOp>(loc, kind, m);
1167  return maskOperation(rewriter, reductionOp, mask)->getResult(0);
1168  }
1169  // Construct new iterator types and affine map array attribute.
1170  std::array<AffineMap, 3> lowIndexingMaps = {
1171  adjustMap(iMap[0], iterIndex, rewriter),
1172  adjustMap(iMap[1], iterIndex, rewriter),
1173  adjustMap(iMap[2], iterIndex, rewriter)};
1174  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1175  auto lowIter =
1176  rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1177  // Unroll into a series of lower dimensional vector.contract ops.
1178  // By feeding the initial accumulator into the first contraction,
1179  // and the result of each contraction into the next, eventually
1180  // the sum of all reductions is computed.
1181  Value result = op.getAcc();
1182  for (int64_t d = 0; d < dimSize; ++d) {
1183  auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1184  auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1185  Value newMask;
1186  if (mask)
1187  newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1188  iterIndex, d, rewriter);
1189 
1190  Operation *newContract = rewriter.create<vector::ContractionOp>(
1191  loc, lhs, rhs, result, lowAffine, lowIter);
1192  result = maskOperation(rewriter, newContract, newMask)->getResult(0);
1193  }
1194  return result;
1195 }
1196 
1197 /// Progressive lowering of OuterProductOp.
1198 /// One:
1199 /// %x = vector.outerproduct %lhs, %rhs, %acc
1200 /// is replaced by:
1201 /// %z = zero-result
1202 /// %0 = vector.extract %lhs[0]
1203 /// %1 = vector.broadcast %0
1204 /// %2 = vector.extract %acc[0]
1205 /// %3 = vector.fma %1, %rhs, %2
1206 /// %4 = vector.insert %3, %z[0]
1207 /// ..
1208 /// %x = vector.insert %.., %..[N-1]
1209 ///
1210 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1211 public:
1213 
1214  LogicalResult matchAndRewrite(vector::OuterProductOp op,
1215  PatternRewriter &rewriter) const override {
1216  VectorType resType = op.getResultVectorType();
1217  if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1218  return failure();
1219 
1220  auto loc = op.getLoc();
1221 
1222  VectorType lhsType = op.getOperandVectorTypeLHS();
1223  VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1224  Type eltType = resType.getElementType();
1225  bool isInt = isa<IntegerType, IndexType>(eltType);
1226  Value acc = op.getAcc();
1227  vector::CombiningKind kind = op.getKind();
1228 
1229  // Vector mask setup.
1230  OpBuilder::InsertionGuard guard(rewriter);
1231  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1232  Operation *rootOp;
1233  Value mask;
1234  if (maskableOp.isMasked()) {
1235  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
1236  rootOp = maskableOp.getMaskingOp();
1237  mask = maskableOp.getMaskingOp().getMask();
1238  } else {
1239  rootOp = op;
1240  }
1241 
1242  if (!rhsType) {
1243  // Special case: AXPY operation.
1244  Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
1245  std::optional<Value> mult = createContractArithOp(
1246  loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
1247  if (!mult.has_value())
1248  return failure();
1249  rewriter.replaceOp(rootOp, *mult);
1250  return success();
1251  }
1252 
1253  Value result = rewriter.create<arith::ConstantOp>(
1254  loc, resType, rewriter.getZeroAttr(resType));
1255  for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1256  Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d);
1257  Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
1258  Value r = nullptr;
1259  if (acc)
1260  r = rewriter.create<vector::ExtractOp>(loc, acc, d);
1261  Value extrMask;
1262  if (mask)
1263  extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d);
1264 
1265  std::optional<Value> m = createContractArithOp(
1266  loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1267  if (!m.has_value())
1268  return failure();
1269  result = rewriter.create<vector::InsertOp>(loc, *m, result, d);
1270  }
1271 
1272  rewriter.replaceOp(rootOp, result);
1273  return success();
1274  }
1275 };
1276 
1277 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1278 /// semantics to:
1279 /// ```
1280 /// %mta = maybe_transpose
1281 /// %mtb = maybe_transpose
1282 /// %flattened_a = vector.shape_cast %mta
1283 /// %flattened_b = vector.shape_cast %mtb
1284 /// %flattened_d = vector.matmul %flattened_a, %flattened_b
1285 /// %mtd = vector.shape_cast %flattened_d
1286 /// %d = maybe_untranspose %mtd
1287 /// %e = add %c, %d
1288 /// ```
1289 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1290 //
1291 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1292 /// vector.transpose operations are inserted if the vector.contract op is not a
1293 /// row-major matrix multiply.
1295 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1296  PatternRewriter &rew) const {
1297  // TODO: Support vector.mask.
1298  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
1299  if (maskableOp.isMasked())
1300  return failure();
1301 
1302  if (vectorTransformOptions.vectorContractLowering !=
1303  vector::VectorContractLowering::Matmul)
1304  return failure();
1305  if (failed(filter(op)))
1306  return failure();
1307 
1308  auto iteratorTypes = op.getIteratorTypes().getValue();
1309  if (!isParallelIterator(iteratorTypes[0]) ||
1310  !isParallelIterator(iteratorTypes[1]) ||
1311  !isReductionIterator(iteratorTypes[2]))
1312  return failure();
1313 
1314  Type elementType = op.getLhsType().getElementType();
1315  if (!elementType.isIntOrFloat())
1316  return failure();
1317 
1318  Type dstElementType = op.getType();
1319  if (auto vecType = dyn_cast<VectorType>(dstElementType))
1320  dstElementType = vecType.getElementType();
1321  if (elementType != dstElementType)
1322  return failure();
1323 
1324  // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1325  // Bail out if the contraction cannot be put in this form.
1326  MLIRContext *ctx = op.getContext();
1327  Location loc = op.getLoc();
1328  AffineExpr m, n, k;
1329  bindDims(rew.getContext(), m, n, k);
1330  // LHS must be A(m, k) or A(k, m).
1331  Value lhs = op.getLhs();
1332  auto lhsMap = op.getIndexingMapsArray()[0];
1333  if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1334  lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1335  else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1336  return failure();
1337 
1338  // RHS must be B(k, n) or B(n, k).
1339  Value rhs = op.getRhs();
1340  auto rhsMap = op.getIndexingMapsArray()[1];
1341  if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1342  rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1343  else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1344  return failure();
1345 
1346  // At this point lhs and rhs are in row-major.
1347  VectorType lhsType = cast<VectorType>(lhs.getType());
1348  VectorType rhsType = cast<VectorType>(rhs.getType());
1349  int64_t lhsRows = lhsType.getDimSize(0);
1350  int64_t lhsColumns = lhsType.getDimSize(1);
1351  int64_t rhsColumns = rhsType.getDimSize(1);
1352 
1353  Type flattenedLHSType =
1354  VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1355  lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1356 
1357  Type flattenedRHSType =
1358  VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1359  rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1360 
1361  Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1362  rhsColumns);
1363  mul = rew.create<vector::ShapeCastOp>(
1364  loc,
1365  VectorType::get({lhsRows, rhsColumns},
1366  getElementTypeOrSelf(op.getAcc().getType())),
1367  mul);
1368 
1369  // ACC must be C(m, n) or C(n, m).
1370  auto accMap = op.getIndexingMapsArray()[2];
1371  if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1372  mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1373  else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1374  llvm_unreachable("invalid contraction semantics");
1375 
1376  Value res =
1377  isa<IntegerType>(elementType)
1378  ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
1379  : static_cast<Value>(
1380  rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
1381 
1382  rew.replaceOp(op, res);
1383  return success();
1384 }
1385 } // namespace
1386 
1389  PatternBenefit benefit, bool disableOuterProductLowering) {
1390  if (!disableOuterProductLowering)
1391  patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1392  patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
1393  ContractionOpToOuterProductOpLowering>(
1394  options, patterns.getContext(), benefit);
1395 }
1396 
1398  RewritePatternSet &patterns, PatternBenefit benefit) {
1399  patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1400 }
static SmallVector< int64_t > getReductionIndex(AffineMap map, ArrayAttr iteratorTypes)
Return the positions of the reductions in the given map.
static Value createMul(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates a MulIOp if isInt is true otherwise create an MulFOp using operands x andy`.
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
static std::optional< unsigned > getDimPosition(AffineMap map, unsigned dim)
Look for a given dimension in an affine map and return its position.
static std::optional< Value > createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask=Value())
Helper to create arithmetic operation associated with a kind of contraction.
static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< Attribute > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
#define MINUI(lhs, rhs)
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:401
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:380
unsigned getNumResults() const
Definition: AffineMap.cpp:388
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:296
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
MLIRContext * getContext() const
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:809
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:685
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:117
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:305
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:330
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:635
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:140
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:135
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:349
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:599
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361
Structure to control the behavior of vector transform patterns.