MLIR  19.0.0git
LowerVectorContract.cpp
Go to the documentation of this file.
1 //===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.contract' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
35 
36 #define DEBUG_TYPE "vector-contract-lowering"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 //===----------------------------------------------------------------------===//
42 // Helper functions
43 //===----------------------------------------------------------------------===//
44 // Helper to find an index in an affine map.
45 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
46  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
47  int64_t idx = map.getDimPosition(i);
48  if (idx == index)
49  return i;
50  }
51  return std::nullopt;
52 }
53 
54 // Helper to construct iterator types with one index removed.
55 static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
56  int64_t index) {
57  SmallVector<Attribute> results;
58  for (const auto &it : llvm::enumerate(iteratorTypes)) {
59  int64_t idx = it.index();
60  if (idx == index)
61  continue;
62  results.push_back(it.value());
63  }
64  return results;
65 }
66 
67 // Helper to construct an affine map with one index removed.
68 static AffineMap adjustMap(AffineMap map, int64_t index,
69  PatternRewriter &rewriter) {
70  auto *ctx = rewriter.getContext();
72  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
73  int64_t idx = map.getDimPosition(i);
74  if (idx == index)
75  continue;
76  // Re-insert remaining indices, but renamed when occurring
77  // after the removed index.
78  auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
79  results.push_back(targetExpr);
80  }
81  return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
82 }
83 
84 // Helper method to possibly drop a dimension in a load.
85 // TODO
86 static Value reshapeLoad(Location loc, Value val, VectorType type,
87  int64_t index, int64_t pos,
88  PatternRewriter &rewriter) {
89  if (index == -1)
90  return val;
91 
92  // At extraction dimension?
93  if (index == 0)
94  return rewriter.create<vector::ExtractOp>(loc, val, pos);
95 
96  // Unroll leading dimensions.
97  VectorType vType = VectorType::Builder(type).dropDim(0);
98  VectorType resType = VectorType::Builder(type).dropDim(index);
99  Value result = rewriter.create<arith::ConstantOp>(
100  loc, resType, rewriter.getZeroAttr(resType));
101  for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
102  Value ext = rewriter.create<vector::ExtractOp>(loc, val, d);
103  Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
104  result = rewriter.create<vector::InsertOp>(loc, load, result, d);
105  }
106  return result;
107 }
108 
109 // Helper method to possibly drop a dimension in a store.
110 // TODO
111 static Value reshapeStore(Location loc, Value val, Value result,
112  VectorType type, int64_t index, int64_t pos,
113  PatternRewriter &rewriter) {
114  // Unmodified?
115  if (index == -1)
116  return val;
117  // At insertion dimension?
118  if (index == 0)
119  return rewriter.create<vector::InsertOp>(loc, val, result, pos);
120 
121  // Unroll leading dimensions.
122  VectorType vType = VectorType::Builder(type).dropDim(0);
123  for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
124  Value ext = rewriter.create<vector::ExtractOp>(loc, result, d);
125  Value ins = rewriter.create<vector::ExtractOp>(loc, val, d);
126  Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
127  result = rewriter.create<vector::InsertOp>(loc, sto, result, d);
128  }
129  return result;
130 }
131 
132 /// Helper to create arithmetic operation associated with a kind of contraction.
133 static std::optional<Value>
135  vector::CombiningKind kind, PatternRewriter &rewriter,
136  bool isInt, Value mask = Value()) {
137  using vector::CombiningKind;
138  Value mul;
139 
140  if (isInt) {
141  if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
142  kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
143  // Only valid for floating point types.
144  return std::nullopt;
145  mul = rewriter.create<arith::MulIOp>(loc, x, y);
146  } else {
147  // Float case.
148  if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
149  kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
150  kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
151  kind == CombiningKind::XOR)
152  // Only valid for integer types.
153  return std::nullopt;
154  // Special case for fused multiply-add.
155  if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
156  Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
157  if (mask)
158  // The fma op doesn't need explicit masking. However, fma ops used in
159  // reductions must preserve previous 'acc' values for masked-out lanes.
160  fma = selectPassthru(rewriter, mask, fma, acc);
161  return fma;
162  }
163  mul = rewriter.create<arith::MulFOp>(loc, x, y);
164  }
165 
166  if (!acc)
167  return std::optional<Value>(mul);
168 
169  return makeArithReduction(rewriter, loc, kind, mul, acc,
170  /*fastmath=*/nullptr, mask);
171 }
172 
173 /// Return the positions of the reductions in the given map.
175  ArrayAttr iteratorTypes) {
176  SmallVector<int64_t> dimsIdx;
177  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
178  if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
179  dimsIdx.push_back(i);
180  }
181  return dimsIdx;
182 }
183 
184 /// Look for a given dimension in an affine map and return its position. Return
185 /// std::nullopt if the dimension is not in the map results.
186 static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
187  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
188  if (map.getDimPosition(i) == dim)
189  return i;
190  }
191  return std::nullopt;
192 }
193 
194 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
195 /// operands `x` and `y`.
196 static Value createAdd(Location loc, Value x, Value y, bool isInt,
197  PatternRewriter &rewriter) {
198  if (isInt)
199  return rewriter.create<arith::AddIOp>(loc, x, y);
200  return rewriter.create<arith::AddFOp>(loc, x, y);
201 }
202 
203 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
204 /// operands `x and `y`.
205 static Value createMul(Location loc, Value x, Value y, bool isInt,
206  PatternRewriter &rewriter) {
207  if (isInt)
208  return rewriter.create<arith::MulIOp>(loc, x, y);
209  return rewriter.create<arith::MulFOp>(loc, x, y);
210 }
211 
212 namespace {
213 
214 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
215 /// semantics to:
216 /// ```
217 /// %flattened_a = vector.shape_cast %a
218 /// %flattened_b = vector.shape_cast %b
219 /// %flattened_d = vector.matmul %flattened_a, %flattened_b
220 /// %d = vector.shape_cast %%flattened_d
221 /// %e = add %c, %d
222 /// ```
223 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
224 //
225 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
226 /// the vector.contract op is a row-major matrix multiply.
227 class ContractionOpToMatmulOpLowering
228  : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
229 public:
230  using MaskableOpRewritePattern::MaskableOpRewritePattern;
231 
232  using FilterConstraintType =
233  std::function<LogicalResult(vector::ContractionOp op)>;
234 
235  static LogicalResult defaultFilter(vector::ContractionOp op) {
236  return success();
237  }
238 
239  ContractionOpToMatmulOpLowering(
240  vector::VectorTransformsOptions vectorTransformOptions,
241  MLIRContext *context, PatternBenefit benefit = 1,
242  FilterConstraintType constraint = defaultFilter)
243  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
244  vectorTransformOptions(vectorTransformOptions),
245  filter(std::move(constraint)) {}
246 
248  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
249  PatternRewriter &rewriter) const override;
250 
251 private:
252  /// Options to control the vector patterns.
253  vector::VectorTransformsOptions vectorTransformOptions;
254  FilterConstraintType filter;
255 };
256 
257 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
258 /// semantics to a reduction_size-unrolled sequence:
259 /// ```
260 /// %at = vector.transpose %a, [1, 0]
261 /// %bRow0 = vector.extract %b[0]
262 /// %atRow0 = vector.extract %at[0]
263 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
264 /// ...
265 /// %bRowK = vector.extract %b[K]
266 /// %atRowK = vector.extract %at[K]
267 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
268 /// ```
269 ///
270 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
271 /// the vector.contract op is a row-major matrix multiply.
272 class ContractionOpToOuterProductOpLowering
273  : public MaskableOpRewritePattern<vector::ContractionOp> {
274 public:
275  using MaskableOpRewritePattern::MaskableOpRewritePattern;
276 
277  using FilterConstraintType =
278  std::function<LogicalResult(vector::ContractionOp op)>;
279 
280  static LogicalResult defaultFilter(vector::ContractionOp op) {
281  return success();
282  }
283 
284  ContractionOpToOuterProductOpLowering(
285  vector::VectorTransformsOptions vectorTransformOptions,
286  MLIRContext *context, PatternBenefit benefit = 1,
287  FilterConstraintType constraint = defaultFilter)
288  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
289  vectorTransformOptions(vectorTransformOptions),
290  filter(std::move(constraint)) {}
291 
293  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
294  PatternRewriter &rewriter) const override;
295 
296 private:
297  /// Options to control the vector patterns.
298  vector::VectorTransformsOptions vectorTransformOptions;
299  FilterConstraintType filter;
300 };
301 
302 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
303 /// semantics to an output-size-unrolled sequence:
304 /// ```
305 /// %out = arith.constant ... : vector<MxNxelt_type>
306 /// %bt = vector.transpose %b, [1, 0]
307 /// %aRow0 = vector.extract %a[0]
308 /// %btRow0 = vector.extract %bt[0]
309 /// %c00 = vector.reduce %atRow0, %bRow0
310 /// %out00 = vector.insert %c00, %out[0, 0]
311 /// ...
312 /// %aRowLast = vector.extract %at[M-1]
313 /// %btRowLast = vector.extract %b[N-1]
314 /// %cLastLast = vector.reduce %atRowLast, %bRowLast
315 /// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
316 /// ```
317 ///
318 /// This only kicks in when VectorTransformsOptions is set to Dot and
319 /// the vector.contract op is a row-major matmul or matvec.
320 class ContractionOpToDotLowering
321  : public MaskableOpRewritePattern<vector::ContractionOp> {
322 public:
323  using MaskableOpRewritePattern::MaskableOpRewritePattern;
324 
325  using FilterConstraintType =
326  std::function<LogicalResult(vector::ContractionOp op)>;
327 
328  static LogicalResult defaultFilter(vector::ContractionOp op) {
329  return success();
330  }
331 
332  ContractionOpToDotLowering(
333  vector::VectorTransformsOptions vectorTransformOptions,
334  MLIRContext *context, PatternBenefit benefit = 1,
335  const FilterConstraintType &constraint = defaultFilter)
336  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
337  vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
338 
340  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
341  PatternRewriter &rewriter) const override;
342 
343 private:
344  /// Options to control the vector patterns.
345  vector::VectorTransformsOptions vectorTransformOptions;
346  FilterConstraintType filter;
347 };
348 
349 /// Progressive lowering of ContractionOp.
350 ///
351 /// One:
352 /// %x = vector.contract with at least one free/batch dimension
353 /// is replaced by:
354 /// %a = vector.contract with one less free/batch dimension
355 /// %b = vector.contract with one less free/batch dimension
356 /// ..
357 /// %x = combine %a %b ..
358 /// until a pure contraction is reached (no free/batch dimensions),
359 /// which is replaced by a dot-product.
360 ///
361 /// This only kicks in when either VectorTransformsOptions is set
362 /// to Dot or when other contraction patterns fail.
363 class ContractionOpLowering
364  : public MaskableOpRewritePattern<vector::ContractionOp> {
365 public:
366  using MaskableOpRewritePattern::MaskableOpRewritePattern;
367  using FilterConstraintType =
368  std::function<LogicalResult(vector::ContractionOp op)>;
369 
370  static LogicalResult defaultFilter(vector::ContractionOp op) {
371  return success();
372  }
373 
374  ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
375  MLIRContext *context, PatternBenefit benefit = 1,
376  FilterConstraintType constraint = defaultFilter)
377  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
378  vectorTransformOptions(vectorTransformOptions),
379  filter(std::move(constraint)) {}
380 
382  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
383  PatternRewriter &rewriter) const override;
384 
385 private:
386  /// Options to control the vector patterns.
387  vector::VectorTransformsOptions vectorTransformOptions;
388  FilterConstraintType filter;
389  // Lower one parallel dimension.
390  FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
391  vector::ContractionOp op, int64_t lhsIndex,
392  int64_t rhsIndex, Value mask) const;
393  // Lower one reduction dimension.
394  FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
395  vector::ContractionOp op, Value mask) const;
396 };
397 
398 /// Generate a vector implementation for matmat, matvec and tmatvec.
399 /// This unrolls outer-products along the reduction dimension.
400 struct UnrolledOuterProductGenerator
401  : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
402  UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
403  : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
404  kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
405  res(op.getAcc()), lhsType(op.getLhsType()) {
406  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
407  if (maskableOp.isMasked())
408  mask = maskableOp.getMaskingOp().getMask();
409  }
410 
411  Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
412  if (!v)
413  return v;
414  return rewriter.create<vector::TransposeOp>(loc, v, perm);
415  }
416 
417  Value promote(Value v, Type dstElementType) {
418  Type elementType = v.getType();
419  auto vecType = dyn_cast<VectorType>(elementType);
420  if (vecType)
421  elementType = vecType.getElementType();
422  if (elementType == dstElementType)
423  return v;
424  Type promotedType = dstElementType;
425  if (vecType)
426  promotedType = vecType.clone(promotedType);
427  if (isa<FloatType>(dstElementType))
428  return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
429  return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
430  }
431 
432  FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
433  VectorType lhsType, int reductionSize,
434  std::optional<Value> maybeMask = std::nullopt) {
435  // Incremental support for masking.
436  if (mask && !maybeMask.has_value())
437  return failure();
438 
439  Type resElementType = cast<VectorType>(res.getType()).getElementType();
440  for (int64_t k = 0; k < reductionSize; ++k) {
441  Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
442  Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
443  extractA = promote(extractA, resElementType);
444  extractB = promote(extractB, resElementType);
445  Value extractMask;
446  if (maybeMask.has_value() && maybeMask.value())
447  extractMask =
448  rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
449 
450  Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
451  loc, res.getType(), extractA, extractB, res, kind);
452  res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
453  }
454  return res;
455  }
456 
457  /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of
458  /// dimension `reductionDim`. If the dimension is a scalable dimension,
459  /// returns "nullopt".
460  std::optional<int64_t> getReductionSize(VectorType vecType,
461  int64_t reductionDim) {
462  // Cannot unroll scalable dimension.
463  if (vecType.getScalableDims()[reductionDim])
464  return std::nullopt;
465  int64_t reductionSize = vecType.getDimSize(reductionDim);
466  assert(reductionSize > 0 &&
467  "Reduction dim must be a known static size to allow unrolling");
468  return reductionSize;
469  }
470 
471  /// Two outer parallel, one inner reduction (matmat flavor).
472  FailureOr<Value> matmat() {
473  if (!iters({Par(), Par(), Red()}))
474  return failure();
475  // Set up the parallel/reduction structure in the right form.
476  AffineExpr m, n, k;
477  bindDims(rewriter.getContext(), m, n, k);
478 
479  // Classical row-major matmul: Just permute the lhs.
480  if (layout({{m, k}, {k, n}, {m, n}})) {
481  if (auto reductionSize = getReductionSize(lhsType, 1)) {
482  // Note: `t` creates new IR. It must be nested within this `if` check
483  // so that no IR is created when then pattern returns "failure".
484  Value tLhs = t(lhs);
485  Value tMask = t(mask, {2, 0, 1});
486  return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
487  }
488  }
489  // TODO: may be better to fail and use some vector<k> -> scalar reduction.
490  if (layout({{m, k}, {n, k}, {m, n}})) {
491  if (auto reductionSize = getReductionSize(lhsType, 1)) {
492  Value tLhs = t(lhs);
493  Value tRhs = t(rhs);
494  Value tMask = t(mask, {2, 0, 1});
495  return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
496  }
497  }
498  // No need to permute anything.
499  if (layout({{k, m}, {k, n}, {m, n}})) {
500  if (auto reductionSize = getReductionSize(lhsType, 0)) {
501  Value tMask = t(mask, {2, 0, 1});
502  return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
503  }
504  }
505  // Just permute the rhs.
506  if (layout({{k, m}, {n, k}, {m, n}})) {
507  if (auto reductionSize = getReductionSize(lhsType, 0)) {
508  Value tRhs = t(rhs);
509  Value tMask = t(mask, {2, 0, 1});
510  return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
511  }
512  }
513  // Transposed output: swap RHS and LHS.
514  // Classical row-major matmul: permute the lhs.
515  if (layout({{m, k}, {k, n}, {n, m}})) {
516  if (auto reductionSize = getReductionSize(lhsType, 1)) {
517  Value tLhs = t(lhs);
518  Value tMask = t(mask, {2, 0, 1});
519  return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
520  }
521  }
522  // TODO: may be better to fail and use some vector<k> -> scalar reduction.
523  if (layout({{m, k}, {n, k}, {n, m}})) {
524  if (auto reductionSize = getReductionSize(lhsType, 1)) {
525  Value tRhs = t(rhs);
526  Value tLhs = t(lhs);
527  Value tMask = t(mask, {2, 0, 1});
528  return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
529  }
530  }
531  if (layout({{k, m}, {k, n}, {n, m}})) {
532  if (auto reductionSize = getReductionSize(lhsType, 0)) {
533  Value tMask = t(mask, {2, 0, 1});
534  return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
535  }
536  }
537  if (layout({{k, m}, {n, k}, {n, m}})) {
538  if (auto reductionSize = getReductionSize(lhsType, 0)) {
539  Value tRhs = t(rhs);
540  Value tMask = t(mask, {2, 0, 1});
541  return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
542  }
543  }
544  return failure();
545  }
546 
547  //
548  // One outer parallel, one inner reduction (matvec flavor).
549  // Mask needs to be transposed everywhere to turn the reduction dimension
550  // outermost as required by outerproduct.
551  //
552  FailureOr<Value> matvec() {
553  if (!iters({Par(), Red()}))
554  return failure();
555  AffineExpr m, k;
556  bindDims(rewriter.getContext(), m, k);
557 
558  // Case mat-vec: transpose.
559  if (layout({{m, k}, {k}, {m}})) {
560  if (auto reductionSize = getReductionSize(lhsType, 1)) {
561  Value tLhs = t(lhs);
562  Value tMask = t(mask);
563  return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
564  }
565  }
566  // Case mat-trans-vec: ready to go.
567  if (layout({{k, m}, {k}, {m}})) {
568  if (auto reductionSize = getReductionSize(lhsType, 0)) {
569  Value tMask = t(mask);
570  return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
571  }
572  }
573  // Case vec-mat: swap and transpose.
574  if (layout({{k}, {m, k}, {m}})) {
575  if (auto reductionSize = getReductionSize(lhsType, 0)) {
576  Value tRhs = t(rhs);
577  Value tMask = t(mask);
578  return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
579  }
580  }
581  // Case vec-mat-trans: swap and ready to go.
582  if (layout({{k}, {k, m}, {m}})) {
583  if (auto reductionSize = getReductionSize(lhsType, 0)) {
584  Value tMask = t(mask);
585  return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
586  }
587  }
588  return failure();
589  }
590 
591  //
592  // One outer reduction, one inner parallel (tmatvec flavor).
593  // Mask already has the shape of the outer product.
594  //
595  FailureOr<Value> tmatvec() {
596  if (!iters({Red(), Par()}))
597  return failure();
598  AffineExpr k, m;
599  bindDims(rewriter.getContext(), k, m);
600 
601  // Case mat-vec: transpose.
602  if (layout({{m, k}, {k}, {m}}))
603  if (auto reductionSize = getReductionSize(lhsType, 1))
604  return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
605  // Case mat-trans-vec: ready to go.
606  if (layout({{k, m}, {k}, {m}}))
607  if (auto reductionSize = getReductionSize(lhsType, 0))
608  return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
609  // Case vec-mat: swap and transpose.
610  if (layout({{k}, {m, k}, {m}}))
611  if (auto reductionSize = getReductionSize(lhsType, 0))
612  return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
613  // Case vec-mat-trans: swap and ready to go.
614  if (layout({{k}, {k, m}, {m}}))
615  if (auto reductionSize = getReductionSize(lhsType, 0))
616  return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
617  return failure();
618  }
619 
620 private:
621  vector::CombiningKind kind;
622  Value lhs, rhs, res, mask;
623  VectorType lhsType;
624 };
625 
626 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
627 /// semantics to a reduction_size-unrolled sequence:
628 /// ```
629 /// %at = vector.transpose %a, [1, 0]
630 /// %bRow0 = vector.extract %b[0]
631 /// %atRow0 = vector.extract %at[0]
632 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
633 /// ...
634 /// %bRowK = vector.extract %b[K]
635 /// %atRowK = vector.extract %at[K]
636 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
637 /// ```
638 ///
639 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
640 /// otherwise supports any layout permutation of the matrix-multiply.
642 ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
643  vector::ContractionOp op, MaskingOpInterface maskOp,
644  PatternRewriter &rewriter) const {
645  if (vectorTransformOptions.vectorContractLowering !=
646  vector::VectorContractLowering::OuterProduct)
647  return failure();
648 
649  if (failed(filter(op)))
650  return failure();
651 
652  UnrolledOuterProductGenerator e(rewriter, op);
653  FailureOr<Value> matmatRes = e.matmat();
654  if (succeeded(matmatRes)) {
655  return matmatRes;
656  }
657  FailureOr<Value> matvecRes = e.matvec();
658  if (succeeded(matvecRes)) {
659  return matvecRes;
660  }
661 
662  FailureOr<Value> tmatvecRes = e.tmatvec();
663  return tmatvecRes;
664 }
665 
666 FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
667  vector::ContractionOp op, MaskingOpInterface maskOp,
668  PatternRewriter &rewriter) const {
669  // TODO: Support vector.mask.
670  if (maskOp)
671  return failure();
672 
673  if (failed(filter(op)))
674  return failure();
675 
676  if (vectorTransformOptions.vectorContractLowering !=
677  vector::VectorContractLowering::Dot)
678  return failure();
679 
680  auto iteratorTypes = op.getIteratorTypes().getValue();
681  static constexpr std::array<int64_t, 2> perm = {1, 0};
682  Location loc = op.getLoc();
683  Value lhs = op.getLhs(), rhs = op.getRhs();
684 
685  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
686  auto infer = [&](MapList m) {
688  };
689  AffineExpr m, n, k;
690  bindDims(rewriter.getContext(), m, n, k);
691  SmallVector<AffineMap> maps = op.getIndexingMapsArray();
692  //
693  // In the following we wish to make the reduction dimension innermost so we
694  // can load vectors and just fmul + reduce into a scalar.
695  //
696  if (isParallelIterator(iteratorTypes[0]) &&
697  isParallelIterator(iteratorTypes[1]) &&
698  isReductionIterator(iteratorTypes[2])) {
699  //
700  // Two outer parallel, one inner reduction (matmat flavor).
701  //
702  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
703  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
704  } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
705  // No need to permute anything.
706  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
707  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
708  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
709  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
710  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
711  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
712  // This is the classical row-major matmul. Just permute the lhs.
713  Value tmp = lhs;
714  lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
715  rhs = tmp;
716  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
717  std::swap(lhs, rhs);
718  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
719  Value tmp = lhs;
720  lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
721  rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
722  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
723  Value tmp = rhs;
724  rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
725  lhs = tmp;
726  } else {
727  return failure();
728  }
729  } else if (isParallelIterator(iteratorTypes[0]) &&
730  isReductionIterator(iteratorTypes[1])) {
731  //
732  // One outer parallel, one inner reduction (matvec flavor)
733  //
734  if (maps == infer({{m, n}, {n}, {m}})) {
735  // No need to permute anything.
736  } else if (maps == infer({{n, m}, {n}, {m}})) {
737  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
738  } else if (maps == infer({{n}, {m, n}, {m}})) {
739  std::swap(lhs, rhs);
740  } else if (maps == infer({{n}, {n, m}, {m}})) {
741  std::swap(lhs, rhs);
742  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
743  } else {
744  return failure();
745  }
746  } else {
747  return failure();
748  }
749 
750  VectorType dstType = cast<VectorType>(op.getResultType());
751  assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
752  "Expected dst type of rank 1 or 2");
753 
754  unsigned rank = dstType.getRank();
755  unsigned dstRows = dstType.getShape()[0];
756  unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
757 
758  // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
759  Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
760  rewriter.getZeroAttr(dstType));
761  bool isInt = isa<IntegerType>(dstType.getElementType());
762  for (unsigned r = 0; r < dstRows; ++r) {
763  Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
764  for (unsigned c = 0; c < dstColumns; ++c) {
765  Value b = rank == 1
766  ? rhs
767  : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
768  Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
769  Value reduced = rewriter.create<vector::ReductionOp>(
770  op.getLoc(), vector::CombiningKind::ADD, m);
771 
773  : SmallVector<int64_t, 2>{r, c};
774  res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
775  }
776  }
777  if (auto acc = op.getAcc())
778  res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
779  return res;
780 }
781 
782 /// Lower vector.contract with all size one reduction dimensions to
783 /// elementwise ops when possible.
784 struct ContractOpToElementwise
785  : public MaskableOpRewritePattern<vector::ContractionOp> {
786  using MaskableOpRewritePattern::MaskableOpRewritePattern;
787  using FilterConstraintType =
788  std::function<LogicalResult(vector::ContractionOp op)>;
789  static LogicalResult defaultFilter(vector::ContractionOp op) {
790  return success();
791  }
792  ContractOpToElementwise(
793  vector::VectorTransformsOptions vectorTransformOptions,
794  MLIRContext *context, PatternBenefit benefit = 1,
795  const FilterConstraintType &constraint = defaultFilter)
796  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
797  vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
798 
800  matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
801  MaskingOpInterface maskOp,
802  PatternRewriter &rewriter) const override {
803  // TODO: Support vector.mask.
804  if (maskOp)
805  return failure();
806 
807  if (failed(filter(contractOp)))
808  return failure();
809 
810  if (vectorTransformOptions.vectorContractLowering !=
811  vector::VectorContractLowering::ParallelArith)
812  return failure();
813 
814  ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
815  ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
816  AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
817  AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
818  SmallVector<int64_t> lhsReductionDims =
819  getReductionIndex(lhsMap, contractOp.getIteratorTypes());
820  SmallVector<int64_t> rhsReductionDims =
821  getReductionIndex(rhsMap, contractOp.getIteratorTypes());
822  // All the reduction dimensions must be a size 1.
823  for (int64_t dim : lhsReductionDims) {
824  if (lhsShape[dim] != 1)
825  return failure();
826  }
827  for (int64_t dim : rhsReductionDims) {
828  if (rhsShape[dim] != 1)
829  return failure();
830  }
831  AffineMap accMap = contractOp.getIndexingMapsArray()[2];
832  unsigned numParallelDims = accMap.getNumResults();
833  unsigned numLhsDimToBroadcast =
834  numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
835  unsigned numRhsDimToBroadcast =
836  numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
837  SmallVector<int64_t> lhsDims;
838  SmallVector<int64_t> lhsTranspose;
839  SmallVector<int64_t> rhsDims;
840  SmallVector<int64_t> rhsTranspose;
841  for (int64_t dim : lhsReductionDims)
842  lhsTranspose.push_back(numLhsDimToBroadcast + dim);
843  for (int64_t dim : rhsReductionDims)
844  rhsTranspose.push_back(numRhsDimToBroadcast + dim);
845  // Loop through the parallel dimensions to calculate the dimensions to
846  // broadcast and to permute in order to extract only parallel dimensions.
847  for (unsigned i = 0; i < numParallelDims; i++) {
848  std::optional<unsigned> lhsDim =
849  getDimPosition(lhsMap, accMap.getDimPosition(i));
850  if (lhsDim) {
851  lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
852  } else {
853  // If the parallel dimension doesn't exist we will have to broadcast it.
854  lhsDims.push_back(
855  cast<VectorType>(contractOp.getResultType()).getDimSize(i));
856  lhsTranspose.push_back(lhsDims.size() - 1);
857  }
858  std::optional<unsigned> rhsDim =
859  getDimPosition(rhsMap, accMap.getDimPosition(i));
860  if (rhsDim) {
861  rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
862  } else {
863  // If the parallel dimension doesn't exist we will have to broadcast it.
864  rhsDims.push_back(
865  cast<VectorType>(contractOp.getResultType()).getDimSize(i));
866  rhsTranspose.push_back(rhsDims.size() - 1);
867  }
868  }
869  Value newLhs = contractOp.getLhs();
870  Value newRhs = contractOp.getRhs();
871  Location loc = contractOp.getLoc();
872  if (!lhsDims.empty()) {
873  lhsDims.append(lhsShape.begin(), lhsShape.end());
874  auto expandedType =
875  VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
876  newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
877  }
878  if (!rhsDims.empty()) {
879  rhsDims.append(rhsShape.begin(), rhsShape.end());
880  auto expandedType =
881  VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
882  newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
883  }
884  bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
885  newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
886  newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
887  SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
888  SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
889  newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
890  newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
891  std::optional<Value> result =
892  createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
893  contractOp.getKind(), rewriter, isInt);
894  if (result)
895  return *result;
896 
897  return failure();
898  }
899 
900 private:
901  /// Options to control the vector patterns.
902  vector::VectorTransformsOptions vectorTransformOptions;
903  FilterConstraintType filter;
904 };
905 
906 /// Progressive lowering of ContractionOp.
907 /// One:
908 /// %x = vector.contract with at least one free/batch dimension
909 /// is replaced by:
910 /// %a = vector.contract with one less free/batch dimension
911 /// %b = vector.contract with one less free/batch dimension
912 /// ..
913 /// %x = combine %a %b ..
914 /// until a pure contraction is reached (no free/batch dimensions),
915 /// which is replaced by a dot-product.
916 ///
917 /// This only kicks in when either VectorTransformsOptions is set
918 /// to DOT or when other contraction patterns fail.
919 //
920 // TODO: break down into transpose/reshape/cast ops
921 // when they become available to avoid code dup
922 // TODO: investigate lowering order impact on performance
923 FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
924  vector::ContractionOp op, MaskingOpInterface maskOp,
925  PatternRewriter &rewriter) const {
926  if (failed(filter(op)))
927  return failure();
928 
929  // TODO: support mixed mode contract lowering.
930  if (op.getLhsType().getElementType() !=
931  getElementTypeOrSelf(op.getAccType()) ||
932  op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
933  return failure();
934 
935  // TODO: the code below assumes the default contraction, make sure it supports
936  // other kinds before enabling this lowering.
937  if (op.getKind() != vector::CombiningKind::ADD) {
938  return rewriter.notifyMatchFailure(
939  op, "contractions other than 'add' not supported");
940  }
941 
942  // TODO: implement benefits, cost models.
943  MLIRContext *ctx = op.getContext();
944 
945  ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
946  FailureOr<Value> newVal1 =
947  pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
948  if (!failed(newVal1))
949  return newVal1;
950 
951  ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
952  FailureOr<Value> newVal2 =
953  pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
954  if (!failed(newVal2))
955  return newVal2;
956 
957  ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
958  FailureOr<Value> newVal3 =
959  pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
960  if (!failed(newVal3))
961  return newVal3;
962 
963  ContractOpToElementwise pat4(vectorTransformOptions, ctx);
964  FailureOr<Value> newVal4 =
965  pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
966  if (!failed(newVal4))
967  return newVal4;
968 
969  // Vector mask setup.
970 
971  Value mask;
972  if (maskOp)
973  mask = maskOp.getMask();
974  // Find first batch dimension in LHS/RHS, and lower when found.
975  std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
976  if (!batchDimMap.empty()) {
977  int64_t lhsIndex = batchDimMap[0].first;
978  int64_t rhsIndex = batchDimMap[0].second;
979  auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
980  if (failed(newOp))
981  return failure();
982  return newOp;
983  }
984 
985  // Collect contracting dimensions.
986  std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
987  op.getContractingDimMap();
988  DenseSet<int64_t> lhsContractingDimSet;
989  DenseSet<int64_t> rhsContractingDimSet;
990  for (auto &dimPair : contractingDimMap) {
991  lhsContractingDimSet.insert(dimPair.first);
992  rhsContractingDimSet.insert(dimPair.second);
993  }
994 
995  // Find first free dimension in LHS, and lower when found.
996  VectorType lhsType = op.getLhsType();
997  for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
998  if (lhsContractingDimSet.count(lhsIndex) == 0) {
999  auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
1000  if (failed(newOp))
1001  return failure();
1002  return newOp;
1003  }
1004  }
1005 
1006  // Find first free dimension in RHS, and lower when found.
1007  VectorType rhsType = op.getRhsType();
1008  for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1009  if (rhsContractingDimSet.count(rhsIndex) == 0) {
1010  auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
1011  if (failed(newOp))
1012  return failure();
1013  return newOp;
1014  }
1015  }
1016 
1017  // Lower the first remaining reduction dimension.
1018  if (!contractingDimMap.empty()) {
1019  auto newOp = lowerReduction(rewriter, op, mask);
1020  if (failed(newOp))
1021  return failure();
1022  return newOp;
1023  }
1024 
1025  return failure();
1026 }
1027 
1028 // Lower one parallel dimension.
1029 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
1030 // TODO: consider reusing existing contract unrolling
1031 FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
1032  vector::ContractionOp op,
1033  int64_t lhsIndex,
1034  int64_t rhsIndex,
1035  Value mask) const {
1036  VectorType lhsType = op.getLhsType();
1037  VectorType rhsType = op.getRhsType();
1038  VectorType resType = cast<VectorType>(op.getResultType());
1039  // Find the iterator type index and result index.
1040  SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1041  int64_t iterIndex = -1;
1042  int64_t dimSize = -1;
1043  if (lhsIndex >= 0) {
1044  iterIndex = iMap[0].getDimPosition(lhsIndex);
1045  if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
1046  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1047  diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
1048  << " to map to the same dimension";
1049  });
1050  if (lhsType.getScalableDims()[lhsIndex])
1051  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1052  diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex
1053  << ") is not supported yet";
1054  });
1055  dimSize = lhsType.getDimSize(lhsIndex);
1056  } else if (rhsIndex >= 0) {
1057  iterIndex = iMap[1].getDimPosition(rhsIndex);
1058  if (rhsType.getScalableDims()[rhsIndex])
1059  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1060  diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex
1061  << ") is not supported yet";
1062  });
1063  dimSize = rhsType.getDimSize(rhsIndex);
1064  }
1065  if (iterIndex < 0)
1066  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1067  diag << "expected either lhsIndex=" << lhsIndex
1068  << " or rhsIndex=" << rhsIndex << " to be nonnegative";
1069  });
1070  // value_or(-1) means that we tolerate a dimension not appearing
1071  // in the result map. That can't happen for actual parallel iterators, but
1072  // the caller ContractionOpLowering::matchAndRewrite is currently calling
1073  // lowerParallel also for the case of unit-size reduction dims appearing only
1074  // on one of LHS or RHS, not both. At the moment, such cases are created by
1075  // CastAwayContractionLeadingOneDim, so we need to either support that or
1076  // modify that pattern.
1077  int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
1078  if (resIndex == -1 && dimSize != 1)
1079  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1080  diag << "expected the dimension for iterIndex=" << iterIndex
1081  << " to either appear in the result map, or to be a unit dimension";
1082  });
1083 
1084  // Construct new iterator types and affine map array attribute.
1085  std::array<AffineMap, 3> lowIndexingMaps = {
1086  adjustMap(iMap[0], iterIndex, rewriter),
1087  adjustMap(iMap[1], iterIndex, rewriter),
1088  adjustMap(iMap[2], iterIndex, rewriter)};
1089  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1090  auto lowIter =
1091  rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1092  // Unroll into a series of lower dimensional vector.contract ops.
1093  Location loc = op.getLoc();
1094  Value result = rewriter.create<arith::ConstantOp>(
1095  loc, resType, rewriter.getZeroAttr(resType));
1096 
1097  for (int64_t d = 0; d < dimSize; ++d) {
1098  auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1099  auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1100  auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1101 
1102  Value lowMask;
1103  if (mask)
1104  lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1105  iterIndex, d, rewriter);
1106 
1107  Operation *lowContract = rewriter.create<vector::ContractionOp>(
1108  loc, lhs, rhs, acc, lowAffine, lowIter);
1109  lowContract = maskOperation(rewriter, lowContract, lowMask);
1110  result = reshapeStore(loc, lowContract->getResult(0), result, resType,
1111  resIndex, d, rewriter);
1112  }
1113  return result;
1114 }
1115 
1116 // Lower one reduction dimension.
1117 FailureOr<Value> ContractionOpLowering::lowerReduction(
1118  PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
1119  auto loc = op.getLoc();
1120  VectorType lhsType = op.getLhsType();
1121  VectorType rhsType = op.getRhsType();
1122  Type resType = op.getResultType();
1123  if (isa<VectorType>(resType))
1124  return rewriter.notifyMatchFailure(op,
1125  "did not expect a VectorType result");
1126  bool isInt = isa<IntegerType>(resType);
1127  // Use iterator index 0.
1128  int64_t iterIndex = 0;
1129  SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1130  std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1131  std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1132  if (!lookupLhs.has_value())
1133  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1134  diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1135  });
1136  if (!lookupRhs.has_value())
1137  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1138  diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1139  });
1140  int64_t lhsIndex = *lookupLhs;
1141  int64_t rhsIndex = *lookupRhs;
1142  int64_t dimSize = lhsType.getDimSize(lhsIndex);
1143  if (dimSize != rhsType.getDimSize(rhsIndex))
1144  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1145  diag << "expect LHS dimension " << lhsIndex
1146  << " to have the same size as RHS dimension " << rhsIndex;
1147  });
1148  // Base case.
1149  if (lhsType.getRank() == 1) {
1150  if (rhsType.getRank() != 1)
1151  return rewriter.notifyMatchFailure(
1152  op, "When LHS has rank 1, expected also RHS to have rank 1");
1153  Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1154  auto kind = vector::CombiningKind::ADD;
1155 
1156  Value acc = op.getAcc();
1157  Operation *reductionOp =
1158  acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
1159  : rewriter.create<vector::ReductionOp>(loc, kind, m);
1160  return maskOperation(rewriter, reductionOp, mask)->getResult(0);
1161  }
1162  // Construct new iterator types and affine map array attribute.
1163  std::array<AffineMap, 3> lowIndexingMaps = {
1164  adjustMap(iMap[0], iterIndex, rewriter),
1165  adjustMap(iMap[1], iterIndex, rewriter),
1166  adjustMap(iMap[2], iterIndex, rewriter)};
1167  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1168  auto lowIter =
1169  rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1170  // Unroll into a series of lower dimensional vector.contract ops.
1171  // By feeding the initial accumulator into the first contraction,
1172  // and the result of each contraction into the next, eventually
1173  // the sum of all reductions is computed.
1174  Value result = op.getAcc();
1175  for (int64_t d = 0; d < dimSize; ++d) {
1176  auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1177  auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1178  Value newMask;
1179  if (mask)
1180  newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1181  iterIndex, d, rewriter);
1182 
1183  Operation *newContract = rewriter.create<vector::ContractionOp>(
1184  loc, lhs, rhs, result, lowAffine, lowIter);
1185  result = maskOperation(rewriter, newContract, newMask)->getResult(0);
1186  }
1187  return result;
1188 }
1189 
1190 /// Progressive lowering of OuterProductOp.
1191 /// One:
1192 /// %x = vector.outerproduct %lhs, %rhs, %acc
1193 /// is replaced by:
1194 /// %z = zero-result
1195 /// %0 = vector.extract %lhs[0]
1196 /// %1 = vector.broadcast %0
1197 /// %2 = vector.extract %acc[0]
1198 /// %3 = vector.fma %1, %rhs, %2
1199 /// %4 = vector.insert %3, %z[0]
1200 /// ..
1201 /// %x = vector.insert %.., %..[N-1]
1202 ///
1203 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1204 public:
1206 
1207  LogicalResult matchAndRewrite(vector::OuterProductOp op,
1208  PatternRewriter &rewriter) const override {
1209  VectorType resType = op.getResultVectorType();
1210  if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1211  return failure();
1212 
1213  auto loc = op.getLoc();
1214 
1215  VectorType lhsType = op.getOperandVectorTypeLHS();
1216  VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1217  Type eltType = resType.getElementType();
1218  bool isInt = isa<IntegerType, IndexType>(eltType);
1219  Value acc = op.getAcc();
1220  vector::CombiningKind kind = op.getKind();
1221 
1222  // Vector mask setup.
1223  OpBuilder::InsertionGuard guard(rewriter);
1224  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1225  Operation *rootOp;
1226  Value mask;
1227  if (maskableOp.isMasked()) {
1228  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
1229  rootOp = maskableOp.getMaskingOp();
1230  mask = maskableOp.getMaskingOp().getMask();
1231  } else {
1232  rootOp = op;
1233  }
1234 
1235  if (!rhsType) {
1236  // Special case: AXPY operation.
1237  Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
1238  std::optional<Value> mult = createContractArithOp(
1239  loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
1240  if (!mult.has_value())
1241  return failure();
1242  rewriter.replaceOp(rootOp, *mult);
1243  return success();
1244  }
1245 
1246  Value result = rewriter.create<arith::ConstantOp>(
1247  loc, resType, rewriter.getZeroAttr(resType));
1248  for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1249  Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d);
1250  Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
1251  Value r = nullptr;
1252  if (acc)
1253  r = rewriter.create<vector::ExtractOp>(loc, acc, d);
1254  Value extrMask;
1255  if (mask)
1256  extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d);
1257 
1258  std::optional<Value> m = createContractArithOp(
1259  loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1260  if (!m.has_value())
1261  return failure();
1262  result = rewriter.create<vector::InsertOp>(loc, *m, result, d);
1263  }
1264 
1265  rewriter.replaceOp(rootOp, result);
1266  return success();
1267  }
1268 };
1269 
1270 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1271 /// semantics to:
1272 /// ```
1273 /// %mta = maybe_transpose
1274 /// %mtb = maybe_transpose
1275 /// %flattened_a = vector.shape_cast %mta
1276 /// %flattened_b = vector.shape_cast %mtb
1277 /// %flattened_d = vector.matmul %flattened_a, %flattened_b
1278 /// %mtd = vector.shape_cast %flattened_d
1279 /// %d = maybe_untranspose %mtd
1280 /// %e = add %c, %d
1281 /// ```
1282 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1283 //
1284 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1285 /// vector.transpose operations are inserted if the vector.contract op is not a
1286 /// row-major matrix multiply.
1287 FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
1288  vector::ContractionOp op, MaskingOpInterface maskOp,
1289  PatternRewriter &rew) const {
1290  // TODO: Support vector.mask.
1291  if (maskOp)
1292  return failure();
1293 
1294  if (vectorTransformOptions.vectorContractLowering !=
1295  vector::VectorContractLowering::Matmul)
1296  return failure();
1297  if (failed(filter(op)))
1298  return failure();
1299 
1300  auto iteratorTypes = op.getIteratorTypes().getValue();
1301  if (!isParallelIterator(iteratorTypes[0]) ||
1302  !isParallelIterator(iteratorTypes[1]) ||
1303  !isReductionIterator(iteratorTypes[2]))
1304  return failure();
1305 
1306  Type elementType = op.getLhsType().getElementType();
1307  if (!elementType.isIntOrFloat())
1308  return failure();
1309 
1310  Type dstElementType = op.getType();
1311  if (auto vecType = dyn_cast<VectorType>(dstElementType))
1312  dstElementType = vecType.getElementType();
1313  if (elementType != dstElementType)
1314  return failure();
1315 
1316  // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1317  // Bail out if the contraction cannot be put in this form.
1318  MLIRContext *ctx = op.getContext();
1319  Location loc = op.getLoc();
1320  AffineExpr m, n, k;
1321  bindDims(rew.getContext(), m, n, k);
1322  // LHS must be A(m, k) or A(k, m).
1323  Value lhs = op.getLhs();
1324  auto lhsMap = op.getIndexingMapsArray()[0];
1325  if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1326  lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1327  else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1328  return failure();
1329 
1330  // RHS must be B(k, n) or B(n, k).
1331  Value rhs = op.getRhs();
1332  auto rhsMap = op.getIndexingMapsArray()[1];
1333  if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1334  rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1335  else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1336  return failure();
1337 
1338  // At this point lhs and rhs are in row-major.
1339  VectorType lhsType = cast<VectorType>(lhs.getType());
1340  VectorType rhsType = cast<VectorType>(rhs.getType());
1341  int64_t lhsRows = lhsType.getDimSize(0);
1342  int64_t lhsColumns = lhsType.getDimSize(1);
1343  int64_t rhsColumns = rhsType.getDimSize(1);
1344 
1345  Type flattenedLHSType =
1346  VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1347  lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1348 
1349  Type flattenedRHSType =
1350  VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1351  rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1352 
1353  Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1354  rhsColumns);
1355  mul = rew.create<vector::ShapeCastOp>(
1356  loc,
1357  VectorType::get({lhsRows, rhsColumns},
1358  getElementTypeOrSelf(op.getAcc().getType())),
1359  mul);
1360 
1361  // ACC must be C(m, n) or C(n, m).
1362  auto accMap = op.getIndexingMapsArray()[2];
1363  if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1364  mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1365  else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1366  llvm_unreachable("invalid contraction semantics");
1367 
1368  Value res =
1369  isa<IntegerType>(elementType)
1370  ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
1371  : static_cast<Value>(
1372  rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
1373 
1374  return res;
1375 }
1376 } // namespace
1377 
1380  PatternBenefit benefit, bool disableOuterProductLowering) {
1381  if (!disableOuterProductLowering)
1382  patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1383  patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
1384  ContractionOpToOuterProductOpLowering>(
1385  options, patterns.getContext(), benefit);
1386 }
1387 
1389  RewritePatternSet &patterns, PatternBenefit benefit) {
1390  patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1391 }
static SmallVector< int64_t > getReductionIndex(AffineMap map, ArrayAttr iteratorTypes)
Return the positions of the reductions in the given map.
static Value createMul(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates a MulIOp if isInt is true otherwise create an MulFOp using operands x andy`.
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
static std::optional< unsigned > getDimPosition(AffineMap map, unsigned dim)
Look for a given dimension in an affine map and return its position.
static std::optional< Value > createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask=Value())
Helper to create arithmetic operation associated with a kind of contraction.
static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< Attribute > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
#define MINUI(lhs, rhs)
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:399
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:378
unsigned getNumResults() const
Definition: AffineMap.cpp:386
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:296
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:119
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:305
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:330
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:140
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:135
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:349
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:599
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:133
Structure to control the behavior of vector transform patterns.