MLIR  16.0.0git
VectorTransforms.cpp
Go to the documentation of this file.
1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include <type_traits>
16 
26 #include "mlir/IR/BuiltinTypes.h"
28 #include "mlir/IR/Matchers.h"
29 #include "mlir/IR/PatternMatch.h"
31 
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/MapVector.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/raw_ostream.h"
38 
39 #define DEBUG_TYPE "vector-to-vector"
40 
41 using namespace mlir;
42 using namespace mlir::vector;
43 
44 // Helper to find an index in an affine map.
45 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
46  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
47  int64_t idx = map.getDimPosition(i);
48  if (idx == index)
49  return i;
50  }
51  return None;
52 }
53 
54 // Helper to construct iterator types with one index removed.
55 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
56  int64_t index) {
58  for (const auto &it : llvm::enumerate(iteratorTypes)) {
59  int64_t idx = it.index();
60  if (idx == index)
61  continue;
62  results.push_back(it.value());
63  }
64  return results;
65 }
66 
67 // Helper to construct an affine map with one index removed.
68 static AffineMap adjustMap(AffineMap map, int64_t index,
69  PatternRewriter &rewriter) {
70  auto *ctx = rewriter.getContext();
72  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
73  int64_t idx = map.getDimPosition(i);
74  if (idx == index)
75  continue;
76  // Re-insert remaining indices, but renamed when occurring
77  // after the removed index.
78  auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
79  results.push_back(targetExpr);
80  }
81  return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
82 }
83 
84 // Helper method to possibly drop a dimension in a load.
85 // TODO
86 static Value reshapeLoad(Location loc, Value val, VectorType type,
87  int64_t index, int64_t pos,
88  PatternRewriter &rewriter) {
89  if (index == -1)
90  return val;
91  Type lowType = VectorType::Builder(type).dropDim(0);
92  // At extraction dimension?
93  if (index == 0) {
94  auto posAttr = rewriter.getI64ArrayAttr(pos);
95  return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
96  }
97  // Unroll leading dimensions.
98  VectorType vType = lowType.cast<VectorType>();
99  Type resType = VectorType::Builder(type).dropDim(index);
100  auto resVectorType = resType.cast<VectorType>();
101  Value result = rewriter.create<arith::ConstantOp>(
102  loc, resVectorType, rewriter.getZeroAttr(resVectorType));
103  for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
104  auto posAttr = rewriter.getI64ArrayAttr(d);
105  Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
106  Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
107  result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
108  posAttr);
109  }
110  return result;
111 }
112 
113 // Helper method to possibly drop a dimension in a store.
114 // TODO
115 static Value reshapeStore(Location loc, Value val, Value result,
116  VectorType type, int64_t index, int64_t pos,
117  PatternRewriter &rewriter) {
118  // Unmodified?
119  if (index == -1)
120  return val;
121  // At insertion dimension?
122  if (index == 0) {
123  auto posAttr = rewriter.getI64ArrayAttr(pos);
124  return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
125  }
126  // Unroll leading dimensions.
127  Type lowType = VectorType::Builder(type).dropDim(0);
128  VectorType vType = lowType.cast<VectorType>();
129  Type insType = VectorType::Builder(vType).dropDim(0);
130  for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
131  auto posAttr = rewriter.getI64ArrayAttr(d);
132  Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
133  Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
134  Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
135  result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
136  }
137  return result;
138 }
139 
140 template <typename IntType>
141 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
142  return llvm::to_vector<4>(llvm::map_range(
143  arrayAttr.getAsRange<IntegerAttr>(),
144  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
145 }
146 
147 /// Helper to create arithmetic operation associated with a kind of contraction.
148 static Optional<Value> createContractArithOp(Location loc, Value x, Value y,
149  Value acc,
150  vector::CombiningKind kind,
151  PatternRewriter &rewriter,
152  bool isInt) {
153  using vector::CombiningKind;
154  Value mul;
155  if (isInt) {
156  if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
157  // Only valid for floating point types.
158  return Optional<Value>();
159  mul = rewriter.create<arith::MulIOp>(loc, x, y);
160  } else {
161  // Float case.
162  if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
163  kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
164  kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
165  kind == CombiningKind::XOR)
166  // Only valid for integer types.
167  return Optional<Value>();
168  // Special case for fused multiply-add.
169  if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
170  return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
171  }
172  mul = rewriter.create<arith::MulFOp>(loc, x, y);
173  }
174  if (!acc)
175  return Optional<Value>(mul);
176  return makeArithReduction(rewriter, loc, kind, mul, acc);
177 }
178 
179 /// Return the positions of the reductions in the given map.
181  ArrayAttr iteratorTypes) {
182  SmallVector<int64_t> dimsIdx;
183  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
184  if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
185  dimsIdx.push_back(i);
186  }
187  return dimsIdx;
188 }
189 
190 /// Look for a given dimension in an affine map and return its position. Return
191 /// llvm::None if the dimension is not in the map results.
193  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
194  if (map.getDimPosition(i) == dim)
195  return i;
196  }
197  return llvm::None;
198 }
199 
200 namespace {
201 
202 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
203 //
204 // Example:
205 //
206 // The following MLIR with cancelling ShapeCastOps:
207 //
208 // %0 = source : vector<5x4x2xf32>
209 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
210 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
211 // %3 = user %2 : vector<5x4x2xf32>
212 //
213 // Should canonicalize to the following:
214 //
215 // %0 = source : vector<5x4x2xf32>
216 // %1 = user %0 : vector<5x4x2xf32>
217 //
218 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
220 
221  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
222  PatternRewriter &rewriter) const override {
223  // Check if 'shapeCastOp' has vector source/result type.
224  auto sourceVectorType =
225  shapeCastOp.getSource().getType().dyn_cast_or_null<VectorType>();
226  auto resultVectorType =
227  shapeCastOp.getResult().getType().dyn_cast_or_null<VectorType>();
228  if (!sourceVectorType || !resultVectorType)
229  return failure();
230 
231  // Check if shape cast op source operand is also a shape cast op.
232  auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
233  shapeCastOp.getSource().getDefiningOp());
234  if (!sourceShapeCastOp)
235  return failure();
236  auto operandSourceVectorType =
237  sourceShapeCastOp.getSource().getType().cast<VectorType>();
238  auto operandResultVectorType = sourceShapeCastOp.getType();
239 
240  // Check if shape cast operations invert each other.
241  if (operandSourceVectorType != resultVectorType ||
242  operandResultVectorType != sourceVectorType)
243  return failure();
244 
245  rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
246  return success();
247  }
248 };
249 
250 /// Progressive lowering of BroadcastOp.
251 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
252 public:
254 
255  LogicalResult matchAndRewrite(vector::BroadcastOp op,
256  PatternRewriter &rewriter) const override {
257  auto loc = op.getLoc();
258  VectorType dstType = op.getVectorType();
259  VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
260  Type eltType = dstType.getElementType();
261 
262  // Scalar to any vector can use splat.
263  if (!srcType) {
264  rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
265  return success();
266  }
267 
268  // Determine rank of source and destination.
269  int64_t srcRank = srcType.getRank();
270  int64_t dstRank = dstType.getRank();
271 
272  // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
273  if (srcRank <= 1 && dstRank == 1) {
274  Value ext;
275  if (srcRank == 0)
276  ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
277  else
278  ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
279  rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
280  return success();
281  }
282 
283  // Duplicate this rank.
284  // For example:
285  // %x = broadcast %y : k-D to n-D, k < n
286  // becomes:
287  // %b = broadcast %y : k-D to (n-1)-D
288  // %x = [%b,%b,%b,%b] : n-D
289  // becomes:
290  // %b = [%y,%y] : (n-1)-D
291  // %x = [%b,%b,%b,%b] : n-D
292  if (srcRank < dstRank) {
293  // Duplication.
294  VectorType resType =
295  VectorType::get(dstType.getShape().drop_front(), eltType);
296  Value bcst =
297  rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
298  Value result = rewriter.create<arith::ConstantOp>(
299  loc, dstType, rewriter.getZeroAttr(dstType));
300  for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
301  result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
302  rewriter.replaceOp(op, result);
303  return success();
304  }
305 
306  // Find non-matching dimension, if any.
307  assert(srcRank == dstRank);
308  int64_t m = -1;
309  for (int64_t r = 0; r < dstRank; r++)
310  if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
311  m = r;
312  break;
313  }
314 
315  // All trailing dimensions are the same. Simply pass through.
316  if (m == -1) {
317  rewriter.replaceOp(op, op.getSource());
318  return success();
319  }
320 
321  // Any non-matching dimension forces a stretch along this rank.
322  // For example:
323  // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
324  // becomes:
325  // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
326  // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
327  // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
328  // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
329  // %x = [%a,%b,%c,%d]
330  // becomes:
331  // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
332  // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
333  // %a = [%u, %v]
334  // ..
335  // %x = [%a,%b,%c,%d]
336  VectorType resType =
337  VectorType::get(dstType.getShape().drop_front(), eltType);
338  Value result = rewriter.create<arith::ConstantOp>(
339  loc, dstType, rewriter.getZeroAttr(dstType));
340  if (m == 0) {
341  // Stetch at start.
342  Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
343  Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
344  for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
345  result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
346  } else {
347  // Stetch not at start.
348  for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
349  Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
350  Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
351  result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
352  }
353  }
354  rewriter.replaceOp(op, result);
355  return success();
356  }
357 };
358 
359 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not
360 /// transposed.
361 void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
362  SmallVectorImpl<int64_t> &result) {
363  size_t numTransposedDims = transpose.size();
364  for (size_t transpDim : llvm::reverse(transpose)) {
365  if (transpDim != numTransposedDims - 1)
366  break;
367  numTransposedDims--;
368  }
369 
370  result.append(transpose.begin(), transpose.begin() + numTransposedDims);
371 }
372 
373 /// Progressive lowering of TransposeOp.
374 /// One:
375 /// %x = vector.transpose %y, [1, 0]
376 /// is replaced by:
377 /// %z = arith.constant dense<0.000000e+00>
378 /// %0 = vector.extract %y[0, 0]
379 /// %1 = vector.insert %0, %z [0, 0]
380 /// ..
381 /// %x = vector.insert .., .. [.., ..]
382 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
383 public:
385 
387  MLIRContext *context)
389  vectorTransformOptions(vectorTransformOptions) {}
390 
391  LogicalResult matchAndRewrite(vector::TransposeOp op,
392  PatternRewriter &rewriter) const override {
393  auto loc = op.getLoc();
394 
395  Value input = op.getVector();
396  VectorType inputType = op.getVectorType();
397  VectorType resType = op.getResultType();
398 
399  // Set up convenience transposition table.
401  for (auto attr : op.getTransp())
402  transp.push_back(attr.cast<IntegerAttr>().getInt());
403 
404  if (vectorTransformOptions.vectorTransposeLowering ==
406  resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
407  return rewriter.notifyMatchFailure(
408  op, "Options specifies lowering to shuffle");
409 
410  // Handle a true 2-D matrix transpose differently when requested.
411  if (vectorTransformOptions.vectorTransposeLowering ==
413  resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
414  Type flattenedType =
415  VectorType::get(resType.getNumElements(), resType.getElementType());
416  auto matrix =
417  rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
418  auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
419  auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
420  Value trans = rewriter.create<vector::FlatTransposeOp>(
421  loc, flattenedType, matrix, rows, columns);
422  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
423  return success();
424  }
425 
426  // Generate unrolled extract/insert ops. We do not unroll the rightmost
427  // (i.e., highest-order) dimensions that are not transposed and leave them
428  // in vector form to improve performance. Therefore, we prune those
429  // dimensions from the shape/transpose data structures used to generate the
430  // extract/insert ops.
431  SmallVector<int64_t, 4> prunedTransp;
432  pruneNonTransposedDims(transp, prunedTransp);
433  size_t numPrunedDims = transp.size() - prunedTransp.size();
434  auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
435  SmallVector<int64_t, 4> ones(prunedInShape.size(), 1);
436  auto prunedInStrides = computeStrides(prunedInShape, ones);
437 
438  // Generates the extract/insert operations for every scalar/vector element
439  // of the leftmost transposed dimensions. We traverse every transpose
440  // element using a linearized index that we delinearize to generate the
441  // appropriate indices for the extract/insert operations.
442  Value result = rewriter.create<arith::ConstantOp>(
443  loc, resType, rewriter.getZeroAttr(resType));
444  int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
445 
446  for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
447  ++linearIdx) {
448  auto extractIdxs = delinearize(prunedInStrides, linearIdx);
449  SmallVector<int64_t, 4> insertIdxs(extractIdxs);
450  applyPermutationToVector(insertIdxs, prunedTransp);
451  Value extractOp =
452  rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
453  result =
454  rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
455  }
456 
457  rewriter.replaceOp(op, result);
458  return success();
459  }
460 
461 private:
462  /// Options to control the vector patterns.
463  vector::VectorTransformsOptions vectorTransformOptions;
464 };
465 
466 /// Rewrite a 2-D vector.transpose as a sequence of:
467 /// vector.shape_cast 2D -> 1D
468 /// vector.shuffle
469 /// vector.shape_cast 1D -> 2D
470 class TransposeOp2DToShuffleLowering
471  : public OpRewritePattern<vector::TransposeOp> {
472 public:
474 
475  TransposeOp2DToShuffleLowering(
476  vector::VectorTransformsOptions vectorTransformOptions,
477  MLIRContext *context)
479  vectorTransformOptions(vectorTransformOptions) {}
480 
481  LogicalResult matchAndRewrite(vector::TransposeOp op,
482  PatternRewriter &rewriter) const override {
483  auto loc = op.getLoc();
484 
485  VectorType srcType = op.getVectorType();
486  if (srcType.getRank() != 2)
487  return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
488 
490  for (auto attr : op.getTransp())
491  transp.push_back(attr.cast<IntegerAttr>().getInt());
492  if (transp[0] != 1 && transp[1] != 0)
493  return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
494 
495  if (vectorTransformOptions.vectorTransposeLowering !=
497  return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
498 
499  int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
500  Value casted = rewriter.create<vector::ShapeCastOp>(
501  loc, VectorType::get({m * n}, srcType.getElementType()),
502  op.getVector());
504  mask.reserve(m * n);
505  for (int64_t j = 0; j < n; ++j)
506  for (int64_t i = 0; i < m; ++i)
507  mask.push_back(i * n + j);
508 
509  Value shuffled =
510  rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
511  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
512  shuffled);
513 
514  return success();
515  }
516 
517 private:
518  /// Options to control the vector patterns.
519  vector::VectorTransformsOptions vectorTransformOptions;
520 };
521 
522 /// Progressive lowering of OuterProductOp.
523 /// One:
524 /// %x = vector.outerproduct %lhs, %rhs, %acc
525 /// is replaced by:
526 /// %z = zero-result
527 /// %0 = vector.extract %lhs[0]
528 /// %1 = vector.broadcast %0
529 /// %2 = vector.extract %acc[0]
530 /// %3 = vector.fma %1, %rhs, %2
531 /// %4 = vector.insert %3, %z[0]
532 /// ..
533 /// %x = vector.insert %.., %..[N-1]
534 ///
535 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
536 public:
538 
539  LogicalResult matchAndRewrite(vector::OuterProductOp op,
540  PatternRewriter &rewriter) const override {
541  auto loc = op.getLoc();
542 
543  VectorType lhsType = op.getOperandVectorTypeLHS();
544  VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
545  VectorType resType = op.getVectorType();
546  Type eltType = resType.getElementType();
547  bool isInt = eltType.isa<IntegerType, IndexType>();
548  Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
549  vector::CombiningKind kind = op.getKind();
550 
551  if (!rhsType) {
552  // Special case: AXPY operation.
553  Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
554  Optional<Value> mult = createContractArithOp(loc, op.getLhs(), b, acc,
555  kind, rewriter, isInt);
556  if (!mult.has_value())
557  return failure();
558  rewriter.replaceOp(op, mult.value());
559  return success();
560  }
561 
562  Value result = rewriter.create<arith::ConstantOp>(
563  loc, resType, rewriter.getZeroAttr(resType));
564  for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
565  auto pos = rewriter.getI64ArrayAttr(d);
566  Value x =
567  rewriter.create<vector::ExtractOp>(loc, eltType, op.getLhs(), pos);
568  Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
569  Value r = nullptr;
570  if (acc)
571  r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
572  Optional<Value> m =
573  createContractArithOp(loc, a, op.getRhs(), r, kind, rewriter, isInt);
574  if (!m.has_value())
575  return failure();
576  result = rewriter.create<vector::InsertOp>(loc, resType, m.value(),
577  result, pos);
578  }
579  rewriter.replaceOp(op, result);
580  return success();
581  }
582 };
583 
584 /// Lower vector.contract with all size one reduction dimensions to
585 /// elementwise ops when possible.
586 struct ContractOpToElementwise
587  : public OpRewritePattern<vector::ContractionOp> {
589  using FilterConstraintType =
590  std::function<LogicalResult(vector::ContractionOp op)>;
591  static LogicalResult defaultFilter(vector::ContractionOp op) {
592  return success();
593  }
594  ContractOpToElementwise(
595  vector::VectorTransformsOptions vectorTransformOptions,
596  MLIRContext *context,
597  const FilterConstraintType &constraint = defaultFilter)
599  vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
600 
601  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
602  PatternRewriter &rewriter) const override {
603  // TODO: implement masks
604  if (llvm::size(contractOp.getMasks()) != 0)
605  return failure();
606 
607  if (failed(filter(contractOp)))
608  return failure();
609 
610  if (vectorTransformOptions.vectorContractLowering !=
612  return failure();
613  ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
614  ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
615  AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
616  AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
617  SmallVector<int64_t> lhsReductionDims =
618  getReductionIndex(lhsMap, contractOp.getIteratorTypes());
619  SmallVector<int64_t> rhsReductionDims =
620  getReductionIndex(rhsMap, contractOp.getIteratorTypes());
621  // All the reduction dimensions must be a size 1.
622  for (int64_t dim : lhsReductionDims) {
623  if (lhsShape[dim] != 1)
624  return failure();
625  }
626  for (int64_t dim : rhsReductionDims) {
627  if (rhsShape[dim] != 1)
628  return failure();
629  }
630  AffineMap accMap = contractOp.getIndexingMapsArray()[2];
631  unsigned numParallelDims = accMap.getNumResults();
632  unsigned numLhsDimToBroadcast =
633  numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
634  unsigned numRhsDimToBroadcast =
635  numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
636  SmallVector<int64_t> lhsDims;
637  SmallVector<int64_t> lhsTranspose;
638  SmallVector<int64_t> rhsDims;
639  SmallVector<int64_t> rhsTranspose;
640  for (int64_t dim : lhsReductionDims)
641  lhsTranspose.push_back(numLhsDimToBroadcast + dim);
642  for (int64_t dim : rhsReductionDims)
643  rhsTranspose.push_back(numRhsDimToBroadcast + dim);
644  // Loop through the parallel dimensions to calculate the dimensions to
645  // broadcast and to permute in order to extract only parallel dimensions.
646  for (unsigned i = 0; i < numParallelDims; i++) {
647  llvm::Optional<unsigned> lhsDim =
648  getDimPosition(lhsMap, accMap.getDimPosition(i));
649  if (lhsDim) {
650  lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
651  } else {
652  // If the parallel dimension doesn't exist we will have to broadcast it.
653  lhsDims.push_back(
654  contractOp.getResultType().cast<VectorType>().getDimSize(i));
655  lhsTranspose.push_back(lhsDims.size() - 1);
656  }
657  llvm::Optional<unsigned> rhsDim =
658  getDimPosition(rhsMap, accMap.getDimPosition(i));
659  if (rhsDim) {
660  rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
661  } else {
662  // If the parallel dimension doesn't exist we will have to broadcast it.
663  rhsDims.push_back(
664  contractOp.getResultType().cast<VectorType>().getDimSize(i));
665  rhsTranspose.push_back(rhsDims.size() - 1);
666  }
667  }
668  Value newLhs = contractOp.getLhs();
669  Value newRhs = contractOp.getRhs();
670  Location loc = contractOp.getLoc();
671  if (!lhsDims.empty()) {
672  lhsDims.append(lhsShape.begin(), lhsShape.end());
673  auto expandedType =
674  VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
675  newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
676  }
677  if (!rhsDims.empty()) {
678  rhsDims.append(rhsShape.begin(), rhsShape.end());
679  auto expandedType =
680  VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
681  newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
682  }
683  bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
684  newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
685  newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
686  SmallVector<int64_t, 4> lhsOffsets(lhsReductionDims.size(), 0);
687  SmallVector<int64_t, 4> rhsOffsets(rhsReductionDims.size(), 0);
688  newLhs = rewriter.create<vector::ExtractOp>(
689  loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
690  newRhs = rewriter.create<vector::ExtractOp>(
691  loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets));
692  Optional<Value> result =
693  createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
694  contractOp.getKind(), rewriter, isInt);
695  rewriter.replaceOp(contractOp, {*result});
696  return success();
697  }
698 
699 private:
700  /// Options to control the vector patterns.
701  vector::VectorTransformsOptions vectorTransformOptions;
702  FilterConstraintType filter;
703 };
704 
705 /// Progressive lowering of ConstantMaskOp.
706 /// One:
707 /// %x = vector.constant_mask [a,b]
708 /// is replaced by:
709 /// %z = zero-result
710 /// %l = vector.constant_mask [b]
711 /// %4 = vector.insert %l, %z[0]
712 /// ..
713 /// %x = vector.insert %l, %..[a-1]
714 /// until a one-dimensional vector is reached. All these operations
715 /// will be folded at LLVM IR level.
716 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
717 public:
719 
720  LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
721  PatternRewriter &rewriter) const override {
722  auto loc = op.getLoc();
723  auto dstType = op.getType();
724  auto eltType = dstType.getElementType();
725  auto dimSizes = op.getMaskDimSizes();
726  int64_t rank = dstType.getRank();
727 
728  if (rank == 0) {
729  assert(dimSizes.size() == 1 &&
730  "Expected exactly one dim size for a 0-D vector");
731  bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
732  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
733  op, dstType,
735  VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
736  ArrayRef<bool>{value}));
737  return success();
738  }
739 
740  // Scalable constant masks can only be lowered for the "none set" case.
741  if (dstType.cast<VectorType>().isScalable()) {
742  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
743  op, DenseElementsAttr::get(dstType, false));
744  return success();
745  }
746 
747  int64_t trueDim = std::min(dstType.getDimSize(0),
748  dimSizes[0].cast<IntegerAttr>().getInt());
749 
750  if (rank == 1) {
751  // Express constant 1-D case in explicit vector form:
752  // [T,..,T,F,..,F].
753  SmallVector<bool, 4> values(dstType.getDimSize(0));
754  for (int64_t d = 0; d < trueDim; d++)
755  values[d] = true;
756  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
757  op, dstType, rewriter.getBoolVectorAttr(values));
758  return success();
759  }
760 
761  VectorType lowType =
762  VectorType::get(dstType.getShape().drop_front(), eltType);
763  SmallVector<int64_t, 4> newDimSizes;
764  for (int64_t r = 1; r < rank; r++)
765  newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
766  Value trueVal = rewriter.create<vector::ConstantMaskOp>(
767  loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
768  Value result = rewriter.create<arith::ConstantOp>(
769  loc, dstType, rewriter.getZeroAttr(dstType));
770  for (int64_t d = 0; d < trueDim; d++) {
771  auto pos = rewriter.getI64ArrayAttr(d);
772  result =
773  rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
774  }
775  rewriter.replaceOp(op, result);
776  return success();
777  }
778 };
779 
780 /// Progressive lowering of CreateMaskOp.
781 /// One:
782 /// %x = vector.create_mask %a, ... : vector<dx...>
783 /// is replaced by:
784 /// %l = vector.create_mask ... : vector<...> ; one lower rank
785 /// %0 = arith.cmpi "slt", %ci, %a |
786 /// %1 = select %0, %l, %zeroes |
787 /// %r = vector.insert %1, %pr [i] | d-times
788 /// %x = ....
789 /// until a one-dimensional vector is reached.
790 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
791 public:
793 
794  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
795  PatternRewriter &rewriter) const override {
796  auto dstType = op.getResult().getType().cast<VectorType>();
797  int64_t rank = dstType.getRank();
798  if (rank <= 1)
799  return rewriter.notifyMatchFailure(
800  op, "0-D and 1-D vectors are handled separately");
801 
802  auto loc = op.getLoc();
803  auto eltType = dstType.getElementType();
804  int64_t dim = dstType.getDimSize(0);
805  Value idx = op.getOperand(0);
806 
807  VectorType lowType =
808  VectorType::get(dstType.getShape().drop_front(), eltType);
809  Value trueVal = rewriter.create<vector::CreateMaskOp>(
810  loc, lowType, op.getOperands().drop_front());
811  Value falseVal = rewriter.create<arith::ConstantOp>(
812  loc, lowType, rewriter.getZeroAttr(lowType));
813  Value result = rewriter.create<arith::ConstantOp>(
814  loc, dstType, rewriter.getZeroAttr(dstType));
815  for (int64_t d = 0; d < dim; d++) {
816  Value bnd =
817  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
818  Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
819  bnd, idx);
820  Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
821  auto pos = rewriter.getI64ArrayAttr(d);
822  result =
823  rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
824  }
825  rewriter.replaceOp(op, result);
826  return success();
827  }
828 };
829 
830 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
831 /// vectors progressively on the way to target llvm.matrix intrinsics.
832 /// This iterates over the most major dimension of the 2-D vector and performs
833 /// rewrites into:
834 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
835 class ShapeCastOp2DDownCastRewritePattern
836  : public OpRewritePattern<vector::ShapeCastOp> {
837 public:
839 
840  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
841  PatternRewriter &rewriter) const override {
842  auto sourceVectorType = op.getSourceVectorType();
843  auto resultVectorType = op.getResultVectorType();
844  if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
845  return failure();
846 
847  auto loc = op.getLoc();
848  Value desc = rewriter.create<arith::ConstantOp>(
849  loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
850  unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
851  for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
852  Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
853  desc = rewriter.create<vector::InsertStridedSliceOp>(
854  loc, vec, desc,
855  /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
856  }
857  rewriter.replaceOp(op, desc);
858  return success();
859  }
860 };
861 
862 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
863 /// vectors progressively.
864 /// This iterates over the most major dimension of the 2-D vector and performs
865 /// rewrites into:
866 /// vector.extract_strided_slice from 1-D + vector.insert into 2-D
867 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
868 class ShapeCastOp2DUpCastRewritePattern
869  : public OpRewritePattern<vector::ShapeCastOp> {
870 public:
872 
873  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
874  PatternRewriter &rewriter) const override {
875  auto sourceVectorType = op.getSourceVectorType();
876  auto resultVectorType = op.getResultVectorType();
877  if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
878  return failure();
879 
880  auto loc = op.getLoc();
881  Value desc = rewriter.create<arith::ConstantOp>(
882  loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
883  unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
884  for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
885  Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
886  loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
887  /*sizes=*/mostMinorVectorSize,
888  /*strides=*/1);
889  desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
890  }
891  rewriter.replaceOp(op, desc);
892  return success();
893  }
894 };
895 
896 // We typically should not lower general shape cast operations into data
897 // movement instructions, since the assumption is that these casts are
898 // optimized away during progressive lowering. For completeness, however,
899 // we fall back to a reference implementation that moves all elements
900 // into the right place if we get here.
901 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
902 public:
904 
905  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
906  PatternRewriter &rewriter) const override {
907  Location loc = op.getLoc();
908  auto sourceVectorType = op.getSourceVectorType();
909  auto resultVectorType = op.getResultVectorType();
910 
911  // Special case 2D/1D lowerings with better implementations.
912  // TODO: make is ND/1D to allow generic ND->1D->MD.
913  int64_t srcRank = sourceVectorType.getRank();
914  int64_t resRank = resultVectorType.getRank();
915  if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
916  return failure();
917 
918  // Generic ShapeCast lowering path goes all the way down to unrolled scalar
919  // extract/insert chains.
920  // TODO: consider evolving the semantics to only allow 1D source or dest and
921  // drop this potentially very expensive lowering.
922  // Compute number of elements involved in the reshape.
923  int64_t numElts = 1;
924  for (int64_t r = 0; r < srcRank; r++)
925  numElts *= sourceVectorType.getDimSize(r);
926  // Replace with data movement operations:
927  // x[0,0,0] = y[0,0]
928  // x[0,0,1] = y[0,1]
929  // x[0,1,0] = y[0,2]
930  // etc., incrementing the two index vectors "row-major"
931  // within the source and result shape.
932  SmallVector<int64_t, 4> srcIdx(srcRank);
933  SmallVector<int64_t, 4> resIdx(resRank);
934  Value result = rewriter.create<arith::ConstantOp>(
935  loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
936  for (int64_t i = 0; i < numElts; i++) {
937  if (i != 0) {
938  incIdx(srcIdx, sourceVectorType, srcRank - 1);
939  incIdx(resIdx, resultVectorType, resRank - 1);
940  }
941  Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
942  result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
943  }
944  rewriter.replaceOp(op, result);
945  return success();
946  }
947 
948 private:
949  static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
950  assert(0 <= r && r < tp.getRank());
951  if (++idx[r] == tp.getDimSize(r)) {
952  idx[r] = 0;
953  incIdx(idx, tp, r - 1);
954  }
955  }
956 };
957 
958 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
959 /// Ex:
960 /// ```
961 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
962 /// %1 = vector.multi_reduction add, %0 [1]
963 /// : vector<8x32x16xf32> to vector<8x16xf32>
964 /// ```
965 /// Gets converted to:
966 /// ```
967 /// %1 = vector.contract {indexing_maps = [
968 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
969 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
970 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
971 /// iterator_types = ["parallel", "parallel", "reduction"],
972 /// kind = add} %0, %arg1, %cst_f0
973 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
974 /// ```
975 struct MultiReduceToContract
976  : public OpRewritePattern<vector::MultiDimReductionOp> {
978 
979  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
980  PatternRewriter &rewriter) const override {
981  if (reduceOp.getKind() != vector::CombiningKind::ADD)
982  return failure();
983  Operation *mulOp = reduceOp.getSource().getDefiningOp();
984  if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
985  return failure();
986  SmallVector<bool> reductionMask = reduceOp.getReductionMask();
987  auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
989  SmallVector<StringRef> iteratorTypes;
990  for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
991  if (!isReduceDim.value()) {
992  iteratorTypes.push_back(getParallelIteratorTypeName());
993  exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
994  } else {
995  iteratorTypes.push_back(getReductionIteratorTypeName());
996  }
997  }
998  auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
999  /*symCount=*/0, exprs, reduceOp.getContext());
1000  rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
1001  reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
1002  rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
1003  rewriter.getStrArrayAttr(iteratorTypes));
1004  return success();
1005  }
1006 };
1007 
1008 /// Merge TransposeOp into ContractionOp user.
1009 /// Ex:
1010 /// ```
1011 /// %0 = vector.transpose %arg0, [2, 0, 1]
1012 /// : vector<32x16x8xf32> to vector<8x32x16xf32>
1013 /// %1 = vector.contract {indexing_maps = [
1014 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1015 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1016 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
1017 /// iterator_types = ["parallel", "parallel", "reduction"],
1018 /// kind = add} %0, %arg1, %cst_f0
1019 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1020 /// ```
1021 /// Gets converted to:
1022 /// ```
1023 /// %1 = vector.contract {indexing_maps = [
1024 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
1025 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1026 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
1027 /// iterator_types = ["parallel", "parallel", "reduction"],
1028 /// kind = add} %arg0, %arg1, %cst_f0
1029 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1030 /// ```
1031 struct CombineContractTranspose
1032  : public OpRewritePattern<vector::ContractionOp> {
1034 
1035  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1036  PatternRewriter &rewriter) const override {
1038  llvm::to_vector<4>(contractOp.getIndexingMapsArray());
1039  Value lhs = contractOp.getLhs();
1040  Value rhs = contractOp.getRhs();
1041  size_t index = 0;
1042  bool changed = false;
1043  for (Value *operand : {&lhs, &rhs}) {
1044  AffineMap &map = maps[index++];
1045  auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
1046  if (!transposeOp)
1047  continue;
1049  transposeOp.getTransp(perm);
1050  AffineMap permutationMap = AffineMap::getPermutationMap(
1051  extractVector<unsigned>(transposeOp.getTransp()),
1052  contractOp.getContext());
1053  map = inversePermutation(permutationMap).compose(map);
1054  *operand = transposeOp.getVector();
1055  changed = true;
1056  }
1057  if (!changed)
1058  return failure();
1059  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1060  contractOp, lhs, rhs, contractOp.getAcc(),
1061  rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
1062  return success();
1063  }
1064 };
1065 
1066 /// Merge BroadcastOp into ContractionOp user.
1067 /// Ex:
1068 /// ```
1069 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
1070 /// %1 = vector.contract {indexing_maps = [
1071 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1072 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1073 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
1074 /// iterator_types = ["parallel", "parallel", "reduction"],
1075 /// kind = add} %0, %arg1, %cst_f0
1076 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1077 /// ```
1078 /// Gets converted to:
1079 /// ```
1080 /// %1 = vector.contract {indexing_maps = [
1081 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
1082 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1083 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
1084 /// iterator_types = ["parallel", "parallel", "reduction"],
1085 /// kind = add} %arg0, %arg1, %cst_f0
1086 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1087 /// ```
1088 struct CombineContractBroadcast
1089  : public OpRewritePattern<vector::ContractionOp> {
1091 
1092  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1093  PatternRewriter &rewriter) const override {
1095  llvm::to_vector<4>(contractOp.getIndexingMapsArray());
1096  Value lhs = contractOp.getLhs();
1097  Value rhs = contractOp.getRhs();
1098  size_t index = 0;
1099  bool changed = false;
1100  for (Value *operand : {&lhs, &rhs}) {
1101  AffineMap &map = maps[index++];
1102  auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
1103  if (!broadcast)
1104  continue;
1105  // contractionOp can only take vector as operands.
1106  auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
1107  if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
1108  continue;
1109  int64_t rankDiff =
1110  broadcast.getVectorType().getRank() - srcType.getRank();
1111  bool innerDimBroadcast = false;
1112  SmallVector<AffineExpr> originalDims;
1113  for (const auto &dim : llvm::enumerate(srcType.getShape())) {
1114  if (dim.value() !=
1115  broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
1116  innerDimBroadcast = true;
1117  break;
1118  }
1119  originalDims.push_back(
1120  rewriter.getAffineDimExpr(dim.index() + rankDiff));
1121  }
1122  // Contract doesn't support inner dimension broadcast. Once this is
1123  // relaxed we can remove this case.
1124  if (innerDimBroadcast)
1125  continue;
1126 
1127  // It would be incorrect to fold a broadcast onto a reduction dimension
1128  // of non-unit size.
1129  bool nonUnitDimReductionBroadcast = false;
1130  for (int64_t i = 0; i < rankDiff; ++i) {
1131  if (broadcast.getVectorType().getDimSize(i) != 1 &&
1132  isReductionIterator(contractOp.getIteratorTypes()
1133  .getValue()[map.getDimPosition(i)])) {
1134  nonUnitDimReductionBroadcast = true;
1135  break;
1136  }
1137  }
1138  if (nonUnitDimReductionBroadcast)
1139  continue;
1140 
1141  AffineMap broadcastMap =
1142  AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
1143  contractOp.getContext());
1144  map = broadcastMap.compose(map);
1145  *operand = broadcast.getSource();
1146  changed = true;
1147  }
1148 
1149  if (!changed)
1150  return failure();
1151 
1152  // Determine which dims are usused, now that the maps have been composed
1153  // with the broadcast maps.
1154  llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
1155  // Compress unused dims.
1156  for (auto &m : maps)
1157  m = compressDims(m, unusedDimsBitVector);
1158  // Compute the combined iterators.
1159  SmallVector<Attribute, 4> iterators;
1160  for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
1161  if (!unusedDimsBitVector.test(i))
1162  iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
1163  }
1164  // Check that compressing unused dims isn't removing all reduction dimension
1165  // pairs. For example, if the vector.contract had only one reduction
1166  // iterator and that was a unit-dimension created by a broadcast,
1167  // then we should bail here, otherwise we would create a contract without
1168  // a reduction dimension pair.
1169  bool hasReductionIteratorApplyingOnBothSides = false;
1170  for (unsigned i = 0; i < iterators.size(); ++i) {
1171  if (!isReductionIterator(iterators[i]))
1172  continue;
1173  if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
1174  hasReductionIteratorApplyingOnBothSides = true;
1175  break;
1176  }
1177  }
1178  if (!hasReductionIteratorApplyingOnBothSides)
1179  return failure();
1180 
1181  // If the compressed maps have a dimension that is not used by either LHS or
1182  // RHS then the ContractionOp verifier would fail.
1183  if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
1184  return failure();
1185  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1186  contractOp, lhs, rhs, contractOp.getAcc(),
1187  rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
1188  return success();
1189  }
1190 };
1191 
1192 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
1193 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
1194 /// casting ops are around these operations.
1195 /// Ex:
1196 /// ```
1197 /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
1198 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
1199 /// ```
1200 /// Gets converted to:
1201 /// ```
1202 /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
1203 /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
1204 /// ```
1205 struct ReorderCastOpsOnBroadcast
1206  : public OpInterfaceRewritePattern<CastOpInterface> {
1208 
1209  LogicalResult matchAndRewrite(CastOpInterface op,
1210  PatternRewriter &rewriter) const override {
1211  if (op->getNumOperands() != 1)
1212  return failure();
1213  auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
1214  if (!bcastOp)
1215  return failure();
1216 
1217  Type castResTy = getElementTypeOrSelf(op->getResult(0));
1218  if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
1219  castResTy = VectorType::get(vecTy.getShape(), castResTy);
1220  auto *castOp =
1221  rewriter.create(op->getLoc(), op->getName().getIdentifier(),
1222  bcastOp.getSource(), castResTy, op->getAttrs());
1223  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1224  op, op->getResult(0).getType(), castOp->getResult(0));
1225  return success();
1226  }
1227 };
1228 
1229 /// Reorders elementwise(transpose) to transpose(elementwise). This makes
1230 /// transpose ops and contraction ops closer, which kicks in
1231 /// CombineContractTranspose pattern when elementwise ops are between these
1232 /// operations. Ex:
1233 /// ```
1234 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
1235 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
1236 /// %r = arith.addf %at, %bt : vector<2x4xf32>
1237 /// ```
1238 /// Gets converted to:
1239 /// ```
1240 /// %0 = arith.addf %a, %b : vector<4x2xf32>
1241 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
1242 /// ```
1243 struct ReorderElementwiseOpsOnTranspose final
1244  : public OpTraitRewritePattern<OpTrait::Elementwise> {
1246  LogicalResult matchAndRewrite(Operation *op,
1247  PatternRewriter &rewriter) const override {
1248  if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1249  return failure();
1250 
1251  // Make sure all operands are transpose/constant ops and collect their
1252  // transposition maps.
1253  SmallVector<ArrayAttr, 4> transposeMaps;
1254  transposeMaps.reserve(op->getNumOperands());
1255  // Record the initial type before transposition. We'll use its shape later.
1256  // Any type will do here as we will check all transpose maps are the same.
1257  VectorType srcType;
1258  for (Value operand : op->getOperands()) {
1259  auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
1260  if (transposeOp) {
1261  transposeMaps.push_back(transposeOp.getTransp());
1262  srcType = transposeOp.getVectorType();
1263  } else if (!matchPattern(operand, m_Constant())) {
1264  return failure();
1265  }
1266  }
1267  if (transposeMaps.empty())
1268  return failure();
1269  // This is an elementwise op, so all transposed operands should have the
1270  // same type. We need to additionally check that all transposes uses the
1271  // same map.
1272  if (!llvm::is_splat(transposeMaps))
1273  return rewriter.notifyMatchFailure(op, "different transpose map");
1274 
1275  SmallVector<Value, 4> srcValues;
1276  srcValues.reserve(op->getNumOperands());
1277 
1278  // If there are constant operands, we need to insert inverse transposes for
1279  // them. Calculate the inverse order first.
1280  auto order = extractVector<unsigned>(transposeMaps.front());
1281  SmallVector<int64_t> invOrder(order.size());
1282  for (int i = 0, e = order.size(); i < e; ++i)
1283  invOrder[order[i]] = i;
1284 
1285  for (Value operand : op->getOperands()) {
1286  auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
1287  if (transposeOp) {
1288  srcValues.push_back(transposeOp.getVector());
1289  } else {
1290  // This is a constant. Create a reverse transpose op for it.
1291  auto vectorType = VectorType::get(
1292  srcType.getShape(),
1293  operand.getType().cast<VectorType>().getElementType());
1294  srcValues.push_back(rewriter.create<vector::TransposeOp>(
1295  operand.getLoc(), vectorType, operand,
1296  rewriter.getI64ArrayAttr(invOrder)));
1297  }
1298  }
1299 
1300  auto vectorType = VectorType::get(
1301  srcType.getShape(),
1302  op->getResultTypes()[0].cast<VectorType>().getElementType());
1303  Operation *elementwiseOp =
1304  rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
1305  vectorType, op->getAttrs());
1306  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
1307  op, op->getResultTypes()[0], elementwiseOp->getResult(0),
1308  transposeMaps.front());
1309  return success();
1310  }
1311 };
1312 
1313 } // namespace
1314 
1315 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
1316 /// operands `x` and `y`.
1317 static Value createAdd(Location loc, Value x, Value y, bool isInt,
1318  PatternRewriter &rewriter) {
1319  if (isInt)
1320  return rewriter.create<arith::AddIOp>(loc, x, y);
1321  return rewriter.create<arith::AddFOp>(loc, x, y);
1322 }
1323 
1324 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
1325 /// operands `x and `y`.
1326 static Value createMul(Location loc, Value x, Value y, bool isInt,
1327  PatternRewriter &rewriter) {
1328  if (isInt)
1329  return rewriter.create<arith::MulIOp>(loc, x, y);
1330  return rewriter.create<arith::MulFOp>(loc, x, y);
1331 }
1332 
1333 namespace mlir {
1334 
1335 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1336 /// semantics to:
1337 /// ```
1338 /// %mta = maybe_transpose
1339 /// %mtb = maybe_transpose
1340 /// %flattened_a = vector.shape_cast %mta
1341 /// %flattened_b = vector.shape_cast %mtb
1342 /// %flattened_d = vector.matmul %flattened_a, %flattened_b
1343 /// %mtd = vector.shape_cast %flattened_d
1344 /// %d = maybe_untranspose %mtd
1345 /// %e = add %c, %d
1346 /// ```
1347 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1348 //
1349 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1350 /// vector.transpose operations are inserted if the vector.contract op is not a
1351 /// row-major matrix multiply.
1354  PatternRewriter &rew) const {
1355  // TODO: implement masks
1356  if (llvm::size(op.getMasks()) != 0)
1357  return failure();
1358  if (vectorTransformOptions.vectorContractLowering !=
1360  return failure();
1361  if (failed(filter(op)))
1362  return failure();
1363 
1364  auto iteratorTypes = op.getIteratorTypes().getValue();
1365  if (!isParallelIterator(iteratorTypes[0]) ||
1366  !isParallelIterator(iteratorTypes[1]) ||
1367  !isReductionIterator(iteratorTypes[2]))
1368  return failure();
1369 
1370  Type elementType = op.getLhsType().getElementType();
1371  if (!elementType.isIntOrFloat())
1372  return failure();
1373 
1374  Type dstElementType = op.getType();
1375  if (auto vecType = dstElementType.dyn_cast<VectorType>())
1376  dstElementType = vecType.getElementType();
1377  if (elementType != dstElementType)
1378  return failure();
1379 
1380  // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1381  // Bail out if the contraction cannot be put in this form.
1382  MLIRContext *ctx = op.getContext();
1383  Location loc = op.getLoc();
1384  AffineExpr m, n, k;
1385  bindDims(rew.getContext(), m, n, k);
1386  // LHS must be A(m, k) or A(k, m).
1387  Value lhs = op.getLhs();
1388  auto lhsMap = op.getIndexingMapsArray()[0];
1389  if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1390  lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1391  else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1392  return failure();
1393 
1394  // RHS must be B(k, n) or B(n, k).
1395  Value rhs = op.getRhs();
1396  auto rhsMap = op.getIndexingMapsArray()[1];
1397  if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1398  rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1399  else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1400  return failure();
1401 
1402  // At this point lhs and rhs are in row-major.
1403  VectorType lhsType = lhs.getType().cast<VectorType>();
1404  VectorType rhsType = rhs.getType().cast<VectorType>();
1405  int64_t lhsRows = lhsType.getDimSize(0);
1406  int64_t lhsColumns = lhsType.getDimSize(1);
1407  int64_t rhsColumns = rhsType.getDimSize(1);
1408 
1409  Type flattenedLHSType =
1410  VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1411  lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1412 
1413  Type flattenedRHSType =
1414  VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1415  rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1416 
1417  Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1418  rhsColumns);
1419  mul = rew.create<vector::ShapeCastOp>(
1420  loc,
1421  VectorType::get({lhsRows, rhsColumns},
1422  getElementTypeOrSelf(op.getAcc().getType())),
1423  mul);
1424 
1425  // ACC must be C(m, n) or C(n, m).
1426  auto accMap = op.getIndexingMapsArray()[2];
1427  if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1428  mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1429  else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1430  llvm_unreachable("invalid contraction semantics");
1431 
1432  Value res =
1433  elementType.isa<IntegerType>()
1434  ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
1435  : static_cast<Value>(
1436  rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
1437 
1438  rew.replaceOp(op, res);
1439  return success();
1440 }
1441 
1442 namespace {
1443 struct IteratorType {
1444  IteratorType(StringRef strRef) : strRef(strRef) {}
1445  bool isOfType(Attribute attr) const {
1446  auto sAttr = attr.dyn_cast<StringAttr>();
1447  return sAttr && sAttr.getValue() == strRef;
1448  }
1449  StringRef strRef;
1450 };
1451 struct Par : public IteratorType {
1453 };
1454 struct Red : public IteratorType {
1456 };
1457 
1458 /// Generate a vector implementation for matmat, matvec and tmatvec.
1459 /// This unrolls outer-products along the reduction dimension.
1460 struct UnrolledOuterProductGenerator
1461  : public StructuredGenerator<vector::ContractionOp> {
1462  UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
1464  kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
1465  res(op.getAcc()), lhsType(op.getLhsType()) {}
1466 
1467  Value t(Value v) {
1468  static constexpr std::array<int64_t, 2> perm = {1, 0};
1469  return builder.create<vector::TransposeOp>(loc, v, perm);
1470  }
1471 
1472  Value promote(Value v, Type dstElementType) {
1473  Type elementType = v.getType();
1474  auto vecType = elementType.dyn_cast<VectorType>();
1475  if (vecType)
1476  elementType = vecType.getElementType();
1477  if (elementType == dstElementType)
1478  return v;
1479  Type promotedType = dstElementType;
1480  if (vecType)
1481  promotedType = VectorType::get(vecType.getShape(), promotedType);
1482  if (dstElementType.isa<FloatType>())
1483  return builder.create<arith::ExtFOp>(loc, promotedType, v);
1484  return builder.create<arith::ExtSIOp>(loc, promotedType, v);
1485  }
1486 
1487  Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
1488  assert(reductionSize > 0);
1489  Type resElementType = res.getType().cast<VectorType>().getElementType();
1490  for (int64_t k = 0; k < reductionSize; ++k) {
1491  Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
1492  Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
1493  a = promote(a, resElementType);
1494  b = promote(b, resElementType);
1495  res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b,
1496  res, kind);
1497  }
1498  return res;
1499  }
1500 
1501  /// Two outer parallel, one inner reduction (matmat flavor).
1502  FailureOr<Value> matmat() {
1503  if (!iters({Par(), Par(), Red()}))
1504  return failure();
1505  // Set up the parallel/reduction structure in the right form.
1506  AffineExpr m, n, k;
1507  bindDims(builder.getContext(), m, n, k);
1508  // Classical row-major matmul: Just permute the lhs.
1509  if (layout({{m, k}, {k, n}, {m, n}}))
1510  return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1511  // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1512  if (layout({{m, k}, {n, k}, {m, n}})) {
1513  Value tlhs = t(lhs);
1514  return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
1515  }
1516  // No need to permute anything.
1517  if (layout({{k, m}, {k, n}, {m, n}}))
1518  return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1519  // Just permute the rhs.
1520  if (layout({{k, m}, {n, k}, {m, n}}))
1521  return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
1522  // Transposed output: swap RHS and LHS.
1523  // Classical row-major matmul: permute the lhs.
1524  if (layout({{m, k}, {k, n}, {n, m}}))
1525  return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
1526  // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1527  if (layout({{m, k}, {n, k}, {n, m}})) {
1528  Value trhs = t(rhs);
1529  return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
1530  }
1531  if (layout({{k, m}, {k, n}, {n, m}}))
1532  return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1533  if (layout({{k, m}, {n, k}, {n, m}}))
1534  return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1535  return failure();
1536  }
1537 
1538  /// One outer parallel, one inner reduction (matvec flavor)
1539  FailureOr<Value> matvec() {
1540  if (!iters({Par(), Red()}))
1541  return failure();
1542  AffineExpr m, k;
1543  bindDims(builder.getContext(), m, k);
1544 
1545  // Case mat-vec: transpose.
1546  if (layout({{m, k}, {k}, {m}}))
1547  return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1548  // Case mat-trans-vec: ready to go.
1549  if (layout({{k, m}, {k}, {m}}))
1550  return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1551  // Case vec-mat: swap and transpose.
1552  if (layout({{k}, {m, k}, {m}}))
1553  return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1554  // Case vec-mat-trans: swap and ready to go.
1555  if (layout({{k}, {k, m}, {m}}))
1556  return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1557  return failure();
1558  }
1559 
1560  //
1561  // One outer reduction, one inner parallel (tmatvec flavor)
1562  //
1563  FailureOr<Value> tmatvec() {
1564  if (!iters({Red(), Par()}))
1565  return failure();
1566  AffineExpr k, m;
1567  bindDims(builder.getContext(), k, m);
1568 
1569  // Case mat-vec: transpose.
1570  if (layout({{m, k}, {k}, {m}}))
1571  return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1572  // Case mat-trans-vec: ready to go.
1573  if (layout({{k, m}, {k}, {m}}))
1574  return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1575  // Case vec-mat: swap and transpose.
1576  if (layout({{k}, {m, k}, {m}}))
1577  return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1578  // Case vec-mat-trans: swap and ready to go.
1579  if (layout({{k}, {k, m}, {m}}))
1580  return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1581  return failure();
1582  }
1583 
1584 private:
1585  vector::CombiningKind kind;
1586  Value lhs, rhs, res;
1587  VectorType lhsType;
1588 };
1589 } // namespace
1590 
1591 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1592 /// semantics to a reduction_size-unrolled sequence:
1593 /// ```
1594 /// %at = vector.transpose %a, [1, 0]
1595 /// %bRow0 = vector.extract %b[0]
1596 /// %atRow0 = vector.extract %at[0]
1597 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
1598 /// ...
1599 /// %bRowK = vector.extract %b[K]
1600 /// %atRowK = vector.extract %at[K]
1601 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
1602 /// ```
1603 ///
1604 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
1605 /// otherwise supports any layout permutation of the matrix-multiply.
1607  vector::ContractionOp op, PatternRewriter &rewriter) const {
1608  // TODO: implement masks
1609  if (llvm::size(op.getMasks()) != 0)
1610  return failure();
1611 
1612  if (vectorTransformOptions.vectorContractLowering !=
1614  return failure();
1615 
1616  if (failed(filter(op)))
1617  return failure();
1618 
1619  UnrolledOuterProductGenerator e(rewriter, op);
1620  FailureOr<Value> matmatRes = e.matmat();
1621  if (succeeded(matmatRes)) {
1622  rewriter.replaceOp(op, *matmatRes);
1623  return success();
1624  }
1625  FailureOr<Value> matvecRes = e.matvec();
1626  if (succeeded(matvecRes)) {
1627  rewriter.replaceOp(op, *matvecRes);
1628  return success();
1629  }
1630  FailureOr<Value> tmatvecRes = e.tmatvec();
1631  if (succeeded(tmatvecRes)) {
1632  rewriter.replaceOp(op, *tmatvecRes);
1633  return success();
1634  }
1635 
1636  return failure();
1637 }
1638 
1641  PatternRewriter &rewriter) const {
1642  // TODO: implement masks
1643  if (llvm::size(op.getMasks()) != 0)
1644  return failure();
1645 
1646  if (failed(filter(op)))
1647  return failure();
1648 
1649  if (vectorTransformOptions.vectorContractLowering !=
1651  return failure();
1652 
1653  auto iteratorTypes = op.getIteratorTypes().getValue();
1654  static constexpr std::array<int64_t, 2> perm = {1, 0};
1655  Location loc = op.getLoc();
1656  Value lhs = op.getLhs(), rhs = op.getRhs();
1657 
1658  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1659  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1660  AffineExpr m, n, k;
1661  bindDims(rewriter.getContext(), m, n, k);
1662  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1663  //
1664  // In the following we wish to make the reduction dimension innermost so we
1665  // can load vectors and just fmul + reduce into a scalar.
1666  //
1667  if (isParallelIterator(iteratorTypes[0]) &&
1668  isParallelIterator(iteratorTypes[1]) &&
1669  isReductionIterator(iteratorTypes[2])) {
1670  //
1671  // Two outer parallel, one inner reduction (matmat flavor).
1672  //
1673  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1674  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1675  } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1676  // No need to permute anything.
1677  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1678  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1679  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1680  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1681  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1682  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1683  // This is the classical row-major matmul. Just permute the lhs.
1684  Value tmp = lhs;
1685  lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1686  rhs = tmp;
1687  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1688  std::swap(lhs, rhs);
1689  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1690  Value tmp = lhs;
1691  lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1692  rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1693  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1694  Value tmp = rhs;
1695  rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1696  lhs = tmp;
1697  } else {
1698  return failure();
1699  }
1700  } else if (isParallelIterator(iteratorTypes[0]) &&
1701  isReductionIterator(iteratorTypes[1])) {
1702  //
1703  // One outer parallel, one inner reduction (matvec flavor)
1704  //
1705  if (maps == infer({{m, n}, {n}, {m}})) {
1706  // No need to permute anything.
1707  } else if (maps == infer({{n, m}, {n}, {m}})) {
1708  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1709  } else if (maps == infer({{n}, {m, n}, {m}})) {
1710  std::swap(lhs, rhs);
1711  } else if (maps == infer({{n}, {n, m}, {m}})) {
1712  std::swap(lhs, rhs);
1713  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1714  } else {
1715  return failure();
1716  }
1717  } else {
1718  return failure();
1719  }
1720 
1721  VectorType dstType = op.getResultType().cast<VectorType>();
1722  assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
1723  "Expected dst type of rank 1 or 2");
1724 
1725  unsigned rank = dstType.getRank();
1726  unsigned dstRows = dstType.getShape()[0];
1727  unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
1728 
1729  // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
1730  Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
1731  rewriter.getZeroAttr(dstType));
1732  bool isInt = dstType.getElementType().isa<IntegerType>();
1733  for (unsigned r = 0; r < dstRows; ++r) {
1734  Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
1735  for (unsigned c = 0; c < dstColumns; ++c) {
1736  Value b = rank == 1
1737  ? rhs
1738  : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
1739  Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
1740  Value reduced = rewriter.create<vector::ReductionOp>(
1741  op.getLoc(), vector::CombiningKind::ADD, m);
1742 
1744  : SmallVector<int64_t, 2>{r, c};
1745  res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
1746  }
1747  }
1748  if (auto acc = op.getAcc())
1749  res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1750  rewriter.replaceOp(op, res);
1751  return success();
1752 }
1753 
1754 /// Progressive lowering of ContractionOp.
1755 /// One:
1756 /// %x = vector.contract with at least one free/batch dimension
1757 /// is replaced by:
1758 /// %a = vector.contract with one less free/batch dimension
1759 /// %b = vector.contract with one less free/batch dimension
1760 /// ..
1761 /// %x = combine %a %b ..
1762 /// until a pure contraction is reached (no free/batch dimensions),
1763 /// which is replaced by a dot-product.
1764 ///
1765 /// This only kicks in when either VectorTransformsOptions is set
1766 /// to DOT or when other contraction patterns fail.
1767 //
1768 // TODO: break down into transpose/reshape/cast ops
1769 // when they become available to avoid code dup
1770 // TODO: investigate lowering order impact on performance
1773  PatternRewriter &rewriter) const {
1774  // TODO: implement masks.
1775  if (llvm::size(op.getMasks()) != 0)
1776  return failure();
1777 
1778  if (failed(filter(op)))
1779  return failure();
1780 
1781  // TODO: support mixed mode contract lowering.
1782  if (op.getLhsType().getElementType() !=
1783  getElementTypeOrSelf(op.getAccType()) ||
1784  op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
1785  return failure();
1786 
1787  // TODO: implement benefits, cost models.
1788  MLIRContext *ctx = op.getContext();
1789  ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
1790  if (succeeded(pat1.matchAndRewrite(op, rewriter)))
1791  return success();
1792  ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
1793  if (succeeded(pat2.matchAndRewrite(op, rewriter)))
1794  return success();
1795  ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
1796  if (succeeded(pat3.matchAndRewrite(op, rewriter)))
1797  return success();
1798  ContractOpToElementwise pat4(vectorTransformOptions, ctx);
1799  if (succeeded(pat4.matchAndRewrite(op, rewriter)))
1800  return success();
1801 
1802  // Find first batch dimension in LHS/RHS, and lower when found.
1803  std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
1804  if (!batchDimMap.empty()) {
1805  int64_t lhsIndex = batchDimMap[0].first;
1806  int64_t rhsIndex = batchDimMap[0].second;
1807  auto newOp = lowerParallel(op, lhsIndex, rhsIndex, rewriter);
1808  if (failed(newOp))
1809  return failure();
1810  rewriter.replaceOp(op, newOp.value());
1811  return success();
1812  }
1813 
1814  // Collect contracting dimensions.
1815  std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
1816  op.getContractingDimMap();
1817  DenseSet<int64_t> lhsContractingDimSet;
1818  DenseSet<int64_t> rhsContractingDimSet;
1819  for (auto &dimPair : contractingDimMap) {
1820  lhsContractingDimSet.insert(dimPair.first);
1821  rhsContractingDimSet.insert(dimPair.second);
1822  }
1823 
1824  // Find first free dimension in LHS, and lower when found.
1825  VectorType lhsType = op.getLhsType();
1826  for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1827  if (lhsContractingDimSet.count(lhsIndex) == 0) {
1828  auto newOp = lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter);
1829  if (failed(newOp))
1830  return failure();
1831  rewriter.replaceOp(op, newOp.value());
1832  return success();
1833  }
1834  }
1835 
1836  // Find first free dimension in RHS, and lower when found.
1837  VectorType rhsType = op.getRhsType();
1838  for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1839  if (rhsContractingDimSet.count(rhsIndex) == 0) {
1840  auto newOp = lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter);
1841  if (failed(newOp))
1842  return failure();
1843  rewriter.replaceOp(op, newOp.value());
1844  return success();
1845  }
1846  }
1847 
1848  // Lower the first remaining reduction dimension.
1849  if (!contractingDimMap.empty()) {
1850  auto newOp = lowerReduction(op, rewriter);
1851  if (failed(newOp))
1852  return failure();
1853  rewriter.replaceOp(op, newOp.value());
1854  return success();
1855  }
1856 
1857  return failure();
1858 }
1859 
1860 // Lower one parallel dimension.
1861 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
1862 // TODO: consider reusing existing contract unrolling
1864 ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
1865  int64_t rhsIndex,
1866  PatternRewriter &rewriter) const {
1867  VectorType lhsType = op.getLhsType();
1868  VectorType rhsType = op.getRhsType();
1869  VectorType resType = op.getResultType().cast<VectorType>();
1870  // Find the iterator type index and result index.
1871  SmallVector<AffineMap, 4> iMap = op.getIndexingMapsArray();
1872  int64_t iterIndex = -1;
1873  int64_t dimSize = -1;
1874  if (lhsIndex >= 0) {
1875  iterIndex = iMap[0].getDimPosition(lhsIndex);
1876  if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
1877  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1878  diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
1879  << " to map to the same dimension";
1880  });
1881  dimSize = lhsType.getDimSize(lhsIndex);
1882  } else if (rhsIndex >= 0) {
1883  iterIndex = iMap[1].getDimPosition(rhsIndex);
1884  dimSize = rhsType.getDimSize(rhsIndex);
1885  }
1886  if (iterIndex < 0)
1887  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1888  diag << "expected either lhsIndex=" << lhsIndex
1889  << " or rhsIndex=" << rhsIndex << " to be nonnegative";
1890  });
1891  // value_or(-1) means that we tolerate a dimension not appearing
1892  // in the result map. That can't happen for actual parallel iterators, but
1893  // the caller ContractionOpLowering::matchAndRewrite is currently calling
1894  // lowerParallel also for the case of unit-size reduction dims appearing only
1895  // on one of LHS or RHS, not both. At the moment, such cases are created by
1896  // CastAwayContractionLeadingOneDim, so we need to either support that or
1897  // modify that pattern.
1898  int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
1899  if (resIndex == -1 && dimSize != 1)
1900  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1901  diag << "expected the dimension for iterIndex=" << iterIndex
1902  << " to either appear in the result map, or to be a unit dimension";
1903  });
1904  // Construct new iterator types and affine map array attribute.
1905  std::array<AffineMap, 3> lowIndexingMaps = {
1906  adjustMap(iMap[0], iterIndex, rewriter),
1907  adjustMap(iMap[1], iterIndex, rewriter),
1908  adjustMap(iMap[2], iterIndex, rewriter)};
1909  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1910  auto lowIter =
1911  rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1912  // Unroll into a series of lower dimensional vector.contract ops.
1913  Location loc = op.getLoc();
1914  Value result = rewriter.create<arith::ConstantOp>(
1915  loc, resType, rewriter.getZeroAttr(resType));
1916  for (int64_t d = 0; d < dimSize; ++d) {
1917  auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1918  auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1919  auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1920  Value lowContract = rewriter.create<vector::ContractionOp>(
1921  loc, lhs, rhs, acc, lowAffine, lowIter);
1922  result =
1923  reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
1924  }
1925  return result;
1926 }
1927 
1928 // Lower one reduction dimension.
1930 ContractionOpLowering::lowerReduction(vector::ContractionOp op,
1931  PatternRewriter &rewriter) const {
1932  auto loc = op.getLoc();
1933  VectorType lhsType = op.getLhsType();
1934  VectorType rhsType = op.getRhsType();
1935  Type resType = op.getResultType();
1936  if (resType.isa<VectorType>())
1937  return rewriter.notifyMatchFailure(op,
1938  "did not expect a VectorType result");
1939  bool isInt = resType.isa<IntegerType>();
1940  // Use iterator index 0.
1941  int64_t iterIndex = 0;
1942  SmallVector<AffineMap, 4> iMap = op.getIndexingMapsArray();
1943  Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1944  Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1945  if (!lookupLhs.has_value())
1946  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1947  diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1948  });
1949  if (!lookupRhs.has_value())
1950  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1951  diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1952  });
1953  int64_t lhsIndex = lookupLhs.value();
1954  int64_t rhsIndex = lookupRhs.value();
1955  int64_t dimSize = lhsType.getDimSize(lhsIndex);
1956  if (dimSize != rhsType.getDimSize(rhsIndex))
1957  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1958  diag << "expect LHS dimension " << lhsIndex
1959  << " to have the same size as RHS dimension " << rhsIndex;
1960  });
1961  // Base case.
1962  if (lhsType.getRank() == 1) {
1963  if (rhsType.getRank() != 1)
1964  return rewriter.notifyMatchFailure(
1965  op, "When LHS has rank 1, expected also RHS to have rank 1");
1966  Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1967  auto kind = vector::CombiningKind::ADD;
1968  if (auto acc = op.getAcc())
1969  return rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
1970  .getResult();
1971  return rewriter.create<vector::ReductionOp>(loc, kind, m).getResult();
1972  }
1973  // Construct new iterator types and affine map array attribute.
1974  std::array<AffineMap, 3> lowIndexingMaps = {
1975  adjustMap(iMap[0], iterIndex, rewriter),
1976  adjustMap(iMap[1], iterIndex, rewriter),
1977  adjustMap(iMap[2], iterIndex, rewriter)};
1978  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1979  auto lowIter =
1980  rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1981  // Unroll into a series of lower dimensional vector.contract ops.
1982  // By feeding the initial accumulator into the first contraction,
1983  // and the result of each contraction into the next, eventually
1984  // the sum of all reductions is computed.
1985  Value result = op.getAcc();
1986  for (int64_t d = 0; d < dimSize; ++d) {
1987  auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1988  auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1989  result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
1990  lowAffine, lowIter);
1991  }
1992  return result;
1993 }
1994 
1995 } // namespace mlir
1996 
1998  OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
1999  ArrayRef<int64_t> multiplicity, const AffineMap &map) {
2000  OpBuilder::InsertionGuard guard(builder);
2001  builder.setInsertionPointAfter(op);
2002  Location loc = op->getLoc();
2003  if (op->getNumResults() != 1)
2004  return {};
2005  Value result = op->getResult(0);
2006  VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
2007  if (!type || map.getNumResults() != multiplicity.size())
2008  return {};
2009  // For each dimension being distributed check that the size is a multiple of
2010  // the multiplicity. To handle more sizes we would need to support masking.
2011  unsigned multiplictyCount = 0;
2012  for (auto exp : map.getResults()) {
2013  auto affinExp = exp.dyn_cast<AffineDimExpr>();
2014  if (!affinExp || affinExp.getPosition() >= type.getRank() ||
2015  type.getDimSize(affinExp.getPosition()) %
2016  multiplicity[multiplictyCount++] !=
2017  0)
2018  return {};
2019  }
2020  DistributeOps ops;
2021  ops.extract =
2022  builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
2023  ops.insert =
2024  builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
2025  return ops;
2026 }
2027 
2028 /// Progressive lowering of transfer_read. This pattern supports lowering of
2029 /// `vector.transfer_read` to a combination of `vector.load` and
2030 /// `vector.broadcast` if all of the following hold:
2031 /// - Stride of most minor memref dimension must be 1.
2032 /// - Out-of-bounds masking is not required.
2033 /// - If the memref's element type is a vector type then it coincides with the
2034 /// result type.
2035 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
2037  : public OpRewritePattern<vector::TransferReadOp> {
2039  llvm::Optional<unsigned> maxRank)
2040  : OpRewritePattern<vector::TransferReadOp>(context),
2041  maxTransferRank(maxRank) {}
2042 
2043  LogicalResult matchAndRewrite(vector::TransferReadOp read,
2044  PatternRewriter &rewriter) const override {
2045  if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
2046  return failure();
2047 
2048  SmallVector<unsigned, 4> broadcastedDims;
2049  // Permutations are handled by VectorToSCF or
2050  // populateVectorTransferPermutationMapLoweringPatterns.
2051  // We let the 0-d corner case pass-through as it is supported.
2052  if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
2053  &broadcastedDims))
2054  return failure();
2055 
2056  auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
2057  if (!memRefType)
2058  return failure();
2059 
2060  // Non-unit strides are handled by VectorToSCF.
2061  if (!vector::isLastMemrefDimUnitStride(memRefType))
2062  return failure();
2063 
2064  // If there is broadcasting involved then we first load the unbroadcasted
2065  // vector, and then broadcast it with `vector.broadcast`.
2066  ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
2067  SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
2068  vectorShape.end());
2069  for (unsigned i : broadcastedDims)
2070  unbroadcastedVectorShape[i] = 1;
2071  VectorType unbroadcastedVectorType = VectorType::get(
2072  unbroadcastedVectorShape, read.getVectorType().getElementType());
2073 
2074  // `vector.load` supports vector types as memref's elements only when the
2075  // resulting vector type is the same as the element type.
2076  auto memrefElTy = memRefType.getElementType();
2077  if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
2078  return failure();
2079 
2080  // Otherwise, element types of the memref and the vector must match.
2081  if (!memrefElTy.isa<VectorType>() &&
2082  memrefElTy != read.getVectorType().getElementType())
2083  return failure();
2084 
2085  // Out-of-bounds dims are handled by MaterializeTransferMask.
2086  if (read.hasOutOfBoundsDim())
2087  return failure();
2088 
2089  // Create vector load op.
2090  Operation *loadOp;
2091  if (read.getMask()) {
2092  Value fill = rewriter.create<vector::SplatOp>(
2093  read.getLoc(), unbroadcastedVectorType, read.getPadding());
2094  loadOp = rewriter.create<vector::MaskedLoadOp>(
2095  read.getLoc(), unbroadcastedVectorType, read.getSource(),
2096  read.getIndices(), read.getMask(), fill);
2097  } else {
2098  loadOp = rewriter.create<vector::LoadOp>(
2099  read.getLoc(), unbroadcastedVectorType, read.getSource(),
2100  read.getIndices());
2101  }
2102 
2103  // Insert a broadcasting op if required.
2104  if (!broadcastedDims.empty()) {
2105  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2106  read, read.getVectorType(), loadOp->getResult(0));
2107  } else {
2108  rewriter.replaceOp(read, loadOp->getResult(0));
2109  }
2110 
2111  return success();
2112  }
2113 
2115 };
2116 
2117 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
2118 // TODO: we shouldn't cross the vector/scalar domains just for this
2119 // but atm we lack the infra to avoid it. Possible solutions include:
2120 // - go directly to LLVM + bitcast
2121 // - introduce a bitcast op and likely a new pointer dialect
2122 // - let memref.load/store additionally support the 0-d vector case
2123 // There are still deeper data layout issues lingering even in this
2124 // trivial case (for architectures for which this matters).
2126  : public OpRewritePattern<vector::LoadOp> {
2128 
2129  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
2130  PatternRewriter &rewriter) const override {
2131  auto vecType = loadOp.getVectorType();
2132  if (vecType.getNumElements() != 1)
2133  return failure();
2134  auto memrefLoad = rewriter.create<memref::LoadOp>(
2135  loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
2136  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
2137  memrefLoad);
2138  return success();
2139  }
2140 };
2141 
2142 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
2144  : public OpRewritePattern<vector::StoreOp> {
2146 
2147  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
2148  PatternRewriter &rewriter) const override {
2149  auto vecType = storeOp.getVectorType();
2150  if (vecType.getNumElements() != 1)
2151  return failure();
2152  Value extracted;
2153  if (vecType.getRank() == 0) {
2154  // TODO: Unifiy once ExtractOp supports 0-d vectors.
2155  extracted = rewriter.create<vector::ExtractElementOp>(
2156  storeOp.getLoc(), storeOp.getValueToStore());
2157  } else {
2158  SmallVector<int64_t> indices(vecType.getRank(), 0);
2159  extracted = rewriter.create<vector::ExtractOp>(
2160  storeOp.getLoc(), storeOp.getValueToStore(), indices);
2161  }
2162 
2163  rewriter.replaceOpWithNewOp<memref::StoreOp>(
2164  storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
2165  return success();
2166  }
2167 };
2168 
2169 /// Progressive lowering of transfer_write. This pattern supports lowering of
2170 /// `vector.transfer_write` to `vector.store` if all of the following hold:
2171 /// - Stride of most minor memref dimension must be 1.
2172 /// - Out-of-bounds masking is not required.
2173 /// - If the memref's element type is a vector type then it coincides with the
2174 /// type of the written value.
2175 /// - The permutation map is the minor identity map (neither permutation nor
2176 /// broadcasting is allowed).
2178  : public OpRewritePattern<vector::TransferWriteOp> {
2180  llvm::Optional<unsigned> maxRank)
2181  : OpRewritePattern<vector::TransferWriteOp>(context),
2182  maxTransferRank(maxRank) {}
2183 
2184  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
2185  PatternRewriter &rewriter) const override {
2186  if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
2187  return failure();
2188 
2189  // Permutations are handled by VectorToSCF or
2190  // populateVectorTransferPermutationMapLoweringPatterns.
2191  if ( // pass-through for the 0-d corner case.
2192  !write.getPermutationMap().isMinorIdentity())
2193  return failure();
2194 
2195  auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
2196  if (!memRefType)
2197  return failure();
2198 
2199  // Non-unit strides are handled by VectorToSCF.
2200  if (!vector::isLastMemrefDimUnitStride(memRefType))
2201  return failure();
2202 
2203  // `vector.store` supports vector types as memref's elements only when the
2204  // type of the vector value being written is the same as the element type.
2205  auto memrefElTy = memRefType.getElementType();
2206  if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
2207  return failure();
2208 
2209  // Otherwise, element types of the memref and the vector must match.
2210  if (!memrefElTy.isa<VectorType>() &&
2211  memrefElTy != write.getVectorType().getElementType())
2212  return failure();
2213 
2214  // Out-of-bounds dims are handled by MaterializeTransferMask.
2215  if (write.hasOutOfBoundsDim())
2216  return failure();
2217  if (write.getMask()) {
2218  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
2219  write, write.getSource(), write.getIndices(), write.getMask(),
2220  write.getVector());
2221  } else {
2222  rewriter.replaceOpWithNewOp<vector::StoreOp>(
2223  write, write.getVector(), write.getSource(), write.getIndices());
2224  }
2225  return success();
2226  }
2227 
2229 };
2230 
2231 // Returns the values in `arrayAttr` as an integer vector.
2232 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
2233  return llvm::to_vector<4>(
2234  llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
2235  [](IntegerAttr attr) { return attr.getInt(); }));
2236 }
2237 
2238 // Shuffles vector.bitcast op after vector.extract op.
2239 //
2240 // This transforms IR like:
2241 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
2242 // %1 = vector.extract %0[3] : vector<8xf16>
2243 // Into:
2244 // %0 = vector.extract %src[1] : vector<4xf32>
2245 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
2246 // %2 = vector.extract %1[1] : vector<2xf16>
2248  : public OpRewritePattern<vector::ExtractOp> {
2250 
2251  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2252  PatternRewriter &rewriter) const override {
2253  // Only support extracting scalars for now.
2254  if (extractOp.getVectorType().getRank() != 1)
2255  return failure();
2256 
2257  auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
2258  if (!castOp)
2259  return failure();
2260 
2261  VectorType castSrcType = castOp.getSourceVectorType();
2262  VectorType castDstType = castOp.getResultVectorType();
2263  assert(castSrcType.getRank() == castDstType.getRank());
2264 
2265  // Fail to match if we only have one element in the cast op source.
2266  // This is to avoid infinite loop given that this pattern can generate
2267  // such cases.
2268  if (castSrcType.getNumElements() == 1)
2269  return failure();
2270 
2271  // Only support casting to a larger number of elements or now.
2272  // E.g., vector<4xf32> -> vector<8xf16>.
2273  if (castSrcType.getNumElements() > castDstType.getNumElements())
2274  return failure();
2275 
2276  unsigned expandRatio =
2277  castDstType.getNumElements() / castSrcType.getNumElements();
2278 
2279  auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
2280  return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
2281  };
2282 
2283  uint64_t index = getFirstIntValue(extractOp.getPosition());
2284 
2285  // Get the single scalar (as a vector) in the source value that packs the
2286  // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
2287  VectorType oneScalarType =
2288  VectorType::get({1}, castSrcType.getElementType());
2289  Value packedValue = rewriter.create<vector::ExtractOp>(
2290  extractOp.getLoc(), oneScalarType, castOp.getSource(),
2291  rewriter.getI64ArrayAttr(index / expandRatio));
2292 
2293  // Cast it to a vector with the desired scalar's type.
2294  // E.g. f32 -> vector<2xf16>
2295  VectorType packedType =
2296  VectorType::get({expandRatio}, castDstType.getElementType());
2297  Value castedValue = rewriter.create<vector::BitCastOp>(
2298  extractOp.getLoc(), packedType, packedValue);
2299 
2300  // Finally extract the desired scalar.
2301  rewriter.replaceOpWithNewOp<vector::ExtractOp>(
2302  extractOp, extractOp.getType(), castedValue,
2303  rewriter.getI64ArrayAttr(index % expandRatio));
2304 
2305  return success();
2306  }
2307 };
2308 
2309 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
2310 //
2311 // This transforms IR like:
2312 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
2313 // %0 = vector.extract_strided_slice %cast {
2314 // offsets = [4], sizes = [4], strides = [1]
2315 // } : vector<8xf16> to vector<4xf16>
2316 // Into:
2317 // %0 = vector.extract_strided_slice %src {
2318 // offsets = [2], sizes = [2], strides = [1]
2319 // } : vector<4xf32> to vector<2xf32>
2320 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
2322  : public OpRewritePattern<vector::ExtractStridedSliceOp> {
2324 
2325  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
2326  PatternRewriter &rewriter) const override {
2327  auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
2328  if (!castOp)
2329  return failure();
2330 
2331  VectorType castSrcType = castOp.getSourceVectorType();
2332  VectorType castDstType = castOp.getResultVectorType();
2333  assert(castSrcType.getRank() == castDstType.getRank());
2334 
2335  int64_t castSrcLastDim = castSrcType.getShape().back();
2336  int64_t castDstLastDim = castDstType.getShape().back();
2337  // Require casting to more elements for now; other cases to be implemented.
2338  if (castSrcLastDim > castDstLastDim)
2339  return failure();
2340 
2341  // Only accept all one strides for now.
2342  if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
2343  [](const APInt &val) { return !val.isOneValue(); }))
2344  return failure();
2345 
2346  unsigned rank = extractOp.getVectorType().getRank();
2347  assert(castDstLastDim % castSrcLastDim == 0);
2348  int64_t expandRatio = castDstLastDim / castSrcLastDim;
2349 
2350  // If we have a less number of offsets than the rank, then implicitly we
2351  // are selecting the full range for the last bitcasted dimension; other
2352  // dimensions aren't affected. Otherwise, we need to scale down the last
2353  // dimension's offset given we are extracting from less elements now.
2354  ArrayAttr newOffsets = extractOp.getOffsets();
2355  if (newOffsets.size() == rank) {
2356  SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2357  if (offsets.back() % expandRatio != 0)
2358  return failure();
2359  offsets.back() = offsets.back() / expandRatio;
2360  newOffsets = rewriter.getI64ArrayAttr(offsets);
2361  }
2362 
2363  // Similarly for sizes.
2364  ArrayAttr newSizes = extractOp.getSizes();
2365  if (newSizes.size() == rank) {
2366  SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
2367  if (sizes.back() % expandRatio != 0)
2368  return failure();
2369  sizes.back() = sizes.back() / expandRatio;
2370  newSizes = rewriter.getI64ArrayAttr(sizes);
2371  }
2372 
2374  llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
2375  dims.back() = dims.back() / expandRatio;
2376  VectorType newExtractType =
2377  VectorType::get(dims, castSrcType.getElementType());
2378 
2379  auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
2380  extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
2381  newSizes, extractOp.getStrides());
2382 
2383  rewriter.replaceOpWithNewOp<vector::BitCastOp>(
2384  extractOp, extractOp.getType(), newExtractOp);
2385 
2386  return success();
2387  }
2388 };
2389 
2390 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
2391 //
2392 // This transforms IR like:
2393 // %0 = vector.insert_strided_slice %src, %dst {
2394 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
2395 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
2396 // Into:
2397 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
2398 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
2399 // %2 = vector.insert_strided_slice %src, %dst {
2400 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
2402  : public OpRewritePattern<vector::BitCastOp> {
2404  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
2405  PatternRewriter &rewriter) const override {
2406  VectorType castSrcType = bitcastOp.getSourceVectorType();
2407  VectorType castDstType = bitcastOp.getResultVectorType();
2408  assert(castSrcType.getRank() == castDstType.getRank());
2409 
2410  int64_t castSrcLastDim = castSrcType.getShape().back();
2411  int64_t castDstLastDim = castDstType.getShape().back();
2412  // Require casting to less elements for now; other cases to be implemented.
2413  if (castSrcLastDim < castDstLastDim)
2414  return failure();
2415 
2416  assert(castSrcLastDim % castDstLastDim == 0);
2417  int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
2418 
2419  auto insertOp =
2420  bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
2421  if (!insertOp)
2422  return failure();
2423 
2424  // Only accept all one strides for now.
2425  if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
2426  [](const APInt &val) { return !val.isOneValue(); }))
2427  return failure();
2428 
2429  unsigned rank = insertOp.getSourceVectorType().getRank();
2430  // Require insert op to have the same rank for the source and destination
2431  // vector; other cases to be implemented.
2432  if (rank != insertOp.getDestVectorType().getRank())
2433  return failure();
2434 
2435  ArrayAttr newOffsets = insertOp.getOffsets();
2436  assert(newOffsets.size() == rank);
2437  SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2438  if (offsets.back() % shrinkRatio != 0)
2439  return failure();
2440  offsets.back() = offsets.back() / shrinkRatio;
2441  newOffsets = rewriter.getI64ArrayAttr(offsets);
2442 
2443  SmallVector<int64_t, 4> srcDims =
2444  llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
2445  srcDims.back() = srcDims.back() / shrinkRatio;
2446  VectorType newCastSrcType =
2447  VectorType::get(srcDims, castDstType.getElementType());
2448 
2449  auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
2450  bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
2451 
2452  SmallVector<int64_t, 4> dstDims =
2453  llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
2454  dstDims.back() = dstDims.back() / shrinkRatio;
2455  VectorType newCastDstType =
2456  VectorType::get(dstDims, castDstType.getElementType());
2457 
2458  auto newCastDstOp = rewriter.create<vector::BitCastOp>(
2459  bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
2460 
2461  rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
2462  bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
2463  insertOp.getStrides());
2464 
2465  return success();
2466  }
2467 };
2468 
2469 // Helper that returns a vector comparison that constructs a mask:
2470 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
2471 //
2472 // If `dim == 0` then the result will be a 0-D vector.
2473 //
2474 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
2475 // much more compact, IR for this operation, but LLVM eventually
2476 // generates more elaborate instructions for this intrinsic since it
2477 // is very conservative on the boundary conditions.
2479  bool force32BitVectorIndices, int64_t dim,
2480  Value b, Value *off = nullptr) {
2481  auto loc = op->getLoc();
2482  // If we can assume all indices fit in 32-bit, we perform the vector
2483  // comparison in 32-bit to get a higher degree of SIMD parallelism.
2484  // Otherwise we perform the vector comparison using 64-bit indices.
2485  Type idxType =
2486  force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
2487  DenseIntElementsAttr indicesAttr;
2488  if (dim == 0 && force32BitVectorIndices) {
2489  indicesAttr = DenseIntElementsAttr::get(
2490  VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
2491  } else if (dim == 0) {
2492  indicesAttr = DenseIntElementsAttr::get(
2493  VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
2494  } else if (force32BitVectorIndices) {
2495  indicesAttr = rewriter.getI32VectorAttr(
2496  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
2497  } else {
2498  indicesAttr = rewriter.getI64VectorAttr(
2499  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
2500  }
2501  Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
2502  // Add in an offset if requested.
2503  if (off) {
2504  Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
2505  Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
2506  indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
2507  }
2508  // Construct the vector comparison.
2509  Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
2510  Value bounds =
2511  rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
2512  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
2513  bounds);
2514 }
2515 
2516 template <typename ConcreteOp>
2517 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
2518 public:
2519  explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
2520  : mlir::OpRewritePattern<ConcreteOp>(context),
2521  force32BitVectorIndices(enableIndexOpt) {}
2522 
2523  LogicalResult matchAndRewrite(ConcreteOp xferOp,
2524  PatternRewriter &rewriter) const override {
2525  if (!xferOp.hasOutOfBoundsDim())
2526  return failure();
2527 
2528  if (xferOp.getVectorType().getRank() > 1 ||
2529  llvm::size(xferOp.getIndices()) == 0)
2530  return failure();
2531 
2532  Location loc = xferOp->getLoc();
2533  VectorType vtp = xferOp.getVectorType();
2534 
2535  // Create the in-bounds mask with all elements between [0 .. dim - offset)
2536  // set and [dim - offset .. vector_length) unset.
2537  //
2538  // TODO: when the leaf transfer rank is k > 1, we need the last `k`
2539  // dimensions here.
2540  unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
2541  Value off = xferOp.getIndices()[lastIndex];
2542  Value dim =
2543  vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
2544  Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
2545  Value mask = rewriter.create<vector::CreateMaskOp>(
2546  loc,
2547  VectorType::get(vtp.getShape(), rewriter.getI1Type(),
2548  vtp.getNumScalableDims()),
2549  b);
2550  if (xferOp.getMask()) {
2551  // Intersect the in-bounds with the mask specified as an op parameter.
2552  mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
2553  }
2554 
2555  rewriter.updateRootInPlace(xferOp, [&]() {
2556  xferOp.getMaskMutable().assign(mask);
2557  xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
2558  });
2559 
2560  return success();
2561  }
2562 
2563 private:
2564  const bool force32BitVectorIndices;
2565 };
2566 
2567 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
2569  : public OpRewritePattern<vector::CreateMaskOp> {
2570 public:
2572  bool enableIndexOpt)
2573  : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
2574  force32BitVectorIndices(enableIndexOpt) {}
2575 
2576  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
2577  PatternRewriter &rewriter) const override {
2578  auto dstType = op.getType();
2579  if (dstType.cast<VectorType>().isScalable())
2580  return failure();
2581  int64_t rank = dstType.getRank();
2582  if (rank > 1)
2583  return failure();
2584  rewriter.replaceOp(
2585  op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
2586  rank == 0 ? 0 : dstType.getDimSize(0),
2587  op.getOperand(0)));
2588  return success();
2589  }
2590 
2591 private:
2592  const bool force32BitVectorIndices;
2593 };
2594 
2595 // Drop inner most contiguous unit dimensions from transfer_read operand.
2596 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
2598 
2599  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
2600  PatternRewriter &rewriter) const override {
2601  // TODO: support 0-d corner case.
2602  if (readOp.getTransferRank() == 0)
2603  return failure();
2604 
2605  // TODO: support mask.
2606  if (readOp.getMask())
2607  return failure();
2608 
2609  auto srcType = readOp.getSource().getType().dyn_cast<MemRefType>();
2610  if (!srcType || !srcType.hasStaticShape())
2611  return failure();
2612 
2613  if (!readOp.getPermutationMap().isMinorIdentity())
2614  return failure();
2615 
2616  auto targetType = readOp.getVectorType();
2617  if (targetType.getRank() <= 1)
2618  return failure();
2619 
2620  SmallVector<int64_t> srcStrides;
2621  int64_t srcOffset;
2622  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2623  return failure();
2624 
2625  size_t dimsToDrop = 0;
2626  for (size_t i = 1; i < srcStrides.size(); ++i) {
2627  int dim = srcType.getRank() - i - 1;
2628  if (srcStrides[dim] == 1) {
2629  dimsToDrop++;
2630  } else {
2631  break;
2632  }
2633  }
2634  if (dimsToDrop == 0)
2635  return failure();
2636 
2637  auto resultTargetVecType =
2638  VectorType::get(targetType.getShape().drop_back(dimsToDrop),
2639  targetType.getElementType());
2640 
2641  MemRefType resultMemrefType;
2642  if (srcType.getLayout().getAffineMap().isIdentity()) {
2643  resultMemrefType = MemRefType::get(
2644  srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2645  {}, srcType.getMemorySpaceAsInt());
2646  } else {
2647  AffineMap map = srcType.getLayout().getAffineMap();
2648  int numSymbols = map.getNumSymbols();
2649  for (size_t i = 0; i < dimsToDrop; ++i) {
2650  int dim = srcType.getRank() - i - 1;
2651  map = map.replace(rewriter.getAffineDimExpr(dim),
2652  rewriter.getAffineConstantExpr(0),
2653  map.getNumDims() - 1, numSymbols);
2654  }
2655  resultMemrefType = MemRefType::get(
2656  srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2657  map, srcType.getMemorySpaceAsInt());
2658  }
2659 
2660  auto loc = readOp.getLoc();
2661  SmallVector<int64_t> offsets(srcType.getRank(), 0);
2662  SmallVector<int64_t> strides(srcType.getRank(), 1);
2663 
2664  ArrayAttr inBoundsAttr =
2665  readOp.getInBounds()
2666  ? rewriter.getArrayAttr(
2667  readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
2668  : ArrayAttr();
2669  Value rankedReducedView = rewriter.create<memref::SubViewOp>(
2670  loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
2671  strides);
2672  auto permMap = getTransferMinorIdentityMap(
2673  rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
2674  Value result = rewriter.create<vector::TransferReadOp>(
2675  loc, resultTargetVecType, rankedReducedView,
2676  readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
2677  readOp.getPadding(),
2678  // TODO: support mask.
2679  /*mask=*/Value(), inBoundsAttr);
2680  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
2681  result);
2682  return success();
2683  }
2684 };
2685 
2686 namespace {
2687 
2688 /// This function checks to see if the vector combining kind
2689 /// is consistent with the integer or float element type.
2690 static bool isValidKind(bool isInt, vector::CombiningKind kind) {
2691  using vector::CombiningKind;
2692  enum class KindType { FLOAT, INT, INVALID };
2693  KindType type{KindType::INVALID};
2694  switch (kind) {
2695  case CombiningKind::MINF:
2696  case CombiningKind::MAXF:
2697  type = KindType::FLOAT;
2698  break;
2699  case CombiningKind::MINUI:
2700  case CombiningKind::MINSI:
2701  case CombiningKind::MAXUI:
2702  case CombiningKind::MAXSI:
2703  case CombiningKind::AND:
2704  case CombiningKind::OR:
2705  case CombiningKind::XOR:
2706  type = KindType::INT;
2707  break;
2708  case CombiningKind::ADD:
2709  case CombiningKind::MUL:
2710  type = isInt ? KindType::INT : KindType::FLOAT;
2711  break;
2712  }
2713  bool isValidIntKind = (type == KindType::INT) && isInt;
2714  bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
2715  return (isValidIntKind || isValidFloatKind);
2716 }
2717 
2718 /// This function constructs the appropriate integer or float
2719 /// operation given the vector combining kind and operands. The
2720 /// supported int operations are : add, mul, min (signed/unsigned),
2721 /// max(signed/unsigned), and, or, xor. The supported float
2722 /// operations are : add, mul, min and max.
2723 static Value genOperator(Location loc, Value x, Value y,
2724  vector::CombiningKind kind,
2725  PatternRewriter &rewriter) {
2726  using vector::CombiningKind;
2727 
2728  auto elType = x.getType().cast<VectorType>().getElementType();
2729  bool isInt = elType.isIntOrIndex();
2730 
2731  Value combinedResult{nullptr};
2732  switch (kind) {
2733  case CombiningKind::ADD:
2734  if (isInt)
2735  combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
2736  else
2737  combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
2738  break;
2739  case CombiningKind::MUL:
2740  if (isInt)
2741  combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
2742  else
2743  combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
2744  break;
2745  case CombiningKind::MINUI:
2746  combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
2747  break;
2748  case CombiningKind::MINSI:
2749  combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
2750  break;
2751  case CombiningKind::MAXUI:
2752  combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
2753  break;
2754  case CombiningKind::MAXSI:
2755  combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
2756  break;
2757  case CombiningKind::AND:
2758  combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
2759  break;
2760  case CombiningKind::OR:
2761  combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
2762  break;
2763  case CombiningKind::XOR:
2764  combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
2765  break;
2766  case CombiningKind::MINF:
2767  combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
2768  break;
2769  case CombiningKind::MAXF:
2770  combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
2771  break;
2772  }
2773  return combinedResult;
2774 }
2775 
2776 /// Convert vector.scan op into arith ops and
2777 /// vector.insert_strided_slice/extract_strided_slice
2778 ///
2779 /// Ex:
2780 /// ```
2781 /// %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim =
2782 /// 1} :
2783 /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
2784 /// ```
2785 /// Gets converted to:
2786 /// ```
2787 /// %cst = arith.constant dense<0> : vector<2x3xi32>
2788 /// %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1],
2789 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 =
2790 /// vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]}
2791 /// : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice
2792 /// %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} :
2793 /// vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 :
2794 /// vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1],
2795 /// strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 =
2796 /// vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1],
2797 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3,
2798 /// %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets =
2799 /// [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 =
2800 /// vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 :
2801 /// vector<2x3xi32>, vector<2xi32>
2802 /// ```
2803 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
2805 
2806  LogicalResult matchAndRewrite(vector::ScanOp scanOp,
2807  PatternRewriter &rewriter) const override {
2808  auto loc = scanOp.getLoc();
2809  VectorType destType = scanOp.getDestType();
2810  ArrayRef<int64_t> destShape = destType.getShape();
2811  auto elType = destType.getElementType();
2812  bool isInt = elType.isIntOrIndex();
2813  if (!isValidKind(isInt, scanOp.getKind()))
2814  return failure();
2815 
2816  VectorType resType = VectorType::get(destShape, elType);
2817  Value result = rewriter.create<arith::ConstantOp>(
2818  loc, resType, rewriter.getZeroAttr(resType));
2819  int64_t reductionDim = scanOp.getReductionDim();
2820  bool inclusive = scanOp.getInclusive();
2821  int64_t destRank = destType.getRank();
2822  VectorType initialValueType = scanOp.getInitialValueType();
2823  int64_t initialValueRank = initialValueType.getRank();
2824 
2825  SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
2826  reductionShape[reductionDim] = 1;
2827  VectorType reductionType = VectorType::get(reductionShape, elType);
2828  SmallVector<int64_t> offsets(destRank, 0);
2829  SmallVector<int64_t> strides(destRank, 1);
2830  SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
2831  sizes[reductionDim] = 1;
2832  ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
2833  ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
2834 
2835  Value lastOutput, lastInput;
2836  for (int i = 0; i < destShape[reductionDim]; i++) {
2837  offsets[reductionDim] = i;
2838  ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
2839  Value input = rewriter.create<vector::ExtractStridedSliceOp>(
2840  loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
2841  scanStrides);
2842  Value output;
2843  if (i == 0) {
2844  if (inclusive) {
2845  output = input;
2846  } else {
2847  if (initialValueRank == 0) {
2848  // ShapeCastOp cannot handle 0-D vectors
2849  output = rewriter.create<vector::BroadcastOp>(
2850  loc, input.getType(), scanOp.getInitialValue());
2851  } else {
2852  output = rewriter.create<vector::ShapeCastOp>(
2853  loc, input.getType(), scanOp.getInitialValue());
2854  }
2855  }
2856  } else {
2857  Value y = inclusive ? input : lastInput;
2858  output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
2859  assert(output != nullptr);
2860  }
2861  result = rewriter.create<vector::InsertStridedSliceOp>(
2862  loc, output, result, offsets, strides);
2863  lastOutput = output;
2864  lastInput = input;
2865  }
2866 
2867  Value reduction;
2868  if (initialValueRank == 0) {
2869  Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
2870  reduction =
2871  rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
2872  } else {
2873  reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
2874  lastOutput);
2875  }
2876 
2877  rewriter.replaceOp(scanOp, {result, reduction});
2878  return success();
2879  }
2880 };
2881 
2882 } // namespace
2883 
2885  RewritePatternSet &patterns, bool force32BitVectorIndices) {
2889  patterns.getContext(), force32BitVectorIndices);
2890 }
2891 
2893  RewritePatternSet &patterns) {
2894  patterns.add<ShapeCastOpFolder>(patterns.getContext());
2895 }
2896 
2898  RewritePatternSet &patterns) {
2902 }
2903 
2905  RewritePatternSet &patterns) {
2906  patterns.add<BroadcastOpLowering>(patterns.getContext());
2907 }
2908 
2910  RewritePatternSet &patterns) {
2911  patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
2912  patterns.getContext());
2913 }
2914 
2916  RewritePatternSet &patterns) {
2917  patterns.add<ShapeCastOp2DDownCastRewritePattern,
2918  ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
2919  patterns.getContext());
2920 }
2921 
2924  patterns.add<OuterProductOpLowering>(patterns.getContext());
2927  patterns.getContext());
2928 }
2929 
2932  patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
2933  options, patterns.getContext());
2934 }
2935 
2937  RewritePatternSet &patterns) {
2938  patterns.add<MultiReduceToContract, CombineContractBroadcast,
2939  CombineContractTranspose, ReorderCastOpsOnBroadcast,
2940  ReorderElementwiseOpsOnTranspose>(patterns.getContext());
2941 }
2942 
2943 void mlir::vector::
2945  RewritePatternSet &patterns) {
2946  patterns.add<DropInnerMostUnitDims>(patterns.getContext());
2947 }
2948 
2950  RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
2953  maxTransferRank);
2954  patterns
2956  patterns.getContext());
2957 }
2958 
2960  RewritePatternSet &patterns) {
2961  patterns.add<ScanToArithOps>(patterns.getContext());
2962 }
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options=VectorTransformsOptions())
Insert TransposeLowering patterns into extraction/insertion.
Include the generated interface declarations.
SmallVector< int64_t, 4 > computeStrides(ArrayRef< int64_t > shape, ArrayRef< int64_t > sizes)
Given the shape and sizes of a vector, returns the corresponding strides for each dimension...
Definition: VectorUtils.cpp:54
static std::string diag(llvm::Value &v)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:115
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:653
MLIRContext * getContext() const
Definition: Builders.h:54
constexpr StringRef getParallelIteratorTypeName()
Use to encode that a particular iterator type has parallel semantics.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
static uint64_t getFirstIntValue(ArrayAttr attr)
Gets the first integer value from attr, assuming it is an integer array attribute.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:439
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:332
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, PatternRewriter &rewriter) const override
Replace a 0-d vector.store with a vector.extractelement + memref.store.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
Definition: IndexingUtils.h:38
Replace a 0-d vector.load with a memref.load + vector.broadcast.
bool isParallelIterator(Attribute attr)
unsigned getNumSymbols() const
Definition: AffineMap.cpp:298
unsigned getNumDims() const
Definition: AffineMap.cpp:294
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:288
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:295
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:477
Progressive lowering of a vector.contract a, b, c with row-major matmul semantics to a reduction_size...
static SmallVector< IntType, 4 > extractVector(ArrayAttr arrayAttr)
static SmallVector< Attribute, 4 > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:356
Value getOperand(unsigned idx)
Definition: Operation.h:267
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:205
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
unsigned getNumOperands()
Definition: Operation.h:263
VectorTransposeLowering vectorTransposeLowering
Option to control the lowering of vector.transpose.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:685
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:87
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:384
Progressive lowering of transfer_write.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
IteratorType
Typed representation for loop type strings.
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
Progressive lowering of ContractionOp.
static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d)
Helper method to apply dimension ordering permutation.
bool isLastMemrefDimUnitStride(MemRefType type)
Return true if the last dimension of the MemRefType has unit stride.
Definition: VectorOps.cpp:115
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:244
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Progressive lowering of a vector.contract a, b, c with row-major matmul semantics to an output-size-u...
Progressively lower to finer grained vector.contract and dot-products.
static ArrayRef< int64_t > vectorShape(Type type)
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns)
Collect a set of patterns that bubble up/down bitcast ops.
static constexpr const bool value
LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override
Builder & dropDim(unsigned pos)
Erase a dim from shape .
Definition: BuiltinTypes.h:305
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:358
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:282
static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, bool force32BitVectorIndices, int64_t dim, Value b, Value *off=nullptr)
static Optional< int64_t > getResultIndex(AffineMap map, int64_t index)
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:172
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine...
Definition: Diagnostics.h:157
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
Progressive lowering of ContractionOp.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
Progressively lower a vector.contract a, b, c with row-major matmul semantics to a reduction_size-unr...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Lower contract with all reduction dimensions unrolled to 1 to a vector elementwise operations...
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:78
Optional< DistributeOps > distributPointwiseVectorOp(OpBuilder &builder, Operation *op, ArrayRef< Value > id, ArrayRef< int64_t > multiplicity, const AffineMap &map)
Distribute a N-D vector pointwise operation over a range of given ids taking all values in [0 ...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:720
U dyn_cast() const
Definition: Types.h:270
LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
VectorContractLowering vectorContractLowering
Option to control the lowering of vector.contract.
llvm::Optional< unsigned > maxTransferRank
TransferWriteToVectorStoreLowering(MLIRContext *context, llvm::Optional< unsigned > maxRank)
U dyn_cast() const
Definition: Value.h:100
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options=VectorTransformsOptions())
Collects patterns to progressively lower vector contraction ops on high-D into low-D reduction and pr...
Lower 2-D transpose to vector.flat_transpose, maps 1-1 to LLVM matrix intrinsics. ...
Base type for affine expression.
Definition: AffineExpr.h:68
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
SmallVector< int64_t, 4 > delinearize(ArrayRef< int64_t > strides, int64_t linearIndex)
Given the strides together with a linear index in the dimension space, returns the vector-space offse...
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:324
DenseIntElementsAttr getBoolVectorAttr(ArrayRef< bool > values)
Vector-typed DenseIntElementsAttr getters. values must not be empty.
Definition: Builders.cpp:103
IntegerType getI1Type()
Definition: Builders.cpp:50
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:360
LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::LoadOp loadOp, PatternRewriter &rewriter) const override
unsigned getNumResults() const
Definition: AffineMap.cpp:302
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices)
These patterns materialize masks for various vector ops such as transfers.
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:42
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:258
LogicalResult matchAndRewrite(ConcreteOp xferOp, PatternRewriter &rewriter) const override
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:307
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:407
StringRef strRef
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:489
Lower to vector.matrix_multiply, maps 1-1 to LLVM matrix intrinsics.
LogicalResult matchAndRewrite(vector::ExtractOp extractOp, PatternRewriter &rewriter) const override
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:233
VectorCreateMaskOpConversion(MLIRContext *context, bool enableIndexOpt)
LogicalResult matchAndRewrite(vector::StoreOp storeOp, PatternRewriter &rewriter) const override
Progressive lowering of transfer_read.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
bool isReductionIterator(Attribute attr)
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:315
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source...
Definition: VectorUtils.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns)
Collects patterns to progressively lower vector mask ops into elementary selection and insertion ops...
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:122
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value v2)
Return the result value of reducing two scalar/vector values with the corresponding arith operation...
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns)
Collects patterns to progressively lower vector.shape_cast ops on high-D vectors into 1-D/2-D vector ...
IntegerType getI64Type()
Definition: Builders.cpp:56
LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override
static llvm::ManagedStatic< PassManagerOptions > options
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:62
static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:279
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:317
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:294
Type getType() const
Return the type of this value.
Definition: Value.h:118
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns)
Collect patterns to convert scan op.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Do not split vector transfer operations.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are not listed in unusedDims.
Definition: AffineMap.cpp:543
U dyn_cast() const
Definition: Attributes.h:127
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:332
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:753
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns)
Collect a set of vector.shape_cast folding patterns.
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
TransferReadToVectorLoadLowering(MLIRContext *context, llvm::Optional< unsigned > maxRank)
Progressive lowering of a vector.contract a, b, c with row-major matmul semantics to: %flattened_a = ...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Structure to control the behavior of vector transform patterns.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:109
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, llvm::Optional< unsigned > maxTransferRank=llvm::None)
Collect a set of transfer read/write lowering patterns.
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:309
static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static Optional< Value > createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt)
Helper to create arithmetic operation associated with a kind of contraction.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:512
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
Progressively lower a vector.contract a, b, c with row-major matmul semantics to: %mta = maybe_transp...
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:328
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
static Value createMul(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates a MulIOp if isInt is true otherwise create an MulFOp using operands x andy`.
bool isa() const
Definition: Types.h:254
constexpr StringRef getReductionIteratorTypeName()
Use to encode that a particular iterator type has reduction semantics.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:235
This class helps build Operations.
Definition: Builders.h:192
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:229
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
Conversion pattern for a vector.create_mask (0-D and 1-D only).
result_type_range getResultTypes()
Definition: Operation.h:345
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns)
Collects patterns to progressively lower vector.broadcast ops on high-D vectors to low-D vector ops...
MLIRContext * getContext() const
IntegerType getI32Type()
Definition: Builders.cpp:54
static SmallVector< int64_t > getReductionIndex(AffineMap map, ArrayAttr iteratorTypes)
Return the positions of the reductions in the given map.
static llvm::Optional< unsigned > getDimPosition(AffineMap map, unsigned dim)
Look for a given dimension in an affine map and return its position.
llvm::Optional< unsigned > maxTransferRank
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
Definition: PatternMatch.h:382
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:370
An attribute that represents a reference to a dense integer vector or tensor object.
U cast() const
Definition: Types.h:278
Lower 2-D transpose to vector.shuffle.
static SmallVector< int64_t, 4 > getIntValueVector(ArrayAttr arrayAttr)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:270