MLIR  18.0.0git
LowerVectorContract.cpp
Go to the documentation of this file.
1 //===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.contract' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
35 
36 #define DEBUG_TYPE "vector-contract-lowering"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 //===----------------------------------------------------------------------===//
42 // Helper functions
43 //===----------------------------------------------------------------------===//
44 
45 // Helper to find an index in an affine map.
46 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
47  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
48  int64_t idx = map.getDimPosition(i);
49  if (idx == index)
50  return i;
51  }
52  return std::nullopt;
53 }
54 
55 // Helper to construct iterator types with one index removed.
56 static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
57  int64_t index) {
58  SmallVector<Attribute> results;
59  for (const auto &it : llvm::enumerate(iteratorTypes)) {
60  int64_t idx = it.index();
61  if (idx == index)
62  continue;
63  results.push_back(it.value());
64  }
65  return results;
66 }
67 
68 // Helper to construct an affine map with one index removed.
69 static AffineMap adjustMap(AffineMap map, int64_t index,
70  PatternRewriter &rewriter) {
71  auto *ctx = rewriter.getContext();
73  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
74  int64_t idx = map.getDimPosition(i);
75  if (idx == index)
76  continue;
77  // Re-insert remaining indices, but renamed when occurring
78  // after the removed index.
79  auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
80  results.push_back(targetExpr);
81  }
82  return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
83 }
84 
85 // Helper method to possibly drop a dimension in a load.
86 // TODO
87 static Value reshapeLoad(Location loc, Value val, VectorType type,
88  int64_t index, int64_t pos,
89  PatternRewriter &rewriter) {
90  if (index == -1)
91  return val;
92 
93  // At extraction dimension?
94  if (index == 0)
95  return rewriter.create<vector::ExtractOp>(loc, val, pos);
96 
97  // Unroll leading dimensions.
98  VectorType vType = VectorType::Builder(type).dropDim(0);
99  VectorType resType = VectorType::Builder(type).dropDim(index);
100  Value result = rewriter.create<arith::ConstantOp>(
101  loc, resType, rewriter.getZeroAttr(resType));
102  for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
103  Value ext = rewriter.create<vector::ExtractOp>(loc, val, d);
104  Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
105  result = rewriter.create<vector::InsertOp>(loc, load, result, d);
106  }
107  return result;
108 }
109 
110 // Helper method to possibly drop a dimension in a store.
111 // TODO
112 static Value reshapeStore(Location loc, Value val, Value result,
113  VectorType type, int64_t index, int64_t pos,
114  PatternRewriter &rewriter) {
115  // Unmodified?
116  if (index == -1)
117  return val;
118  // At insertion dimension?
119  if (index == 0)
120  return rewriter.create<vector::InsertOp>(loc, val, result, pos);
121 
122  // Unroll leading dimensions.
123  VectorType vType = VectorType::Builder(type).dropDim(0);
124  for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
125  Value ext = rewriter.create<vector::ExtractOp>(loc, result, d);
126  Value ins = rewriter.create<vector::ExtractOp>(loc, val, d);
127  Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
128  result = rewriter.create<vector::InsertOp>(loc, sto, result, d);
129  }
130  return result;
131 }
132 
133 /// Helper to create arithmetic operation associated with a kind of contraction.
134 static std::optional<Value>
136  vector::CombiningKind kind, PatternRewriter &rewriter,
137  bool isInt, Value mask = Value()) {
138  using vector::CombiningKind;
139  Value mul;
140 
141  if (isInt) {
142  if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF ||
143  kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
144  // Only valid for floating point types.
145  return std::nullopt;
146  mul = rewriter.create<arith::MulIOp>(loc, x, y);
147  } else {
148  // Float case.
149  if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
150  kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
151  kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
152  kind == CombiningKind::XOR)
153  // Only valid for integer types.
154  return std::nullopt;
155  // Special case for fused multiply-add.
156  if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
157  Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
158  if (mask)
159  // The fma op doesn't need explicit masking. However, fma ops used in
160  // reductions must preserve previous 'acc' values for masked-out lanes.
161  fma = selectPassthru(rewriter, mask, fma, acc);
162  return fma;
163  }
164  mul = rewriter.create<arith::MulFOp>(loc, x, y);
165  }
166 
167  if (!acc)
168  return std::optional<Value>(mul);
169 
170  return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
171 }
172 
173 /// Return the positions of the reductions in the given map.
175  ArrayAttr iteratorTypes) {
176  SmallVector<int64_t> dimsIdx;
177  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
178  if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
179  dimsIdx.push_back(i);
180  }
181  return dimsIdx;
182 }
183 
184 /// Look for a given dimension in an affine map and return its position. Return
185 /// std::nullopt if the dimension is not in the map results.
186 static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
187  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
188  if (map.getDimPosition(i) == dim)
189  return i;
190  }
191  return std::nullopt;
192 }
193 
194 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
195 /// operands `x` and `y`.
196 static Value createAdd(Location loc, Value x, Value y, bool isInt,
197  PatternRewriter &rewriter) {
198  if (isInt)
199  return rewriter.create<arith::AddIOp>(loc, x, y);
200  return rewriter.create<arith::AddFOp>(loc, x, y);
201 }
202 
203 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
204 /// operands `x and `y`.
205 static Value createMul(Location loc, Value x, Value y, bool isInt,
206  PatternRewriter &rewriter) {
207  if (isInt)
208  return rewriter.create<arith::MulIOp>(loc, x, y);
209  return rewriter.create<arith::MulFOp>(loc, x, y);
210 }
211 
212 namespace {
213 
214 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
215 /// semantics to:
216 /// ```
217 /// %flattened_a = vector.shape_cast %a
218 /// %flattened_b = vector.shape_cast %b
219 /// %flattened_d = vector.matmul %flattened_a, %flattened_b
220 /// %d = vector.shape_cast %%flattened_d
221 /// %e = add %c, %d
222 /// ```
223 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
224 //
225 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
226 /// the vector.contract op is a row-major matrix multiply.
227 class ContractionOpToMatmulOpLowering
228  : public OpRewritePattern<vector::ContractionOp> {
229 public:
231 
232  using FilterConstraintType =
233  std::function<LogicalResult(vector::ContractionOp op)>;
234 
235  static LogicalResult defaultFilter(vector::ContractionOp op) {
236  return success();
237  }
238 
239  ContractionOpToMatmulOpLowering(
240  vector::VectorTransformsOptions vectorTransformOptions,
241  MLIRContext *context, PatternBenefit benefit = 1,
242  FilterConstraintType constraint = defaultFilter)
243  : OpRewritePattern<vector::ContractionOp>(context, benefit),
244  vectorTransformOptions(vectorTransformOptions),
245  filter(std::move(constraint)) {}
246 
247  LogicalResult matchAndRewrite(vector::ContractionOp op,
248  PatternRewriter &rewriter) const override;
249 
250 private:
251  /// Options to control the vector patterns.
252  vector::VectorTransformsOptions vectorTransformOptions;
253  FilterConstraintType filter;
254 };
255 
256 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
257 /// semantics to a reduction_size-unrolled sequence:
258 /// ```
259 /// %at = vector.transpose %a, [1, 0]
260 /// %bRow0 = vector.extract %b[0]
261 /// %atRow0 = vector.extract %at[0]
262 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
263 /// ...
264 /// %bRowK = vector.extract %b[K]
265 /// %atRowK = vector.extract %at[K]
266 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
267 /// ```
268 ///
269 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
270 /// the vector.contract op is a row-major matrix multiply.
271 class ContractionOpToOuterProductOpLowering
272  : public OpRewritePattern<vector::ContractionOp> {
273 public:
275 
276  using FilterConstraintType =
277  std::function<LogicalResult(vector::ContractionOp op)>;
278 
279  static LogicalResult defaultFilter(vector::ContractionOp op) {
280  return success();
281  }
282 
283  ContractionOpToOuterProductOpLowering(
284  vector::VectorTransformsOptions vectorTransformOptions,
285  MLIRContext *context, PatternBenefit benefit = 1,
286  FilterConstraintType constraint = defaultFilter)
287  : OpRewritePattern<vector::ContractionOp>(context, benefit),
288  vectorTransformOptions(vectorTransformOptions),
289  filter(std::move(constraint)) {}
290 
291  LogicalResult matchAndRewrite(vector::ContractionOp op,
292  PatternRewriter &rewriter) const override;
293 
294 private:
295  /// Options to control the vector patterns.
296  vector::VectorTransformsOptions vectorTransformOptions;
297  FilterConstraintType filter;
298 };
299 
300 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
301 /// semantics to an output-size-unrolled sequence:
302 /// ```
303 /// %out = arith.constant ... : vector<MxNxelt_type>
304 /// %bt = vector.transpose %b, [1, 0]
305 /// %aRow0 = vector.extract %a[0]
306 /// %btRow0 = vector.extract %bt[0]
307 /// %c00 = vector.reduce %atRow0, %bRow0
308 /// %out00 = vector.insert %c00, %out[0, 0]
309 /// ...
310 /// %aRowLast = vector.extract %at[M-1]
311 /// %btRowLast = vector.extract %b[N-1]
312 /// %cLastLast = vector.reduce %atRowLast, %bRowLast
313 /// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
314 /// ```
315 ///
316 /// This only kicks in when VectorTransformsOptions is set to Dot and
317 /// the vector.contract op is a row-major matmul or matvec.
318 class ContractionOpToDotLowering
319  : public OpRewritePattern<vector::ContractionOp> {
320 public:
322 
323  using FilterConstraintType =
324  std::function<LogicalResult(vector::ContractionOp op)>;
325 
326  static LogicalResult defaultFilter(vector::ContractionOp op) {
327  return success();
328  }
329 
330  ContractionOpToDotLowering(
331  vector::VectorTransformsOptions vectorTransformOptions,
332  MLIRContext *context, PatternBenefit benefit = 1,
333  const FilterConstraintType &constraint = defaultFilter)
334  : OpRewritePattern<vector::ContractionOp>(context, benefit),
335  vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
336 
337  LogicalResult matchAndRewrite(vector::ContractionOp op,
338  PatternRewriter &rewriter) const override;
339 
340 private:
341  /// Options to control the vector patterns.
342  vector::VectorTransformsOptions vectorTransformOptions;
343  FilterConstraintType filter;
344 };
345 
346 /// Progressive lowering of ContractionOp.
347 ///
348 /// One:
349 /// %x = vector.contract with at least one free/batch dimension
350 /// is replaced by:
351 /// %a = vector.contract with one less free/batch dimension
352 /// %b = vector.contract with one less free/batch dimension
353 /// ..
354 /// %x = combine %a %b ..
355 /// until a pure contraction is reached (no free/batch dimensions),
356 /// which is replaced by a dot-product.
357 ///
358 /// This only kicks in when either VectorTransformsOptions is set
359 /// to Dot or when other contraction patterns fail.
360 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
361 public:
363  using FilterConstraintType =
364  std::function<LogicalResult(vector::ContractionOp op)>;
365 
366  static LogicalResult defaultFilter(vector::ContractionOp op) {
367  return success();
368  }
369 
370  ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
371  MLIRContext *context, PatternBenefit benefit = 1,
372  FilterConstraintType constraint = defaultFilter)
373  : OpRewritePattern<vector::ContractionOp>(context, benefit),
374  vectorTransformOptions(vectorTransformOptions),
375  filter(std::move(constraint)) {}
376 
377  LogicalResult matchAndRewrite(vector::ContractionOp op,
378  PatternRewriter &rewriter) const override;
379 
380 private:
381  /// Options to control the vector patterns.
382  vector::VectorTransformsOptions vectorTransformOptions;
383  FilterConstraintType filter;
384  // Lower one parallel dimension.
385  FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
386  vector::ContractionOp op, int64_t lhsIndex,
387  int64_t rhsIndex, Value mask) const;
388  // Lower one reduction dimension.
389  FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
390  vector::ContractionOp op, Value mask) const;
391 };
392 
393 /// Generate a vector implementation for matmat, matvec and tmatvec.
394 /// This unrolls outer-products along the reduction dimension.
395 struct UnrolledOuterProductGenerator
396  : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
397  UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
398  : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
399  kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
400  res(op.getAcc()), lhsType(op.getLhsType()) {
401  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
402  if (maskableOp.isMasked())
403  mask = maskableOp.getMaskingOp().getMask();
404  }
405 
406  Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
407  if (!v)
408  return v;
409  return rewriter.create<vector::TransposeOp>(loc, v, perm);
410  }
411 
412  Value promote(Value v, Type dstElementType) {
413  Type elementType = v.getType();
414  auto vecType = dyn_cast<VectorType>(elementType);
415  if (vecType)
416  elementType = vecType.getElementType();
417  if (elementType == dstElementType)
418  return v;
419  Type promotedType = dstElementType;
420  if (vecType)
421  promotedType = VectorType::get(vecType.getShape(), promotedType);
422  if (isa<FloatType>(dstElementType))
423  return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
424  return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
425  }
426 
427  FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, int reductionSize,
428  std::optional<Value> maybeMask = std::nullopt) {
429  assert(reductionSize > 0);
430  // Incremental support for masking.
431  if (mask && !maybeMask.has_value())
432  return failure();
433 
434  Type resElementType = cast<VectorType>(res.getType()).getElementType();
435  for (int64_t k = 0; k < reductionSize; ++k) {
436  Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
437  Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
438  extractA = promote(extractA, resElementType);
439  extractB = promote(extractB, resElementType);
440  Value extractMask;
441  if (maybeMask.has_value() && maybeMask.value())
442  extractMask =
443  rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
444 
445  Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
446  loc, res.getType(), extractA, extractB, res, kind);
447  res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
448  }
449  return res;
450  }
451 
452  /// Two outer parallel, one inner reduction (matmat flavor).
453  FailureOr<Value> matmat() {
454  if (!iters({Par(), Par(), Red()}))
455  return failure();
456  // Set up the parallel/reduction structure in the right form.
457  AffineExpr m, n, k;
458  bindDims(rewriter.getContext(), m, n, k);
459  Value transposedMask = t(mask, {2, 0, 1});
460  // Classical row-major matmul: Just permute the lhs.
461  if (layout({{m, k}, {k, n}, {m, n}}))
462  return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
463  // TODO: may be better to fail and use some vector<k> -> scalar reduction.
464  if (layout({{m, k}, {n, k}, {m, n}})) {
465  Value tlhs = t(lhs);
466  return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1),
467  transposedMask);
468  }
469  // No need to permute anything.
470  if (layout({{k, m}, {k, n}, {m, n}}))
471  return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
472  // Just permute the rhs.
473  if (layout({{k, m}, {n, k}, {m, n}}))
474  return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0), transposedMask);
475  // Transposed output: swap RHS and LHS.
476  // Classical row-major matmul: permute the lhs.
477  if (layout({{m, k}, {k, n}, {n, m}}))
478  return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1), transposedMask);
479  // TODO: may be better to fail and use some vector<k> -> scalar reduction.
480  if (layout({{m, k}, {n, k}, {n, m}})) {
481  Value trhs = t(rhs);
482  return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1),
483  transposedMask);
484  }
485  if (layout({{k, m}, {k, n}, {n, m}}))
486  return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
487  if (layout({{k, m}, {n, k}, {n, m}}))
488  return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
489  return failure();
490  }
491 
492  //
493  // One outer parallel, one inner reduction (matvec flavor).
494  // Mask needs to be transposed everywhere to turn the reduction dimension
495  // outermost as required by outerproduct.
496  //
497  FailureOr<Value> matvec() {
498  if (!iters({Par(), Red()}))
499  return failure();
500  AffineExpr m, k;
501  bindDims(rewriter.getContext(), m, k);
502  Value transposedMask = t(mask);
503 
504  // Case mat-vec: transpose.
505  if (layout({{m, k}, {k}, {m}}))
506  return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
507  // Case mat-trans-vec: ready to go.
508  if (layout({{k, m}, {k}, {m}}))
509  return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
510  // Case vec-mat: swap and transpose.
511  if (layout({{k}, {m, k}, {m}}))
512  return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
513  // Case vec-mat-trans: swap and ready to go.
514  if (layout({{k}, {k, m}, {m}}))
515  return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
516  return failure();
517  }
518 
519  //
520  // One outer reduction, one inner parallel (tmatvec flavor).
521  // Mask already has the shape of the outer product.
522  //
523  FailureOr<Value> tmatvec() {
524  if (!iters({Red(), Par()}))
525  return failure();
526  AffineExpr k, m;
527  bindDims(rewriter.getContext(), k, m);
528 
529  // Case mat-vec: transpose.
530  if (layout({{m, k}, {k}, {m}}))
531  return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), mask);
532  // Case mat-trans-vec: ready to go.
533  if (layout({{k, m}, {k}, {m}}))
534  return outerProd(lhs, rhs, res, lhsType.getDimSize(0), mask);
535  // Case vec-mat: swap and transpose.
536  if (layout({{k}, {m, k}, {m}}))
537  return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), mask);
538  // Case vec-mat-trans: swap and ready to go.
539  if (layout({{k}, {k, m}, {m}}))
540  return outerProd(rhs, lhs, res, lhsType.getDimSize(0), mask);
541  return failure();
542  }
543 
544 private:
545  vector::CombiningKind kind;
546  Value lhs, rhs, res, mask;
547  VectorType lhsType;
548 };
549 
550 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
551 /// semantics to a reduction_size-unrolled sequence:
552 /// ```
553 /// %at = vector.transpose %a, [1, 0]
554 /// %bRow0 = vector.extract %b[0]
555 /// %atRow0 = vector.extract %at[0]
556 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
557 /// ...
558 /// %bRowK = vector.extract %b[K]
559 /// %atRowK = vector.extract %at[K]
560 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
561 /// ```
562 ///
563 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
564 /// otherwise supports any layout permutation of the matrix-multiply.
565 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
566  vector::ContractionOp op, PatternRewriter &rewriter) const {
567  if (vectorTransformOptions.vectorContractLowering !=
568  vector::VectorContractLowering::OuterProduct)
569  return failure();
570 
571  if (failed(filter(op)))
572  return failure();
573 
574  // Vector mask setup.
575  OpBuilder::InsertionGuard guard(rewriter);
576  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
577  Operation *rootOp;
578  if (maskableOp.isMasked()) {
579  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
580  rootOp = maskableOp.getMaskingOp();
581  } else {
582  rootOp = op;
583  }
584 
585  UnrolledOuterProductGenerator e(rewriter, op);
586  FailureOr<Value> matmatRes = e.matmat();
587  if (succeeded(matmatRes)) {
588  rewriter.replaceOp(rootOp, *matmatRes);
589  return success();
590  }
591  FailureOr<Value> matvecRes = e.matvec();
592  if (succeeded(matvecRes)) {
593  rewriter.replaceOp(rootOp, *matvecRes);
594  return success();
595  }
596  FailureOr<Value> tmatvecRes = e.tmatvec();
597  if (succeeded(tmatvecRes)) {
598  rewriter.replaceOp(rootOp, *tmatvecRes);
599  return success();
600  }
601 
602  return failure();
603 }
604 
606 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
607  PatternRewriter &rewriter) const {
608  // TODO: Support vector.mask.
609  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
610  if (maskableOp.isMasked())
611  return failure();
612 
613  if (failed(filter(op)))
614  return failure();
615 
616  if (vectorTransformOptions.vectorContractLowering !=
617  vector::VectorContractLowering::Dot)
618  return failure();
619 
620  auto iteratorTypes = op.getIteratorTypes().getValue();
621  static constexpr std::array<int64_t, 2> perm = {1, 0};
622  Location loc = op.getLoc();
623  Value lhs = op.getLhs(), rhs = op.getRhs();
624 
625  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
626  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
627  AffineExpr m, n, k;
628  bindDims(rewriter.getContext(), m, n, k);
629  SmallVector<AffineMap> maps = op.getIndexingMapsArray();
630  //
631  // In the following we wish to make the reduction dimension innermost so we
632  // can load vectors and just fmul + reduce into a scalar.
633  //
634  if (isParallelIterator(iteratorTypes[0]) &&
635  isParallelIterator(iteratorTypes[1]) &&
636  isReductionIterator(iteratorTypes[2])) {
637  //
638  // Two outer parallel, one inner reduction (matmat flavor).
639  //
640  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
641  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
642  } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
643  // No need to permute anything.
644  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
645  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
646  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
647  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
648  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
649  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
650  // This is the classical row-major matmul. Just permute the lhs.
651  Value tmp = lhs;
652  lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
653  rhs = tmp;
654  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
655  std::swap(lhs, rhs);
656  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
657  Value tmp = lhs;
658  lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
659  rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
660  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
661  Value tmp = rhs;
662  rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
663  lhs = tmp;
664  } else {
665  return failure();
666  }
667  } else if (isParallelIterator(iteratorTypes[0]) &&
668  isReductionIterator(iteratorTypes[1])) {
669  //
670  // One outer parallel, one inner reduction (matvec flavor)
671  //
672  if (maps == infer({{m, n}, {n}, {m}})) {
673  // No need to permute anything.
674  } else if (maps == infer({{n, m}, {n}, {m}})) {
675  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
676  } else if (maps == infer({{n}, {m, n}, {m}})) {
677  std::swap(lhs, rhs);
678  } else if (maps == infer({{n}, {n, m}, {m}})) {
679  std::swap(lhs, rhs);
680  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
681  } else {
682  return failure();
683  }
684  } else {
685  return failure();
686  }
687 
688  VectorType dstType = cast<VectorType>(op.getResultType());
689  assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
690  "Expected dst type of rank 1 or 2");
691 
692  unsigned rank = dstType.getRank();
693  unsigned dstRows = dstType.getShape()[0];
694  unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
695 
696  // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
697  Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
698  rewriter.getZeroAttr(dstType));
699  bool isInt = isa<IntegerType>(dstType.getElementType());
700  for (unsigned r = 0; r < dstRows; ++r) {
701  Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
702  for (unsigned c = 0; c < dstColumns; ++c) {
703  Value b = rank == 1
704  ? rhs
705  : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
706  Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
707  Value reduced = rewriter.create<vector::ReductionOp>(
708  op.getLoc(), vector::CombiningKind::ADD, m);
709 
711  : SmallVector<int64_t, 2>{r, c};
712  res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
713  }
714  }
715  if (auto acc = op.getAcc())
716  res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
717  rewriter.replaceOp(op, res);
718  return success();
719 }
720 
721 /// Lower vector.contract with all size one reduction dimensions to
722 /// elementwise ops when possible.
723 struct ContractOpToElementwise
724  : public OpRewritePattern<vector::ContractionOp> {
726  using FilterConstraintType =
727  std::function<LogicalResult(vector::ContractionOp op)>;
728  static LogicalResult defaultFilter(vector::ContractionOp op) {
729  return success();
730  }
731  ContractOpToElementwise(
732  vector::VectorTransformsOptions vectorTransformOptions,
733  MLIRContext *context, PatternBenefit benefit = 1,
734  const FilterConstraintType &constraint = defaultFilter)
735  : OpRewritePattern<vector::ContractionOp>(context, benefit),
736  vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
737 
738  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
739  PatternRewriter &rewriter) const override {
740  // TODO: Support vector.mask.
741  auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation());
742  if (maskableOp.isMasked())
743  return failure();
744 
745  if (failed(filter(contractOp)))
746  return failure();
747 
748  if (vectorTransformOptions.vectorContractLowering !=
749  vector::VectorContractLowering::ParallelArith)
750  return failure();
751 
752  ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
753  ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
754  AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
755  AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
756  SmallVector<int64_t> lhsReductionDims =
757  getReductionIndex(lhsMap, contractOp.getIteratorTypes());
758  SmallVector<int64_t> rhsReductionDims =
759  getReductionIndex(rhsMap, contractOp.getIteratorTypes());
760  // All the reduction dimensions must be a size 1.
761  for (int64_t dim : lhsReductionDims) {
762  if (lhsShape[dim] != 1)
763  return failure();
764  }
765  for (int64_t dim : rhsReductionDims) {
766  if (rhsShape[dim] != 1)
767  return failure();
768  }
769  AffineMap accMap = contractOp.getIndexingMapsArray()[2];
770  unsigned numParallelDims = accMap.getNumResults();
771  unsigned numLhsDimToBroadcast =
772  numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
773  unsigned numRhsDimToBroadcast =
774  numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
775  SmallVector<int64_t> lhsDims;
776  SmallVector<int64_t> lhsTranspose;
777  SmallVector<int64_t> rhsDims;
778  SmallVector<int64_t> rhsTranspose;
779  for (int64_t dim : lhsReductionDims)
780  lhsTranspose.push_back(numLhsDimToBroadcast + dim);
781  for (int64_t dim : rhsReductionDims)
782  rhsTranspose.push_back(numRhsDimToBroadcast + dim);
783  // Loop through the parallel dimensions to calculate the dimensions to
784  // broadcast and to permute in order to extract only parallel dimensions.
785  for (unsigned i = 0; i < numParallelDims; i++) {
786  std::optional<unsigned> lhsDim =
787  getDimPosition(lhsMap, accMap.getDimPosition(i));
788  if (lhsDim) {
789  lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
790  } else {
791  // If the parallel dimension doesn't exist we will have to broadcast it.
792  lhsDims.push_back(
793  cast<VectorType>(contractOp.getResultType()).getDimSize(i));
794  lhsTranspose.push_back(lhsDims.size() - 1);
795  }
796  std::optional<unsigned> rhsDim =
797  getDimPosition(rhsMap, accMap.getDimPosition(i));
798  if (rhsDim) {
799  rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
800  } else {
801  // If the parallel dimension doesn't exist we will have to broadcast it.
802  rhsDims.push_back(
803  cast<VectorType>(contractOp.getResultType()).getDimSize(i));
804  rhsTranspose.push_back(rhsDims.size() - 1);
805  }
806  }
807  Value newLhs = contractOp.getLhs();
808  Value newRhs = contractOp.getRhs();
809  Location loc = contractOp.getLoc();
810  if (!lhsDims.empty()) {
811  lhsDims.append(lhsShape.begin(), lhsShape.end());
812  auto expandedType =
813  VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
814  newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
815  }
816  if (!rhsDims.empty()) {
817  rhsDims.append(rhsShape.begin(), rhsShape.end());
818  auto expandedType =
819  VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
820  newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
821  }
822  bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
823  newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
824  newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
825  SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
826  SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
827  newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
828  newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
829  std::optional<Value> result =
830  createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
831  contractOp.getKind(), rewriter, isInt);
832  rewriter.replaceOp(contractOp, {*result});
833  return success();
834  }
835 
836 private:
837  /// Options to control the vector patterns.
838  vector::VectorTransformsOptions vectorTransformOptions;
839  FilterConstraintType filter;
840 };
841 
842 /// Progressive lowering of ContractionOp.
843 /// One:
844 /// %x = vector.contract with at least one free/batch dimension
845 /// is replaced by:
846 /// %a = vector.contract with one less free/batch dimension
847 /// %b = vector.contract with one less free/batch dimension
848 /// ..
849 /// %x = combine %a %b ..
850 /// until a pure contraction is reached (no free/batch dimensions),
851 /// which is replaced by a dot-product.
852 ///
853 /// This only kicks in when either VectorTransformsOptions is set
854 /// to DOT or when other contraction patterns fail.
855 //
856 // TODO: break down into transpose/reshape/cast ops
857 // when they become available to avoid code dup
858 // TODO: investigate lowering order impact on performance
860 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
861  PatternRewriter &rewriter) const {
862  if (failed(filter(op)))
863  return failure();
864 
865  // TODO: support mixed mode contract lowering.
866  if (op.getLhsType().getElementType() !=
867  getElementTypeOrSelf(op.getAccType()) ||
868  op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
869  return failure();
870 
871  // TODO: the code below assumes the default contraction, make sure it supports
872  // other kinds before enabling this lowering.
873  if (op.getKind() != vector::CombiningKind::ADD) {
874  return rewriter.notifyMatchFailure(
875  op, "contractions other than 'add' not supported");
876  }
877 
878  // TODO: implement benefits, cost models.
879  MLIRContext *ctx = op.getContext();
880  ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
881  if (succeeded(pat1.matchAndRewrite(op, rewriter)))
882  return success();
883  ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
884  if (succeeded(pat2.matchAndRewrite(op, rewriter)))
885  return success();
886  ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
887  if (succeeded(pat3.matchAndRewrite(op, rewriter)))
888  return success();
889  ContractOpToElementwise pat4(vectorTransformOptions, ctx);
890  if (succeeded(pat4.matchAndRewrite(op, rewriter)))
891  return success();
892 
893  // Vector mask setup.
894  OpBuilder::InsertionGuard guard(rewriter);
895  Operation *rootOp = op;
896  Value mask;
897  if (op.isMasked()) {
898  rewriter.setInsertionPoint(op.getMaskingOp());
899  rootOp = op.getMaskingOp();
900  mask = op.getMaskingOp().getMask();
901  }
902 
903  // Find first batch dimension in LHS/RHS, and lower when found.
904  std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
905  if (!batchDimMap.empty()) {
906  int64_t lhsIndex = batchDimMap[0].first;
907  int64_t rhsIndex = batchDimMap[0].second;
908  auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
909  if (failed(newOp))
910  return failure();
911  rewriter.replaceOp(rootOp, *newOp);
912  return success();
913  }
914 
915  // Collect contracting dimensions.
916  std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
917  op.getContractingDimMap();
918  DenseSet<int64_t> lhsContractingDimSet;
919  DenseSet<int64_t> rhsContractingDimSet;
920  for (auto &dimPair : contractingDimMap) {
921  lhsContractingDimSet.insert(dimPair.first);
922  rhsContractingDimSet.insert(dimPair.second);
923  }
924 
925  // Find first free dimension in LHS, and lower when found.
926  VectorType lhsType = op.getLhsType();
927  for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
928  if (lhsContractingDimSet.count(lhsIndex) == 0) {
929  auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
930  if (failed(newOp))
931  return failure();
932  rewriter.replaceOp(rootOp, *newOp);
933  return success();
934  }
935  }
936 
937  // Find first free dimension in RHS, and lower when found.
938  VectorType rhsType = op.getRhsType();
939  for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
940  if (rhsContractingDimSet.count(rhsIndex) == 0) {
941  auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
942  if (failed(newOp))
943  return failure();
944  rewriter.replaceOp(rootOp, *newOp);
945  return success();
946  }
947  }
948 
949  // Lower the first remaining reduction dimension.
950  if (!contractingDimMap.empty()) {
951  auto newOp = lowerReduction(rewriter, op, mask);
952  if (failed(newOp))
953  return failure();
954  rewriter.replaceOp(rootOp, *newOp);
955  return success();
956  }
957 
958  return failure();
959 }
960 
961 // Lower one parallel dimension.
962 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
963 // TODO: consider reusing existing contract unrolling
964 FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
965  vector::ContractionOp op,
966  int64_t lhsIndex,
967  int64_t rhsIndex,
968  Value mask) const {
969  VectorType lhsType = op.getLhsType();
970  VectorType rhsType = op.getRhsType();
971  VectorType resType = cast<VectorType>(op.getResultType());
972  // Find the iterator type index and result index.
973  SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
974  int64_t iterIndex = -1;
975  int64_t dimSize = -1;
976  if (lhsIndex >= 0) {
977  iterIndex = iMap[0].getDimPosition(lhsIndex);
978  if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
979  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
980  diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
981  << " to map to the same dimension";
982  });
983  dimSize = lhsType.getDimSize(lhsIndex);
984  } else if (rhsIndex >= 0) {
985  iterIndex = iMap[1].getDimPosition(rhsIndex);
986  dimSize = rhsType.getDimSize(rhsIndex);
987  }
988  if (iterIndex < 0)
989  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
990  diag << "expected either lhsIndex=" << lhsIndex
991  << " or rhsIndex=" << rhsIndex << " to be nonnegative";
992  });
993  // value_or(-1) means that we tolerate a dimension not appearing
994  // in the result map. That can't happen for actual parallel iterators, but
995  // the caller ContractionOpLowering::matchAndRewrite is currently calling
996  // lowerParallel also for the case of unit-size reduction dims appearing only
997  // on one of LHS or RHS, not both. At the moment, such cases are created by
998  // CastAwayContractionLeadingOneDim, so we need to either support that or
999  // modify that pattern.
1000  int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
1001  if (resIndex == -1 && dimSize != 1)
1002  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1003  diag << "expected the dimension for iterIndex=" << iterIndex
1004  << " to either appear in the result map, or to be a unit dimension";
1005  });
1006 
1007  // Construct new iterator types and affine map array attribute.
1008  std::array<AffineMap, 3> lowIndexingMaps = {
1009  adjustMap(iMap[0], iterIndex, rewriter),
1010  adjustMap(iMap[1], iterIndex, rewriter),
1011  adjustMap(iMap[2], iterIndex, rewriter)};
1012  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1013  auto lowIter =
1014  rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1015  // Unroll into a series of lower dimensional vector.contract ops.
1016  Location loc = op.getLoc();
1017  Value result = rewriter.create<arith::ConstantOp>(
1018  loc, resType, rewriter.getZeroAttr(resType));
1019 
1020  for (int64_t d = 0; d < dimSize; ++d) {
1021  auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1022  auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1023  auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1024 
1025  Value lowMask;
1026  if (mask)
1027  lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1028  iterIndex, d, rewriter);
1029 
1030  Operation *lowContract = rewriter.create<vector::ContractionOp>(
1031  loc, lhs, rhs, acc, lowAffine, lowIter);
1032  lowContract = maskOperation(rewriter, lowContract, lowMask);
1033  result = reshapeStore(loc, lowContract->getResult(0), result, resType,
1034  resIndex, d, rewriter);
1035  }
1036  return result;
1037 }
1038 
1039 // Lower one reduction dimension.
1040 FailureOr<Value> ContractionOpLowering::lowerReduction(
1041  PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
1042  auto loc = op.getLoc();
1043  VectorType lhsType = op.getLhsType();
1044  VectorType rhsType = op.getRhsType();
1045  Type resType = op.getResultType();
1046  if (isa<VectorType>(resType))
1047  return rewriter.notifyMatchFailure(op,
1048  "did not expect a VectorType result");
1049  bool isInt = isa<IntegerType>(resType);
1050  // Use iterator index 0.
1051  int64_t iterIndex = 0;
1052  SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1053  std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1054  std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1055  if (!lookupLhs.has_value())
1056  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1057  diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1058  });
1059  if (!lookupRhs.has_value())
1060  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1061  diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1062  });
1063  int64_t lhsIndex = *lookupLhs;
1064  int64_t rhsIndex = *lookupRhs;
1065  int64_t dimSize = lhsType.getDimSize(lhsIndex);
1066  if (dimSize != rhsType.getDimSize(rhsIndex))
1067  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1068  diag << "expect LHS dimension " << lhsIndex
1069  << " to have the same size as RHS dimension " << rhsIndex;
1070  });
1071  // Base case.
1072  if (lhsType.getRank() == 1) {
1073  if (rhsType.getRank() != 1)
1074  return rewriter.notifyMatchFailure(
1075  op, "When LHS has rank 1, expected also RHS to have rank 1");
1076  Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1077  auto kind = vector::CombiningKind::ADD;
1078 
1079  Value acc = op.getAcc();
1080  Operation *reductionOp =
1081  acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
1082  : rewriter.create<vector::ReductionOp>(loc, kind, m);
1083  return maskOperation(rewriter, reductionOp, mask)->getResult(0);
1084  }
1085  // Construct new iterator types and affine map array attribute.
1086  std::array<AffineMap, 3> lowIndexingMaps = {
1087  adjustMap(iMap[0], iterIndex, rewriter),
1088  adjustMap(iMap[1], iterIndex, rewriter),
1089  adjustMap(iMap[2], iterIndex, rewriter)};
1090  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1091  auto lowIter =
1092  rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1093  // Unroll into a series of lower dimensional vector.contract ops.
1094  // By feeding the initial accumulator into the first contraction,
1095  // and the result of each contraction into the next, eventually
1096  // the sum of all reductions is computed.
1097  Value result = op.getAcc();
1098  for (int64_t d = 0; d < dimSize; ++d) {
1099  auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1100  auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1101  Value newMask;
1102  if (mask)
1103  newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1104  iterIndex, d, rewriter);
1105 
1106  Operation *newContract = rewriter.create<vector::ContractionOp>(
1107  loc, lhs, rhs, result, lowAffine, lowIter);
1108  result = maskOperation(rewriter, newContract, newMask)->getResult(0);
1109  }
1110  return result;
1111 }
1112 
1113 /// Progressive lowering of OuterProductOp.
1114 /// One:
1115 /// %x = vector.outerproduct %lhs, %rhs, %acc
1116 /// is replaced by:
1117 /// %z = zero-result
1118 /// %0 = vector.extract %lhs[0]
1119 /// %1 = vector.broadcast %0
1120 /// %2 = vector.extract %acc[0]
1121 /// %3 = vector.fma %1, %rhs, %2
1122 /// %4 = vector.insert %3, %z[0]
1123 /// ..
1124 /// %x = vector.insert %.., %..[N-1]
1125 ///
1126 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1127 public:
1129 
1130  LogicalResult matchAndRewrite(vector::OuterProductOp op,
1131  PatternRewriter &rewriter) const override {
1132  VectorType resType = op.getResultVectorType();
1133  if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1134  return failure();
1135 
1136  auto loc = op.getLoc();
1137 
1138  VectorType lhsType = op.getOperandVectorTypeLHS();
1139  VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1140  Type eltType = resType.getElementType();
1141  bool isInt = isa<IntegerType, IndexType>(eltType);
1142  Value acc = op.getAcc();
1143  vector::CombiningKind kind = op.getKind();
1144 
1145  // Vector mask setup.
1146  OpBuilder::InsertionGuard guard(rewriter);
1147  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1148  Operation *rootOp;
1149  Value mask;
1150  if (maskableOp.isMasked()) {
1151  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
1152  rootOp = maskableOp.getMaskingOp();
1153  mask = maskableOp.getMaskingOp().getMask();
1154  } else {
1155  rootOp = op;
1156  }
1157 
1158  if (!rhsType) {
1159  // Special case: AXPY operation.
1160  Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
1161  std::optional<Value> mult = createContractArithOp(
1162  loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
1163  if (!mult.has_value())
1164  return failure();
1165  rewriter.replaceOp(rootOp, *mult);
1166  return success();
1167  }
1168 
1169  Value result = rewriter.create<arith::ConstantOp>(
1170  loc, resType, rewriter.getZeroAttr(resType));
1171  for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1172  Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d);
1173  Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
1174  Value r = nullptr;
1175  if (acc)
1176  r = rewriter.create<vector::ExtractOp>(loc, acc, d);
1177  Value extrMask;
1178  if (mask)
1179  extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d);
1180 
1181  std::optional<Value> m = createContractArithOp(
1182  loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1183  if (!m.has_value())
1184  return failure();
1185  result = rewriter.create<vector::InsertOp>(loc, *m, result, d);
1186  }
1187 
1188  rewriter.replaceOp(rootOp, result);
1189  return success();
1190  }
1191 };
1192 
1193 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1194 /// semantics to:
1195 /// ```
1196 /// %mta = maybe_transpose
1197 /// %mtb = maybe_transpose
1198 /// %flattened_a = vector.shape_cast %mta
1199 /// %flattened_b = vector.shape_cast %mtb
1200 /// %flattened_d = vector.matmul %flattened_a, %flattened_b
1201 /// %mtd = vector.shape_cast %flattened_d
1202 /// %d = maybe_untranspose %mtd
1203 /// %e = add %c, %d
1204 /// ```
1205 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1206 //
1207 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1208 /// vector.transpose operations are inserted if the vector.contract op is not a
1209 /// row-major matrix multiply.
1211 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1212  PatternRewriter &rew) const {
1213  // TODO: Support vector.mask.
1214  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
1215  if (maskableOp.isMasked())
1216  return failure();
1217 
1218  if (vectorTransformOptions.vectorContractLowering !=
1219  vector::VectorContractLowering::Matmul)
1220  return failure();
1221  if (failed(filter(op)))
1222  return failure();
1223 
1224  auto iteratorTypes = op.getIteratorTypes().getValue();
1225  if (!isParallelIterator(iteratorTypes[0]) ||
1226  !isParallelIterator(iteratorTypes[1]) ||
1227  !isReductionIterator(iteratorTypes[2]))
1228  return failure();
1229 
1230  Type elementType = op.getLhsType().getElementType();
1231  if (!elementType.isIntOrFloat())
1232  return failure();
1233 
1234  Type dstElementType = op.getType();
1235  if (auto vecType = dyn_cast<VectorType>(dstElementType))
1236  dstElementType = vecType.getElementType();
1237  if (elementType != dstElementType)
1238  return failure();
1239 
1240  // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1241  // Bail out if the contraction cannot be put in this form.
1242  MLIRContext *ctx = op.getContext();
1243  Location loc = op.getLoc();
1244  AffineExpr m, n, k;
1245  bindDims(rew.getContext(), m, n, k);
1246  // LHS must be A(m, k) or A(k, m).
1247  Value lhs = op.getLhs();
1248  auto lhsMap = op.getIndexingMapsArray()[0];
1249  if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1250  lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1251  else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1252  return failure();
1253 
1254  // RHS must be B(k, n) or B(n, k).
1255  Value rhs = op.getRhs();
1256  auto rhsMap = op.getIndexingMapsArray()[1];
1257  if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1258  rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1259  else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1260  return failure();
1261 
1262  // At this point lhs and rhs are in row-major.
1263  VectorType lhsType = cast<VectorType>(lhs.getType());
1264  VectorType rhsType = cast<VectorType>(rhs.getType());
1265  int64_t lhsRows = lhsType.getDimSize(0);
1266  int64_t lhsColumns = lhsType.getDimSize(1);
1267  int64_t rhsColumns = rhsType.getDimSize(1);
1268 
1269  Type flattenedLHSType =
1270  VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1271  lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1272 
1273  Type flattenedRHSType =
1274  VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1275  rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1276 
1277  Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1278  rhsColumns);
1279  mul = rew.create<vector::ShapeCastOp>(
1280  loc,
1281  VectorType::get({lhsRows, rhsColumns},
1282  getElementTypeOrSelf(op.getAcc().getType())),
1283  mul);
1284 
1285  // ACC must be C(m, n) or C(n, m).
1286  auto accMap = op.getIndexingMapsArray()[2];
1287  if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1288  mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1289  else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1290  llvm_unreachable("invalid contraction semantics");
1291 
1292  Value res =
1293  isa<IntegerType>(elementType)
1294  ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
1295  : static_cast<Value>(
1296  rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
1297 
1298  rew.replaceOp(op, res);
1299  return success();
1300 }
1301 } // namespace
1302 
1305  PatternBenefit benefit, bool disableOuterProductLowering) {
1306  if (!disableOuterProductLowering)
1307  patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1308  patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
1309  ContractionOpToOuterProductOpLowering>(
1310  options, patterns.getContext(), benefit);
1311 }
1312 
1314  RewritePatternSet &patterns, PatternBenefit benefit) {
1315  patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1316 }
static SmallVector< int64_t > getReductionIndex(AffineMap map, ArrayAttr iteratorTypes)
Return the positions of the reductions in the given map.
static Value createMul(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates a MulIOp if isInt is true otherwise create an MulFOp using operands x andy`.
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
static std::optional< unsigned > getDimPosition(AffineMap map, unsigned dim)
Look for a given dimension in an affine map and return its position.
static std::optional< Value > createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask=Value())
Helper to create arithmetic operation associated with a kind of contraction.
static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< Attribute > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:358
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:337
unsigned getNumResults() const
Definition: AffineMap.cpp:345
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:255
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:117
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:312
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:346
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:624
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:130
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:125
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, Value mask=Value())
Return the result value of reducing two scalar/vector values with the corresponding arith operation.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:331
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:502
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361
Structure to control the behavior of vector transform patterns.