MLIR  14.0.0git
VectorTransforms.cpp
Go to the documentation of this file.
1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <type_traits>
14 
19 #include "mlir/Dialect/SCF/SCF.h"
22 
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
28 
29 #include "llvm/ADT/DenseSet.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
35 
36 #define DEBUG_TYPE "vector-to-vector"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 // Helper to find an index in an affine map.
42 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
43  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
44  int64_t idx = map.getDimPosition(i);
45  if (idx == index)
46  return i;
47  }
48  return None;
49 }
50 
51 // Helper to construct iterator types with one index removed.
52 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
53  int64_t index) {
54  SmallVector<Attribute, 4> results;
55  for (const auto &it : llvm::enumerate(iteratorTypes)) {
56  int64_t idx = it.index();
57  if (idx == index)
58  continue;
59  results.push_back(it.value());
60  }
61  return results;
62 }
63 
64 // Helper to construct an affine map with one index removed.
65 static AffineMap adjustMap(AffineMap map, int64_t index,
66  PatternRewriter &rewriter) {
67  auto *ctx = rewriter.getContext();
68  SmallVector<AffineExpr, 4> results;
69  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
70  int64_t idx = map.getDimPosition(i);
71  if (idx == index)
72  continue;
73  // Re-insert remaining indices, but renamed when occurring
74  // after the removed index.
75  auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
76  results.push_back(targetExpr);
77  }
78  return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
79 }
80 
81 // Helper method to possibly drop a dimension in a load.
82 // TODO
83 static Value reshapeLoad(Location loc, Value val, VectorType type,
84  int64_t index, int64_t pos,
85  PatternRewriter &rewriter) {
86  if (index == -1)
87  return val;
88  Type lowType = VectorType::Builder(type).dropDim(0);
89  // At extraction dimension?
90  if (index == 0) {
91  auto posAttr = rewriter.getI64ArrayAttr(pos);
92  return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
93  }
94  // Unroll leading dimensions.
95  VectorType vType = lowType.cast<VectorType>();
96  Type resType = VectorType::Builder(type).dropDim(index);
97  auto resVectorType = resType.cast<VectorType>();
98  Value result = rewriter.create<arith::ConstantOp>(
99  loc, resVectorType, rewriter.getZeroAttr(resVectorType));
100  for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
101  auto posAttr = rewriter.getI64ArrayAttr(d);
102  Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
103  Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
104  result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
105  posAttr);
106  }
107  return result;
108 }
109 
110 // Helper method to possibly drop a dimension in a store.
111 // TODO
112 static Value reshapeStore(Location loc, Value val, Value result,
113  VectorType type, int64_t index, int64_t pos,
114  PatternRewriter &rewriter) {
115  // Unmodified?
116  if (index == -1)
117  return val;
118  // At insertion dimension?
119  if (index == 0) {
120  auto posAttr = rewriter.getI64ArrayAttr(pos);
121  return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
122  }
123  // Unroll leading dimensions.
124  Type lowType = VectorType::Builder(type).dropDim(0);
125  VectorType vType = lowType.cast<VectorType>();
126  Type insType = VectorType::Builder(vType).dropDim(0);
127  for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
128  auto posAttr = rewriter.getI64ArrayAttr(d);
129  Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
130  Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
131  Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
132  result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
133  }
134  return result;
135 }
136 
137 template <typename IntType>
138 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
139  return llvm::to_vector<4>(llvm::map_range(
140  arrayAttr.getAsRange<IntegerAttr>(),
141  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
142 }
143 
144 namespace {
145 
146 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
147 //
148 // Example:
149 //
150 // The following MLIR with cancelling ShapeCastOps:
151 //
152 // %0 = source : vector<5x4x2xf32>
153 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
154 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
155 // %3 = user %2 : vector<5x4x2xf32>
156 //
157 // Should canonicalize to the following:
158 //
159 // %0 = source : vector<5x4x2xf32>
160 // %1 = user %0 : vector<5x4x2xf32>
161 //
162 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
164 
165  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
166  PatternRewriter &rewriter) const override {
167  // Check if 'shapeCastOp' has vector source/result type.
168  auto sourceVectorType =
169  shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
170  auto resultVectorType =
171  shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
172  if (!sourceVectorType || !resultVectorType)
173  return failure();
174 
175  // Check if shape cast op source operand is also a shape cast op.
176  auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
177  shapeCastOp.source().getDefiningOp());
178  if (!sourceShapeCastOp)
179  return failure();
180  auto operandSourceVectorType =
181  sourceShapeCastOp.source().getType().cast<VectorType>();
182  auto operandResultVectorType = sourceShapeCastOp.getType();
183 
184  // Check if shape cast operations invert each other.
185  if (operandSourceVectorType != resultVectorType ||
186  operandResultVectorType != sourceVectorType)
187  return failure();
188 
189  rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
190  return success();
191  }
192 };
193 
194 /// Progressive lowering of BroadcastOp.
195 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
196 public:
198 
199  LogicalResult matchAndRewrite(vector::BroadcastOp op,
200  PatternRewriter &rewriter) const override {
201  auto loc = op.getLoc();
202  VectorType dstType = op.getVectorType();
203  VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
204  Type eltType = dstType.getElementType();
205 
206  // Scalar to any vector can use splat.
207  if (!srcType) {
208  rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, op.source());
209  return success();
210  }
211 
212  // Determine rank of source and destination.
213  int64_t srcRank = srcType.getRank();
214  int64_t dstRank = dstType.getRank();
215 
216  // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
217  if (srcRank <= 1 && dstRank == 1) {
218  Value ext;
219  if (srcRank == 0)
220  ext = rewriter.create<vector::ExtractElementOp>(loc, op.source());
221  else
222  ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
223  rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, ext);
224  return success();
225  }
226 
227  // Duplicate this rank.
228  // For example:
229  // %x = broadcast %y : k-D to n-D, k < n
230  // becomes:
231  // %b = broadcast %y : k-D to (n-1)-D
232  // %x = [%b,%b,%b,%b] : n-D
233  // becomes:
234  // %b = [%y,%y] : (n-1)-D
235  // %x = [%b,%b,%b,%b] : n-D
236  if (srcRank < dstRank) {
237  // Duplication.
238  VectorType resType =
239  VectorType::get(dstType.getShape().drop_front(), eltType);
240  Value bcst =
241  rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
242  Value result = rewriter.create<arith::ConstantOp>(
243  loc, dstType, rewriter.getZeroAttr(dstType));
244  for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
245  result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
246  rewriter.replaceOp(op, result);
247  return success();
248  }
249 
250  // Find non-matching dimension, if any.
251  assert(srcRank == dstRank);
252  int64_t m = -1;
253  for (int64_t r = 0; r < dstRank; r++)
254  if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
255  m = r;
256  break;
257  }
258 
259  // All trailing dimensions are the same. Simply pass through.
260  if (m == -1) {
261  rewriter.replaceOp(op, op.source());
262  return success();
263  }
264 
265  // Any non-matching dimension forces a stretch along this rank.
266  // For example:
267  // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
268  // becomes:
269  // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
270  // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
271  // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
272  // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
273  // %x = [%a,%b,%c,%d]
274  // becomes:
275  // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
276  // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
277  // %a = [%u, %v]
278  // ..
279  // %x = [%a,%b,%c,%d]
280  VectorType resType =
281  VectorType::get(dstType.getShape().drop_front(), eltType);
282  Value result = rewriter.create<arith::ConstantOp>(
283  loc, dstType, rewriter.getZeroAttr(dstType));
284  if (m == 0) {
285  // Stetch at start.
286  Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
287  Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
288  for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
289  result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
290  } else {
291  // Stetch not at start.
292  for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
293  Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d);
294  Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
295  result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
296  }
297  }
298  rewriter.replaceOp(op, result);
299  return success();
300  }
301 };
302 
303 /// Progressive lowering of TransposeOp.
304 /// One:
305 /// %x = vector.transpose %y, [1, 0]
306 /// is replaced by:
307 /// %z = arith.constant dense<0.000000e+00>
308 /// %0 = vector.extract %y[0, 0]
309 /// %1 = vector.insert %0, %z [0, 0]
310 /// ..
311 /// %x = vector.insert .., .. [.., ..]
312 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
313 public:
315 
317  MLIRContext *context)
319  vectorTransformOptions(vectorTransformOptions) {}
320 
321  LogicalResult matchAndRewrite(vector::TransposeOp op,
322  PatternRewriter &rewriter) const override {
323  auto loc = op.getLoc();
324 
325  VectorType resType = op.getResultType();
326 
327  // Set up convenience transposition table.
328  SmallVector<int64_t, 4> transp;
329  for (auto attr : op.transp())
330  transp.push_back(attr.cast<IntegerAttr>().getInt());
331 
332  if (vectorTransformOptions.vectorTransposeLowering ==
334  resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
335  return rewriter.notifyMatchFailure(
336  op, "Options specifies lowering to shuffle");
337 
338  // Handle a true 2-D matrix transpose differently when requested.
339  if (vectorTransformOptions.vectorTransposeLowering ==
341  resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
342  Type flattenedType =
343  VectorType::get(resType.getNumElements(), resType.getElementType());
344  auto matrix =
345  rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector());
346  auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
347  auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
348  Value trans = rewriter.create<vector::FlatTransposeOp>(
349  loc, flattenedType, matrix, rows, columns);
350  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
351  return success();
352  }
353 
354  // Generate fully unrolled extract/insert ops.
355  Value result = rewriter.create<arith::ConstantOp>(
356  loc, resType, rewriter.getZeroAttr(resType));
357  SmallVector<int64_t, 4> lhs(transp.size(), 0);
358  SmallVector<int64_t, 4> rhs(transp.size(), 0);
359  rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs,
360  op.vector(), result, rewriter));
361  return success();
362  }
363 
364 private:
365  // Builds the indices arrays for the lhs and rhs. Generates the extract/insert
366  // operation when al ranks are exhausted.
367  Value expandIndices(Location loc, VectorType resType, int64_t pos,
368  SmallVector<int64_t, 4> &transp,
369  SmallVector<int64_t, 4> &lhs,
370  SmallVector<int64_t, 4> &rhs, Value input, Value result,
371  PatternRewriter &rewriter) const {
372  if (pos >= resType.getRank()) {
373  auto ridx = rewriter.getI64ArrayAttr(rhs);
374  auto lidx = rewriter.getI64ArrayAttr(lhs);
375  Type eltType = resType.getElementType();
376  Value e = rewriter.create<vector::ExtractOp>(loc, eltType, input, ridx);
377  return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx);
378  }
379  for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) {
380  lhs[pos] = d;
381  rhs[transp[pos]] = d;
382  result = expandIndices(loc, resType, pos + 1, transp, lhs, rhs, input,
383  result, rewriter);
384  }
385  return result;
386  }
387 
388  /// Options to control the vector patterns.
389  vector::VectorTransformsOptions vectorTransformOptions;
390 };
391 
392 /// Rewrite a 2-D vector.transpose as a sequence of:
393 /// vector.shape_cast 2D -> 1D
394 /// vector.shuffle
395 /// vector.shape_cast 1D -> 2D
396 class TransposeOp2DToShuffleLowering
397  : public OpRewritePattern<vector::TransposeOp> {
398 public:
400 
401  TransposeOp2DToShuffleLowering(
402  vector::VectorTransformsOptions vectorTransformOptions,
403  MLIRContext *context)
405  vectorTransformOptions(vectorTransformOptions) {}
406 
407  LogicalResult matchAndRewrite(vector::TransposeOp op,
408  PatternRewriter &rewriter) const override {
409  auto loc = op.getLoc();
410 
411  VectorType srcType = op.getVectorType();
412  if (srcType.getRank() != 2)
413  return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
414 
415  SmallVector<int64_t, 4> transp;
416  for (auto attr : op.transp())
417  transp.push_back(attr.cast<IntegerAttr>().getInt());
418  if (transp[0] != 1 && transp[1] != 0)
419  return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
420 
421  if (vectorTransformOptions.vectorTransposeLowering !=
423  return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
424 
425  int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
426  Value casted = rewriter.create<vector::ShapeCastOp>(
427  loc, VectorType::get({m * n}, srcType.getElementType()), op.vector());
428  SmallVector<int64_t> mask;
429  mask.reserve(m * n);
430  for (int64_t j = 0; j < n; ++j)
431  for (int64_t i = 0; i < m; ++i)
432  mask.push_back(i * n + j);
433 
434  Value shuffled =
435  rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
436  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
437  shuffled);
438 
439  return success();
440  }
441 
442 private:
443  /// Options to control the vector patterns.
444  vector::VectorTransformsOptions vectorTransformOptions;
445 };
446 
447 /// Progressive lowering of OuterProductOp.
448 /// One:
449 /// %x = vector.outerproduct %lhs, %rhs, %acc
450 /// is replaced by:
451 /// %z = zero-result
452 /// %0 = vector.extract %lhs[0]
453 /// %1 = vector.broadcast %0
454 /// %2 = vector.extract %acc[0]
455 /// %3 = vector.fma %1, %rhs, %2
456 /// %4 = vector.insert %3, %z[0]
457 /// ..
458 /// %x = vector.insert %.., %..[N-1]
459 ///
460 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
461 public:
463 
464  LogicalResult matchAndRewrite(vector::OuterProductOp op,
465  PatternRewriter &rewriter) const override {
466  auto loc = op.getLoc();
467 
468  VectorType lhsType = op.getOperandVectorTypeLHS();
469  VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
470  VectorType resType = op.getVectorType();
471  Type eltType = resType.getElementType();
472  bool isInt = eltType.isa<IntegerType, IndexType>();
473  Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
474  vector::CombiningKind kind = op.kind();
475 
476  if (!rhsType) {
477  // Special case: AXPY operation.
478  Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs());
479  Optional<Value> mult =
480  isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter)
481  : genMultF(loc, op.lhs(), b, acc, kind, rewriter);
482  if (!mult.hasValue())
483  return failure();
484  rewriter.replaceOp(op, mult.getValue());
485  return success();
486  }
487 
488  Value result = rewriter.create<arith::ConstantOp>(
489  loc, resType, rewriter.getZeroAttr(resType));
490  for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
491  auto pos = rewriter.getI64ArrayAttr(d);
492  Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
493  Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
494  Value r = nullptr;
495  if (acc)
496  r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
497  Optional<Value> m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter)
498  : genMultF(loc, a, op.rhs(), r, kind, rewriter);
499  if (!m.hasValue())
500  return failure();
501  result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(),
502  result, pos);
503  }
504  rewriter.replaceOp(op, result);
505  return success();
506  }
507 
508 private:
509  static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc,
510  vector::CombiningKind kind,
511  PatternRewriter &rewriter) {
512  using vector::CombiningKind;
513 
514  auto mul = rewriter.create<arith::MulIOp>(loc, x, y);
515  if (!acc)
516  return Optional<Value>(mul);
517 
518  Value combinedResult;
519  switch (kind) {
520  case CombiningKind::ADD:
521  combinedResult = rewriter.create<arith::AddIOp>(loc, mul, acc);
522  break;
523  case CombiningKind::MUL:
524  combinedResult = rewriter.create<arith::MulIOp>(loc, mul, acc);
525  break;
526  case CombiningKind::MINUI:
527  combinedResult = rewriter.create<arith::MinUIOp>(loc, mul, acc);
528  break;
529  case CombiningKind::MINSI:
530  combinedResult = rewriter.create<arith::MinSIOp>(loc, mul, acc);
531  break;
532  case CombiningKind::MAXUI:
533  combinedResult = rewriter.create<arith::MaxUIOp>(loc, mul, acc);
534  break;
535  case CombiningKind::MAXSI:
536  combinedResult = rewriter.create<arith::MaxSIOp>(loc, mul, acc);
537  break;
538  case CombiningKind::AND:
539  combinedResult = rewriter.create<arith::AndIOp>(loc, mul, acc);
540  break;
541  case CombiningKind::OR:
542  combinedResult = rewriter.create<arith::OrIOp>(loc, mul, acc);
543  break;
544  case CombiningKind::XOR:
545  combinedResult = rewriter.create<arith::XOrIOp>(loc, mul, acc);
546  break;
547  case CombiningKind::MINF: // Only valid for floating point types.
548  case CombiningKind::MAXF: // Only valid for floating point types.
549  return Optional<Value>();
550  }
551  return Optional<Value>(combinedResult);
552  }
553 
554  static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc,
555  vector::CombiningKind kind,
556  PatternRewriter &rewriter) {
557  using vector::CombiningKind;
558 
559  // Special case for fused multiply-add.
560  if (acc && kind == CombiningKind::ADD) {
561  return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
562  }
563 
564  auto mul = rewriter.create<arith::MulFOp>(loc, x, y);
565 
566  if (!acc)
567  return Optional<Value>(mul);
568 
569  Value combinedResult;
570  switch (kind) {
571  case CombiningKind::MUL:
572  combinedResult = rewriter.create<arith::MulFOp>(loc, mul, acc);
573  break;
574  case CombiningKind::MINF:
575  combinedResult = rewriter.create<arith::MinFOp>(loc, mul, acc);
576  break;
577  case CombiningKind::MAXF:
578  combinedResult = rewriter.create<arith::MaxFOp>(loc, mul, acc);
579  break;
580  case CombiningKind::ADD: // Already handled this special case above.
581  case CombiningKind::AND: // Only valid for integer types.
582  case CombiningKind::MINUI: // Only valid for integer types.
583  case CombiningKind::MINSI: // Only valid for integer types.
584  case CombiningKind::MAXUI: // Only valid for integer types.
585  case CombiningKind::MAXSI: // Only valid for integer types.
586  case CombiningKind::OR: // Only valid for integer types.
587  case CombiningKind::XOR: // Only valid for integer types.
588  return Optional<Value>();
589  }
590  return Optional<Value>(combinedResult);
591  }
592 };
593 
594 /// Progressive lowering of ConstantMaskOp.
595 /// One:
596 /// %x = vector.constant_mask [a,b]
597 /// is replaced by:
598 /// %z = zero-result
599 /// %l = vector.constant_mask [b]
600 /// %4 = vector.insert %l, %z[0]
601 /// ..
602 /// %x = vector.insert %l, %..[a-1]
603 /// until a one-dimensional vector is reached. All these operations
604 /// will be folded at LLVM IR level.
605 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
606 public:
608 
609  LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
610  PatternRewriter &rewriter) const override {
611  auto loc = op.getLoc();
612  auto dstType = op.getType();
613  auto eltType = dstType.getElementType();
614  auto dimSizes = op.mask_dim_sizes();
615  int64_t rank = dstType.getRank();
616 
617  if (rank == 0) {
618  assert(dimSizes.size() == 1 &&
619  "Expected exactly one dim size for a 0-D vector");
620  bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
621  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
622  op, dstType,
624  VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
625  ArrayRef<bool>{value}));
626  return success();
627  }
628 
629  int64_t trueDim = std::min(dstType.getDimSize(0),
630  dimSizes[0].cast<IntegerAttr>().getInt());
631 
632  if (rank == 1) {
633  // Express constant 1-D case in explicit vector form:
634  // [T,..,T,F,..,F].
635  SmallVector<bool, 4> values(dstType.getDimSize(0));
636  for (int64_t d = 0; d < trueDim; d++)
637  values[d] = true;
638  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
639  op, dstType, rewriter.getBoolVectorAttr(values));
640  return success();
641  }
642 
643  VectorType lowType =
644  VectorType::get(dstType.getShape().drop_front(), eltType);
645  SmallVector<int64_t, 4> newDimSizes;
646  for (int64_t r = 1; r < rank; r++)
647  newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
648  Value trueVal = rewriter.create<vector::ConstantMaskOp>(
649  loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
650  Value result = rewriter.create<arith::ConstantOp>(
651  loc, dstType, rewriter.getZeroAttr(dstType));
652  for (int64_t d = 0; d < trueDim; d++) {
653  auto pos = rewriter.getI64ArrayAttr(d);
654  result =
655  rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
656  }
657  rewriter.replaceOp(op, result);
658  return success();
659  }
660 };
661 
662 /// Progressive lowering of CreateMaskOp.
663 /// One:
664 /// %x = vector.create_mask %a, ... : vector<dx...>
665 /// is replaced by:
666 /// %l = vector.create_mask ... : vector<...> ; one lower rank
667 /// %0 = arith.cmpi "slt", %ci, %a |
668 /// %1 = select %0, %l, %zeroes |
669 /// %r = vector.insert %1, %pr [i] | d-times
670 /// %x = ....
671 /// until a one-dimensional vector is reached.
672 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
673 public:
675 
676  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
677  PatternRewriter &rewriter) const override {
678  auto dstType = op.getResult().getType().cast<VectorType>();
679  int64_t rank = dstType.getRank();
680  if (rank <= 1)
681  return rewriter.notifyMatchFailure(
682  op, "0-D and 1-D vectors are handled separately");
683 
684  auto loc = op.getLoc();
685  auto eltType = dstType.getElementType();
686  int64_t dim = dstType.getDimSize(0);
687  Value idx = op.getOperand(0);
688 
689  VectorType lowType =
690  VectorType::get(dstType.getShape().drop_front(), eltType);
691  Value trueVal = rewriter.create<vector::CreateMaskOp>(
692  loc, lowType, op.getOperands().drop_front());
693  Value falseVal = rewriter.create<arith::ConstantOp>(
694  loc, lowType, rewriter.getZeroAttr(lowType));
695  Value result = rewriter.create<arith::ConstantOp>(
696  loc, dstType, rewriter.getZeroAttr(dstType));
697  for (int64_t d = 0; d < dim; d++) {
698  Value bnd =
699  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
700  Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
701  bnd, idx);
702  Value sel = rewriter.create<SelectOp>(loc, val, trueVal, falseVal);
703  auto pos = rewriter.getI64ArrayAttr(d);
704  result =
705  rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
706  }
707  rewriter.replaceOp(op, result);
708  return success();
709  }
710 };
711 
712 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
713 /// vectors progressively on the way to target llvm.matrix intrinsics.
714 /// This iterates over the most major dimension of the 2-D vector and performs
715 /// rewrites into:
716 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
717 class ShapeCastOp2DDownCastRewritePattern
718  : public OpRewritePattern<vector::ShapeCastOp> {
719 public:
721 
722  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
723  PatternRewriter &rewriter) const override {
724  auto sourceVectorType = op.getSourceVectorType();
725  auto resultVectorType = op.getResultVectorType();
726  if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
727  return failure();
728 
729  auto loc = op.getLoc();
730  Value desc = rewriter.create<arith::ConstantOp>(
731  loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
732  unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
733  for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
734  Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
735  desc = rewriter.create<vector::InsertStridedSliceOp>(
736  loc, vec, desc,
737  /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
738  }
739  rewriter.replaceOp(op, desc);
740  return success();
741  }
742 };
743 
744 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
745 /// vectors progressively.
746 /// This iterates over the most major dimension of the 2-D vector and performs
747 /// rewrites into:
748 /// vector.extract_strided_slice from 1-D + vector.insert into 2-D
749 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
750 class ShapeCastOp2DUpCastRewritePattern
751  : public OpRewritePattern<vector::ShapeCastOp> {
752 public:
754 
755  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
756  PatternRewriter &rewriter) const override {
757  auto sourceVectorType = op.getSourceVectorType();
758  auto resultVectorType = op.getResultVectorType();
759  if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
760  return failure();
761 
762  auto loc = op.getLoc();
763  Value desc = rewriter.create<arith::ConstantOp>(
764  loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
765  unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
766  for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
767  Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
768  loc, op.source(), /*offsets=*/i * mostMinorVectorSize,
769  /*sizes=*/mostMinorVectorSize,
770  /*strides=*/1);
771  desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
772  }
773  rewriter.replaceOp(op, desc);
774  return success();
775  }
776 };
777 
778 // We typically should not lower general shape cast operations into data
779 // movement instructions, since the assumption is that these casts are
780 // optimized away during progressive lowering. For completeness, however,
781 // we fall back to a reference implementation that moves all elements
782 // into the right place if we get here.
783 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
784 public:
786 
787  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
788  PatternRewriter &rewriter) const override {
789  Location loc = op.getLoc();
790  auto sourceVectorType = op.getSourceVectorType();
791  auto resultVectorType = op.getResultVectorType();
792 
793  // Special case 2D/1D lowerings with better implementations.
794  // TODO: make is ND/1D to allow generic ND->1D->MD.
795  int64_t srcRank = sourceVectorType.getRank();
796  int64_t resRank = resultVectorType.getRank();
797  if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
798  return failure();
799 
800  // Generic ShapeCast lowering path goes all the way down to unrolled scalar
801  // extract/insert chains.
802  // TODO: consider evolving the semantics to only allow 1D source or dest and
803  // drop this potentially very expensive lowering.
804  // Compute number of elements involved in the reshape.
805  int64_t numElts = 1;
806  for (int64_t r = 0; r < srcRank; r++)
807  numElts *= sourceVectorType.getDimSize(r);
808  // Replace with data movement operations:
809  // x[0,0,0] = y[0,0]
810  // x[0,0,1] = y[0,1]
811  // x[0,1,0] = y[0,2]
812  // etc., incrementing the two index vectors "row-major"
813  // within the source and result shape.
814  SmallVector<int64_t, 4> srcIdx(srcRank);
815  SmallVector<int64_t, 4> resIdx(resRank);
816  Value result = rewriter.create<arith::ConstantOp>(
817  loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
818  for (int64_t i = 0; i < numElts; i++) {
819  if (i != 0) {
820  incIdx(srcIdx, sourceVectorType, srcRank - 1);
821  incIdx(resIdx, resultVectorType, resRank - 1);
822  }
823  Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx);
824  result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
825  }
826  rewriter.replaceOp(op, result);
827  return success();
828  }
829 
830 private:
831  static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
832  assert(0 <= r && r < tp.getRank());
833  if (++idx[r] == tp.getDimSize(r)) {
834  idx[r] = 0;
835  incIdx(idx, tp, r - 1);
836  }
837  }
838 };
839 
840 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
841 /// Ex:
842 /// ```
843 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
844 /// %1 = vector.multi_reduction add, %0 [1]
845 /// : vector<8x32x16xf32> to vector<8x16xf32>
846 /// ```
847 /// Gets converted to:
848 /// ```
849 /// %1 = vector.contract {indexing_maps = [
850 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
851 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
852 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
853 /// iterator_types = ["parallel", "parallel", "reduction"],
854 /// kind = add} %0, %arg1, %cst_f0
855 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
856 /// ```
857 struct MultiReduceToContract
858  : public OpRewritePattern<vector::MultiDimReductionOp> {
860 
861  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
862  PatternRewriter &rewriter) const override {
863  if (reduceOp.kind() != vector::CombiningKind::ADD)
864  return failure();
865  Operation *mulOp = reduceOp.source().getDefiningOp();
866  if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
867  return failure();
868  SmallVector<bool> reductionMask = reduceOp.getReductionMask();
869  auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
870  SmallVector<AffineExpr> exprs;
871  SmallVector<StringRef> iteratorTypes;
872  for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
873  if (!isReduceDim.value()) {
874  iteratorTypes.push_back(getParallelIteratorTypeName());
875  exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
876  } else {
877  iteratorTypes.push_back(getReductionIteratorTypeName());
878  }
879  }
880  auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
881  /*symCount=*/0, exprs, reduceOp.getContext());
882  Value zero = rewriter.create<arith::ConstantOp>(
883  reduceOp.getLoc(), reduceOp.getDestType(),
884  rewriter.getZeroAttr(reduceOp.getDestType()));
885  rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
886  reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero,
887  rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
888  rewriter.getStrArrayAttr(iteratorTypes));
889  return success();
890  }
891 };
892 
893 /// Merge TransposeOp into ContractionOp user.
894 /// Ex:
895 /// ```
896 /// %0 = vector.transpose %arg0, [2, 0, 1]
897 /// : vector<32x16x8xf32> to vector<8x32x16xf32>
898 /// %1 = vector.contract {indexing_maps = [
899 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
900 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
901 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
902 /// iterator_types = ["parallel", "parallel", "reduction"],
903 /// kind = add} %0, %arg1, %cst_f0
904 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
905 /// ```
906 /// Gets converted to:
907 /// ```
908 /// %1 = vector.contract {indexing_maps = [
909 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
910 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
911 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
912 /// iterator_types = ["parallel", "parallel", "reduction"],
913 /// kind = add} %arg0, %arg1, %cst_f0
914 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
915 /// ```
916 struct CombineContractTranspose
917  : public OpRewritePattern<vector::ContractionOp> {
919 
920  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
921  PatternRewriter &rewriter) const override {
922  SmallVector<AffineMap, 4> maps =
923  llvm::to_vector<4>(contractOp.getIndexingMaps());
924  Value lhs = contractOp.lhs();
925  Value rhs = contractOp.rhs();
926  size_t index = 0;
927  bool changed = false;
928  for (Value *operand : {&lhs, &rhs}) {
929  AffineMap &map = maps[index++];
930  auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
931  if (!transposeOp)
932  continue;
933  SmallVector<int64_t> perm;
934  transposeOp.getTransp(perm);
935  AffineMap permutationMap = AffineMap::getPermutationMap(
936  extractVector<unsigned>(transposeOp.transp()),
937  contractOp.getContext());
938  map = inversePermutation(permutationMap).compose(map);
939  *operand = transposeOp.vector();
940  changed = true;
941  }
942  if (!changed)
943  return failure();
944  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
945  contractOp, lhs, rhs, contractOp.acc(),
946  rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
947  return success();
948  }
949 };
950 
951 /// Merge BroadcastOp into ContractionOp user.
952 /// Ex:
953 /// ```
954 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
955 /// %1 = vector.contract {indexing_maps = [
956 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
957 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
958 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
959 /// iterator_types = ["parallel", "parallel", "reduction"],
960 /// kind = add} %0, %arg1, %cst_f0
961 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
962 /// ```
963 /// Gets converted to:
964 /// ```
965 /// %1 = vector.contract {indexing_maps = [
966 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
967 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
968 /// affine_map<(d0, d1, d2) -> (d0, d1)>],
969 /// iterator_types = ["parallel", "parallel", "reduction"],
970 /// kind = add} %arg0, %arg1, %cst_f0
971 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
972 /// ```
973 struct CombineContractBroadcast
974  : public OpRewritePattern<vector::ContractionOp> {
976 
977  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
978  PatternRewriter &rewriter) const override {
979  SmallVector<AffineMap, 4> maps =
980  llvm::to_vector<4>(contractOp.getIndexingMaps());
981  Value lhs = contractOp.lhs();
982  Value rhs = contractOp.rhs();
983  size_t index = 0;
984  bool changed = false;
985  for (Value *operand : {&lhs, &rhs}) {
986  AffineMap &map = maps[index++];
987  auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
988  if (!broadcast)
989  continue;
990  // contractionOp can only take vector as operands.
991  auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
992  if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
993  continue;
994  int64_t rankDiff =
995  broadcast.getVectorType().getRank() - srcType.getRank();
996  bool innerDimBroadcast = false;
997  SmallVector<AffineExpr> originalDims;
998  for (const auto &dim : llvm::enumerate(srcType.getShape())) {
999  if (dim.value() !=
1000  broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
1001  innerDimBroadcast = true;
1002  break;
1003  }
1004  originalDims.push_back(
1005  rewriter.getAffineDimExpr(dim.index() + rankDiff));
1006  }
1007  // Contract doesn't support inner dimension broadcast. Once this is
1008  // relaxed we can remove this case.
1009  if (innerDimBroadcast)
1010  continue;
1011  AffineMap broadcastMap =
1012  AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
1013  contractOp.getContext());
1014  map = broadcastMap.compose(map);
1015  *operand = broadcast.source();
1016  changed = true;
1017  }
1018  if (!changed)
1019  return failure();
1020  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1021  contractOp, lhs, rhs, contractOp.acc(),
1022  rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
1023  return success();
1024  }
1025 };
1026 
1027 } // namespace
1028 
1029 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
1030 /// operands `x` and `y`.
1031 static Value createAdd(Location loc, Value x, Value y, bool isInt,
1032  PatternRewriter &rewriter) {
1033  if (isInt)
1034  return rewriter.create<arith::AddIOp>(loc, x, y);
1035  return rewriter.create<arith::AddFOp>(loc, x, y);
1036 }
1037 
1038 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
1039 /// operands `x and `y`.
1040 static Value createMul(Location loc, Value x, Value y, bool isInt,
1041  PatternRewriter &rewriter) {
1042  if (isInt)
1043  return rewriter.create<arith::MulIOp>(loc, x, y);
1044  return rewriter.create<arith::MulFOp>(loc, x, y);
1045 }
1046 
1047 namespace mlir {
1048 
1049 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1050 /// semantics to:
1051 /// ```
1052 /// %mta = maybe_transpose
1053 /// %mtb = maybe_transpose
1054 /// %flattened_a = vector.shape_cast %mta
1055 /// %flattened_b = vector.shape_cast %mtb
1056 /// %flattened_d = vector.matmul %flattened_a, %flattened_b
1057 /// %mtd = vector.shape_cast %flattened_d
1058 /// %d = maybe_untranspose %mtd
1059 /// %e = add %c, %d
1060 /// ```
1061 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1062 //
1063 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1064 /// vector.transpose operations are inserted if the vector.contract op is not a
1065 /// row-major matrix multiply.
1068  PatternRewriter &rew) const {
1069  // TODO: implement masks
1070  if (llvm::size(op.masks()) != 0)
1071  return failure();
1072  if (vectorTransformOptions.vectorContractLowering !=
1074  return failure();
1075  if (failed(filter(op)))
1076  return failure();
1077 
1078  auto iteratorTypes = op.iterator_types().getValue();
1079  if (!isParallelIterator(iteratorTypes[0]) ||
1080  !isParallelIterator(iteratorTypes[1]) ||
1081  !isReductionIterator(iteratorTypes[2]))
1082  return failure();
1083 
1084  Type elementType = op.getLhsType().getElementType();
1085  if (!elementType.isIntOrFloat())
1086  return failure();
1087 
1088  // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1089  // Bail out if the contraction cannot be put in this form.
1090  MLIRContext *ctx = op.getContext();
1091  Location loc = op.getLoc();
1092  AffineExpr m, n, k;
1093  bindDims(rew.getContext(), m, n, k);
1094  // LHS must be A(m, k) or A(k, m).
1095  Value lhs = op.lhs();
1096  auto lhsMap = op.indexing_maps()[0].cast<AffineMapAttr>().getValue();
1097  if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1098  lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1099  else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1100  return failure();
1101 
1102  // RHS must be B(k, n) or B(n, k).
1103  Value rhs = op.rhs();
1104  auto rhsMap = op.indexing_maps()[1].cast<AffineMapAttr>().getValue();
1105  if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1106  rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1107  else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1108  return failure();
1109 
1110  // At this point lhs and rhs are in row-major.
1111  VectorType lhsType = lhs.getType().cast<VectorType>();
1112  VectorType rhsType = rhs.getType().cast<VectorType>();
1113  int64_t lhsRows = lhsType.getDimSize(0);
1114  int64_t lhsColumns = lhsType.getDimSize(1);
1115  int64_t rhsColumns = rhsType.getDimSize(1);
1116 
1117  Type flattenedLHSType =
1118  VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1119  lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1120 
1121  Type flattenedRHSType =
1122  VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1123  rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1124 
1125  Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1126  rhsColumns);
1127  mul = rew.create<vector::ShapeCastOp>(
1128  loc,
1129  VectorType::get({lhsRows, rhsColumns},
1130  getElementTypeOrSelf(op.acc().getType())),
1131  mul);
1132 
1133  // ACC must be C(m, n) or C(n, m).
1134  auto accMap = op.indexing_maps()[2].cast<AffineMapAttr>().getValue();
1135  if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1136  mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1137  else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1138  llvm_unreachable("invalid contraction semantics");
1139 
1140  Value res =
1141  elementType.isa<IntegerType>()
1142  ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.acc(), mul))
1143  : static_cast<Value>(rew.create<arith::AddFOp>(loc, op.acc(), mul));
1144 
1145  rew.replaceOp(op, res);
1146  return success();
1147 }
1148 
1149 namespace {
1150 struct IteratorType {
1151  IteratorType(StringRef strRef) : strRef(strRef) {}
1152  bool isOfType(Attribute attr) const {
1153  auto sAttr = attr.dyn_cast<StringAttr>();
1154  return sAttr && sAttr.getValue() == strRef;
1155  }
1156  StringRef strRef;
1157 };
1158 struct Par : public IteratorType {
1160 };
1161 struct Red : public IteratorType {
1163 };
1164 
1165 /// Generate a vector implementation for matmat, matvec and tmatvec.
1166 /// This unrolls outer-products along the reduction dimension.
1167 struct UnrolledOuterProductGenerator
1168  : public StructuredGenerator<vector::ContractionOp> {
1169 
1170  UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
1172  kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()),
1173  lhsType(op.getLhsType()) {}
1174 
1175  Value t(Value v) {
1176  static constexpr std::array<int64_t, 2> perm = {1, 0};
1177  return builder.create<vector::TransposeOp>(loc, v, perm);
1178  }
1179 
1180  Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
1181  assert(reductionSize > 0);
1182  for (int64_t k = 0; k < reductionSize; ++k) {
1183  Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
1184  Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
1185  res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b,
1186  res, kind);
1187  }
1188  return res;
1189  }
1190 
1191  /// Two outer parallel, one inner reduction (matmat flavor).
1192  FailureOr<Value> matmat() {
1193  if (!iters({Par(), Par(), Red()}))
1194  return failure();
1195  // Set up the parallel/reduction structure in the right form.
1196  AffineExpr m, n, k;
1197  bindDims(builder.getContext(), m, n, k);
1198  // Classical row-major matmul: Just permute the lhs.
1199  if (layout({{m, k}, {k, n}, {m, n}}))
1200  return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1201  // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1202  if (layout({{m, k}, {n, k}, {m, n}})) {
1203  Value tlhs = t(lhs);
1204  return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
1205  }
1206  // No need to permute anything.
1207  if (layout({{k, m}, {k, n}, {m, n}}))
1208  return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1209  // Just permute the rhs.
1210  if (layout({{k, m}, {n, k}, {m, n}}))
1211  return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
1212  // Transposed output: swap RHS and LHS.
1213  // Classical row-major matmul: permute the lhs.
1214  if (layout({{m, k}, {k, n}, {n, m}}))
1215  return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
1216  // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1217  if (layout({{m, k}, {n, k}, {n, m}})) {
1218  Value trhs = t(rhs);
1219  return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
1220  }
1221  if (layout({{k, m}, {k, n}, {n, m}}))
1222  return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1223  if (layout({{k, m}, {n, k}, {n, m}}))
1224  return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1225  return failure();
1226  }
1227 
1228  /// One outer parallel, one inner reduction (matvec flavor)
1229  FailureOr<Value> matvec() {
1230  if (!iters({Par(), Red()}))
1231  return failure();
1232  AffineExpr m, k;
1233  bindDims(builder.getContext(), m, k);
1234 
1235  // Case mat-vec: transpose.
1236  if (layout({{m, k}, {k}, {m}}))
1237  return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1238  // Case mat-trans-vec: ready to go.
1239  if (layout({{k, m}, {k}, {m}}))
1240  return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1241  // Case vec-mat: swap and transpose.
1242  if (layout({{k}, {m, k}, {m}}))
1243  return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1244  // Case vec-mat-trans: swap and ready to go.
1245  if (layout({{k}, {k, m}, {m}}))
1246  return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1247  return failure();
1248  }
1249 
1250  //
1251  // One outer reduction, one inner parallel (tmatvec flavor)
1252  //
1253  FailureOr<Value> tmatvec() {
1254  if (!iters({Red(), Par()}))
1255  return failure();
1256  AffineExpr k, m;
1257  bindDims(builder.getContext(), k, m);
1258 
1259  // Case mat-vec: transpose.
1260  if (layout({{m, k}, {k}, {m}}))
1261  return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1262  // Case mat-trans-vec: ready to go.
1263  if (layout({{k, m}, {k}, {m}}))
1264  return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1265  // Case vec-mat: swap and transpose.
1266  if (layout({{k}, {m, k}, {m}}))
1267  return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1268  // Case vec-mat-trans: swap and ready to go.
1269  if (layout({{k}, {k, m}, {m}}))
1270  return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1271  return failure();
1272  }
1273 
1274 private:
1275  vector::CombiningKind kind;
1276  Value lhs, rhs, res;
1277  VectorType lhsType;
1278 };
1279 } // namespace
1280 
1281 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1282 /// semantics to a reduction_size-unrolled sequence:
1283 /// ```
1284 /// %at = vector.transpose %a, [1, 0]
1285 /// %bRow0 = vector.extract %b[0]
1286 /// %atRow0 = vector.extract %at[0]
1287 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
1288 /// ...
1289 /// %bRowK = vector.extract %b[K]
1290 /// %atRowK = vector.extract %at[K]
1291 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
1292 /// ```
1293 ///
1294 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
1295 /// otherwise supports any layout permutation of the matrix-multiply.
1297  vector::ContractionOp op, PatternRewriter &rewriter) const {
1298  // TODO: implement masks
1299  if (llvm::size(op.masks()) != 0)
1300  return failure();
1301 
1302  if (vectorTransformOptions.vectorContractLowering !=
1304  return failure();
1305 
1306  if (failed(filter(op)))
1307  return failure();
1308 
1309  UnrolledOuterProductGenerator e(rewriter, op);
1310  FailureOr<Value> matmatRes = e.matmat();
1311  if (succeeded(matmatRes)) {
1312  rewriter.replaceOp(op, *matmatRes);
1313  return success();
1314  }
1315  FailureOr<Value> matvecRes = e.matvec();
1316  if (succeeded(matvecRes)) {
1317  rewriter.replaceOp(op, *matvecRes);
1318  return success();
1319  }
1320  FailureOr<Value> tmatvecRes = e.tmatvec();
1321  if (succeeded(tmatvecRes)) {
1322  rewriter.replaceOp(op, *tmatvecRes);
1323  return success();
1324  }
1325 
1326  return failure();
1327 }
1328 
1331  PatternRewriter &rewriter) const {
1332  // TODO: implement masks
1333  if (llvm::size(op.masks()) != 0)
1334  return failure();
1335 
1336  if (failed(filter(op)))
1337  return failure();
1338 
1339  if (vectorTransformOptions.vectorContractLowering !=
1341  return failure();
1342 
1343  auto iteratorTypes = op.iterator_types().getValue();
1344  static constexpr std::array<int64_t, 2> perm = {1, 0};
1345  Location loc = op.getLoc();
1346  Value lhs = op.lhs(), rhs = op.rhs();
1347 
1348  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1349  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1350  AffineExpr m, n, k;
1351  bindDims(rewriter.getContext(), m, n, k);
1352  SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1353  //
1354  // In the following we wish to make the reduction dimension innermost so we
1355  // can load vectors and just fmul + reduce into a scalar.
1356  //
1357  if (isParallelIterator(iteratorTypes[0]) &&
1358  isParallelIterator(iteratorTypes[1]) &&
1359  isReductionIterator(iteratorTypes[2])) {
1360  //
1361  // Two outer parallel, one inner reduction (matmat flavor).
1362  //
1363  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1364  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1365  } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1366  // No need to permute anything.
1367  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1368  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1369  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1370  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1371  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1372  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1373  // This is the classical row-major matmul. Just permute the lhs.
1374  Value tmp = lhs;
1375  lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1376  rhs = tmp;
1377  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1378  std::swap(lhs, rhs);
1379  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1380  Value tmp = lhs;
1381  lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1382  rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1383  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1384  Value tmp = rhs;
1385  rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1386  lhs = tmp;
1387  } else {
1388  return failure();
1389  }
1390  } else if (isParallelIterator(iteratorTypes[0]) &&
1391  isReductionIterator(iteratorTypes[1])) {
1392  //
1393  // One outer parallel, one inner reduction (matvec flavor)
1394  //
1395  if (maps == infer({{m, n}, {n}, {m}})) {
1396  // No need to permute anything.
1397  } else if (maps == infer({{n, m}, {n}, {m}})) {
1398  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1399  } else if (maps == infer({{n}, {m, n}, {m}})) {
1400  std::swap(lhs, rhs);
1401  } else if (maps == infer({{n}, {n, m}, {m}})) {
1402  std::swap(lhs, rhs);
1403  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1404  } else {
1405  return failure();
1406  }
1407  } else {
1408  return failure();
1409  }
1410 
1411  VectorType dstType = op.getResultType().cast<VectorType>();
1412  assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
1413  "Expected dst type of rank 1 or 2");
1414 
1415  unsigned rank = dstType.getRank();
1416  unsigned dstRows = dstType.getShape()[0];
1417  unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
1418 
1419  // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
1420  Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
1421  rewriter.getZeroAttr(dstType));
1422  bool isInt = dstType.getElementType().isa<IntegerType>();
1423  for (unsigned r = 0; r < dstRows; ++r) {
1424  Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
1425  for (unsigned c = 0; c < dstColumns; ++c) {
1426  Value b = rank == 1
1427  ? rhs
1428  : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
1429  Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
1430  Value reduced = rewriter.create<vector::ReductionOp>(
1431  op.getLoc(), dstType.getElementType(), rewriter.getStringAttr("add"),
1432  m, ValueRange{});
1433 
1435  : SmallVector<int64_t, 2>{r, c};
1436  res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
1437  }
1438  }
1439  if (auto acc = op.acc())
1440  res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1441  rewriter.replaceOp(op, res);
1442  return success();
1443 }
1444 
1445 /// Progressive lowering of ContractionOp.
1446 /// One:
1447 /// %x = vector.contract with at least one free/batch dimension
1448 /// is replaced by:
1449 /// %a = vector.contract with one less free/batch dimension
1450 /// %b = vector.contract with one less free/batch dimension
1451 /// ..
1452 /// %x = combine %a %b ..
1453 /// until a pure contraction is reached (no free/batch dimensions),
1454 /// which is replaced by a dot-product.
1455 ///
1456 /// This only kicks in when either VectorTransformsOptions is set
1457 /// to DOT or when other contraction patterns fail.
1458 //
1459 // TODO: break down into transpose/reshape/cast ops
1460 // when they become available to avoid code dup
1461 // TODO: investigate lowering order impact on performance
1464  PatternRewriter &rewriter) const {
1465  // TODO: implement masks.
1466  if (llvm::size(op.masks()) != 0)
1467  return failure();
1468 
1469  if (failed(filter(op)))
1470  return failure();
1471 
1472  // TODO: support mixed mode contract lowering.
1473  if (op.getLhsType().getElementType() !=
1474  getElementTypeOrSelf(op.getAccType()) ||
1475  op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
1476  return failure();
1477 
1478  // TODO: implement benefits, cost models.
1479  MLIRContext *ctx = op.getContext();
1480  ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
1481  if (succeeded(pat1.matchAndRewrite(op, rewriter)))
1482  return success();
1483  ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
1484  if (succeeded(pat2.matchAndRewrite(op, rewriter)))
1485  return success();
1486  ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
1487  if (succeeded(pat3.matchAndRewrite(op, rewriter)))
1488  return success();
1489 
1490  // Find first batch dimension in LHS/RHS, and lower when found.
1491  std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
1492  if (!batchDimMap.empty()) {
1493  int64_t lhsIndex = batchDimMap[0].first;
1494  int64_t rhsIndex = batchDimMap[0].second;
1495  rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
1496  return success();
1497  }
1498 
1499  // Collect contracting dimensions.
1500  std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
1501  op.getContractingDimMap();
1502  DenseSet<int64_t> lhsContractingDimSet;
1503  DenseSet<int64_t> rhsContractingDimSet;
1504  for (auto &dimPair : contractingDimMap) {
1505  lhsContractingDimSet.insert(dimPair.first);
1506  rhsContractingDimSet.insert(dimPair.second);
1507  }
1508 
1509  // Find first free dimension in LHS, and lower when found.
1510  VectorType lhsType = op.getLhsType();
1511  for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1512  if (lhsContractingDimSet.count(lhsIndex) == 0) {
1513  rewriter.replaceOp(
1514  op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
1515  return success();
1516  }
1517  }
1518 
1519  // Find first free dimension in RHS, and lower when found.
1520  VectorType rhsType = op.getRhsType();
1521  for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1522  if (rhsContractingDimSet.count(rhsIndex) == 0) {
1523  rewriter.replaceOp(
1524  op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
1525  return success();
1526  }
1527  }
1528 
1529  // Lower the first remaining reduction dimension.
1530  if (!contractingDimMap.empty()) {
1531  rewriter.replaceOp(op, lowerReduction(op, rewriter));
1532  return success();
1533  }
1534 
1535  return failure();
1536 }
1537 
1538 // Lower one parallel dimension.
1539 // TODO: consider reusing existing contract unrolling
1540 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
1541  int64_t lhsIndex, int64_t rhsIndex,
1542  PatternRewriter &rewriter) const {
1543  VectorType lhsType = op.getLhsType();
1544  VectorType rhsType = op.getRhsType();
1545  VectorType resType = op.getResultType().cast<VectorType>();
1546  // Find the iterator type index and result index.
1547  SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1548  int64_t iterIndex = -1;
1549  int64_t dimSize = -1;
1550  if (lhsIndex >= 0) {
1551  iterIndex = iMap[0].getDimPosition(lhsIndex);
1552  assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
1553  "parallel index should be free in LHS or batch in LHS/RHS");
1554  dimSize = lhsType.getDimSize(lhsIndex);
1555  } else {
1556  assert(rhsIndex >= 0 && "missing parallel index");
1557  iterIndex = iMap[1].getDimPosition(rhsIndex);
1558  dimSize = rhsType.getDimSize(rhsIndex);
1559  }
1560  assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
1561  Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
1562  assert(lookup.hasValue() && "parallel index not listed in reduction");
1563  int64_t resIndex = lookup.getValue();
1564  // Construct new iterator types and affine map array attribute.
1565  std::array<AffineMap, 3> lowIndexingMaps = {
1566  adjustMap(iMap[0], iterIndex, rewriter),
1567  adjustMap(iMap[1], iterIndex, rewriter),
1568  adjustMap(iMap[2], iterIndex, rewriter)};
1569  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1570  auto lowIter =
1571  rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
1572  // Unroll into a series of lower dimensional vector.contract ops.
1573  Location loc = op.getLoc();
1574  Value result = rewriter.create<arith::ConstantOp>(
1575  loc, resType, rewriter.getZeroAttr(resType));
1576  for (int64_t d = 0; d < dimSize; ++d) {
1577  auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
1578  auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
1579  auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
1580  Value lowContract = rewriter.create<vector::ContractionOp>(
1581  loc, lhs, rhs, acc, lowAffine, lowIter);
1582  result =
1583  reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
1584  }
1585  return result;
1586 }
1587 
1588 // Lower one reduction dimension.
1589 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
1590  PatternRewriter &rewriter) const {
1591  auto loc = op.getLoc();
1592  VectorType lhsType = op.getLhsType();
1593  VectorType rhsType = op.getRhsType();
1594  Type resType = op.getResultType();
1595  assert(!resType.isa<VectorType>());
1596  bool isInt = resType.isa<IntegerType>();
1597  // Use iterator index 0.
1598  int64_t iterIndex = 0;
1599  SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1600  Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1601  Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1602  assert(lookupLhs.hasValue() && "missing LHS parallel index");
1603  assert(lookupRhs.hasValue() && "missing RHS parallel index");
1604  int64_t lhsIndex = lookupLhs.getValue();
1605  int64_t rhsIndex = lookupRhs.getValue();
1606  int64_t dimSize = lhsType.getDimSize(lhsIndex);
1607  assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
1608  // Base case.
1609  if (lhsType.getRank() == 1) {
1610  assert(rhsType.getRank() == 1 && "corrupt contraction");
1611  Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter);
1612  StringAttr kind = rewriter.getStringAttr("add");
1613  Value res = rewriter.create<vector::ReductionOp>(loc, resType, kind, m,
1614  ValueRange{});
1615  if (auto acc = op.acc())
1616  res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1617  return res;
1618  }
1619  // Construct new iterator types and affine map array attribute.
1620  std::array<AffineMap, 3> lowIndexingMaps = {
1621  adjustMap(iMap[0], iterIndex, rewriter),
1622  adjustMap(iMap[1], iterIndex, rewriter),
1623  adjustMap(iMap[2], iterIndex, rewriter)};
1624  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1625  auto lowIter =
1626  rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
1627  // Unroll into a series of lower dimensional vector.contract ops.
1628  // By feeding the initial accumulator into the first contraction,
1629  // and the result of each contraction into the next, eventually
1630  // the sum of all reductions is computed.
1631  Value result = op.acc();
1632  for (int64_t d = 0; d < dimSize; ++d) {
1633  auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
1634  auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
1635  result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
1636  lowAffine, lowIter);
1637  }
1638  return result;
1639 }
1640 
1641 } // namespace mlir
1642 
1644  OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
1645  ArrayRef<int64_t> multiplicity, const AffineMap &map) {
1646  OpBuilder::InsertionGuard guard(builder);
1647  builder.setInsertionPointAfter(op);
1648  Location loc = op->getLoc();
1649  if (op->getNumResults() != 1)
1650  return {};
1651  Value result = op->getResult(0);
1652  VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
1653  if (!type || map.getNumResults() != multiplicity.size())
1654  return {};
1655  // For each dimension being distributed check that the size is a multiple of
1656  // the multiplicity. To handle more sizes we would need to support masking.
1657  unsigned multiplictyCount = 0;
1658  for (auto exp : map.getResults()) {
1659  auto affinExp = exp.dyn_cast<AffineDimExpr>();
1660  if (!affinExp || affinExp.getPosition() >= type.getRank() ||
1661  type.getDimSize(affinExp.getPosition()) %
1662  multiplicity[multiplictyCount++] !=
1663  0)
1664  return {};
1665  }
1666  DistributeOps ops;
1667  ops.extract =
1668  builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
1669  ops.insert =
1670  builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
1671  return ops;
1672 }
1673 
1674 /// Progressive lowering of transfer_read. This pattern supports lowering of
1675 /// `vector.transfer_read` to a combination of `vector.load` and
1676 /// `vector.broadcast` if all of the following hold:
1677 /// - Stride of most minor memref dimension must be 1.
1678 /// - Out-of-bounds masking is not required.
1679 /// - If the memref's element type is a vector type then it coincides with the
1680 /// result type.
1681 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
1683  : public OpRewritePattern<vector::TransferReadOp> {
1685  llvm::Optional<unsigned> maxRank)
1686  : OpRewritePattern<vector::TransferReadOp>(context),
1687  maxTransferRank(maxRank) {}
1688 
1689  LogicalResult matchAndRewrite(vector::TransferReadOp read,
1690  PatternRewriter &rewriter) const override {
1691  if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
1692  return failure();
1693 
1694  SmallVector<unsigned, 4> broadcastedDims;
1695  // Permutations are handled by VectorToSCF or
1696  // populateVectorTransferPermutationMapLoweringPatterns.
1697  // We let the 0-d corner case pass-through as it is supported.
1698  if (!read.permutation_map().isMinorIdentityWithBroadcasting(
1699  &broadcastedDims))
1700  return failure();
1701 
1702  auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
1703  if (!memRefType)
1704  return failure();
1705 
1706  // Non-unit strides are handled by VectorToSCF.
1707  if (!vector::isLastMemrefDimUnitStride(memRefType))
1708  return failure();
1709 
1710  // If there is broadcasting involved then we first load the unbroadcasted
1711  // vector, and then broadcast it with `vector.broadcast`.
1712  ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
1713  SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
1714  vectorShape.end());
1715  for (unsigned i : broadcastedDims)
1716  unbroadcastedVectorShape[i] = 1;
1717  VectorType unbroadcastedVectorType = VectorType::get(
1718  unbroadcastedVectorShape, read.getVectorType().getElementType());
1719 
1720  // `vector.load` supports vector types as memref's elements only when the
1721  // resulting vector type is the same as the element type.
1722  auto memrefElTy = memRefType.getElementType();
1723  if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
1724  return failure();
1725 
1726  // Otherwise, element types of the memref and the vector must match.
1727  if (!memrefElTy.isa<VectorType>() &&
1728  memrefElTy != read.getVectorType().getElementType())
1729  return failure();
1730 
1731  // Out-of-bounds dims are handled by MaterializeTransferMask.
1732  if (read.hasOutOfBoundsDim())
1733  return failure();
1734 
1735  // Create vector load op.
1736  Operation *loadOp;
1737  if (read.mask()) {
1738  Value fill = rewriter.create<SplatOp>(
1739  read.getLoc(), unbroadcastedVectorType, read.padding());
1740  loadOp = rewriter.create<vector::MaskedLoadOp>(
1741  read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(),
1742  read.mask(), fill);
1743  } else {
1744  loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
1745  unbroadcastedVectorType,
1746  read.source(), read.indices());
1747  }
1748 
1749  // Insert a broadcasting op if required.
1750  if (!broadcastedDims.empty()) {
1751  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1752  read, read.getVectorType(), loadOp->getResult(0));
1753  } else {
1754  rewriter.replaceOp(read, loadOp->getResult(0));
1755  }
1756 
1757  return success();
1758  }
1759 
1761 };
1762 
1763 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
1764 // TODO: we shouldn't cross the vector/scalar domains just for this
1765 // but atm we lack the infra to avoid it. Possible solutions include:
1766 // - go directly to LLVM + bitcast
1767 // - introduce a bitcast op and likely a new pointer dialect
1768 // - let memref.load/store additionally support the 0-d vector case
1769 // There are still deeper data layout issues lingering even in this
1770 // trivial case (for architectures for which this matters).
1772  : public OpRewritePattern<vector::LoadOp> {
1774 
1775  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
1776  PatternRewriter &rewriter) const override {
1777  auto vecType = loadOp.getVectorType();
1778  if (vecType.getNumElements() != 1)
1779  return failure();
1780  auto memrefLoad = rewriter.create<memref::LoadOp>(
1781  loadOp.getLoc(), loadOp.base(), loadOp.indices());
1782  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
1783  memrefLoad);
1784  return success();
1785  }
1786 };
1787 
1788 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
1790  : public OpRewritePattern<vector::StoreOp> {
1792 
1793  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
1794  PatternRewriter &rewriter) const override {
1795  auto vecType = storeOp.getVectorType();
1796  if (vecType.getNumElements() != 1)
1797  return failure();
1798  Value extracted;
1799  if (vecType.getRank() == 0) {
1800  // TODO: Unifiy once ExtractOp supports 0-d vectors.
1801  extracted = rewriter.create<vector::ExtractElementOp>(
1802  storeOp.getLoc(), storeOp.valueToStore());
1803  } else {
1804  SmallVector<int64_t> indices(vecType.getRank(), 0);
1805  extracted = rewriter.create<vector::ExtractOp>(
1806  storeOp.getLoc(), storeOp.valueToStore(), indices);
1807  }
1808 
1809  rewriter.replaceOpWithNewOp<memref::StoreOp>(
1810  storeOp, extracted, storeOp.base(), storeOp.indices());
1811  return success();
1812  }
1813 };
1814 
1815 /// Progressive lowering of transfer_write. This pattern supports lowering of
1816 /// `vector.transfer_write` to `vector.store` if all of the following hold:
1817 /// - Stride of most minor memref dimension must be 1.
1818 /// - Out-of-bounds masking is not required.
1819 /// - If the memref's element type is a vector type then it coincides with the
1820 /// type of the written value.
1821 /// - The permutation map is the minor identity map (neither permutation nor
1822 /// broadcasting is allowed).
1824  : public OpRewritePattern<vector::TransferWriteOp> {
1826  llvm::Optional<unsigned> maxRank)
1827  : OpRewritePattern<vector::TransferWriteOp>(context),
1828  maxTransferRank(maxRank) {}
1829 
1830  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
1831  PatternRewriter &rewriter) const override {
1832  if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
1833  return failure();
1834 
1835  // Permutations are handled by VectorToSCF or
1836  // populateVectorTransferPermutationMapLoweringPatterns.
1837  if ( // pass-through for the 0-d corner case.
1838  !write.permutation_map().isMinorIdentity())
1839  return failure();
1840 
1841  auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
1842  if (!memRefType)
1843  return failure();
1844 
1845  // Non-unit strides are handled by VectorToSCF.
1846  if (!vector::isLastMemrefDimUnitStride(memRefType))
1847  return failure();
1848 
1849  // `vector.store` supports vector types as memref's elements only when the
1850  // type of the vector value being written is the same as the element type.
1851  auto memrefElTy = memRefType.getElementType();
1852  if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
1853  return failure();
1854 
1855  // Otherwise, element types of the memref and the vector must match.
1856  if (!memrefElTy.isa<VectorType>() &&
1857  memrefElTy != write.getVectorType().getElementType())
1858  return failure();
1859 
1860  // Out-of-bounds dims are handled by MaterializeTransferMask.
1861  if (write.hasOutOfBoundsDim())
1862  return failure();
1863  if (write.mask()) {
1864  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
1865  write, write.source(), write.indices(), write.mask(), write.vector());
1866  } else {
1867  rewriter.replaceOpWithNewOp<vector::StoreOp>(
1868  write, write.vector(), write.source(), write.indices());
1869  }
1870  return success();
1871  }
1872 
1874 };
1875 
1876 // Returns the values in `arrayAttr` as an integer vector.
1877 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
1878  return llvm::to_vector<4>(
1879  llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
1880  [](IntegerAttr attr) { return attr.getInt(); }));
1881 }
1882 
1883 // Shuffles vector.bitcast op after vector.extract op.
1884 //
1885 // This transforms IR like:
1886 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
1887 // %1 = vector.extract %0[3] : vector<8xf16>
1888 // Into:
1889 // %0 = vector.extract %src[1] : vector<4xf32>
1890 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
1891 // %2 = vector.extract %1[1] : vector<2xf16>
1893  : public OpRewritePattern<vector::ExtractOp> {
1895 
1896  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
1897  PatternRewriter &rewriter) const override {
1898  // Only support extracting scalars for now.
1899  if (extractOp.getVectorType().getRank() != 1)
1900  return failure();
1901 
1902  auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
1903  if (!castOp)
1904  return failure();
1905 
1906  VectorType castSrcType = castOp.getSourceVectorType();
1907  VectorType castDstType = castOp.getResultVectorType();
1908  assert(castSrcType.getRank() == castDstType.getRank());
1909 
1910  // Fail to match if we only have one element in the cast op source.
1911  // This is to avoid infinite loop given that this pattern can generate
1912  // such cases.
1913  if (castSrcType.getNumElements() == 1)
1914  return failure();
1915 
1916  // Only support casting to a larger number of elements or now.
1917  // E.g., vector<4xf32> -> vector<8xf16>.
1918  if (castSrcType.getNumElements() > castDstType.getNumElements())
1919  return failure();
1920 
1921  unsigned expandRatio =
1922  castDstType.getNumElements() / castSrcType.getNumElements();
1923 
1924  auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
1925  return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
1926  };
1927 
1928  uint64_t index = getFirstIntValue(extractOp.position());
1929 
1930  // Get the single scalar (as a vector) in the source value that packs the
1931  // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
1932  VectorType oneScalarType =
1933  VectorType::get({1}, castSrcType.getElementType());
1934  Value packedValue = rewriter.create<vector::ExtractOp>(
1935  extractOp.getLoc(), oneScalarType, castOp.source(),
1936  rewriter.getI64ArrayAttr(index / expandRatio));
1937 
1938  // Cast it to a vector with the desired scalar's type.
1939  // E.g. f32 -> vector<2xf16>
1940  VectorType packedType =
1941  VectorType::get({expandRatio}, castDstType.getElementType());
1942  Value castedValue = rewriter.create<vector::BitCastOp>(
1943  extractOp.getLoc(), packedType, packedValue);
1944 
1945  // Finally extract the desired scalar.
1946  rewriter.replaceOpWithNewOp<vector::ExtractOp>(
1947  extractOp, extractOp.getType(), castedValue,
1948  rewriter.getI64ArrayAttr(index % expandRatio));
1949 
1950  return success();
1951  }
1952 };
1953 
1954 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
1955 //
1956 // This transforms IR like:
1957 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
1958 // %0 = vector.extract_strided_slice %cast {
1959 // offsets = [4], sizes = [4], strides = [1]
1960 // } : vector<8xf16> to vector<4xf16>
1961 // Into:
1962 // %0 = vector.extract_strided_slice %src {
1963 // offsets = [2], sizes = [2], strides = [1]
1964 // } : vector<4xf32> to vector<2xf32>
1965 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
1967  : public OpRewritePattern<vector::ExtractStridedSliceOp> {
1969 
1970  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
1971  PatternRewriter &rewriter) const override {
1972  auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
1973  if (!castOp)
1974  return failure();
1975 
1976  VectorType castSrcType = castOp.getSourceVectorType();
1977  VectorType castDstType = castOp.getResultVectorType();
1978  assert(castSrcType.getRank() == castDstType.getRank());
1979 
1980  int64_t castSrcLastDim = castSrcType.getShape().back();
1981  int64_t castDstLastDim = castDstType.getShape().back();
1982  // Require casting to more elements for now; other cases to be implemented.
1983  if (castSrcLastDim > castDstLastDim)
1984  return failure();
1985 
1986  // Only accept all one strides for now.
1987  if (llvm::any_of(extractOp.strides().getAsValueRange<IntegerAttr>(),
1988  [](const APInt &val) { return !val.isOneValue(); }))
1989  return failure();
1990 
1991  unsigned rank = extractOp.getVectorType().getRank();
1992  assert(castDstLastDim % castSrcLastDim == 0);
1993  int64_t expandRatio = castDstLastDim / castSrcLastDim;
1994 
1995  // If we have a less number of offsets than the rank, then implicitly we
1996  // are selecting the full range for the last bitcasted dimension; other
1997  // dimensions aren't affected. Otherwise, we need to scale down the last
1998  // dimension's offset given we are extracting from less elements now.
1999  ArrayAttr newOffsets = extractOp.offsets();
2000  if (newOffsets.size() == rank) {
2001  SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2002  if (offsets.back() % expandRatio != 0)
2003  return failure();
2004  offsets.back() = offsets.back() / expandRatio;
2005  newOffsets = rewriter.getI64ArrayAttr(offsets);
2006  }
2007 
2008  // Similarly for sizes.
2009  ArrayAttr newSizes = extractOp.sizes();
2010  if (newSizes.size() == rank) {
2011  SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
2012  if (sizes.back() % expandRatio != 0)
2013  return failure();
2014  sizes.back() = sizes.back() / expandRatio;
2015  newSizes = rewriter.getI64ArrayAttr(sizes);
2016  }
2017 
2019  llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
2020  dims.back() = dims.back() / expandRatio;
2021  VectorType newExtractType =
2022  VectorType::get(dims, castSrcType.getElementType());
2023 
2024  auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
2025  extractOp.getLoc(), newExtractType, castOp.source(), newOffsets,
2026  newSizes, extractOp.strides());
2027 
2028  rewriter.replaceOpWithNewOp<vector::BitCastOp>(
2029  extractOp, extractOp.getType(), newExtractOp);
2030 
2031  return success();
2032  }
2033 };
2034 
2035 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
2036 //
2037 // This transforms IR like:
2038 // %0 = vector.insert_strided_slice %src, %dst {
2039 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
2040 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
2041 // Into:
2042 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
2043 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
2044 // %2 = vector.insert_strided_slice %src, %dst {
2045 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
2047  : public OpRewritePattern<vector::BitCastOp> {
2049  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
2050  PatternRewriter &rewriter) const override {
2051  VectorType castSrcType = bitcastOp.getSourceVectorType();
2052  VectorType castDstType = bitcastOp.getResultVectorType();
2053  assert(castSrcType.getRank() == castDstType.getRank());
2054 
2055  int64_t castSrcLastDim = castSrcType.getShape().back();
2056  int64_t castDstLastDim = castDstType.getShape().back();
2057  // Require casting to less elements for now; other cases to be implemented.
2058  if (castSrcLastDim < castDstLastDim)
2059  return failure();
2060 
2061  assert(castSrcLastDim % castDstLastDim == 0);
2062  int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
2063 
2064  auto insertOp =
2065  bitcastOp.source().getDefiningOp<vector::InsertStridedSliceOp>();
2066  if (!insertOp)
2067  return failure();
2068 
2069  // Only accept all one strides for now.
2070  if (llvm::any_of(insertOp.strides().getAsValueRange<IntegerAttr>(),
2071  [](const APInt &val) { return !val.isOneValue(); }))
2072  return failure();
2073 
2074  unsigned rank = insertOp.getSourceVectorType().getRank();
2075  // Require insert op to have the same rank for the source and destination
2076  // vector; other cases to be implemented.
2077  if (rank != insertOp.getDestVectorType().getRank())
2078  return failure();
2079 
2080  ArrayAttr newOffsets = insertOp.offsets();
2081  assert(newOffsets.size() == rank);
2082  SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2083  if (offsets.back() % shrinkRatio != 0)
2084  return failure();
2085  offsets.back() = offsets.back() / shrinkRatio;
2086  newOffsets = rewriter.getI64ArrayAttr(offsets);
2087 
2088  SmallVector<int64_t, 4> srcDims =
2089  llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
2090  srcDims.back() = srcDims.back() / shrinkRatio;
2091  VectorType newCastSrcType =
2092  VectorType::get(srcDims, castDstType.getElementType());
2093 
2094  auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
2095  bitcastOp.getLoc(), newCastSrcType, insertOp.source());
2096 
2097  SmallVector<int64_t, 4> dstDims =
2098  llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
2099  dstDims.back() = dstDims.back() / shrinkRatio;
2100  VectorType newCastDstType =
2101  VectorType::get(dstDims, castDstType.getElementType());
2102 
2103  auto newCastDstOp = rewriter.create<vector::BitCastOp>(
2104  bitcastOp.getLoc(), newCastDstType, insertOp.dest());
2105 
2106  rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
2107  bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
2108  insertOp.strides());
2109 
2110  return success();
2111  }
2112 };
2113 
2115  Type targetType, Value value) {
2116  if (targetType == value.getType())
2117  return value;
2118 
2119  bool targetIsIndex = targetType.isIndex();
2120  bool valueIsIndex = value.getType().isIndex();
2121  if (targetIsIndex ^ valueIsIndex)
2122  return rewriter.create<arith::IndexCastOp>(loc, targetType, value);
2123 
2124  auto targetIntegerType = targetType.dyn_cast<IntegerType>();
2125  auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
2126  assert(targetIntegerType && valueIntegerType &&
2127  "unexpected cast between types other than integers and index");
2128  assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
2129 
2130  if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
2131  return rewriter.create<arith::ExtSIOp>(loc, targetIntegerType, value);
2132  return rewriter.create<arith::TruncIOp>(loc, targetIntegerType, value);
2133 }
2134 
2135 // Helper that returns a vector comparison that constructs a mask:
2136 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
2137 //
2138 // If `dim == 0` then the result will be a 0-D vector.
2139 //
2140 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
2141 // much more compact, IR for this operation, but LLVM eventually
2142 // generates more elaborate instructions for this intrinsic since it
2143 // is very conservative on the boundary conditions.
2145  bool indexOptimizations, int64_t dim,
2146  Value b, Value *off = nullptr) {
2147  auto loc = op->getLoc();
2148  // If we can assume all indices fit in 32-bit, we perform the vector
2149  // comparison in 32-bit to get a higher degree of SIMD parallelism.
2150  // Otherwise we perform the vector comparison using 64-bit indices.
2151  Type idxType =
2152  indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type();
2153  DenseIntElementsAttr indicesAttr;
2154  if (dim == 0 && indexOptimizations) {
2155  indicesAttr = DenseIntElementsAttr::get(
2156  VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
2157  } else if (dim == 0) {
2158  indicesAttr = DenseIntElementsAttr::get(
2159  VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
2160  } else if (indexOptimizations) {
2161  indicesAttr = rewriter.getI32VectorAttr(
2162  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
2163  } else {
2164  indicesAttr = rewriter.getI64VectorAttr(
2165  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
2166  }
2167  Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
2168  // Add in an offset if requested.
2169  if (off) {
2170  Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
2171  Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
2172  indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
2173  }
2174  // Construct the vector comparison.
2175  Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
2176  Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
2177  return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
2178  bounds);
2179 }
2180 
2181 template <typename ConcreteOp>
2182 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
2183 public:
2184  explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
2185  : mlir::OpRewritePattern<ConcreteOp>(context),
2186  indexOptimizations(enableIndexOpt) {}
2187 
2188  LogicalResult matchAndRewrite(ConcreteOp xferOp,
2189  PatternRewriter &rewriter) const override {
2190  if (!xferOp.hasOutOfBoundsDim())
2191  return failure();
2192 
2193  if (xferOp.getVectorType().getRank() > 1 ||
2194  llvm::size(xferOp.indices()) == 0)
2195  return failure();
2196 
2197  Location loc = xferOp->getLoc();
2198  VectorType vtp = xferOp.getVectorType();
2199 
2200  // * Create a vector with linear indices [ 0 .. vector_length - 1 ].
2201  // * Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
2202  // * Let dim the memref dimension, compute the vector comparison mask
2203  // (in-bounds mask):
2204  // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
2205  //
2206  // TODO: when the leaf transfer rank is k > 1, we need the last `k`
2207  // dimensions here.
2208  unsigned vecWidth = vtp.getNumElements();
2209  unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
2210  Value off = xferOp.indices()[lastIndex];
2211  Value dim =
2212  vector::createOrFoldDimOp(rewriter, loc, xferOp.source(), lastIndex);
2213  Value mask = buildVectorComparison(rewriter, xferOp, indexOptimizations,
2214  vecWidth, dim, &off);
2215 
2216  if (xferOp.mask()) {
2217  // Intersect the in-bounds with the mask specified as an op parameter.
2218  mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.mask());
2219  }
2220 
2221  rewriter.updateRootInPlace(xferOp, [&]() {
2222  xferOp.maskMutable().assign(mask);
2223  xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true}));
2224  });
2225 
2226  return success();
2227  }
2228 
2229 private:
2230  const bool indexOptimizations;
2231 };
2232 
2233 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
2235  : public OpRewritePattern<vector::CreateMaskOp> {
2236 public:
2238  bool enableIndexOpt)
2239  : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
2240  indexOptimizations(enableIndexOpt) {}
2241 
2242  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
2243  PatternRewriter &rewriter) const override {
2244  auto dstType = op.getType();
2245  int64_t rank = dstType.getRank();
2246  if (rank > 1)
2247  return failure();
2248  rewriter.replaceOp(
2249  op, buildVectorComparison(rewriter, op, indexOptimizations,
2250  rank == 0 ? 0 : dstType.getDimSize(0),
2251  op.getOperand(0)));
2252  return success();
2253  }
2254 
2255 private:
2256  const bool indexOptimizations;
2257 };
2258 
2259 // Drop inner most contiguous unit dimensions from transfer_read operand.
2260 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
2262 
2263  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
2264  PatternRewriter &rewriter) const override {
2265  // TODO: support 0-d corner case.
2266  if (readOp.getTransferRank() == 0)
2267  return failure();
2268 
2269  // TODO: support mask.
2270  if (readOp.mask())
2271  return failure();
2272 
2273  auto srcType = readOp.source().getType().dyn_cast<MemRefType>();
2274  if (!srcType || !srcType.hasStaticShape())
2275  return failure();
2276 
2277  if (!readOp.permutation_map().isMinorIdentity())
2278  return failure();
2279 
2280  auto targetType = readOp.getVectorType();
2281  if (targetType.getRank() <= 1)
2282  return failure();
2283 
2284  SmallVector<int64_t> srcStrides;
2285  int64_t srcOffset;
2286  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2287  return failure();
2288 
2289  size_t dimsToDrop = 0;
2290  for (size_t i = 1; i < srcStrides.size(); ++i) {
2291  int dim = srcType.getRank() - i - 1;
2292  if (srcStrides[dim] == 1) {
2293  dimsToDrop++;
2294  } else {
2295  break;
2296  }
2297  }
2298  if (dimsToDrop == 0)
2299  return failure();
2300 
2301  auto resultTargetVecType =
2302  VectorType::get(targetType.getShape().drop_back(dimsToDrop),
2303  targetType.getElementType());
2304 
2305  MemRefType resultMemrefType;
2306  if (srcType.getLayout().getAffineMap().isIdentity()) {
2307  resultMemrefType = MemRefType::get(
2308  srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2309  {}, srcType.getMemorySpaceAsInt());
2310  } else {
2311  AffineMap map = srcType.getLayout().getAffineMap();
2312  int numResultDims = map.getNumDims() - dimsToDrop;
2313  int numSymbols = map.getNumSymbols();
2314  for (size_t i = 0; i < dimsToDrop; ++i) {
2315  int dim = srcType.getRank() - i - 1;
2316  map = map.replace(rewriter.getAffineDimExpr(dim),
2317  rewriter.getAffineConstantExpr(0), numResultDims,
2318  numSymbols);
2319  }
2320  resultMemrefType = MemRefType::get(
2321  srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2322  map, srcType.getMemorySpaceAsInt());
2323  }
2324 
2325  auto loc = readOp.getLoc();
2326  SmallVector<int64_t> offsets(srcType.getRank(), 0);
2327  SmallVector<int64_t> strides(srcType.getRank(), 1);
2328 
2329  ArrayAttr inBoundsAttr =
2330  readOp.in_bounds()
2331  ? rewriter.getArrayAttr(
2332  readOp.in_boundsAttr().getValue().drop_back(dimsToDrop))
2333  : ArrayAttr();
2334  Value rankedReducedView = rewriter.create<memref::SubViewOp>(
2335  loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(),
2336  strides);
2337  auto permMap = getTransferMinorIdentityMap(
2338  rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
2339  Value result = rewriter.create<vector::TransferReadOp>(
2340  loc, resultTargetVecType, rankedReducedView,
2341  readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
2342  readOp.padding(),
2343  // TODO: support mask.
2344  /*mask=*/Value(), inBoundsAttr);
2345  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
2346  result);
2347  return success();
2348  }
2349 };
2350 
2352  RewritePatternSet &patterns, bool indexOptimizations) {
2356  patterns.getContext(), indexOptimizations);
2357 }
2358 
2360  RewritePatternSet &patterns) {
2361  patterns.add<ShapeCastOpFolder>(patterns.getContext());
2362 }
2363 
2365  RewritePatternSet &patterns) {
2369 }
2370 
2372  RewritePatternSet &patterns) {
2373  patterns.add<BroadcastOpLowering>(patterns.getContext());
2374 }
2375 
2377  RewritePatternSet &patterns) {
2378  patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
2379  patterns.getContext());
2380 }
2381 
2383  RewritePatternSet &patterns) {
2384  patterns.add<ShapeCastOp2DDownCastRewritePattern,
2385  ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
2386  patterns.getContext());
2387 }
2388 
2391  patterns.add<OuterProductOpLowering>(patterns.getContext());
2394  patterns.getContext());
2395 }
2396 
2399  patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
2400  options, patterns.getContext());
2401 }
2402 
2404  RewritePatternSet &patterns) {
2405  patterns.add<MultiReduceToContract, CombineContractBroadcast,
2406  CombineContractTranspose>(patterns.getContext());
2407 }
2408 
2409 void mlir::vector::
2411  RewritePatternSet &patterns) {
2412  patterns.add<DropInnerMostUnitDims>(patterns.getContext());
2413 }
2414 
2416  RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
2419  maxTransferRank);
2420  patterns
2422  patterns.getContext());
2423 }
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options=VectorTransformsOptions())
Insert TransposeLowering patterns into extraction/insertion.
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:115
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:673
MLIRContext * getContext() const
Definition: Builders.h:54
constexpr StringRef getParallelIteratorTypeName()
Use to encode that a particular iterator type has parallel semantics.
static uint64_t getFirstIntValue(ArrayAttr attr)
Gets the first integer value from attr, assuming it is an integer array attribute.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:444
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:308
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, PatternRewriter &rewriter) const override
Replace a 0-d vector.store with a vector.extractelement + memref.store.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Replace a 0-d vector.load with a memref.load + vector.broadcast.
bool isParallelIterator(Attribute attr)
static Value min(ImplicitLocOpBuilder &builder, Value a, Value b)
unsigned getNumSymbols() const
Definition: AffineMap.cpp:298
unsigned getNumDims() const
Definition: AffineMap.cpp:294
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:264
Progressive lowering of a vector.contract a, b, c with row-major matmul semantics to a reduction_size...
static SmallVector< IntType, 4 > extractVector(ArrayAttr arrayAttr)
static SmallVector< Attribute, 4 > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
Value getOperand(unsigned idx)
Definition: Operation.h:219
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:205
VectorTransposeLowering vectorTransposeLowering
Option to control the lowering of vector.transpose.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:87
Progressive lowering of transfer_write.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
IteratorType
Typed representation for loop type strings.
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
Progressive lowering of ContractionOp.
static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d)
Helper method to apply dimension ordering permutation.
bool isLastMemrefDimUnitStride(MemRefType type)
Return true if the last dimension of the MemRefType has unit stride.
Definition: VectorOps.cpp:115
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:220
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Progressive lowering of a vector.contract a, b, c with row-major matmul semantics to an output-size-u...
Progressively lower to finer grained vector.contract and dot-products.
static ArrayRef< int64_t > vectorShape(Type type)
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns)
Collect a set of patterns that bubble up/down bitcast ops.
static constexpr const bool value
LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override
Builder & dropDim(unsigned pos)
Erase a dim from shape .
Definition: BuiltinTypes.h:292
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:343
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:258
static Optional< int64_t > getResultIndex(AffineMap map, int64_t index)
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:148
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
Progressive lowering of ContractionOp.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc, Type targetType, Value value)
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool indexOptimizations)
These patterns materialize masks for various vector ops such as transfers.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
Progressively lower a vector.contract a, b, c with row-major matmul semantics to a reduction_size-unr...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:77
Optional< DistributeOps > distributPointwiseVectorOp(OpBuilder &builder, Operation *op, ArrayRef< Value > id, ArrayRef< int64_t > multiplicity, const AffineMap &map)
Distribute a N-D vector pointwise operation over a range of given ids taking all values in [0 ...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
U dyn_cast() const
Definition: Types.h:244
LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
llvm::Optional< unsigned > maxTransferRank
TransferWriteToVectorStoreLowering(MLIRContext *context, llvm::Optional< unsigned > maxRank)
U dyn_cast() const
Definition: Value.h:99
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options=VectorTransformsOptions())
Collects patterns to progressively lower vector contraction ops on high-D into low-D reduction and pr...
Lower 2-D transpose to vector.flat_transpose, maps 1-1 to LLVM matrix intrinsics. ...
bool isIndex() const
Definition: Types.cpp:28
Base type for affine expression.
Definition: AffineExpr.h:68
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:276
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
DenseIntElementsAttr getBoolVectorAttr(ArrayRef< bool > values)
Vector-typed DenseIntElementsAttr getters. values must not be empty.
Definition: Builders.cpp:103
IntegerType getI1Type()
Definition: Builders.cpp:50
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:360
LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::LoadOp loadOp, PatternRewriter &rewriter) const override
unsigned getNumResults() const
Definition: AffineMap.cpp:302
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:789
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:38
static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, bool indexOptimizations, int64_t dim, Value b, Value *off=nullptr)
LogicalResult matchAndRewrite(ConcreteOp xferOp, PatternRewriter &rewriter) const override
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:311
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:412
StringRef strRef
Eliminates identifier at the specified position using Fourier-Motzkin variable elimination.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:491
Lower to vector.matrix_multiply, maps 1-1 to LLVM matrix intrinsics.
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
LogicalResult matchAndRewrite(vector::ExtractOp extractOp, PatternRewriter &rewriter) const override
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:209
VectorCreateMaskOpConversion(MLIRContext *context, bool enableIndexOpt)
LogicalResult matchAndRewrite(vector::StoreOp storeOp, PatternRewriter &rewriter) const override
Progressive lowering of transfer_read.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
bool isReductionIterator(Attribute attr)
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:320
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source...
Definition: VectorUtils.cpp:35
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns)
Collects patterns to progressively lower vector mask ops into elementary selection and insertion ops...
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns)
Collects patterns to progressively lower vector.shape_cast ops on high-D vectors into 1-D/2-D vector ...
IntegerType getI64Type()
Definition: Builders.cpp:56
LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override
static llvm::ManagedStatic< PassManagerOptions > options
static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:266
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:293
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
U dyn_cast() const
Definition: Attributes.h:117
Rewrite avx2-specific 2-D vector.transpose, for the supported cases and depending on the TransposeLow...
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns)
Collect a set of vector.shape_cast folding patterns.
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
TransferReadToVectorLoadLowering(MLIRContext *context, llvm::Optional< unsigned > maxRank)
Progressive lowering of a vector.contract a, b, c with row-major matmul semantics to: %flattened_a = ...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Structure to control the behavior of vector transform patterns.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:109
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, llvm::Optional< unsigned > maxTransferRank=llvm::None)
Collect a set of transfer read/write lowering patterns.
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:285
static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
U cast() const
Definition: Value.h:107
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:273
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
Progressively lower a vector.contract a, b, c with row-major matmul semantics to: %mta = maybe_transp...
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:802
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:328
static Value createMul(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates a MulIOp if isInt is true otherwise create an MulFOp using operands x andy`.
bool isa() const
Definition: Types.h:234
constexpr StringRef getReductionIteratorTypeName()
Use to encode that a particular iterator type has reduction semantics.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:235
This class helps build Operations.
Definition: Builders.h:177
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:205
This class provides an abstraction over the different types of ranges over Values.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
Conversion pattern for a vector.create_mask (0-D and 1-D only).
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:201
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns)
Collects patterns to progressively lower vector.broadcast ops on high-D vectors to low-D vector ops...
MLIRContext * getContext() const
Definition: PatternMatch.h:906
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:120
IntegerType getI32Type()
Definition: Builders.cpp:54
llvm::Optional< unsigned > maxTransferRank
An attribute that represents a reference to a dense integer vector or tensor object.
U cast() const
Definition: Types.h:250
Lower 2-D transpose to vector.shuffle.
static SmallVector< int64_t, 4 > getIntValueVector(ArrayAttr arrayAttr)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:246