MLIR  16.0.0git
VectorOps.cpp
Go to the documentation of this file.
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/Builders.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/Support/LLVM.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/SmallVector.h"
37 #include "llvm/ADT/StringSet.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/ADT/bit.h"
40 
41 #include <cassert>
42 #include <cstdint>
43 #include <numeric>
44 
45 #include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc"
46 // Pull in all enum type and utility function definitions.
47 #include "mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc"
48 
49 using namespace mlir;
50 using namespace mlir::vector;
51 
52 /// Helper enum to classify mask value.
53 enum class MaskFormat {
54  AllTrue = 0,
55  AllFalse = 1,
56  Unknown = 2,
57 };
58 
59 /// Helper method to classify a mask value. Currently, the method
60 /// looks "under the hood" of a constant value with dense attributes
61 /// and a constant mask operation (since the client may be called at
62 /// various stages during progressive lowering).
64  if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
65  // Inspect constant dense values. We count up for bits that
66  // are set, count down for bits that are cleared, and bail
67  // when a mix is detected.
68  if (auto denseElts = c.getValue().dyn_cast<DenseIntElementsAttr>()) {
69  int64_t val = 0;
70  for (bool b : denseElts.getValues<bool>())
71  if (b && val >= 0)
72  val++;
73  else if (!b && val <= 0)
74  val--;
75  else
76  return MaskFormat::Unknown;
77  if (val > 0)
78  return MaskFormat::AllTrue;
79  if (val < 0)
80  return MaskFormat::AllFalse;
81  }
82  } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
83  // Inspect constant mask index. If the index exceeds the
84  // dimension size, all bits are set. If the index is zero
85  // or less, no bits are set.
86  ArrayAttr masks = m.getMaskDimSizes();
87  auto shape = m.getType().getShape();
88  bool allTrue = true;
89  bool allFalse = true;
90  for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
91  int64_t i = maskIdx.cast<IntegerAttr>().getInt();
92  if (i < dimSize)
93  allTrue = false;
94  if (i > 0)
95  allFalse = false;
96  }
97  if (allTrue)
98  return MaskFormat::AllTrue;
99  if (allFalse)
100  return MaskFormat::AllFalse;
101  }
102  return MaskFormat::Unknown;
103 }
104 
105 /// Default callback to build a region with a 'vector.yield' terminator with no
106 /// arguments.
108  builder.create<vector::YieldOp>(loc);
109 }
110 
111 // Helper for verifying combining kinds in contractions and reductions.
112 static bool isSupportedCombiningKind(CombiningKind combiningKind,
113  Type elementType) {
114  switch (combiningKind) {
115  case CombiningKind::ADD:
116  case CombiningKind::MUL:
117  return elementType.isIntOrIndexOrFloat();
118  case CombiningKind::MINUI:
119  case CombiningKind::MINSI:
120  case CombiningKind::MAXUI:
121  case CombiningKind::MAXSI:
122  case CombiningKind::AND:
123  case CombiningKind::OR:
124  case CombiningKind::XOR:
125  return elementType.isIntOrIndex();
126  case CombiningKind::MINF:
127  case CombiningKind::MAXF:
128  return elementType.isa<FloatType>();
129  }
130  return false;
131 }
132 
133 /// Return true if the last dimension of the MemRefType has unit stride. Also
134 /// return true for memrefs with no strides.
136  int64_t offset;
137  SmallVector<int64_t> strides;
138  auto successStrides = getStridesAndOffset(type, strides, offset);
139  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
140 }
141 
143  VectorType vectorType) {
144  int64_t elementVectorRank = 0;
145  VectorType elementVectorType =
146  shapedType.getElementType().dyn_cast<VectorType>();
147  if (elementVectorType)
148  elementVectorRank += elementVectorType.getRank();
149  // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
150  // TODO: replace once we have 0-d vectors.
151  if (shapedType.getRank() == 0 &&
152  vectorType.getShape() == ArrayRef<int64_t>{1})
153  return AffineMap::get(
154  /*numDims=*/0, /*numSymbols=*/0,
155  getAffineConstantExpr(0, shapedType.getContext()));
157  shapedType.getRank(), vectorType.getRank() - elementVectorRank,
158  shapedType.getContext());
159 }
160 
161 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
162  vector::TransferReadOp read) {
163  return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
164  !read.getMask() && defWrite.getIndices() == read.getIndices() &&
165  defWrite.getVectorType() == read.getVectorType() &&
166  defWrite.getPermutationMap() == read.getPermutationMap();
167 }
168 
169 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
170  vector::TransferWriteOp priorWrite) {
171  return priorWrite.getIndices() == write.getIndices() &&
172  priorWrite.getMask() == write.getMask() &&
173  priorWrite.getVectorType() == write.getVectorType() &&
174  priorWrite.getPermutationMap() == write.getPermutationMap();
175 }
176 
178  VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) {
179  // For simplicity only look at transfer of same type.
180  if (transferA.getVectorType() != transferB.getVectorType())
181  return false;
182  unsigned rankOffset = transferA.getLeadingShapedRank();
183  for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
184  auto indexA = transferA.indices()[i].getDefiningOp<arith::ConstantOp>();
185  auto indexB = transferB.indices()[i].getDefiningOp<arith::ConstantOp>();
186  // If any of the indices are dynamic we cannot prove anything.
187  if (!indexA || !indexB)
188  continue;
189 
190  if (i < rankOffset) {
191  // For leading dimensions, if we can prove that index are different we
192  // know we are accessing disjoint slices.
193  if (indexA.getValue().cast<IntegerAttr>().getInt() !=
194  indexB.getValue().cast<IntegerAttr>().getInt())
195  return true;
196  } else {
197  // For this dimension, we slice a part of the memref we need to make sure
198  // the intervals accessed don't overlap.
199  int64_t distance =
200  std::abs(indexA.getValue().cast<IntegerAttr>().getInt() -
201  indexB.getValue().cast<IntegerAttr>().getInt());
202  if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
203  return true;
204  }
205  }
206  return false;
207 }
208 
209 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
210  VectorTransferOpInterface transferB) {
211  if (transferA.source() != transferB.source())
212  return false;
213  return isDisjointTransferIndices(transferA, transferB);
214 }
215 
216 // Helper to iterate over n-D vector slice elements. Calculate the next
217 // `position` in the n-D vector of size `shape`, applying an offset `offsets`.
218 // Modifies the `position` in place. Returns a failure when `position` becomes
219 // the end position.
221  ArrayRef<int64_t> shape,
222  ArrayRef<int64_t> offsets) {
223  for (auto [posInDim, dimSize, offsetInDim] :
224  llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
225  ++posInDim;
226  if (posInDim < dimSize + offsetInDim)
227  return success();
228 
229  // Carry the overflow to the next loop iteration.
230  posInDim = offsetInDim;
231  }
232 
233  return failure();
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // CombiningKindAttr
238 //===----------------------------------------------------------------------===//
239 
240 namespace mlir {
241 namespace vector {
242 namespace detail {
244  using KeyTy = uint64_t;
245 
247 
248  bool operator==(const KeyTy &key) const { return value == key; }
249 
251  const KeyTy &key) {
252  return new (allocator.allocate<BitmaskEnumStorage>())
253  BitmaskEnumStorage(key);
254  }
255 
257 };
258 } // namespace detail
259 } // namespace vector
260 } // namespace mlir
261 
262 //===----------------------------------------------------------------------===//
263 // VectorDialect
264 //===----------------------------------------------------------------------===//
265 
266 void VectorDialect::initialize() {
267  addAttributes<
268 #define GET_ATTRDEF_LIST
269 #include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc"
270  >();
271 
272  addOperations<
273 #define GET_OP_LIST
274 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
275  >();
276 }
277 
278 /// Materialize a single constant operation from a given attribute value with
279 /// the desired resultant type.
281  Attribute value, Type type,
282  Location loc) {
283  return builder.create<arith::ConstantOp>(loc, type, value);
284 }
285 
287  return builder.getIntegerType(64);
288 }
289 
291  ArrayRef<int64_t> values) {
292  return builder.getI64ArrayAttr(values);
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // MultiDimReductionOp
297 //===----------------------------------------------------------------------===//
298 
299 void vector::MultiDimReductionOp::build(OpBuilder &builder,
300  OperationState &result, Value source,
301  Value acc, ArrayRef<bool> reductionMask,
302  CombiningKind kind) {
303  SmallVector<int64_t> reductionDims;
304  for (const auto &en : llvm::enumerate(reductionMask))
305  if (en.value())
306  reductionDims.push_back(en.index());
307  build(builder, result, kind, source, acc,
308  builder.getI64ArrayAttr(reductionDims));
309 }
310 
311 OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
312  // Single parallel dim, this is a noop.
313  if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
314  return getSource();
315  return {};
316 }
317 
318 Optional<SmallVector<int64_t, 4>> MultiDimReductionOp::getShapeForUnroll() {
319  return llvm::to_vector<4>(getSourceVectorType().getShape());
320 }
321 
323  SmallVector<int64_t> targetShape;
324  Type inferredReturnType;
325  for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
326  if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
327  return attr.cast<IntegerAttr>().getValue() == it.index();
328  }))
329  targetShape.push_back(it.value());
330  // TODO: update to also allow 0-d vectors when available.
331  if (targetShape.empty())
332  inferredReturnType = getSourceVectorType().getElementType();
333  else
334  inferredReturnType =
335  VectorType::get(targetShape, getSourceVectorType().getElementType());
336  if (getType() != inferredReturnType)
337  return emitOpError() << "destination type " << getType()
338  << " is incompatible with source type "
339  << getSourceVectorType();
340 
341  return success();
342 }
343 
344 namespace {
345 // Only unit dimensions that are being reduced are folded. If the dimension is
346 // unit, but not reduced, it is not folded, thereby keeping the output type the
347 // same. If not all dimensions which are reduced are of unit dimension, this
348 // transformation does nothing. This is just a generalization of
349 // ElideSingleElementReduction for ReduceOp.
350 struct ElideUnitDimsInMultiDimReduction
351  : public OpRewritePattern<MultiDimReductionOp> {
353 
354  LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
355  PatternRewriter &rewriter) const override {
356  ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
357  for (const auto &dim : enumerate(shape)) {
358  if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
359  return failure();
360  }
361  Location loc = reductionOp.getLoc();
362  Value acc = reductionOp.getAcc();
363  Value cast;
364  if (reductionOp.getDestType().isa<VectorType>()) {
365  cast = rewriter.create<vector::ShapeCastOp>(
366  loc, reductionOp.getDestType(), reductionOp.getSource());
367  } else {
368  // This means we are reducing all the dimensions, and all reduction
369  // dimensions are of size 1. So a simple extraction would do.
370  cast = rewriter.create<vector::ExtractOp>(
371  loc, reductionOp.getDestType(), reductionOp.getSource(),
372  rewriter.getI64ArrayAttr(SmallVector<int64_t>(shape.size(), 0)));
373  }
374 
375  Value result = vector::makeArithReduction(rewriter, loc,
376  reductionOp.getKind(), acc, cast);
377  rewriter.replaceOp(reductionOp, result);
378  return success();
379  }
380 };
381 } // namespace
382 
383 void MultiDimReductionOp::getCanonicalizationPatterns(
384  RewritePatternSet &results, MLIRContext *context) {
385  results.add<ElideUnitDimsInMultiDimReduction>(context);
386 }
387 
388 //===----------------------------------------------------------------------===//
389 // ReductionOp
390 //===----------------------------------------------------------------------===//
391 
392 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
393  CombiningKind kind, Value vector) {
394  build(builder, result, kind, vector, /*acc=*/Value());
395 }
396 
397 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
398  CombiningKind kind, Value vector, Value acc) {
399  build(builder, result, vector.getType().cast<VectorType>().getElementType(),
400  kind, vector, acc);
401 }
402 
404  // Verify for 0-D and 1-D vector.
405  int64_t rank = getVectorType().getRank();
406  if (rank > 1)
407  return emitOpError("unsupported reduction rank: ") << rank;
408 
409  // Verify supported reduction kind.
410  Type eltType = getDest().getType();
411  if (!isSupportedCombiningKind(getKind(), eltType))
412  return emitOpError("unsupported reduction type '")
413  << eltType << "' for kind '" << stringifyCombiningKind(getKind())
414  << "'";
415 
416  return success();
417 }
418 
419 ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) {
421  Type redType;
422  Type resType;
423  CombiningKindAttr kindAttr;
424  if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind",
425  result.attributes) ||
426  parser.parseComma() || parser.parseOperandList(operandsInfo) ||
427  parser.parseColonType(redType) ||
428  parser.parseKeywordType("into", resType) ||
429  (!operandsInfo.empty() &&
430  parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
431  (operandsInfo.size() > 1 &&
432  parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
433  parser.addTypeToList(resType, result.types))
434  return failure();
435  if (operandsInfo.empty() || operandsInfo.size() > 2)
436  return parser.emitError(parser.getNameLoc(),
437  "unsupported number of operands");
438  return success();
439 }
440 
442  p << " ";
443  getKindAttr().print(p);
444  p << ", " << getVector();
445  if (getAcc())
446  p << ", " << getAcc();
447  p << " : " << getVector().getType() << " into " << getDest().getType();
448 }
449 
451  OpBuilder &builder, Location loc,
452  Value vector) {
453  switch (op) {
454  case arith::AtomicRMWKind::addf:
455  case arith::AtomicRMWKind::addi:
456  return builder.create<vector::ReductionOp>(vector.getLoc(),
457  CombiningKind::ADD, vector);
458  case arith::AtomicRMWKind::mulf:
459  case arith::AtomicRMWKind::muli:
460  return builder.create<vector::ReductionOp>(vector.getLoc(),
461  CombiningKind::MUL, vector);
462  case arith::AtomicRMWKind::minf:
463  return builder.create<vector::ReductionOp>(vector.getLoc(),
464  CombiningKind::MINF, vector);
465  case arith::AtomicRMWKind::mins:
466  return builder.create<vector::ReductionOp>(vector.getLoc(),
467  CombiningKind::MINSI, vector);
468  case arith::AtomicRMWKind::minu:
469  return builder.create<vector::ReductionOp>(vector.getLoc(),
470  CombiningKind::MINUI, vector);
471  case arith::AtomicRMWKind::maxf:
472  return builder.create<vector::ReductionOp>(vector.getLoc(),
473  CombiningKind::MAXF, vector);
474  case arith::AtomicRMWKind::maxs:
475  return builder.create<vector::ReductionOp>(vector.getLoc(),
476  CombiningKind::MAXSI, vector);
477  case arith::AtomicRMWKind::maxu:
478  return builder.create<vector::ReductionOp>(vector.getLoc(),
479  CombiningKind::MAXUI, vector);
480  case arith::AtomicRMWKind::andi:
481  return builder.create<vector::ReductionOp>(vector.getLoc(),
482  CombiningKind::AND, vector);
483  case arith::AtomicRMWKind::ori:
484  return builder.create<vector::ReductionOp>(vector.getLoc(),
485  CombiningKind::OR, vector);
486  // TODO: Add remaining reduction operations.
487  default:
488  (void)emitOptionalError(loc, "Reduction operation type not supported");
489  break;
490  }
491  return nullptr;
492 }
493 
494 Optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
495  return llvm::to_vector<4>(getVectorType().getShape());
496 }
497 
498 namespace {
499 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
501 
502  LogicalResult matchAndRewrite(ReductionOp reductionOp,
503  PatternRewriter &rewriter) const override {
504  if (reductionOp.getVectorType().getDimSize(0) != 1)
505  return failure();
506 
507  Location loc = reductionOp.getLoc();
508  Value result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
509  reductionOp.getVector(),
510  rewriter.getI64ArrayAttr(0));
511 
512  if (Value acc = reductionOp.getAcc())
513  result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
514  result, acc);
515 
516  rewriter.replaceOp(reductionOp, result);
517  return success();
518  }
519 };
520 } // namespace
521 
522 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
523  MLIRContext *context) {
524  results.add<ElideSingleElementReduction>(context);
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // ContractionOp
529 //===----------------------------------------------------------------------===//
530 
531 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
532  Value lhs, Value rhs, Value acc,
533  ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
534  ArrayRef<IteratorType> iteratorTypes) {
535  result.addOperands({lhs, rhs, acc});
536  result.addTypes(acc.getType());
537  result.addAttribute(getIndexingMapsAttrName(result.name),
538  builder.getAffineMapArrayAttr(
539  AffineMap::inferFromExprList(indexingExprs)));
540  result.addAttribute(
541  getIteratorTypesAttrName(result.name),
542  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
543  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
544  return IteratorTypeAttr::get(builder.getContext(), t);
545  }))));
546 }
547 
548 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
549  Value lhs, Value rhs, Value acc,
550  ArrayAttr indexingMaps,
551  ArrayAttr iteratorTypes) {
552  build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
553  ContractionOp::getDefaultKind());
554 }
555 
556 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
557  Value lhs, Value rhs, Value acc,
558  ArrayAttr indexingMaps,
559  ArrayAttr iteratorTypes, CombiningKind kind) {
560  result.addOperands({lhs, rhs, acc});
561  result.addTypes(acc.getType());
562  result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
563  result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
564  result.addAttribute(getKindAttrName(result.name),
565  CombiningKindAttr::get(builder.getContext(), kind));
566 }
567 
568 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
573  SmallVector<Type, 2> types;
574  Type resultType;
575  auto loc = parser.getCurrentLocation();
576  DictionaryAttr dictAttr;
577  // TODO: Unify linalg op attribute parsing.
578  if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
579  parser.parseOperand(lhsInfo) || parser.parseComma() ||
580  parser.parseOperand(rhsInfo) || parser.parseComma() ||
581  parser.parseOperand(accInfo) ||
582  parser.parseTrailingOperandList(masksInfo) ||
583  parser.parseOptionalAttrDict(result.attributes) ||
584  parser.parseColonTypeList(types) ||
585  parser.parseKeywordType("into", resultType) ||
586  parser.resolveOperand(lhsInfo, types[0], result.operands) ||
587  parser.resolveOperand(rhsInfo, types[1], result.operands) ||
588  parser.resolveOperand(accInfo, resultType, result.operands) ||
589  parser.addTypeToList(resultType, result.types))
590  return failure();
591  result.attributes.assign(dictAttr.getValue().begin(),
592  dictAttr.getValue().end());
593 
594  // Convert array of string into an array of IteratyType enums. This is needed,
595  // because tests still use the old format when 'iterator_types' attribute is
596  // represented as an array of strings.
597  // TODO: Remove this conversion once tests are fixed.
598  ArrayAttr iteratorTypes =
599  result.attributes.get(getIteratorTypesAttrName(result.name))
600  .cast<ArrayAttr>();
601 
602  SmallVector<Attribute> iteratorTypeAttrs;
603 
604  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
605  auto maybeIteratorType = symbolizeIteratorType(s);
606  if (!maybeIteratorType.has_value())
607  return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
608 
609  iteratorTypeAttrs.push_back(
610  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
611  }
612  result.attributes.set(getIteratorTypesAttrName(result.name),
613  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
614 
615  if (!result.attributes.get(getKindAttrName(result.name))) {
616  result.addAttribute(
617  getKindAttrName(result.name),
618  CombiningKindAttr::get(result.getContext(),
619  ContractionOp::getDefaultKind()));
620  }
621  if (masksInfo.empty())
622  return success();
623  if (masksInfo.size() != 2)
624  return parser.emitError(parser.getNameLoc(),
625  "expected zero or exactly 2 vector mask operands");
626  auto lhsType = types[0].cast<VectorType>();
627  auto rhsType = types[1].cast<VectorType>();
628  auto maskElementType = parser.getBuilder().getI1Type();
629  std::array<Type, 2> maskTypes = {
630  VectorType::Builder(lhsType).setElementType(maskElementType),
631  VectorType::Builder(rhsType).setElementType(maskElementType)};
632  if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
633  return failure();
634  return success();
635 }
636 
638  // TODO: Unify printing code with linalg ops.
639  auto attrNames = getTraitAttrNames();
640  llvm::StringSet<> traitAttrsSet;
641  traitAttrsSet.insert(attrNames.begin(), attrNames.end());
643  for (auto attr : (*this)->getAttrs()) {
644  if (attr.getName() == getIteratorTypesAttrName()) {
645  auto iteratorTypes =
646  attr.getValue()
647  .cast<ArrayAttr>()
648  .getAsValueRange<IteratorTypeAttr, IteratorType>();
649  // Convert IteratorType enums into the string representation. This is
650  // needed, because tests still use the old format when 'iterator_types'
651  // attribute is represented as an array of strings.
652  // TODO: Remove this conversion once tests are fixed.
653  SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
654  llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
655  return StringAttr::get(getContext(), stringifyIteratorType(t));
656  }));
657 
658  attrs.emplace_back(getIteratorTypesAttrName(),
659  ArrayAttr::get(getContext(), iteratorTypeNames));
660  } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
661  attrs.push_back(attr);
662  }
663 
664  auto dictAttr = DictionaryAttr::get(getContext(), attrs);
665  p << " " << dictAttr << " " << getLhs() << ", ";
666  p << getRhs() << ", " << getAcc();
667  if (getMasks().size() == 2)
668  p << ", " << getMasks();
669 
670  p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
671  p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
672  << getResultType();
673 }
674 
675 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
676  const std::vector<std::pair<int64_t, int64_t>> &map) {
677  for (auto &dimPair : map) {
678  if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
679  dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
680  lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
681  return false;
682  }
683  return true;
684 }
685 
687  ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
688  Type resType,
689  const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
690  const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
691  DenseSet<int64_t> lhsContractingDimSet;
692  DenseSet<int64_t> rhsContractingDimSet;
693  for (auto &dimPair : contractingDimMap) {
694  lhsContractingDimSet.insert(dimPair.first);
695  rhsContractingDimSet.insert(dimPair.second);
696  }
697  DenseSet<int64_t> rhsBatchDimSet;
698  for (auto &dimPair : batchDimMap)
699  rhsBatchDimSet.insert(dimPair.second);
700 
701  // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
702  SmallVector<int64_t, 4> expectedResultDims;
703  for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
704  if (lhsContractingDimSet.count(i) > 0)
705  continue;
706  expectedResultDims.push_back(lhsType.getDimSize(i));
707  }
708 
709  // Add free dimensions from 'rhsType' to 'expectedResultDims'.
710  for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
711  if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
712  continue;
713  expectedResultDims.push_back(rhsType.getDimSize(i));
714  }
715 
716  // Verify 'expectedResultDims'.
717  if (expectedResultDims.empty()) {
718  // No batch or free dimension implies a scalar result.
719  if (resType.isa<VectorType>() || accType.isa<VectorType>())
720  return op.emitOpError("invalid accumulator/result vector shape");
721  } else {
722  // At least one batch or free dimension implies a vector result.
723  auto resVectorType = resType.dyn_cast<VectorType>();
724  auto accVectorType = accType.dyn_cast<VectorType>();
725  if (!resVectorType || !accVectorType)
726  return op.emitOpError("invalid accumulator/result vector shape");
727 
728  // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
729  // types fully define the result vector type. This assumes the affine maps
730  // are well-formed, which must have been verified already.
731  MLIRContext *ctx = op.getContext();
732  AffineMap lhsMap = op.getIndexingMapsArray()[0];
733  AffineMap rhsMap = op.getIndexingMapsArray()[1];
734  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
735  return op.emitOpError(
736  "expected all dimensions to be either a LHS or a RHS dimension");
737  SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
738  for (auto pair :
739  {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
740  VectorType v = pair.first;
741  auto map = pair.second;
742  for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
743  unsigned pos = map.getDimPosition(idx);
744  if (!extents[pos])
745  extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
746  }
747  }
748  if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
749  return op.emitOpError("expected all dimensions to get an extent as "
750  "either a LHS or a RHS dimension");
751 
752  AffineMap resMap = op.getIndexingMapsArray()[2];
753  auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
754  /*symCount=*/0, extents, ctx);
755  // Compose the resMap with the extentsMap, which is a constant map.
756  AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
757  assert(llvm::all_of(
758  expectedMap.getResults(),
759  [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&
760  "expected constant extent along all dimensions.");
761  // Extract the expected shape and build the type.
762  auto expectedShape = llvm::to_vector<4>(
763  llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
764  return e.cast<AffineConstantExpr>().getValue();
765  }));
766  auto expected =
767  VectorType::get(expectedShape, resVectorType.getElementType());
768  if (resVectorType != expected || accVectorType != expected)
769  return op.emitOpError(
770  "invalid accumulator/result vector shape, expected: ")
771  << expected;
772  }
773  return success();
774 }
775 
777  auto lhsType = getLhsType();
778  auto rhsType = getRhsType();
779  auto accType = getAccType();
780  auto resType = getResultType();
781 
782  // Verify that an indexing map was specified for each vector operand.
783  if (getIndexingMapsArray().size() != 3)
784  return emitOpError("expected an indexing map for each vector operand");
785 
786  // Verify that each index map has 'numIterators' inputs, no symbols, and
787  // that the number of map outputs equals the rank of its associated
788  // vector operand.
789  unsigned numIterators = getIteratorTypes().getValue().size();
790  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
791  auto index = it.index();
792  auto map = it.value();
793  if (map.getNumSymbols() != 0)
794  return emitOpError("expected indexing map ")
795  << index << " to have no symbols";
796  auto vectorType = getOperand(index).getType().dyn_cast<VectorType>();
797  unsigned rank = vectorType ? vectorType.getShape().size() : 0;
798  // Verify that the map has the right number of inputs, outputs, and indices.
799  // This also correctly accounts for (..) -> () for rank-0 results.
800  if (map.getNumDims() != numIterators)
801  return emitOpError("expected indexing map ")
802  << index << " to have " << numIterators << " number of inputs";
803  if (map.getNumResults() != rank)
804  return emitOpError("expected indexing map ")
805  << index << " to have " << rank << " number of outputs";
806  if (!map.isProjectedPermutation())
807  return emitOpError("expected indexing map ")
808  << index << " to be a projected permutation of its inputs";
809  }
810 
811  auto contractingDimMap = getContractingDimMap();
812  auto batchDimMap = getBatchDimMap();
813 
814  // Verify at least one contracting dimension pair was specified.
815  if (contractingDimMap.empty())
816  return emitOpError("expected at least one contracting dimension pair");
817 
818  // Verify contracting dimension map was properly constructed.
819  if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
820  return emitOpError("invalid contracting dimension map");
821 
822  // Verify batch dimension map was properly constructed.
823  if (!verifyDimMap(lhsType, rhsType, batchDimMap))
824  return emitOpError("invalid batch dimension map");
825 
826  // Verify 'accType' and 'resType' shape.
827  if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
828  contractingDimMap, batchDimMap)))
829  return failure();
830 
831  // Verify that either two vector masks are set or none are set.
832  auto lhsMaskType = getLHSVectorMaskType();
833  auto rhsMaskType = getRHSVectorMaskType();
834  if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
835  return emitOpError("invalid number of vector masks specified");
836  if (lhsMaskType && rhsMaskType) {
837  // Verify mask rank == argument rank.
838  if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
839  rhsMaskType.getShape().size() != rhsType.getShape().size())
840  return emitOpError("invalid vector mask rank");
841  }
842 
843  // Verify supported combining kind.
844  auto vectorType = resType.dyn_cast<VectorType>();
845  auto elementType = vectorType ? vectorType.getElementType() : resType;
846  if (!isSupportedCombiningKind(getKind(), elementType))
847  return emitOpError("unsupported contraction type");
848 
849  return success();
850 }
851 
852 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
853  return SmallVector<StringRef>{getIndexingMapsAttrName(),
854  getIteratorTypesAttrName(), getKindAttrName()};
855 }
856 
857 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
858  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
859  if (targetExpr == map.getResult(i))
860  return i;
861  return -1;
862 }
863 
864 static std::vector<std::pair<int64_t, int64_t>>
865 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
866  IteratorType targetIteratorType, MLIRContext *context) {
867  std::vector<std::pair<int64_t, int64_t>> dimMap;
868  for (const auto &it : llvm::enumerate(iteratorTypes)) {
869  auto iteratorType = it.value().cast<IteratorTypeAttr>().getValue();
870  if (iteratorType != targetIteratorType)
871  continue;
872  // Search lhs/rhs map results for 'targetExpr'.
873  auto targetExpr = getAffineDimExpr(it.index(), context);
874  int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
875  int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
876  if (lhsDim >= 0 && rhsDim >= 0)
877  dimMap.emplace_back(lhsDim, rhsDim);
878  }
879  return dimMap;
880 }
881 
882 void ContractionOp::getIterationBounds(
883  SmallVectorImpl<int64_t> &iterationBounds) {
884  auto lhsShape = getLhsType().getShape();
885  auto resVectorType = getResultType().dyn_cast<VectorType>();
886  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
887  SmallVector<int64_t, 2> iterationShape;
888  for (const auto &it : llvm::enumerate(getIteratorTypes())) {
889  // Search lhs/rhs map results for 'targetExpr'.
890  auto targetExpr = getAffineDimExpr(it.index(), getContext());
891  auto iteratorType = it.value().cast<IteratorTypeAttr>().getValue();
892  if (iteratorType == IteratorType::reduction) {
893  // Get reduction dim size from lhs shape (same size in rhsShape).
894  int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
895  assert(lhsDimIndex >= 0);
896  iterationBounds.push_back(lhsShape[lhsDimIndex]);
897  continue;
898  }
899  // Get parallel dimension size from result shape.
900  int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
901  assert(resDimIndex >= 0);
902  assert(resVectorType != nullptr);
903  iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
904  }
905 }
906 
907 void ContractionOp::getIterationIndexMap(
908  std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
909  unsigned numMaps = getIndexingMapsArray().size();
910  iterationIndexMap.resize(numMaps);
911  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
912  auto index = it.index();
913  auto map = it.value();
914  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
915  auto dim = map.getResult(i).cast<AffineDimExpr>();
916  iterationIndexMap[index][dim.getPosition()] = i;
917  }
918  }
919 }
920 
921 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
922  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
923  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
924  getContext());
925 }
926 
927 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
928  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
929  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
930  getContext());
931 }
932 
933 Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
935  getIterationBounds(shape);
936  return shape;
937 }
938 
939 /// Return a fused vector::ContractionOp which represents a patterns such as:
940 ///
941 /// ```mlir
942 /// %c0 = vector.constant 0: ...
943 /// %c = vector.contract %a, %b, %c0: ...
944 /// %e = add %c, %d: ...
945 /// ```
946 ///
947 /// by:
948 ///
949 /// ```mlir
950 /// %e = vector.contract %a, %b, %d: ...
951 /// ```
952 ///
953 /// Return null if the canonicalization does not apply.
954 // TODO: This should be a folding of Add into Contract in core but while they
955 // live in different dialects, it is not possible without unnatural
956 // dependencies.
957 template <typename AddOpType>
958 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
960 
962  PatternRewriter &rewriter) const override {
963  auto canonicalize = [&](Value maybeContraction,
964  Value otherOperand) -> vector::ContractionOp {
965  vector::ContractionOp contractionOp =
966  dyn_cast_or_null<vector::ContractionOp>(
967  maybeContraction.getDefiningOp());
968  if (!contractionOp)
969  return vector::ContractionOp();
970  if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
971  contractionOp.getAcc().getDefiningOp())) {
972  if (maybeZero.getValue() ==
973  rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
975  bvm.map(contractionOp.getAcc(), otherOperand);
976  auto newContraction =
977  cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
978  rewriter.replaceOp(addOp, newContraction.getResult());
979  return newContraction;
980  }
981  }
982  return vector::ContractionOp();
983  };
984 
985  Value a = addOp->getOperand(0), b = addOp->getOperand(1);
986  vector::ContractionOp contract = canonicalize(a, b);
987  contract = contract ? contract : canonicalize(b, a);
988  return contract ? success() : failure();
989  }
990 };
991 
992 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
993  MLIRContext *context) {
996 }
997 
998 //===----------------------------------------------------------------------===//
999 // ExtractElementOp
1000 //===----------------------------------------------------------------------===//
1001 
1002 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1003  Value source) {
1004  result.addOperands({source});
1005  result.addTypes(source.getType().cast<VectorType>().getElementType());
1006 }
1007 
1008 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1009  Value source, Value position) {
1010  result.addOperands({source, position});
1011  result.addTypes(source.getType().cast<VectorType>().getElementType());
1012 }
1013 
1015  VectorType vectorType = getVectorType();
1016  if (vectorType.getRank() == 0) {
1017  if (getPosition())
1018  return emitOpError("expected position to be empty with 0-D vector");
1019  return success();
1020  }
1021  if (vectorType.getRank() != 1)
1022  return emitOpError("unexpected >1 vector rank");
1023  if (!getPosition())
1024  return emitOpError("expected position for 1-D vector");
1025  return success();
1026 }
1027 
1028 OpFoldResult vector::ExtractElementOp::fold(ArrayRef<Attribute> operands) {
1029  // Skip the 0-D vector here now.
1030  if (operands.size() < 2)
1031  return {};
1032 
1033  Attribute src = operands[0];
1034  Attribute pos = operands[1];
1035 
1036  // Fold extractelement (splat X) -> X.
1037  if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
1038  return splat.getInput();
1039 
1040  if (!pos || !src)
1041  return {};
1042 
1043  auto srcElements = src.cast<DenseElementsAttr>().getValues<Attribute>();
1044 
1045  auto attr = pos.dyn_cast<IntegerAttr>();
1046  uint64_t posIdx = attr.getInt();
1047 
1048  return srcElements[posIdx];
1049 }
1050 
1051 //===----------------------------------------------------------------------===//
1052 // ExtractOp
1053 //===----------------------------------------------------------------------===//
1054 
1055 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1056  Value source, ArrayRef<int64_t> position) {
1057  build(builder, result, source, getVectorSubscriptAttr(builder, position));
1058 }
1059 
1060 // Convenience builder which assumes the values are constant indices.
1061 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1062  Value source, ValueRange position) {
1063  SmallVector<int64_t, 4> positionConstants =
1064  llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
1065  return pos.getDefiningOp<arith::ConstantIndexOp>().value();
1066  }));
1067  build(builder, result, source, positionConstants);
1068 }
1069 
1071 ExtractOp::inferReturnTypes(MLIRContext *, Optional<Location>,
1072  ValueRange operands, DictionaryAttr attributes,
1073  RegionRange,
1074  SmallVectorImpl<Type> &inferredReturnTypes) {
1075  ExtractOp::Adaptor op(operands, attributes);
1076  auto vectorType = op.getVector().getType().cast<VectorType>();
1077  if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
1078  inferredReturnTypes.push_back(vectorType.getElementType());
1079  } else {
1080  auto n =
1081  std::min<size_t>(op.getPosition().size(), vectorType.getRank() - 1);
1082  inferredReturnTypes.push_back(VectorType::get(
1083  vectorType.getShape().drop_front(n), vectorType.getElementType()));
1084  }
1085  return success();
1086 }
1087 
1088 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1089  // Allow extracting 1-element vectors instead of scalars.
1090  auto isCompatible = [](TypeRange l, TypeRange r) {
1091  auto vectorType = l.front().dyn_cast<VectorType>();
1092  return vectorType && vectorType.getShape().equals({1}) &&
1093  vectorType.getElementType() == r.front();
1094  };
1095  if (l.size() == 1 && r.size() == 1 &&
1096  (isCompatible(l, r) || isCompatible(r, l)))
1097  return true;
1098  return l == r;
1099 }
1100 
1102  auto positionAttr = getPosition().getValue();
1103  if (positionAttr.size() > static_cast<unsigned>(getVectorType().getRank()))
1104  return emitOpError(
1105  "expected position attribute of rank smaller than vector rank");
1106  for (const auto &en : llvm::enumerate(positionAttr)) {
1107  auto attr = en.value().dyn_cast<IntegerAttr>();
1108  if (!attr || attr.getInt() < 0 ||
1109  attr.getInt() >= getVectorType().getDimSize(en.index()))
1110  return emitOpError("expected position attribute #")
1111  << (en.index() + 1)
1112  << " to be a non-negative integer smaller than the corresponding "
1113  "vector dimension";
1114  }
1115  return success();
1116 }
1117 
1118 template <typename IntType>
1119 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1120  return llvm::to_vector<4>(llvm::map_range(
1121  arrayAttr.getAsRange<IntegerAttr>(),
1122  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1123 }
1124 
1125 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1126 /// positions.
1127 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1128  if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1129  return failure();
1130 
1131  SmallVector<int64_t, 4> globalPosition;
1132  ExtractOp currentOp = extractOp;
1133  auto extrPos = extractVector<int64_t>(currentOp.getPosition());
1134  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1135  while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1136  currentOp = nextOp;
1137  auto extrPos = extractVector<int64_t>(currentOp.getPosition());
1138  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1139  }
1140  extractOp.setOperand(currentOp.getVector());
1141  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1142  OpBuilder b(extractOp.getContext());
1143  std::reverse(globalPosition.begin(), globalPosition.end());
1144  extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1145  b.getI64ArrayAttr(globalPosition));
1146  return success();
1147 }
1148 
1149 namespace {
1150 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1151 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1152 /// Compose TransposeOp permutations as we walk back.
1153 /// This helper class keeps an updated extraction position `extractPosition`
1154 /// with extra trailing sentinels.
1155 /// The sentinels encode the internal transposition status of the result vector.
1156 /// As we iterate, extractPosition is permuted and updated.
1157 class ExtractFromInsertTransposeChainState {
1158 public:
1159  ExtractFromInsertTransposeChainState(ExtractOp e);
1160 
1161  /// Iterate over producing insert and transpose ops until we find a fold.
1162  Value fold();
1163 
1164 private:
1165  /// Return true if the vector at position `a` is contained within the vector
1166  /// at position `b`. Under insert/extract semantics, this is the same as `a`
1167  /// is a prefix of `b`.
1168  template <typename ContainerA, typename ContainerB>
1169  bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1170  return a.size() <= b.size() &&
1171  std::equal(a.begin(), a.begin() + a.size(), b.begin());
1172  }
1173 
1174  /// Return true if the vector at position `a` intersects the vector at
1175  /// position `b`. Under insert/extract semantics, this is the same as equality
1176  /// of all entries of `a` that are >=0 with the corresponding entries of b.
1177  /// Comparison is on the common prefix (i.e. zip).
1178  template <typename ContainerA, typename ContainerB>
1179  bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1180  for (auto [elemA, elemB] : llvm::zip(a, b)) {
1181  if (elemA < 0 || elemB < 0)
1182  continue;
1183  if (elemA != elemB)
1184  return false;
1185  }
1186  return true;
1187  }
1188 
1189  /// Folding is only possible in the absence of an internal permutation in the
1190  /// result vector.
1191  bool canFold() {
1192  return (sentinels ==
1193  makeArrayRef(extractPosition).drop_front(extractedRank));
1194  }
1195 
1196  // Helper to get the next defining op of interest.
1197  void updateStateForNextIteration(Value v) {
1198  nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1199  nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1200  };
1201 
1202  // Case 1. If we hit a transpose, just compose the map and iterate.
1203  // Invariant: insert + transpose do not change rank, we can always compose.
1204  LogicalResult handleTransposeOp();
1205 
1206  // Case 2: the insert position matches extractPosition exactly, early return.
1207  LogicalResult handleInsertOpWithMatchingPos(Value &res);
1208 
1209  /// Case 3: if the insert position is a prefix of extractPosition, extract a
1210  /// portion of the source of the insert.
1211  /// Example:
1212  /// ```
1213  /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1214  /// // extractPosition == [1, 2, 3]
1215  /// %ext = vector.extract %ins[1, 0]: vector<3x4x5>
1216  /// // can fold to vector.extract %source[0, 3]
1217  /// %ext = vector.extract %source[3]: vector<5x6>
1218  /// ```
1219  /// To traverse through %source, we need to set the leading dims to 0 and
1220  /// drop the extra leading dims.
1221  /// This method updates the internal state.
1222  LogicalResult handleInsertOpWithPrefixPos(Value &res);
1223 
1224  /// Try to fold in place to extract(source, extractPosition) and return the
1225  /// folded result. Return null if folding is not possible (e.g. due to an
1226  /// internal tranposition in the result).
1227  Value tryToFoldExtractOpInPlace(Value source);
1228 
1229  ExtractOp extractOp;
1230  int64_t vectorRank;
1231  int64_t extractedRank;
1232 
1233  InsertOp nextInsertOp;
1234  TransposeOp nextTransposeOp;
1235 
1236  /// Sentinel values that encode the internal permutation status of the result.
1237  /// They are set to (-1, ... , -k) at the beginning and appended to
1238  /// `extractPosition`.
1239  /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1240  /// ensure that there is no internal transposition.
1241  /// Internal transposition cannot be accounted for with a folding pattern.
1242  // TODO: We could relax the internal transposition with an extra transposition
1243  // operation in a future canonicalizer.
1244  SmallVector<int64_t> sentinels;
1246 };
1247 } // namespace
1248 
1249 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1250  ExtractOp e)
1251  : extractOp(e), vectorRank(extractOp.getVectorType().getRank()),
1252  extractedRank(extractOp.getPosition().size()) {
1253  assert(vectorRank >= extractedRank && "extracted pos overflow");
1254  sentinels.reserve(vectorRank - extractedRank);
1255  for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1256  sentinels.push_back(-(i + 1));
1257  extractPosition = extractVector<int64_t>(extractOp.getPosition());
1258  llvm::append_range(extractPosition, sentinels);
1259 }
1260 
1261 // Case 1. If we hit a transpose, just compose the map and iterate.
1262 // Invariant: insert + transpose do not change rank, we can always compose.
1263 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1264  if (!nextTransposeOp)
1265  return failure();
1266  auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp());
1268  AffineMap::getPermutationMap(permutation, extractOp.getContext()));
1270  return success();
1271 }
1272 
1273 // Case 2: the insert position matches extractPosition exactly, early return.
1275 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1276  Value &res) {
1277  auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
1278  if (makeArrayRef(insertedPos) !=
1279  llvm::makeArrayRef(extractPosition).take_front(extractedRank))
1280  return failure();
1281  // Case 2.a. early-exit fold.
1282  res = nextInsertOp.getSource();
1283  // Case 2.b. if internal transposition is present, canFold will be false.
1284  return success();
1285 }
1286 
1287 /// Case 3: if inserted position is a prefix of extractPosition,
1288 /// extract a portion of the source of the insertion.
1289 /// This method updates the internal state.
1291 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1292  auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
1293  if (!isContainedWithin(insertedPos, extractPosition))
1294  return failure();
1295  // Set leading dims to zero.
1296  std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1297  // Drop extra leading dims.
1298  extractPosition.erase(extractPosition.begin(),
1299  extractPosition.begin() + insertedPos.size());
1300  extractedRank = extractPosition.size() - sentinels.size();
1301  // Case 3.a. early-exit fold (break and delegate to post-while path).
1302  res = nextInsertOp.getSource();
1303  // Case 3.b. if internal transposition is present, canFold will be false.
1304  return success();
1305 }
1306 
1307 /// Try to fold in place to extract(source, extractPosition) and return the
1308 /// folded result. Return null if folding is not possible (e.g. due to an
1309 /// internal tranposition in the result).
1310 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1311  Value source) {
1312  // If we can't fold (either internal transposition, or nothing to fold), bail.
1313  bool nothingToFold = (source == extractOp.getVector());
1314  if (nothingToFold || !canFold())
1315  return Value();
1316  // Otherwise, fold by updating the op inplace and return its result.
1317  OpBuilder b(extractOp.getContext());
1318  extractOp->setAttr(
1319  extractOp.getPositionAttrName(),
1320  b.getI64ArrayAttr(
1321  makeArrayRef(extractPosition).take_front(extractedRank)));
1322  extractOp.getVectorMutable().assign(source);
1323  return extractOp.getResult();
1324 }
1325 
1326 /// Iterate over producing insert and transpose ops until we find a fold.
1327 Value ExtractFromInsertTransposeChainState::fold() {
1328  Value valueToExtractFrom = extractOp.getVector();
1329  updateStateForNextIteration(valueToExtractFrom);
1330  while (nextInsertOp || nextTransposeOp) {
1331  // Case 1. If we hit a transpose, just compose the map and iterate.
1332  // Invariant: insert + transpose do not change rank, we can always compose.
1333  if (succeeded(handleTransposeOp())) {
1334  valueToExtractFrom = nextTransposeOp.getVector();
1335  updateStateForNextIteration(valueToExtractFrom);
1336  continue;
1337  }
1338 
1339  Value result;
1340  // Case 2: the position match exactly.
1341  if (succeeded(handleInsertOpWithMatchingPos(result)))
1342  return result;
1343 
1344  // Case 3: if the inserted position is a prefix of extractPosition, we can
1345  // just extract a portion of the source of the insert.
1346  if (succeeded(handleInsertOpWithPrefixPos(result)))
1347  return tryToFoldExtractOpInPlace(result);
1348 
1349  // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1350  // values. This is a more difficult case and we bail.
1351  auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
1352  if (isContainedWithin(extractPosition, insertedPos) ||
1353  intersectsWhereNonNegative(extractPosition, insertedPos))
1354  return Value();
1355 
1356  // Case 5: No intersection, we forward the extract to insertOp.dest().
1357  valueToExtractFrom = nextInsertOp.getDest();
1358  updateStateForNextIteration(valueToExtractFrom);
1359  }
1360  // If after all this we can fold, go for it.
1361  return tryToFoldExtractOpInPlace(valueToExtractFrom);
1362 }
1363 
1364 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1365 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1366  Operation *defOp = extractOp.getVector().getDefiningOp();
1367  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1368  return Value();
1369  Value source = defOp->getOperand(0);
1370  if (extractOp.getType() == source.getType())
1371  return source;
1372  auto getRank = [](Type type) {
1373  return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
1374  };
1375  // If splat or broadcast from a scalar, just return the source scalar.
1376  unsigned broadcastSrcRank = getRank(source.getType());
1377  if (broadcastSrcRank == 0)
1378  return source;
1379 
1380  unsigned extractResultRank = getRank(extractOp.getType());
1381  if (extractResultRank >= broadcastSrcRank)
1382  return Value();
1383  // Check that the dimension of the result haven't been broadcasted.
1384  auto extractVecType = extractOp.getType().dyn_cast<VectorType>();
1385  auto broadcastVecType = source.getType().dyn_cast<VectorType>();
1386  if (extractVecType && broadcastVecType &&
1387  extractVecType.getShape() !=
1388  broadcastVecType.getShape().take_back(extractResultRank))
1389  return Value();
1390 
1391  auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1392  int64_t rankDiff = broadcastSrcRank - extractResultRank;
1393  // Detect all the positions that come from "dim-1" broadcasting.
1394  // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1395  // extract position to `0` when extracting from the source operand.
1396  llvm::SetVector<int64_t> broadcastedUnitDims =
1397  broadcastOp.computeBroadcastedUnitDims();
1398  auto extractPos = extractVector<int64_t>(extractOp.getPosition());
1399  for (int64_t i = rankDiff, e = extractPos.size(); i < e; ++i)
1400  if (broadcastedUnitDims.contains(i))
1401  extractPos[i] = 0;
1402  // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1403  // matching extract position when extracting from the source operand.
1404  extractPos.erase(extractPos.begin(),
1405  std::next(extractPos.begin(), extractPos.size() - rankDiff));
1406  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1407  OpBuilder b(extractOp.getContext());
1408  extractOp.setOperand(source);
1409  extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1410  b.getI64ArrayAttr(extractPos));
1411  return extractOp.getResult();
1412 }
1413 
1414 // Fold extractOp with source coming from ShapeCast op.
1415 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1416  auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1417  if (!shapeCastOp)
1418  return Value();
1419  // Get the nth dimension size starting from lowest dimension.
1420  auto getDimReverse = [](VectorType type, int64_t n) {
1421  return type.getShape().take_back(n + 1).front();
1422  };
1423  int64_t destinationRank =
1424  extractOp.getType().isa<VectorType>()
1425  ? extractOp.getType().cast<VectorType>().getRank()
1426  : 0;
1427  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1428  return Value();
1429  if (destinationRank > 0) {
1430  auto destinationType = extractOp.getResult().getType().cast<VectorType>();
1431  for (int64_t i = 0; i < destinationRank; i++) {
1432  // The lowest dimension of of the destination must match the lowest
1433  // dimension of the shapecast op source.
1434  // TODO: This case could be support in a canonicalization pattern.
1435  if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1436  getDimReverse(destinationType, i))
1437  return Value();
1438  }
1439  }
1440  // Extract the strides associated with the extract op vector source. Then use
1441  // this to calculate a linearized position for the extract.
1442  auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
1443  std::reverse(extractedPos.begin(), extractedPos.end());
1444  SmallVector<int64_t, 4> strides;
1445  int64_t stride = 1;
1446  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1447  strides.push_back(stride);
1448  stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
1449  }
1450 
1451  int64_t position = linearize(extractedPos, strides);
1452  // Then extract the strides associated to the shapeCast op vector source and
1453  // delinearize the position using those strides.
1454  SmallVector<int64_t, 4> newStrides;
1455  int64_t numDimension =
1456  shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1457  stride = 1;
1458  for (int64_t i = 0; i < numDimension; i++) {
1459  newStrides.push_back(stride);
1460  stride *=
1461  getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1462  }
1463  std::reverse(newStrides.begin(), newStrides.end());
1464  SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
1465  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1466  OpBuilder b(extractOp.getContext());
1467  extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1468  b.getI64ArrayAttr(newPosition));
1469  extractOp.setOperand(shapeCastOp.getSource());
1470  return extractOp.getResult();
1471 }
1472 
1473 /// Fold an ExtractOp from ExtractStridedSliceOp.
1474 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1475  auto extractStridedSliceOp =
1476  extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1477  if (!extractStridedSliceOp)
1478  return Value();
1479  // Return if 'extractStridedSliceOp' has non-unit strides.
1480  if (extractStridedSliceOp.hasNonUnitStrides())
1481  return Value();
1482 
1483  // Trim offsets for dimensions fully extracted.
1484  auto sliceOffsets =
1485  extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1486  while (!sliceOffsets.empty()) {
1487  size_t lastOffset = sliceOffsets.size() - 1;
1488  if (sliceOffsets.back() != 0 ||
1489  extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1490  extractStridedSliceOp.getVectorType().getDimSize(lastOffset))
1491  break;
1492  sliceOffsets.pop_back();
1493  }
1494  unsigned destinationRank = 0;
1495  if (auto vecType = extractOp.getType().dyn_cast<VectorType>())
1496  destinationRank = vecType.getRank();
1497  // The dimensions of the result need to be untouched by the
1498  // extractStridedSlice op.
1499  if (destinationRank >
1500  extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size())
1501  return Value();
1502  auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
1503  assert(extractedPos.size() >= sliceOffsets.size());
1504  for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1505  extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1506  extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1507  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1508  OpBuilder b(extractOp.getContext());
1509  extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1510  b.getI64ArrayAttr(extractedPos));
1511  return extractOp.getResult();
1512 }
1513 
1514 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1516  int64_t destinationRank = op.getType().isa<VectorType>()
1517  ? op.getType().cast<VectorType>().getRank()
1518  : 0;
1519  auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
1520  while (insertOp) {
1521  int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1522  insertOp.getSourceVectorType().getRank();
1523  if (destinationRank > insertOp.getSourceVectorType().getRank())
1524  return Value();
1525  auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1526  auto extractOffsets = extractVector<int64_t>(op.getPosition());
1527 
1528  if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1529  return attr.cast<IntegerAttr>().getInt() != 1;
1530  }))
1531  return Value();
1532  bool disjoint = false;
1533  SmallVector<int64_t, 4> offsetDiffs;
1534  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1535  int64_t start = insertOffsets[dim];
1536  int64_t size =
1537  (dim < insertRankDiff)
1538  ? 1
1539  : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1540  int64_t end = start + size;
1541  int64_t offset = extractOffsets[dim];
1542  // Check if the start of the extract offset is in the interval inserted.
1543  if (start <= offset && offset < end) {
1544  if (dim >= insertRankDiff)
1545  offsetDiffs.push_back(offset - start);
1546  continue;
1547  }
1548  disjoint = true;
1549  break;
1550  }
1551  // The extract element chunk overlap with the vector inserted.
1552  if (!disjoint) {
1553  // If any of the inner dimensions are only partially inserted we have a
1554  // partial overlap.
1555  int64_t srcRankDiff =
1556  insertOp.getSourceVectorType().getRank() - destinationRank;
1557  for (int64_t i = 0; i < destinationRank; i++) {
1558  if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1559  insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1560  insertRankDiff))
1561  return Value();
1562  }
1563  op.getVectorMutable().assign(insertOp.getSource());
1564  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1565  OpBuilder b(op.getContext());
1566  op->setAttr(ExtractOp::getPositionAttrStrName(),
1567  b.getI64ArrayAttr(offsetDiffs));
1568  return op.getResult();
1569  }
1570  // If the chunk extracted is disjoint from the chunk inserted, keep
1571  // looking in the insert chain.
1572  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1573  }
1574  return Value();
1575 }
1576 
1577 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
1578  if (getPosition().empty())
1579  return getVector();
1581  return getResult();
1582  if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
1583  return res;
1584  if (auto res = foldExtractFromBroadcast(*this))
1585  return res;
1586  if (auto res = foldExtractFromShapeCast(*this))
1587  return res;
1588  if (auto val = foldExtractFromExtractStrided(*this))
1589  return val;
1590  if (auto val = foldExtractStridedOpFromInsertChain(*this))
1591  return val;
1592  return OpFoldResult();
1593 }
1594 
1595 namespace {
1596 
1597 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
1598 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
1599 public:
1601 
1602  LogicalResult matchAndRewrite(ExtractOp extractOp,
1603  PatternRewriter &rewriter) const override {
1604  Operation *defOp = extractOp.getVector().getDefiningOp();
1605  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1606  return failure();
1607 
1608  Value source = defOp->getOperand(0);
1609  if (extractOp.getType() == source.getType())
1610  return failure();
1611  auto getRank = [](Type type) {
1612  return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
1613  };
1614  unsigned broadcastSrcRank = getRank(source.getType());
1615  unsigned extractResultRank = getRank(extractOp.getType());
1616  // We only consider the case where the rank of the source is less than or
1617  // equal to the rank of the extract dst. The other cases are handled in the
1618  // folding patterns.
1619  if (extractResultRank < broadcastSrcRank)
1620  return failure();
1621  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1622  extractOp, extractOp.getType(), source);
1623  return success();
1624  }
1625 };
1626 
1627 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
1628 class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
1629 public:
1631 
1632  LogicalResult matchAndRewrite(ExtractOp extractOp,
1633  PatternRewriter &rewriter) const override {
1634  // Return if 'ExtractOp' operand is not defined by a splat vector
1635  // ConstantOp.
1636  Value sourceVector = extractOp.getVector();
1637  Attribute vectorCst;
1638  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
1639  return failure();
1640  auto splat = vectorCst.dyn_cast<SplatElementsAttr>();
1641  if (!splat)
1642  return failure();
1643  Attribute newAttr = splat.getSplatValue<Attribute>();
1644  if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
1645  newAttr = DenseElementsAttr::get(vecDstType, newAttr);
1646  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
1647  return success();
1648  }
1649 };
1650 
1651 // Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
1652 class ExtractOpNonSplatConstantFolder final
1653  : public OpRewritePattern<ExtractOp> {
1654 public:
1656 
1657  LogicalResult matchAndRewrite(ExtractOp extractOp,
1658  PatternRewriter &rewriter) const override {
1659  // Return if 'ExtractOp' operand is not defined by a compatible vector
1660  // ConstantOp.
1661  Value sourceVector = extractOp.getVector();
1662  Attribute vectorCst;
1663  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
1664  return failure();
1665 
1666  auto vecTy = sourceVector.getType().cast<VectorType>();
1667  if (vecTy.isScalable())
1668  return failure();
1669 
1670  // The splat case is handled by `ExtractOpSplatConstantFolder`.
1671  auto dense = vectorCst.dyn_cast<DenseElementsAttr>();
1672  if (!dense || dense.isSplat())
1673  return failure();
1674 
1675  // Calculate the linearized position of the continuous chunk of elements to
1676  // extract.
1677  llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
1678  copy(getI64SubArray(extractOp.getPosition()), completePositions.begin());
1679  int64_t elemBeginPosition =
1680  linearize(completePositions, computeStrides(vecTy.getShape()));
1681  auto denseValuesBegin = dense.value_begin<Attribute>() + elemBeginPosition;
1682 
1683  Attribute newAttr;
1684  if (auto resVecTy = extractOp.getType().dyn_cast<VectorType>()) {
1685  SmallVector<Attribute> elementValues(
1686  denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
1687  newAttr = DenseElementsAttr::get(resVecTy, elementValues);
1688  } else {
1689  newAttr = *denseValuesBegin;
1690  }
1691 
1692  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
1693  return success();
1694  }
1695 };
1696 
1697 } // namespace
1698 
1699 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1700  MLIRContext *context) {
1701  results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
1702  ExtractOpFromBroadcast>(context);
1703 }
1704 
1705 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
1706  SmallVectorImpl<int64_t> &results) {
1707  for (auto attr : arrayAttr)
1708  results.push_back(attr.cast<IntegerAttr>().getInt());
1709 }
1710 
1711 //===----------------------------------------------------------------------===//
1712 // FmaOp
1713 //===----------------------------------------------------------------------===//
1714 
1715 Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
1716  return llvm::to_vector<4>(getVectorType().getShape());
1717 }
1718 
1719 //===----------------------------------------------------------------------===//
1720 // BroadcastOp
1721 //===----------------------------------------------------------------------===//
1722 
1723 /// Return the dimensions of the result vector that were formerly ones in the
1724 /// source tensor and thus correspond to "dim-1" broadcasting.
1727  ArrayRef<int64_t> dstShape) {
1728  int64_t rankDiff = dstShape.size() - srcShape.size();
1729  int64_t dstDim = rankDiff;
1731  for (auto [s1, s2] :
1732  llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
1733  if (s1 != s2) {
1734  assert(s1 == 1 && "expected dim-1 broadcasting");
1735  res.insert(dstDim);
1736  }
1737  ++dstDim;
1738  }
1739  return res;
1740 }
1741 
1743  // Scalar broadcast is without any unit dim broadcast.
1744  auto srcVectorType = getSourceType().dyn_cast<VectorType>();
1745  if (!srcVectorType)
1746  return {};
1747  return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
1748  getVectorType().getShape());
1749 }
1750 
1751 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
1752 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
1753 /// This requires (and asserts) that the broadcast is free of dim-1
1754 /// broadcasting.
1755 /// Since vector.broadcast only allows expanding leading dimensions, an extra
1756 /// vector.transpose may be inserted to make the broadcast possible.
1757 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
1758 /// the helper will assert. This means:
1759 /// 1. `dstShape` must not be empty.
1760 /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
1761 /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
1762 // must match the `value` shape.
1763 Value BroadcastOp::createOrFoldBroadcastOp(
1764  OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
1765  const llvm::SetVector<int64_t> &broadcastedDims) {
1766  assert(!dstShape.empty() && "unexpected empty dst shape");
1767 
1768  // Well-formedness check.
1769  SmallVector<int64_t> checkShape;
1770  for (int i = 0, e = dstShape.size(); i < e; ++i) {
1771  if (broadcastedDims.contains(i))
1772  continue;
1773  checkShape.push_back(dstShape[i]);
1774  }
1775  assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
1776  "ill-formed broadcastedDims contains values not confined to "
1777  "destVectorShape");
1778 
1779  Location loc = value.getLoc();
1780  Type elementType = getElementTypeOrSelf(value.getType());
1781  VectorType srcVectorType = value.getType().dyn_cast<VectorType>();
1782  VectorType dstVectorType = VectorType::get(dstShape, elementType);
1783 
1784  // Step 2. If scalar -> dstShape broadcast, just do it.
1785  if (!srcVectorType) {
1786  assert(checkShape.empty() &&
1787  "ill-formed createOrFoldBroadcastOp arguments");
1788  return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
1789  }
1790 
1791  assert(srcVectorType.getShape().equals(checkShape) &&
1792  "ill-formed createOrFoldBroadcastOp arguments");
1793 
1794  // Step 3. Since vector.broadcast only allows creating leading dims,
1795  // vector -> dstShape broadcast may require a transpose.
1796  // Traverse the dims in order and construct:
1797  // 1. The leading entries of the broadcastShape that is guaranteed to be
1798  // achievable by a simple broadcast.
1799  // 2. The induced permutation for the subsequent vector.transpose that will
1800  // bring us from `broadcastShape` back to he desired `dstShape`.
1801  // If the induced permutation is not the identity, create a vector.transpose.
1802  SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
1803  broadcastShape.reserve(dstShape.size());
1804  // Consider the example:
1805  // srcShape = 2x4
1806  // dstShape = 1x2x3x4x5
1807  // broadcastedDims = [0, 2, 4]
1808  //
1809  // We want to build:
1810  // broadcastShape = 1x3x5x2x4
1811  // permutation = [0, 2, 4, 1, 3]
1812  // ---V--- -----V-----
1813  // leading broadcast part src shape part
1814  //
1815  // Note that the trailing dims of broadcastShape are exactly the srcShape
1816  // by construction.
1817  // nextSrcShapeDim is used to keep track of where in the permutation the
1818  // "src shape part" occurs.
1819  int64_t nextSrcShapeDim = broadcastedDims.size();
1820  for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
1821  if (broadcastedDims.contains(i)) {
1822  // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
1823  // bring it to the head of the broadcastShape.
1824  // It will need to be permuted back from `broadcastShape.size() - 1` into
1825  // position `i`.
1826  broadcastShape.push_back(dstShape[i]);
1827  permutation[i] = broadcastShape.size() - 1;
1828  } else {
1829  // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
1830  // shape and needs to be permuted into position `i`.
1831  // Don't touch `broadcastShape` here, the whole srcShape will be
1832  // appended after.
1833  permutation[i] = nextSrcShapeDim++;
1834  }
1835  }
1836  // 3.c. Append the srcShape.
1837  llvm::append_range(broadcastShape, srcVectorType.getShape());
1838 
1839  // Ensure there are no dim-1 broadcasts.
1840  assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
1841  .empty() &&
1842  "unexpected dim-1 broadcast");
1843 
1844  VectorType broadcastType = VectorType::get(broadcastShape, elementType);
1845  assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
1846  vector::BroadcastableToResult::Success &&
1847  "must be broadcastable");
1848  Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
1849  // Step 4. If we find any dimension that indeed needs to be permuted,
1850  // immediately return a new vector.transpose.
1851  for (int64_t i = 0, e = permutation.size(); i < e; ++i)
1852  if (permutation[i] != i)
1853  return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
1854  // Otherwise return res.
1855  return res;
1856 }
1857 
1859 mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
1860  std::pair<int, int> *mismatchingDims) {
1861  // Broadcast scalar to vector of the same element type.
1862  if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
1863  getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
1864  return BroadcastableToResult::Success;
1865  // From now on, only vectors broadcast.
1866  VectorType srcVectorType = srcType.dyn_cast<VectorType>();
1867  if (!srcVectorType)
1868  return BroadcastableToResult::SourceTypeNotAVector;
1869 
1870  int64_t srcRank = srcVectorType.getRank();
1871  int64_t dstRank = dstVectorType.getRank();
1872  if (srcRank > dstRank)
1873  return BroadcastableToResult::SourceRankHigher;
1874  // Source has an exact match or singleton value for all trailing dimensions
1875  // (all leading dimensions are simply duplicated).
1876  int64_t lead = dstRank - srcRank;
1877  for (int64_t r = 0; r < srcRank; ++r) {
1878  int64_t srcDim = srcVectorType.getDimSize(r);
1879  int64_t dstDim = dstVectorType.getDimSize(lead + r);
1880  if (srcDim != 1 && srcDim != dstDim) {
1881  if (mismatchingDims) {
1882  mismatchingDims->first = srcDim;
1883  mismatchingDims->second = dstDim;
1884  }
1885  return BroadcastableToResult::DimensionMismatch;
1886  }
1887  }
1888 
1889  return BroadcastableToResult::Success;
1890 }
1891 
1893  std::pair<int, int> mismatchingDims;
1894  BroadcastableToResult res =
1895  isBroadcastableTo(getSourceType(), getVectorType(), &mismatchingDims);
1896  if (res == BroadcastableToResult::Success)
1897  return success();
1898  if (res == BroadcastableToResult::SourceRankHigher)
1899  return emitOpError("source rank higher than destination rank");
1900  if (res == BroadcastableToResult::DimensionMismatch)
1901  return emitOpError("dimension mismatch (")
1902  << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
1903  if (res == BroadcastableToResult::SourceTypeNotAVector)
1904  return emitOpError("source type is not a vector");
1905  llvm_unreachable("unexpected vector.broadcast op error");
1906 }
1907 
1908 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
1909  if (getSourceType() == getVectorType())
1910  return getSource();
1911  if (!operands[0])
1912  return {};
1913  auto vectorType = getVectorType();
1914  if (operands[0].isa<IntegerAttr, FloatAttr>())
1915  return DenseElementsAttr::get(vectorType, operands[0]);
1916  if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
1917  return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
1918  return {};
1919 }
1920 
1921 namespace {
1922 
1923 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
1924 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
1926 
1927  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
1928  PatternRewriter &rewriter) const override {
1929  auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
1930  if (!srcBroadcast)
1931  return failure();
1932  rewriter.replaceOpWithNewOp<BroadcastOp>(
1933  broadcastOp, broadcastOp.getVectorType(), srcBroadcast.getSource());
1934  return success();
1935  }
1936 };
1937 } // namespace
1938 
1939 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1940  MLIRContext *context) {
1941  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
1942  // calling `populateCastAwayVectorLeadingOneDimPatterns`
1943  results.add<BroadcastFolder>(context);
1944 }
1945 
1946 //===----------------------------------------------------------------------===//
1947 // ShuffleOp
1948 //===----------------------------------------------------------------------===//
1949 
1950 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
1951  Value v2, ArrayRef<int64_t> mask) {
1952  build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask));
1953 }
1954 
1956  VectorType resultType = getVectorType();
1957  VectorType v1Type = getV1VectorType();
1958  VectorType v2Type = getV2VectorType();
1959  // Verify ranks.
1960  int64_t resRank = resultType.getRank();
1961  int64_t v1Rank = v1Type.getRank();
1962  int64_t v2Rank = v2Type.getRank();
1963  bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
1964  bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
1965  if (!wellFormed0DCase && !wellFormedNDCase)
1966  return emitOpError("rank mismatch");
1967 
1968  // Verify all but leading dimension sizes.
1969  for (int64_t r = 1; r < v1Rank; ++r) {
1970  int64_t resDim = resultType.getDimSize(r);
1971  int64_t v1Dim = v1Type.getDimSize(r);
1972  int64_t v2Dim = v2Type.getDimSize(r);
1973  if (resDim != v1Dim || v1Dim != v2Dim)
1974  return emitOpError("dimension mismatch");
1975  }
1976  // Verify mask length.
1977  auto maskAttr = getMask().getValue();
1978  int64_t maskLength = maskAttr.size();
1979  if (maskLength <= 0)
1980  return emitOpError("invalid mask length");
1981  if (maskLength != resultType.getDimSize(0))
1982  return emitOpError("mask length mismatch");
1983  // Verify all indices.
1984  int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
1985  (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
1986  for (const auto &en : llvm::enumerate(maskAttr)) {
1987  auto attr = en.value().dyn_cast<IntegerAttr>();
1988  if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
1989  return emitOpError("mask index #") << (en.index() + 1) << " out of range";
1990  }
1991  return success();
1992 }
1993 
1995 ShuffleOp::inferReturnTypes(MLIRContext *, Optional<Location>,
1996  ValueRange operands, DictionaryAttr attributes,
1997  RegionRange,
1998  SmallVectorImpl<Type> &inferredReturnTypes) {
1999  ShuffleOp::Adaptor op(operands, attributes);
2000  auto v1Type = op.getV1().getType().cast<VectorType>();
2001  auto v1Rank = v1Type.getRank();
2002  // Construct resulting type: leading dimension matches mask
2003  // length, all trailing dimensions match the operands.
2005  shape.reserve(v1Rank);
2006  shape.push_back(std::max<size_t>(1, op.getMask().size()));
2007  // In the 0-D case there is no trailing shape to append.
2008  if (v1Rank > 0)
2009  llvm::append_range(shape, v1Type.getShape().drop_front());
2010  inferredReturnTypes.push_back(
2011  VectorType::get(shape, v1Type.getElementType()));
2012  return success();
2013 }
2014 
2015 static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
2016  uint64_t expected = begin;
2017  return idxArr.size() == width &&
2018  llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
2019  [&expected](auto attr) {
2020  return attr.getZExtValue() == expected++;
2021  });
2022 }
2023 
2024 OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
2025  VectorType v1Type = getV1VectorType();
2026  // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
2027  // but must be a canonicalization into a vector.broadcast.
2028  if (v1Type.getRank() == 0)
2029  return {};
2030 
2031  // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
2032  if (!v1Type.isScalable() &&
2033  isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
2034  return getV1();
2035  // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
2036  if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
2037  isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
2038  getV2VectorType().getDimSize(0)))
2039  return getV2();
2040 
2041  Attribute lhs = operands.front(), rhs = operands.back();
2042  if (!lhs || !rhs)
2043  return {};
2044 
2045  auto lhsType = lhs.cast<DenseElementsAttr>().getType().cast<VectorType>();
2046  // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2047  // manipulation.
2048  if (lhsType.getRank() != 1)
2049  return {};
2050  int64_t lhsSize = lhsType.getDimSize(0);
2051 
2052  SmallVector<Attribute> results;
2053  auto lhsElements = lhs.cast<DenseElementsAttr>().getValues<Attribute>();
2054  auto rhsElements = rhs.cast<DenseElementsAttr>().getValues<Attribute>();
2055  for (const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
2056  int64_t i = index.getZExtValue();
2057  if (i >= lhsSize) {
2058  results.push_back(rhsElements[i - lhsSize]);
2059  } else {
2060  results.push_back(lhsElements[i]);
2061  }
2062  }
2063 
2064  return DenseElementsAttr::get(getVectorType(), results);
2065 }
2066 
2067 namespace {
2068 
2069 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
2070 // to a broadcast.
2071 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
2073 
2074  LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
2075  PatternRewriter &rewriter) const override {
2076  VectorType v1VectorType = shuffleOp.getV1VectorType();
2077  ArrayAttr mask = shuffleOp.getMask();
2078  if (v1VectorType.getRank() > 0)
2079  return failure();
2080  if (mask.size() != 1)
2081  return failure();
2082  Type resType = VectorType::Builder(v1VectorType).setShape({1});
2083  if (mask[0].cast<IntegerAttr>().getInt() == 0)
2084  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2085  shuffleOp.getV1());
2086  else
2087  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2088  shuffleOp.getV2());
2089  return success();
2090  }
2091 };
2092 
2093 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
2094 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
2095 public:
2097 
2098  LogicalResult matchAndRewrite(ShuffleOp op,
2099  PatternRewriter &rewriter) const override {
2100  auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2101  auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2102 
2103  if (!v1Splat || !v2Splat)
2104  return failure();
2105 
2106  if (v1Splat.getInput() != v2Splat.getInput())
2107  return failure();
2108 
2109  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
2110  return success();
2111  }
2112 };
2113 
2114 } // namespace
2115 
2116 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
2117  MLIRContext *context) {
2118  results.add<ShuffleSplat, Canonicalize0DShuffleOp>(context);
2119 }
2120 
2121 //===----------------------------------------------------------------------===//
2122 // InsertElementOp
2123 //===----------------------------------------------------------------------===//
2124 
2125 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
2126  Value source, Value dest) {
2127  build(builder, result, source, dest, {});
2128 }
2129 
2131  auto dstVectorType = getDestVectorType();
2132  if (dstVectorType.getRank() == 0) {
2133  if (getPosition())
2134  return emitOpError("expected position to be empty with 0-D vector");
2135  return success();
2136  }
2137  if (dstVectorType.getRank() != 1)
2138  return emitOpError("unexpected >1 vector rank");
2139  if (!getPosition())
2140  return emitOpError("expected position for 1-D vector");
2141  return success();
2142 }
2143 
2144 OpFoldResult vector::InsertElementOp::fold(ArrayRef<Attribute> operands) {
2145  // Skip the 0-D vector here.
2146  if (operands.size() < 3)
2147  return {};
2148 
2149  Attribute src = operands[0];
2150  Attribute dst = operands[1];
2151  Attribute pos = operands[2];
2152  if (!src || !dst || !pos)
2153  return {};
2154 
2155  auto dstElements = dst.cast<DenseElementsAttr>().getValues<Attribute>();
2156 
2157  SmallVector<Attribute> results(dstElements);
2158 
2159  auto attr = pos.dyn_cast<IntegerAttr>();
2160  uint64_t posIdx = attr.getInt();
2161 
2162  results[posIdx] = src;
2163 
2164  return DenseElementsAttr::get(getDestVectorType(), results);
2165 }
2166 
2167 //===----------------------------------------------------------------------===//
2168 // InsertOp
2169 //===----------------------------------------------------------------------===//
2170 
2171 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
2172  Value dest, ArrayRef<int64_t> position) {
2173  result.addOperands({source, dest});
2174  auto positionAttr = getVectorSubscriptAttr(builder, position);
2175  result.addTypes(dest.getType());
2176  result.addAttribute(getPositionAttrStrName(), positionAttr);
2177 }
2178 
2179 // Convenience builder which assumes the values are constant indices.
2180 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
2181  Value dest, ValueRange position) {
2182  SmallVector<int64_t, 4> positionConstants =
2183  llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
2184  return pos.getDefiningOp<arith::ConstantIndexOp>().value();
2185  }));
2186  build(builder, result, source, dest, positionConstants);
2187 }
2188 
2190  auto positionAttr = getPosition().getValue();
2191  auto destVectorType = getDestVectorType();
2192  if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
2193  return emitOpError(
2194  "expected position attribute of rank smaller than dest vector rank");
2195  auto srcVectorType = getSourceType().dyn_cast<VectorType>();
2196  if (srcVectorType &&
2197  (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
2198  static_cast<unsigned>(destVectorType.getRank())))
2199  return emitOpError("expected position attribute rank + source rank to "
2200  "match dest vector rank");
2201  if (!srcVectorType &&
2202  (positionAttr.size() != static_cast<unsigned>(destVectorType.getRank())))
2203  return emitOpError(
2204  "expected position attribute rank to match the dest vector rank");
2205  for (const auto &en : llvm::enumerate(positionAttr)) {
2206  auto attr = en.value().dyn_cast<IntegerAttr>();
2207  if (!attr || attr.getInt() < 0 ||
2208  attr.getInt() >= destVectorType.getDimSize(en.index()))
2209  return emitOpError("expected position attribute #")
2210  << (en.index() + 1)
2211  << " to be a non-negative integer smaller than the corresponding "
2212  "dest vector dimension";
2213  }
2214  return success();
2215 }
2216 
2217 namespace {
2218 
2219 // If insertOp is only inserting unit dimensions it can be transformed to a
2220 // broadcast.
2221 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
2222 public:
2224 
2225  LogicalResult matchAndRewrite(InsertOp insertOp,
2226  PatternRewriter &rewriter) const override {
2227  auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
2228  if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2229  srcVecType.getNumElements())
2230  return failure();
2231  rewriter.replaceOpWithNewOp<BroadcastOp>(
2232  insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2233  return success();
2234  }
2235 };
2236 
2237 /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
2238 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
2239 public:
2241 
2242  LogicalResult matchAndRewrite(InsertOp op,
2243  PatternRewriter &rewriter) const override {
2244  auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2245  auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2246 
2247  if (!srcSplat || !dstSplat)
2248  return failure();
2249 
2250  if (srcSplat.getInput() != dstSplat.getInput())
2251  return failure();
2252 
2253  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
2254  return success();
2255  }
2256 };
2257 
2258 // Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
2259 class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
2260 public:
2262 
2263  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
2264  // unless the source vector constant has a single use.
2265  static constexpr int64_t vectorSizeFoldThreshold = 256;
2266 
2267  LogicalResult matchAndRewrite(InsertOp op,
2268  PatternRewriter &rewriter) const override {
2269  // Return if 'InsertOp' operand is not defined by a compatible vector
2270  // ConstantOp.
2271  TypedValue<VectorType> destVector = op.getDest();
2272  Attribute vectorDestCst;
2273  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
2274  return failure();
2275 
2276  VectorType destTy = destVector.getType();
2277  if (destTy.isScalable())
2278  return failure();
2279 
2280  // Make sure we do not create too many large constants.
2281  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
2282  !destVector.hasOneUse())
2283  return failure();
2284 
2285  auto denseDest = vectorDestCst.cast<DenseElementsAttr>();
2286 
2287  Value sourceValue = op.getSource();
2288  Attribute sourceCst;
2289  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
2290  return failure();
2291 
2292  // Calculate the linearized position of the continuous chunk of elements to
2293  // insert.
2294  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
2295  copy(getI64SubArray(op.getPosition()), completePositions.begin());
2296  int64_t insertBeginPosition =
2297  linearize(completePositions, computeStrides(destTy.getShape()));
2298 
2299  SmallVector<Attribute> insertedValues;
2300  if (auto denseSource = sourceCst.dyn_cast<DenseElementsAttr>())
2301  llvm::append_range(insertedValues, denseSource.getValues<Attribute>());
2302  else
2303  insertedValues.push_back(sourceCst);
2304 
2305  auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
2306  copy(insertedValues, allValues.begin() + insertBeginPosition);
2307  auto newAttr = DenseElementsAttr::get(destTy, allValues);
2308 
2309  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
2310  return success();
2311  }
2312 };
2313 
2314 } // namespace
2315 
2316 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
2317  MLIRContext *context) {
2318  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
2319  InsertOpConstantFolder>(context);
2320 }
2321 
2322 // Eliminates insert operations that produce values identical to their source
2323 // value. This happens when the source and destination vectors have identical
2324 // sizes.
2325 OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
2326  if (getPosition().empty())
2327  return getSource();
2328  return {};
2329 }
2330 
2331 //===----------------------------------------------------------------------===//
2332 // InsertStridedSliceOp
2333 //===----------------------------------------------------------------------===//
2334 
2335 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
2336  Value source, Value dest,
2337  ArrayRef<int64_t> offsets,
2338  ArrayRef<int64_t> strides) {
2339  result.addOperands({source, dest});
2340  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
2341  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
2342  result.addTypes(dest.getType());
2343  result.addAttribute(getOffsetsAttrStrName(), offsetsAttr);
2344  result.addAttribute(getStridesAttrStrName(), stridesAttr);
2345 }
2346 
2347 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
2348 template <typename OpType>
2350  ArrayAttr arrayAttr,
2351  ArrayRef<int64_t> shape,
2352  StringRef attrName) {
2353  if (arrayAttr.size() > shape.size())
2354  return op.emitOpError("expected ")
2355  << attrName << " attribute of rank smaller than vector rank";
2356  return success();
2357 }
2358 
2359 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2360 // interval. If `halfOpen` is true then the admissible interval is [min, max).
2361 // Otherwise, the admissible interval is [min, max].
2362 template <typename OpType>
2363 static LogicalResult
2364 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
2365  int64_t max, StringRef attrName,
2366  bool halfOpen = true) {
2367  for (auto attr : arrayAttr) {
2368  auto val = attr.cast<IntegerAttr>().getInt();
2369  auto upper = max;
2370  if (!halfOpen)
2371  upper += 1;
2372  if (val < min || val >= upper)
2373  return op.emitOpError("expected ") << attrName << " to be confined to ["
2374  << min << ", " << upper << ")";
2375  }
2376  return success();
2377 }
2378 
2379 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2380 // interval. If `halfOpen` is true then the admissible interval is [min, max).
2381 // Otherwise, the admissible interval is [min, max].
2382 template <typename OpType>
2383 static LogicalResult
2384 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
2385  ArrayRef<int64_t> shape, StringRef attrName,
2386  bool halfOpen = true, int64_t min = 0) {
2387  for (auto [index, attrDimPair] :
2388  llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
2389  int64_t val =
2390  std::get<0>(attrDimPair).template cast<IntegerAttr>().getInt();
2391  int64_t max = std::get<1>(attrDimPair);
2392  if (!halfOpen)
2393  max += 1;
2394  if (val < min || val >= max)
2395  return op.emitOpError("expected ")
2396  << attrName << " dimension " << index << " to be confined to ["
2397  << min << ", " << max << ")";
2398  }
2399  return success();
2400 }
2401 
2402 // Returns true if all integers in `arrayAttr` are in the interval [min, max}.
2403 // interval. If `halfOpen` is true then the admissible interval is [min, max).
2404 // Otherwise, the admissible interval is [min, max].
2405 template <typename OpType>
2407  OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
2408  ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
2409  bool halfOpen = true, int64_t min = 1) {
2410  assert(arrayAttr1.size() <= shape.size());
2411  assert(arrayAttr2.size() <= shape.size());
2412  for (auto [index, it] :
2413  llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
2414  auto val1 = std::get<0>(it).template cast<IntegerAttr>().getInt();
2415  auto val2 = std::get<1>(it).template cast<IntegerAttr>().getInt();
2416  int64_t max = std::get<2>(it);
2417  if (!halfOpen)
2418  max += 1;
2419  if (val1 + val2 < 0 || val1 + val2 >= max)
2420  return op.emitOpError("expected sum(")
2421  << attrName1 << ", " << attrName2 << ") dimension " << index
2422  << " to be confined to [" << min << ", " << max << ")";
2423  }
2424  return success();
2425 }
2426 
2427 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
2428  MLIRContext *context) {
2429  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
2430  return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
2431  });
2432  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
2433 }
2434 
2436  auto sourceVectorType = getSourceVectorType();
2437  auto destVectorType = getDestVectorType();
2438  auto offsets = getOffsetsAttr();
2439  auto strides = getStridesAttr();
2440  if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
2441  return emitOpError(
2442  "expected offsets of same size as destination vector rank");
2443  if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
2444  return emitOpError("expected strides of same size as source vector rank");
2445  if (sourceVectorType.getRank() > destVectorType.getRank())
2446  return emitOpError(
2447  "expected source rank to be smaller than destination rank");
2448 
2449  auto sourceShape = sourceVectorType.getShape();
2450  auto destShape = destVectorType.getShape();
2451  SmallVector<int64_t, 4> sourceShapeAsDestShape(
2452  destShape.size() - sourceShape.size(), 0);
2453  sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
2454  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
2455  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
2456  if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
2457  offName)) ||
2458  failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1,
2459  stridesName,
2460  /*halfOpen=*/false)) ||
2462  *this, offsets,
2463  makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
2464  offName, "source vector shape",
2465  /*halfOpen=*/false, /*min=*/1)))
2466  return failure();
2467 
2468  return success();
2469 }
2470 
2471 namespace {
2472 /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
2473 /// SplatOp(X):dst_type) to SplatOp(X):dst_type.
2474 class FoldInsertStridedSliceSplat final
2475  : public OpRewritePattern<InsertStridedSliceOp> {
2476 public:
2478 
2479  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
2480  PatternRewriter &rewriter) const override {
2481  auto srcSplatOp =
2482  insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
2483  auto destSplatOp =
2484  insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
2485 
2486  if (!srcSplatOp || !destSplatOp)
2487  return failure();
2488 
2489  if (srcSplatOp.getInput() != destSplatOp.getInput())
2490  return failure();
2491 
2492  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2493  return success();
2494  }
2495 };
2496 
2497 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
2498 /// to dst.
2499 class FoldInsertStridedSliceOfExtract final
2500  : public OpRewritePattern<InsertStridedSliceOp> {
2501 public:
2503 
2504  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
2505  PatternRewriter &rewriter) const override {
2506  auto extractStridedSliceOp =
2507  insertStridedSliceOp.getSource()
2508  .getDefiningOp<vector::ExtractStridedSliceOp>();
2509 
2510  if (!extractStridedSliceOp)
2511  return failure();
2512 
2513  if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
2514  return failure();
2515 
2516  // Check if have the same strides and offsets.
2517  if (extractStridedSliceOp.getStrides() !=
2518  insertStridedSliceOp.getStrides() ||
2519  extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
2520  return failure();
2521 
2522  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2523  return success();
2524  }
2525 };
2526 
2527 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
2528 // ConstantOp.
2529 class InsertStridedSliceConstantFolder final
2530  : public OpRewritePattern<InsertStridedSliceOp> {
2531 public:
2533 
2534  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
2535  // unless the source vector constant has a single use.
2536  static constexpr int64_t vectorSizeFoldThreshold = 256;
2537 
2538  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
2539  PatternRewriter &rewriter) const override {
2540  // Return if 'InsertOp' operand is not defined by a compatible vector
2541  // ConstantOp.
2542  TypedValue<VectorType> destVector = op.getDest();
2543  Attribute vectorDestCst;
2544  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
2545  return failure();
2546 
2547  VectorType destTy = destVector.getType();
2548  if (destTy.isScalable())
2549  return failure();
2550 
2551  // Make sure we do not create too many large constants.
2552  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
2553  !destVector.hasOneUse())
2554  return failure();
2555 
2556  auto denseDest = vectorDestCst.cast<DenseElementsAttr>();
2557 
2558  TypedValue<VectorType> sourceValue = op.getSource();
2559  Attribute sourceCst;
2560  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
2561  return failure();
2562 
2563  // TODO: Handle non-unit strides when they become available.
2564  if (op.hasNonUnitStrides())
2565  return failure();
2566 
2567  VectorType sliceVecTy = sourceValue.getType();
2568  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
2569  int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
2570  SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
2571  SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
2572 
2573  // Calcualte the destination element indices by enumerating all slice
2574  // positions within the destination and linearizing them. The enumeration
2575  // order is lexicographic which yields a sequence of monotonically
2576  // increasing linearized position indices.
2577  // Because the destination may have higher dimensionality then the slice,
2578  // we keep track of two overlapping sets of positions and offsets.
2579  auto denseSlice = sourceCst.cast<DenseElementsAttr>();
2580  auto sliceValuesIt = denseSlice.value_begin<Attribute>();
2581  auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
2582  SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
2583  MutableArrayRef<int64_t> currSlicePosition(
2584  currDestPosition.begin() + rankDifference, currDestPosition.end());
2585  ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
2586  offsets.end());
2587  do {
2588  int64_t linearizedPosition = linearize(currDestPosition, destStrides);
2589  assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
2590  assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
2591  "Invalid slice element");
2592  newValues[linearizedPosition] = *sliceValuesIt;
2593  ++sliceValuesIt;
2594  } while (succeeded(
2595  incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
2596 
2597  auto newAttr = DenseElementsAttr::get(destTy, newValues);
2598  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
2599  return success();
2600  }
2601 };
2602 
2603 } // namespace
2604 
2605 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
2606  RewritePatternSet &results, MLIRContext *context) {
2607  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
2608  InsertStridedSliceConstantFolder>(context);
2609 }
2610 
2611 OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
2612  if (getSourceVectorType() == getDestVectorType())
2613  return getSource();
2614  return {};
2615 }
2616 
2617 //===----------------------------------------------------------------------===//
2618 // OuterProductOp
2619 //===----------------------------------------------------------------------===//
2620 
2621 /// Build an op without mask, use the type of `acc` as the return type.
2622 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
2623  Value lhs, Value rhs, Value acc) {
2624  result.addOperands({lhs, rhs, acc});
2625  result.addTypes(acc.getType());
2626 }
2627 
2629  p << " " << getLhs() << ", " << getRhs();
2630  if (!getAcc().empty()) {
2631  p << ", " << getAcc();
2632  p.printOptionalAttrDict((*this)->getAttrs());
2633  }
2634  p << " : " << getLhs().getType() << ", " << getRhs().getType();
2635 }
2636 
2637 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
2639  Type tLHS, tRHS;
2640  if (parser.parseOperandList(operandsInfo) ||
2641  parser.parseOptionalAttrDict(result.attributes) ||
2642  parser.parseColonType(tLHS) || parser.parseComma() ||
2643  parser.parseType(tRHS))
2644  return failure();
2645  if (operandsInfo.size() < 2)
2646  return parser.emitError(parser.getNameLoc(),
2647  "expected at least 2 operands");
2648  VectorType vLHS = tLHS.dyn_cast<VectorType>();
2649  VectorType vRHS = tRHS.dyn_cast<VectorType>();
2650  if (!vLHS)
2651  return parser.emitError(parser.getNameLoc(),
2652  "expected vector type for operand #1");
2653  VectorType resType =
2654  vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
2655  vLHS.getElementType())
2656  : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
2657 
2658  if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
2659  result.attributes.append(
2660  OuterProductOp::getKindAttrStrName(),
2661  CombiningKindAttr::get(result.getContext(),
2662  OuterProductOp::getDefaultKind()));
2663  }
2664 
2665  return failure(
2666  parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
2667  parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
2668  (operandsInfo.size() > 2 &&
2669  parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
2670  parser.addTypeToList(resType, result.types));
2671 }
2672 
2674  Type tRHS = getOperandTypeRHS();
2675  VectorType vLHS = getOperandVectorTypeLHS(),
2676  vRHS = tRHS.dyn_cast<VectorType>(),
2677  vACC = getOperandVectorTypeACC(), vRES = getVectorType();
2678 
2679  if (vLHS.getRank() != 1)
2680  return emitOpError("expected 1-d vector for operand #1");
2681 
2682  if (vRHS) {
2683  // Proper OUTER operation.
2684  if (vRHS.getRank() != 1)
2685  return emitOpError("expected 1-d vector for operand #2");
2686  if (vRES.getRank() != 2)
2687  return emitOpError("expected 2-d vector result");
2688  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
2689  return emitOpError("expected #1 operand dim to match result dim #1");
2690  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
2691  return emitOpError("expected #2 operand dim to match result dim #2");
2692  } else {
2693  // An AXPY operation.
2694  if (vRES.getRank() != 1)
2695  return emitOpError("expected 1-d vector result");
2696  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
2697  return emitOpError("expected #1 operand dim to match result dim #1");
2698  }
2699 
2700  if (vACC && vACC != vRES)
2701  return emitOpError("expected operand #3 of same type as result type");
2702 
2703  // Verify supported combining kind.
2704  if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
2705  return emitOpError("unsupported outerproduct type");
2706 
2707  return success();
2708 }
2709 
2710 //===----------------------------------------------------------------------===//
2711 // ReshapeOp
2712 //===----------------------------------------------------------------------===//
2713 
2715  // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
2716  auto inputVectorType = getInputVectorType();
2717  auto outputVectorType = getOutputVectorType();
2718  int64_t inputShapeRank = getNumInputShapeSizes();
2719  int64_t outputShapeRank = getNumOutputShapeSizes();
2720  SmallVector<int64_t, 4> fixedVectorSizes;
2721  getFixedVectorSizes(fixedVectorSizes);
2722  int64_t numFixedVectorSizes = fixedVectorSizes.size();
2723 
2724  if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
2725  return emitError("invalid input shape for vector type ") << inputVectorType;
2726 
2727  if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
2728  return emitError("invalid output shape for vector type ")
2729  << outputVectorType;
2730 
2731  // Verify that the 'fixedVectorSizes' match an input/output vector shape
2732  // suffix.
2733  unsigned inputVectorRank = inputVectorType.getRank();
2734  for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
2735  unsigned index = inputVectorRank - numFixedVectorSizes - i;
2736  if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
2737  return emitError("fixed vector size must match input vector for dim ")
2738  << i;
2739  }
2740 
2741  unsigned outputVectorRank = outputVectorType.getRank();
2742  for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
2743  unsigned index = outputVectorRank - numFixedVectorSizes - i;
2744  if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
2745  return emitError("fixed vector size must match output vector for dim ")
2746  << i;
2747  }
2748 
2749  // If all shape operands are produced by constant ops, verify that product
2750  // of dimensions for input/output shape match.
2751  auto isDefByConstant = [](Value operand) {
2752  return isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
2753  };
2754  if (llvm::all_of(getInputShape(), isDefByConstant) &&
2755  llvm::all_of(getOutputShape(), isDefByConstant)) {
2756  int64_t numInputElements = 1;
2757  for (auto operand : getInputShape())
2758  numInputElements *=
2759  cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
2760  int64_t numOutputElements = 1;
2761  for (auto operand : getOutputShape())
2762  numOutputElements *=
2763  cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
2764  if (numInputElements != numOutputElements)
2765  return emitError("product of input and output shape sizes must match");
2766  }
2767  return success();
2768 }
2769 
2770 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
2771  populateFromInt64AttrArray(getFixedVectorSizes(), results);
2772 }
2773 
2774 //===----------------------------------------------------------------------===//
2775 // ExtractStridedSliceOp
2776 //===----------------------------------------------------------------------===//
2777 
2778 // Inference works as follows:
2779 // 1. Add 'sizes' from prefix of dims in 'offsets'.
2780 // 2. Add sizes from 'vectorType' for remaining dims.
2781 static Type inferStridedSliceOpResultType(VectorType vectorType,
2782  ArrayAttr offsets, ArrayAttr sizes,
2783  ArrayAttr strides) {
2784  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
2786  shape.reserve(vectorType.getRank());
2787  unsigned idx = 0;
2788  for (unsigned e = offsets.size(); idx < e; ++idx)
2789  shape.push_back(sizes[idx].cast<IntegerAttr>().getInt());
2790  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
2791  shape.push_back(vectorType.getShape()[idx]);
2792 
2793  return VectorType::get(shape, vectorType.getElementType());
2794 }
2795 
2796 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
2797  Value source, ArrayRef<int64_t> offsets,
2798  ArrayRef<int64_t> sizes,
2799  ArrayRef<int64_t> strides) {
2800  result.addOperands(source);
2801  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
2802  auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
2803  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
2804  result.addTypes(
2805  inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
2806  offsetsAttr, sizesAttr, stridesAttr));
2807  result.addAttribute(getOffsetsAttrStrName(), offsetsAttr);
2808  result.addAttribute(getSizesAttrStrName(), sizesAttr);
2809  result.addAttribute(getStridesAttrStrName(), stridesAttr);
2810 }
2811 
2813  auto type = getVectorType();
2814  auto offsets = getOffsetsAttr();
2815  auto sizes = getSizesAttr();
2816  auto strides = getStridesAttr();
2817  if (offsets.size() != sizes.size() || offsets.size() != strides.size())
2818  return emitOpError(
2819  "expected offsets, sizes and strides attributes of same size");
2820 
2821  auto shape = type.getShape();
2822  auto offName = getOffsetsAttrName();
2823  auto sizesName = getSizesAttrName();
2824  auto stridesName = getStridesAttrName();
2825  if (failed(
2826  isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
2827  failed(
2828  isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
2829  failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
2830  stridesName)) ||
2831  failed(
2832  isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
2833  failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
2834  /*halfOpen=*/false,
2835  /*min=*/1)) ||
2836  failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1,
2837  stridesName,
2838  /*halfOpen=*/false)) ||
2839  failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
2840  shape, offName, sizesName,
2841  /*halfOpen=*/false)))
2842  return failure();
2843 
2844  auto resultType =
2845  inferStridedSliceOpResultType(getVectorType(), offsets, sizes, strides);
2846  if (getResult().getType() != resultType)
2847  return emitOpError("expected result type to be ") << resultType;
2848 
2849  return success();
2850 }
2851 
2852 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
2853 // to use the source of the InsertStrided ops if we can detect that the
2854 // extracted vector is a subset of one of the vector inserted.
2855 static LogicalResult
2856 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
2857  // Helper to extract integer out of ArrayAttr.
2858  auto getElement = [](ArrayAttr array, int idx) {
2859  return array[idx].cast<IntegerAttr>().getInt();
2860  };
2861  ArrayAttr extractOffsets = op.getOffsets();
2862  ArrayAttr extractStrides = op.getStrides();
2863  ArrayAttr extractSizes = op.getSizes();
2864  auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
2865  while (insertOp) {
2866  if (op.getVectorType().getRank() !=
2867  insertOp.getSourceVectorType().getRank())
2868  return failure();
2869  ArrayAttr insertOffsets = insertOp.getOffsets();
2870  ArrayAttr insertStrides = insertOp.getStrides();
2871  // If the rank of extract is greater than the rank of insert, we are likely
2872  // extracting a partial chunk of the vector inserted.
2873  if (extractOffsets.size() > insertOffsets.size())
2874  return failure();
2875  bool patialoverlap = false;
2876  bool disjoint = false;
2877  SmallVector<int64_t, 4> offsetDiffs;
2878  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
2879  if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
2880  return failure();
2881  int64_t start = getElement(insertOffsets, dim);
2882  int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
2883  int64_t offset = getElement(extractOffsets, dim);
2884  int64_t size = getElement(extractSizes, dim);
2885  // Check if the start of the extract offset is in the interval inserted.
2886  if (start <= offset && offset < end) {
2887  // If the extract interval overlaps but is not fully included we may
2888  // have a partial overlap that will prevent any folding.
2889  if (offset + size > end)
2890  patialoverlap = true;
2891  offsetDiffs.push_back(offset - start);
2892  continue;
2893  }
2894  disjoint = true;
2895  break;
2896  }
2897  // The extract element chunk is a subset of the insert element.
2898  if (!disjoint && !patialoverlap) {
2899  op.setOperand(insertOp.getSource());
2900  // OpBuilder is only used as a helper to build an I64ArrayAttr.
2901  OpBuilder b(op.getContext());
2902  op->setAttr(ExtractStridedSliceOp::getOffsetsAttrStrName(),
2903  b.getI64ArrayAttr(offsetDiffs));
2904  return success();
2905  }
2906  // If the chunk extracted is disjoint from the chunk inserted, keep looking
2907  // in the insert chain.
2908  if (disjoint)
2909  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2910  else {
2911  // The extracted vector partially overlap the inserted vector, we cannot
2912  // fold.
2913  return failure();
2914  }
2915  }
2916  return failure();
2917 }
2918 
2919 OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
2920  if (getVectorType() == getResult().getType())
2921  return getVector();
2923  return getResult();
2924  return {};
2925 }
2926 
2927 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
2928  populateFromInt64AttrArray(getOffsets(), results);
2929 }
2930 
2931 namespace {
2932 
2933 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
2934 // ConstantMaskOp.
2935 class StridedSliceConstantMaskFolder final
2936  : public OpRewritePattern<ExtractStridedSliceOp> {
2937 public:
2939 
2940  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2941  PatternRewriter &rewriter) const override {
2942  // Return if 'extractStridedSliceOp' operand is not defined by a
2943  // ConstantMaskOp.
2944  auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
2945  auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
2946  if (!constantMaskOp)
2947  return failure();
2948  // Return if 'extractStridedSliceOp' has non-unit strides.
2949  if (extractStridedSliceOp.hasNonUnitStrides())
2950  return failure();
2951  // Gather constant mask dimension sizes.
2952  SmallVector<int64_t, 4> maskDimSizes;
2953  populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
2954  // Gather strided slice offsets and sizes.
2955  SmallVector<int64_t, 4> sliceOffsets;
2956  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
2957  sliceOffsets);
2958  SmallVector<int64_t, 4> sliceSizes;
2959  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
2960 
2961  // Compute slice of vector mask region.
2962  SmallVector<int64_t, 4> sliceMaskDimSizes;
2963  sliceMaskDimSizes.reserve(maskDimSizes.size());
2964  for (auto [maskDimSize, sliceOffset, sliceSize] :
2965  llvm::zip_equal(maskDimSizes, sliceOffsets, sliceSizes)) {
2966  int64_t sliceMaskDimSize = std::max(
2967  static_cast<int64_t>(0),
2968  std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
2969  sliceMaskDimSizes.push_back(sliceMaskDimSize);
2970  }
2971  // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
2972  // region is a conjunction of mask dim intervals).
2973  if (llvm::is_contained(sliceMaskDimSizes, 0))
2974  sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
2975 
2976  // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
2977  // region.
2978  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
2979  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
2980  vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
2981  return success();
2982  }
2983 };
2984 
2985 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
2986 class StridedSliceSplatConstantFolder final
2987  : public OpRewritePattern<ExtractStridedSliceOp> {
2988 public:
2990 
2991  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2992  PatternRewriter &rewriter) const override {
2993  // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
2994  // ConstantOp.
2995  Value sourceVector = extractStridedSliceOp.getVector();
2996  Attribute vectorCst;
2997  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2998  return failure();
2999 
3000  auto splat = vectorCst.dyn_cast<SplatElementsAttr>();
3001  if (!splat)
3002  return failure();
3003 
3004  auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
3005  splat.getSplatValue<Attribute>());
3006  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3007  newAttr);
3008  return success();
3009  }
3010 };
3011 
3012 // Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3013 // ConstantOp.
3014 class StridedSliceNonSplatConstantFolder final
3015  : public OpRewritePattern<ExtractStridedSliceOp> {
3016 public:
3018 
3019  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3020  PatternRewriter &rewriter) const override {
3021  // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3022  // ConstantOp.
3023  Value sourceVector = extractStridedSliceOp.getVector();
3024  Attribute vectorCst;
3025  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3026  return failure();
3027 
3028  // The splat case is handled by `StridedSliceSplatConstantFolder`.
3029  auto dense = vectorCst.dyn_cast<DenseElementsAttr>();
3030  if (!dense || dense.isSplat())
3031  return failure();
3032 
3033  // TODO: Handle non-unit strides when they become available.
3034  if (extractStridedSliceOp.hasNonUnitStrides())
3035  return failure();
3036 
3037  auto sourceVecTy = sourceVector.getType().cast<VectorType>();
3038  ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3039  SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3040 
3041  VectorType sliceVecTy = extractStridedSliceOp.getType();
3042  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3043  int64_t sliceRank = sliceVecTy.getRank();
3044 
3045  // Expand offsets and sizes to match the vector rank.
3046  SmallVector<int64_t, 4> offsets(sliceRank, 0);
3047  copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
3048 
3049  SmallVector<int64_t, 4> sizes(sourceShape.begin(), sourceShape.end());
3050  copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
3051 
3052  // Calculate the slice elements by enumerating all slice positions and
3053  // linearizing them. The enumeration order is lexicographic which yields a
3054  // sequence of monotonically increasing linearized position indices.
3055  auto denseValuesBegin = dense.value_begin<Attribute>();
3056  SmallVector<Attribute> sliceValues;
3057  sliceValues.reserve(sliceVecTy.getNumElements());
3058  SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3059  do {
3060  int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
3061  assert(linearizedPosition < sourceVecTy.getNumElements() &&
3062  "Invalid index");
3063  sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3064  } while (
3065  succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
3066 
3067  assert(static_cast<int64_t>(sliceValues.size()) ==
3068  sliceVecTy.getNumElements() &&
3069  "Invalid number of slice elements");
3070  auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
3071  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3072  newAttr);
3073  return success();
3074  }
3075 };
3076 
3077 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
3078 // BroadcastOp(ExtractStrideSliceOp).
3079 class StridedSliceBroadcast final
3080  : public OpRewritePattern<ExtractStridedSliceOp> {
3081 public:
3083 
3084  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3085  PatternRewriter &rewriter) const override {
3086  auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3087  if (!broadcast)
3088  return failure();
3089  auto srcVecType = broadcast.getSource().getType().dyn_cast<VectorType>();
3090  unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3091  auto dstVecType = op.getType().cast<VectorType>();
3092  unsigned dstRank = dstVecType.getRank();
3093  unsigned rankDiff = dstRank - srcRank;
3094  // Check if the most inner dimensions of the source of the broadcast are the
3095  // same as the destination of the extract. If this is the case we can just
3096  // use a broadcast as the original dimensions are untouched.
3097  bool lowerDimMatch = true;
3098  for (unsigned i = 0; i < srcRank; i++) {
3099  if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3100  lowerDimMatch = false;
3101  break;
3102  }
3103  }
3104  Value source = broadcast.getSource();
3105  // If the inner dimensions don't match, it means we need to extract from the
3106  // source of the orignal broadcast and then broadcast the extracted value.
3107  // We also need to handle degenerated cases where the source is effectively
3108  // just a single scalar.
3109  bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3110  if (!lowerDimMatch && !isScalarSrc) {
3111  source = rewriter.create<ExtractStridedSliceOp>(
3112  op->getLoc(), source,
3113  getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
3114  getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
3115  getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
3116  }
3117  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
3118  return success();
3119  }
3120 };
3121 
3122 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
3123 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
3124 public:
3126 
3127  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3128  PatternRewriter &rewriter) const override {
3129  auto splat = op.getVector().getDefiningOp<SplatOp>();
3130  if (!splat)
3131  return failure();
3132  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
3133  return success();
3134  }
3135 };
3136 
3137 } // namespace
3138 
3139 void ExtractStridedSliceOp::getCanonicalizationPatterns(
3140  RewritePatternSet &results, MLIRContext *context) {
3141  // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
3142  // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
3143  results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
3144  StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3145  StridedSliceSplat>(context);
3146 }
3147 
3148 //===----------------------------------------------------------------------===//
3149 // TransferReadOp
3150 //===----------------------------------------------------------------------===//
3151 
3152 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
3153 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3154  VectorType vectorType, Value source,
3155  ValueRange indices, AffineMapAttr permutationMapAttr,
3156  /*optional*/ ArrayAttr inBoundsAttr) {
3157  Type elemType = source.getType().cast<ShapedType>().getElementType();
3158  Value padding = builder.create<arith::ConstantOp>(
3159  result.location, elemType, builder.getZeroAttr(elemType));
3160  build(builder, result, vectorType, source, indices, permutationMapAttr,
3161  padding, /*mask=*/Value(), inBoundsAttr);
3162 }
3163 
3164 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
3165 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3166  VectorType vectorType, Value source,
3167  ValueRange indices, AffineMap permutationMap,
3168  Optional<ArrayRef<bool>> inBounds) {
3169  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3170  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3171  ? builder.getBoolArrayAttr(inBounds.value())
3172  : ArrayAttr();
3173  build(builder, result, vectorType, source, indices, permutationMapAttr,
3174  inBoundsAttr);
3175 }
3176 
3177 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
3178 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3179  VectorType vectorType, Value source,
3180  ValueRange indices, Value padding,
3181  Optional<ArrayRef<bool>> inBounds) {
3182  AffineMap permutationMap = getTransferMinorIdentityMap(
3183  source.getType().cast<ShapedType>(), vectorType);
3184  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3185  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3186  ? builder.getBoolArrayAttr(inBounds.value())
3187  : ArrayAttr();
3188  build(builder, result, vectorType, source, indices, permutationMapAttr,
3189  padding,
3190  /*mask=*/Value(), inBoundsAttr);
3191 }
3192 
3193 /// 4. Builder that sets padding to zero and permutation map to
3194 /// 'getMinorIdentityMap'.
3195 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3196  VectorType vectorType, Value source,
3197  ValueRange indices,
3198  Optional<ArrayRef<bool>> inBounds) {
3199  Type elemType = source.getType().cast<ShapedType>().getElementType();
3200  Value padding = builder.create<arith::ConstantOp>(
3201  result.location, elemType, builder.getZeroAttr(elemType));
3202  build(builder, result, vectorType, source, indices, padding, inBounds);
3203 }
3204 
3205 template <typename EmitFun>
3207  EmitFun emitOpError) {
3208  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
3209  for (auto expr : permutationMap.getResults()) {
3210  auto dim = expr.dyn_cast<AffineDimExpr>();
3211  auto zero = expr.dyn_cast<AffineConstantExpr>();
3212  if (zero) {
3213  if (zero.getValue() != 0) {
3214  return emitOpError(
3215  "requires a projected permutation_map (at most one dim or the zero "
3216  "constant can appear in each result)");
3217  }
3218  continue;
3219  }
3220  if (!dim) {
3221  return emitOpError("requires a projected permutation_map (at most one "
3222  "dim or the zero constant can appear in each result)");
3223  }
3224  if (seen[dim.getPosition()]) {
3225  return emitOpError(
3226  "requires a permutation_map that is a permutation (found one dim "
3227  "used more than once)");
3228  }
3229  seen[dim.getPosition()] = true;
3230  }
3231  return success();
3232 }
3233 
3234 static LogicalResult
3235 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
3236  VectorType vectorType, VectorType maskType,
3237  VectorType inferredMaskType, AffineMap permutationMap,
3238  ArrayAttr inBounds) {
3239  if (op->hasAttr("masked")) {
3240  return op->emitOpError("masked attribute has been removed. "
3241  "Use in_bounds instead.");
3242  }
3243 
3244  if (!shapedType.isa<MemRefType, RankedTensorType>())
3245  return op->emitOpError(
3246  "requires source to be a memref or ranked tensor type");
3247 
3248  auto elementType = shapedType.getElementType();
3249  DataLayout dataLayout = DataLayout::closest(op);
3250  if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
3251  // Memref or tensor has vector element type.
3252  unsigned sourceVecSize =
3253  dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
3254  vectorElementType.getShape().back();
3255  unsigned resultVecSize =
3256  dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
3257  vectorType.getShape().back();
3258  if (resultVecSize % sourceVecSize != 0)
3259  return op->emitOpError(
3260  "requires the bitwidth of the minor 1-D vector to be an integral "
3261  "multiple of the bitwidth of the minor 1-D vector of the source");
3262 
3263  unsigned sourceVecEltRank = vectorElementType.getRank();
3264  unsigned resultVecRank = vectorType.getRank();
3265  if (sourceVecEltRank > resultVecRank)
3266  return op->emitOpError(
3267  "requires source vector element and vector result ranks to match.");
3268  unsigned rankOffset = resultVecRank - sourceVecEltRank;
3269  // Check that permutation map results match 'rankOffset' of vector type.
3270  if (permutationMap.getNumResults() != rankOffset)
3271  return op->emitOpError("requires a permutation_map with result dims of "
3272  "the same rank as the vector type");
3273 
3274  if (maskType)
3275  return op->emitOpError("does not support masks with vector element type");
3276  } else {
3277  // Memref or tensor has scalar element type.
3278  unsigned minorSize =
3279  vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
3280  unsigned resultVecSize =
3281  dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
3282  if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
3283  return op->emitOpError(
3284  "requires the bitwidth of the minor 1-D vector to be an integral "
3285  "multiple of the bitwidth of the source element type");
3286 
3287  // Check that permutation map results match rank of vector type.
3288  if (permutationMap.getNumResults() != vectorType.getRank())
3289  return op->emitOpError("requires a permutation_map with result dims of "
3290  "the same rank as the vector type");
3291  }
3292 
3293  if (permutationMap.getNumSymbols() != 0)
3294  return op->emitOpError("requires permutation_map without symbols");
3295 
3296  if (permutationMap.getNumInputs() != shapedType.getRank())
3297  return op->emitOpError("requires a permutation_map with input dims of the "
3298  "same rank as the source type");
3299 
3300  if (maskType && maskType != inferredMaskType)
3301  return op->emitOpError("inferred mask type (")
3302  << inferredMaskType << ") and mask operand type (" << maskType
3303  << ") don't match";
3304 
3305  if (inBounds) {
3306  if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
3307  return op->emitOpError("expects the optional in_bounds attr of same rank "
3308  "as permutation_map results: ")
3309  << AffineMapAttr::get(permutationMap)
3310  << " vs inBounds of size: " << inBounds.size();
3311  for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
3312  if (permutationMap.getResult(i).isa<AffineConstantExpr>() &&
3313  !inBounds.getValue()[i].cast<BoolAttr>().getValue())
3314  return op->emitOpError("requires broadcast dimensions to be in-bounds");
3315  }
3316 
3317  return success();
3318 }
3319 
3320 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
3321  SmallVector<StringRef, 3> elidedAttrs;
3322  elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
3323  if (op.permutation_map().isMinorIdentity())
3324  elidedAttrs.push_back(op.getPermutationMapAttrStrName());
3325  bool elideInBounds = true;
3326  if (auto inBounds = op.in_bounds()) {
3327  for (auto attr : *inBounds) {
3328  if (attr.template cast<BoolAttr>().getValue()) {
3329  elideInBounds = false;
3330  break;
3331  }
3332  }
3333  }
3334  if (elideInBounds)
3335  elidedAttrs.push_back(op.getInBoundsAttrStrName());
3336  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
3337 }
3338 
3340  p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
3341  if (getMask())
3342  p << ", " << getMask();
3343  printTransferAttrs(p, *this);
3344  p << " : " << getShapedType() << ", " << getVectorType();
3345 }
3346 
3347 /// Infers the mask type for a transfer read given its vector type and
3348 /// permutation map. The mask in a transfer read operation applies to the
3349 /// tensor/buffer reading part of it and its type should match the shape read
3350 /// *before* any permutation or broadcasting.
3351 static VectorType inferTransferReadMaskType(VectorType vecType,
3352  AffineMap permMap) {
3353  auto i1Type = IntegerType::get(permMap.getContext(), 1);
3354  AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
3355  assert(invPermMap && "Inversed permutation map couldn't be computed");
3356  SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
3357  return VectorType::get(maskShape, i1Type);
3358 }
3359 
3360 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
3361  auto &builder = parser.getBuilder();
3362  SMLoc typesLoc;
3363  OpAsmParser::UnresolvedOperand sourceInfo;
3365  OpAsmParser::UnresolvedOperand paddingInfo;
3366  SmallVector<Type, 2> types;
3368  // Parsing with support for paddingValue.
3369  if (parser.parseOperand(sourceInfo) ||
3370  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
3371  parser.parseComma() || parser.parseOperand(paddingInfo))
3372  return failure();
3373  ParseResult hasMask = parser.parseOptionalComma();
3374  if (hasMask.succeeded()) {
3375  if (parser.parseOperand(maskInfo))
3376  return failure();
3377  }
3378  if (parser.parseOptionalAttrDict(result.attributes) ||
3379  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
3380  return failure();
3381  if (types.size() != 2)
3382  return parser.emitError(typesLoc, "requires two types");
3383  auto indexType = builder.getIndexType();
3384  auto shapedType = types[0].dyn_cast<ShapedType>();
3385  if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
3386  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
3387  VectorType vectorType = types[1].dyn_cast<VectorType>();
3388  if (!vectorType)
3389  return parser.emitError(typesLoc, "requires vector type");
3390  auto permMapAttrName = TransferReadOp::getPermutationMapAttrStrName();
3391  Attribute permMapAttr = result.attributes.get(permMapAttrName);
3392  AffineMap permMap;
3393  if (!permMapAttr) {
3394  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3395  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
3396  } else {
3397  permMap = permMapAttr.cast<AffineMapAttr>().getValue();
3398  }
3399  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
3400  parser.resolveOperands(indexInfo, indexType, result.operands) ||
3401  parser.resolveOperand(paddingInfo, shapedType.getElementType(),
3402  result.operands))
3403  return failure();
3404  if (hasMask.succeeded()) {
3405  if (shapedType.getElementType().dyn_cast<VectorType>())
3406  return parser.emitError(
3407  maskInfo.location, "does not support masks with vector element type");
3408  // Instead of adding the mask type as an op type, compute it based on the
3409  // vector type and the permutation map (to keep the type signature small).
3410  auto maskType = inferTransferReadMaskType(vectorType, permMap);
3411  if (parser.resolveOperand(maskInfo, maskType, result.operands))
3412  return failure();
3413  }
3414  result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
3415  builder.getDenseI32ArrayAttr(
3416  {1, static_cast<int32_t>(indexInfo.size()), 1,
3417  static_cast<int32_t>(hasMask.succeeded())}));
3418  return parser.addTypeToList(vectorType, result.types);
3419 }
3420 
3422  // Consistency of elemental types in source and vector.
3423  ShapedType shapedType = getShapedType();
3424  VectorType vectorType = getVectorType();
3425  VectorType maskType = getMaskType();
3426  auto paddingType = getPadding().getType();
3427  auto permutationMap = getPermutationMap();
3428  VectorType inferredMaskType =
3429  maskType ? inferTransferReadMaskType(vectorType, permutationMap)
3430  : VectorType();
3431  auto sourceElementType = shapedType.getElementType();
3432 
3433  if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
3434  return emitOpError("requires ") << shapedType.getRank() << " indices";
3435 
3436  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
3437  shapedType, vectorType, maskType,
3438  inferredMaskType, permutationMap,
3439  getInBounds() ? *getInBounds() : ArrayAttr())))
3440  return failure();
3441 
3442  if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
3443  // Source has vector element type.
3444  // Check that 'sourceVectorElementType' and 'paddingType' types match.
3445  if (sourceVectorElementType != paddingType)
3446  return emitOpError(
3447  "requires source element type and padding type to match.");
3448 
3449  } else {
3450  // Check that 'paddingType' is valid to store in a vector type.
3451  if (!VectorType::isValidElementType(paddingType))
3452  return emitOpError("requires valid padding vector elemental type");
3453 
3454  // Check that padding type and vector element types match.
3455  if (paddingType != sourceElementType)
3456  return emitOpError(
3457  "requires formal padding and source of the same elemental type");
3458  }
3459 
3460  return verifyPermutationMap(permutationMap,
3461  [&](Twine t) { return emitOpError(t); });
3462 }
3463 
3464 template <typename TransferOp>
3465 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
3466  // TODO: support more aggressive createOrFold on:
3467  // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)`
3468  if (op.getShapedType().isDynamicDim(indicesIdx))
3469  return false;
3470  Value index = op.getIndices()[indicesIdx];
3471  auto cstOp = index.getDefiningOp<arith::ConstantIndexOp>();
3472  if (!cstOp)
3473  return false;
3474 
3475  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
3476  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
3477 
3478  return cstOp.value() + vectorSize <= sourceSize;
3479 }
3480 
3481 template <typename TransferOp>
3483  // TODO: support 0-d corner case.
3484  // TODO: Be less conservative.
3485  if (op.getTransferRank() == 0)
3486  return failure();
3487  AffineMap permutationMap = op.getPermutationMap();
3488  bool changed = false;
3489  SmallVector<bool, 4> newInBounds;
3490  newInBounds.reserve(op.getTransferRank());
3491  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
3492  // Already marked as in-bounds, nothing to see here.
3493  if (op.isDimInBounds(i)) {
3494  newInBounds.push_back(true);
3495  continue;
3496  }
3497  // Currently out-of-bounds, check whether we can statically determine it is
3498  // inBounds.
3499  auto dimExpr = permutationMap.getResult(i).dyn_cast<AffineDimExpr>();
3500  assert(dimExpr && "Broadcast dims must be in-bounds");
3501  auto inBounds =
3502  isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
3503  newInBounds.push_back(inBounds);
3504  // We commit the pattern if it is "more inbounds".
3505  changed |= inBounds;
3506  }
3507  if (!changed)
3508  return failure();
3509  // OpBuilder is only used as a helper to build an I64ArrayAttr.
3510  OpBuilder b(op.getContext());
3511  op->setAttr(TransferOp::getInBoundsAttrStrName(),
3512  b.getBoolArrayAttr(newInBounds));
3513  return success();
3514 }
3515 
3516 /// ```
3517 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
3518 /// : vector<1x4xf32>, tensor<4x4xf32>
3519 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
3520 /// : tensor<4x4xf32>, vector<1x4xf32>
3521 /// ```
3522 /// -> Folds into
3523 /// ```
3524 /// %v0
3525 /// ```
3526 static Value foldRAW(TransferReadOp readOp) {
3527  if (!readOp.getShapedType().isa<RankedTensorType>())
3528  return {};
3529  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
3530  while (defWrite) {
3531  if (checkSameValueRAW(defWrite, readOp))
3532  return defWrite.getVector();
3534  cast<VectorTransferOpInterface>(defWrite.getOperation()),
3535  cast<VectorTransferOpInterface>(readOp.getOperation())))
3536  break;
3537  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
3538  }
3539  return {};
3540 }
3541 
3542 OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
3543  if (Value vec = foldRAW(*this))
3544  return vec;
3545  /// transfer_read(memrefcast) -> transfer_read
3547  return getResult();
3548  if (succeeded(memref::foldMemRefCast(*this)))
3549  return getResult();
3550  if (succeeded(tensor::foldTensorCast(*this)))
3551  return getResult();
3552  return OpFoldResult();
3553 }
3554 
3555 Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
3556  return llvm::to_vector<4>(getVectorType().getShape());
3557 }
3558 
3559 void TransferReadOp::getEffects(
3561  &effects) {
3562  if (getShapedType().isa<MemRefType>())
3563  effects.emplace_back(MemoryEffects::Read::get(), getSource(),
3564  SideEffects::DefaultResource::get());
3565 }
3566 
3567 /// Returns true if all rank reduced in the given `extractOp` happen in leading
3568 /// dimensions earlier than last `trailingRank` dimensions.
3569 static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp,
3570  unsigned trailingRank) {
3571  // If no ranks are reduced at all, it's a degenerated case; always true.
3572  if (extractOp.getSourceType().getRank() == extractOp.getType().getRank())
3573  return true;
3574 
3575  RankedTensorType inferredType = extractOp.inferResultType(
3576  extractOp.getSourceType(), extractOp.getMixedOffsets(),
3577  extractOp.getMixedSizes(), extractOp.getMixedStrides());
3578  return extractOp.getType().getShape().take_back(trailingRank) ==
3579  inferredType.getShape().take_back(trailingRank);
3580 }
3581 
3582 namespace {
3583 /// Fold transfer_reads of a tensor.extract_slice op. E.g.:
3584 ///
3585 /// ```
3586 /// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
3587 /// : tensor<?x?xf32> to tensor<?x?xf32>
3588 /// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
3589 /// : tensor<?x?xf32>, vector<4x5xf32>
3590 /// ```
3591 /// is rewritten to:
3592 /// ```
3593 /// %p0 = arith.addi %a, %e : index
3594 /// %p1 = arith.addi %b, %f : index
3595 /// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
3596 /// : tensor<?x?xf32>, vector<4x5xf32>
3597 /// ```
3598 struct FoldExtractSliceIntoTransferRead
3599  : public OpRewritePattern<TransferReadOp> {
3600 public:
3602 
3603  LogicalResult matchAndRewrite(TransferReadOp xferOp,
3604  PatternRewriter &rewriter) const override {
3605  // TODO: support 0-d corner case.
3606  if (xferOp.getTransferRank() == 0)
3607  return failure();
3608  if (xferOp.hasOutOfBoundsDim())
3609  return failure();
3610  if (!xferOp.getPermutationMap().isMinorIdentity())
3611  return failure();
3612  if (xferOp.getMask())
3613  return failure();
3614  auto extractOp = xferOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
3615  if (!extractOp)
3616  return failure();
3617  if (!extractOp.hasUnitStride())
3618  return failure();
3619 
3620  // Bail on illegal rank-reduction: we need to check that the rank-reduced
3621  // dims are exactly the leading dims. I.e. the following is illegal:
3622  // ```
3623  // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] :
3624  // tensor<2x1x4xf32> to tensor<2x4xf32>
3625  // %1 = vector.transfer_read %0[0,0], %cst :
3626  // tensor<2x4xf32>, vector<2x4xf32>
3627  // ```
3628  //
3629  // Cannot fold into:
3630  // ```
3631  // %0 = vector.transfer_read %t[0,0,0], %cst :
3632  // tensor<2x1x4xf32>, vector<2x4xf32>
3633  // ```
3634  // For this, check the trailing `vectorRank` dims of the extract_slice
3635  // result tensor match the trailing dims of the inferred result tensor.
3636  if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank()))
3637  return failure();
3638 
3639  int64_t rankReduced =
3640  extractOp.getSourceType().getRank() - extractOp.getType().getRank();
3641 
3642  SmallVector<Value> newIndices;
3643  // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
3644  // indices first.
3645  for (int64_t i = 0; i < rankReduced; ++i) {
3646  OpFoldResult offset = extractOp.getMixedOffsets()[i];
3647  newIndices.push_back(getValueOrCreateConstantIndexOp(
3648  rewriter, extractOp.getLoc(), offset));
3649  }
3650  for (const auto &it : llvm::enumerate(xferOp.getIndices())) {
3651  OpFoldResult offset =
3652  extractOp.getMixedOffsets()[it.index() + rankReduced];
3653  newIndices.push_back(rewriter.create<arith::AddIOp>(
3654  xferOp->getLoc(), it.value(),
3655  getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(),
3656  offset)));
3657  }
3658  SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
3659  rewriter.replaceOpWithNewOp<TransferReadOp>(
3660  xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices,
3661  xferOp.getPadding(), ArrayRef<bool>{inBounds});
3662 
3663  return success();
3664  }
3665 };
3666 
3667 /// Store to load forwarding for transfer operations with permuation maps.
3668 /// Even if the permutation maps are different we can still propagate the store
3669 /// into the load if the size of the dimensions read and written match. Then we
3670 /// can replace the transfer_read + transfer_write by vector.broadcast and
3671 /// vector.transpose.
3672 /// Example:
3673 /// ```
3674 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
3675 /// {in_bounds = [true, true],
3676 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
3677 /// vector<4x1xf32>, tensor<4x4x4xf32>
3678 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
3679 /// {in_bounds = [true, true, true, true],
3680 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
3681 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
3682 /// ```
3683 /// To:
3684 /// ```
3685 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
3686 /// %r = vector.transpose %0, [3, 0, 2, 1] :
3687 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
3688 /// ```
3689 struct TransferReadAfterWriteToBroadcast
3690  : public OpRewritePattern<TransferReadOp> {
3692 
3693  LogicalResult matchAndRewrite(TransferReadOp readOp,
3694  PatternRewriter &rewriter) const override {
3695  if (readOp.hasOutOfBoundsDim() ||
3696  !readOp.getShapedType().isa<RankedTensorType>())
3697  return failure();
3698  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
3699  if (!defWrite)
3700  return failure();
3701 
3702  SmallVector<int64_t> readDims = readOp.getTransferChunkAccessed();
3703  Value vec;
3704  if (readOp.getIndices() == defWrite.getIndices() &&
3705  readOp.getMask() == defWrite.getMask()) {
3706  SmallVector<int64_t> writeDims = defWrite.getTransferChunkAccessed();
3707  // TODO: If the writeDim is a superset of the read dims we could do an
3708  // extract_strided_slice.
3709  if (writeDims == readDims)
3710  vec = defWrite.getVector();
3711  }
3712  // TODO: loop through the chain of transfer_write if we can prove that they
3713  // don't overlap with the transfer_read. This requires improving
3714  // `isDisjointTransferIndices` helper.
3715  if (!vec)
3716  return failure();
3717  SmallVector<unsigned> permutation;
3718  AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
3719  AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
3720  AffineMap map = readMap.compose(writeMap);
3721  if (map.getNumResults() == 0)
3722  return failure();
3723  // Calculate the permuation to apply to go from the vector stored to the
3724  // vector read.
3725  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
3726  return failure();
3727 
3728  Location loc = readOp.getLoc();
3729  // Calculate the broadcast shape by applying the reverse permuation to the
3730  // final shape we want.
3731  ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
3732  SmallVector<int64_t> broadcastShape(destShape.size());
3733  for (const auto &pos : llvm::enumerate(permutation))
3734  broadcastShape[pos.value()] = destShape[pos.index()];
3735  VectorType broadcastedType = VectorType::get(
3736  broadcastShape, defWrite.getVectorType().getElementType());
3737  vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
3738  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
3739  rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
3740  transposePerm);
3741  return success();
3742  }
3743 };
3744 } // namespace
3745 
3746 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3747  MLIRContext *context) {
3748  results
3749  .add<FoldExtractSliceIntoTransferRead, TransferReadAfterWriteToBroadcast>(
3750  context);
3751 }
3752 
3753 //===----------------------------------------------------------------------===//
3754 // TransferWriteOp
3755 //===----------------------------------------------------------------------===//
3756 
3757 /// 1. Builder with type inference.
3758 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3759  Value vector, Value dest, ValueRange indices,
3760  AffineMapAttr permutationMapAttr,
3761  /*optional*/ Value mask,
3762  /*optional*/ ArrayAttr inBoundsAttr) {
3763  Type resultType = dest.getType().dyn_cast<RankedTensorType>();
3764  build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
3765  mask, inBoundsAttr);
3766 }
3767 
3768 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
3769 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3770  Value vector, Value dest, ValueRange indices,
3771  AffineMapAttr permutationMapAttr,
3772  /*optional*/ ArrayAttr inBoundsAttr) {
3773  build(builder, result, vector, dest, indices, permutationMapAttr,
3774  /*mask=*/Value(), inBoundsAttr);
3775 }
3776 
3777 /// 3. Builder with type inference that sets an empty mask (variant without
3778 /// attrs)
3779 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3780  Value vector, Value dest, ValueRange indices,
3781  AffineMap permutationMap,
3782  Optional<ArrayRef<bool>> inBounds) {
3783  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3784  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3785  ? builder.getBoolArrayAttr(inBounds.value())
3786  : ArrayAttr();
3787  build(builder, result, vector, dest, indices, permutationMapAttr,
3788  /*mask=*/Value(), inBoundsAttr);
3789 }
3790 
3791 /// 4. Builder with type inference that sets an empty mask and sets permutation
3792 /// map to 'getMinorIdentityMap'.
3793 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3794  Value vector, Value dest, ValueRange indices,
3795  Optional<ArrayRef<bool>> inBounds) {
3796  auto vectorType = vector.getType().cast<VectorType>();
3797  AffineMap permutationMap = getTransferMinorIdentityMap(
3798  dest.getType().cast<ShapedType>(), vectorType);
3799  build(builder, result, vector, dest, indices, permutationMap, inBounds);
3800 }
3801 
3802 /// Infers the mask type for a transfer write given its vector type and
3803 /// permutation map. The mask in a transfer read operation applies to the
3804 /// tensor/buffer writing part of it and its type should match the shape written
3805 /// *after* any permutation.
3806 static VectorType inferTransferWriteMaskType(VectorType vecType,
3807  AffineMap permMap) {
3808  auto i1Type = IntegerType::get(permMap.getContext(), 1);
3809  SmallVector<int64_t, 8> maskShape =
3810  compressUnusedDims(permMap).compose(vecType.getShape());
3811  return VectorType::get(maskShape, i1Type);
3812 }
3813 
3814 ParseResult TransferWriteOp::parse(OpAsmParser &parser,
3815  OperationState &result) {
3816  auto &builder = parser.getBuilder();
3817  SMLoc typesLoc;
3818  OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
3820  SmallVector<Type, 2> types;
3822  if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
3823  parser.parseOperand(sourceInfo) ||
3824  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
3825  return failure();
3826  ParseResult hasMask = parser.parseOptionalComma();
3827  if (hasMask.succeeded() && parser.parseOperand(maskInfo))
3828  return failure();
3829  if (parser.parseOptionalAttrDict(result.attributes) ||
3830  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
3831  return failure();
3832  if (types.size() != 2)
3833  return parser.emitError(typesLoc, "requires two types");
3834  auto indexType = builder.getIndexType();
3835  VectorType vectorType = types[0].dyn_cast<VectorType>();
3836  if (!vectorType)
3837  return parser.emitError(typesLoc, "requires vector type");
3838  ShapedType shapedType = types[1].dyn_cast<ShapedType>();
3839  if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
3840  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
3841  auto permMapAttrName = TransferWriteOp::getPermutationMapAttrStrName();
3842  auto permMapAttr = result.attributes.get(permMapAttrName);
3843  AffineMap permMap;
3844  if (!permMapAttr) {
3845  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3846  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
3847  } else {
3848  permMap = permMapAttr.cast<AffineMapAttr>().getValue();
3849  }
3850  if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
3851  parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
3852  parser.resolveOperands(indexInfo, indexType, result.operands))
3853  return failure();
3854  if (hasMask.succeeded()) {
3855  if (shapedType.getElementType().dyn_cast<VectorType>())
3856  return parser.emitError(
3857  maskInfo.location, "does not support masks with vector element type");
3858  auto maskType = inferTransferWriteMaskType(vectorType, permMap);
3859  if (parser.resolveOperand(maskInfo, maskType, result.operands))
3860  return failure();
3861  }
3862  result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
3863  builder.getDenseI32ArrayAttr(
3864  {1, 1, static_cast<int32_t>(indexInfo.size()),
3865  static_cast<int32_t>(hasMask.succeeded())}));
3866  return failure(shapedType.isa<RankedTensorType>() &&
3867  parser.addTypeToList(shapedType, result.types));
3868 }
3869 
3871  p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
3872  if (getMask())
3873  p << ", " << getMask();
3874  printTransferAttrs(p, *this);
3875  p << " : " << getVectorType() << ", " << getShapedType();
3876 }
3877 
3879  // Consistency of elemental types in shape and vector.
3880  ShapedType shapedType = getShapedType();
3881  VectorType vectorType = getVectorType();
3882  VectorType maskType = getMaskType();
3883  auto permutationMap = getPermutationMap();
3884  VectorType inferredMaskType =
3885  maskType ? inferTransferWriteMaskType(vectorType, permutationMap)
3886  : VectorType();
3887 
3888  if (llvm::size(getIndices()) != shapedType.getRank())
3889  return emitOpError("requires ") << shapedType.getRank() << " indices";
3890 
3891  // We do not allow broadcast dimensions on TransferWriteOps for the moment,
3892  // as the semantics is unclear. This can be revisited later if necessary.
3893  if (hasBroadcastDim())
3894  return emitOpError("should not have broadcast dimensions");
3895 
3896  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
3897  shapedType, vectorType, maskType,
3898  inferredMaskType, permutationMap,
3899  getInBounds() ? *getInBounds() : ArrayAttr())))
3900  return failure();
3901 
3902  return verifyPermutationMap(permutationMap,
3903  [&](Twine t) { return emitOpError(t); });
3904 }
3905 
3906 /// Fold:
3907 /// ```
3908 /// %t1 = ...
3909 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
3910 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
3911 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
3912 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
3913 /// ```
3914 ///
3915 /// into:
3916 ///
3917 /// ```
3918 /// %t0
3919 /// ```
3920 ///
3921 /// The producer of t1 may or may not be DCE'd depending on whether it is a
3922 /// block argument or has side effects.
3923 static LogicalResult foldReadInitWrite(TransferWriteOp write,
3925  SmallVectorImpl<OpFoldResult> &results) {
3926  // TODO: support 0-d corner case.
3927  if (write.getTransferRank() == 0)
3928  return failure();
3929  auto rankedTensorType =
3930  write.getSource().getType().dyn_cast<RankedTensorType>();
3931  // If not operating on tensors, bail.
3932  if (!rankedTensorType)
3933  return failure();
3934  // If no read, bail.
3935  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
3936  if (!read)
3937  return failure();
3938  // TODO: support 0-d corner case.
3939  if (read.getTransferRank() == 0)
3940  return failure();
3941  // For now, only accept minor identity. Future: composition is minor identity.
3942  if (!read.getPermutationMap().isMinorIdentity() ||
3943  !write.getPermutationMap().isMinorIdentity())
3944  return failure();
3945  // Bail on mismatching ranks.
3946  if (read.getTransferRank() != write.getTransferRank())
3947  return failure();
3948  // Bail on potential out-of-bounds accesses.
3949  if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
3950  return failure();
3951  // Tensor types must be the same.
3952  if (read.getSource().getType() != rankedTensorType)
3953  return failure();
3954  // Vector types must be the same.
3955  if (read.getVectorType() != write.getVectorType())
3956  return failure();
3957  // Vector and Tensor shapes must match.
3958  if (read.getVectorType().getShape() != rankedTensorType.getShape())
3959  return failure();
3960  // If any index is nonzero.
3961  auto isNotConstantZero = [](Value v) {
3962  auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>();
3963  return !cstOp || cstOp.value() != 0;
3964  };
3965  if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
3966  llvm::any_of(write.getIndices(), isNotConstantZero))
3967  return failure();
3968  // Success.
3969  results.push_back(read.getSource());
3970  return success();
3971 }
3972 
3973 static bool checkSameValueWAR(vector::TransferReadOp read,
3974  vector::TransferWriteOp write) {
3975  return read.getSource() == write.getSource() &&
3976  read.getIndices() == write.getIndices() &&
3977  read.getPermutationMap() == write.getPermutationMap() &&
3978  read.getVectorType() == write.getVectorType() && !read.getMask() &&
3979  !write.getMask();
3980 }
3981 /// Fold transfer_write write after read:
3982 /// ```
3983 /// %t0 = ...
3984 /// %v = vector.transfer_read %t0[%c0...] :
3985 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
3986 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
3987 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
3988 /// ```
3989 ///
3990 /// into:
3991 ///
3992 /// ```
3993 /// %t0
3994 /// ```
3995 static LogicalResult foldWAR(TransferWriteOp write,
3996  SmallVectorImpl<OpFoldResult> &results) {
3997  if (!write.getSource().getType().isa<RankedTensorType>())
3998  return failure();
3999  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4000  if (!read)
4001  return failure();
4002 
4003  if (!checkSameValueWAR(read, write))
4004  return failure();
4005  results.push_back(read.getSource());
4006  return success();
4007 }
4008 
4009 LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
4010  SmallVectorImpl<OpFoldResult> &results) {
4011  if (succeeded(foldReadInitWrite(*this, operands, results)))
4012  return success();
4013  if (succeeded(foldWAR(*this, results)))
4014  return success();
4016  return success();
4017  return memref::foldMemRefCast(*this);
4018 }
4019 
4020 Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4021  return llvm::to_vector<4>(getVectorType().getShape());
4022 }
4023 
4024 void TransferWriteOp::getEffects(
4026  &effects) {
4027  if (getShapedType().isa<MemRefType>())
4028  effects.emplace_back(MemoryEffects::Write::get(), getSource(),
4029  SideEffects::DefaultResource::get());
4030 }
4031 
4032 namespace {
4033 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
4034 /// DCE
4035 /// ```
4036 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4037 /// : vector<1x4xf32>, tensor<4x4xf32>
4038 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
4039 /// : vector<1x4xf32>, tensor<4x4xf32>
4040 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4041 /// : vector<1x4xf32>, tensor<4x4xf32>
4042 /// ```
4043 ///
4044 /// into:
4045 ///
4046 /// ```
4047 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4048 /// : vector<1x4xf32>, tensor<4x4xf32>
4049 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
4050 /// : vector<1x4xf32>, tensor<4x4xf32>
4051 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4052 /// : vector<1x4xf32>, tensor<4x4xf32>
4053 /// ```
4054 ///
4055 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
4056 /// any other uses.
4057 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
4058 public:
4060  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
4061  PatternRewriter &rewriter) const override {
4062  if (!writeOp.getShapedType().isa<RankedTensorType>())
4063  return failure();
4064  vector::TransferWriteOp writeToModify = writeOp;
4065 
4066  auto defWrite =
4067  writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4068  while (defWrite) {
4069  if (checkSameValueWAW(writeOp, defWrite)) {
4070  writeToModify.getSourceMutable().assign(defWrite.getSource());
4071  return success();
4072  }
4074  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4075  cast<VectorTransferOpInterface>(writeOp.getOperation())))
4076  break;
4077  // If the previous write op doesn't have any other use we an safely look
4078  // at the previous store to see if it can be removed.
4079  if (!defWrite->hasOneUse())
4080  break;
4081  writeToModify = defWrite;
4082  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4083  }
4084  return failure();
4085  }
4086 };
4087 
4088 /// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
4089 /// could directly write to the insert_slice's destination. E.g.:
4090 ///
4091 /// ```
4092 /// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
4093 /// : vector<4x5xf32>, tensor<4x5xf32>
4094 /// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
4095 /// : tensor<4x5xf32> into tensor<?x?xf32>
4096 /// ```
4097 /// is rewritten to:
4098 /// ```
4099 /// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
4100 /// : vector<4x5xf32>, tensor<?x?xf32>
4101 /// ```
4102 struct FoldInsertSliceIntoTransferWrite
4103  : public OpRewritePattern<tensor::InsertSliceOp> {
4104 public:
4106 
4107  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
4108  PatternRewriter &rewriter) const override {
4109  if (!insertOp.hasUnitStride())
4110  return failure();
4111 
4112  auto xferOp = insertOp.getSource().getDefiningOp<TransferWriteOp>();
4113  if (!xferOp)
4114  return failure();
4115  // TODO: support 0-d corner case.
4116  if (xferOp.getTransferRank() == 0)
4117  return failure();
4118 
4119  if (xferOp.hasOutOfBoundsDim())
4120  return failure();
4121  if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
4122  return failure();
4123  if (xferOp.getMask())
4124  return failure();
4125  // Fold only if the TransferWriteOp completely overwrites the `source` with
4126  // a vector. I.e., the result of the TransferWriteOp is a new tensor whose
4127  // content is the data of the vector.
4128  if (!llvm::equal(xferOp.getVectorType().getShape(),
4129  xferOp.getShapedType().getShape()))
4130  return failure();
4131  if (!xferOp.getPermutationMap().isIdentity())
4132  return failure();
4133 
4134  // Bail on illegal rank-reduction: we need to check that the rank-reduced
4135  // dims are exactly the leading dims. I.e. the following is illegal:
4136  // ```
4137  // %0 = vector.transfer_write %v, %t[0,0], %cst :
4138  // vector<2x4xf32>, tensor<2x4xf32>
4139  // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] :
4140  // tensor<2x4xf32> into tensor<2x1x4xf32>
4141  // ```
4142  //
4143  // Cannot fold into:
4144  // ```
4145  // %0 = vector.transfer_write %v, %t[0,0,0], %cst :
4146  // vector<2x4xf32>, tensor<2x1x4xf32>
4147  // ```
4148  // For this, check the trailing `vectorRank` dims of the insert_slice result
4149  // tensor match the trailing dims of the inferred result tensor.
4150  int64_t rankReduced =
4151  insertOp.getType().getRank() - insertOp.getSourceType().getRank();
4152  int64_t vectorRank = xferOp.getVectorType().getRank();
4153  RankedTensorType inferredSourceTensorType =
4154  tensor::ExtractSliceOp::inferResultType(
4155  insertOp.getType(), insertOp.getMixedOffsets(),
4156  insertOp.getMixedSizes(), insertOp.getMixedStrides());
4157  auto actualSourceTensorShape = insertOp.getSourceType().getShape();
4158  if (rankReduced > 0 &&
4159  actualSourceTensorShape.take_back(vectorRank) !=
4160  inferredSourceTensorType.getShape().take_back(vectorRank))
4161  return failure();
4162 
4164  rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
4165  SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
4166  rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.getVector(),
4167  insertOp.getDest(), indices,
4168  ArrayRef<bool>{inBounds});
4169  return success();
4170  }
4171 };
4172 
4173 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
4174 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
4175 /// overwritten and inserted into another tensor. After this rewrite, the
4176 /// operations bufferize in-place since all of them work on the same slice.
4177 ///
4178 /// For example:
4179 /// ```mlir
4180 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
4181 /// : vector<8x16xf32>, tensor<8x16xf32>
4182 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
4183 /// : tensor<8x16xf32> to tensor<?x?xf32>
4184 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4185 /// : tensor<?x?xf32> into tensor<27x37xf32>
4186 /// ```
4187 /// folds to
4188 /// ```mlir
4189 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4190 /// : tensor<27x37xf32> to tensor<?x?xf32>
4191 /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
4192 /// : vector<8x16xf32>, tensor<?x?xf32>
4193 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4194 /// : tensor<?x?xf32> into tensor<27x37xf32>
4195 /// ```
4196 struct SwapExtractSliceOfTransferWrite
4197  : public OpRewritePattern<tensor::InsertSliceOp> {
4198 public:
4200 
4201  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
4202  PatternRewriter &rewriter) const override {
4203  if (!insertOp.hasUnitStride())
4204  return failure();
4205  auto extractOp =
4206  insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4207  if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4208  return failure();
4209  auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4210  if (!transferOp || !transferOp->hasOneUse())
4211  return failure();
4212 
4213  // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
4214  // rank-reducing.
4215  if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4216  return rewriter.notifyMatchFailure(insertOp,
4217  "use-def chain is rank-reducing");
4218  }
4219 
4220  // Fail if tensor::ExtractSliceOp has non-zero offset.
4221  if (!extractOp.hasZeroOffset()) {
4222  return rewriter.notifyMatchFailure(insertOp,
4223  "ExtractSliceOp has non-zero offset");
4224  }
4225 
4226  // Fail if tensor::TransferWriteOp has non-zero offset.
4227  if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
4228  return getConstantIntValue(value) == static_cast<int64_t>(0);
4229  })) {
4230  return rewriter.notifyMatchFailure(insertOp,
4231  "TranferWriteOp has non-zero offset");
4232  }
4233 
4234  // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
4235  for (auto [insertSize, extractSize] :
4236  llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
4237  if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
4238  return rewriter.notifyMatchFailure(
4239  insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
4240  }
4241  }
4242 
4243  // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
4244  assert(transferOp.getVectorType().hasStaticShape() &&
4245  "expected vector to have a static shape");
4246  ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
4248  transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
4249  if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
4250  return rewriter.notifyMatchFailure(
4251  insertOp, "TransferWriteOp may not write the full tensor.");
4252  }
4253 
4254  // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
4255  SmallVector<int64_t> newResultShape = applyPermutationMap(
4256  transferOp.getPermutationMap(), insertOp.getSourceType().getShape());
4257  SmallVector<bool> newInBounds;
4258  for (const auto &en : enumerate(newResultShape))
4259  newInBounds.push_back(en.value() == vectorShape[en.index()]);
4260  auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
4261  extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
4262  insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
4263  insertOp.getMixedStrides());
4264  auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
4265  transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
4266  transferOp.getIndices(), transferOp.getPermutationMapAttr(),
4267  rewriter.getBoolArrayAttr(newInBounds));
4268  rewriter.updateRootInPlace(insertOp, [&]() {
4269  insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
4270  });
4271  return success();
4272  }
4273 };
4274 
4275 } // namespace
4276 
4277 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
4278  MLIRContext *context) {
4279  results.add<FoldWaw, FoldInsertSliceIntoTransferWrite,
4280  SwapExtractSliceOfTransferWrite>(context);
4281 }
4282 
4283 //===----------------------------------------------------------------------===//
4284 // LoadOp
4285 //===----------------------------------------------------------------------===//
4286 
4287 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
4288  MemRefType memRefTy) {
4289  if (!isLastMemrefDimUnitStride(memRefTy))
4290  return op->emitOpError("most minor memref dim must have unit stride");
4291  return success();
4292 }
4293 
4295  VectorType resVecTy = getVectorType();
4296  MemRefType memRefTy = getMemRefType();
4297 
4298  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
4299  return failure();
4300 
4301  // Checks for vector memrefs.
4302  Type memElemTy = memRefTy.getElementType();
4303  if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
4304  if (memVecTy != resVecTy)
4305  return emitOpError("base memref and result vector types should match");
4306  memElemTy = memVecTy.getElementType();
4307  }
4308 
4309  if (resVecTy.getElementType() != memElemTy)
4310  return emitOpError("base and result element types should match");
4311  if (llvm::size(getIndices()) != memRefTy.getRank())
4312  return emitOpError("requires ") << memRefTy.getRank() << " indices";
4313  return success();
4314 }
4315 
4316 OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
4317  if (succeeded(memref::foldMemRefCast(*this)))
4318  return getResult();
4319  return OpFoldResult();
4320 }
4321 
4322 //===----------------------------------------------------------------------===//
4323 // StoreOp
4324 //===----------------------------------------------------------------------===//
4325 
4327  VectorType valueVecTy = getVectorType();
4328  MemRefType memRefTy = getMemRefType();
4329 
4330  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
4331  return failure();
4332 
4333  // Checks for vector memrefs.
4334  Type memElemTy = memRefTy.getElementType();
4335  if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
4336  if (memVecTy != valueVecTy)
4337  return emitOpError(
4338  "base memref and valueToStore vector types should match");
4339  memElemTy = memVecTy.getElementType();
4340  }
4341 
4342  if (valueVecTy.getElementType() != memElemTy)
4343  return emitOpError("base and valueToStore element type should match");
4344  if (llvm::size(getIndices()) != memRefTy.getRank())
4345  return emitOpError("requires ") << memRefTy.getRank() << " indices";
4346  return success();
4347 }
4348 
4349 LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
4350  SmallVectorImpl<OpFoldResult> &results) {
4351  return memref::foldMemRefCast(*this);
4352 }
4353 
4354 //===----------------------------------------------------------------------===//
4355 // MaskedLoadOp
4356 //===----------------------------------------------------------------------===//
4357 
4359  VectorType maskVType = getMaskVectorType();
4360  VectorType passVType = getPassThruVectorType();
4361  VectorType resVType = getVectorType();
4362  MemRefType memType = getMemRefType();
4363 
4364  if (resVType.getElementType() != memType.getElementType())
4365  return emitOpError("base and result element type should match");
4366  if (llvm::size(getIndices()) != memType.getRank())
4367  return emitOpError("requires ") << memType.getRank() << " indices";
4368  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4369  return emitOpError("expected result dim to match mask dim");
4370  if (resVType != passVType)
4371  return emitOpError("expected pass_thru of same type as result type");
4372  return success();
4373 }
4374 
4375 namespace {
4376 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
4377 public:
4379  LogicalResult matchAndRewrite(MaskedLoadOp load,
4380  PatternRewriter &rewriter) const override {
4381  switch (getMaskFormat(load.getMask())) {
4382  case MaskFormat::AllTrue:
4383  rewriter.replaceOpWithNewOp<vector::LoadOp>(
4384  load, load.getType(), load.getBase(), load.getIndices());
4385  return success();
4386  case MaskFormat::AllFalse:
4387  rewriter.replaceOp(load, load.getPassThru());
4388  return success();
4389  case MaskFormat::Unknown:
4390  return failure();
4391  }
4392  llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
4393  }
4394 };
4395 } // namespace
4396 
4397 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4398  MLIRContext *context) {
4399  results.add<MaskedLoadFolder>(context);
4400 }
4401 
4402 OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
4403  if (succeeded(memref::foldMemRefCast(*this)))
4404  return getResult();
4405  return OpFoldResult();
4406 }
4407 
4408 //===----------------------------------------------------------------------===//
4409 // MaskedStoreOp
4410 //===----------------------------------------------------------------------===//
4411 
4413  VectorType maskVType = getMaskVectorType();
4414  VectorType valueVType = getVectorType();
4415  MemRefType memType = getMemRefType();
4416 
4417  if (valueVType.getElementType() != memType.getElementType())
4418  return emitOpError("base and valueToStore element type should match");
4419  if (llvm::size(getIndices()) != memType.getRank())
4420  return emitOpError("requires ") << memType.getRank() << " indices";
4421  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4422  return emitOpError("expected valueToStore dim to match mask dim");
4423  return success();
4424 }
4425 
4426 namespace {
4427 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
4428 public:
4430  LogicalResult matchAndRewrite(MaskedStoreOp store,
4431  PatternRewriter &rewriter) const override {
4432  switch (getMaskFormat(store.getMask())) {
4433  case MaskFormat::AllTrue:
4434  rewriter.replaceOpWithNewOp<vector::StoreOp>(
4435  store, store.getValueToStore(), store.getBase(), store.getIndices());
4436  return success();
4437  case MaskFormat::AllFalse:
4438  rewriter.eraseOp(store);
4439  return success();
4440  case MaskFormat::Unknown:
4441  return failure();
4442  }
4443  llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
4444  }
4445 };
4446 } // namespace
4447 
4448 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
4449  MLIRContext *context) {
4450  results.add<MaskedStoreFolder>(context);
4451 }
4452 
4453 LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
4454  SmallVectorImpl<OpFoldResult> &results) {
4455  return memref::foldMemRefCast(*this);
4456 }
4457 
4458 //===----------------------------------------------------------------------===//
4459 // GatherOp
4460 //===----------------------------------------------------------------------===//
4461 
4463  VectorType indVType = getIndexVectorType();
4464  VectorType maskVType = getMaskVectorType();
4465  VectorType resVType = getVectorType();
4466  ShapedType baseType = getBaseType();
4467 
4468  if (!baseType.isa<MemRefType, RankedTensorType>())
4469  return emitOpError("requires base to be a memref or ranked tensor type");
4470 
4471  if (resVType.getElementType() != baseType.getElementType())
4472  return emitOpError("base and result element type should match");
4473  if (llvm::size(getIndices()) != baseType.getRank())
4474  return emitOpError("requires ") << baseType.getRank() << " indices";
4475  if (resVType.getShape() != indVType.getShape())
4476  return emitOpError("expected result dim to match indices dim");
4477  if (resVType.getShape() != maskVType.getShape())
4478  return emitOpError("expected result dim to match mask dim");
4479  if (resVType != getPassThruVectorType())
4480  return emitOpError("expected pass_thru of same type as result type");
4481  return success();
4482 }
4483 
4484 namespace {
4485 class GatherFolder final : public OpRewritePattern<GatherOp> {
4486 public:
4488  LogicalResult matchAndRewrite(GatherOp gather,
4489  PatternRewriter &rewriter) const override {
4490  switch (getMaskFormat(gather.getMask())) {
4491  case MaskFormat::AllTrue:
4492  return failure(); // no unmasked equivalent
4493  case MaskFormat::AllFalse:
4494  rewriter.replaceOp(gather, gather.getPassThru());
4495  return success();
4496  case MaskFormat::Unknown:
4497  return failure();
4498  }
4499  llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
4500  }
4501 };
4502 } // namespace
4503 
4504 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
4505  MLIRContext *context) {
4506  results.add<GatherFolder>(context);
4507 }
4508 
4509 //===----------------------------------------------------------------------===//
4510 // ScatterOp
4511 //===----------------------------------------------------------------------===//
4512 
4514  VectorType indVType = getIndexVectorType();
4515  VectorType maskVType = getMaskVectorType();
4516  VectorType valueVType = getVectorType();
4517  MemRefType memType = getMemRefType();
4518 
4519  if (valueVType.getElementType() != memType.getElementType())
4520  return emitOpError("base and valueToStore element type should match");
4521  if (llvm::size(getIndices()) != memType.getRank())
4522  return emitOpError("requires ") << memType.getRank() << " indices";
4523  if (valueVType.getDimSize(0) != indVType.getDimSize(0))
4524  return emitOpError("expected valueToStore dim to match indices dim");
4525  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4526  return emitOpError("expected valueToStore dim to match mask dim");
4527  return success();
4528 }
4529 
4530 namespace {
4531 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
4532 public:
4534  LogicalResult matchAndRewrite(ScatterOp scatter,
4535  PatternRewriter &rewriter) const override {
4536  switch (getMaskFormat(scatter.getMask())) {
4537  case MaskFormat::AllTrue:
4538  return failure(); // no unmasked equivalent
4539  case MaskFormat::AllFalse:
4540  rewriter.eraseOp(scatter);
4541  return success();
4542  case MaskFormat::Unknown:
4543  return failure();
4544  }
4545  llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
4546  }
4547 };
4548 } // namespace
4549 
4550 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
4551  MLIRContext *context) {
4552  results.add<ScatterFolder>(context);
4553 }
4554 
4555 //===----------------------------------------------------------------------===//
4556 // ExpandLoadOp
4557 //===----------------------------------------------------------------------===//
4558 
4560  VectorType maskVType = getMaskVectorType();
4561  VectorType passVType = getPassThruVectorType();
4562  VectorType resVType = getVectorType();
4563  MemRefType memType = getMemRefType();
4564 
4565  if (resVType.getElementType() != memType.getElementType())
4566  return emitOpError("base and result element type should match");
4567  if (llvm::size(getIndices()) != memType.getRank())
4568  return emitOpError("requires ") << memType.getRank() << " indices";
4569  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4570  return emitOpError("expected result dim to match mask dim");
4571  if (resVType != passVType)
4572  return emitOpError("expected pass_thru of same type as result type");
4573  return success();
4574 }
4575 
4576 namespace {
4577 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
4578 public:
4580  LogicalResult matchAndRewrite(ExpandLoadOp expand,
4581  PatternRewriter &rewriter) const override {
4582  switch (getMaskFormat(expand.getMask())) {
4583  case MaskFormat::AllTrue:
4584  rewriter.replaceOpWithNewOp<vector::LoadOp>(
4585  expand, expand.getType(), expand.getBase(), expand.getIndices());
4586  return success();
4587  case MaskFormat::AllFalse:
4588  rewriter.replaceOp(expand, expand.getPassThru());
4589  return success();
4590  case MaskFormat::Unknown:
4591  return failure();
4592  }
4593  llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
4594  }
4595 };
4596 } // namespace
4597 
4598 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4599  MLIRContext *context) {
4600  results.add<ExpandLoadFolder>(context);
4601 }
4602 
4603 //===----------------------------------------------------------------------===//
4604 // CompressStoreOp
4605 //===----------------------------------------------------------------------===//
4606 
4608  VectorType maskVType = getMaskVectorType();
4609  VectorType valueVType = getVectorType();
4610  MemRefType memType = getMemRefType();
4611 
4612  if (valueVType.getElementType() != memType.getElementType())
4613  return emitOpError("base and valueToStore element type should match");
4614  if (llvm::size(getIndices()) != memType.getRank())
4615  return emitOpError("requires ") << memType.getRank() << " indices";
4616  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4617  return emitOpError("expected valueToStore dim to match mask dim");
4618  return success();
4619 }
4620 
4621 namespace {
4622 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
4623 public:
4625  LogicalResult matchAndRewrite(CompressStoreOp compress,
4626  PatternRewriter &rewriter) const override {
4627  switch (getMaskFormat(compress.getMask())) {
4628  case MaskFormat::AllTrue:
4629  rewriter.replaceOpWithNewOp<vector::StoreOp>(
4630  compress, compress.getValueToStore(), compress.getBase(),
4631  compress.getIndices());
4632  return success();
4633  case MaskFormat::AllFalse:
4634  rewriter.eraseOp(compress);
4635  return success();
4636  case MaskFormat::Unknown:
4637  return failure();
4638  }
4639  llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
4640  }
4641 };
4642 } // namespace
4643 
4644 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
4645  MLIRContext *context) {
4646  results.add<CompressStoreFolder>(context);
4647 }
4648 
4649 //===----------------------------------------------------------------------===//
4650 // ShapeCastOp
4651 //===----------------------------------------------------------------------===//
4652 
4653 /// Returns true if each element of 'a' is equal to the product of a contiguous
4654 /// sequence of the elements of 'b'. Returns false otherwise.
4655 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
4656  unsigned rankA = a.size();
4657  unsigned rankB = b.size();
4658  assert(rankA < rankB);
4659 
4660  unsigned i = 0;
4661  unsigned j = 0;
4662  while (i < rankA && j < rankB) {
4663  int64_t dimA = a[i];
4664  int64_t dimB = 1;
4665  while (dimB < dimA && j < rankB)
4666  dimB *= b[j++];
4667  if (dimA != dimB)
4668  break;
4669  ++i;
4670 
4671  // Handle the case when trailing dimensions are of size 1.
4672  // Include them into the contiguous sequence.
4673  auto isOne = [](int64_t v) { return v == 1; };
4674  if (i < rankA && llvm::all_of(a.slice(i), isOne))
4675  i = rankA;
4676  if (j < rankB && llvm::all_of(b.slice(j), isOne))
4677  j = rankB;
4678  }
4679 
4680  return i == rankA && j == rankB;
4681 }
4682 
4683 static LogicalResult verifyVectorShapeCast(Operation *op,
4684  VectorType sourceVectorType,
4685  VectorType resultVectorType) {
4686  // Check that element type is the same.
4687  if (sourceVectorType.getElementType() != resultVectorType.getElementType())
4688  return op->emitOpError("source/result vectors must have same element type");
4689  auto sourceShape = sourceVectorType.getShape();
4690  auto resultShape = resultVectorType.getShape();
4691 
4692  // Check that product of source dim sizes matches product of result dim sizes.
4693  int64_t sourceDimProduct = std::accumulate(
4694  sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
4695  int64_t resultDimProduct = std::accumulate(
4696  resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
4697  if (sourceDimProduct != resultDimProduct)
4698  return op->emitOpError("source/result number of elements must match");
4699 
4700  // Check that expanding/contracting rank cases.
4701  unsigned sourceRank = sourceVectorType.getRank();
4702  unsigned resultRank = resultVectorType.getRank();
4703  if (sourceRank < resultRank) {
4704  if (!isValidShapeCast(sourceShape, resultShape))
4705  return op->emitOpError("invalid shape cast");
4706  } else if (sourceRank > resultRank) {
4707  if (!isValidShapeCast(resultShape, sourceShape))
4708  return op->emitOpError("invalid shape cast");
4709  }
4710  return success();
4711 }
4712 
4714  auto sourceVectorType = getSource().getType().dyn_cast_or_null<VectorType>();
4715  auto resultVectorType = getResult().getType().dyn_cast_or_null<VectorType>();
4716 
4717  // Check if source/result are of vector type.
4718  if (sourceVectorType && resultVectorType)
4719  return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
4720 
4721  return success();
4722 }
4723 
4724 OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
4725  // No-op shape cast.
4726  if (getSource().getType() == getResult().getType())
4727  return getSource();
4728 
4729  // Canceling shape casts.
4730  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
4731  if (getResult().getType() == otherOp.getSource().getType())
4732  return otherOp.getSource();
4733 
4734  // Only allows valid transitive folding.
4735  VectorType srcType = otherOp.getSource().getType().cast<VectorType>();
4736  VectorType resultType = getResult().getType().cast<VectorType>();
4737  if (srcType.getRank() < resultType.getRank()) {
4738  if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
4739  return {};
4740  } else if (srcType.getRank() > resultType.getRank()) {
4741  if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
4742  return {};
4743  } else {
4744  return {};
4745  }
4746 
4747  setOperand(otherOp.getSource());
4748  return getResult();
4749  }
4750 
4751  // Cancelling broadcast and shape cast ops.
4752  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
4753  if (bcastOp.getSourceType() == getType())
4754  return bcastOp.getSource();
4755  }
4756 
4757  return {};
4758 }
4759 
4760 namespace {
4761 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
4762 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
4763 public:
4765 
4766  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
4767  PatternRewriter &rewriter) const override {
4768  auto constantOp =
4769  shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
4770  if (!constantOp)
4771  return failure();
4772  // Only handle splat for now.
4773  auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
4774  if (!dense)
4775  return failure();
4776  auto newAttr =
4777  DenseElementsAttr::get(shapeCastOp.getType().cast<VectorType>(),
4778  dense.getSplatValue<Attribute>());
4779  rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
4780  return success();
4781  }
4782 };
4783 
4784 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
4785 /// This only applies when the shape of the broadcast source is a suffix of the
4786 /// shape of the result (i.e. when broadcast without reshape is expressive
4787 /// enough to capture the result in a single op).
4788 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
4789 public:
4791 
4792  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
4793  PatternRewriter &rewriter) const override {
4794  auto broadcastOp =
4795  shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
4796  if (!broadcastOp)
4797  return failure();
4798 
4799  auto broadcastSourceVectorType =
4800  broadcastOp.getSourceType().dyn_cast<VectorType>();
4801  auto broadcastSourceShape = broadcastSourceVectorType
4802  ? broadcastSourceVectorType.getShape()
4803  : ArrayRef<int64_t>{};
4804  auto shapeCastTargetShape = shapeCastOp.getResultVectorType().getShape();
4805 
4806  // Bail if `broadcastSourceShape` is not a suffix of the result.
4807  bool isSuffix = (broadcastSourceShape == shapeCastTargetShape.take_back(
4808  broadcastSourceShape.size()));
4809  if (!isSuffix)
4810  return failure();
4811 
4812  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
4813  shapeCastOp, shapeCastOp.getResultVectorType(),
4814  broadcastOp.getSource());
4815  return success();
4816  }
4817 };
4818 
4819 } // namespace
4820 
4821 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
4822  MLIRContext *context) {
4823  results.add<ShapeCastConstantFolder, ShapeCastBroadcastFolder>(context);
4824 }
4825 
4826 //===----------------------------------------------------------------------===//
4827 // VectorBitCastOp
4828 //===----------------------------------------------------------------------===//
4829 
4831  auto sourceVectorType = getSourceVectorType();
4832  auto resultVectorType = getResultVectorType();
4833 
4834  for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
4835  if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
4836  return emitOpError("dimension size mismatch at: ") << i;
4837  }
4838 
4839  DataLayout dataLayout = DataLayout::closest(*this);
4840  auto sourceElementBits =
4841  dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
4842  auto resultElementBits =
4843  dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
4844 
4845  if (sourceVectorType.getRank() == 0) {
4846  if (sourceElementBits != resultElementBits)
4847  return emitOpError("source/result bitwidth of the 0-D vector element "
4848  "types must be equal");
4849  } else if (sourceElementBits * sourceVectorType.getShape().back() !=
4850  resultElementBits * resultVectorType.getShape().back()) {
4851  return emitOpError(
4852  "source/result bitwidth of the minor 1-D vectors must be equal");
4853  }
4854 
4855  return success();
4856 }
4857 
4858 OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
4859  // Nop cast.
4860  if (getSource().getType() == getResult().getType())
4861  return getSource();
4862 
4863  // Canceling bitcasts.
4864  if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
4865  if (getResult().getType() == otherOp.getSource().getType())
4866  return otherOp.getSource();
4867 
4868  setOperand(otherOp.getSource());
4869  return getResult();
4870  }
4871 
4872  Attribute sourceConstant = operands.front();
4873  if (!sourceConstant)
4874  return {};
4875 
4876  Type srcElemType = getSourceVectorType().getElementType();
4877  Type dstElemType = getResultVectorType().getElementType();
4878 
4879  if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) {
4880  if (floatPack.isSplat()) {
4881  auto splat = floatPack.getSplatValue<FloatAttr>();
4882 
4883  // Casting fp16 into fp32.
4884  if (srcElemType.isF16() && dstElemType.isF32()) {
4885  uint32_t bits = static_cast<uint32_t>(
4886  splat.getValue().bitcastToAPInt().getZExtValue());
4887  // Duplicate the 16-bit pattern.
4888  bits = (bits << 16) | (bits & 0xffff);
4889  APInt intBits(32, bits);
4890  APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
4891  return DenseElementsAttr::get(getResultVectorType(), floatBits);
4892  }
4893  }
4894  }
4895 
4896  return {};
4897 }
4898 
4899 //===----------------------------------------------------------------------===//
4900 // TypeCastOp
4901 //===----------------------------------------------------------------------===//
4902 
4903 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
4904  auto vectorType = memRefType.getElementType().dyn_cast<VectorType>();
4905  SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
4906  memRefType.getShape().end());
4907  if (vectorType)
4908  res.append(vectorType.getShape().begin(), vectorType.getShape().end());
4909  return res;
4910 }
4911 
4912 /// Build the canonical memRefType with a single vector.
4913 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
4914 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
4915  Value source) {
4916  result.addOperands(source);
4917  MemRefType memRefType = source.getType().cast<MemRefType>();
4918  VectorType vectorType =
4919  VectorType::get(extractShape(memRefType),
4921  result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
4922  memRefType.getMemorySpace()));
4923 }
4924 
4926  MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType());
4927  if (!canonicalType.getLayout().isIdentity())
4928  return emitOpError("expects operand to be a memref with identity layout");
4929  if (!getResultMemRefType().getLayout().isIdentity())
4930  return emitOpError("expects result to be a memref with identity layout");
4931  if (getResultMemRefType().getMemorySpace() !=
4932  getMemRefType().getMemorySpace())
4933  return emitOpError("expects result in same memory space");
4934 
4935  auto sourceType = getMemRefType();
4936  auto resultType = getResultMemRefType();
4937  if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
4939  return emitOpError(
4940  "expects result and operand with same underlying scalar type: ")
4941  << resultType;
4942  if (extractShape(sourceType) != extractShape(resultType))
4943  return emitOpError(
4944  "expects concatenated result and operand shapes to be equal: ")
4945  << resultType;
4946  return success();
4947 }
4948 
4949 //===----------------------------------------------------------------------===//
4950 // TransposeOp
4951 //===----------------------------------------------------------------------===//
4952 
4953 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
4954  Value vector, ArrayRef<int64_t> transp) {
4955  VectorType vt = vector.getType().cast<VectorType>();
4956  SmallVector<int64_t, 4> transposedShape(vt.getRank());
4957  for (unsigned i = 0; i < transp.size(); ++i)
4958  transposedShape[i] = vt.getShape()[transp[i]];
4959 
4960  result.addOperands(vector);
4961  result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
4962  result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp));
4963 }
4964 
4965 OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
4966  // Eliminate splat constant transpose ops.
4967  if (auto attr = operands.front().dyn_cast_or_null<DenseElementsAttr>())
4968  if (attr.isSplat())
4969  return attr.reshape(getResultType());
4970 
4971  // Eliminate identity transpose ops. This happens when the dimensions of the
4972  // input vector remain in their original order after the transpose operation.
4973  SmallVector<int64_t, 4> transp;
4974  getTransp(transp);
4975 
4976  // Check if the permutation of the dimensions contains sequential values:
4977  // {0, 1, 2, ...}.
4978  for (int64_t i = 0, e = transp.size(); i < e; i++) {
4979  if (transp[i] != i)
4980  return {};
4981  }
4982 
4983  return getVector();
4984 }
4985 
4987  VectorType vectorType = getVectorType();
4988  VectorType resultType = getResultType();
4989  int64_t rank = resultType.getRank();
4990  if (vectorType.getRank() != rank)
4991  return emitOpError("vector result rank mismatch: ") << rank;
4992  // Verify transposition array.
4993  auto transpAttr = getTransp().getValue();
4994  int64_t size = transpAttr.size();
4995  if (rank != size)
4996  return emitOpError("transposition length mismatch: ") << size;
4997  SmallVector<bool, 8> seen(rank, false);
4998  for (const auto &ta : llvm::enumerate(transpAttr)) {
4999  int64_t i = ta.value().cast<IntegerAttr>().getInt();
5000  if (i < 0 || i >= rank)
5001  return emitOpError("transposition index out of range: ") << i;
5002  if (seen[i])
5003  return emitOpError("duplicate position index: ") << i;
5004  seen[i] = true;
5005  if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
5006  return emitOpError("dimension size mismatch at: ") << i;
5007  }
5008  return success();
5009 }
5010 
5011 Optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
5012  return llvm::to_vector<4>(getResultType().getShape());
5013 }
5014 
5015 namespace {
5016 
5017 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
5018 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
5019 public:
5021 
5022  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5023  PatternRewriter &rewriter) const override {
5024  // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
5025  auto getPermutation = [](vector::TransposeOp transpose) {
5026  SmallVector<int64_t, 4> permutation;
5027  transpose.getTransp(permutation);
5028  return permutation;
5029  };
5030 
5031  // Composes two permutations: result[i] = permutation1[permutation2[i]].
5032  auto composePermutations = [](ArrayRef<int64_t> permutation1,
5033  ArrayRef<int64_t> permutation2) {
5034  SmallVector<int64_t, 4> result;
5035  for (auto index : permutation2)
5036  result.push_back(permutation1[index]);
5037  return result;
5038  };
5039 
5040  // Return if the input of 'transposeOp' is not defined by another transpose.
5041  vector::TransposeOp parentTransposeOp =
5042  transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
5043  if (!parentTransposeOp)
5044  return failure();
5045 
5046  SmallVector<int64_t, 4> permutation = composePermutations(
5047  getPermutation(parentTransposeOp), getPermutation(transposeOp));
5048  // Replace 'transposeOp' with a new transpose operation.
5049  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
5050  transposeOp, transposeOp.getResult().getType(),
5051  parentTransposeOp.getVector(),
5052  vector::getVectorSubscriptAttr(rewriter, permutation));
5053  return success();
5054  }
5055 };
5056 
5057 // Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
5058 struct FoldTransposedScalarBroadcast final
5059  : public OpRewritePattern<vector::TransposeOp> {
5061 
5062  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5063  PatternRewriter &rewriter) const override {
5064  auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
5065  if (!bcastOp)
5066  return failure();
5067 
5068  auto srcVectorType = bcastOp.getSourceType().dyn_cast<VectorType>();
5069  if (!srcVectorType || srcVectorType.getNumElements() == 1) {
5070  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5071  transposeOp, transposeOp.getResultType(), bcastOp.getSource());
5072  return success();
5073  }
5074 
5075  return failure();
5076  }
5077 };
5078 
5079 // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
5080 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
5081 public:
5083 
5084  LogicalResult matchAndRewrite(TransposeOp transposeOp,
5085  PatternRewriter &rewriter) const override {
5086  auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
5087  if (!splatOp)
5088  return failure();
5089 
5090  rewriter.replaceOpWithNewOp<vector::SplatOp>(
5091  transposeOp, transposeOp.getResultType(), splatOp.getInput());
5092  return success();
5093  }
5094 };
5095 
5096 } // namespace
5097 
5098 void vector::TransposeOp::getCanonicalizationPatterns(
5099  RewritePatternSet &results, MLIRContext *context) {
5100  results
5101  .add<FoldTransposedScalarBroadcast, TransposeFolder, FoldTransposeSplat>(
5102  context);
5103 }
5104 
5105 void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
5106  populateFromInt64AttrArray(getTransp(), results);
5107 }
5108 
5109 //===----------------------------------------------------------------------===//
5110 // ConstantMaskOp
5111 //===----------------------------------------------------------------------===//
5112 
5114  auto resultType = getResult().getType().cast<VectorType>();
5115  // Check the corner case of 0-D vectors first.
5116  if (resultType.getRank() == 0) {
5117  if (getMaskDimSizes().size() != 1)
5118  return emitError("array attr must have length 1 for 0-D vectors");
5119  auto dim = getMaskDimSizes()[0].cast<IntegerAttr>().getInt();
5120  if (dim != 0 && dim != 1)
5121  return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
5122  return success();
5123  }
5124 
5125  // Verify that array attr size matches the rank of the vector result.
5126  if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
5127  return emitOpError(
5128  "must specify array attr of size equal vector result rank");
5129  // Verify that each array attr element is in bounds of corresponding vector
5130  // result dimension size.
5131  auto resultShape = resultType.getShape();
5132  SmallVector<int64_t, 4> maskDimSizes;
5133  for (const auto &it : llvm::enumerate(getMaskDimSizes())) {
5134  int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
5135  if (attrValue < 0 || attrValue > resultShape[it.index()])
5136  return emitOpError(
5137  "array attr of size out of bounds of vector result dimension size");
5138  maskDimSizes.push_back(attrValue);
5139  }
5140  // Verify that if one mask dim size is zero, they all should be zero (because
5141  // the mask region is a conjunction of each mask dimension interval).
5142  bool anyZeros = llvm::is_contained(maskDimSizes, 0);
5143  bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
5144  if (anyZeros && !allZeros)
5145  return emitOpError("expected all mask dim sizes to be zeros, "
5146  "as a result of conjunction with zero mask dim");
5147  // Verify that if the mask type is scalable, dimensions should be zero because
5148  // constant scalable masks can only be defined for the "none set" or "all set"
5149  // cases, and there is no VLA way to define an "all set" case for
5150  // `vector.constant_mask`. In the future, a convention could be established
5151  // to decide if a specific dimension value could be considered as "all set".
5152  if (resultType.isScalable() &&
5153  getMaskDimSizes()[0].cast<IntegerAttr>().getInt() != 0)
5154  return emitOpError("expected mask dim sizes for scalable masks to be 0");
5155  return success();
5156 }
5157 
5158 //===----------------------------------------------------------------------===//
5159 // CreateMaskOp
5160 //===----------------------------------------------------------------------===//
5161 
5163  auto vectorType = getResult().getType().cast<VectorType>();
5164  // Verify that an operand was specified for each result vector each dimension.
5165  if (vectorType.getRank() == 0) {
5166  if (getNumOperands() != 1)
5167  return emitOpError(
5168  "must specify exactly one operand for 0-D create_mask");
5169  } else if (getNumOperands() !=
5170  getResult().getType().cast<VectorType>().getRank()) {
5171  return emitOpError(
5172  "must specify an operand for each result vector dimension");
5173  }
5174  return success();
5175 }
5176 
5177 namespace {
5178 
5179 // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
5180 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
5181 public:
5183 
5184  LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
5185  PatternRewriter &rewriter) const override {
5186  // Return if any of 'createMaskOp' operands are not defined by a constant.
5187  auto isNotDefByConstant = [](Value operand) {
5188  return !isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
5189  };
5190  if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant))
5191  return failure();
5192 
5193  // CreateMaskOp for scalable vectors can be folded only if all dimensions
5194  // are negative or zero.
5195  if (auto vType = createMaskOp.getType().dyn_cast<VectorType>()) {
5196  if (vType.isScalable())
5197  for (auto opDim : createMaskOp.getOperands()) {
5198  APInt intVal;
5199  if (matchPattern(opDim, m_ConstantInt(&intVal)) &&
5200  intVal.isStrictlyPositive())
5201  return failure();
5202  }
5203  }
5204 
5205  // Gather constant mask dimension sizes.
5206  SmallVector<int64_t, 4> maskDimSizes;
5207  maskDimSizes.reserve(createMaskOp->getNumOperands());
5208  for (auto [operand, maxDimSize] : llvm::zip_equal(
5209  createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
5210  Operation *defOp = operand.getDefiningOp();
5211  int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value();
5212  dimSize = std::min(dimSize, maxDimSize);
5213  // If one of dim sizes is zero, set all dims to zero.
5214  if (dimSize <= 0) {
5215  maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
5216  break;
5217  }
5218  maskDimSizes.push_back(dimSize);
5219  }
5220  // Replace 'createMaskOp' with ConstantMaskOp.
5221  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
5222  createMaskOp, createMaskOp.getResult().getType(),
5223  vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
5224  return success();
5225  }
5226 };
5227 
5228 } // namespace
5229 
5230 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
5231  MLIRContext *context) {
5232  results.add<CreateMaskFolder>(context);
5233 }
5234 
5235 //===----------------------------------------------------------------------===//
5236 // MaskOp
5237 //===----------------------------------------------------------------------===//
5238 
5239 void MaskOp::build(
5240  OpBuilder &builder, OperationState &result, Value mask,
5241  function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
5242  assert(maskRegionBuilder &&
5243  "builder callback for 'maskRegion' must be present");
5244 
5245  result.addOperands(mask);
5246  OpBuilder::InsertionGuard guard(builder);
5247  Region *maskRegion = result.addRegion();
5248  builder.createBlock(maskRegion);
5249  maskRegionBuilder(builder, result.location);
5250 }
5251 
5252 void MaskOp::build(
5253  OpBuilder &builder, OperationState &result, Type resultType, Value mask,
5254  function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
5255  build(builder, result, resultType, mask, /*passthru=*/Value(),
5256  maskRegionBuilder);
5257 }
5258 
5259 void MaskOp::build(
5260  OpBuilder &builder, OperationState &result, Type resultType, Value mask,
5261  Value passthru,
5262  function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
5263  build(builder, result, mask, maskRegionBuilder);
5264  if (passthru)
5265  result.addOperands(passthru);
5266  result.addTypes(resultType);
5267 }
5268 
5269 ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
5270  // Create the op region.
5271  result.regions.reserve(1);
5272  Region &maskRegion = *result.addRegion();
5273 
5274  auto &builder = parser.getBuilder();
5275 
5276  // Parse all the operands.
5278  if (parser.parseOperand(mask))
5279  return failure();
5280 
5281  // Optional passthru operand.
5283  ParseResult parsePassthru = parser.parseOptionalComma();
5284  if (parsePassthru.succeeded() && parser.parseOperand(passthru))
5285  return failure();
5286 
5287  // Parse op region.
5288  if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
5289  return failure();
5290 
5291  MaskOp::ensureTerminator(maskRegion, builder, result.location);
5292 
5293  // Parse the optional attribute list.
5294  if (parser.parseOptionalAttrDict(result.attributes))
5295  return failure();
5296 
5297  // Parse all the types.
5298  Type maskType;
5299  if (parser.parseColonType(maskType))
5300  return failure();
5301 
5302  SmallVector<Type> resultTypes;
5303  if (parser.parseOptionalArrowTypeList(resultTypes))
5304  return failure();
5305  result.types.append(resultTypes);
5306 
5307  // Resolve operands.
5308  if (parser.resolveOperand(mask, maskType, result.operands))
5309  return failure();
5310 
5311  if (parsePassthru.succeeded())
5312  if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
5313  return failure();
5314 
5315  return success();
5316 }
5317 
5319  p << " " << getMask();
5320  if (getPassthru())
5321  p << ", " << getPassthru();
5322 
5323  // Print single masked operation and skip terminator.
5324  p << " { ";
5325  Block *singleBlock = &getMaskRegion().getBlocks().front();
5326  if (singleBlock && singleBlock->getOperations().size() > 1)
5327  p.printCustomOrGenericOp(&singleBlock->front());
5328  p << " }";
5329 
5330  p.printOptionalAttrDict(getOperation()->getAttrs());
5331 
5332  p << " : " << getMask().getType();
5333  if (getNumResults() > 0)
5334  p << " -> " << getResultTypes();
5335 }
5336 
5337 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
5339  MaskOp>::ensureTerminator(region, builder, loc);
5340  // Keep the default yield terminator if the number of masked operations is not
5341  // the expected. This case will trigger a verification failure.
5342  if (region.front().getOperations().size() != 2)
5343  return;
5344 
5345  // Replace default yield terminator with a new one that returns the results
5346  // from the masked operation.
5347  OpBuilder opBuilder(builder.getContext());
5348  Operation *maskedOp = &region.front().front();
5349  Operation *oldYieldOp = &region.front().back();
5350  assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
5351 
5352  opBuilder.setInsertionPoint(oldYieldOp);
5353  opBuilder.create<vector::YieldOp>(maskedOp->getLoc(), maskedOp->getResults());
5354  oldYieldOp->dropAllReferences();
5355  oldYieldOp->erase();
5356 }
5357 
5359  // Structural checks.
5360  Block &block = getMaskRegion().getBlocks().front();
5361  if (block.getOperations().size() < 2)
5362  return emitOpError("expects an operation to mask");
5363  if (block.getOperations().size() > 2)
5364  return emitOpError("expects only one operation to mask");
5365 
5366  auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
5367  if (!maskableOp)
5368  return emitOpError("expects a maskable operation");
5369 
5370  // Result checks.
5371  if (maskableOp->getNumResults() != getNumResults())
5372  return emitOpError("expects number of results to match maskable operation "
5373  "number of results");
5374 
5375  if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
5376  return emitOpError(
5377  "expects result type to match maskable operation result type");
5378 
5379  // Mask checks.
5380  if (getMask().getType() != maskableOp.getExpectedMaskType())
5381  return emitOpError("expects a ") << maskableOp.getExpectedMaskType()
5382  << " mask for the maskable operation";
5383 
5384  // Passthru checks.
5385  Value passthru = getPassthru();
5386  if (passthru) {
5387  if (!maskableOp.supportsPassthru())
5388  return emitOpError(
5389  "doesn't expect a passthru argument for this maskable operation");
5390 
5391  if (maskableOp->getNumResults() != 1)
5392  return emitOpError("expects result when passthru argument is provided");
5393 
5394  if (passthru.getType() != maskableOp->getResultTypes()[0])
5395  return emitOpError("expects passthru type to match result type");
5396  }
5397 
5398  return success();
5399 }
5400 
5401 // MaskingOpInterface definitions.
5402 
5403 /// Returns the operation masked by this 'vector.mask'.
5404 Operation *MaskOp::getMaskableOp() { return &getMaskRegion().front().front(); }
5405 
5406 /// Returns true if 'vector.mask' has a passthru value.
5407 bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
5408 
5409 //===----------------------------------------------------------------------===//
5410 // ScanOp
5411 //===----------------------------------------------------------------------===//
5412 
5414  VectorType srcType = getSourceType();
5415  VectorType initialType = getInitialValueType();
5416  // Check reduction dimension < rank.
5417  int64_t srcRank = srcType.getRank();
5418  int64_t reductionDim = getReductionDim();
5419  if (reductionDim >= srcRank)
5420  return emitOpError("reduction dimension ")
5421  << reductionDim << " has to be less than " << srcRank;
5422 
5423  // Check that rank(initial_value) = rank(src) - 1.
5424  int64_t initialValueRank = initialType.getRank();
5425  if (initialValueRank != srcRank - 1)
5426  return emitOpError("initial value rank ")
5427  << initialValueRank << " has to be equal to " << srcRank - 1;
5428 
5429  // Check shapes of initial value and src.
5430  ArrayRef<int64_t> srcShape = srcType.getShape();
5431  ArrayRef<int64_t> initialValueShapes = initialType.getShape();
5432  SmallVector<int64_t> expectedShape;
5433  for (int i = 0; i < srcRank; i++) {
5434  if (i != reductionDim)
5435  expectedShape.push_back(srcShape[i]);
5436  }
5437  if (!llvm::equal(initialValueShapes, expectedShape)) {
5438  return emitOpError("incompatible input/initial value shapes");
5439  }
5440 
5441  // Verify supported reduction kind.
5442  Type eltType = getDestType().getElementType();
5443  if (!isSupportedCombiningKind(getKind(), eltType))
5444  return emitOpError("unsupported reduction type ")
5445  << eltType << " for kind '" << stringifyCombiningKind(getKind())
5446  << "'";
5447 
5448  return success();
5449 }
5450 
5452  RewritePatternSet &patterns, PatternBenefit benefit) {
5453  patterns
5454  .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
5455  ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
5456  StridedSliceConstantMaskFolder, TransposeFolder>(
5457  patterns.getContext(), benefit);
5458 }
5459 
5460 //===----------------------------------------------------------------------===//
5461 // SplatOp
5462 //===----------------------------------------------------------------------===//
5463 
5464 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
5465  auto constOperand = operands.front();
5466  if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
5467  return {};
5468 
5469  // SplatElementsAttr::get treats single value for second arg as being a splat.
5470  return SplatElementsAttr::get(getType(), {constOperand});
5471 }
5472 
5473 //===----------------------------------------------------------------------===//
5474 // WarpExecuteOnLane0Op
5475 //===----------------------------------------------------------------------===//
5476 
5478  p << "(" << getLaneid() << ")";
5479 
5480  SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
5481  auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
5482  p << "[" << warpSizeAttr.cast<IntegerAttr>().getInt() << "]";
5483 
5484  if (!getArgs().empty())
5485  p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
5486  if (!getResults().empty())
5487  p << " -> (" << getResults().getTypes() << ')';
5488  p << " ";
5489  p.printRegion(getRegion(),
5490  /*printEntryBlockArgs=*/true,
5491  /*printBlockTerminators=*/!getResults().empty());
5492  p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
5493 }
5494 
5495 ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
5496  OperationState &result) {
5497  // Create the region.
5498  result.regions.reserve(1);
5499  Region *warpRegion = result.addRegion();
5500 
5501  auto &builder = parser.getBuilder();
5503 
5504  // Parse predicate operand.
5505  if (parser.parseLParen() ||
5506  parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
5507  parser.parseRParen())
5508  return failure();
5509 
5510  int64_t warpSize;
5511  if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
5512  parser.parseRSquare())
5513  return failure();
5514  result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
5515  builder.getContext())),
5516  builder.getI64IntegerAttr(warpSize));
5517 
5518  if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
5519  return failure();
5520 
5521  llvm::SMLoc inputsOperandsLoc;
5523  SmallVector<Type> inputTypes;
5524  if (succeeded(parser.parseOptionalKeyword("args"))) {
5525  if (parser.parseLParen())
5526  return failure();
5527 
5528  inputsOperandsLoc = parser.getCurrentLocation();
5529  if (parser.parseOperandList(inputsOperands) ||
5530  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
5531  return failure();
5532  }
5533  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
5534  result.operands))
5535  return failure();
5536 
5537  // Parse optional results type list.
5538  if (parser.parseOptionalArrowTypeList(result.types))
5539  return failure();
5540  // Parse the region.
5541  if (parser.parseRegion(*warpRegion, /*arguments=*/{},
5542  /*argTypes=*/{}))
5543  return failure();
5544  WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
5545 
5546  // Parse the optional attribute list.
5547  if (parser.parseOptionalAttrDict(result.attributes))
5548  return failure();
5549  return success();
5550 }
5551 
5552 void WarpExecuteOnLane0Op::getSuccessorRegions(
5553  Optional<unsigned> index, ArrayRef<Attribute> operands,
5555  if (index) {
5556  regions.push_back(RegionSuccessor(getResults()));
5557  return;
5558  }
5559 
5560  // The warp region is always executed
5561  regions.push_back(RegionSuccessor(&getWarpRegion()));
5562 }
5563 
5564 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
5565  TypeRange resultTypes, Value laneId,
5566  int64_t warpSize) {
5567  build(builder, result, resultTypes, laneId, warpSize,
5568  /*operands=*/std::nullopt, /*argTypes=*/std::nullopt);
5569 }
5570 
5571 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
5572  TypeRange resultTypes, Value laneId,
5573  int64_t warpSize, ValueRange args,
5574  TypeRange blockArgTypes) {
5575  result.addOperands(laneId);
5576  result.addAttribute(getAttributeNames()[0],
5577  builder.getI64IntegerAttr(warpSize));
5578  result.addTypes(resultTypes);
5579  result.addOperands(args);
5580  assert(args.size() == blockArgTypes.size());
5581  OpBuilder::InsertionGuard guard(builder);
5582  Region *warpRegion = result.addRegion();
5583  Block *block = builder.createBlock(warpRegion);
5584  for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
5585  block->addArgument(type, arg.getLoc());
5586 }
5587 
5588 /// Helper check if the distributed vector type is consistent with the expanded
5589 /// type and distributed size.
5590 static LogicalResult verifyDistributedType(Type expanded, Type distributed,
5591  int64_t warpSize, Operation *op) {
5592  // If the types matches there is no distribution.
5593  if (expanded == distributed)
5594  return success();
5595  auto expandedVecType = expanded.dyn_cast<VectorType>();
5596  auto distributedVecType = distributed.dyn_cast<VectorType>();
5597  if (!expandedVecType || !distributedVecType)
5598  return op->emitOpError("expected vector type for distributed operands.");
5599  if (expandedVecType.getRank() != distributedVecType.getRank() ||
5600  expandedVecType.getElementType() != distributedVecType.getElementType())
5601  return op->emitOpError(
5602  "expected distributed vectors to have same rank and element type.");
5603  bool foundDistributedDim = false;
5604  for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
5605  if (expandedVecType.getDimSize(i) == distributedVecType.getDimSize(i))
5606  continue;
5607  if (expandedVecType.getDimSize(i) ==
5608  distributedVecType.getDimSize(i) * warpSize) {
5609  if (foundDistributedDim)
5610  return op->emitOpError()
5611  << "expected only one dimension to be distributed from "
5612  << expandedVecType << " to " << distributedVecType;
5613  foundDistributedDim = true;
5614  continue;
5615  }
5616  return op->emitOpError() << "incompatible distribution dimensions from "
5617  << expandedVecType << " to " << distributedVecType;
5618  }
5619  return success();
5620 }
5621 
5623  if (getArgs().size() != getWarpRegion().getNumArguments())
5624  return emitOpError(
5625  "expected same number op arguments and block arguments.");
5626  auto yield =
5627  cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
5628  if (yield.getNumOperands() != getNumResults())
5629  return emitOpError(
5630  "expected same number of yield operands and return values.");
5631  int64_t warpSize = getWarpSize();
5632  for (auto [regionArg, arg] :
5633  llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
5634  if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
5635  warpSize, getOperation())))
5636  return failure();
5637  }
5638  for (auto [yieldOperand, result] :
5639  llvm::zip_equal(yield.getOperands(), getResults())) {
5640  if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
5641  warpSize, getOperation())))
5642  return failure();
5643  }
5644  return success();
5645 }
5646 
5647 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
5648  return succeeded(
5649  verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
5650 }
5651 
5653  CombiningKind kind, Value v1, Value v2) {
5654  Type t1 = getElementTypeOrSelf(v1.getType());
5655  Type t2 = getElementTypeOrSelf(v2.getType());
5656  switch (kind) {
5657  case CombiningKind::ADD:
5658  if (t1.isIntOrIndex() && t2.isIntOrIndex())
5659  return b.createOrFold<arith::AddIOp>(loc, v1, v2);
5660  else if (t1.isa<FloatType>() && t2.isa<FloatType>())
5661  return b.createOrFold<arith::AddFOp>(loc, v1, v2);
5662  llvm_unreachable("invalid value types for ADD reduction");
5663  case CombiningKind::AND:
5664  assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5665  return b.createOrFold<arith::AndIOp>(loc, v1, v2);
5666  case CombiningKind::MAXF:
5667  assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
5668  "expected float values");
5669  return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
5670  case CombiningKind::MINF:
5671  assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
5672  "expected float values");
5673  return b.createOrFold<arith::MinFOp>(loc, v1, v2);
5674  case CombiningKind::MAXSI:
5675  assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5676  return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
5677  case CombiningKind::MINSI:
5678  assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5679  return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
5680  case CombiningKind::MAXUI:
5681  assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5682  return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
5683  case CombiningKind::MINUI:
5684  assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5685  return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
5686  case CombiningKind::MUL:
5687  if (t1.isIntOrIndex() && t2.isIntOrIndex())
5688  return b.createOrFold<arith::MulIOp>(loc, v1, v2);
5689  else if (t1.isa<FloatType>() && t2.isa<FloatType>())
5690  return b.createOrFold<arith::MulFOp>(loc, v1, v2);
5691  llvm_unreachable("invalid value types for MUL reduction");
5692  case CombiningKind::OR:
5693  assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5694  return b.createOrFold<arith::OrIOp>(loc, v1, v2);
5695  case CombiningKind::XOR:
5696  assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
5697  return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
5698  };
5699  llvm_unreachable("unknown CombiningKind");
5700 }
5701 
5702 //===----------------------------------------------------------------------===//
5703 // TableGen'd op method definitions
5704 //===----------------------------------------------------------------------===//
5705 
5706 #define GET_ATTRDEF_CLASSES
5707 #include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc"
5708 
5709 #define GET_OP_CLASSES
5710 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static constexpr const bool value
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
Operation::operand_range getIndices(Operation *op)
static ArrayRef< int64_t > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:698
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition: VectorOps.cpp:63
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
Definition: VectorOps.cpp:865
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
Definition: VectorOps.cpp:1127
static VectorType inferTransferReadMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer read given its vector type and permutation map.
Definition: VectorOps.cpp:3351
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
Definition: VectorOps.cpp:3320
static Value foldExtractStridedOpFromInsertChain(ExtractOp op)
Fold extract_op fed from a chain of insertStridedSlice ops.
Definition: VectorOps.cpp:1515
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
Definition: VectorOps.cpp:1474
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
Definition: VectorOps.cpp:1365
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Definition: VectorOps.cpp:1705
static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width)
Definition: VectorOps.cpp:2015
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
Definition: VectorOps.cpp:112
MaskFormat
Helper enum to classify mask value.
Definition: VectorOps.cpp:53
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
Definition: VectorOps.cpp:2427
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
Definition: VectorOps.cpp:220
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
Definition: VectorOps.cpp:675
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
Definition: VectorOps.cpp:1726
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
Definition: VectorOps.cpp:2364
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
Definition: VectorOps.cpp:857
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
Definition: VectorOps.cpp:2384
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
Definition: VectorOps.cpp:2406
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Definition: VectorOps.cpp:1119
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
Definition: VectorOps.cpp:3206
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
Definition: VectorOps.cpp:686
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
Definition: VectorOps.cpp:2349
static Value foldExtractFromShapeCast(ExtractOp extractOp)
Definition: