MLIR  20.0.0git
LowerVectorContract.cpp
Go to the documentation of this file.
1 //===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.contract' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
34 
35 #define DEBUG_TYPE "vector-contract-lowering"
36 
37 using namespace mlir;
38 using namespace mlir::vector;
39 
40 //===----------------------------------------------------------------------===//
41 // Helper functions
42 //===----------------------------------------------------------------------===//
43 // Helper to find an index in an affine map.
44 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
45  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
46  int64_t idx = map.getDimPosition(i);
47  if (idx == index)
48  return i;
49  }
50  return std::nullopt;
51 }
52 
53 // Helper to construct iterator types with one index removed.
54 static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
55  int64_t index) {
56  SmallVector<Attribute> results;
57  for (const auto &it : llvm::enumerate(iteratorTypes)) {
58  int64_t idx = it.index();
59  if (idx == index)
60  continue;
61  results.push_back(it.value());
62  }
63  return results;
64 }
65 
66 // Helper to construct an affine map with one index removed.
67 static AffineMap adjustMap(AffineMap map, int64_t index,
68  PatternRewriter &rewriter) {
69  auto *ctx = rewriter.getContext();
71  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
72  int64_t idx = map.getDimPosition(i);
73  if (idx == index)
74  continue;
75  // Re-insert remaining indices, but renamed when occurring
76  // after the removed index.
77  auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
78  results.push_back(targetExpr);
79  }
80  return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
81 }
82 
83 // Helper method to possibly drop a dimension in a load.
84 // TODO
85 static Value reshapeLoad(Location loc, Value val, VectorType type,
86  int64_t index, int64_t pos,
87  PatternRewriter &rewriter) {
88  if (index == -1)
89  return val;
90 
91  // At extraction dimension?
92  if (index == 0)
93  return rewriter.create<vector::ExtractOp>(loc, val, pos);
94 
95  // Unroll leading dimensions.
96  VectorType vType = VectorType::Builder(type).dropDim(0);
97  VectorType resType = VectorType::Builder(type).dropDim(index);
98  Value result = rewriter.create<arith::ConstantOp>(
99  loc, resType, rewriter.getZeroAttr(resType));
100  for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
101  Value ext = rewriter.create<vector::ExtractOp>(loc, val, d);
102  Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
103  result = rewriter.create<vector::InsertOp>(loc, load, result, d);
104  }
105  return result;
106 }
107 
108 // Helper method to possibly drop a dimension in a store.
109 // TODO
110 static Value reshapeStore(Location loc, Value val, Value result,
111  VectorType type, int64_t index, int64_t pos,
112  PatternRewriter &rewriter) {
113  // Unmodified?
114  if (index == -1)
115  return val;
116  // At insertion dimension?
117  if (index == 0)
118  return rewriter.create<vector::InsertOp>(loc, val, result, pos);
119 
120  // Unroll leading dimensions.
121  VectorType vType = VectorType::Builder(type).dropDim(0);
122  for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
123  Value ext = rewriter.create<vector::ExtractOp>(loc, result, d);
124  Value ins = rewriter.create<vector::ExtractOp>(loc, val, d);
125  Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
126  result = rewriter.create<vector::InsertOp>(loc, sto, result, d);
127  }
128  return result;
129 }
130 
131 /// Helper to create arithmetic operation associated with a kind of contraction.
132 static std::optional<Value>
134  vector::CombiningKind kind, PatternRewriter &rewriter,
135  bool isInt, Value mask = Value()) {
136  using vector::CombiningKind;
137  Value mul;
138 
139  if (isInt) {
140  if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
141  kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
142  // Only valid for floating point types.
143  return std::nullopt;
144  mul = rewriter.create<arith::MulIOp>(loc, x, y);
145  } else {
146  // Float case.
147  if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
148  kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
149  kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
150  kind == CombiningKind::XOR)
151  // Only valid for integer types.
152  return std::nullopt;
153  // Special case for fused multiply-add.
154  if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
155  Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
156  if (mask)
157  // The fma op doesn't need explicit masking. However, fma ops used in
158  // reductions must preserve previous 'acc' values for masked-out lanes.
159  fma = selectPassthru(rewriter, mask, fma, acc);
160  return fma;
161  }
162  mul = rewriter.create<arith::MulFOp>(loc, x, y);
163  }
164 
165  if (!acc)
166  return std::optional<Value>(mul);
167 
168  return makeArithReduction(rewriter, loc, kind, mul, acc,
169  /*fastmath=*/nullptr, mask);
170 }
171 
172 /// Return the positions of the reductions in the given map.
174  ArrayAttr iteratorTypes) {
175  SmallVector<int64_t> dimsIdx;
176  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
177  if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
178  dimsIdx.push_back(i);
179  }
180  return dimsIdx;
181 }
182 
183 /// Look for a given dimension in an affine map and return its position. Return
184 /// std::nullopt if the dimension is not in the map results.
185 static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
186  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
187  if (map.getDimPosition(i) == dim)
188  return i;
189  }
190  return std::nullopt;
191 }
192 
193 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
194 /// operands `x` and `y`.
195 static Value createAdd(Location loc, Value x, Value y, bool isInt,
196  PatternRewriter &rewriter) {
197  if (isInt)
198  return rewriter.create<arith::AddIOp>(loc, x, y);
199  return rewriter.create<arith::AddFOp>(loc, x, y);
200 }
201 
202 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
203 /// operands `x and `y`.
204 static Value createMul(Location loc, Value x, Value y, bool isInt,
205  PatternRewriter &rewriter) {
206  if (isInt)
207  return rewriter.create<arith::MulIOp>(loc, x, y);
208  return rewriter.create<arith::MulFOp>(loc, x, y);
209 }
210 
211 namespace {
212 
213 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
214 /// semantics to:
215 /// ```
216 /// %flattened_a = vector.shape_cast %a
217 /// %flattened_b = vector.shape_cast %b
218 /// %flattened_d = vector.matmul %flattened_a, %flattened_b
219 /// %d = vector.shape_cast %%flattened_d
220 /// %e = add %c, %d
221 /// ```
222 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
223 //
224 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
225 /// the vector.contract op is a row-major matrix multiply.
226 class ContractionOpToMatmulOpLowering
227  : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
228 public:
229  using MaskableOpRewritePattern::MaskableOpRewritePattern;
230 
231  using FilterConstraintType =
232  std::function<LogicalResult(vector::ContractionOp op)>;
233 
234  static LogicalResult defaultFilter(vector::ContractionOp op) {
235  return success();
236  }
237 
238  ContractionOpToMatmulOpLowering(
239  vector::VectorTransformsOptions vectorTransformOptions,
240  MLIRContext *context, PatternBenefit benefit = 1,
241  FilterConstraintType constraint = defaultFilter)
242  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
243  vectorTransformOptions(vectorTransformOptions),
244  filter(std::move(constraint)) {}
245 
246  FailureOr<Value>
247  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
248  PatternRewriter &rewriter) const override;
249 
250 private:
251  /// Options to control the vector patterns.
252  vector::VectorTransformsOptions vectorTransformOptions;
253  FilterConstraintType filter;
254 };
255 
256 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
257 /// semantics to a reduction_size-unrolled sequence:
258 /// ```
259 /// %at = vector.transpose %a, [1, 0]
260 /// %bRow0 = vector.extract %b[0]
261 /// %atRow0 = vector.extract %at[0]
262 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
263 /// ...
264 /// %bRowK = vector.extract %b[K]
265 /// %atRowK = vector.extract %at[K]
266 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
267 /// ```
268 ///
269 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
270 /// the vector.contract op is a row-major matrix multiply.
271 class ContractionOpToOuterProductOpLowering
272  : public MaskableOpRewritePattern<vector::ContractionOp> {
273 public:
274  using MaskableOpRewritePattern::MaskableOpRewritePattern;
275 
276  using FilterConstraintType =
277  std::function<LogicalResult(vector::ContractionOp op)>;
278 
279  static LogicalResult defaultFilter(vector::ContractionOp op) {
280  return success();
281  }
282 
283  ContractionOpToOuterProductOpLowering(
284  vector::VectorTransformsOptions vectorTransformOptions,
285  MLIRContext *context, PatternBenefit benefit = 1,
286  FilterConstraintType constraint = defaultFilter)
287  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
288  vectorTransformOptions(vectorTransformOptions),
289  filter(std::move(constraint)) {}
290 
291  FailureOr<Value>
292  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
293  PatternRewriter &rewriter) const override;
294 
295 private:
296  /// Options to control the vector patterns.
297  vector::VectorTransformsOptions vectorTransformOptions;
298  FilterConstraintType filter;
299 };
300 
301 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
302 /// semantics to an output-size-unrolled sequence:
303 /// ```
304 /// %out = arith.constant ... : vector<MxNxelt_type>
305 /// %bt = vector.transpose %b, [1, 0]
306 /// %aRow0 = vector.extract %a[0]
307 /// %btRow0 = vector.extract %bt[0]
308 /// %c00 = vector.reduce %atRow0, %bRow0
309 /// %out00 = vector.insert %c00, %out[0, 0]
310 /// ...
311 /// %aRowLast = vector.extract %at[M-1]
312 /// %btRowLast = vector.extract %b[N-1]
313 /// %cLastLast = vector.reduce %atRowLast, %bRowLast
314 /// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
315 /// ```
316 ///
317 /// This only kicks in when VectorTransformsOptions is set to Dot and
318 /// the vector.contract op is a row-major matmul or matvec.
319 class ContractionOpToDotLowering
320  : public MaskableOpRewritePattern<vector::ContractionOp> {
321 public:
322  using MaskableOpRewritePattern::MaskableOpRewritePattern;
323 
324  using FilterConstraintType =
325  std::function<LogicalResult(vector::ContractionOp op)>;
326 
327  static LogicalResult defaultFilter(vector::ContractionOp op) {
328  return success();
329  }
330 
331  ContractionOpToDotLowering(
332  vector::VectorTransformsOptions vectorTransformOptions,
333  MLIRContext *context, PatternBenefit benefit = 1,
334  const FilterConstraintType &constraint = defaultFilter)
335  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
336  vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
337 
338  FailureOr<Value>
339  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
340  PatternRewriter &rewriter) const override;
341 
342 private:
343  /// Options to control the vector patterns.
344  vector::VectorTransformsOptions vectorTransformOptions;
345  FilterConstraintType filter;
346 };
347 
348 /// Progressive lowering of ContractionOp.
349 ///
350 /// One:
351 /// %x = vector.contract with at least one free/batch dimension
352 /// is replaced by:
353 /// %a = vector.contract with one less free/batch dimension
354 /// %b = vector.contract with one less free/batch dimension
355 /// ..
356 /// %x = combine %a %b ..
357 /// until a pure contraction is reached (no free/batch dimensions),
358 /// which is replaced by a dot-product.
359 ///
360 /// This only kicks in when either VectorTransformsOptions is set
361 /// to Dot or when other contraction patterns fail.
362 class ContractionOpLowering
363  : public MaskableOpRewritePattern<vector::ContractionOp> {
364 public:
365  using MaskableOpRewritePattern::MaskableOpRewritePattern;
366  using FilterConstraintType =
367  std::function<LogicalResult(vector::ContractionOp op)>;
368 
369  static LogicalResult defaultFilter(vector::ContractionOp op) {
370  return success();
371  }
372 
373  ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
374  MLIRContext *context, PatternBenefit benefit = 1,
375  FilterConstraintType constraint = defaultFilter)
376  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
377  vectorTransformOptions(vectorTransformOptions),
378  filter(std::move(constraint)) {}
379 
380  FailureOr<Value>
381  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
382  PatternRewriter &rewriter) const override;
383 
384 private:
385  /// Options to control the vector patterns.
386  vector::VectorTransformsOptions vectorTransformOptions;
387  FilterConstraintType filter;
388  // Lower one parallel dimension.
389  FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
390  vector::ContractionOp op, int64_t lhsIndex,
391  int64_t rhsIndex, Value mask) const;
392  // Lower one reduction dimension.
393  FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
394  vector::ContractionOp op, Value mask) const;
395 };
396 
397 /// Generate a vector implementation for matmat, matvec and tmatvec.
398 /// This unrolls outer-products along the reduction dimension.
399 struct UnrolledOuterProductGenerator
400  : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
401  UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
402  : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
403  kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
404  res(op.getAcc()), lhsType(op.getLhsType()) {
405  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
406  if (maskableOp.isMasked())
407  mask = maskableOp.getMaskingOp().getMask();
408  }
409 
410  Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
411  if (!v)
412  return v;
413  return rewriter.create<vector::TransposeOp>(loc, v, perm);
414  }
415 
416  Value promote(Value v, Type dstElementType) {
417  Type elementType = v.getType();
418  auto vecType = dyn_cast<VectorType>(elementType);
419  if (vecType)
420  elementType = vecType.getElementType();
421  if (elementType == dstElementType)
422  return v;
423  Type promotedType = dstElementType;
424  if (vecType)
425  promotedType = vecType.clone(promotedType);
426  if (isa<FloatType>(dstElementType))
427  return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
428  return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
429  }
430 
431  FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
432  VectorType lhsType, int reductionSize,
433  std::optional<Value> maybeMask = std::nullopt) {
434  // Incremental support for masking.
435  if (mask && !maybeMask.has_value())
436  return failure();
437 
438  Type resElementType = cast<VectorType>(res.getType()).getElementType();
439  for (int64_t k = 0; k < reductionSize; ++k) {
440  Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
441  Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
442  extractA = promote(extractA, resElementType);
443  extractB = promote(extractB, resElementType);
444  Value extractMask;
445  if (maybeMask.has_value() && maybeMask.value())
446  extractMask =
447  rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
448 
449  Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
450  loc, res.getType(), extractA, extractB, res, kind);
451  res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
452  }
453  return res;
454  }
455 
456  /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of
457  /// dimension `reductionDim`. If the dimension is a scalable dimension,
458  /// returns "nullopt".
459  std::optional<int64_t> getReductionSize(VectorType vecType,
460  int64_t reductionDim) {
461  // Cannot unroll scalable dimension.
462  if (vecType.getScalableDims()[reductionDim])
463  return std::nullopt;
464  int64_t reductionSize = vecType.getDimSize(reductionDim);
465  assert(reductionSize > 0 &&
466  "Reduction dim must be a known static size to allow unrolling");
467  return reductionSize;
468  }
469 
470  /// Two outer parallel, one inner reduction (matmat flavor).
471  FailureOr<Value> matmat() {
472  if (!iters({Par(), Par(), Red()}))
473  return failure();
474  // Set up the parallel/reduction structure in the right form.
475  AffineExpr m, n, k;
476  bindDims(rewriter.getContext(), m, n, k);
477 
478  // Classical row-major matmul: Just permute the lhs.
479  if (layout({{m, k}, {k, n}, {m, n}})) {
480  if (auto reductionSize = getReductionSize(lhsType, 1)) {
481  // Note: `t` creates new IR. It must be nested within this `if` check
482  // so that no IR is created when then pattern returns "failure".
483  Value tLhs = t(lhs);
484  Value tMask = t(mask, {2, 0, 1});
485  return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
486  }
487  }
488  // TODO: may be better to fail and use some vector<k> -> scalar reduction.
489  if (layout({{m, k}, {n, k}, {m, n}})) {
490  if (auto reductionSize = getReductionSize(lhsType, 1)) {
491  Value tLhs = t(lhs);
492  Value tRhs = t(rhs);
493  Value tMask = t(mask, {2, 0, 1});
494  return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
495  }
496  }
497  // No need to permute anything.
498  if (layout({{k, m}, {k, n}, {m, n}})) {
499  if (auto reductionSize = getReductionSize(lhsType, 0)) {
500  Value tMask = t(mask, {2, 0, 1});
501  return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
502  }
503  }
504  // Just permute the rhs.
505  if (layout({{k, m}, {n, k}, {m, n}})) {
506  if (auto reductionSize = getReductionSize(lhsType, 0)) {
507  Value tRhs = t(rhs);
508  Value tMask = t(mask, {2, 0, 1});
509  return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
510  }
511  }
512  // Transposed output: swap RHS and LHS.
513  // Classical row-major matmul: permute the lhs.
514  if (layout({{m, k}, {k, n}, {n, m}})) {
515  if (auto reductionSize = getReductionSize(lhsType, 1)) {
516  Value tLhs = t(lhs);
517  Value tMask = t(mask, {2, 0, 1});
518  return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
519  }
520  }
521  // TODO: may be better to fail and use some vector<k> -> scalar reduction.
522  if (layout({{m, k}, {n, k}, {n, m}})) {
523  if (auto reductionSize = getReductionSize(lhsType, 1)) {
524  Value tRhs = t(rhs);
525  Value tLhs = t(lhs);
526  Value tMask = t(mask, {2, 0, 1});
527  return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
528  }
529  }
530  if (layout({{k, m}, {k, n}, {n, m}})) {
531  if (auto reductionSize = getReductionSize(lhsType, 0)) {
532  Value tMask = t(mask, {2, 0, 1});
533  return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
534  }
535  }
536  if (layout({{k, m}, {n, k}, {n, m}})) {
537  if (auto reductionSize = getReductionSize(lhsType, 0)) {
538  Value tRhs = t(rhs);
539  Value tMask = t(mask, {2, 0, 1});
540  return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
541  }
542  }
543  return failure();
544  }
545 
546  //
547  // One outer parallel, one inner reduction (matvec flavor).
548  // Mask needs to be transposed everywhere to turn the reduction dimension
549  // outermost as required by outerproduct.
550  //
551  FailureOr<Value> matvec() {
552  if (!iters({Par(), Red()}))
553  return failure();
554  AffineExpr m, k;
555  bindDims(rewriter.getContext(), m, k);
556 
557  // Case mat-vec: transpose.
558  if (layout({{m, k}, {k}, {m}})) {
559  if (auto reductionSize = getReductionSize(lhsType, 1)) {
560  Value tLhs = t(lhs);
561  Value tMask = t(mask);
562  return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
563  }
564  }
565  // Case mat-trans-vec: ready to go.
566  if (layout({{k, m}, {k}, {m}})) {
567  if (auto reductionSize = getReductionSize(lhsType, 0)) {
568  Value tMask = t(mask);
569  return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
570  }
571  }
572  // Case vec-mat: swap and transpose.
573  if (layout({{k}, {m, k}, {m}})) {
574  if (auto reductionSize = getReductionSize(lhsType, 0)) {
575  Value tRhs = t(rhs);
576  Value tMask = t(mask);
577  return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
578  }
579  }
580  // Case vec-mat-trans: swap and ready to go.
581  if (layout({{k}, {k, m}, {m}})) {
582  if (auto reductionSize = getReductionSize(lhsType, 0)) {
583  Value tMask = t(mask);
584  return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
585  }
586  }
587  return failure();
588  }
589 
590  //
591  // One outer reduction, one inner parallel (tmatvec flavor).
592  // Mask already has the shape of the outer product.
593  //
594  FailureOr<Value> tmatvec() {
595  if (!iters({Red(), Par()}))
596  return failure();
597  AffineExpr k, m;
598  bindDims(rewriter.getContext(), k, m);
599 
600  // Case mat-vec: transpose.
601  if (layout({{m, k}, {k}, {m}}))
602  if (auto reductionSize = getReductionSize(lhsType, 1))
603  return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
604  // Case mat-trans-vec: ready to go.
605  if (layout({{k, m}, {k}, {m}}))
606  if (auto reductionSize = getReductionSize(lhsType, 0))
607  return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
608  // Case vec-mat: swap and transpose.
609  if (layout({{k}, {m, k}, {m}}))
610  if (auto reductionSize = getReductionSize(lhsType, 0))
611  return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
612  // Case vec-mat-trans: swap and ready to go.
613  if (layout({{k}, {k, m}, {m}}))
614  if (auto reductionSize = getReductionSize(lhsType, 0))
615  return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
616  return failure();
617  }
618 
619 private:
620  vector::CombiningKind kind;
621  Value lhs, rhs, res, mask;
622  VectorType lhsType;
623 };
624 
625 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
626 /// semantics to a reduction_size-unrolled sequence:
627 /// ```
628 /// %at = vector.transpose %a, [1, 0]
629 /// %bRow0 = vector.extract %b[0]
630 /// %atRow0 = vector.extract %at[0]
631 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
632 /// ...
633 /// %bRowK = vector.extract %b[K]
634 /// %atRowK = vector.extract %at[K]
635 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
636 /// ```
637 ///
638 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
639 /// otherwise supports any layout permutation of the matrix-multiply.
640 FailureOr<Value>
641 ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
642  vector::ContractionOp op, MaskingOpInterface maskOp,
643  PatternRewriter &rewriter) const {
644  if (vectorTransformOptions.vectorContractLowering !=
645  vector::VectorContractLowering::OuterProduct)
646  return failure();
647 
648  if (failed(filter(op)))
649  return failure();
650 
651  UnrolledOuterProductGenerator e(rewriter, op);
652  FailureOr<Value> matmatRes = e.matmat();
653  if (succeeded(matmatRes)) {
654  return matmatRes;
655  }
656  FailureOr<Value> matvecRes = e.matvec();
657  if (succeeded(matvecRes)) {
658  return matvecRes;
659  }
660 
661  FailureOr<Value> tmatvecRes = e.tmatvec();
662  return tmatvecRes;
663 }
664 
665 FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
666  vector::ContractionOp op, MaskingOpInterface maskOp,
667  PatternRewriter &rewriter) const {
668  // TODO: Support vector.mask.
669  if (maskOp)
670  return failure();
671 
672  if (failed(filter(op)))
673  return failure();
674 
675  if (vectorTransformOptions.vectorContractLowering !=
676  vector::VectorContractLowering::Dot)
677  return failure();
678 
679  auto iteratorTypes = op.getIteratorTypes().getValue();
680  static constexpr std::array<int64_t, 2> perm = {1, 0};
681  Location loc = op.getLoc();
682  Value lhs = op.getLhs(), rhs = op.getRhs();
683 
684  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
685  auto infer = [&](MapList m) {
686  return AffineMap::inferFromExprList(m, op.getContext());
687  };
688  AffineExpr m, n, k;
689  bindDims(rewriter.getContext(), m, n, k);
690  SmallVector<AffineMap> maps = op.getIndexingMapsArray();
691  //
692  // In the following we wish to make the reduction dimension innermost so we
693  // can load vectors and just fmul + reduce into a scalar.
694  //
695  if (isParallelIterator(iteratorTypes[0]) &&
696  isParallelIterator(iteratorTypes[1]) &&
697  isReductionIterator(iteratorTypes[2])) {
698  //
699  // Two outer parallel, one inner reduction (matmat flavor).
700  //
701  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
702  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
703  } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
704  // No need to permute anything.
705  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
706  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
707  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
708  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
709  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
710  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
711  // This is the classical row-major matmul. Just permute the lhs.
712  Value tmp = lhs;
713  lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
714  rhs = tmp;
715  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
716  std::swap(lhs, rhs);
717  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
718  Value tmp = lhs;
719  lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
720  rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
721  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
722  Value tmp = rhs;
723  rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
724  lhs = tmp;
725  } else {
726  return failure();
727  }
728  } else if (isParallelIterator(iteratorTypes[0]) &&
729  isReductionIterator(iteratorTypes[1])) {
730  //
731  // One outer parallel, one inner reduction (matvec flavor)
732  //
733  if (maps == infer({{m, n}, {n}, {m}})) {
734  // No need to permute anything.
735  } else if (maps == infer({{n, m}, {n}, {m}})) {
736  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
737  } else if (maps == infer({{n}, {m, n}, {m}})) {
738  std::swap(lhs, rhs);
739  } else if (maps == infer({{n}, {n, m}, {m}})) {
740  std::swap(lhs, rhs);
741  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
742  } else {
743  return failure();
744  }
745  } else {
746  return failure();
747  }
748 
749  VectorType dstType = cast<VectorType>(op.getResultType());
750  assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
751  "Expected dst type of rank 1 or 2");
752 
753  unsigned rank = dstType.getRank();
754  unsigned dstRows = dstType.getShape()[0];
755  unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
756 
757  // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
758  Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
759  rewriter.getZeroAttr(dstType));
760  bool isInt = isa<IntegerType>(dstType.getElementType());
761  for (unsigned r = 0; r < dstRows; ++r) {
762  Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
763  for (unsigned c = 0; c < dstColumns; ++c) {
764  Value b = rank == 1
765  ? rhs
766  : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
767  Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
768  Value reduced = rewriter.create<vector::ReductionOp>(
769  op.getLoc(), vector::CombiningKind::ADD, m);
770 
772  : SmallVector<int64_t, 2>{r, c};
773  res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
774  }
775  }
776  if (auto acc = op.getAcc())
777  res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
778  return res;
779 }
780 
781 /// Lower vector.contract with all size one reduction dimensions to
782 /// elementwise ops when possible.
783 struct ContractOpToElementwise
784  : public MaskableOpRewritePattern<vector::ContractionOp> {
785  using MaskableOpRewritePattern::MaskableOpRewritePattern;
786  using FilterConstraintType =
787  std::function<LogicalResult(vector::ContractionOp op)>;
788  static LogicalResult defaultFilter(vector::ContractionOp op) {
789  return success();
790  }
791  ContractOpToElementwise(
792  vector::VectorTransformsOptions vectorTransformOptions,
793  MLIRContext *context, PatternBenefit benefit = 1,
794  const FilterConstraintType &constraint = defaultFilter)
795  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
796  vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
797 
798  FailureOr<Value>
799  matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
800  MaskingOpInterface maskOp,
801  PatternRewriter &rewriter) const override {
802  // TODO: Support vector.mask.
803  if (maskOp)
804  return failure();
805 
806  if (failed(filter(contractOp)))
807  return failure();
808 
809  if (vectorTransformOptions.vectorContractLowering !=
810  vector::VectorContractLowering::ParallelArith)
811  return failure();
812 
813  ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
814  ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
815  AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
816  AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
817  SmallVector<int64_t> lhsReductionDims =
818  getReductionIndex(lhsMap, contractOp.getIteratorTypes());
819  SmallVector<int64_t> rhsReductionDims =
820  getReductionIndex(rhsMap, contractOp.getIteratorTypes());
821  // All the reduction dimensions must be a size 1.
822  for (int64_t dim : lhsReductionDims) {
823  if (lhsShape[dim] != 1)
824  return failure();
825  }
826  for (int64_t dim : rhsReductionDims) {
827  if (rhsShape[dim] != 1)
828  return failure();
829  }
830  AffineMap accMap = contractOp.getIndexingMapsArray()[2];
831  unsigned numParallelDims = accMap.getNumResults();
832  unsigned numLhsDimToBroadcast =
833  numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
834  unsigned numRhsDimToBroadcast =
835  numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
836  SmallVector<int64_t> lhsDims;
837  SmallVector<int64_t> lhsTranspose;
838  SmallVector<int64_t> rhsDims;
839  SmallVector<int64_t> rhsTranspose;
840  for (int64_t dim : lhsReductionDims)
841  lhsTranspose.push_back(numLhsDimToBroadcast + dim);
842  for (int64_t dim : rhsReductionDims)
843  rhsTranspose.push_back(numRhsDimToBroadcast + dim);
844  // Loop through the parallel dimensions to calculate the dimensions to
845  // broadcast and to permute in order to extract only parallel dimensions.
846  for (unsigned i = 0; i < numParallelDims; i++) {
847  std::optional<unsigned> lhsDim =
848  getDimPosition(lhsMap, accMap.getDimPosition(i));
849  if (lhsDim) {
850  lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
851  } else {
852  // If the parallel dimension doesn't exist we will have to broadcast it.
853  lhsDims.push_back(
854  cast<VectorType>(contractOp.getResultType()).getDimSize(i));
855  lhsTranspose.push_back(lhsDims.size() - 1);
856  }
857  std::optional<unsigned> rhsDim =
858  getDimPosition(rhsMap, accMap.getDimPosition(i));
859  if (rhsDim) {
860  rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
861  } else {
862  // If the parallel dimension doesn't exist we will have to broadcast it.
863  rhsDims.push_back(
864  cast<VectorType>(contractOp.getResultType()).getDimSize(i));
865  rhsTranspose.push_back(rhsDims.size() - 1);
866  }
867  }
868  Value newLhs = contractOp.getLhs();
869  Value newRhs = contractOp.getRhs();
870  Location loc = contractOp.getLoc();
871  if (!lhsDims.empty()) {
872  lhsDims.append(lhsShape.begin(), lhsShape.end());
873  auto expandedType =
874  VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
875  newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
876  }
877  if (!rhsDims.empty()) {
878  rhsDims.append(rhsShape.begin(), rhsShape.end());
879  auto expandedType =
880  VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
881  newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
882  }
883  bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
884  newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
885  newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
886  SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
887  SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
888  newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
889  newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
890  std::optional<Value> result =
891  createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
892  contractOp.getKind(), rewriter, isInt);
893  if (result)
894  return *result;
895 
896  return failure();
897  }
898 
899 private:
900  /// Options to control the vector patterns.
901  vector::VectorTransformsOptions vectorTransformOptions;
902  FilterConstraintType filter;
903 };
904 
905 /// Progressive lowering of ContractionOp.
906 /// One:
907 /// %x = vector.contract with at least one free/batch dimension
908 /// is replaced by:
909 /// %a = vector.contract with one less free/batch dimension
910 /// %b = vector.contract with one less free/batch dimension
911 /// ..
912 /// %x = combine %a %b ..
913 /// until a pure contraction is reached (no free/batch dimensions),
914 /// which is replaced by a dot-product.
915 ///
916 /// This only kicks in when either VectorTransformsOptions is set
917 /// to DOT or when other contraction patterns fail.
918 //
919 // TODO: break down into transpose/reshape/cast ops
920 // when they become available to avoid code dup
921 // TODO: investigate lowering order impact on performance
922 FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
923  vector::ContractionOp op, MaskingOpInterface maskOp,
924  PatternRewriter &rewriter) const {
925  if (failed(filter(op)))
926  return failure();
927 
928  // TODO: support mixed mode contract lowering.
929  if (op.getLhsType().getElementType() !=
930  getElementTypeOrSelf(op.getAccType()) ||
931  op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
932  return failure();
933 
934  // TODO: the code below assumes the default contraction, make sure it supports
935  // other kinds before enabling this lowering.
936  if (op.getKind() != vector::CombiningKind::ADD) {
937  return rewriter.notifyMatchFailure(
938  op, "contractions other than 'add' not supported");
939  }
940 
941  // TODO: implement benefits, cost models.
942  MLIRContext *ctx = op.getContext();
943 
944  ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
945  FailureOr<Value> newVal1 =
946  pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
947  if (!failed(newVal1))
948  return newVal1;
949 
950  ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
951  FailureOr<Value> newVal2 =
952  pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
953  if (!failed(newVal2))
954  return newVal2;
955 
956  ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
957  FailureOr<Value> newVal3 =
958  pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
959  if (!failed(newVal3))
960  return newVal3;
961 
962  ContractOpToElementwise pat4(vectorTransformOptions, ctx);
963  FailureOr<Value> newVal4 =
964  pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
965  if (!failed(newVal4))
966  return newVal4;
967 
968  // Vector mask setup.
969 
970  Value mask;
971  if (maskOp)
972  mask = maskOp.getMask();
973  // Find first batch dimension in LHS/RHS, and lower when found.
974  std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
975  if (!batchDimMap.empty()) {
976  int64_t lhsIndex = batchDimMap[0].first;
977  int64_t rhsIndex = batchDimMap[0].second;
978  auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
979  if (failed(newOp))
980  return failure();
981  return newOp;
982  }
983 
984  // Collect contracting dimensions.
985  std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
986  op.getContractingDimMap();
987  DenseSet<int64_t> lhsContractingDimSet;
988  DenseSet<int64_t> rhsContractingDimSet;
989  for (auto &dimPair : contractingDimMap) {
990  lhsContractingDimSet.insert(dimPair.first);
991  rhsContractingDimSet.insert(dimPair.second);
992  }
993 
994  // Find first free dimension in LHS, and lower when found.
995  VectorType lhsType = op.getLhsType();
996  for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
997  if (lhsContractingDimSet.count(lhsIndex) == 0) {
998  auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
999  if (failed(newOp))
1000  return failure();
1001  return newOp;
1002  }
1003  }
1004 
1005  // Find first free dimension in RHS, and lower when found.
1006  VectorType rhsType = op.getRhsType();
1007  for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1008  if (rhsContractingDimSet.count(rhsIndex) == 0) {
1009  auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
1010  if (failed(newOp))
1011  return failure();
1012  return newOp;
1013  }
1014  }
1015 
1016  // Lower the first remaining reduction dimension.
1017  if (!contractingDimMap.empty()) {
1018  auto newOp = lowerReduction(rewriter, op, mask);
1019  if (failed(newOp))
1020  return failure();
1021  return newOp;
1022  }
1023 
1024  return failure();
1025 }
1026 
1027 // Lower one parallel dimension.
1028 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
1029 // TODO: consider reusing existing contract unrolling
1030 FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
1031  vector::ContractionOp op,
1032  int64_t lhsIndex,
1033  int64_t rhsIndex,
1034  Value mask) const {
1035  VectorType lhsType = op.getLhsType();
1036  VectorType rhsType = op.getRhsType();
1037  VectorType resType = cast<VectorType>(op.getResultType());
1038  // Find the iterator type index and result index.
1039  SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1040  int64_t iterIndex = -1;
1041  int64_t dimSize = -1;
1042  if (lhsIndex >= 0) {
1043  iterIndex = iMap[0].getDimPosition(lhsIndex);
1044  if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
1045  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1046  diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
1047  << " to map to the same dimension";
1048  });
1049  if (lhsType.getScalableDims()[lhsIndex])
1050  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1051  diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex
1052  << ") is not supported yet";
1053  });
1054  dimSize = lhsType.getDimSize(lhsIndex);
1055  } else if (rhsIndex >= 0) {
1056  iterIndex = iMap[1].getDimPosition(rhsIndex);
1057  if (rhsType.getScalableDims()[rhsIndex])
1058  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1059  diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex
1060  << ") is not supported yet";
1061  });
1062  dimSize = rhsType.getDimSize(rhsIndex);
1063  }
1064  if (iterIndex < 0)
1065  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1066  diag << "expected either lhsIndex=" << lhsIndex
1067  << " or rhsIndex=" << rhsIndex << " to be nonnegative";
1068  });
1069  // value_or(-1) means that we tolerate a dimension not appearing
1070  // in the result map. That can't happen for actual parallel iterators, but
1071  // the caller ContractionOpLowering::matchAndRewrite is currently calling
1072  // lowerParallel also for the case of unit-size reduction dims appearing only
1073  // on one of LHS or RHS, not both. At the moment, such cases are created by
1074  // CastAwayContractionLeadingOneDim, so we need to either support that or
1075  // modify that pattern.
1076  int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
1077  if (resIndex == -1 && dimSize != 1)
1078  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1079  diag << "expected the dimension for iterIndex=" << iterIndex
1080  << " to either appear in the result map, or to be a unit dimension";
1081  });
1082 
1083  // Construct new iterator types and affine map array attribute.
1084  std::array<AffineMap, 3> lowIndexingMaps = {
1085  adjustMap(iMap[0], iterIndex, rewriter),
1086  adjustMap(iMap[1], iterIndex, rewriter),
1087  adjustMap(iMap[2], iterIndex, rewriter)};
1088  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1089  auto lowIter =
1090  rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1091  // Unroll into a series of lower dimensional vector.contract ops.
1092  Location loc = op.getLoc();
1093  Value result = rewriter.create<arith::ConstantOp>(
1094  loc, resType, rewriter.getZeroAttr(resType));
1095 
1096  for (int64_t d = 0; d < dimSize; ++d) {
1097  auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1098  auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1099  auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1100 
1101  Value lowMask;
1102  if (mask)
1103  lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1104  iterIndex, d, rewriter);
1105 
1106  Operation *lowContract = rewriter.create<vector::ContractionOp>(
1107  loc, lhs, rhs, acc, lowAffine, lowIter);
1108  lowContract = maskOperation(rewriter, lowContract, lowMask);
1109  result = reshapeStore(loc, lowContract->getResult(0), result, resType,
1110  resIndex, d, rewriter);
1111  }
1112  return result;
1113 }
1114 
1115 // Lower one reduction dimension.
1116 FailureOr<Value> ContractionOpLowering::lowerReduction(
1117  PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
1118  auto loc = op.getLoc();
1119  VectorType lhsType = op.getLhsType();
1120  VectorType rhsType = op.getRhsType();
1121  Type resType = op.getResultType();
1122  if (isa<VectorType>(resType))
1123  return rewriter.notifyMatchFailure(op,
1124  "did not expect a VectorType result");
1125  bool isInt = isa<IntegerType>(resType);
1126  // Use iterator index 0.
1127  int64_t iterIndex = 0;
1128  SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1129  std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1130  std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1131  if (!lookupLhs.has_value())
1132  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1133  diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1134  });
1135  if (!lookupRhs.has_value())
1136  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1137  diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1138  });
1139  int64_t lhsIndex = *lookupLhs;
1140  int64_t rhsIndex = *lookupRhs;
1141  int64_t dimSize = lhsType.getDimSize(lhsIndex);
1142  if (dimSize != rhsType.getDimSize(rhsIndex))
1143  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1144  diag << "expect LHS dimension " << lhsIndex
1145  << " to have the same size as RHS dimension " << rhsIndex;
1146  });
1147  // Base case.
1148  if (lhsType.getRank() == 1) {
1149  if (rhsType.getRank() != 1)
1150  return rewriter.notifyMatchFailure(
1151  op, "When LHS has rank 1, expected also RHS to have rank 1");
1152  Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1153  auto kind = vector::CombiningKind::ADD;
1154 
1155  Value acc = op.getAcc();
1156  Operation *reductionOp =
1157  acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
1158  : rewriter.create<vector::ReductionOp>(loc, kind, m);
1159  return maskOperation(rewriter, reductionOp, mask)->getResult(0);
1160  }
1161  // Construct new iterator types and affine map array attribute.
1162  std::array<AffineMap, 3> lowIndexingMaps = {
1163  adjustMap(iMap[0], iterIndex, rewriter),
1164  adjustMap(iMap[1], iterIndex, rewriter),
1165  adjustMap(iMap[2], iterIndex, rewriter)};
1166  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1167  auto lowIter =
1168  rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1169  // Unroll into a series of lower dimensional vector.contract ops.
1170  // By feeding the initial accumulator into the first contraction,
1171  // and the result of each contraction into the next, eventually
1172  // the sum of all reductions is computed.
1173  Value result = op.getAcc();
1174  for (int64_t d = 0; d < dimSize; ++d) {
1175  auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1176  auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1177  Value newMask;
1178  if (mask)
1179  newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1180  iterIndex, d, rewriter);
1181 
1182  Operation *newContract = rewriter.create<vector::ContractionOp>(
1183  loc, lhs, rhs, result, lowAffine, lowIter);
1184  result = maskOperation(rewriter, newContract, newMask)->getResult(0);
1185  }
1186  return result;
1187 }
1188 
1189 /// Progressive lowering of OuterProductOp.
1190 /// One:
1191 /// %x = vector.outerproduct %lhs, %rhs, %acc
1192 /// is replaced by:
1193 /// %z = zero-result
1194 /// %0 = vector.extract %lhs[0]
1195 /// %1 = vector.broadcast %0
1196 /// %2 = vector.extract %acc[0]
1197 /// %3 = vector.fma %1, %rhs, %2
1198 /// %4 = vector.insert %3, %z[0]
1199 /// ..
1200 /// %x = vector.insert %.., %..[N-1]
1201 ///
1202 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1203 public:
1205 
1206  LogicalResult matchAndRewrite(vector::OuterProductOp op,
1207  PatternRewriter &rewriter) const override {
1208  VectorType resType = op.getResultVectorType();
1209  if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1210  return failure();
1211 
1212  auto loc = op.getLoc();
1213 
1214  VectorType lhsType = op.getOperandVectorTypeLHS();
1215  VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1216  Type eltType = resType.getElementType();
1217  bool isInt = isa<IntegerType, IndexType>(eltType);
1218  Value acc = op.getAcc();
1219  vector::CombiningKind kind = op.getKind();
1220 
1221  // Vector mask setup.
1222  OpBuilder::InsertionGuard guard(rewriter);
1223  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1224  Operation *rootOp;
1225  Value mask;
1226  if (maskableOp.isMasked()) {
1227  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
1228  rootOp = maskableOp.getMaskingOp();
1229  mask = maskableOp.getMaskingOp().getMask();
1230  } else {
1231  rootOp = op;
1232  }
1233 
1234  if (!rhsType) {
1235  // Special case: AXPY operation.
1236  Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
1237  std::optional<Value> mult = createContractArithOp(
1238  loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
1239  if (!mult.has_value())
1240  return failure();
1241  rewriter.replaceOp(rootOp, *mult);
1242  return success();
1243  }
1244 
1245  Value result = rewriter.create<arith::ConstantOp>(
1246  loc, resType, rewriter.getZeroAttr(resType));
1247  for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1248  Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d);
1249  Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
1250  Value r = nullptr;
1251  if (acc)
1252  r = rewriter.create<vector::ExtractOp>(loc, acc, d);
1253  Value extrMask;
1254  if (mask)
1255  extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d);
1256 
1257  std::optional<Value> m = createContractArithOp(
1258  loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1259  if (!m.has_value())
1260  return failure();
1261  result = rewriter.create<vector::InsertOp>(loc, *m, result, d);
1262  }
1263 
1264  rewriter.replaceOp(rootOp, result);
1265  return success();
1266  }
1267 };
1268 
1269 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1270 /// semantics to:
1271 /// ```
1272 /// %mta = maybe_transpose
1273 /// %mtb = maybe_transpose
1274 /// %flattened_a = vector.shape_cast %mta
1275 /// %flattened_b = vector.shape_cast %mtb
1276 /// %flattened_d = vector.matmul %flattened_a, %flattened_b
1277 /// %mtd = vector.shape_cast %flattened_d
1278 /// %d = maybe_untranspose %mtd
1279 /// %e = add %c, %d
1280 /// ```
1281 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1282 //
1283 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1284 /// vector.transpose operations are inserted if the vector.contract op is not a
1285 /// row-major matrix multiply.
1286 ///
1287 /// Scalable vectors are not supported.
1288 FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
1289  vector::ContractionOp op, MaskingOpInterface maskOp,
1290  PatternRewriter &rew) const {
1291  // TODO: Support vector.mask.
1292  if (maskOp)
1293  return failure();
1294 
1295  if (vectorTransformOptions.vectorContractLowering !=
1296  vector::VectorContractLowering::Matmul)
1297  return failure();
1298  if (failed(filter(op)))
1299  return failure();
1300 
1301  auto iteratorTypes = op.getIteratorTypes().getValue();
1302  if (!isParallelIterator(iteratorTypes[0]) ||
1303  !isParallelIterator(iteratorTypes[1]) ||
1304  !isReductionIterator(iteratorTypes[2]))
1305  return failure();
1306 
1307  Type opResType = op.getType();
1308  VectorType vecType = dyn_cast<VectorType>(opResType);
1309  if (vecType && vecType.isScalable()) {
1310  // Note - this is sufficient to reject all cases with scalable vectors.
1311  return failure();
1312  }
1313 
1314  Type elementType = op.getLhsType().getElementType();
1315  if (!elementType.isIntOrFloat())
1316  return failure();
1317 
1318  Type dstElementType = vecType ? vecType.getElementType() : opResType;
1319  if (elementType != dstElementType)
1320  return failure();
1321 
1322  // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1323  // Bail out if the contraction cannot be put in this form.
1324  MLIRContext *ctx = op.getContext();
1325  Location loc = op.getLoc();
1326  AffineExpr m, n, k;
1327  bindDims(rew.getContext(), m, n, k);
1328  // LHS must be A(m, k) or A(k, m).
1329  Value lhs = op.getLhs();
1330  auto lhsMap = op.getIndexingMapsArray()[0];
1331  if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1332  lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1333  else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1334  return failure();
1335 
1336  // RHS must be B(k, n) or B(n, k).
1337  Value rhs = op.getRhs();
1338  auto rhsMap = op.getIndexingMapsArray()[1];
1339  if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1340  rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1341  else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1342  return failure();
1343 
1344  // At this point lhs and rhs are in row-major.
1345  VectorType lhsType = cast<VectorType>(lhs.getType());
1346  VectorType rhsType = cast<VectorType>(rhs.getType());
1347  int64_t lhsRows = lhsType.getDimSize(0);
1348  int64_t lhsColumns = lhsType.getDimSize(1);
1349  int64_t rhsColumns = rhsType.getDimSize(1);
1350 
1351  Type flattenedLHSType =
1352  VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1353  lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1354 
1355  Type flattenedRHSType =
1356  VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1357  rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1358 
1359  Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1360  rhsColumns);
1361  mul = rew.create<vector::ShapeCastOp>(
1362  loc,
1363  VectorType::get({lhsRows, rhsColumns},
1364  getElementTypeOrSelf(op.getAcc().getType())),
1365  mul);
1366 
1367  // ACC must be C(m, n) or C(n, m).
1368  auto accMap = op.getIndexingMapsArray()[2];
1369  if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1370  mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1371  else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1372  llvm_unreachable("invalid contraction semantics");
1373 
1374  Value res =
1375  isa<IntegerType>(elementType)
1376  ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
1377  : static_cast<Value>(
1378  rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
1379 
1380  return res;
1381 }
1382 } // namespace
1383 
1386  PatternBenefit benefit, bool disableOuterProductLowering) {
1387  if (!disableOuterProductLowering)
1388  patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1389  patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
1390  ContractionOpToOuterProductOpLowering>(
1391  options, patterns.getContext(), benefit);
1392 }
1393 
1395  RewritePatternSet &patterns, PatternBenefit benefit) {
1396  patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1397 }
static SmallVector< int64_t > getReductionIndex(AffineMap map, ArrayAttr iteratorTypes)
Return the positions of the reductions in the given map.
static Value createMul(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates a MulIOp if isInt is true otherwise create an MulFOp using operands x andy`.
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
static std::optional< unsigned > getDimPosition(AffineMap map, unsigned dim)
Look for a given dimension in an affine map and return its position.
static std::optional< Value > createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask=Value())
Helper to create arithmetic operation associated with a kind of contraction.
static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< Attribute > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
#define MINUI(lhs, rhs)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:394
unsigned getNumResults() const
Definition: AffineMap.cpp:402
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:358
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:127
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:317
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:342
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:152
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:147
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:157
Structure to control the behavior of vector transform patterns.