MLIR  16.0.0git
VectorOps.cpp
Go to the documentation of this file.
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/BuiltinOps.h"
27 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/PatternMatch.h"
31 #include "mlir/IR/TypeUtilities.h"
32 #include "mlir/Support/LLVM.h"
34 #include "llvm/ADT/StringSet.h"
35 #include "llvm/ADT/bit.h"
36 #include <numeric>
37 
38 #include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc"
39 // Pull in all enum type and utility function definitions.
40 #include "mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc"
41 
42 using namespace mlir;
43 using namespace mlir::vector;
44 
45 /// Helper enum to classify mask value.
46 enum class MaskFormat {
47  AllTrue = 0,
48  AllFalse = 1,
49  Unknown = 2,
50 };
51 
52 /// Helper method to classify a 1-D mask value. Currently, the method
53 /// looks "under the hood" of a constant value with dense attributes
54 /// and a constant mask operation (since the client may be called at
55 /// various stages during progressive lowering).
57  if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
58  // Inspect constant dense values. We count up for bits that
59  // are set, count down for bits that are cleared, and bail
60  // when a mix is detected.
61  if (auto denseElts = c.getValue().dyn_cast<DenseIntElementsAttr>()) {
62  int64_t val = 0;
63  for (bool b : denseElts.getValues<bool>())
64  if (b && val >= 0)
65  val++;
66  else if (!b && val <= 0)
67  val--;
68  else
69  return MaskFormat::Unknown;
70  if (val > 0)
71  return MaskFormat::AllTrue;
72  if (val < 0)
73  return MaskFormat::AllFalse;
74  }
75  } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
76  // Inspect constant mask index. If the index exceeds the
77  // dimension size, all bits are set. If the index is zero
78  // or less, no bits are set.
79  ArrayAttr masks = m.getMaskDimSizes();
80  assert(masks.size() == 1);
81  int64_t i = masks[0].cast<IntegerAttr>().getInt();
82  int64_t u = m.getType().getDimSize(0);
83  if (i >= u)
84  return MaskFormat::AllTrue;
85  if (i <= 0)
86  return MaskFormat::AllFalse;
87  }
88  return MaskFormat::Unknown;
89 }
90 
91 // Helper for verifying combining kinds in contractions and reductions.
92 static bool isSupportedCombiningKind(CombiningKind combiningKind,
93  Type elementType) {
94  switch (combiningKind) {
95  case CombiningKind::ADD:
96  case CombiningKind::MUL:
97  return elementType.isIntOrIndexOrFloat();
98  case CombiningKind::MINUI:
99  case CombiningKind::MINSI:
100  case CombiningKind::MAXUI:
101  case CombiningKind::MAXSI:
102  case CombiningKind::AND:
103  case CombiningKind::OR:
104  case CombiningKind::XOR:
105  return elementType.isIntOrIndex();
106  case CombiningKind::MINF:
107  case CombiningKind::MAXF:
108  return elementType.isa<FloatType>();
109  }
110  return false;
111 }
112 
113 /// Return true if the last dimension of the MemRefType has unit stride. Also
114 /// return true for memrefs with no strides.
116  int64_t offset;
117  SmallVector<int64_t> strides;
118  auto successStrides = getStridesAndOffset(type, strides, offset);
119  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
120 }
121 
123  VectorType vectorType) {
124  int64_t elementVectorRank = 0;
125  VectorType elementVectorType =
126  shapedType.getElementType().dyn_cast<VectorType>();
127  if (elementVectorType)
128  elementVectorRank += elementVectorType.getRank();
129  // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
130  // TODO: replace once we have 0-d vectors.
131  if (shapedType.getRank() == 0 &&
132  vectorType.getShape() == ArrayRef<int64_t>{1})
133  return AffineMap::get(
134  /*numDims=*/0, /*numSymbols=*/0,
135  getAffineConstantExpr(0, shapedType.getContext()));
137  shapedType.getRank(), vectorType.getRank() - elementVectorRank,
138  shapedType.getContext());
139 }
140 
141 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
142  vector::TransferReadOp read) {
143  return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
144  !read.getMask() && defWrite.getIndices() == read.getIndices() &&
145  defWrite.getVectorType() == read.getVectorType() &&
146  defWrite.getPermutationMap() == read.getPermutationMap();
147 }
148 
149 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
150  vector::TransferWriteOp priorWrite) {
151  return priorWrite.getIndices() == write.getIndices() &&
152  priorWrite.getMask() == write.getMask() &&
153  priorWrite.getVectorType() == write.getVectorType() &&
154  priorWrite.getPermutationMap() == write.getPermutationMap();
155 }
156 
158  VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) {
159  // For simplicity only look at transfer of same type.
160  if (transferA.getVectorType() != transferB.getVectorType())
161  return false;
162  unsigned rankOffset = transferA.getLeadingShapedRank();
163  for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
164  auto indexA = transferA.indices()[i].getDefiningOp<arith::ConstantOp>();
165  auto indexB = transferB.indices()[i].getDefiningOp<arith::ConstantOp>();
166  // If any of the indices are dynamic we cannot prove anything.
167  if (!indexA || !indexB)
168  continue;
169 
170  if (i < rankOffset) {
171  // For leading dimensions, if we can prove that index are different we
172  // know we are accessing disjoint slices.
173  if (indexA.getValue().cast<IntegerAttr>().getInt() !=
174  indexB.getValue().cast<IntegerAttr>().getInt())
175  return true;
176  } else {
177  // For this dimension, we slice a part of the memref we need to make sure
178  // the intervals accessed don't overlap.
179  int64_t distance =
180  std::abs(indexA.getValue().cast<IntegerAttr>().getInt() -
181  indexB.getValue().cast<IntegerAttr>().getInt());
182  if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
183  return true;
184  }
185  }
186  return false;
187 }
188 
189 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
190  VectorTransferOpInterface transferB) {
191  if (transferA.source() != transferB.source())
192  return false;
193  return isDisjointTransferIndices(transferA, transferB);
194 }
195 
196 //===----------------------------------------------------------------------===//
197 // CombiningKindAttr
198 //===----------------------------------------------------------------------===//
199 
200 namespace mlir {
201 namespace vector {
202 namespace detail {
204  using KeyTy = uint64_t;
205 
207 
208  bool operator==(const KeyTy &key) const { return value == key; }
209 
211  const KeyTy &key) {
212  return new (allocator.allocate<BitmaskEnumStorage>())
213  BitmaskEnumStorage(key);
214  }
215 
217 };
218 } // namespace detail
219 } // namespace vector
220 } // namespace mlir
221 
223  MLIRContext *context) {
224  return Base::get(context, static_cast<uint64_t>(kind));
225 }
226 
227 CombiningKind CombiningKindAttr::getKind() const {
228  return static_cast<CombiningKind>(getImpl()->value);
229 }
230 
231 static constexpr const CombiningKind combiningKindsList[] = {
232  // clang-format off
233  CombiningKind::ADD,
234  CombiningKind::MUL,
235  CombiningKind::MINUI,
236  CombiningKind::MINSI,
237  CombiningKind::MINF,
238  CombiningKind::MAXUI,
239  CombiningKind::MAXSI,
240  CombiningKind::MAXF,
241  CombiningKind::AND,
242  CombiningKind::OR,
243  CombiningKind::XOR,
244  // clang-format on
245 };
246 
247 void CombiningKindAttr::print(AsmPrinter &printer) const {
248  printer << "<";
249  auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
250  return bitEnumContains(this->getKind(), kind);
251  });
252  llvm::interleaveComma(kinds, printer,
253  [&](auto kind) { printer << stringifyEnum(kind); });
254  printer << ">";
255 }
256 
258  if (failed(parser.parseLess()))
259  return {};
260 
261  StringRef elemName;
262  if (failed(parser.parseKeyword(&elemName)))
263  return {};
264 
265  auto kind = symbolizeCombiningKind(elemName);
266  if (!kind) {
267  parser.emitError(parser.getNameLoc(), "Unknown combining kind: ")
268  << elemName;
269  return {};
270  }
271 
272  if (failed(parser.parseGreater()))
273  return {};
274 
275  return CombiningKindAttr::get(*kind, parser.getContext());
276 }
277 
279  Type type) const {
280  StringRef attrKind;
281  if (parser.parseKeyword(&attrKind))
282  return {};
283 
284  if (attrKind == "kind")
285  return CombiningKindAttr::parse(parser, {});
286 
287  parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind;
288  return {};
289 }
290 
291 void VectorDialect::printAttribute(Attribute attr,
292  DialectAsmPrinter &os) const {
293  if (auto ck = attr.dyn_cast<CombiningKindAttr>()) {
294  os << "kind";
295  ck.print(os);
296  return;
297  }
298  llvm_unreachable("Unknown attribute type");
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // VectorDialect
303 //===----------------------------------------------------------------------===//
304 
305 void VectorDialect::initialize() {
306  addAttributes<CombiningKindAttr>();
307 
308  addOperations<
309 #define GET_OP_LIST
310 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
311  >();
312 }
313 
314 /// Materialize a single constant operation from a given attribute value with
315 /// the desired resultant type.
317  Attribute value, Type type,
318  Location loc) {
319  return builder.create<arith::ConstantOp>(loc, type, value);
320 }
321 
323  return builder.getIntegerType(64);
324 }
325 
327  ArrayRef<int64_t> values) {
328  return builder.getI64ArrayAttr(values);
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // MultiDimReductionOp
333 //===----------------------------------------------------------------------===//
334 
335 void vector::MultiDimReductionOp::build(OpBuilder &builder,
336  OperationState &result, Value source,
337  Value acc, ArrayRef<bool> reductionMask,
338  CombiningKind kind) {
339  SmallVector<int64_t> reductionDims;
340  for (const auto &en : llvm::enumerate(reductionMask))
341  if (en.value())
342  reductionDims.push_back(en.index());
343  build(builder, result, kind, source, acc,
344  builder.getI64ArrayAttr(reductionDims));
345 }
346 
347 OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
348  // Single parallel dim, this is a noop.
349  if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
350  return getSource();
351  return {};
352 }
353 
354 Optional<SmallVector<int64_t, 4>> MultiDimReductionOp::getShapeForUnroll() {
355  return llvm::to_vector<4>(getSourceVectorType().getShape());
356 }
357 
359  SmallVector<int64_t> targetShape;
360  Type inferredReturnType;
361  for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
362  if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
363  return attr.cast<IntegerAttr>().getValue() == it.index();
364  }))
365  targetShape.push_back(it.value());
366  // TODO: update to also allow 0-d vectors when available.
367  if (targetShape.empty())
368  inferredReturnType = getSourceVectorType().getElementType();
369  else
370  inferredReturnType =
371  VectorType::get(targetShape, getSourceVectorType().getElementType());
372  if (getType() != inferredReturnType)
373  return emitOpError() << "destination type " << getType()
374  << " is incompatible with source type "
375  << getSourceVectorType();
376 
377  return success();
378 }
379 
380 //===----------------------------------------------------------------------===//
381 // ReductionOp
382 //===----------------------------------------------------------------------===//
383 
384 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
385  CombiningKind kind, Value vector) {
386  build(builder, result, kind, vector, /*acc=*/Value());
387 }
388 
389 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
390  CombiningKind kind, Value vector, Value acc) {
391  build(builder, result, vector.getType().cast<VectorType>().getElementType(),
392  kind, vector, acc);
393 }
394 
396  // Verify for 1-D vector.
397  int64_t rank = getVectorType().getRank();
398  if (rank != 1)
399  return emitOpError("unsupported reduction rank: ") << rank;
400 
401  // Verify supported reduction kind.
402  Type eltType = getDest().getType();
403  if (!isSupportedCombiningKind(getKind(), eltType))
404  return emitOpError("unsupported reduction type '")
405  << eltType << "' for kind '" << stringifyCombiningKind(getKind())
406  << "'";
407 
408  return success();
409 }
410 
411 ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) {
413  Type redType;
414  Type resType;
415  CombiningKindAttr kindAttr;
416  if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind",
417  result.attributes) ||
418  parser.parseComma() || parser.parseOperandList(operandsInfo) ||
419  parser.parseColonType(redType) ||
420  parser.parseKeywordType("into", resType) ||
421  (!operandsInfo.empty() &&
422  parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
423  (operandsInfo.size() > 1 &&
424  parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
425  parser.addTypeToList(resType, result.types))
426  return failure();
427  if (operandsInfo.empty() || operandsInfo.size() > 2)
428  return parser.emitError(parser.getNameLoc(),
429  "unsupported number of operands");
430  return success();
431 }
432 
434  p << " ";
435  getKindAttr().print(p);
436  p << ", " << getVector();
437  if (getAcc())
438  p << ", " << getAcc();
439  p << " : " << getVector().getType() << " into " << getDest().getType();
440 }
441 
443  OpBuilder &builder, Location loc,
444  Value vector) {
445  switch (op) {
446  case arith::AtomicRMWKind::addf:
447  case arith::AtomicRMWKind::addi:
448  return builder.create<vector::ReductionOp>(vector.getLoc(),
449  CombiningKind::ADD, vector);
450  case arith::AtomicRMWKind::mulf:
451  case arith::AtomicRMWKind::muli:
452  return builder.create<vector::ReductionOp>(vector.getLoc(),
453  CombiningKind::MUL, vector);
454  case arith::AtomicRMWKind::minf:
455  return builder.create<vector::ReductionOp>(vector.getLoc(),
456  CombiningKind::MINF, vector);
457  case arith::AtomicRMWKind::mins:
458  return builder.create<vector::ReductionOp>(vector.getLoc(),
459  CombiningKind::MINSI, vector);
460  case arith::AtomicRMWKind::minu:
461  return builder.create<vector::ReductionOp>(vector.getLoc(),
462  CombiningKind::MINUI, vector);
463  case arith::AtomicRMWKind::maxf:
464  return builder.create<vector::ReductionOp>(vector.getLoc(),
465  CombiningKind::MAXF, vector);
466  case arith::AtomicRMWKind::maxs:
467  return builder.create<vector::ReductionOp>(vector.getLoc(),
468  CombiningKind::MAXSI, vector);
469  case arith::AtomicRMWKind::maxu:
470  return builder.create<vector::ReductionOp>(vector.getLoc(),
471  CombiningKind::MAXUI, vector);
472  case arith::AtomicRMWKind::andi:
473  return builder.create<vector::ReductionOp>(vector.getLoc(),
474  CombiningKind::AND, vector);
475  case arith::AtomicRMWKind::ori:
476  return builder.create<vector::ReductionOp>(vector.getLoc(),
477  CombiningKind::OR, vector);
478  // TODO: Add remaining reduction operations.
479  default:
480  (void)emitOptionalError(loc, "Reduction operation type not supported");
481  break;
482  }
483  return nullptr;
484 }
485 
486 Optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
487  return llvm::to_vector<4>(getVectorType().getShape());
488 }
489 
490 namespace {
491 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
493 
494  LogicalResult matchAndRewrite(ReductionOp reductionOp,
495  PatternRewriter &rewriter) const override {
496  if (reductionOp.getVectorType().getDimSize(0) != 1)
497  return failure();
498 
499  Location loc = reductionOp.getLoc();
500  Value result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
501  reductionOp.getVector(),
502  rewriter.getI64ArrayAttr(0));
503 
504  if (Value acc = reductionOp.getAcc())
505  result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
506  result, acc);
507 
508  rewriter.replaceOp(reductionOp, result);
509  return success();
510  }
511 };
512 } // namespace
513 
514 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
515  MLIRContext *context) {
516  results.add<ElideSingleElementReduction>(context);
517 }
518 
519 //===----------------------------------------------------------------------===//
520 // ContractionOp
521 //===----------------------------------------------------------------------===//
522 
523 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
524  Value lhs, Value rhs, Value acc,
525  ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
526  ArrayRef<StringRef> iteratorTypes) {
527  result.addOperands({lhs, rhs, acc});
528  result.addTypes(acc.getType());
530  builder.getAffineMapArrayAttr(
531  AffineMap::inferFromExprList(indexingExprs)));
533  builder.getStrArrayAttr(iteratorTypes));
534 }
535 
536 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
537  Value lhs, Value rhs, Value acc,
538  ArrayAttr indexingMaps,
539  ArrayAttr iteratorTypes) {
540  build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
541  ContractionOp::getDefaultKind());
542 }
543 
544 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
545  Value lhs, Value rhs, Value acc,
546  ArrayAttr indexingMaps,
547  ArrayAttr iteratorTypes, CombiningKind kind) {
548  result.addOperands({lhs, rhs, acc});
549  result.addTypes(acc.getType());
550  result.addAttribute(::mlir::getIndexingMapsAttrName(), indexingMaps);
551  result.addAttribute(::mlir::getIteratorTypesAttrName(), iteratorTypes);
552  result.addAttribute(ContractionOp::getKindAttrStrName(),
553  CombiningKindAttr::get(kind, builder.getContext()));
554 }
555 
556 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
561  SmallVector<Type, 2> types;
562  Type resultType;
563  auto loc = parser.getCurrentLocation();
564  DictionaryAttr dictAttr;
565  // TODO: Unify linalg op attribute parsing.
566  if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
567  parser.parseOperand(lhsInfo) || parser.parseComma() ||
568  parser.parseOperand(rhsInfo) || parser.parseComma() ||
569  parser.parseOperand(accInfo) ||
570  parser.parseTrailingOperandList(masksInfo) ||
571  parser.parseOptionalAttrDict(result.attributes) ||
572  parser.parseColonTypeList(types) ||
573  parser.parseKeywordType("into", resultType) ||
574  parser.resolveOperand(lhsInfo, types[0], result.operands) ||
575  parser.resolveOperand(rhsInfo, types[1], result.operands) ||
576  parser.resolveOperand(accInfo, resultType, result.operands) ||
577  parser.addTypeToList(resultType, result.types))
578  return failure();
579  result.attributes.assign(dictAttr.getValue().begin(),
580  dictAttr.getValue().end());
581  if (!result.attributes.get(ContractionOp::getKindAttrStrName())) {
582  result.addAttribute(ContractionOp::getKindAttrStrName(),
583  CombiningKindAttr::get(ContractionOp::getDefaultKind(),
584  result.getContext()));
585  }
586  if (masksInfo.empty())
587  return success();
588  if (masksInfo.size() != 2)
589  return parser.emitError(parser.getNameLoc(),
590  "expected zero or exactly 2 vector mask operands");
591  auto lhsType = types[0].cast<VectorType>();
592  auto rhsType = types[1].cast<VectorType>();
593  auto maskElementType = parser.getBuilder().getI1Type();
594  std::array<Type, 2> maskTypes = {
595  VectorType::Builder(lhsType).setElementType(maskElementType),
596  VectorType::Builder(rhsType).setElementType(maskElementType)};
597  if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
598  return failure();
599  return success();
600 }
601 
603  // TODO: Unify printing code with linalg ops.
604  auto attrNames = getTraitAttrNames();
605  llvm::StringSet<> traitAttrsSet;
606  traitAttrsSet.insert(attrNames.begin(), attrNames.end());
608  for (auto attr : (*this)->getAttrs())
609  if (traitAttrsSet.count(attr.getName().strref()) > 0)
610  attrs.push_back(attr);
611 
612  auto dictAttr = DictionaryAttr::get(getContext(), attrs);
613  p << " " << dictAttr << " " << getLhs() << ", ";
614  p << getRhs() << ", " << getAcc();
615  if (getMasks().size() == 2)
616  p << ", " << getMasks();
617 
618  p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
619  p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
620  << getResultType();
621 }
622 
623 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
624  const std::vector<std::pair<int64_t, int64_t>> &map) {
625  for (auto &dimPair : map) {
626  if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
627  dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
628  lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
629  return false;
630  }
631  return true;
632 }
633 
635  ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
636  Type resType,
637  const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
638  const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
639  DenseSet<int64_t> lhsContractingDimSet;
640  DenseSet<int64_t> rhsContractingDimSet;
641  for (auto &dimPair : contractingDimMap) {
642  lhsContractingDimSet.insert(dimPair.first);
643  rhsContractingDimSet.insert(dimPair.second);
644  }
645  DenseSet<int64_t> rhsBatchDimSet;
646  for (auto &dimPair : batchDimMap)
647  rhsBatchDimSet.insert(dimPair.second);
648 
649  // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
650  SmallVector<int64_t, 4> expectedResultDims;
651  for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
652  if (lhsContractingDimSet.count(i) > 0)
653  continue;
654  expectedResultDims.push_back(lhsType.getDimSize(i));
655  }
656 
657  // Add free dimensions from 'rhsType' to 'expectedResultDims'.
658  for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
659  if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
660  continue;
661  expectedResultDims.push_back(rhsType.getDimSize(i));
662  }
663 
664  // Verify 'expectedResultDims'.
665  if (expectedResultDims.empty()) {
666  // No batch or free dimension implies a scalar result.
667  if (resType.isa<VectorType>() || accType.isa<VectorType>())
668  return op.emitOpError("invalid accumulator/result vector shape");
669  } else {
670  // At least one batch or free dimension implies a vector result.
671  auto resVectorType = resType.dyn_cast<VectorType>();
672  auto accVectorType = accType.dyn_cast<VectorType>();
673  if (!resVectorType || !accVectorType)
674  return op.emitOpError("invalid accumulator/result vector shape");
675 
676  // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
677  // types fully define the result vector type. This assumes the affine maps
678  // are well-formed, which must have been verified already.
679  MLIRContext *ctx = op.getContext();
680  AffineMap lhsMap = op.getIndexingMapsArray()[0];
681  AffineMap rhsMap = op.getIndexingMapsArray()[1];
682  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
683  return op.emitOpError(
684  "expected all dimensions to be either a LHS or a RHS dimension");
685  SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
686  for (auto pair :
687  {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
688  VectorType v = pair.first;
689  auto map = pair.second;
690  for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
691  unsigned pos = map.getDimPosition(idx);
692  if (!extents[pos])
693  extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
694  }
695  }
696  if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
697  return op.emitOpError("expected all dimensions to get an extent as "
698  "either a LHS or a RHS dimension");
699 
700  AffineMap resMap = op.getIndexingMapsArray()[2];
701  auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
702  /*symCount=*/0, extents, ctx);
703  // Compose the resMap with the extentsMap, which is a constant map.
704  AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
705  assert(llvm::all_of(
706  expectedMap.getResults(),
707  [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&
708  "expected constant extent along all dimensions.");
709  // Extract the expected shape and build the type.
710  auto expectedShape = llvm::to_vector<4>(
711  llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
712  return e.cast<AffineConstantExpr>().getValue();
713  }));
714  auto expected =
715  VectorType::get(expectedShape, resVectorType.getElementType());
716  if (resVectorType != expected || accVectorType != expected)
717  return op.emitOpError(
718  "invalid accumulator/result vector shape, expected: ")
719  << expected;
720  }
721  return success();
722 }
723 
725  auto lhsType = getLhsType();
726  auto rhsType = getRhsType();
727  auto accType = getAccType();
728  auto resType = getResultType();
729 
730  // Verify that an indexing map was specified for each vector operand.
731  if (getIndexingMapsArray().size() != 3)
732  return emitOpError("expected an indexing map for each vector operand");
733 
734  // Verify that each index map has 'numIterators' inputs, no symbols, and
735  // that the number of map outputs equals the rank of its associated
736  // vector operand.
737  unsigned numIterators = getIteratorTypes().getValue().size();
738  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
739  auto index = it.index();
740  auto map = it.value();
741  if (map.getNumSymbols() != 0)
742  return emitOpError("expected indexing map ")
743  << index << " to have no symbols";
744  auto vectorType = getOperand(index).getType().dyn_cast<VectorType>();
745  unsigned rank = vectorType ? vectorType.getShape().size() : 0;
746  // Verify that the map has the right number of inputs, outputs, and indices.
747  // This also correctly accounts for (..) -> () for rank-0 results.
748  if (map.getNumDims() != numIterators)
749  return emitOpError("expected indexing map ")
750  << index << " to have " << numIterators << " number of inputs";
751  if (map.getNumResults() != rank)
752  return emitOpError("expected indexing map ")
753  << index << " to have " << rank << " number of outputs";
754  if (!map.isProjectedPermutation())
755  return emitOpError("expected indexing map ")
756  << index << " to be a projected permutation of its inputs";
757  }
758 
759  auto contractingDimMap = getContractingDimMap();
760  auto batchDimMap = getBatchDimMap();
761 
762  // Verify at least one contracting dimension pair was specified.
763  if (contractingDimMap.empty())
764  return emitOpError("expected at least one contracting dimension pair");
765 
766  // Verify contracting dimension map was properly constructed.
767  if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
768  return emitOpError("invalid contracting dimension map");
769 
770  // Verify batch dimension map was properly constructed.
771  if (!verifyDimMap(lhsType, rhsType, batchDimMap))
772  return emitOpError("invalid batch dimension map");
773 
774  // Verify 'accType' and 'resType' shape.
775  if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
776  contractingDimMap, batchDimMap)))
777  return failure();
778 
779  // Verify that either two vector masks are set or none are set.
780  auto lhsMaskType = getLHSVectorMaskType();
781  auto rhsMaskType = getRHSVectorMaskType();
782  if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
783  return emitOpError("invalid number of vector masks specified");
784  if (lhsMaskType && rhsMaskType) {
785  // Verify mask rank == argument rank.
786  if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
787  rhsMaskType.getShape().size() != rhsType.getShape().size())
788  return emitOpError("invalid vector mask rank");
789  }
790 
791  // Verify supported combining kind.
792  auto vectorType = resType.dyn_cast<VectorType>();
793  auto elementType = vectorType ? vectorType.getElementType() : resType;
794  if (!isSupportedCombiningKind(getKind(), elementType))
795  return emitOpError("unsupported contraction type");
796 
797  return success();
798 }
799 
800 ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
801  static constexpr StringRef names[3] = {::mlir::getIndexingMapsAttrName(),
803  ContractionOp::getKindAttrStrName()};
804  return llvm::makeArrayRef(names);
805 }
806 
807 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
808  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
809  if (targetExpr == map.getResult(i))
810  return i;
811  return -1;
812 }
813 
814 static std::vector<std::pair<int64_t, int64_t>>
815 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
816  StringRef targetIteratorTypeName, MLIRContext *context) {
817  std::vector<std::pair<int64_t, int64_t>> dimMap;
818  for (const auto &it : llvm::enumerate(iteratorTypes)) {
819  auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
820  if (iteratorTypeName != targetIteratorTypeName)
821  continue;
822  // Search lhs/rhs map results for 'targetExpr'.
823  auto targetExpr = getAffineDimExpr(it.index(), context);
824  int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
825  int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
826  if (lhsDim >= 0 && rhsDim >= 0)
827  dimMap.emplace_back(lhsDim, rhsDim);
828  }
829  return dimMap;
830 }
831 
832 void ContractionOp::getIterationBounds(
833  SmallVectorImpl<int64_t> &iterationBounds) {
834  auto lhsShape = getLhsType().getShape();
835  auto resVectorType = getResultType().dyn_cast<VectorType>();
836  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
837  SmallVector<int64_t, 2> iterationShape;
838  for (const auto &it : llvm::enumerate(getIteratorTypes())) {
839  // Search lhs/rhs map results for 'targetExpr'.
840  auto targetExpr = getAffineDimExpr(it.index(), getContext());
841  auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
842  if (iteratorTypeName == getReductionIteratorTypeName()) {
843  // Get reduction dim size from lhs shape (same size in rhsShape).
844  int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
845  assert(lhsDimIndex >= 0);
846  iterationBounds.push_back(lhsShape[lhsDimIndex]);
847  continue;
848  }
849  // Get parallel dimension size from result shape.
850  int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
851  assert(resDimIndex >= 0);
852  assert(resVectorType != nullptr);
853  iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
854  }
855 }
856 
857 void ContractionOp::getIterationIndexMap(
858  std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
859  unsigned numMaps = getIndexingMapsArray().size();
860  iterationIndexMap.resize(numMaps);
861  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
862  auto index = it.index();
863  auto map = it.value();
864  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
865  auto dim = map.getResult(i).cast<AffineDimExpr>();
866  iterationIndexMap[index][dim.getPosition()] = i;
867  }
868  }
869 }
870 
871 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
872  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
873  return getDimMap(indexingMaps, getIteratorTypes(),
874  getReductionIteratorTypeName(), getContext());
875 }
876 
877 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
878  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
879  return getDimMap(indexingMaps, getIteratorTypes(),
880  getParallelIteratorTypeName(), getContext());
881 }
882 
883 Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
884  SmallVector<int64_t, 4> shape;
885  getIterationBounds(shape);
886  return shape;
887 }
888 
889 /// Return a fused vector::ContractionOp which represents a patterns such as:
890 ///
891 /// ```mlir
892 /// %c0 = vector.constant 0: ...
893 /// %c = vector.contract %a, %b, %c0: ...
894 /// %e = add %c, %d: ...
895 /// ```
896 ///
897 /// by:
898 ///
899 /// ```mlir
900 /// %e = vector.contract %a, %b, %d: ...
901 /// ```
902 ///
903 /// Return null if the canonicalization does not apply.
904 // TODO: This should be a folding of Add into Contract in core but while they
905 // live in different dialects, it is not possible without unnatural
906 // dependencies.
907 template <typename AddOpType>
908 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
910 
912  PatternRewriter &rewriter) const override {
913  auto canonicalize = [&](Value maybeContraction,
914  Value otherOperand) -> vector::ContractionOp {
915  vector::ContractionOp contractionOp =
916  dyn_cast_or_null<vector::ContractionOp>(
917  maybeContraction.getDefiningOp());
918  if (!contractionOp)
919  return vector::ContractionOp();
920  if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
921  contractionOp.getAcc().getDefiningOp())) {
922  if (maybeZero.getValue() ==
923  rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
925  bvm.map(contractionOp.getAcc(), otherOperand);
926  auto newContraction =
927  cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
928  rewriter.replaceOp(addOp, newContraction.getResult());
929  return newContraction;
930  }
931  }
932  return vector::ContractionOp();
933  };
934 
935  Value a = addOp->getOperand(0), b = addOp->getOperand(1);
936  vector::ContractionOp contract = canonicalize(a, b);
937  contract = contract ? contract : canonicalize(b, a);
938  return contract ? success() : failure();
939  }
940 };
941 
942 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
943  MLIRContext *context) {
946 }
947 
948 //===----------------------------------------------------------------------===//
949 // ExtractElementOp
950 //===----------------------------------------------------------------------===//
951 
952 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
953  Value source) {
954  result.addOperands({source});
955  result.addTypes(source.getType().cast<VectorType>().getElementType());
956 }
957 
958 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
959  Value source, Value position) {
960  result.addOperands({source, position});
961  result.addTypes(source.getType().cast<VectorType>().getElementType());
962 }
963 
965  VectorType vectorType = getVectorType();
966  if (vectorType.getRank() == 0) {
967  if (getPosition())
968  return emitOpError("expected position to be empty with 0-D vector");
969  return success();
970  }
971  if (vectorType.getRank() != 1)
972  return emitOpError("unexpected >1 vector rank");
973  if (!getPosition())
974  return emitOpError("expected position for 1-D vector");
975  return success();
976 }
977 
978 OpFoldResult vector::ExtractElementOp::fold(ArrayRef<Attribute> operands) {
979  // Skip the 0-D vector here now.
980  if (operands.size() < 2)
981  return {};
982 
983  Attribute src = operands[0];
984  Attribute pos = operands[1];
985 
986  // Fold extractelement (splat X) -> X.
987  if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
988  return splat.getInput();
989 
990  if (!pos || !src)
991  return {};
992 
993  auto srcElements = src.cast<DenseElementsAttr>().getValues<Attribute>();
994 
995  auto attr = pos.dyn_cast<IntegerAttr>();
996  uint64_t posIdx = attr.getInt();
997 
998  return srcElements[posIdx];
999 }
1000 
1001 //===----------------------------------------------------------------------===//
1002 // ExtractOp
1003 //===----------------------------------------------------------------------===//
1004 
1005 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1006  Value source, ArrayRef<int64_t> position) {
1007  build(builder, result, source, getVectorSubscriptAttr(builder, position));
1008 }
1009 
1010 // Convenience builder which assumes the values are constant indices.
1011 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1012  Value source, ValueRange position) {
1013  SmallVector<int64_t, 4> positionConstants =
1014  llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
1015  return pos.getDefiningOp<arith::ConstantIndexOp>().value();
1016  }));
1017  build(builder, result, source, positionConstants);
1018 }
1019 
1021 ExtractOp::inferReturnTypes(MLIRContext *, Optional<Location>,
1022  ValueRange operands, DictionaryAttr attributes,
1023  RegionRange,
1024  SmallVectorImpl<Type> &inferredReturnTypes) {
1025  ExtractOp::Adaptor op(operands, attributes);
1026  auto vectorType = op.getVector().getType().cast<VectorType>();
1027  if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
1028  inferredReturnTypes.push_back(vectorType.getElementType());
1029  } else {
1030  auto n =
1031  std::min<size_t>(op.getPosition().size(), vectorType.getRank() - 1);
1032  inferredReturnTypes.push_back(VectorType::get(
1033  vectorType.getShape().drop_front(n), vectorType.getElementType()));
1034  }
1035  return success();
1036 }
1037 
1038 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1039  // Allow extracting 1-element vectors instead of scalars.
1040  auto isCompatible = [](TypeRange l, TypeRange r) {
1041  auto vectorType = l.front().dyn_cast<VectorType>();
1042  return vectorType && vectorType.getShape().equals({1}) &&
1043  vectorType.getElementType() == r.front();
1044  };
1045  if (l.size() == 1 && r.size() == 1 &&
1046  (isCompatible(l, r) || isCompatible(r, l)))
1047  return true;
1048  return l == r;
1049 }
1050 
1052  auto positionAttr = getPosition().getValue();
1053  if (positionAttr.size() > static_cast<unsigned>(getVectorType().getRank()))
1054  return emitOpError(
1055  "expected position attribute of rank smaller than vector rank");
1056  for (const auto &en : llvm::enumerate(positionAttr)) {
1057  auto attr = en.value().dyn_cast<IntegerAttr>();
1058  if (!attr || attr.getInt() < 0 ||
1059  attr.getInt() >= getVectorType().getDimSize(en.index()))
1060  return emitOpError("expected position attribute #")
1061  << (en.index() + 1)
1062  << " to be a non-negative integer smaller than the corresponding "
1063  "vector dimension";
1064  }
1065  return success();
1066 }
1067 
1068 template <typename IntType>
1069 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1070  return llvm::to_vector<4>(llvm::map_range(
1071  arrayAttr.getAsRange<IntegerAttr>(),
1072  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1073 }
1074 
1075 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1076 /// positions.
1077 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1078  if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1079  return failure();
1080 
1081  SmallVector<int64_t, 4> globalPosition;
1082  ExtractOp currentOp = extractOp;
1083  auto extrPos = extractVector<int64_t>(currentOp.getPosition());
1084  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1085  while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1086  currentOp = nextOp;
1087  auto extrPos = extractVector<int64_t>(currentOp.getPosition());
1088  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1089  }
1090  extractOp.setOperand(currentOp.getVector());
1091  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1092  OpBuilder b(extractOp.getContext());
1093  std::reverse(globalPosition.begin(), globalPosition.end());
1094  extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1095  b.getI64ArrayAttr(globalPosition));
1096  return success();
1097 }
1098 
1099 namespace {
1100 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1101 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1102 /// Compose TransposeOp permutations as we walk back.
1103 /// This helper class keeps an updated extraction position `extractPosition`
1104 /// with extra trailing sentinels.
1105 /// The sentinels encode the internal transposition status of the result vector.
1106 /// As we iterate, extractPosition is permuted and updated.
1107 class ExtractFromInsertTransposeChainState {
1108 public:
1109  ExtractFromInsertTransposeChainState(ExtractOp e);
1110 
1111  /// Iterate over producing insert and transpose ops until we find a fold.
1112  Value fold();
1113 
1114 private:
1115  /// Return true if the vector at position `a` is contained within the vector
1116  /// at position `b`. Under insert/extract semantics, this is the same as `a`
1117  /// is a prefix of `b`.
1118  template <typename ContainerA, typename ContainerB>
1119  bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1120  return a.size() <= b.size() &&
1121  std::equal(a.begin(), a.begin() + a.size(), b.begin());
1122  }
1123 
1124  /// Return true if the vector at position `a` intersects the vector at
1125  /// position `b`. Under insert/extract semantics, this is the same as equality
1126  /// of all entries of `a` that are >=0 with the corresponding entries of b.
1127  /// Comparison is on the common prefix (i.e. zip).
1128  template <typename ContainerA, typename ContainerB>
1129  bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1130  for (auto it : llvm::zip(a, b)) {
1131  if (std::get<0>(it) < 0 || std::get<0>(it) < 0)
1132  continue;
1133  if (std::get<0>(it) != std::get<1>(it))
1134  return false;
1135  }
1136  return true;
1137  }
1138 
1139  /// Folding is only possible in the absence of an internal permutation in the
1140  /// result vector.
1141  bool canFold() {
1142  return (sentinels ==
1143  makeArrayRef(extractPosition).drop_front(extractedRank));
1144  }
1145 
1146  // Helper to get the next defining op of interest.
1147  void updateStateForNextIteration(Value v) {
1148  nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1149  nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1150  };
1151 
1152  // Case 1. If we hit a transpose, just compose the map and iterate.
1153  // Invariant: insert + transpose do not change rank, we can always compose.
1154  LogicalResult handleTransposeOp();
1155 
1156  // Case 2: the insert position matches extractPosition exactly, early return.
1157  LogicalResult handleInsertOpWithMatchingPos(Value &res);
1158 
1159  /// Case 3: if the insert position is a prefix of extractPosition, extract a
1160  /// portion of the source of the insert.
1161  /// Example:
1162  /// ```
1163  /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1164  /// // extractPosition == [1, 2, 3]
1165  /// %ext = vector.extract %ins[1, 0]: vector<3x4x5>
1166  /// // can fold to vector.extract %source[0, 3]
1167  /// %ext = vector.extract %source[3]: vector<5x6>
1168  /// ```
1169  /// To traverse through %source, we need to set the leading dims to 0 and
1170  /// drop the extra leading dims.
1171  /// This method updates the internal state.
1172  LogicalResult handleInsertOpWithPrefixPos(Value &res);
1173 
1174  /// Try to fold in place to extract(source, extractPosition) and return the
1175  /// folded result. Return null if folding is not possible (e.g. due to an
1176  /// internal tranposition in the result).
1177  Value tryToFoldExtractOpInPlace(Value source);
1178 
1179  ExtractOp extractOp;
1180  int64_t vectorRank;
1181  int64_t extractedRank;
1182 
1183  InsertOp nextInsertOp;
1184  TransposeOp nextTransposeOp;
1185 
1186  /// Sentinel values that encode the internal permutation status of the result.
1187  /// They are set to (-1, ... , -k) at the beginning and appended to
1188  /// `extractPosition`.
1189  /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1190  /// ensure that there is no internal transposition.
1191  /// Internal transposition cannot be accounted for with a folding pattern.
1192  // TODO: We could relax the internal transposition with an extra transposition
1193  // operation in a future canonicalizer.
1194  SmallVector<int64_t> sentinels;
1195  SmallVector<int64_t> extractPosition;
1196 };
1197 } // namespace
1198 
1199 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1200  ExtractOp e)
1201  : extractOp(e), vectorRank(extractOp.getVectorType().getRank()),
1202  extractedRank(extractOp.getPosition().size()) {
1203  assert(vectorRank >= extractedRank && "extracted pos overflow");
1204  sentinels.reserve(vectorRank - extractedRank);
1205  for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1206  sentinels.push_back(-(i + 1));
1207  extractPosition = extractVector<int64_t>(extractOp.getPosition());
1208  llvm::append_range(extractPosition, sentinels);
1209 }
1210 
1211 // Case 1. If we hit a transpose, just compose the map and iterate.
1212 // Invariant: insert + transpose do not change rank, we can always compose.
1213 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1214  if (!nextTransposeOp)
1215  return failure();
1216  auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp());
1218  AffineMap::getPermutationMap(permutation, extractOp.getContext()));
1220  return success();
1221 }
1222 
1223 // Case 2: the insert position matches extractPosition exactly, early return.
1225 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1226  Value &res) {
1227  auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
1228  if (makeArrayRef(insertedPos) !=
1229  llvm::makeArrayRef(extractPosition).take_front(extractedRank))
1230  return failure();
1231  // Case 2.a. early-exit fold.
1232  res = nextInsertOp.getSource();
1233  // Case 2.b. if internal transposition is present, canFold will be false.
1234  return success();
1235 }
1236 
1237 /// Case 3: if inserted position is a prefix of extractPosition,
1238 /// extract a portion of the source of the insertion.
1239 /// This method updates the internal state.
1241 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1242  auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
1243  if (!isContainedWithin(insertedPos, extractPosition))
1244  return failure();
1245  // Set leading dims to zero.
1246  std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1247  // Drop extra leading dims.
1248  extractPosition.erase(extractPosition.begin(),
1249  extractPosition.begin() + insertedPos.size());
1250  extractedRank = extractPosition.size() - sentinels.size();
1251  // Case 3.a. early-exit fold (break and delegate to post-while path).
1252  res = nextInsertOp.getSource();
1253  // Case 3.b. if internal transposition is present, canFold will be false.
1254  return success();
1255 }
1256 
1257 /// Try to fold in place to extract(source, extractPosition) and return the
1258 /// folded result. Return null if folding is not possible (e.g. due to an
1259 /// internal tranposition in the result).
1260 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1261  Value source) {
1262  // If we can't fold (either internal transposition, or nothing to fold), bail.
1263  bool nothingToFold = (source == extractOp.getVector());
1264  if (nothingToFold || !canFold())
1265  return Value();
1266  // Otherwise, fold by updating the op inplace and return its result.
1267  OpBuilder b(extractOp.getContext());
1268  extractOp->setAttr(
1269  extractOp.getPositionAttrName(),
1270  b.getI64ArrayAttr(
1271  makeArrayRef(extractPosition).take_front(extractedRank)));
1272  extractOp.getVectorMutable().assign(source);
1273  return extractOp.getResult();
1274 }
1275 
1276 /// Iterate over producing insert and transpose ops until we find a fold.
1277 Value ExtractFromInsertTransposeChainState::fold() {
1278  Value valueToExtractFrom = extractOp.getVector();
1279  updateStateForNextIteration(valueToExtractFrom);
1280  while (nextInsertOp || nextTransposeOp) {
1281  // Case 1. If we hit a transpose, just compose the map and iterate.
1282  // Invariant: insert + transpose do not change rank, we can always compose.
1283  if (succeeded(handleTransposeOp())) {
1284  valueToExtractFrom = nextTransposeOp.getVector();
1285  updateStateForNextIteration(valueToExtractFrom);
1286  continue;
1287  }
1288 
1289  Value result;
1290  // Case 2: the position match exactly.
1291  if (succeeded(handleInsertOpWithMatchingPos(result)))
1292  return result;
1293 
1294  // Case 3: if the inserted position is a prefix of extractPosition, we can
1295  // just extract a portion of the source of the insert.
1296  if (succeeded(handleInsertOpWithPrefixPos(result)))
1297  return tryToFoldExtractOpInPlace(result);
1298 
1299  // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1300  // values. This is a more difficult case and we bail.
1301  auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
1302  if (isContainedWithin(extractPosition, insertedPos) ||
1303  intersectsWhereNonNegative(extractPosition, insertedPos))
1304  return Value();
1305 
1306  // Case 5: No intersection, we forward the extract to insertOp.dest().
1307  valueToExtractFrom = nextInsertOp.getDest();
1308  updateStateForNextIteration(valueToExtractFrom);
1309  }
1310  // If after all this we can fold, go for it.
1311  return tryToFoldExtractOpInPlace(valueToExtractFrom);
1312 }
1313 
1314 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1315 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1316  Operation *defOp = extractOp.getVector().getDefiningOp();
1317  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1318  return Value();
1319  Value source = defOp->getOperand(0);
1320  if (extractOp.getType() == source.getType())
1321  return source;
1322  auto getRank = [](Type type) {
1323  return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
1324  };
1325  unsigned broadcastSrcRank = getRank(source.getType());
1326  unsigned extractResultRank = getRank(extractOp.getType());
1327  if (extractResultRank >= broadcastSrcRank)
1328  return Value();
1329  // Check that the dimension of the result haven't been broadcasted.
1330  auto extractVecType = extractOp.getType().dyn_cast<VectorType>();
1331  auto broadcastVecType = source.getType().dyn_cast<VectorType>();
1332  if (extractVecType && broadcastVecType &&
1333  extractVecType.getShape() !=
1334  broadcastVecType.getShape().take_back(extractResultRank))
1335  return Value();
1336  auto extractPos = extractVector<int64_t>(extractOp.getPosition());
1337  unsigned rankDiff = broadcastSrcRank - extractResultRank;
1338  extractPos.erase(extractPos.begin(),
1339  std::next(extractPos.begin(), extractPos.size() - rankDiff));
1340  extractOp.setOperand(source);
1341  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1342  OpBuilder b(extractOp.getContext());
1343  extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1344  b.getI64ArrayAttr(extractPos));
1345  return extractOp.getResult();
1346 }
1347 
1348 // Fold extractOp with source coming from ShapeCast op.
1349 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1350  auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1351  if (!shapeCastOp)
1352  return Value();
1353  // Get the nth dimension size starting from lowest dimension.
1354  auto getDimReverse = [](VectorType type, int64_t n) {
1355  return type.getShape().take_back(n + 1).front();
1356  };
1357  int64_t destinationRank =
1358  extractOp.getType().isa<VectorType>()
1359  ? extractOp.getType().cast<VectorType>().getRank()
1360  : 0;
1361  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1362  return Value();
1363  if (destinationRank > 0) {
1364  auto destinationType = extractOp.getResult().getType().cast<VectorType>();
1365  for (int64_t i = 0; i < destinationRank; i++) {
1366  // The lowest dimension of of the destination must match the lowest
1367  // dimension of the shapecast op source.
1368  // TODO: This case could be support in a canonicalization pattern.
1369  if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1370  getDimReverse(destinationType, i))
1371  return Value();
1372  }
1373  }
1374  // Extract the strides associated with the extract op vector source. Then use
1375  // this to calculate a linearized position for the extract.
1376  auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
1377  std::reverse(extractedPos.begin(), extractedPos.end());
1378  SmallVector<int64_t, 4> strides;
1379  int64_t stride = 1;
1380  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1381  strides.push_back(stride);
1382  stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
1383  }
1384 
1385  int64_t position = linearize(extractedPos, strides);
1386  // Then extract the strides associated to the shapeCast op vector source and
1387  // delinearize the position using those strides.
1388  SmallVector<int64_t, 4> newStrides;
1389  int64_t numDimension =
1390  shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1391  stride = 1;
1392  for (int64_t i = 0; i < numDimension; i++) {
1393  newStrides.push_back(stride);
1394  stride *=
1395  getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1396  }
1397  std::reverse(newStrides.begin(), newStrides.end());
1398  SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
1399  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1400  OpBuilder b(extractOp.getContext());
1401  extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1402  b.getI64ArrayAttr(newPosition));
1403  extractOp.setOperand(shapeCastOp.getSource());
1404  return extractOp.getResult();
1405 }
1406 
1407 /// Fold an ExtractOp from ExtractStridedSliceOp.
1408 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1409  auto extractStridedSliceOp =
1410  extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1411  if (!extractStridedSliceOp)
1412  return Value();
1413  // Return if 'extractStridedSliceOp' has non-unit strides.
1414  if (extractStridedSliceOp.hasNonUnitStrides())
1415  return Value();
1416 
1417  // Trim offsets for dimensions fully extracted.
1418  auto sliceOffsets =
1419  extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1420  while (!sliceOffsets.empty()) {
1421  size_t lastOffset = sliceOffsets.size() - 1;
1422  if (sliceOffsets.back() != 0 ||
1423  extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1424  extractStridedSliceOp.getVectorType().getDimSize(lastOffset))
1425  break;
1426  sliceOffsets.pop_back();
1427  }
1428  unsigned destinationRank = 0;
1429  if (auto vecType = extractOp.getType().dyn_cast<VectorType>())
1430  destinationRank = vecType.getRank();
1431  // The dimensions of the result need to be untouched by the
1432  // extractStridedSlice op.
1433  if (destinationRank >
1434  extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size())
1435  return Value();
1436  auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
1437  assert(extractedPos.size() >= sliceOffsets.size());
1438  for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1439  extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1440  extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1441  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1442  OpBuilder b(extractOp.getContext());
1443  extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1444  b.getI64ArrayAttr(extractedPos));
1445  return extractOp.getResult();
1446 }
1447 
1448 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1450  int64_t destinationRank = op.getType().isa<VectorType>()
1451  ? op.getType().cast<VectorType>().getRank()
1452  : 0;
1453  auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
1454  while (insertOp) {
1455  int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1456  insertOp.getSourceVectorType().getRank();
1457  if (destinationRank > insertOp.getSourceVectorType().getRank())
1458  return Value();
1459  auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1460  auto extractOffsets = extractVector<int64_t>(op.getPosition());
1461 
1462  if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1463  return attr.cast<IntegerAttr>().getInt() != 1;
1464  }))
1465  return Value();
1466  bool disjoint = false;
1467  SmallVector<int64_t, 4> offsetDiffs;
1468  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1469  int64_t start = insertOffsets[dim];
1470  int64_t size =
1471  (dim < insertRankDiff)
1472  ? 1
1473  : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1474  int64_t end = start + size;
1475  int64_t offset = extractOffsets[dim];
1476  // Check if the start of the extract offset is in the interval inserted.
1477  if (start <= offset && offset < end) {
1478  if (dim >= insertRankDiff)
1479  offsetDiffs.push_back(offset - start);
1480  continue;
1481  }
1482  disjoint = true;
1483  break;
1484  }
1485  // The extract element chunk overlap with the vector inserted.
1486  if (!disjoint) {
1487  // If any of the inner dimensions are only partially inserted we have a
1488  // partial overlap.
1489  int64_t srcRankDiff =
1490  insertOp.getSourceVectorType().getRank() - destinationRank;
1491  for (int64_t i = 0; i < destinationRank; i++) {
1492  if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1493  insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1494  insertRankDiff))
1495  return Value();
1496  }
1497  op.getVectorMutable().assign(insertOp.getSource());
1498  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1499  OpBuilder b(op.getContext());
1500  op->setAttr(ExtractOp::getPositionAttrStrName(),
1501  b.getI64ArrayAttr(offsetDiffs));
1502  return op.getResult();
1503  }
1504  // If the chunk extracted is disjoint from the chunk inserted, keep
1505  // looking in the insert chain.
1506  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1507  }
1508  return Value();
1509 }
1510 
1511 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
1512  if (getPosition().empty())
1513  return getVector();
1515  return getResult();
1516  if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
1517  return res;
1518  if (auto res = foldExtractFromBroadcast(*this))
1519  return res;
1520  if (auto res = foldExtractFromShapeCast(*this))
1521  return res;
1522  if (auto val = foldExtractFromExtractStrided(*this))
1523  return val;
1524  if (auto val = foldExtractStridedOpFromInsertChain(*this))
1525  return val;
1526  return OpFoldResult();
1527 }
1528 
1529 namespace {
1530 
1531 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
1532 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
1533 public:
1535 
1536  LogicalResult matchAndRewrite(ExtractOp extractOp,
1537  PatternRewriter &rewriter) const override {
1538  Operation *defOp = extractOp.getVector().getDefiningOp();
1539  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1540  return failure();
1541 
1542  Value source = defOp->getOperand(0);
1543  if (extractOp.getType() == source.getType())
1544  return failure();
1545  auto getRank = [](Type type) {
1546  return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
1547  };
1548  unsigned broadcastSrcRank = getRank(source.getType());
1549  unsigned extractResultRank = getRank(extractOp.getType());
1550  // We only consider the case where the rank of the source is less than or
1551  // equal to the rank of the extract dst. The other cases are handled in the
1552  // folding patterns.
1553  if (extractResultRank < broadcastSrcRank)
1554  return failure();
1555  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1556  extractOp, extractOp.getType(), source);
1557  return success();
1558  }
1559 };
1560 
1561 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
1562 class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> {
1563 public:
1565 
1566  LogicalResult matchAndRewrite(ExtractOp extractOp,
1567  PatternRewriter &rewriter) const override {
1568  // Return if 'extractStridedSliceOp' operand is not defined by a
1569  // ConstantOp.
1570  auto constantOp = extractOp.getVector().getDefiningOp<arith::ConstantOp>();
1571  if (!constantOp)
1572  return failure();
1573  auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
1574  if (!dense)
1575  return failure();
1576  Attribute newAttr = dense.getSplatValue<Attribute>();
1577  if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
1578  newAttr = DenseElementsAttr::get(vecDstType, newAttr);
1579  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
1580  return success();
1581  }
1582 };
1583 
1584 } // namespace
1585 
1586 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1587  MLIRContext *context) {
1588  results.add<ExtractOpConstantFolder, ExtractOpFromBroadcast>(context);
1589 }
1590 
1591 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
1592  SmallVectorImpl<int64_t> &results) {
1593  for (auto attr : arrayAttr)
1594  results.push_back(attr.cast<IntegerAttr>().getInt());
1595 }
1596 
1597 //===----------------------------------------------------------------------===//
1598 // ExtractMapOp
1599 //===----------------------------------------------------------------------===//
1600 
1601 void ExtractMapOp::build(OpBuilder &builder, OperationState &result,
1602  Value vector, ValueRange ids,
1603  ArrayRef<int64_t> multiplicity,
1604  AffineMap permutationMap) {
1605  assert(ids.size() == multiplicity.size() &&
1606  ids.size() == permutationMap.getNumResults());
1607  assert(permutationMap.isProjectedPermutation());
1608  VectorType type = vector.getType().cast<VectorType>();
1609  SmallVector<int64_t, 4> newShape(type.getShape().begin(),
1610  type.getShape().end());
1611  for (unsigned i = 0, e = permutationMap.getNumResults(); i < e; i++) {
1612  AffineExpr expr = permutationMap.getResult(i);
1613  auto dim = expr.cast<AffineDimExpr>();
1614  newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i];
1615  }
1616  VectorType resultType = VectorType::get(newShape, type.getElementType());
1617  ExtractMapOp::build(builder, result, resultType, vector, ids);
1618 }
1619 
1621  if (getSourceVectorType().getRank() != getResultType().getRank())
1622  return emitOpError("expected source and destination vectors of same rank");
1623  unsigned numId = 0;
1624  for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; ++i) {
1625  if (getSourceVectorType().getDimSize(i) % getResultType().getDimSize(i) !=
1626  0)
1627  return emitOpError("source vector dimensions must be a multiple of "
1628  "destination vector dimensions");
1629  if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i))
1630  numId++;
1631  }
1632  if (numId != getIds().size())
1633  return emitOpError("expected number of ids must match the number of "
1634  "dimensions distributed");
1635  return success();
1636 }
1637 
1638 OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) {
1639  auto insert = getVector().getDefiningOp<vector::InsertMapOp>();
1640  if (insert == nullptr || getType() != insert.getVector().getType() ||
1641  getIds() != insert.getIds())
1642  return {};
1643  return insert.getVector();
1644 }
1645 
1646 void ExtractMapOp::getMultiplicity(SmallVectorImpl<int64_t> &multiplicity) {
1647  assert(multiplicity.empty());
1648  for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) {
1649  if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i))
1650  multiplicity.push_back(getSourceVectorType().getDimSize(i) /
1651  getResultType().getDimSize(i));
1652  }
1653 }
1654 
1655 template <typename MapOp>
1657  SmallVector<AffineExpr, 4> perm;
1658  // Check which dimension have a multiplicity greater than 1 and associated
1659  // them to the IDs in order.
1660  for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) {
1661  if (op.getSourceVectorType().getDimSize(i) !=
1662  op.getResultType().getDimSize(i))
1663  perm.push_back(getAffineDimExpr(i, op.getContext()));
1664  }
1665  auto map = AffineMap::get(op.getSourceVectorType().getRank(), 0, perm,
1666  op.getContext());
1667  return map;
1668 }
1669 
1670 AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); }
1671 
1672 //===----------------------------------------------------------------------===//
1673 // FmaOp
1674 //===----------------------------------------------------------------------===//
1675 
1676 Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
1677  return llvm::to_vector<4>(getVectorType().getShape());
1678 }
1679 
1680 //===----------------------------------------------------------------------===//
1681 // BroadcastOp
1682 //===----------------------------------------------------------------------===//
1683 
1685 mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
1686  std::pair<int, int> *mismatchingDims) {
1687  // Broadcast scalar to vector of the same element type.
1688  if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
1689  getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
1691  // From now on, only vectors broadcast.
1692  VectorType srcVectorType = srcType.dyn_cast<VectorType>();
1693  if (!srcVectorType)
1695 
1696  int64_t srcRank = srcVectorType.getRank();
1697  int64_t dstRank = dstVectorType.getRank();
1698  if (srcRank > dstRank)
1700  // Source has an exact match or singleton value for all trailing dimensions
1701  // (all leading dimensions are simply duplicated).
1702  int64_t lead = dstRank - srcRank;
1703  for (int64_t r = 0; r < srcRank; ++r) {
1704  int64_t srcDim = srcVectorType.getDimSize(r);
1705  int64_t dstDim = dstVectorType.getDimSize(lead + r);
1706  if (srcDim != 1 && srcDim != dstDim) {
1707  if (mismatchingDims) {
1708  mismatchingDims->first = srcDim;
1709  mismatchingDims->second = dstDim;
1710  }
1712  }
1713  }
1714 
1716 }
1717 
1719  std::pair<int, int> mismatchingDims;
1720  BroadcastableToResult res =
1721  isBroadcastableTo(getSourceType(), getVectorType(), &mismatchingDims);
1722  if (res == BroadcastableToResult::Success)
1723  return success();
1725  return emitOpError("source rank higher than destination rank");
1727  return emitOpError("dimension mismatch (")
1728  << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
1730  return emitOpError("source type is not a vector");
1731  llvm_unreachable("unexpected vector.broadcast op error");
1732 }
1733 
1734 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
1735  if (getSourceType() == getVectorType())
1736  return getSource();
1737  if (!operands[0])
1738  return {};
1739  auto vectorType = getVectorType();
1740  if (operands[0].isa<IntegerAttr, FloatAttr>())
1741  return DenseElementsAttr::get(vectorType, operands[0]);
1742  if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
1743  return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
1744  return {};
1745 }
1746 
1747 namespace {
1748 
1749 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
1750 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
1752 
1753  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
1754  PatternRewriter &rewriter) const override {
1755  auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
1756  if (!srcBroadcast)
1757  return failure();
1758  rewriter.replaceOpWithNewOp<BroadcastOp>(
1759  broadcastOp, broadcastOp.getVectorType(), srcBroadcast.getSource());
1760  return success();
1761  }
1762 };
1763 } // namespace
1764 
1765 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1766  MLIRContext *context) {
1767  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
1768  // calling `populateCastAwayVectorLeadingOneDimPatterns`
1769  results.add<BroadcastFolder>(context);
1770 }
1771 
1772 //===----------------------------------------------------------------------===//
1773 // ShuffleOp
1774 //===----------------------------------------------------------------------===//
1775 
1776 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
1777  Value v2, ArrayRef<int64_t> mask) {
1778  build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask));
1779 }
1780 
1782  VectorType resultType = getVectorType();
1783  VectorType v1Type = getV1VectorType();
1784  VectorType v2Type = getV2VectorType();
1785  // Verify ranks.
1786  int64_t resRank = resultType.getRank();
1787  int64_t v1Rank = v1Type.getRank();
1788  int64_t v2Rank = v2Type.getRank();
1789  if (resRank != v1Rank || v1Rank != v2Rank)
1790  return emitOpError("rank mismatch");
1791  // Verify all but leading dimension sizes.
1792  for (int64_t r = 1; r < v1Rank; ++r) {
1793  int64_t resDim = resultType.getDimSize(r);
1794  int64_t v1Dim = v1Type.getDimSize(r);
1795  int64_t v2Dim = v2Type.getDimSize(r);
1796  if (resDim != v1Dim || v1Dim != v2Dim)
1797  return emitOpError("dimension mismatch");
1798  }
1799  // Verify mask length.
1800  auto maskAttr = getMask().getValue();
1801  int64_t maskLength = maskAttr.size();
1802  if (maskLength <= 0)
1803  return emitOpError("invalid mask length");
1804  if (maskLength != resultType.getDimSize(0))
1805  return emitOpError("mask length mismatch");
1806  // Verify all indices.
1807  int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
1808  for (const auto &en : llvm::enumerate(maskAttr)) {
1809  auto attr = en.value().dyn_cast<IntegerAttr>();
1810  if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
1811  return emitOpError("mask index #") << (en.index() + 1) << " out of range";
1812  }
1813  return success();
1814 }
1815 
1817 ShuffleOp::inferReturnTypes(MLIRContext *, Optional<Location>,
1818  ValueRange operands, DictionaryAttr attributes,
1819  RegionRange,
1820  SmallVectorImpl<Type> &inferredReturnTypes) {
1821  ShuffleOp::Adaptor op(operands, attributes);
1822  auto v1Type = op.getV1().getType().cast<VectorType>();
1823  // Construct resulting type: leading dimension matches mask length,
1824  // all trailing dimensions match the operands.
1826  shape.reserve(v1Type.getRank());
1827  shape.push_back(std::max<size_t>(1, op.getMask().size()));
1828  llvm::append_range(shape, v1Type.getShape().drop_front());
1829  inferredReturnTypes.push_back(
1830  VectorType::get(shape, v1Type.getElementType()));
1831  return success();
1832 }
1833 
1834 static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
1835  uint64_t expected = begin;
1836  return idxArr.size() == width &&
1837  llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
1838  [&expected](auto attr) {
1839  return attr.getZExtValue() == expected++;
1840  });
1841 }
1842 
1843 OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
1844  // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
1845  if (!getV1VectorType().isScalable() &&
1846  isStepIndexArray(getMask(), 0, getV1VectorType().getDimSize(0)))
1847  return getV1();
1848  // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
1849  if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
1850  isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
1851  getV2VectorType().getDimSize(0)))
1852  return getV2();
1853 
1854  Attribute lhs = operands.front(), rhs = operands.back();
1855  if (!lhs || !rhs)
1856  return {};
1857 
1858  auto lhsType = lhs.cast<DenseElementsAttr>().getType().cast<VectorType>();
1859  // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
1860  // manipulation.
1861  if (lhsType.getRank() != 1)
1862  return {};
1863  int64_t lhsSize = lhsType.getDimSize(0);
1864 
1865  SmallVector<Attribute> results;
1866  auto lhsElements = lhs.cast<DenseElementsAttr>().getValues<Attribute>();
1867  auto rhsElements = rhs.cast<DenseElementsAttr>().getValues<Attribute>();
1868  for (const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
1869  int64_t i = index.getZExtValue();
1870  if (i >= lhsSize) {
1871  results.push_back(rhsElements[i - lhsSize]);
1872  } else {
1873  results.push_back(lhsElements[i]);
1874  }
1875  }
1876 
1877  return DenseElementsAttr::get(getVectorType(), results);
1878 }
1879 
1880 namespace {
1881 
1882 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
1883 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
1884 public:
1886 
1887  LogicalResult matchAndRewrite(ShuffleOp op,
1888  PatternRewriter &rewriter) const override {
1889  auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
1890  auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
1891 
1892  if (!v1Splat || !v2Splat)
1893  return failure();
1894 
1895  if (v1Splat.getInput() != v2Splat.getInput())
1896  return failure();
1897 
1898  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
1899  return success();
1900  }
1901 };
1902 
1903 } // namespace
1904 
1905 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
1906  MLIRContext *context) {
1907  results.add<ShuffleSplat>(context);
1908 }
1909 
1910 //===----------------------------------------------------------------------===//
1911 // InsertElementOp
1912 //===----------------------------------------------------------------------===//
1913 
1914 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
1915  Value source, Value dest) {
1916  build(builder, result, source, dest, {});
1917 }
1918 
1920  auto dstVectorType = getDestVectorType();
1921  if (dstVectorType.getRank() == 0) {
1922  if (getPosition())
1923  return emitOpError("expected position to be empty with 0-D vector");
1924  return success();
1925  }
1926  if (dstVectorType.getRank() != 1)
1927  return emitOpError("unexpected >1 vector rank");
1928  if (!getPosition())
1929  return emitOpError("expected position for 1-D vector");
1930  return success();
1931 }
1932 
1933 OpFoldResult vector::InsertElementOp::fold(ArrayRef<Attribute> operands) {
1934  // Skip the 0-D vector here.
1935  if (operands.size() < 3)
1936  return {};
1937 
1938  Attribute src = operands[0];
1939  Attribute dst = operands[1];
1940  Attribute pos = operands[2];
1941  if (!src || !dst || !pos)
1942  return {};
1943 
1944  auto dstElements = dst.cast<DenseElementsAttr>().getValues<Attribute>();
1945 
1946  SmallVector<Attribute> results(dstElements);
1947 
1948  auto attr = pos.dyn_cast<IntegerAttr>();
1949  uint64_t posIdx = attr.getInt();
1950 
1951  results[posIdx] = src;
1952 
1953  return DenseElementsAttr::get(getDestVectorType(), results);
1954 }
1955 
1956 //===----------------------------------------------------------------------===//
1957 // InsertOp
1958 //===----------------------------------------------------------------------===//
1959 
1960 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1961  Value dest, ArrayRef<int64_t> position) {
1962  result.addOperands({source, dest});
1963  auto positionAttr = getVectorSubscriptAttr(builder, position);
1964  result.addTypes(dest.getType());
1965  result.addAttribute(getPositionAttrStrName(), positionAttr);
1966 }
1967 
1968 // Convenience builder which assumes the values are constant indices.
1969 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1970  Value dest, ValueRange position) {
1971  SmallVector<int64_t, 4> positionConstants =
1972  llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
1973  return pos.getDefiningOp<arith::ConstantIndexOp>().value();
1974  }));
1975  build(builder, result, source, dest, positionConstants);
1976 }
1977 
1979  auto positionAttr = getPosition().getValue();
1980  auto destVectorType = getDestVectorType();
1981  if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
1982  return emitOpError(
1983  "expected position attribute of rank smaller than dest vector rank");
1984  auto srcVectorType = getSourceType().dyn_cast<VectorType>();
1985  if (srcVectorType &&
1986  (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
1987  static_cast<unsigned>(destVectorType.getRank())))
1988  return emitOpError("expected position attribute rank + source rank to "
1989  "match dest vector rank");
1990  if (!srcVectorType &&
1991  (positionAttr.size() != static_cast<unsigned>(destVectorType.getRank())))
1992  return emitOpError(
1993  "expected position attribute rank to match the dest vector rank");
1994  for (const auto &en : llvm::enumerate(positionAttr)) {
1995  auto attr = en.value().dyn_cast<IntegerAttr>();
1996  if (!attr || attr.getInt() < 0 ||
1997  attr.getInt() >= destVectorType.getDimSize(en.index()))
1998  return emitOpError("expected position attribute #")
1999  << (en.index() + 1)
2000  << " to be a non-negative integer smaller than the corresponding "
2001  "dest vector dimension";
2002  }
2003  return success();
2004 }
2005 
2006 namespace {
2007 
2008 // If insertOp is only inserting unit dimensions it can be transformed to a
2009 // broadcast.
2010 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
2011 public:
2013 
2014  LogicalResult matchAndRewrite(InsertOp insertOp,
2015  PatternRewriter &rewriter) const override {
2016  auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
2017  if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2018  srcVecType.getNumElements())
2019  return failure();
2020  rewriter.replaceOpWithNewOp<BroadcastOp>(
2021  insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2022  return success();
2023  }
2024 };
2025 
2026 /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
2027 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
2028 public:
2030 
2031  LogicalResult matchAndRewrite(InsertOp op,
2032  PatternRewriter &rewriter) const override {
2033  auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2034  auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2035 
2036  if (!srcSplat || !dstSplat)
2037  return failure();
2038 
2039  if (srcSplat.getInput() != dstSplat.getInput())
2040  return failure();
2041 
2042  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
2043  return success();
2044  }
2045 };
2046 
2047 } // namespace
2048 
2049 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
2050  MLIRContext *context) {
2051  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
2052 }
2053 
2054 // Eliminates insert operations that produce values identical to their source
2055 // value. This happens when the source and destination vectors have identical
2056 // sizes.
2057 OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
2058  if (getPosition().empty())
2059  return getSource();
2060  return {};
2061 }
2062 
2063 //===----------------------------------------------------------------------===//
2064 // InsertMapOp
2065 //===----------------------------------------------------------------------===//
2066 
2068  if (getSourceVectorType().getRank() != getResultType().getRank())
2069  return emitOpError("expected source and destination vectors of same rank");
2070  unsigned numId = 0;
2071  for (unsigned i = 0, e = getResultType().getRank(); i < e; i++) {
2072  if (getResultType().getDimSize(i) % getSourceVectorType().getDimSize(i) !=
2073  0)
2074  return emitOpError(
2075  "destination vector size must be a multiple of source vector size");
2076  if (getResultType().getDimSize(i) != getSourceVectorType().getDimSize(i))
2077  numId++;
2078  }
2079  if (numId != getIds().size())
2080  return emitOpError("expected number of ids must match the number of "
2081  "dimensions distributed");
2082  return success();
2083 }
2084 
2085 AffineMap InsertMapOp::map() { return calculateImplicitMap(*this); }
2086 
2087 //===----------------------------------------------------------------------===//
2088 // InsertStridedSliceOp
2089 //===----------------------------------------------------------------------===//
2090 
2091 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
2092  Value source, Value dest,
2093  ArrayRef<int64_t> offsets,
2094  ArrayRef<int64_t> strides) {
2095  result.addOperands({source, dest});
2096  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
2097  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
2098  result.addTypes(dest.getType());
2099  result.addAttribute(getOffsetsAttrStrName(), offsetsAttr);
2100  result.addAttribute(getStridesAttrStrName(), stridesAttr);
2101 }
2102 
2103 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
2104 template <typename OpType>
2106  ArrayAttr arrayAttr,
2107  ArrayRef<int64_t> shape,
2108  StringRef attrName) {
2109  if (arrayAttr.size() > shape.size())
2110  return op.emitOpError("expected ")
2111  << attrName << " attribute of rank smaller than vector rank";
2112  return success();
2113 }
2114 
2115 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2116 // interval. If `halfOpen` is true then the admissible interval is [min, max).
2117 // Otherwise, the admissible interval is [min, max].
2118 template <typename OpType>
2119 static LogicalResult
2120 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
2121  int64_t max, StringRef attrName,
2122  bool halfOpen = true) {
2123  for (auto attr : arrayAttr) {
2124  auto val = attr.cast<IntegerAttr>().getInt();
2125  auto upper = max;
2126  if (!halfOpen)
2127  upper += 1;
2128  if (val < min || val >= upper)
2129  return op.emitOpError("expected ") << attrName << " to be confined to ["
2130  << min << ", " << upper << ")";
2131  }
2132  return success();
2133 }
2134 
2135 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2136 // interval. If `halfOpen` is true then the admissible interval is [min, max).
2137 // Otherwise, the admissible interval is [min, max].
2138 template <typename OpType>
2139 static LogicalResult
2140 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
2141  ArrayRef<int64_t> shape, StringRef attrName,
2142  bool halfOpen = true, int64_t min = 0) {
2143  assert(arrayAttr.size() <= shape.size());
2144  unsigned index = 0;
2145  for (auto it : llvm::zip(arrayAttr, shape)) {
2146  auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
2147  auto max = std::get<1>(it);
2148  if (!halfOpen)
2149  max += 1;
2150  if (val < min || val >= max)
2151  return op.emitOpError("expected ")
2152  << attrName << " dimension " << index << " to be confined to ["
2153  << min << ", " << max << ")";
2154  ++index;
2155  }
2156  return success();
2157 }
2158 
2159 // Returns true if all integers in `arrayAttr` are in the interval [min, max}.
2160 // interval. If `halfOpen` is true then the admissible interval is [min, max).
2161 // Otherwise, the admissible interval is [min, max].
2162 template <typename OpType>
2164  OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
2165  ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
2166  bool halfOpen = true, int64_t min = 1) {
2167  assert(arrayAttr1.size() <= shape.size());
2168  assert(arrayAttr2.size() <= shape.size());
2169  unsigned index = 0;
2170  for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) {
2171  auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
2172  auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
2173  auto max = std::get<2>(it);
2174  if (!halfOpen)
2175  max += 1;
2176  if (val1 + val2 < 0 || val1 + val2 >= max)
2177  return op.emitOpError("expected sum(")
2178  << attrName1 << ", " << attrName2 << ") dimension " << index
2179  << " to be confined to [" << min << ", " << max << ")";
2180  ++index;
2181  }
2182  return success();
2183 }
2184 
2185 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
2186  MLIRContext *context) {
2187  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
2188  return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
2189  });
2190  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
2191 }
2192 
2194  auto sourceVectorType = getSourceVectorType();
2195  auto destVectorType = getDestVectorType();
2196  auto offsets = getOffsetsAttr();
2197  auto strides = getStridesAttr();
2198  if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
2199  return emitOpError(
2200  "expected offsets of same size as destination vector rank");
2201  if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
2202  return emitOpError("expected strides of same size as source vector rank");
2203  if (sourceVectorType.getRank() > destVectorType.getRank())
2204  return emitOpError(
2205  "expected source rank to be smaller than destination rank");
2206 
2207  auto sourceShape = sourceVectorType.getShape();
2208  auto destShape = destVectorType.getShape();
2209  SmallVector<int64_t, 4> sourceShapeAsDestShape(
2210  destShape.size() - sourceShape.size(), 0);
2211  sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
2212  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
2213  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
2214  if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
2215  offName)) ||
2216  failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1,
2217  stridesName,
2218  /*halfOpen=*/false)) ||
2220  *this, offsets,
2221  makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
2222  offName, "source vector shape",
2223  /*halfOpen=*/false, /*min=*/1)))
2224  return failure();
2225 
2226  return success();
2227 }
2228 
2229 namespace {
2230 /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
2231 /// SplatOp(X):dst_type) to SplatOp(X):dst_type.
2232 class FoldInsertStridedSliceSplat final
2233  : public OpRewritePattern<InsertStridedSliceOp> {
2234 public:
2236 
2237  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
2238  PatternRewriter &rewriter) const override {
2239  auto srcSplatOp =
2240  insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
2241  auto destSplatOp =
2242  insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
2243 
2244  if (!srcSplatOp || !destSplatOp)
2245  return failure();
2246 
2247  if (srcSplatOp.getInput() != destSplatOp.getInput())
2248  return failure();
2249 
2250  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2251  return success();
2252  }
2253 };
2254 
2255 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
2256 /// to dst.
2257 class FoldInsertStridedSliceOfExtract final
2258  : public OpRewritePattern<InsertStridedSliceOp> {
2259 public:
2261 
2262  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
2263  PatternRewriter &rewriter) const override {
2264  auto extractStridedSliceOp =
2265  insertStridedSliceOp.getSource()
2266  .getDefiningOp<vector::ExtractStridedSliceOp>();
2267 
2268  if (!extractStridedSliceOp)
2269  return failure();
2270 
2271  if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
2272  return failure();
2273 
2274  // Check if have the same strides and offsets.
2275  if (extractStridedSliceOp.getStrides() !=
2276  insertStridedSliceOp.getStrides() ||
2277  extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
2278  return failure();
2279 
2280  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2281  return success();
2282  }
2283 };
2284 
2285 } // namespace
2286 
2287 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
2288  RewritePatternSet &results, MLIRContext *context) {
2289  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract>(
2290  context);
2291 }
2292 
2293 OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
2294  if (getSourceVectorType() == getDestVectorType())
2295  return getSource();
2296  return {};
2297 }
2298 
2299 //===----------------------------------------------------------------------===//
2300 // OuterProductOp
2301 //===----------------------------------------------------------------------===//
2302 
2303 /// Build an op without mask, use the type of `acc` as the return type.
2304 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
2305  Value lhs, Value rhs, Value acc) {
2306  result.addOperands({lhs, rhs, acc});
2307  result.addTypes(acc.getType());
2308 }
2309 
2311  p << " " << getLhs() << ", " << getRhs();
2312  if (!getAcc().empty()) {
2313  p << ", " << getAcc();
2314  p.printOptionalAttrDict((*this)->getAttrs());
2315  }
2316  p << " : " << getLhs().getType() << ", " << getRhs().getType();
2317 }
2318 
2319 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
2320  SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
2321  Type tLHS, tRHS;
2322  if (parser.parseOperandList(operandsInfo) ||
2323  parser.parseOptionalAttrDict(result.attributes) ||
2324  parser.parseColonType(tLHS) || parser.parseComma() ||
2325  parser.parseType(tRHS))
2326  return failure();
2327  if (operandsInfo.size() < 2)
2328  return parser.emitError(parser.getNameLoc(),
2329  "expected at least 2 operands");
2330  VectorType vLHS = tLHS.dyn_cast<VectorType>();
2331  VectorType vRHS = tRHS.dyn_cast<VectorType>();
2332  if (!vLHS)
2333  return parser.emitError(parser.getNameLoc(),
2334  "expected vector type for operand #1");
2335  VectorType resType =
2336  vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
2337  vLHS.getElementType())
2338  : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
2339 
2340  if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
2341  result.attributes.append(
2342  OuterProductOp::getKindAttrStrName(),
2343  CombiningKindAttr::get(OuterProductOp::getDefaultKind(),
2344  result.getContext()));
2345  }
2346 
2347  return failure(
2348  parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
2349  parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
2350  (operandsInfo.size() > 2 &&
2351  parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
2352  parser.addTypeToList(resType, result.types));
2353 }
2354 
2356  Type tRHS = getOperandTypeRHS();
2357  VectorType vLHS = getOperandVectorTypeLHS(),
2358  vRHS = tRHS.dyn_cast<VectorType>(),
2359  vACC = getOperandVectorTypeACC(), vRES = getVectorType();
2360 
2361  if (vLHS.getRank() != 1)
2362  return emitOpError("expected 1-d vector for operand #1");
2363 
2364  if (vRHS) {
2365  // Proper OUTER operation.
2366  if (vRHS.getRank() != 1)
2367  return emitOpError("expected 1-d vector for operand #2");
2368  if (vRES.getRank() != 2)
2369  return emitOpError("expected 2-d vector result");
2370  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
2371  return emitOpError("expected #1 operand dim to match result dim #1");
2372  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
2373  return emitOpError("expected #2 operand dim to match result dim #2");
2374  } else {
2375  // An AXPY operation.
2376  if (vRES.getRank() != 1)
2377  return emitOpError("expected 1-d vector result");
2378  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
2379  return emitOpError("expected #1 operand dim to match result dim #1");
2380  }
2381 
2382  if (vACC && vACC != vRES)
2383  return emitOpError("expected operand #3 of same type as result type");
2384 
2385  // Verify supported combining kind.
2386  if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
2387  return emitOpError("unsupported outerproduct type");
2388 
2389  return success();
2390 }
2391 
2392 //===----------------------------------------------------------------------===//
2393 // ReshapeOp
2394 //===----------------------------------------------------------------------===//
2395 
2397  // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
2398  auto inputVectorType = getInputVectorType();
2399  auto outputVectorType = getOutputVectorType();
2400  int64_t inputShapeRank = getNumInputShapeSizes();
2401  int64_t outputShapeRank = getNumOutputShapeSizes();
2402  SmallVector<int64_t, 4> fixedVectorSizes;
2403  getFixedVectorSizes(fixedVectorSizes);
2404  int64_t numFixedVectorSizes = fixedVectorSizes.size();
2405 
2406  if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
2407  return emitError("invalid input shape for vector type ") << inputVectorType;
2408 
2409  if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
2410  return emitError("invalid output shape for vector type ")
2411  << outputVectorType;
2412 
2413  // Verify that the 'fixedVectorSizes' match an input/output vector shape
2414  // suffix.
2415  unsigned inputVectorRank = inputVectorType.getRank();
2416  for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
2417  unsigned index = inputVectorRank - numFixedVectorSizes - i;
2418  if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
2419  return emitError("fixed vector size must match input vector for dim ")
2420  << i;
2421  }
2422 
2423  unsigned outputVectorRank = outputVectorType.getRank();
2424  for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
2425  unsigned index = outputVectorRank - numFixedVectorSizes - i;
2426  if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
2427  return emitError("fixed vector size must match output vector for dim ")
2428  << i;
2429  }
2430 
2431  // If all shape operands are produced by constant ops, verify that product
2432  // of dimensions for input/output shape match.
2433  auto isDefByConstant = [](Value operand) {
2434  return isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
2435  };
2436  if (llvm::all_of(getInputShape(), isDefByConstant) &&
2437  llvm::all_of(getOutputShape(), isDefByConstant)) {
2438  int64_t numInputElements = 1;
2439  for (auto operand : getInputShape())
2440  numInputElements *=
2441  cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
2442  int64_t numOutputElements = 1;
2443  for (auto operand : getOutputShape())
2444  numOutputElements *=
2445  cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
2446  if (numInputElements != numOutputElements)
2447  return emitError("product of input and output shape sizes must match");
2448  }
2449  return success();
2450 }
2451 
2452 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
2453  populateFromInt64AttrArray(getFixedVectorSizes(), results);
2454 }
2455 
2456 //===----------------------------------------------------------------------===//
2457 // ExtractStridedSliceOp
2458 //===----------------------------------------------------------------------===//
2459 
2460 // Inference works as follows:
2461 // 1. Add 'sizes' from prefix of dims in 'offsets'.
2462 // 2. Add sizes from 'vectorType' for remaining dims.
2464  ArrayAttr offsets, ArrayAttr sizes,
2465  ArrayAttr strides) {
2466  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
2467  SmallVector<int64_t, 4> shape;
2468  shape.reserve(vectorType.getRank());
2469  unsigned idx = 0;
2470  for (unsigned e = offsets.size(); idx < e; ++idx)
2471  shape.push_back(sizes[idx].cast<IntegerAttr>().getInt());
2472  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
2473  shape.push_back(vectorType.getShape()[idx]);
2474 
2475  return VectorType::get(shape, vectorType.getElementType());
2476 }
2477 
2478 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
2479  Value source, ArrayRef<int64_t> offsets,
2480  ArrayRef<int64_t> sizes,
2481  ArrayRef<int64_t> strides) {
2482  result.addOperands(source);
2483  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
2484  auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
2485  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
2486  result.addTypes(
2487  inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
2488  offsetsAttr, sizesAttr, stridesAttr));
2489  result.addAttribute(getOffsetsAttrStrName(), offsetsAttr);
2490  result.addAttribute(getSizesAttrStrName(), sizesAttr);
2491  result.addAttribute(getStridesAttrStrName(), stridesAttr);
2492 }
2493 
2495  auto type = getVectorType();
2496  auto offsets = getOffsetsAttr();
2497  auto sizes = getSizesAttr();
2498  auto strides = getStridesAttr();
2499  if (offsets.size() != sizes.size() || offsets.size() != strides.size())
2500  return emitOpError(
2501  "expected offsets, sizes and strides attributes of same size");
2502 
2503  auto shape = type.getShape();
2504  auto offName = getOffsetsAttrName();
2505  auto sizesName = getSizesAttrName();
2506  auto stridesName = getStridesAttrName();
2507  if (failed(
2508  isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
2509  failed(
2510  isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
2511  failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
2512  stridesName)) ||
2513  failed(
2514  isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
2515  failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
2516  /*halfOpen=*/false,
2517  /*min=*/1)) ||
2518  failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1,
2519  stridesName,
2520  /*halfOpen=*/false)) ||
2521  failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
2522  shape, offName, sizesName,
2523  /*halfOpen=*/false)))
2524  return failure();
2525 
2526  auto resultType =
2527  inferStridedSliceOpResultType(getVectorType(), offsets, sizes, strides);
2528  if (getResult().getType() != resultType)
2529  return emitOpError("expected result type to be ") << resultType;
2530 
2531  return success();
2532 }
2533 
2534 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
2535 // to use the source of the InsertStrided ops if we can detect that the
2536 // extracted vector is a subset of one of the vector inserted.
2537 static LogicalResult
2538 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
2539  // Helper to extract integer out of ArrayAttr.
2540  auto getElement = [](ArrayAttr array, int idx) {
2541  return array[idx].cast<IntegerAttr>().getInt();
2542  };
2543  ArrayAttr extractOffsets = op.getOffsets();
2544  ArrayAttr extractStrides = op.getStrides();
2545  ArrayAttr extractSizes = op.getSizes();
2546  auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
2547  while (insertOp) {
2548  if (op.getVectorType().getRank() !=
2549  insertOp.getSourceVectorType().getRank())
2550  return failure();
2551  ArrayAttr insertOffsets = insertOp.getOffsets();
2552  ArrayAttr insertStrides = insertOp.getStrides();
2553  // If the rank of extract is greater than the rank of insert, we are likely
2554  // extracting a partial chunk of the vector inserted.
2555  if (extractOffsets.size() > insertOffsets.size())
2556  return failure();
2557  bool patialoverlap = false;
2558  bool disjoint = false;
2559  SmallVector<int64_t, 4> offsetDiffs;
2560  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
2561  if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
2562  return failure();
2563  int64_t start = getElement(insertOffsets, dim);
2564  int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
2565  int64_t offset = getElement(extractOffsets, dim);
2566  int64_t size = getElement(extractSizes, dim);
2567  // Check if the start of the extract offset is in the interval inserted.
2568  if (start <= offset && offset < end) {
2569  // If the extract interval overlaps but is not fully included we may
2570  // have a partial overlap that will prevent any folding.
2571  if (offset + size > end)
2572  patialoverlap = true;
2573  offsetDiffs.push_back(offset - start);
2574  continue;
2575  }
2576  disjoint = true;
2577  break;
2578  }
2579  // The extract element chunk is a subset of the insert element.
2580  if (!disjoint && !patialoverlap) {
2581  op.setOperand(insertOp.getSource());
2582  // OpBuilder is only used as a helper to build an I64ArrayAttr.
2583  OpBuilder b(op.getContext());
2584  op->setAttr(ExtractStridedSliceOp::getOffsetsAttrStrName(),
2585  b.getI64ArrayAttr(offsetDiffs));
2586  return success();
2587  }
2588  // If the chunk extracted is disjoint from the chunk inserted, keep looking
2589  // in the insert chain.
2590  if (disjoint)
2591  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2592  else {
2593  // The extracted vector partially overlap the inserted vector, we cannot
2594  // fold.
2595  return failure();
2596  }
2597  }
2598  return failure();
2599 }
2600 
2601 OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
2602  if (getVectorType() == getResult().getType())
2603  return getVector();
2605  return getResult();
2606  return {};
2607 }
2608 
2609 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
2610  populateFromInt64AttrArray(getOffsets(), results);
2611 }
2612 
2613 namespace {
2614 
2615 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
2616 // ConstantMaskOp.
2617 class StridedSliceConstantMaskFolder final
2618  : public OpRewritePattern<ExtractStridedSliceOp> {
2619 public:
2621 
2622  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2623  PatternRewriter &rewriter) const override {
2624  // Return if 'extractStridedSliceOp' operand is not defined by a
2625  // ConstantMaskOp.
2626  auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
2627  auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
2628  if (!constantMaskOp)
2629  return failure();
2630  // Return if 'extractStridedSliceOp' has non-unit strides.
2631  if (extractStridedSliceOp.hasNonUnitStrides())
2632  return failure();
2633  // Gather constant mask dimension sizes.
2634  SmallVector<int64_t, 4> maskDimSizes;
2635  populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
2636  // Gather strided slice offsets and sizes.
2637  SmallVector<int64_t, 4> sliceOffsets;
2638  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
2639  sliceOffsets);
2640  SmallVector<int64_t, 4> sliceSizes;
2641  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
2642 
2643  // Compute slice of vector mask region.
2644  SmallVector<int64_t, 4> sliceMaskDimSizes;
2645  assert(sliceOffsets.size() == maskDimSizes.size());
2646  for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
2647  int64_t maskDimSize = std::get<0>(it);
2648  int64_t sliceOffset = std::get<1>(it);
2649  int64_t sliceSize = std::get<2>(it);
2650  int64_t sliceMaskDimSize = std::max(
2651  static_cast<int64_t>(0),
2652  std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
2653  sliceMaskDimSizes.push_back(sliceMaskDimSize);
2654  }
2655  // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
2656  // region is a conjunction of mask dim intervals).
2657  if (llvm::is_contained(sliceMaskDimSizes, 0))
2658  sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
2659 
2660  // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
2661  // region.
2662  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
2663  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
2664  vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
2665  return success();
2666  }
2667 };
2668 
2669 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
2670 class StridedSliceConstantFolder final
2671  : public OpRewritePattern<ExtractStridedSliceOp> {
2672 public:
2674 
2675  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2676  PatternRewriter &rewriter) const override {
2677  // Return if 'extractStridedSliceOp' operand is not defined by a
2678  // ConstantOp.
2679  auto constantOp =
2680  extractStridedSliceOp.getVector().getDefiningOp<arith::ConstantOp>();
2681  if (!constantOp)
2682  return failure();
2683  auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
2684  if (!dense)
2685  return failure();
2686  auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(),
2687  dense.getSplatValue<Attribute>());
2688  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
2689  newAttr);
2690  return success();
2691  }
2692 };
2693 
2694 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
2695 // BroadcastOp(ExtractStrideSliceOp).
2696 class StridedSliceBroadcast final
2697  : public OpRewritePattern<ExtractStridedSliceOp> {
2698 public:
2700 
2701  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
2702  PatternRewriter &rewriter) const override {
2703  auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
2704  if (!broadcast)
2705  return failure();
2706  auto srcVecType = broadcast.getSource().getType().dyn_cast<VectorType>();
2707  unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
2708  auto dstVecType = op.getType().cast<VectorType>();
2709  unsigned dstRank = dstVecType.getRank();
2710  unsigned rankDiff = dstRank - srcRank;
2711  // Check if the most inner dimensions of the source of the broadcast are the
2712  // same as the destination of the extract. If this is the case we can just
2713  // use a broadcast as the original dimensions are untouched.
2714  bool lowerDimMatch = true;
2715  for (unsigned i = 0; i < srcRank; i++) {
2716  if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
2717  lowerDimMatch = false;
2718  break;
2719  }
2720  }
2721  Value source = broadcast.getSource();
2722  // If the inner dimensions don't match, it means we need to extract from the
2723  // source of the orignal broadcast and then broadcast the extracted value.
2724  // We also need to handle degenerated cases where the source is effectively
2725  // just a single scalar.
2726  bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
2727  if (!lowerDimMatch && !isScalarSrc) {
2728  source = rewriter.create<ExtractStridedSliceOp>(
2729  op->getLoc(), source,
2730  getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
2731  getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
2732  getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
2733  }
2734  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
2735  return success();
2736  }
2737 };
2738 
2739 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
2740 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
2741 public:
2743 
2744  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
2745  PatternRewriter &rewriter) const override {
2746  auto splat = op.getVector().getDefiningOp<SplatOp>();
2747  if (!splat)
2748  return failure();
2749  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
2750  return success();
2751  }
2752 };
2753 
2754 } // namespace
2755 
2756 void ExtractStridedSliceOp::getCanonicalizationPatterns(
2757  RewritePatternSet &results, MLIRContext *context) {
2758  // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
2759  // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
2760  results.add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
2761  StridedSliceBroadcast, StridedSliceSplat>(context);
2762 }
2763 
2764 //===----------------------------------------------------------------------===//
2765 // TransferReadOp
2766 //===----------------------------------------------------------------------===//
2767 
2768 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
2769 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2770  VectorType vectorType, Value source,
2771  ValueRange indices, AffineMapAttr permutationMapAttr,
2772  /*optional*/ ArrayAttr inBoundsAttr) {
2773  Type elemType = source.getType().cast<ShapedType>().getElementType();
2774  Value padding = builder.create<arith::ConstantOp>(
2775  result.location, elemType, builder.getZeroAttr(elemType));
2776  build(builder, result, vectorType, source, indices, permutationMapAttr,
2777  padding, /*mask=*/Value(), inBoundsAttr);
2778 }
2779 
2780 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
2781 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2782  VectorType vectorType, Value source,
2783  ValueRange indices, AffineMap permutationMap,
2784  Optional<ArrayRef<bool>> inBounds) {
2785  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
2786  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
2787  ? builder.getBoolArrayAttr(inBounds.value())
2788  : ArrayAttr();
2789  build(builder, result, vectorType, source, indices, permutationMapAttr,
2790  inBoundsAttr);
2791 }
2792 
2793 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
2794 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2795  VectorType vectorType, Value source,
2796  ValueRange indices, Value padding,
2797  Optional<ArrayRef<bool>> inBounds) {
2798  AffineMap permutationMap = getTransferMinorIdentityMap(
2799  source.getType().cast<ShapedType>(), vectorType);
2800  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
2801  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
2802  ? builder.getBoolArrayAttr(inBounds.value())
2803  : ArrayAttr();
2804  build(builder, result, vectorType, source, indices, permutationMapAttr,
2805  padding,
2806  /*mask=*/Value(), inBoundsAttr);
2807 }
2808 
2809 /// 4. Builder that sets padding to zero and permutation map to
2810 /// 'getMinorIdentityMap'.
2811 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2812  VectorType vectorType, Value source,
2813  ValueRange indices,
2814  Optional<ArrayRef<bool>> inBounds) {
2815  Type elemType = source.getType().cast<ShapedType>().getElementType();
2816  Value padding = builder.create<arith::ConstantOp>(
2817  result.location, elemType, builder.getZeroAttr(elemType));
2818  build(builder, result, vectorType, source, indices, padding, inBounds);
2819 }
2820 
2821 template <typename EmitFun>
2823  EmitFun emitOpError) {
2824  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
2825  for (auto expr : permutationMap.getResults()) {
2826  auto dim = expr.dyn_cast<AffineDimExpr>();
2827  auto zero = expr.dyn_cast<AffineConstantExpr>();
2828  if (zero) {
2829  if (zero.getValue() != 0) {
2830  return emitOpError(
2831  "requires a projected permutation_map (at most one dim or the zero "
2832  "constant can appear in each result)");
2833  }
2834  continue;
2835  }
2836  if (!dim) {
2837  return emitOpError("requires a projected permutation_map (at most one "
2838  "dim or the zero constant can appear in each result)");
2839  }
2840  if (seen[dim.getPosition()]) {
2841  return emitOpError(
2842  "requires a permutation_map that is a permutation (found one dim "
2843  "used more than once)");
2844  }
2845  seen[dim.getPosition()] = true;
2846  }
2847  return success();
2848 }
2849 
2850 static LogicalResult
2851 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
2852  VectorType vectorType, VectorType maskType,
2853  AffineMap permutationMap, ArrayAttr inBounds) {
2854  if (op->hasAttr("masked")) {
2855  return op->emitOpError("masked attribute has been removed. "
2856  "Use in_bounds instead.");
2857  }
2858 
2859  if (!shapedType.isa<MemRefType, RankedTensorType>())
2860  return op->emitOpError(
2861  "requires source to be a memref or ranked tensor type");
2862 
2863  auto elementType = shapedType.getElementType();
2864  DataLayout dataLayout = DataLayout::closest(op);
2865  if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
2866  // Memref or tensor has vector element type.
2867  unsigned sourceVecSize =
2868  dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
2869  vectorElementType.getShape().back();
2870  unsigned resultVecSize =
2871  dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
2872  vectorType.getShape().back();
2873  if (resultVecSize % sourceVecSize != 0)
2874  return op->emitOpError(
2875  "requires the bitwidth of the minor 1-D vector to be an integral "
2876  "multiple of the bitwidth of the minor 1-D vector of the source");
2877 
2878  unsigned sourceVecEltRank = vectorElementType.getRank();
2879  unsigned resultVecRank = vectorType.getRank();
2880  if (sourceVecEltRank > resultVecRank)
2881  return op->emitOpError(
2882  "requires source vector element and vector result ranks to match.");
2883  unsigned rankOffset = resultVecRank - sourceVecEltRank;
2884  // Check that permutation map results match 'rankOffset' of vector type.
2885  if (permutationMap.getNumResults() != rankOffset)
2886  return op->emitOpError("requires a permutation_map with result dims of "
2887  "the same rank as the vector type");
2888 
2889  if (maskType)
2890  return op->emitOpError("does not support masks with vector element type");
2891  } else {
2892  // Memref or tensor has scalar element type.
2893  unsigned minorSize =
2894  vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
2895  unsigned resultVecSize =
2896  dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
2897  if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
2898  return op->emitOpError(
2899  "requires the bitwidth of the minor 1-D vector to be an integral "
2900  "multiple of the bitwidth of the source element type");
2901 
2902  // Check that permutation map results match rank of vector type.
2903  if (permutationMap.getNumResults() != vectorType.getRank())
2904  return op->emitOpError("requires a permutation_map with result dims of "
2905  "the same rank as the vector type");
2906 
2907  VectorType expectedMaskType =
2908  vector::detail::transferMaskType(vectorType, permutationMap);
2909  if (maskType && expectedMaskType != maskType)
2910  return op->emitOpError("expects mask type consistent with permutation "
2911  "map: ")
2912  << maskType;
2913  }
2914 
2915  if (permutationMap.getNumSymbols() != 0)
2916  return op->emitOpError("requires permutation_map without symbols");
2917 
2918  if (permutationMap.getNumInputs() != shapedType.getRank())
2919  return op->emitOpError("requires a permutation_map with input dims of the "
2920  "same rank as the source type");
2921 
2922  if (inBounds) {
2923  if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
2924  return op->emitOpError("expects the optional in_bounds attr of same rank "
2925  "as permutation_map results: ")
2926  << AffineMapAttr::get(permutationMap)
2927  << " vs inBounds of size: " << inBounds.size();
2928  for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
2929  if (permutationMap.getResult(i).isa<AffineConstantExpr>() &&
2930  !inBounds.getValue()[i].cast<BoolAttr>().getValue())
2931  return op->emitOpError("requires broadcast dimensions to be in-bounds");
2932  }
2933 
2934  return success();
2935 }
2936 
2937 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
2938  SmallVector<StringRef, 3> elidedAttrs;
2939  elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
2940  if (op.permutation_map().isMinorIdentity())
2941  elidedAttrs.push_back(op.getPermutationMapAttrStrName());
2942  bool elideInBounds = true;
2943  if (auto inBounds = op.in_bounds()) {
2944  for (auto attr : *inBounds) {
2945  if (attr.template cast<BoolAttr>().getValue()) {
2946  elideInBounds = false;
2947  break;
2948  }
2949  }
2950  }
2951  if (elideInBounds)
2952  elidedAttrs.push_back(op.getInBoundsAttrStrName());
2953  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
2954 }
2955 
2957  p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
2958  if (getMask())
2959  p << ", " << getMask();
2960  printTransferAttrs(p, *this);
2961  p << " : " << getShapedType() << ", " << getVectorType();
2962 }
2963 
2964 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
2965  auto &builder = parser.getBuilder();
2966  SMLoc typesLoc;
2967  OpAsmParser::UnresolvedOperand sourceInfo;
2968  SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
2969  OpAsmParser::UnresolvedOperand paddingInfo;
2970  SmallVector<Type, 2> types;
2972  // Parsing with support for paddingValue.
2973  if (parser.parseOperand(sourceInfo) ||
2974  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
2975  parser.parseComma() || parser.parseOperand(paddingInfo))
2976  return failure();
2977  ParseResult hasMask = parser.parseOptionalComma();
2978  if (hasMask.succeeded()) {
2979  if (parser.parseOperand(maskInfo))
2980  return failure();
2981  }
2982  if (parser.parseOptionalAttrDict(result.attributes) ||
2983  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
2984  return failure();
2985  if (types.size() != 2)
2986  return parser.emitError(typesLoc, "requires two types");
2987  auto indexType = builder.getIndexType();
2988  auto shapedType = types[0].dyn_cast<ShapedType>();
2989  if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
2990  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
2991  VectorType vectorType = types[1].dyn_cast<VectorType>();
2992  if (!vectorType)
2993  return parser.emitError(typesLoc, "requires vector type");
2994  auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName();
2995  Attribute mapAttr = result.attributes.get(permutationAttrName);
2996  if (!mapAttr) {
2997  auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
2998  // Update `mapAttr` that is used later to determine mask type.
2999  mapAttr = AffineMapAttr::get(permMap);
3000  result.attributes.set(permutationAttrName, mapAttr);
3001  }
3002  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
3003  parser.resolveOperands(indexInfo, indexType, result.operands) ||
3004  parser.resolveOperand(paddingInfo, shapedType.getElementType(),
3005  result.operands))
3006  return failure();
3007  if (hasMask.succeeded()) {
3008  if (shapedType.getElementType().dyn_cast<VectorType>())
3009  return parser.emitError(
3010  maskInfo.location, "does not support masks with vector element type");
3011  auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue();
3012  // Instead of adding the mask type as an op type, compute it based on the
3013  // vector type and the permutation map (to keep the type signature small).
3014  auto maskType = mlir::vector::detail::transferMaskType(vectorType, map);
3015  if (parser.resolveOperand(maskInfo, maskType, result.operands))
3016  return failure();
3017  }
3018  result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
3019  builder.getDenseI32ArrayAttr(
3020  {1, static_cast<int32_t>(indexInfo.size()), 1,
3021  static_cast<int32_t>(hasMask.succeeded())}));
3022  return parser.addTypeToList(vectorType, result.types);
3023 }
3024 
3026  // Consistency of elemental types in source and vector.
3027  ShapedType shapedType = getShapedType();
3028  VectorType vectorType = getVectorType();
3029  VectorType maskType = getMaskType();
3030  auto paddingType = getPadding().getType();
3031  auto permutationMap = getPermutationMap();
3032  auto sourceElementType = shapedType.getElementType();
3033 
3034  if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
3035  return emitOpError("requires ") << shapedType.getRank() << " indices";
3036 
3037  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
3038  shapedType, vectorType, maskType, permutationMap,
3039  getInBounds() ? *getInBounds() : ArrayAttr())))
3040  return failure();
3041 
3042  if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
3043  // Source has vector element type.
3044  // Check that 'sourceVectorElementType' and 'paddingType' types match.
3045  if (sourceVectorElementType != paddingType)
3046  return emitOpError(
3047  "requires source element type and padding type to match.");
3048 
3049  } else {
3050  // Check that 'paddingType' is valid to store in a vector type.
3051  if (!VectorType::isValidElementType(paddingType))
3052  return emitOpError("requires valid padding vector elemental type");
3053 
3054  // Check that padding type and vector element types match.
3055  if (paddingType != sourceElementType)
3056  return emitOpError(
3057  "requires formal padding and source of the same elemental type");
3058  }
3059 
3060  return verifyPermutationMap(permutationMap,
3061  [&](Twine t) { return emitOpError(t); });
3062 }
3063 
3064 /// This is a common class used for patterns of the form
3065 /// ```
3066 /// someop(memrefcast) -> someop
3067 /// ```
3068 /// It folds the source of the memref.cast into the root operation directly.
3070  bool folded = false;
3071  for (OpOperand &operand : op->getOpOperands()) {
3072  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
3073  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
3074  operand.set(castOp.getOperand());
3075  folded = true;
3076  }
3077  }
3078  return success(folded);
3079 }
3080 
3082  bool folded = false;
3083  for (OpOperand &operand : op->getOpOperands()) {
3084  auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
3085  if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
3086  operand.set(castOp.getOperand());
3087  folded = true;
3088  }
3089  }
3090  return success(folded);
3091 }
3092 
3093 template <typename TransferOp>
3094 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
3095  // TODO: support more aggressive createOrFold on:
3096  // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)`
3097  if (op.getShapedType().isDynamicDim(indicesIdx))
3098  return false;
3099  Value index = op.getIndices()[indicesIdx];
3100  auto cstOp = index.getDefiningOp<arith::ConstantIndexOp>();
3101  if (!cstOp)
3102  return false;
3103 
3104  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
3105  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
3106 
3107  return cstOp.value() + vectorSize <= sourceSize;
3108 }
3109 
3110 template <typename TransferOp>
3112  // TODO: support 0-d corner case.
3113  // TODO: Be less conservative.
3114  if (op.getTransferRank() == 0)
3115  return failure();
3116  AffineMap permutationMap = op.getPermutationMap();
3117  bool changed = false;
3118  SmallVector<bool, 4> newInBounds;
3119  newInBounds.reserve(op.getTransferRank());
3120  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
3121  // Already marked as in-bounds, nothing to see here.
3122  if (op.isDimInBounds(i)) {
3123  newInBounds.push_back(true);
3124  continue;
3125  }
3126  // Currently out-of-bounds, check whether we can statically determine it is
3127  // inBounds.
3128  auto dimExpr = permutationMap.getResult(i).dyn_cast<AffineDimExpr>();
3129  assert(dimExpr && "Broadcast dims must be in-bounds");
3130  auto inBounds =
3131  isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
3132  newInBounds.push_back(inBounds);
3133  // We commit the pattern if it is "more inbounds".
3134  changed |= inBounds;
3135  }
3136  if (!changed)
3137  return failure();
3138  // OpBuilder is only used as a helper to build an I64ArrayAttr.
3139  OpBuilder b(op.getContext());
3140  op->setAttr(TransferOp::getInBoundsAttrStrName(),
3141  b.getBoolArrayAttr(newInBounds));
3142  return success();
3143 }
3144 
3145 /// ```
3146 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
3147 /// : vector<1x4xf32>, tensor<4x4xf32>
3148 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
3149 /// : tensor<4x4xf32>, vector<1x4xf32>
3150 /// ```
3151 /// -> Folds into
3152 /// ```
3153 /// %v0
3154 /// ```
3155 static Value foldRAW(TransferReadOp readOp) {
3156  if (!readOp.getShapedType().isa<RankedTensorType>())
3157  return {};
3158  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
3159  while (defWrite) {
3160  if (checkSameValueRAW(defWrite, readOp))
3161  return defWrite.getVector();
3163  cast<VectorTransferOpInterface>(defWrite.getOperation()),
3164  cast<VectorTransferOpInterface>(readOp.getOperation())))
3165  break;
3166  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
3167  }
3168  return {};
3169 }
3170 
3171 OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
3172  if (Value vec = foldRAW(*this))
3173  return vec;
3174  /// transfer_read(memrefcast) -> transfer_read
3176  return getResult();
3177  if (succeeded(foldMemRefCast(*this)))
3178  return getResult();
3179  if (succeeded(foldTensorCast(*this)))
3180  return getResult();
3181  return OpFoldResult();
3182 }
3183 
3184 Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
3185  return llvm::to_vector<4>(getVectorType().getShape());
3186 }
3187 
3188 void TransferReadOp::getEffects(
3190  &effects) {
3191  if (getShapedType().isa<MemRefType>())
3192  effects.emplace_back(MemoryEffects::Read::get(), getSource(),
3194 }
3195 
3196 namespace {
3197 /// Fold transfer_reads of a tensor.extract_slice op. E.g.:
3198 ///
3199 /// ```
3200 /// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
3201 /// : tensor<?x?xf32> to tensor<?x?xf32>
3202 /// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
3203 /// : tensor<?x?xf32>, vector<4x5xf32>
3204 /// ```
3205 /// is rewritten to:
3206 /// ```
3207 /// %p0 = arith.addi %a, %e : index
3208 /// %p1 = arith.addi %b, %f : index
3209 /// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
3210 /// : tensor<?x?xf32>, vector<4x5xf32>
3211 /// ```
3212 struct FoldExtractSliceIntoTransferRead
3213  : public OpRewritePattern<TransferReadOp> {
3214 public:
3216 
3217  LogicalResult matchAndRewrite(TransferReadOp xferOp,
3218  PatternRewriter &rewriter) const override {
3219  // TODO: support 0-d corner case.
3220  if (xferOp.getTransferRank() == 0)
3221  return failure();
3222  if (xferOp.hasOutOfBoundsDim())
3223  return failure();
3224  if (!xferOp.getPermutationMap().isIdentity())
3225  return failure();
3226  if (xferOp.getMask())
3227  return failure();
3228  auto extractOp = xferOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
3229  if (!extractOp)
3230  return failure();
3231  if (!extractOp.hasUnitStride())
3232  return failure();
3233 
3234  // Bail on illegal rank-reduction: we need to check that the rank-reduced
3235  // dims are exactly the leading dims. I.e. the following is illegal:
3236  // ```
3237  // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] :
3238  // tensor<2x1x4xf32> to tensor<2x4xf32>
3239  // %1 = vector.transfer_read %0[0,0], %cst :
3240  // tensor<2x4xf32>, vector<2x4xf32>
3241  // ```
3242  //
3243  // Cannot fold into:
3244  // ```
3245  // %0 = vector.transfer_read %t[0,0,0], %cst :
3246  // tensor<2x1x4xf32>, vector<2x4xf32>
3247  // ```
3248  // For this, check the trailing `vectorRank` dims of the extract_slice
3249  // result tensor match the trailing dims of the inferred result tensor.
3250  int64_t rankReduced =
3251  extractOp.getSourceType().getRank() - extractOp.getType().getRank();
3252  int64_t vectorRank = xferOp.getVectorType().getRank();
3253  RankedTensorType inferredDestTensorType =
3254  tensor::ExtractSliceOp::inferResultType(
3255  extractOp.getSourceType(), extractOp.getMixedOffsets(),
3256  extractOp.getMixedSizes(), extractOp.getMixedStrides());
3257  auto actualDestTensorShape = extractOp.getType().getShape();
3258  if (rankReduced > 0 &&
3259  actualDestTensorShape.take_back(vectorRank) !=
3260  inferredDestTensorType.getShape().take_back(vectorRank))
3261  return failure();
3262 
3263  SmallVector<Value> newIndices;
3264  // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
3265  // indices first.
3266  for (int64_t i = 0; i < rankReduced; ++i) {
3267  OpFoldResult offset = extractOp.getMixedOffsets()[i];
3268  newIndices.push_back(getValueOrCreateConstantIndexOp(
3269  rewriter, extractOp.getLoc(), offset));
3270  }
3271  for (const auto &it : llvm::enumerate(xferOp.getIndices())) {
3272  OpFoldResult offset =
3273  extractOp.getMixedOffsets()[it.index() + rankReduced];
3274  newIndices.push_back(rewriter.create<arith::AddIOp>(
3275  xferOp->getLoc(), it.value(),
3276  getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(),
3277  offset)));
3278  }
3279  SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
3280  rewriter.replaceOpWithNewOp<TransferReadOp>(
3281  xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices,
3282  xferOp.getPadding(), ArrayRef<bool>{inBounds});
3283 
3284  return success();
3285  }
3286 };
3287 
3288 /// Store to load forwarding for transfer operations with permuation maps.
3289 /// Even if the permutation maps are different we can still propagate the store
3290 /// into the load if the size of the dimensions read and written match. Then we
3291 /// can replace the transfer_read + transfer_write by vector.broadcast and
3292 /// vector.transpose.
3293 /// Example:
3294 /// ```
3295 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
3296 /// {in_bounds = [true, true],
3297 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
3298 /// vector<4x1xf32>, tensor<4x4x4xf32>
3299 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
3300 /// {in_bounds = [true, true, true, true],
3301 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
3302 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
3303 /// ```
3304 /// To:
3305 /// ```
3306 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
3307 /// %r = vector.transpose %0, [3, 0, 2, 1] :
3308 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
3309 /// ```
3310 struct TransferReadAfterWriteToBroadcast
3311  : public OpRewritePattern<TransferReadOp> {
3313 
3314  LogicalResult matchAndRewrite(TransferReadOp readOp,
3315  PatternRewriter &rewriter) const override {
3316  if (readOp.hasOutOfBoundsDim() ||
3317  !readOp.getShapedType().isa<RankedTensorType>())
3318  return failure();
3319  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
3320  if (!defWrite)
3321  return failure();
3322 
3323  SmallVector<int64_t> readDims = readOp.getTransferChunkAccessed();
3324  Value vec;
3325  if (readOp.getIndices() == defWrite.getIndices() &&
3326  readOp.getMask() == defWrite.getMask()) {
3327  SmallVector<int64_t> writeDims = defWrite.getTransferChunkAccessed();
3328  // TODO: If the writeDim is a superset of the read dims we could do an
3329  // extract_strided_slice.
3330  if (writeDims == readDims)
3331  vec = defWrite.getVector();
3332  }
3333  // TODO: loop through the chain of transfer_write if we can prove that they
3334  // don't overlap with the transfer_read. This requires improving
3335  // `isDisjointTransferIndices` helper.
3336  if (!vec)
3337  return failure();
3338  SmallVector<unsigned> permutation;
3339  AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
3340  AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
3341  AffineMap map = readMap.compose(writeMap);
3342  if (map.getNumResults() == 0)
3343  return failure();
3344  // Calculate the permuation to apply to go from the vector stored to the
3345  // vector read.
3346  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
3347  return failure();
3348 
3349  Location loc = readOp.getLoc();
3350  // Calculate the broadcast shape by applying the reverse permuation to the
3351  // final shape we want.
3352  ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
3353  SmallVector<int64_t> broadcastShape(destShape.size());
3354  for (const auto &pos : llvm::enumerate(permutation))
3355  broadcastShape[pos.value()] = destShape[pos.index()];
3356  VectorType broadcastedType = VectorType::get(
3357  broadcastShape, defWrite.getVectorType().getElementType());
3358  vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
3359  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
3360  rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
3361  transposePerm);
3362  return success();
3363  }
3364 };
3365 } // namespace
3366 
3367 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3368  MLIRContext *context) {
3369  results
3370  .add<FoldExtractSliceIntoTransferRead, TransferReadAfterWriteToBroadcast>(
3371  context);
3372 }
3373 
3374 //===----------------------------------------------------------------------===//
3375 // TransferWriteOp
3376 //===----------------------------------------------------------------------===//
3377 
3378 /// 1. Builder with type inference.
3379 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3380  Value vector, Value dest, ValueRange indices,
3381  AffineMapAttr permutationMapAttr,
3382  /*optional*/ Value mask,
3383  /*optional*/ ArrayAttr inBoundsAttr) {
3384  Type resultType = dest.getType().dyn_cast<RankedTensorType>();
3385  build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
3386  mask, inBoundsAttr);
3387 }
3388 
3389 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
3390 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3391  Value vector, Value dest, ValueRange indices,
3392  AffineMapAttr permutationMapAttr,
3393  /*optional*/ ArrayAttr inBoundsAttr) {
3394  build(builder, result, vector, dest, indices, permutationMapAttr,
3395  /*mask=*/Value(), inBoundsAttr);
3396 }
3397 
3398 /// 3. Builder with type inference that sets an empty mask (variant without
3399 /// attrs)
3400 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3401  Value vector, Value dest, ValueRange indices,
3402  AffineMap permutationMap,
3403  Optional<ArrayRef<bool>> inBounds) {
3404  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3405  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3406  ? builder.getBoolArrayAttr(inBounds.value())
3407  : ArrayAttr();
3408  build(builder, result, vector, dest, indices, permutationMapAttr,
3409  /*mask=*/Value(), inBoundsAttr);
3410 }
3411 
3412 /// 4. Builder with type inference that sets an empty mask and sets permutation
3413 /// map to 'getMinorIdentityMap'.
3414 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3415  Value vector, Value dest, ValueRange indices,
3416  Optional<ArrayRef<bool>> inBounds) {
3417  auto vectorType = vector.getType().cast<VectorType>();
3418  AffineMap permutationMap = getTransferMinorIdentityMap(
3419  dest.getType().cast<ShapedType>(), vectorType);
3420  build(builder, result, vector, dest, indices, permutationMap, inBounds);
3421 }
3422 
3423 ParseResult TransferWriteOp::parse(OpAsmParser &parser,
3424  OperationState &result) {
3425  auto &builder = parser.getBuilder();
3426  SMLoc typesLoc;
3427  OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
3428  SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
3429  SmallVector<Type, 2> types;
3431  if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
3432  parser.parseOperand(sourceInfo) ||
3434  return failure();
3435  ParseResult hasMask = parser.parseOptionalComma();
3436  if (hasMask.succeeded() && parser.parseOperand(maskInfo))
3437  return failure();
3438  if (parser.parseOptionalAttrDict(result.attributes) ||
3439  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
3440  return failure();
3441  if (types.size() != 2)
3442  return parser.emitError(typesLoc, "requires two types");
3443  auto indexType = builder.getIndexType();
3444  VectorType vectorType = types[0].dyn_cast<VectorType>();
3445  if (!vectorType)
3446  return parser.emitError(typesLoc, "requires vector type");
3447  ShapedType shapedType = types[1].dyn_cast<ShapedType>();
3448  if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
3449  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
3450  auto permutationAttrName = TransferWriteOp::getPermutationMapAttrStrName();
3451  auto attr = result.attributes.get(permutationAttrName);
3452  if (!attr) {
3453  auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3454  result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
3455  }
3456  if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
3457  parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
3458  parser.resolveOperands(indexInfo, indexType, result.operands))
3459  return failure();
3460  if (hasMask.succeeded()) {
3461  if (shapedType.getElementType().dyn_cast<VectorType>())
3462  return parser.emitError(
3463  maskInfo.location, "does not support masks with vector element type");
3464  auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
3465  if (parser.resolveOperand(maskInfo, maskType, result.operands))
3466  return failure();
3467  }
3468  result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
3469  builder.getDenseI32ArrayAttr(
3470  {1, 1, static_cast<int32_t>(indexInfo.size()),
3471  static_cast<int32_t>(hasMask.succeeded())}));
3472  return failure(shapedType.isa<RankedTensorType>() &&
3473  parser.addTypeToList(shapedType, result.types));
3474 }
3475 
3477  p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
3478  if (getMask())
3479  p << ", " << getMask();
3480  printTransferAttrs(p, *this);
3481  p << " : " << getVectorType() << ", " << getShapedType();
3482 }
3483 
3485  // Consistency of elemental types in shape and vector.
3486  ShapedType shapedType = getShapedType();
3487  VectorType vectorType = getVectorType();
3488  VectorType maskType = getMaskType();
3489  auto permutationMap = getPermutationMap();
3490 
3491  if (llvm::size(getIndices()) != shapedType.getRank())
3492  return emitOpError("requires ") << shapedType.getRank() << " indices";
3493 
3494  // We do not allow broadcast dimensions on TransferWriteOps for the moment,
3495  // as the semantics is unclear. This can be revisited later if necessary.
3496  if (hasBroadcastDim())
3497  return emitOpError("should not have broadcast dimensions");
3498 
3499  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
3500  shapedType, vectorType, maskType, permutationMap,
3501  getInBounds() ? *getInBounds() : ArrayAttr())))
3502  return failure();
3503 
3504  return verifyPermutationMap(permutationMap,
3505  [&](Twine t) { return emitOpError(t); });
3506 }
3507 
3508 /// Fold:
3509 /// ```
3510 /// %t1 = ...
3511 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
3512 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
3513 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
3514 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
3515 /// ```
3516 ///
3517 /// into:
3518 ///
3519 /// ```
3520 /// %t0
3521 /// ```
3522 ///
3523 /// The producer of t1 may or may not be DCE'd depending on whether it is a
3524 /// block argument or has side effects.
3525 static LogicalResult foldReadInitWrite(TransferWriteOp write,
3526  ArrayRef<Attribute>,
3527  SmallVectorImpl<OpFoldResult> &results) {
3528  // TODO: support 0-d corner case.
3529  if (write.getTransferRank() == 0)
3530  return failure();
3531  auto rankedTensorType =
3532  write.getSource().getType().dyn_cast<RankedTensorType>();
3533  // If not operating on tensors, bail.
3534  if (!rankedTensorType)
3535  return failure();
3536  // If no read, bail.
3537  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
3538  if (!read)
3539  return failure();
3540  // TODO: support 0-d corner case.
3541  if (read.getTransferRank() == 0)
3542  return failure();
3543  // For now, only accept minor identity. Future: composition is minor identity.
3544  if (!read.getPermutationMap().isMinorIdentity() ||
3545  !write.getPermutationMap().isMinorIdentity())
3546  return failure();
3547  // Bail on mismatching ranks.
3548  if (read.getTransferRank() != write.getTransferRank())
3549  return failure();
3550  // Bail on potential out-of-bounds accesses.
3551  if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
3552  return failure();
3553  // Tensor types must be the same.
3554  if (read.getSource().getType() != rankedTensorType)
3555  return failure();
3556  // Vector types must be the same.
3557  if (read.getVectorType() != write.getVectorType())
3558  return failure();
3559  // Vector and Tensor shapes must match.
3560  if (read.getVectorType().getShape() != rankedTensorType.getShape())
3561  return failure();
3562  // If any index is nonzero.
3563  auto isNotConstantZero = [](Value v) {
3564  auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>();
3565  return !cstOp || cstOp.value() != 0;
3566  };
3567  if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
3568  llvm::any_of(write.getIndices(), isNotConstantZero))
3569  return failure();
3570  // Success.
3571  results.push_back(read.getSource());
3572  return success();
3573 }
3574 
3575 static bool checkSameValueWAR(vector::TransferReadOp read,
3576  vector::TransferWriteOp write) {
3577  return read.getSource() == write.getSource() &&
3578  read.getIndices() == write.getIndices() &&
3579  read.getPermutationMap() == write.getPermutationMap() &&
3580  read.getVectorType() == write.getVectorType() && !read.getMask() &&
3581  !write.getMask();
3582 }
3583 /// Fold transfer_write write after read:
3584 /// ```
3585 /// %t0 = ...
3586 /// %v = vector.transfer_read %t0[%c0...] :
3587 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
3588 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
3589 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
3590 /// ```
3591 ///
3592 /// into:
3593 ///
3594 /// ```
3595 /// %t0
3596 /// ```
3597 static LogicalResult foldWAR(TransferWriteOp write,
3598  SmallVectorImpl<OpFoldResult> &results) {
3599  if (!write.getSource().getType().isa<RankedTensorType>())
3600  return failure();
3601  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
3602  if (!read)
3603  return failure();
3604 
3605  if (!checkSameValueWAR(read, write))
3606  return failure();
3607  results.push_back(read.getSource());
3608  return success();
3609 }
3610 
3611 LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
3612  SmallVectorImpl<OpFoldResult> &results) {
3613  if (succeeded(foldReadInitWrite(*this, operands, results)))
3614  return success();
3615  if (succeeded(foldWAR(*this, results)))
3616  return success();
3618  return success();
3619  return foldMemRefCast(*this);
3620 }
3621 
3622 Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
3623  return llvm::to_vector<4>(getVectorType().getShape());
3624 }
3625 
3626 void TransferWriteOp::getEffects(
3628  &effects) {
3629  if (getShapedType().isa<MemRefType>())
3630  effects.emplace_back(MemoryEffects::Write::get(), getSource(),
3632 }
3633 
3634 namespace {
3635 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
3636 /// DCE
3637 /// ```
3638 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
3639 /// : vector<1x4xf32>, tensor<4x4xf32>
3640 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
3641 /// : vector<1x4xf32>, tensor<4x4xf32>
3642 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
3643 /// : vector<1x4xf32>, tensor<4x4xf32>
3644 /// ```
3645 ///
3646 /// into:
3647 ///
3648 /// ```
3649 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
3650 /// : vector<1x4xf32>, tensor<4x4xf32>
3651 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
3652 /// : vector<1x4xf32>, tensor<4x4xf32>
3653 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
3654 /// : vector<1x4xf32>, tensor<4x4xf32>
3655 /// ```
3656 ///
3657 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
3658 /// any other uses.
3659 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
3660 public:
3662  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
3663  PatternRewriter &rewriter) const override {
3664  if (!writeOp.getShapedType().isa<RankedTensorType>())
3665  return failure();
3666  vector::TransferWriteOp writeToModify = writeOp;
3667 
3668  auto defWrite =
3669  writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
3670  while (defWrite) {
3671  if (checkSameValueWAW(writeOp, defWrite)) {
3672  writeToModify.getSourceMutable().assign(defWrite.getSource());
3673  return success();
3674  }
3676  cast<VectorTransferOpInterface>(defWrite.getOperation()),
3677  cast<VectorTransferOpInterface>(writeOp.getOperation())))
3678  break;
3679  // If the previous write op doesn't have any other use we an safely look
3680  // at the previous store to see if it can be removed.
3681  if (!defWrite->hasOneUse())
3682  break;
3683  writeToModify = defWrite;
3684  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
3685  }
3686  return failure();
3687  }
3688 };
3689 
3690 /// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
3691 /// could directly write to the insert_slice's destination. E.g.:
3692 ///
3693 /// ```
3694 /// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
3695 /// : vector<4x5xf32>, tensor<4x5xf32>
3696 /// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
3697 /// : tensor<4x5xf32> into tensor<?x?xf32>
3698 /// ```
3699 /// is rewritten to:
3700 /// ```
3701 /// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
3702 /// : vector<4x5xf32>, tensor<?x?xf32>
3703 /// ```
3704 struct FoldInsertSliceIntoTransferWrite
3705  : public OpRewritePattern<tensor::InsertSliceOp> {
3706 public:
3708 
3709  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
3710  PatternRewriter &rewriter) const override {
3711  if (!insertOp.hasUnitStride())
3712  return failure();
3713 
3714  auto xferOp = insertOp.getSource().getDefiningOp<TransferWriteOp>();
3715  if (!xferOp)
3716  return failure();
3717  // TODO: support 0-d corner case.
3718  if (xferOp.getTransferRank() == 0)
3719  return failure();
3720 
3721  if (xferOp.hasOutOfBoundsDim())
3722  return failure();
3723  if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
3724  return failure();
3725  if (xferOp.getMask())
3726  return failure();
3727  // Fold only if the TransferWriteOp completely overwrites the `source` with
3728  // a vector. I.e., the result of the TransferWriteOp is a new tensor whose
3729  // content is the data of the vector.
3730  if (!llvm::equal(xferOp.getVectorType().getShape(),
3731  xferOp.getShapedType().getShape()))
3732  return failure();
3733  if (!xferOp.getPermutationMap().isIdentity())
3734  return failure();
3735 
3736  // Bail on illegal rank-reduction: we need to check that the rank-reduced
3737  // dims are exactly the leading dims. I.e. the following is illegal:
3738  // ```
3739  // %0 = vector.transfer_write %v, %t[0,0], %cst :
3740  // vector<2x4xf32>, tensor<2x4xf32>
3741  // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] :
3742  // tensor<2x4xf32> into tensor<2x1x4xf32>
3743  // ```
3744  //
3745  // Cannot fold into:
3746  // ```
3747  // %0 = vector.transfer_write %v, %t[0,0,0], %cst :
3748  // vector<2x4xf32>, tensor<2x1x4xf32>
3749  // ```
3750  // For this, check the trailing `vectorRank` dims of the insert_slice result
3751  // tensor match the trailing dims of the inferred result tensor.
3752  int64_t rankReduced =
3753  insertOp.getType().getRank() - insertOp.getSourceType().getRank();
3754  int64_t vectorRank = xferOp.getVectorType().getRank();
3755  RankedTensorType inferredSourceTensorType =
3756  tensor::ExtractSliceOp::inferResultType(
3757  insertOp.getType(), insertOp.getMixedOffsets(),
3758  insertOp.getMixedSizes(), insertOp.getMixedStrides());
3759  auto actualSourceTensorShape = insertOp.getSourceType().getShape();
3760  if (rankReduced > 0 &&
3761  actualSourceTensorShape.take_back(vectorRank) !=
3762  inferredSourceTensorType.getShape().take_back(vectorRank))
3763  return failure();
3764 
3765  SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
3766  rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
3767  SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
3768  rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.getVector(),
3769  insertOp.getDest(), indices,
3770  ArrayRef<bool>{inBounds});
3771  return success();
3772  }
3773 };
3774 
3775 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
3776 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
3777 /// overwritten and inserted into another tensor. After this rewrite, the
3778 /// operations bufferize in-place since all of them work on the same slice.
3779 ///
3780 /// For example:
3781 /// ```mlir
3782 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
3783 /// : vector<8x16xf32>, tensor<8x16xf32>
3784 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
3785 /// : tensor<8x16xf32> to tensor<?x?xf32>
3786 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
3787 /// : tensor<?x?xf32> into tensor<27x37xf32>
3788 /// ```
3789 /// folds to
3790 /// ```mlir
3791 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
3792 /// : tensor<27x37xf32> to tensor<?x?xf32>
3793 /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
3794 /// : vector<8x16xf32>, tensor<?x?xf32>
3795 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
3796 /// : tensor<?x?xf32> into tensor<27x37xf32>
3797 /// ```
3798 struct SwapExtractSliceOfTransferWrite
3799  : public OpRewritePattern<tensor::InsertSliceOp> {
3800 public:
3802 
3803  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
3804  PatternRewriter &rewriter) const override {
3805  if (!insertOp.hasUnitStride())
3806  return failure();
3807  auto extractOp =
3808  insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
3809  if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
3810  return failure();
3811  auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
3812  if (!transferOp || !transferOp->hasOneUse())
3813  return failure();
3814 
3815  // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
3816  // rank-reducing.
3817  if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
3818  return rewriter.notifyMatchFailure(insertOp,
3819  "use-def chain is rank-reducing");
3820  }
3821 
3822  // Fail if tensor::ExtractSliceOp has non-zero offset.
3823  if (!extractOp.hasZeroOffset()) {
3824  return rewriter.notifyMatchFailure(insertOp,
3825  "ExtractSliceOp has non-zero offset");
3826  }
3827 
3828  // Fail if tensor::TransferWriteOp has non-zero offset.
3829  if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
3830  return getConstantIntValue(value) == static_cast<int64_t>(0);
3831  })) {
3832  return rewriter.notifyMatchFailure(insertOp,
3833  "TranferWriteOp has non-zero offset");
3834  }
3835 
3836  // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
3837  for (const auto &it :
3838  llvm::zip(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
3839  if (!isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it))) {
3840  return rewriter.notifyMatchFailure(
3841  insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
3842  }
3843  }
3844 
3845  // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
3846  assert(transferOp.getVectorType().hasStaticShape() &&
3847  "expected vector to have a static shape");
3848  ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
3849  SmallVector<int64_t> resultShape = applyPermutationMap(
3850  transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
3851  if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
3852  return rewriter.notifyMatchFailure(
3853  insertOp, "TransferWriteOp may not write the full tensor.");
3854  }
3855 
3856  // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
3857  SmallVector<int64_t> newResultShape = applyPermutationMap(
3858  transferOp.getPermutationMap(), insertOp.getSourceType().getShape());
3859  SmallVector<bool> newInBounds;
3860  for (const auto &en : enumerate(newResultShape))
3861  newInBounds.push_back(en.value() == vectorShape[en.index()]);
3862  auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
3863  extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
3864  insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
3865  insertOp.getMixedStrides());
3866  auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
3867  transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
3868  transferOp.getIndices(), transferOp.getPermutationMapAttr(),
3869  rewriter.getBoolArrayAttr(newInBounds));
3870  rewriter.updateRootInPlace(insertOp, [&]() {
3871  insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
3872  });
3873  return success();
3874  }
3875 };
3876 
3877 } // namespace
3878 
3879 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
3880  MLIRContext *context) {
3881  results.add<FoldWaw, FoldInsertSliceIntoTransferWrite,
3882  SwapExtractSliceOfTransferWrite>(context);
3883 }
3884 
3885 //===----------------------------------------------------------------------===//
3886 // LoadOp
3887 //===----------------------------------------------------------------------===//
3888 
3889 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
3890  MemRefType memRefTy) {
3891  if (!isLastMemrefDimUnitStride(memRefTy))
3892  return op->emitOpError("most minor memref dim must have unit stride");
3893  return success();
3894 }
3895 
3897  VectorType resVecTy = getVectorType();
3898  MemRefType memRefTy = getMemRefType();
3899 
3900  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
3901  return failure();
3902 
3903  // Checks for vector memrefs.
3904  Type memElemTy = memRefTy.getElementType();
3905  if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
3906  if (memVecTy != resVecTy)
3907  return emitOpError("base memref and result vector types should match");
3908  memElemTy = memVecTy.getElementType();
3909  }
3910 
3911  if (resVecTy.getElementType() != memElemTy)
3912  return emitOpError("base and result element types should match");
3913  if (llvm::size(getIndices()) != memRefTy.getRank())
3914  return emitOpError("requires ") << memRefTy.getRank() << " indices";
3915  return success();
3916 }
3917 
3918 OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
3919  if (succeeded(foldMemRefCast(*this)))
3920  return getResult();
3921  return OpFoldResult();
3922 }
3923 
3924 //===----------------------------------------------------------------------===//
3925 // StoreOp
3926 //===----------------------------------------------------------------------===//
3927 
3929  VectorType valueVecTy = getVectorType();
3930  MemRefType memRefTy = getMemRefType();
3931 
3932  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
3933  return failure();
3934 
3935  // Checks for vector memrefs.
3936  Type memElemTy = memRefTy.getElementType();
3937  if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
3938  if (memVecTy != valueVecTy)
3939  return emitOpError(
3940  "base memref and valueToStore vector types should match");
3941  memElemTy = memVecTy.getElementType();
3942  }
3943 
3944  if (valueVecTy.getElementType() != memElemTy)
3945  return emitOpError("base and valueToStore element type should match");
3946  if (llvm::size(getIndices()) != memRefTy.getRank())
3947  return emitOpError("requires ") << memRefTy.getRank() << " indices";
3948  return success();
3949 }
3950 
3951 LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
3952  SmallVectorImpl<OpFoldResult> &results) {
3953  return foldMemRefCast(*this);
3954 }
3955 
3956 //===----------------------------------------------------------------------===//
3957 // MaskedLoadOp
3958 //===----------------------------------------------------------------------===//
3959 
3961  VectorType maskVType = getMaskVectorType();
3962  VectorType passVType = getPassThruVectorType();
3963  VectorType resVType = getVectorType();
3964  MemRefType memType = getMemRefType();
3965 
3966  if (resVType.getElementType() != memType.getElementType())
3967  return emitOpError("base and result element type should match");
3968  if (llvm::size(getIndices()) != memType.getRank())
3969  return emitOpError("requires ") << memType.getRank() << " indices";
3970  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3971  return emitOpError("expected result dim to match mask dim");
3972  if (resVType != passVType)
3973  return emitOpError("expected pass_thru of same type as result type");
3974  return success();
3975 }
3976 
3977 namespace {
3978 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
3979 public:
3981  LogicalResult matchAndRewrite(MaskedLoadOp load,
3982  PatternRewriter &rewriter) const override {
3983  switch (get1DMaskFormat(load.getMask())) {
3984  case MaskFormat::AllTrue:
3985  rewriter.replaceOpWithNewOp<vector::LoadOp>(
3986  load, load.getType(), load.getBase(), load.getIndices());
3987  return success();
3988  case MaskFormat::AllFalse:
3989  rewriter.replaceOp(load, load.getPassThru());
3990  return success();
3991  case MaskFormat::Unknown:
3992  return failure();
3993  }
3994  llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
3995  }
3996 };
3997 } // namespace
3998 
3999 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4000  MLIRContext *context) {
4001  results.add<MaskedLoadFolder>(context);
4002 }
4003 
4004 OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
4005  if (succeeded(foldMemRefCast(*this)))
4006  return getResult();
4007  return OpFoldResult();
4008 }
4009 
4010 //===----------------------------------------------------------------------===//
4011 // MaskedStoreOp
4012 //===----------------------------------------------------------------------===//
4013 
4015  VectorType maskVType = getMaskVectorType();
4016  VectorType valueVType = getVectorType();
4017  MemRefType memType = getMemRefType();
4018 
4019  if (valueVType.getElementType() != memType.getElementType())
4020  return emitOpError("base and valueToStore element type should match");
4021  if (llvm::size(getIndices()) != memType.getRank())
4022  return emitOpError("requires ") << memType.getRank() << " indices";
4023  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4024  return emitOpError("expected valueToStore dim to match mask dim");
4025  return success();
4026 }
4027 
4028 namespace {
4029 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
4030 public:
4032  LogicalResult matchAndRewrite(MaskedStoreOp store,
4033  PatternRewriter &rewriter) const override {
4034  switch (get1DMaskFormat(store.getMask())) {
4035  case MaskFormat::AllTrue:
4036  rewriter.replaceOpWithNewOp<vector::StoreOp>(
4037  store, store.getValueToStore(), store.getBase(), store.getIndices());
4038  return success();
4039  case MaskFormat::AllFalse:
4040  rewriter.eraseOp(store);
4041  return success();
4042  case MaskFormat::Unknown:
4043  return failure();
4044  }
4045  llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
4046  }
4047 };
4048 } // namespace
4049 
4050 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
4051  MLIRContext *context) {
4052  results.add<MaskedStoreFolder>(context);
4053 }
4054 
4055 LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
4056  SmallVectorImpl<OpFoldResult> &results) {
4057  return foldMemRefCast(*this);
4058 }
4059 
4060 //===----------------------------------------------------------------------===//
4061 // GatherOp
4062 //===----------------------------------------------------------------------===//
4063 
4065  VectorType indVType = getIndexVectorType();
4066  VectorType maskVType = getMaskVectorType();
4067  VectorType resVType = getVectorType();
4068  ShapedType baseType = getBaseType();
4069 
4070  if (!baseType.isa<MemRefType, RankedTensorType>())
4071  return emitOpError("requires base to be a memref or ranked tensor type");
4072 
4073  if (resVType.getElementType() != baseType.getElementType())
4074  return emitOpError("base and result element type should match");
4075  if (llvm::size(getIndices()) != baseType.getRank())
4076  return emitOpError("requires ") << baseType.getRank() << " indices";
4077  if (resVType.getDimSize(0) != indVType.getDimSize(0))
4078  return emitOpError("expected result dim to match indices dim");
4079  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4080  return emitOpError("expected result dim to match mask dim");
4081  if (resVType != getPassThruVectorType())
4082  return emitOpError("expected pass_thru of same type as result type");
4083  return success();
4084 }
4085 
4086 namespace {
4087 class GatherFolder final : public OpRewritePattern<GatherOp> {
4088 public:
4090  LogicalResult matchAndRewrite(GatherOp gather,
4091  PatternRewriter &rewriter) const override {
4092  switch (get1DMaskFormat(gather.getMask())) {
4093  case MaskFormat::AllTrue:
4094  return failure(); // no unmasked equivalent
4095  case MaskFormat::AllFalse:
4096  rewriter.replaceOp(gather, gather.getPassThru());
4097  return success();
4098  case MaskFormat::Unknown:
4099  return failure();
4100  }
4101  llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
4102  }
4103 };
4104 } // namespace
4105 
4106 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
4107  MLIRContext *context) {
4108  results.add<GatherFolder>(context);
4109 }
4110 
4111 //===----------------------------------------------------------------------===//
4112 // ScatterOp
4113 //===----------------------------------------------------------------------===//
4114 
4116  VectorType indVType = getIndexVectorType();
4117  VectorType maskVType = getMaskVectorType();
4118  VectorType valueVType = getVectorType();
4119  MemRefType memType = getMemRefType();
4120 
4121  if (valueVType.getElementType() != memType.getElementType())
4122  return emitOpError("base and valueToStore element type should match");
4123  if (llvm::size(getIndices()) != memType.getRank())
4124  return emitOpError("requires ") << memType.getRank() << " indices";
4125  if (valueVType.getDimSize(0) != indVType.getDimSize(0))
4126  return emitOpError("expected valueToStore dim to match indices dim");
4127  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4128  return emitOpError("expected valueToStore dim to match mask dim");
4129  return success();
4130 }
4131 
4132 namespace {
4133 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
4134 public:
4136  LogicalResult matchAndRewrite(ScatterOp scatter,
4137  PatternRewriter &rewriter) const override {
4138  switch (get1DMaskFormat(scatter.getMask())) {
4139  case MaskFormat::AllTrue:
4140  return failure(); // no unmasked equivalent
4141  case MaskFormat::AllFalse:
4142  rewriter.eraseOp(scatter);
4143  return success();
4144  case MaskFormat::Unknown:
4145  return failure();
4146  }
4147  llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
4148  }
4149 };
4150 } // namespace
4151 
4152 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
4153  MLIRContext *context) {
4154  results.add<ScatterFolder>(context);
4155 }
4156 
4157 //===----------------------------------------------------------------------===//
4158 // ExpandLoadOp
4159 //===----------------------------------------------------------------------===//
4160 
4162  VectorType maskVType = getMaskVectorType();
4163  VectorType passVType = getPassThruVectorType();
4164  VectorType resVType = getVectorType();
4165  MemRefType memType = getMemRefType();
4166 
4167  if (resVType.getElementType() != memType.getElementType())
4168  return emitOpError("base and result element type should match");
4169  if (llvm::size(getIndices()) != memType.getRank())
4170  return emitOpError("requires ") << memType.getRank() << " indices";
4171  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4172  return emitOpError("expected result dim to match mask dim");
4173  if (resVType != passVType)
4174  return emitOpError("expected pass_thru of same type as result type");
4175  return success();
4176 }
4177 
4178 namespace {
4179 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
4180 public:
4182  LogicalResult matchAndRewrite(ExpandLoadOp expand,
4183  PatternRewriter &rewriter) const override {
4184  switch (get1DMaskFormat(expand.getMask())) {
4185  case MaskFormat::AllTrue:
4186  rewriter.replaceOpWithNewOp<vector::LoadOp>(
4187  expand, expand.getType(), expand.getBase(), expand.getIndices());
4188  return success();
4189  case MaskFormat::AllFalse:
4190  rewriter.replaceOp(expand, expand.getPassThru());
4191  return success();
4192  case MaskFormat::Unknown:
4193  return failure();
4194  }
4195  llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
4196  }
4197 };
4198 } // namespace
4199 
4200 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4201  MLIRContext *context) {
4202  results.add<ExpandLoadFolder>(context);
4203 }
4204 
4205 //===----------------------------------------------------------------------===//
4206 // CompressStoreOp
4207 //===----------------------------------------------------------------------===//
4208 
4210  VectorType maskVType = getMaskVectorType();
4211  VectorType valueVType = getVectorType();
4212  MemRefType memType = getMemRefType();
4213 
4214  if (valueVType.getElementType() != memType.getElementType())
4215  return emitOpError("base and valueToStore element type should match");
4216  if (llvm::size(getIndices()) != memType.getRank())
4217  return emitOpError("requires ") << memType.getRank() << " indices";
4218  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4219  return emitOpError("expected valueToStore dim to match mask dim");
4220  return success();
4221 }
4222 
4223 namespace {
4224 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
4225 public:
4227  LogicalResult matchAndRewrite(CompressStoreOp compress,
4228  PatternRewriter &rewriter) const override {
4229  switch (get1DMaskFormat(compress.getMask())) {
4230  case MaskFormat::AllTrue:
4231  rewriter.replaceOpWithNewOp<vector::StoreOp>(
4232  compress, compress.getValueToStore(), compress.getBase(),
4233  compress.getIndices());
4234  return success();
4235  case MaskFormat::AllFalse:
4236  rewriter.eraseOp(compress);
4237  return success();
4238  case MaskFormat::Unknown:
4239  return failure();
4240  }
4241  llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
4242  }
4243 };
4244 } // namespace
4245 
4246 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
4247  MLIRContext *context) {
4248  results.add<CompressStoreFolder>(context);
4249 }
4250 
4251 //===----------------------------------------------------------------------===//
4252 // ShapeCastOp
4253 //===----------------------------------------------------------------------===//
4254 
4255 /// Returns true if each element of 'a' is equal to the product of a contiguous
4256 /// sequence of the elements of 'b'. Returns false otherwise.
4257 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
4258  unsigned rankA = a.size();
4259  unsigned rankB = b.size();
4260  assert(rankA < rankB);
4261 
4262  unsigned i = 0;
4263  unsigned j = 0;
4264  while (i < rankA && j < rankB) {
4265  int64_t dimA = a[i];
4266  int64_t dimB = 1;
4267  while (dimB < dimA && j < rankB)
4268  dimB *= b[j++];
4269  if (dimA != dimB)
4270  break;
4271  ++i;
4272 
4273  // Handle the case when trailing dimensions are of size 1.
4274  // Include them into the contiguous sequence.
4275  auto isOne = [](int64_t v) { return v == 1; };
4276  if (i < rankA && llvm::all_of(a.slice(i), isOne))
4277  i = rankA;
4278  if (j < rankB && llvm::all_of(b.slice(j), isOne))
4279  j = rankB;
4280  }
4281 
4282  return i == rankA && j == rankB;
4283 }
4284 
4285 static LogicalResult verifyVectorShapeCast(Operation *op,
4286  VectorType sourceVectorType,
4287  VectorType resultVectorType) {
4288  // Check that element type is the same.
4289  if (sourceVectorType.getElementType() != resultVectorType.getElementType())
4290  return op->emitOpError("source/result vectors must have same element type");
4291  auto sourceShape = sourceVectorType.getShape();
4292  auto resultShape = resultVectorType.getShape();
4293 
4294  // Check that product of source dim sizes matches product of result dim sizes.
4295  int64_t sourceDimProduct = std::accumulate(
4296  sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
4297  int64_t resultDimProduct = std::accumulate(
4298  resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
4299  if (sourceDimProduct != resultDimProduct)
4300  return op->emitOpError("source/result number of elements must match");
4301 
4302  // Check that expanding/contracting rank cases.
4303  unsigned sourceRank = sourceVectorType.getRank();
4304  unsigned resultRank = resultVectorType.getRank();
4305  if (sourceRank < resultRank) {
4306  if (!isValidShapeCast(sourceShape, resultShape))
4307  return op->emitOpError("invalid shape cast");
4308  } else if (sourceRank > resultRank) {
4309  if (!isValidShapeCast(resultShape, sourceShape))
4310  return op->emitOpError("invalid shape cast");
4311  }
4312  return success();
4313 }
4314 
4316  auto sourceVectorType = getSource().getType().dyn_cast_or_null<VectorType>();
4317  auto resultVectorType = getResult().getType().dyn_cast_or_null<VectorType>();
4318 
4319  // Check if source/result are of vector type.
4320  if (sourceVectorType && resultVectorType)
4321  return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
4322 
4323  return success();
4324 }
4325 
4326 OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
4327  // No-op shape cast.
4328  if (getSource().getType() == getResult().getType())
4329  return getSource();
4330 
4331  // Canceling shape casts.
4332  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
4333  if (getResult().getType() == otherOp.getSource().getType())
4334  return otherOp.getSource();
4335 
4336  // Only allows valid transitive folding.
4337  VectorType srcType = otherOp.getSource().getType().cast<VectorType>();
4338  VectorType resultType = getResult().getType().cast<VectorType>();
4339  if (srcType.getRank() < resultType.getRank()) {
4340  if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
4341  return {};
4342  } else if (srcType.getRank() > resultType.getRank()) {
4343  if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
4344  return {};
4345  } else {
4346  return {};
4347  }
4348 
4349  setOperand(otherOp.getSource());
4350  return getResult();
4351  }
4352 
4353  // Cancelling broadcast and shape cast ops.
4354  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
4355  if (bcastOp.getSourceType() == getType())
4356  return bcastOp.getSource();
4357  }
4358 
4359  return {};
4360 }
4361 
4362 namespace {
4363 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
4364 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
4365 public:
4367 
4368  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
4369  PatternRewriter &rewriter) const override {
4370  auto constantOp =
4371  shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
4372  if (!constantOp)
4373  return failure();
4374  // Only handle splat for now.
4375  auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
4376  if (!dense)
4377  return failure();
4378  auto newAttr =
4379  DenseElementsAttr::get(shapeCastOp.getType().cast<VectorType>(),
4380  dense.getSplatValue<Attribute>());
4381  rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
4382  return success();
4383  }
4384 };
4385 
4386 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
4387 /// This only applies when the shape of the broadcast source is a suffix of the
4388 /// shape of the result (i.e. when broadcast without reshape is expressive
4389 /// enough to capture the result in a single op).
4390 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
4391 public:
4393 
4394  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
4395  PatternRewriter &rewriter) const override {
4396  auto broadcastOp =
4397  shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
4398  if (!broadcastOp)
4399  return failure();
4400 
4401  auto broadcastSourceVectorType =
4402  broadcastOp.getSourceType().dyn_cast<VectorType>();
4403  auto broadcastSourceShape = broadcastSourceVectorType
4404  ? broadcastSourceVectorType.getShape()
4405  : ArrayRef<int64_t>{};
4406  auto shapeCastTargetShape = shapeCastOp.getResultVectorType().getShape();
4407 
4408  // Bail if `broadcastSourceShape` is not a suffix of the result.
4409  bool isSuffix = (broadcastSourceShape == shapeCastTargetShape.take_back(
4410  broadcastSourceShape.size()));
4411  if (!isSuffix)
4412  return failure();
4413 
4414  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
4415  shapeCastOp, shapeCastOp.getResultVectorType(),
4416  broadcastOp.getSource());
4417  return success();
4418  }
4419 };
4420 
4421 } // namespace
4422 
4423 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
4424  MLIRContext *context) {
4425  results.add<ShapeCastConstantFolder, ShapeCastBroadcastFolder>(context);
4426 }
4427 
4428 //===----------------------------------------------------------------------===//
4429 // VectorBitCastOp
4430 //===----------------------------------------------------------------------===//
4431 
4433  auto sourceVectorType = getSourceVectorType();
4434  auto resultVectorType = getResultVectorType();
4435 
4436  for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
4437  if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
4438  return emitOpError("dimension size mismatch at: ") << i;
4439  }
4440 
4441  DataLayout dataLayout = DataLayout::closest(*this);
4442  auto sourceElementBits =
4443  dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
4444  auto resultElementBits =
4445  dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
4446 
4447  if (sourceVectorType.getRank() == 0) {
4448  if (sourceElementBits != resultElementBits)
4449  return emitOpError("source/result bitwidth of the 0-D vector element "
4450  "types must be equal");
4451  } else if (sourceElementBits * sourceVectorType.getShape().back() !=
4452  resultElementBits * resultVectorType.getShape().back()) {
4453  return emitOpError(
4454  "source/result bitwidth of the minor 1-D vectors must be equal");
4455  }
4456 
4457  return success();
4458 }
4459 
4460 OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
4461  // Nop cast.
4462  if (getSource().getType() == getResult().getType())
4463  return getSource();
4464 
4465  // Canceling bitcasts.
4466  if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
4467  if (getResult().getType() == otherOp.getSource().getType())
4468  return otherOp.getSource();
4469 
4470  setOperand(otherOp.getSource());
4471  return getResult();
4472  }
4473 
4474  Attribute sourceConstant = operands.front();
4475  if (!sourceConstant)
4476  return {};
4477 
4478  Type srcElemType = getSourceVectorType().getElementType();
4479  Type dstElemType = getResultVectorType().getElementType();
4480 
4481  if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) {
4482  if (floatPack.isSplat()) {
4483  auto splat = floatPack.getSplatValue<FloatAttr>();
4484 
4485  // Casting fp16 into fp32.
4486  if (srcElemType.isF16() && dstElemType.isF32()) {
4487  uint32_t bits = static_cast<uint32_t>(
4488  splat.getValue().bitcastToAPInt().getZExtValue());
4489  // Duplicate the 16-bit pattern.
4490  bits = (bits << 16) | (bits & 0xffff);
4491  APInt intBits(32, bits);
4492  APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
4493  return DenseElementsAttr::get(getResultVectorType(), floatBits);
4494  }
4495  }
4496  }
4497 
4498  return {};
4499 }
4500 
4501 //===----------------------------------------------------------------------===//
4502 // TypeCastOp
4503 //===----------------------------------------------------------------------===//
4504 
4505 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
4506  auto vectorType = memRefType.getElementType().dyn_cast<VectorType>();
4507  SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
4508  memRefType.getShape().end());
4509  if (vectorType)
4510  res.append(vectorType.getShape().begin(), vectorType.getShape().end());
4511  return res;
4512 }
4513 
4514 /// Build the canonical memRefType with a single vector.
4515 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
4516 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
4517  Value source) {
4518  result.addOperands(source);
4519  MemRefType memRefType = source.getType().cast<MemRefType>();
4520  VectorType vectorType =
4521  VectorType::get(extractShape(memRefType),
4523  result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
4524  memRefType.getMemorySpace()));
4525 }
4526 
4528  MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType());
4529  if (!canonicalType.getLayout().isIdentity())
4530  return emitOpError("expects operand to be a memref with identity layout");
4531  if (!getResultMemRefType().getLayout().isIdentity())
4532  return emitOpError("expects result to be a memref with identity layout");
4533  if (getResultMemRefType().getMemorySpace() !=
4535  return emitOpError("expects result in same memory space");
4536 
4537  auto sourceType = getMemRefType();
4538  auto resultType = getResultMemRefType();
4539  if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
4541  return emitOpError(
4542  "expects result and operand with same underlying scalar type: ")
4543  << resultType;
4544  if (extractShape(sourceType) != extractShape(resultType))
4545  return emitOpError(
4546  "expects concatenated result and operand shapes to be equal: ")
4547  << resultType;
4548  return success();
4549 }
4550 
4551 //===----------------------------------------------------------------------===//
4552 // TransposeOp
4553 //===----------------------------------------------------------------------===//
4554 
4555 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
4556  Value vector, ArrayRef<int64_t> transp) {
4557  VectorType vt = vector.getType().cast<VectorType>();
4558  SmallVector<int64_t, 4> transposedShape(vt.getRank());
4559  for (unsigned i = 0; i < transp.size(); ++i)
4560  transposedShape[i] = vt.getShape()[transp[i]];
4561 
4562  result.addOperands(vector);
4563  result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
4564  result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp));
4565 }
4566 
4567 OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
4568  // Eliminate splat constant transpose ops.
4569  if (auto attr = operands.front().dyn_cast_or_null<DenseElementsAttr>())
4570  if (attr.isSplat())
4571  return attr.reshape(getResultType());
4572 
4573  // Eliminate identity transpose ops. This happens when the dimensions of the
4574  // input vector remain in their original order after the transpose operation.
4575  SmallVector<int64_t, 4> transp;
4576  getTransp(transp);
4577 
4578  // Check if the permutation of the dimensions contains sequential values:
4579  // {0, 1, 2, ...}.
4580  for (int64_t i = 0, e = transp.size(); i < e; i++) {
4581  if (transp[i] != i)
4582  return {};
4583  }
4584 
4585  return getVector();
4586 }
4587 
4589  VectorType vectorType = getVectorType();
4590  VectorType resultType = getResultType();
4591  int64_t rank = resultType.getRank();
4592  if (vectorType.getRank() != rank)
4593  return emitOpError("vector result rank mismatch: ") << rank;
4594  // Verify transposition array.
4595  auto transpAttr = getTransp().getValue();
4596  int64_t size = transpAttr.size();
4597  if (rank != size)
4598  return emitOpError("transposition length mismatch: ") << size;
4599  SmallVector<bool, 8> seen(rank, false);
4600  for (const auto &ta : llvm::enumerate(transpAttr)) {
4601  int64_t i = ta.value().cast<IntegerAttr>().getInt();
4602  if (i < 0 || i >= rank)
4603  return emitOpError("transposition index out of range: ") << i;
4604  if (seen[i])
4605  return emitOpError("duplicate position index: ") << i;
4606  seen[i] = true;
4607  if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
4608  return emitOpError("dimension size mismatch at: ") << i;
4609  }
4610  return success();
4611 }
4612 
4613 Optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
4614  return llvm::to_vector<4>(getResultType().getShape());
4615 }
4616 
4617 namespace {
4618 
4619 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
4620 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
4621 public:
4623 
4624  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
4625  PatternRewriter &rewriter) const override {
4626  // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
4627  auto getPermutation = [](vector::TransposeOp transpose) {
4628  SmallVector<int64_t, 4> permutation;
4629  transpose.getTransp(permutation);
4630  return permutation;
4631  };
4632 
4633  // Composes two permutations: result[i] = permutation1[permutation2[i]].
4634  auto composePermutations = [](ArrayRef<int64_t> permutation1,
4635  ArrayRef<int64_t> permutation2) {
4636  SmallVector<int64_t, 4> result;
4637  for (auto index : permutation2)
4638  result.push_back(permutation1[index]);
4639  return result;
4640  };
4641 
4642  // Return if the input of 'transposeOp' is not defined by another transpose.
4643  vector::TransposeOp parentTransposeOp =
4644  transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
4645  if (!parentTransposeOp)
4646  return failure();
4647 
4648  SmallVector<int64_t, 4> permutation = composePermutations(
4649  getPermutation(parentTransposeOp), getPermutation(transposeOp));
4650  // Replace 'transposeOp' with a new transpose operation.
4651  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
4652  transposeOp, transposeOp.getResult().getType(),
4653  parentTransposeOp.getVector(),
4654  vector::getVectorSubscriptAttr(rewriter, permutation));
4655  return success();
4656  }
4657 };
4658 
4659 // Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
4660 struct FoldTransposedScalarBroadcast final
4661  : public OpRewritePattern<vector::TransposeOp> {
4663 
4664  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
4665  PatternRewriter &rewriter) const override {
4666  auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
4667  if (!bcastOp)
4668  return failure();
4669 
4670  auto srcVectorType = bcastOp.getSourceType().dyn_cast<VectorType>();
4671  if (!srcVectorType || srcVectorType.getNumElements() == 1) {
4672  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
4673  transposeOp, transposeOp.getResultType(), bcastOp.getSource());
4674  return success();
4675  }
4676 
4677  return failure();
4678  }
4679 };
4680 
4681 // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
4682 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
4683 public:
4685 
4686  LogicalResult matchAndRewrite(TransposeOp transposeOp,
4687  PatternRewriter &rewriter) const override {
4688  auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
4689  if (!splatOp)
4690  return failure();
4691 
4692  rewriter.replaceOpWithNewOp<vector::SplatOp>(
4693  transposeOp, transposeOp.getResultType(), splatOp.getInput());
4694  return success();
4695  }
4696 };
4697 
4698 } // namespace
4699 
4700 void vector::TransposeOp::getCanonicalizationPatterns(
4701  RewritePatternSet &results, MLIRContext *context) {
4702  results
4703  .add<FoldTransposedScalarBroadcast, TransposeFolder, FoldTransposeSplat>(
4704  context);
4705 }
4706 
4707 void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
4708  populateFromInt64AttrArray(getTransp(), results);
4709 }
4710 
4711 //===----------------------------------------------------------------------===//
4712 // ConstantMaskOp
4713 //===----------------------------------------------------------------------===//
4714 
4716  auto resultType = getResult().getType().cast<VectorType>();
4717  // Check the corner case of 0-D vectors first.
4718  if (resultType.getRank() == 0) {
4719  if (getMaskDimSizes().size() != 1)
4720  return emitError("array attr must have length 1 for 0-D vectors");
4721  auto dim = getMaskDimSizes()[0].cast<IntegerAttr>().getInt();
4722  if (dim != 0 && dim != 1)
4723  return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
4724  return success();
4725  }
4726 
4727  // Verify that array attr size matches the rank of the vector result.
4728  if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
4729  return emitOpError(
4730  "must specify array attr of size equal vector result rank");
4731  // Verify that each array attr element is in bounds of corresponding vector
4732  // result dimension size.
4733  auto resultShape = resultType.getShape();
4734  SmallVector<int64_t, 4> maskDimSizes;
4735  for (const auto &it : llvm::enumerate(getMaskDimSizes())) {
4736  int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
4737  if (attrValue < 0 || attrValue > resultShape[it.index()])
4738  return emitOpError(
4739  "array attr of size out of bounds of vector result dimension size");
4740  maskDimSizes.push_back(attrValue);
4741  }
4742  // Verify that if one mask dim size is zero, they all should be zero (because
4743  // the mask region is a conjunction of each mask dimension interval).
4744  bool anyZeros = llvm::is_contained(maskDimSizes, 0);
4745  bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
4746  if (anyZeros && !allZeros)
4747  return emitOpError("expected all mask dim sizes to be zeros, "
4748  "as a result of conjunction with zero mask dim");
4749  // Verify that if the mask type is scalable, dimensions should be zero because
4750  // constant scalable masks can only be defined for the "none set" or "all set"
4751  // cases, and there is no VLA way to define an "all set" case for
4752  // `vector.constant_mask`. In the future, a convention could be established
4753  // to decide if a specific dimension value could be considered as "all set".
4754  if (resultType.isScalable() &&
4755  getMaskDimSizes()[0].cast<IntegerAttr>().getInt() != 0)
4756  return emitOpError("expected mask dim sizes for scalable masks to be 0");
4757  return success();
4758 }
4759 
4760 //===----------------------------------------------------------------------===//
4761 // CreateMaskOp
4762 //===----------------------------------------------------------------------===//
4763 
4765  auto vectorType = getResult().getType().cast<VectorType>();
4766  // Verify that an operand was specified for each result vector each dimension.
4767  if (vectorType.getRank() == 0) {
4768  if (getNumOperands() != 1)
4769  return emitOpError(
4770  "must specify exactly one operand for 0-D create_mask");
4771  } else if (getNumOperands() !=
4772  getResult().getType().cast<VectorType>().getRank()) {
4773  return emitOpError(
4774  "must specify an operand for each result vector dimension");
4775  }
4776  return success();
4777 }
4778 
4779 namespace {
4780 
4781 // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
4782 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
4783 public:
4785 
4786  LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
4787  PatternRewriter &rewriter) const override {
4788  // Return if any of 'createMaskOp' operands are not defined by a constant.
4789  auto isNotDefByConstant = [](Value operand) {
4790  return !isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
4791  };
4792  if (llvm::any_of(createMaskOp.operands(), isNotDefByConstant))
4793  return failure();
4794 
4795  // CreateMaskOp for scalable vectors can be folded only if all dimensions
4796  // are negative or zero.
4797  if (auto vType = createMaskOp.getType().dyn_cast<VectorType>()) {
4798  if (vType.isScalable())
4799  for (auto opDim : createMaskOp.getOperands()) {
4800  APInt intVal;
4801  if (matchPattern(opDim, m_ConstantInt(&intVal)) &&
4802  intVal.isStrictlyPositive())
4803  return failure();
4804  }
4805  }
4806 
4807  // Gather constant mask dimension sizes.
4808  SmallVector<int64_t, 4> maskDimSizes;
4809  for (auto it : llvm::zip(createMaskOp.operands(),
4810  createMaskOp.getType().getShape())) {
4811  auto *defOp = std::get<0>(it).getDefiningOp();
4812  int64_t maxDimSize = std::get<1>(it);
4813  int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value();
4814  dimSize = std::min(dimSize, maxDimSize);
4815  // If one of dim sizes is zero, set all dims to zero.
4816  if (dimSize <= 0) {
4817  maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
4818  break;
4819  }
4820  maskDimSizes.push_back(dimSize);
4821  }
4822  // Replace 'createMaskOp' with ConstantMaskOp.
4823  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4824  createMaskOp, createMaskOp.getResult().getType(),
4825  vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
4826  return success();
4827  }
4828 };
4829 
4830 } // namespace
4831 
4832 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
4833  MLIRContext *context) {
4834  results.add<CreateMaskFolder>(context);
4835 }
4836 
4837 //===----------------------------------------------------------------------===//
4838 // ScanOp
4839 //===----------------------------------------------------------------------===//
4840 
4842  VectorType srcType = getSourceType();
4843  VectorType initialType = getInitialValueType();
4844  // Check reduction dimension < rank.
4845  int64_t srcRank = srcType.getRank();
4846  int64_t reductionDim = getReductionDim();
4847  if (reductionDim >= srcRank)
4848  return emitOpError("reduction dimension ")
4849  << reductionDim << " has to be less than " << srcRank;
4850 
4851  // Check that rank(initial_value) = rank(src) - 1.
4852  int64_t initialValueRank = initialType.getRank();
4853  if (initialValueRank != srcRank - 1)
4854  return emitOpError("initial value rank ")
4855  << initialValueRank << " has to be equal to " << srcRank - 1;
4856 
4857  // Check shapes of initial value and src.
4858  ArrayRef<int64_t> srcShape = srcType.getShape();
4859  ArrayRef<int64_t> initialValueShapes = initialType.getShape();
4860  SmallVector<int64_t> expectedShape;
4861  for (int i = 0; i < srcRank; i++) {
4862  if (i != reductionDim)
4863  expectedShape.push_back(srcShape[i]);
4864  }
4865  if (llvm::any_of(llvm::zip(initialValueShapes, expectedShape),
4866  [](std::tuple<int64_t, int64_t> s) {
4867  return std::get<0>(s) != std::get<1>(s);
4868  })) {
4869  return emitOpError("incompatible input/initial value shapes");
4870  }
4871 
4872  // Verify supported reduction kind.
4873  Type eltType = getDestType().getElementType();
4874  if (!isSupportedCombiningKind(getKind(), eltType))
4875  return emitOpError("unsupported reduction type ")
4876  << eltType << " for kind '" << stringifyCombiningKind(getKind())
4877  << "'";
4878 
4879  return success();
4880 }
4881 
4883  RewritePatternSet &patterns) {
4884  patterns
4885  .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
4886  ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
4887  StridedSliceConstantMaskFolder, TransposeFolder>(
4888  patterns.getContext());
4889 }
4890 
4891 //===----------------------------------------------------------------------===//
4892 // SplatOp
4893 //===----------------------------------------------------------------------===//
4894 
4895 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
4896  auto constOperand = operands.front();
4897  if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
4898  return {};
4899 
4900  // SplatElementsAttr::get treats single value for second arg as being a splat.
4901  return SplatElementsAttr::get(getType(), {constOperand});
4902 }
4903 
4904 //===----------------------------------------------------------------------===//
4905 // WarpExecuteOnLane0Op
4906 //===----------------------------------------------------------------------===//
4907 
4909  p << "(" << getLaneid() << ")";
4910 
4911  SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
4912  auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
4913  p << "[" << warpSizeAttr.cast<IntegerAttr>().getInt() << "]";
4914 
4915  if (!getArgs().empty())
4916  p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
4917  if (!getResults().empty())
4918  p << " -> (" << getResults().getTypes() << ')';
4919  p << " ";
4920  p.printRegion(getRegion(),
4921  /*printEntryBlockArgs=*/true,
4922  /*printBlockTerminators=*/!getResults().empty());
4923  p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
4924 }
4925 
4926 ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
4927  OperationState &result) {
4928  // Create the region.
4929  result.regions.reserve(1);
4930  Region *warpRegion = result.addRegion();
4931 
4932  auto &builder = parser.getBuilder();
4934 
4935  // Parse predicate operand.
4936  if (parser.parseLParen() ||
4937  parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
4938  parser.parseRParen())
4939  return failure();
4940 
4941  int64_t warpSize;
4942  if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
4943  parser.parseRSquare())
4944  return failure();
4945  result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
4946  builder.getContext())),
4947  builder.getI64IntegerAttr(warpSize));
4948 
4949  if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
4950  return failure();
4951 
4952  llvm::SMLoc inputsOperandsLoc;
4953  SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
4954  SmallVector<Type> inputTypes;
4955  if (succeeded(parser.parseOptionalKeyword("args"))) {
4956  if (parser.parseLParen())
4957  return failure();
4958 
4959  inputsOperandsLoc = parser.getCurrentLocation();
4960  if (parser.parseOperandList(inputsOperands) ||
4961  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
4962  return failure();
4963  }
4964  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
4965  result.operands))
4966  return failure();
4967 
4968  // Parse optional results type list.
4969  if (parser.parseOptionalArrowTypeList(result.types))
4970  return failure();
4971  // Parse the region.
4972  if (parser.parseRegion(*warpRegion, /*arguments=*/{},
4973  /*argTypes=*/{}))
4974  return failure();
4975  WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
4976 
4977  // Parse the optional attribute list.
4978  if (parser.parseOptionalAttrDict(result.attributes))
4979  return failure();
4980  return success();
4981 }
4982 
4983 void WarpExecuteOnLane0Op::getSuccessorRegions(
4984  Optional<unsigned> index, ArrayRef<Attribute> operands,
4986  if (index) {
4987  regions.push_back(RegionSuccessor(getResults()));
4988  return;
4989  }
4990 
4991  // The warp region is always executed
4992  regions.push_back(RegionSuccessor(&getWarpRegion()));
4993 }
4994 
4995 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
4996  TypeRange resultTypes, Value laneId,
4997  int64_t warpSize) {
4998  build(builder, result, resultTypes, laneId, warpSize,
4999  /*operands=*/llvm::None, /*argTypes=*/llvm::None);
5000 }
5001 
5002 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
5003  TypeRange resultTypes, Value laneId,
5004  int64_t warpSize, ValueRange args,
5005  TypeRange blockArgTypes) {
5006  result.addOperands(laneId);
5007  result.addAttribute(getAttributeNames()[0],
5008  builder.getI64IntegerAttr(warpSize));
5009  result.addTypes(resultTypes);
5010  result.addOperands(args);
5011  assert(args.size() == blockArgTypes.size());
5012  OpBuilder::InsertionGuard guard(builder);
5013  Region *warpRegion = result.addRegion();
5014  Block *block = builder.createBlock(warpRegion);
5015  for (auto it : llvm::zip(blockArgTypes, args))
5016  block->addArgument(std::get<0>(it), std::get<1>(it).getLoc());
5017 }
5018 
5019 /// Helper check if the distributed vector type is consistent with the expanded
5020 /// type and distributed size.
5021 static LogicalResult verifyDistributedType(Type expanded, Type distributed,
5022  int64_t warpSize, Operation *op) {
5023  // If the types matches there is no distribution.
5024  if (expanded == distributed)
5025  return success();
5026  auto expandedVecType = expanded.dyn_cast<VectorType>();
5027  auto distributedVecType = distributed.dyn_cast<VectorType>();
5028  if (!expandedVecType || !distributedVecType)
5029  return op->emitOpError("expected vector type for distributed operands.");
5030  if (expandedVecType.getRank() != distributedVecType.getRank() ||
5031  expandedVecType.getElementType() != distributedVecType.getElementType())
5032  return op->emitOpError(
5033  "expected distributed vectors to have same rank and element type.");
5034  bool foundDistributedDim = false;
5035  for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
5036  if (expandedVecType.getDimSize(i) == distributedVecType.getDimSize(i))
5037  continue;
5038  if (expandedVecType.getDimSize(i) ==
5039  distributedVecType.getDimSize(i) * warpSize) {
5040  if (foundDistributedDim)
5041  return op->emitOpError()
5042  << "expected only one dimension to be distributed from "
5043  << expandedVecType << " to " << distributedVecType;
5044  foundDistributedDim = true;
5045  continue;
5046  }
5047  return op->emitOpError() << "incompatible distribution dimensions from "
5048  << expandedVecType << " to " << distributedVecType;
5049  }
5050  return success();
5051 }
5052 
5054  if (getArgs().size() != getWarpRegion().getNumArguments())
5055  return emitOpError(
5056  "expected same number op arguments and block arguments.");
5057  auto yield =
5058  cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
5059  if (yield.getNumOperands() != getNumResults())
5060  return emitOpError(
5061  "expected same number of yield operands and return values.");
5062  int64_t warpSize = getWarpSize();
5063  for (auto it : llvm::zip(getWarpRegion().getArguments(), getArgs())) {
5064  if (failed(verifyDistributedType(std::get<0>(it).getType(),
5065  std::get<1>(it).getType(), warpSize,
5066  getOperation())))
5067  return failure();
5068  }
5069  for (auto it : llvm::zip(yield.getOperands(), getResults())) {
5070  if (failed(verifyDistributedType(std::get<0>(it).getType(),
5071  std::get<1>(it).getType(), warpSize,
5072  getOperation())))
5073  return failure();
5074  }
5075  return success();
5076 }
5077 
5078 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
5079  return succeeded(
5080  verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
5081 }
5082 
5084  CombiningKind kind, Value v1, Value v2) {
5085  Type t1 = getElementTypeOrSelf(v1.getType());
5086  Type t2 = getElementTypeOrSelf(v2.getType());
5087  switch (kind) {
5088  case CombiningKind::ADD:
5089  if (t1.isIntOrIndex() && t2.isIntOrIndex())
5090  return b.createOrFold<arith::AddIOp>(loc, v1, v2);
5091  else if (t1.isa<FloatType>() && t2.isa<FloatType>())
5092  return b.createOrFold<arith::AddFOp>(loc, v1, v2);
5093  llvm_unreachable("invalid value types for ADD reduction");
5094  case CombiningKind::AND:
5095  assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5096  return b.createOrFold<arith::AndIOp>(loc, v1, v2);
5097  case CombiningKind::MAXF:
5098  assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
5099  "expected float values");
5100  return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
5101  case CombiningKind::MINF:
5102  assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
5103  "expected float values");
5104  return b.createOrFold<arith::MinFOp>(loc, v1, v2);
5105  case CombiningKind::MAXSI:
5106  assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5107  return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
5108  case CombiningKind::MINSI:
5109  assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5110  return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
5111  case CombiningKind::MAXUI:
5112  assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5113  return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
5114  case CombiningKind::MINUI:
5115  assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5116  return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
5117  case CombiningKind::MUL:
5118  if (t1.isIntOrIndex() && t2.isIntOrIndex())
5119  return b.createOrFold<arith::MulIOp>(loc, v1, v2);
5120  else if (t1.isa<FloatType>() && t2.isa<FloatType>())
5121  return b.createOrFold<arith::MulFOp>(loc, v1, v2);
5122  llvm_unreachable("invalid value types for MUL reduction");
5123  case CombiningKind::OR:
5124  assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5125  return b.createOrFold<arith::OrIOp>(loc, v1, v2);
5126  case CombiningKind::XOR:
5127  assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5128  return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
5129  };
5130  llvm_unreachable("unknown CombiningKind");
5131 }
5132 
5133 //===----------------------------------------------------------------------===//
5134 // TableGen'd op method definitions
5135 //===----------------------------------------------------------------------===//
5136 
5137 #define GET_OP_CLASSES
5138 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
Include the generated interface declarations.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context)
This parses a single MLIR attribute to an MLIR context if it was valid.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
Definition: VectorOps.cpp:442
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
bool isF32() const
Definition: Types.cpp:23
static Value foldExtractFromShapeCast(ExtractOp extractOp)
Definition: VectorOps.cpp:1349
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:653
virtual ParseResult parseLParen()=0
Parse a ( token.
An attribute that represents a reference to a dense float vector or tensor object.
MLIRContext * getContext() const
Definition: Builders.h:54
constexpr StringRef getParallelIteratorTypeName()
Use to encode that a particular iterator type has parallel semantics.
U cast() const
Definition: Attributes.h:135
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:466
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
Definition: MPInt.h:369
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:439
U dyn_cast_or_null() const
Definition: Attributes.h:131
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:355
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
unsigned getNumSymbols() const
Definition: AffineMap.cpp:298
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:288
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:514
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
int64_t getValue() const
Definition: AffineExpr.cpp:506
Block represents an ordered list of Operations.
Definition: Block.h:29
CombiningKind getKind() const
Definition: VectorOps.cpp:227
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:239
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:492
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:89
Value getOperand(unsigned idx)
Definition: Operation.h:267
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Definition: VectorOps.cpp:911
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
Definition: VectorOps.cpp:3094
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
This is a utility allocator used to allocate memory for instances of derived types.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:205
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, StringRef targetIteratorTypeName, MLIRContext *context)
Definition: VectorOps.cpp:815
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
SmallVector< int64_t, 4 > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper that returns a subset of arrayAttr as a vector of int64_t.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
Definition: VectorOps.cpp:1408
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
Definition: VectorOps.cpp:623
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:685
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
Definition: VectorOps.cpp:807
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:562
static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d)
Helper method to apply dimension ordering permutation.
bool isLastMemrefDimUnitStride(MemRefType type)
Return true if the last dimension of the MemRefType has unit stride.
Definition: VectorOps.cpp:115
This is the representation of an operand reference.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:244
constexpr StringRef getIteratorTypesAttrName()
Attribute name for the StrArrayAttr which encodes the type of a structured op&#39;s iterators.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
Definition: VectorOps.cpp:92
virtual ParseResult parseCustomAttributeWithFallback(Attribute &result, Type type, function_ref< ParseResult(Attribute &result, Type type)> parseAttribute)=0
Parse a custom attribute with the provided callback, unless the next token is #, in which case the ge...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
static ArrayRef< int64_t > vectorShape(Type type)
virtual ParseResult parseComma()=0
Parse a , token.
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
static constexpr const bool value
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
SmallVector< Value, 4 > operands
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:311
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:1685
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
unsigned getNumInputs() const
Definition: AffineMap.cpp:303
static DefaultResource * get()
Returns a unique instance for the given effect class.
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:282
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:300
virtual ParseResult parseLSquare()=0
Parse a [ token.
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector...
Definition: VectorOps.h:51
T * allocate()
Allocate an instance of the provided type.
U dyn_cast() const
Definition: AffineExpr.h:281
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
Definition: VectorOps.cpp:2140
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
Definition: VectorOps.cpp:2822
An attribute that represents a reference to a dense vector or tensor object.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseGreater()=0
Parse a &#39;>&#39; token.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
void addOperands(ValueRange newOperands)
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Definition: VectorOps.cpp:1591
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:720
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:548
U dyn_cast() const
Definition: Types.h:270
static LogicalResult foldMemRefCast(Operation *op)
This is a common class used for patterns of the form someop(memrefcast) -> someop It folds the source...
Definition: VectorOps.cpp:3069
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:478
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:99
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
bool isF16() const
Definition: Types.cpp:22
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, unsigned memorySpace=0)
Return a MemRefType to which the type of the given value can be bufferized.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
Operation::operand_range getIndices(Operation *op)
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
Definition: VectorOps.cpp:1077
virtual ParseResult parseRParen()=0
Parse a ) token.
Base type for affine expression.
Definition: AffineExpr.h:68
SmallVector< int64_t, 4 > delinearize(ArrayRef< int64_t > strides, int64_t linearIndex)
Given the strides together with a linear index in the dimension space, returns the vector-space offse...
static LogicalResult foldTensorCast(Operation *op)
Definition: VectorOps.cpp:3081
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
Definition: VectorOps.cpp:2185
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, AffineMap permutationMap, ArrayAttr inBounds)
Definition: VectorOps.cpp:2851
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
void addTypes(ArrayRef< Type > newTypes)
IntegerType getI1Type()
Definition: Builders.cpp:50
virtual ParseResult parseLess()=0
Parse a &#39;<&#39; token.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:360
static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width)
Definition: VectorOps.cpp:1834
void print(AsmPrinter &p) const
Definition: VectorOps.cpp:247
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
unsigned getNumResults() const
Definition: AffineMap.cpp:302
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
MaskFormat
Helper enum to classify mask value.
Definition: VectorOps.cpp:46
This represents an operation in an abstracted form, suitable for use with the builder APIs...
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns)
Collect a set of vector-to-vector canonicalization patterns.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true...
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:42
U cast() const
Definition: AffineExpr.h:291
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
ParseResult resolveOperands(ArrayRef< UnresolvedOperand > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Definition: VectorOps.cpp:1069
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:102
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:307
This class represents a specific instance of an effect.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.
virtual ParseResult parseRSquare()=0
Parse a ] token.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:489
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:93
LogicalResult emitOptionalError(Optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:489
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Definition: VectorOps.cpp:210
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
Definition: VectorOps.cpp:1315
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:233
This base class exposes generic asm parser hooks, usable across the various derived parsers...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:135
bool isa() const
Definition: AffineExpr.h:270
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
unsigned getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:315
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:122
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value v2)
Return the result value of reducing two scalar/vector values with the corresponding arith operation...
An attribute that specifies the combining function for vector.contract, and vector.reduction.
Definition: VectorOps.h:118
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
Definition: VectorOps.cpp:2105
MLIRContext * getContext() const
Get the context held by this operation state.
NamedAttrList attributes
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:279
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:294
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This class represents a successor of a region.
Region * addRegion()
Create a region that should be attached to the operation.
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
constexpr StringRef getIndexingMapsAttrName()
Attribute name for the AffineArrayAttr which encodes the relationship between a structured op iterato...
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:634
Type getType() const
Return the type of this value.
Definition: Value.h:118
IndexType getIndexType()
Definition: Builders.cpp:48
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
static constexpr const CombiningKind combiningKindsList[]
Definition: VectorOps.cpp:231
Do not split vector transfer operations.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
Definition: VectorOps.cpp:3111
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
Definition: VectorOps.cpp:2463
U dyn_cast() const
Definition: Attributes.h:127
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:332
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:41
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
Definition: VectorOps.cpp:2163
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:80
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
constexpr StringRef getStridesAttrName()
Attribute name for the StrArrayAttr which encodes the value of strides.
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
Definition: VectorOps.cpp:2120
virtual ParseResult parseType(Type &result)=0
Parse a type.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Definition: VectorOps.cpp:326
static Value foldExtractStridedOpFromInsertChain(ExtractOp op)
Fold extract_op fed from a chain of insertStridedSlice ops.
Definition: VectorOps.cpp:1449
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
AffineMap calculateImplicitMap(MapOp op)
Definition: VectorOps.cpp:1656
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
This class represents an operand of an operation.
Definition: Value.h:251
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
Definition: VectorOps.cpp:634
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:328
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:85
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB)
Same behavior as isDisjointTransferSet but doesn&#39;t require the operations to have the same tensor/mem...
Definition: VectorOps.cpp:157
Base storage class appearing in an attribute.
U cast() const
Definition: Value.h:108
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:512
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:372
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Definition: VectorOps.cpp:322
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
This base class exposes generic asm printer hooks, usable across the various derived printers...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
static CombiningKindAttr get(CombiningKind kind, MLIRContext *context)
Definition: VectorOps.cpp:222
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with &#39;argTypes&#39; arguments and set the insertion point to the end of it...
Definition: Builders.cpp:377
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:508
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:299
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:53
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
bool isa() const
Definition: Types.h:254
constexpr StringRef getReductionIteratorTypeName()
Use to encode that a particular iterator type has reduction semantics.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:235
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:67
This class represents success/failure for parsing-like operations that find it important to chain tog...
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static Attribute parse(AsmParser &parser, Type type)
Definition: VectorOps.cpp:257
This class helps build Operations.
Definition: Builders.h:192
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
Return a fused vector::ContractionOp which represents a patterns such as:
Definition: VectorOps.cpp:908
VectorType transferMaskType(VectorType vecType, AffineMap map)
Given the vector type and the permutation map of a vector transfer op, compute the expected mask type...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
bool operator==(const KeyTy &key) const
Definition: VectorOps.cpp:208
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
MLIRContext * getContext() const
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
Definition: FoldUtils.cpp:50
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value...
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Square brackets surrounding zero or more operands.
An attribute that represents a reference to a dense integer vector or tensor object.
U cast() const
Definition: Types.h:278
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
Definition: VectorOps.cpp:2937
The main mechanism for performing data layout queries.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
SmallVector< Type, 4 > types
Types of the results of this operation.
static MaskFormat get1DMaskFormat(Value mask)
Helper method to classify a 1-D mask value.
Definition: VectorOps.cpp:56
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB)
Return true if we can prove that the transfer operations access disjoint memory.
Definition: VectorOps.cpp:189
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Computes and returns the linearized index of &#39;offsets&#39; w.r.t. &#39;basis&#39;.
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:270