MLIR  19.0.0git
VectorOps.cpp
Go to the documentation of this file.
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/Builders.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/IRMapping.h"
32 #include "mlir/IR/PatternMatch.h"
33 #include "mlir/IR/TypeUtilities.h"
35 #include "mlir/Support/LLVM.h"
37 #include "llvm/ADT/ArrayRef.h"
38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/StringSet.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/ADT/bit.h"
43 
44 #include <cassert>
45 #include <cstdint>
46 #include <numeric>
47 
48 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
49 // Pull in all enum type and utility function definitions.
50 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
51 
52 using namespace mlir;
53 using namespace mlir::vector;
54 
55 /// Helper enum to classify mask value.
56 enum class MaskFormat {
57  AllTrue = 0,
58  AllFalse = 1,
59  Unknown = 2,
60 };
61 
62 /// Helper method to classify a mask value. Currently, the method
63 /// looks "under the hood" of a constant value with dense attributes
64 /// and a constant mask operation (since the client may be called at
65 /// various stages during progressive lowering).
67  if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
68  // Inspect constant dense values. We count up for bits that
69  // are set, count down for bits that are cleared, and bail
70  // when a mix is detected.
71  if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
72  int64_t val = 0;
73  for (bool b : denseElts.getValues<bool>())
74  if (b && val >= 0)
75  val++;
76  else if (!b && val <= 0)
77  val--;
78  else
79  return MaskFormat::Unknown;
80  if (val > 0)
81  return MaskFormat::AllTrue;
82  if (val < 0)
83  return MaskFormat::AllFalse;
84  }
85  } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
86  // Inspect constant mask index. If the index exceeds the
87  // dimension size, all bits are set. If the index is zero
88  // or less, no bits are set.
89  ArrayAttr masks = m.getMaskDimSizes();
90  auto shape = m.getType().getShape();
91  bool allTrue = true;
92  bool allFalse = true;
93  for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
94  int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt();
95  if (i < dimSize)
96  allTrue = false;
97  if (i > 0)
98  allFalse = false;
99  }
100  if (allTrue)
101  return MaskFormat::AllTrue;
102  if (allFalse)
103  return MaskFormat::AllFalse;
104  } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
105  // Finds all-false create_masks. An all-true create_mask requires all
106  // dims to be constants, so that'll be folded to a constant_mask, then
107  // detected in the constant_mask case.
108  auto maskOperands = m.getOperands();
109  for (Value operand : maskOperands) {
110  if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
111  int64_t dimSize =
112  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
113  if (dimSize <= 0)
114  return MaskFormat::AllFalse;
115  }
116  }
117  return MaskFormat::Unknown;
118  }
119  return MaskFormat::Unknown;
120 }
121 
122 /// Default callback to build a region with a 'vector.yield' terminator with no
123 /// arguments.
125  builder.create<vector::YieldOp>(loc);
126 }
127 
128 // Helper for verifying combining kinds in contractions and reductions.
129 static bool isSupportedCombiningKind(CombiningKind combiningKind,
130  Type elementType) {
131  switch (combiningKind) {
132  case CombiningKind::ADD:
133  case CombiningKind::MUL:
134  return elementType.isIntOrIndexOrFloat();
136  case CombiningKind::MINSI:
137  case CombiningKind::MAXUI:
138  case CombiningKind::MAXSI:
139  case CombiningKind::AND:
140  case CombiningKind::OR:
141  case CombiningKind::XOR:
142  return elementType.isIntOrIndex();
143  case CombiningKind::MINNUMF:
144  case CombiningKind::MAXNUMF:
145  case CombiningKind::MINIMUMF:
146  case CombiningKind::MAXIMUMF:
147  return llvm::isa<FloatType>(elementType);
148  }
149  return false;
150 }
151 
153  VectorType vectorType) {
154  int64_t elementVectorRank = 0;
155  VectorType elementVectorType =
156  llvm::dyn_cast<VectorType>(shapedType.getElementType());
157  if (elementVectorType)
158  elementVectorRank += elementVectorType.getRank();
159  // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
160  // TODO: replace once we have 0-d vectors.
161  if (shapedType.getRank() == 0 &&
162  vectorType.getShape() == ArrayRef<int64_t>{1})
163  return AffineMap::get(
164  /*numDims=*/0, /*numSymbols=*/0,
165  getAffineConstantExpr(0, shapedType.getContext()));
167  shapedType.getRank(), vectorType.getRank() - elementVectorRank,
168  shapedType.getContext());
169 }
170 
171 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
172  vector::TransferReadOp read) {
173  return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
174  !read.getMask() && defWrite.getIndices() == read.getIndices() &&
175  defWrite.getVectorType() == read.getVectorType() &&
176  defWrite.getPermutationMap() == read.getPermutationMap();
177 }
178 
179 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
180  vector::TransferWriteOp priorWrite) {
181  return priorWrite.getIndices() == write.getIndices() &&
182  priorWrite.getMask() == write.getMask() &&
183  priorWrite.getVectorType() == write.getVectorType() &&
184  priorWrite.getPermutationMap() == write.getPermutationMap();
185 }
186 
188  VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
189  bool testDynamicValueUsingBounds) {
190  // For simplicity only look at transfer of same type.
191  if (transferA.getVectorType() != transferB.getVectorType())
192  return false;
193  unsigned rankOffset = transferA.getLeadingShapedRank();
194  for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
195  Value indexA = transferA.getIndices()[i];
196  Value indexB = transferB.getIndices()[i];
197  std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
198  std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
199 
200  if (i < rankOffset) {
201  // For leading dimensions, if we can prove that index are different we
202  // know we are accessing disjoint slices.
203  if (cstIndexA.has_value() && cstIndexB.has_value()) {
204  if (*cstIndexA != *cstIndexB)
205  return true;
206  continue;
207  }
208  if (testDynamicValueUsingBounds) {
209  // First try to see if we can fully compose and simplify the affine
210  // expression as a fast track.
211  FailureOr<uint64_t> delta =
213  if (succeeded(delta) && *delta != 0)
214  return true;
215 
216  FailureOr<bool> testEqual =
217  ValueBoundsConstraintSet::areEqual(indexA, indexB);
218  if (succeeded(testEqual) && !testEqual.value())
219  return true;
220  }
221  } else {
222  // For this dimension, we slice a part of the memref we need to make sure
223  // the intervals accessed don't overlap.
224  int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
225  if (cstIndexA.has_value() && cstIndexB.has_value()) {
226  int64_t distance = std::abs(*cstIndexA - *cstIndexB);
227  if (distance >= vectorDim)
228  return true;
229  continue;
230  }
231  if (testDynamicValueUsingBounds) {
232  // First try to see if we can fully compose and simplify the affine
233  // expression as a fast track.
234  FailureOr<int64_t> delta =
236  if (succeeded(delta) && std::abs(*delta) >= vectorDim)
237  return true;
238 
239  FailureOr<int64_t> computeDelta =
241  if (succeeded(computeDelta)) {
242  if (std::abs(computeDelta.value()) >= vectorDim)
243  return true;
244  }
245  }
246  }
247  }
248  return false;
249 }
250 
251 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
252  VectorTransferOpInterface transferB,
253  bool testDynamicValueUsingBounds) {
254  if (transferA.getSource() != transferB.getSource())
255  return false;
256  return isDisjointTransferIndices(transferA, transferB,
257  testDynamicValueUsingBounds);
258 }
259 
260 // Helper to iterate over n-D vector slice elements. Calculate the next
261 // `position` in the n-D vector of size `shape`, applying an offset `offsets`.
262 // Modifies the `position` in place. Returns a failure when `position` becomes
263 // the end position.
265  ArrayRef<int64_t> shape,
266  ArrayRef<int64_t> offsets) {
267  for (auto [posInDim, dimSize, offsetInDim] :
268  llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
269  ++posInDim;
270  if (posInDim < dimSize + offsetInDim)
271  return success();
272 
273  // Carry the overflow to the next loop iteration.
274  posInDim = offsetInDim;
275  }
276 
277  return failure();
278 }
279 
280 /// Returns the integer numbers in `values`. `values` are expected to be
281 /// constant operations.
284  llvm::transform(values, std::back_inserter(ints), [](Value value) {
285  auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
286  assert(constOp && "Unexpected non-constant index");
287  return constOp.value();
288  });
289  return ints;
290 }
291 
292 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
293 /// be constant operations.
296  llvm::transform(
297  foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
298  assert(foldResult.is<Attribute>() && "Unexpected non-constant index");
299  return cast<IntegerAttr>(foldResult.get<Attribute>()).getInt();
300  });
301  return ints;
302 }
303 
304 /// Convert `foldResults` into Values. Integer attributes are converted to
305 /// constant op.
307  ArrayRef<OpFoldResult> foldResults) {
308  SmallVector<Value> values;
309  llvm::transform(foldResults, std::back_inserter(values),
310  [&](OpFoldResult foldResult) {
311  if (auto attr = foldResult.dyn_cast<Attribute>())
312  return builder
314  loc, cast<IntegerAttr>(attr).getInt())
315  .getResult();
316 
317  return foldResult.get<Value>();
318  });
319  return values;
320 }
321 
322 //===----------------------------------------------------------------------===//
323 // CombiningKindAttr
324 //===----------------------------------------------------------------------===//
325 
326 namespace mlir {
327 namespace vector {
328 namespace detail {
330  using KeyTy = uint64_t;
331 
332  BitmaskEnumStorage(KeyTy val) : value(val) {}
333 
334  bool operator==(const KeyTy &key) const { return value == key; }
335 
337  const KeyTy &key) {
338  return new (allocator.allocate<BitmaskEnumStorage>())
339  BitmaskEnumStorage(key);
340  }
341 
342  KeyTy value = 0;
343 };
344 } // namespace detail
345 } // namespace vector
346 } // namespace mlir
347 
348 //===----------------------------------------------------------------------===//
349 // VectorDialect
350 //===----------------------------------------------------------------------===//
351 
352 namespace {
353 /// This class defines the interface for handling inlining with vector dialect
354 /// operations.
355 struct VectorInlinerInterface : public DialectInlinerInterface {
357 
358  /// All vector dialect ops can be inlined.
359  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
360  return true;
361  }
362 };
363 } // namespace
364 
365 void VectorDialect::initialize() {
366  addAttributes<
367 #define GET_ATTRDEF_LIST
368 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
369  >();
370 
371  addOperations<
372 #define GET_OP_LIST
373 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
374  >();
375 
376  addInterfaces<VectorInlinerInterface>();
377 }
378 
379 /// Materialize a single constant operation from a given attribute value with
380 /// the desired resultant type.
382  Attribute value, Type type,
383  Location loc) {
384  return arith::ConstantOp::materialize(builder, value, type, loc);
385 }
386 
388  return builder.getIntegerType(64);
389 }
390 
392  ArrayRef<int64_t> values) {
393  return builder.getI64ArrayAttr(values);
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // MultiDimReductionOp
398 //===----------------------------------------------------------------------===//
399 
400 void vector::MultiDimReductionOp::build(OpBuilder &builder,
401  OperationState &result, Value source,
402  Value acc, ArrayRef<bool> reductionMask,
403  CombiningKind kind) {
404  SmallVector<int64_t> reductionDims;
405  for (const auto &en : llvm::enumerate(reductionMask))
406  if (en.value())
407  reductionDims.push_back(en.index());
408  build(builder, result, kind, source, acc,
409  builder.getI64ArrayAttr(reductionDims));
410 }
411 
412 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
413  // Single parallel dim, this is a noop.
414  if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
415  return getSource();
416  return {};
417 }
418 
419 std::optional<SmallVector<int64_t, 4>>
420 MultiDimReductionOp::getShapeForUnroll() {
421  return llvm::to_vector<4>(getSourceVectorType().getShape());
422 }
423 
425  SmallVector<int64_t> targetShape;
426  SmallVector<bool> scalableDims;
427  Type inferredReturnType;
428  auto sourceScalableDims = getSourceVectorType().getScalableDims();
429  for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
430  if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
431  return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
432  })) {
433  targetShape.push_back(it.value());
434  scalableDims.push_back(sourceScalableDims[it.index()]);
435  }
436  // TODO: update to also allow 0-d vectors when available.
437  if (targetShape.empty())
438  inferredReturnType = getSourceVectorType().getElementType();
439  else
440  inferredReturnType = VectorType::get(
441  targetShape, getSourceVectorType().getElementType(), scalableDims);
442  if (getType() != inferredReturnType)
443  return emitOpError() << "destination type " << getType()
444  << " is incompatible with source type "
445  << getSourceVectorType();
446 
447  return success();
448 }
449 
450 /// Returns the mask type expected by this operation.
451 Type MultiDimReductionOp::getExpectedMaskType() {
452  auto vecType = getSourceVectorType();
453  return VectorType::get(vecType.getShape(),
454  IntegerType::get(vecType.getContext(), /*width=*/1),
455  vecType.getScalableDims());
456 }
457 
458 namespace {
459 // Only unit dimensions that are being reduced are folded. If the dimension is
460 // unit, but not reduced, it is not folded, thereby keeping the output type the
461 // same. If not all dimensions which are reduced are of unit dimension, this
462 // transformation does nothing. This is just a generalization of
463 // ElideSingleElementReduction for ReduceOp.
464 struct ElideUnitDimsInMultiDimReduction
465  : public OpRewritePattern<MultiDimReductionOp> {
467 
468  LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
469  PatternRewriter &rewriter) const override {
470  ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
471  for (const auto &dim : enumerate(shape)) {
472  if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
473  return failure();
474  }
475 
476  // Vector mask setup.
477  OpBuilder::InsertionGuard guard(rewriter);
478  Operation *rootOp;
479  Value mask;
480  if (reductionOp.isMasked()) {
481  rewriter.setInsertionPoint(reductionOp.getMaskingOp());
482  rootOp = reductionOp.getMaskingOp();
483  mask = reductionOp.getMaskingOp().getMask();
484  } else {
485  rootOp = reductionOp;
486  }
487 
488  Location loc = reductionOp.getLoc();
489  Value acc = reductionOp.getAcc();
490  Value cast;
491  if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
492  if (mask) {
493  VectorType newMaskType =
494  VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
495  dstVecType.getScalableDims());
496  mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
497  }
498  cast = rewriter.create<vector::ShapeCastOp>(
499  loc, reductionOp.getDestType(), reductionOp.getSource());
500  } else {
501  // This means we are reducing all the dimensions, and all reduction
502  // dimensions are of size 1. So a simple extraction would do.
503  SmallVector<int64_t> zeroIdx(shape.size(), 0);
504  if (mask)
505  mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx);
506  cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(),
507  zeroIdx);
508  }
509 
510  Value result =
511  vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
512  cast, /*fastmath=*/nullptr, mask);
513  rewriter.replaceOp(rootOp, result);
514  return success();
515  }
516 };
517 } // namespace
518 
519 void MultiDimReductionOp::getCanonicalizationPatterns(
520  RewritePatternSet &results, MLIRContext *context) {
521  results.add<ElideUnitDimsInMultiDimReduction>(context);
522 }
523 
524 //===----------------------------------------------------------------------===//
525 // ReductionOp
526 //===----------------------------------------------------------------------===//
527 
528 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
529  CombiningKind kind, Value vector,
530  arith::FastMathFlags fastMathFlags) {
531  build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
532 }
533 
534 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
535  CombiningKind kind, Value vector, Value acc,
536  arith::FastMathFlags fastMathFlags) {
537  build(builder, result,
538  llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
539  acc, fastMathFlags);
540 }
541 
543  // Verify for 0-D and 1-D vector.
544  int64_t rank = getSourceVectorType().getRank();
545  if (rank > 1)
546  return emitOpError("unsupported reduction rank: ") << rank;
547 
548  // Verify supported reduction kind.
549  Type eltType = getDest().getType();
550  if (!isSupportedCombiningKind(getKind(), eltType))
551  return emitOpError("unsupported reduction type '")
552  << eltType << "' for kind '" << stringifyCombiningKind(getKind())
553  << "'";
554 
555  return success();
556 }
557 
558 // MaskableOpInterface methods.
559 
560 /// Returns the mask type expected by this operation.
561 Type ReductionOp::getExpectedMaskType() {
562  auto vecType = getSourceVectorType();
563  return VectorType::get(vecType.getShape(),
564  IntegerType::get(vecType.getContext(), /*width=*/1),
565  vecType.getScalableDims());
566 }
567 
569  OpBuilder &builder, Location loc,
570  Value vector) {
571  switch (op) {
572  case arith::AtomicRMWKind::addf:
573  case arith::AtomicRMWKind::addi:
574  return builder.create<vector::ReductionOp>(vector.getLoc(),
575  CombiningKind::ADD, vector);
576  case arith::AtomicRMWKind::mulf:
577  case arith::AtomicRMWKind::muli:
578  return builder.create<vector::ReductionOp>(vector.getLoc(),
579  CombiningKind::MUL, vector);
580  case arith::AtomicRMWKind::minimumf:
581  return builder.create<vector::ReductionOp>(vector.getLoc(),
582  CombiningKind::MINIMUMF, vector);
583  case arith::AtomicRMWKind::mins:
584  return builder.create<vector::ReductionOp>(vector.getLoc(),
585  CombiningKind::MINSI, vector);
586  case arith::AtomicRMWKind::minu:
587  return builder.create<vector::ReductionOp>(vector.getLoc(),
588  CombiningKind::MINUI, vector);
589  case arith::AtomicRMWKind::maximumf:
590  return builder.create<vector::ReductionOp>(vector.getLoc(),
591  CombiningKind::MAXIMUMF, vector);
592  case arith::AtomicRMWKind::maxs:
593  return builder.create<vector::ReductionOp>(vector.getLoc(),
594  CombiningKind::MAXSI, vector);
595  case arith::AtomicRMWKind::maxu:
596  return builder.create<vector::ReductionOp>(vector.getLoc(),
597  CombiningKind::MAXUI, vector);
598  case arith::AtomicRMWKind::andi:
599  return builder.create<vector::ReductionOp>(vector.getLoc(),
600  CombiningKind::AND, vector);
601  case arith::AtomicRMWKind::ori:
602  return builder.create<vector::ReductionOp>(vector.getLoc(),
603  CombiningKind::OR, vector);
604  // TODO: Add remaining reduction operations.
605  default:
606  (void)emitOptionalError(loc, "Reduction operation type not supported");
607  break;
608  }
609  return nullptr;
610 }
611 
612 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
613  return llvm::to_vector<4>(getSourceVectorType().getShape());
614 }
615 
616 namespace {
617 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
619 
620  LogicalResult matchAndRewrite(ReductionOp reductionOp,
621  PatternRewriter &rewriter) const override {
622  // Vector mask setup.
623  OpBuilder::InsertionGuard guard(rewriter);
624  auto maskableOp =
625  cast<vector::MaskableOpInterface>(reductionOp.getOperation());
626  Operation *rootOp;
627  Value mask;
628  if (maskableOp.isMasked()) {
629  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
630  rootOp = maskableOp.getMaskingOp();
631  mask = maskableOp.getMaskingOp().getMask();
632  } else {
633  rootOp = reductionOp;
634  }
635 
636  auto vectorType = reductionOp.getSourceVectorType();
637  if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
638  return failure();
639 
640  Location loc = reductionOp.getLoc();
641  Value result;
642  if (vectorType.getRank() == 0) {
643  if (mask)
644  mask = rewriter.create<ExtractElementOp>(loc, mask);
645  result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
646  } else {
647  if (mask)
648  mask = rewriter.create<ExtractOp>(loc, mask, 0);
649  result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0);
650  }
651 
652  if (Value acc = reductionOp.getAcc())
653  result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
654  result, acc,
655  reductionOp.getFastmathAttr(), mask);
656 
657  rewriter.replaceOp(rootOp, result);
658  return success();
659  }
660 };
661 } // namespace
662 
663 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
664  MLIRContext *context) {
665  results.add<ElideSingleElementReduction>(context);
666 }
667 
668 //===----------------------------------------------------------------------===//
669 // ContractionOp
670 //===----------------------------------------------------------------------===//
671 
672 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
673  Value lhs, Value rhs, Value acc,
674  ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
675  ArrayRef<IteratorType> iteratorTypes) {
676  result.addOperands({lhs, rhs, acc});
677  result.addTypes(acc.getType());
678  result.addAttribute(
679  getIndexingMapsAttrName(result.name),
680  builder.getAffineMapArrayAttr(
681  AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
682  result.addAttribute(
683  getIteratorTypesAttrName(result.name),
684  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
685  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
686  return IteratorTypeAttr::get(builder.getContext(), t);
687  }))));
688 }
689 
690 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
691  Value lhs, Value rhs, Value acc,
692  ArrayAttr indexingMaps,
693  ArrayAttr iteratorTypes) {
694  build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
695  ContractionOp::getDefaultKind());
696 }
697 
698 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
699  Value lhs, Value rhs, Value acc,
700  ArrayAttr indexingMaps,
701  ArrayAttr iteratorTypes, CombiningKind kind) {
702  result.addOperands({lhs, rhs, acc});
703  result.addTypes(acc.getType());
704  result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
705  result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
706  result.addAttribute(getKindAttrName(result.name),
707  CombiningKindAttr::get(builder.getContext(), kind));
708 }
709 
715  SmallVector<Type, 2> types;
716  Type resultType;
717  auto loc = parser.getCurrentLocation();
718  DictionaryAttr dictAttr;
719  // TODO: Unify linalg op attribute parsing.
720  if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
721  parser.parseComma() || parser.parseOperand(rhsInfo) ||
722  parser.parseComma() || parser.parseOperand(accInfo) ||
723  parser.parseTrailingOperandList(masksInfo) ||
724  parser.parseOptionalAttrDict(result.attributes) ||
725  parser.parseColonTypeList(types) ||
726  parser.parseKeywordType("into", resultType) ||
727  parser.resolveOperand(lhsInfo, types[0], result.operands) ||
728  parser.resolveOperand(rhsInfo, types[1], result.operands) ||
729  parser.resolveOperand(accInfo, resultType, result.operands) ||
730  parser.addTypeToList(resultType, result.types))
731  return failure();
732  result.attributes.append(dictAttr.getValue().begin(),
733  dictAttr.getValue().end());
734 
735  // Convert array of string into an array of IteratyType enums. This is needed,
736  // because tests still use the old format when 'iterator_types' attribute is
737  // represented as an array of strings.
738  // TODO: Remove this conversion once tests are fixed.
739  ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
740  result.attributes.get(getIteratorTypesAttrName(result.name)));
741 
742  SmallVector<Attribute> iteratorTypeAttrs;
743 
744  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
745  auto maybeIteratorType = symbolizeIteratorType(s);
746  if (!maybeIteratorType.has_value())
747  return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
748 
749  iteratorTypeAttrs.push_back(
750  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
751  }
752  result.attributes.set(getIteratorTypesAttrName(result.name),
753  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
754 
755  if (!result.attributes.get(getKindAttrName(result.name))) {
756  result.addAttribute(
757  getKindAttrName(result.name),
759  ContractionOp::getDefaultKind()));
760  }
761  if (masksInfo.empty())
762  return success();
763  if (masksInfo.size() != 2)
764  return parser.emitError(parser.getNameLoc(),
765  "expected zero or exactly 2 vector mask operands");
766  auto lhsType = llvm::cast<VectorType>(types[0]);
767  auto rhsType = llvm::cast<VectorType>(types[1]);
768  auto maskElementType = parser.getBuilder().getI1Type();
769  std::array<VectorType, 2> maskTypes = {
770  VectorType::Builder(lhsType).setElementType(maskElementType),
771  VectorType::Builder(rhsType).setElementType(maskElementType)};
772  if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
773  return failure();
774  return success();
775 }
776 
778  // TODO: Unify printing code with linalg ops.
779  auto attrNames = getTraitAttrNames();
780  llvm::StringSet<> traitAttrsSet;
781  traitAttrsSet.insert(attrNames.begin(), attrNames.end());
783  for (auto attr : (*this)->getAttrs()) {
784  if (attr.getName() == getIteratorTypesAttrName()) {
785  auto iteratorTypes =
786  llvm::cast<ArrayAttr>(attr.getValue())
787  .getAsValueRange<IteratorTypeAttr, IteratorType>();
788  // Convert IteratorType enums into the string representation. This is
789  // needed, because tests still use the old format when 'iterator_types'
790  // attribute is represented as an array of strings.
791  // TODO: Remove this conversion once tests are fixed.
792  SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
793  llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
794  return StringAttr::get(getContext(), stringifyIteratorType(t));
795  }));
796 
797  attrs.emplace_back(getIteratorTypesAttrName(),
798  ArrayAttr::get(getContext(), iteratorTypeNames));
799  } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
800  attrs.push_back(attr);
801  }
802 
803  auto dictAttr = DictionaryAttr::get(getContext(), attrs);
804  p << " " << dictAttr << " " << getLhs() << ", ";
805  p << getRhs() << ", " << getAcc();
806 
807  p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
808  p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
809  << getResultType();
810 }
811 
812 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
813  const std::vector<std::pair<int64_t, int64_t>> &map) {
814  for (auto &dimPair : map) {
815  if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
816  dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
817  lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
818  return false;
819  }
820  return true;
821 }
822 
824  ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
825  Type resType,
826  const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
827  const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
828  DenseSet<int64_t> lhsContractingDimSet;
829  DenseSet<int64_t> rhsContractingDimSet;
830  for (auto &dimPair : contractingDimMap) {
831  lhsContractingDimSet.insert(dimPair.first);
832  rhsContractingDimSet.insert(dimPair.second);
833  }
834  DenseSet<int64_t> rhsBatchDimSet;
835  for (auto &dimPair : batchDimMap)
836  rhsBatchDimSet.insert(dimPair.second);
837 
838  // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
839  SmallVector<int64_t, 4> expectedResultDims;
840  for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
841  if (lhsContractingDimSet.count(i) > 0)
842  continue;
843  expectedResultDims.push_back(lhsType.getDimSize(i));
844  }
845 
846  // Add free dimensions from 'rhsType' to 'expectedResultDims'.
847  for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
848  if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
849  continue;
850  expectedResultDims.push_back(rhsType.getDimSize(i));
851  }
852 
853  // Verify 'expectedResultDims'.
854  if (expectedResultDims.empty()) {
855  // No batch or free dimension implies a scalar result.
856  if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
857  return op.emitOpError("invalid accumulator/result vector shape");
858  } else {
859  // At least one batch or free dimension implies a vector result.
860  auto resVectorType = llvm::dyn_cast<VectorType>(resType);
861  auto accVectorType = llvm::dyn_cast<VectorType>(accType);
862  if (!resVectorType || !accVectorType)
863  return op.emitOpError("invalid accumulator/result vector shape");
864 
865  // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
866  // types fully define the result vector type. This assumes the affine maps
867  // are well-formed, which must have been verified already.
868  MLIRContext *ctx = op.getContext();
869  AffineMap lhsMap = op.getIndexingMapsArray()[0];
870  AffineMap rhsMap = op.getIndexingMapsArray()[1];
871  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
872  return op.emitOpError(
873  "expected all dimensions to be either a LHS or a RHS dimension");
874  SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
875  for (auto pair :
876  {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
877  VectorType v = pair.first;
878  auto map = pair.second;
879  for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
880  unsigned pos = map.getDimPosition(idx);
881  if (!extents[pos])
882  extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
883  }
884  }
885  if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
886  return op.emitOpError("expected all dimensions to get an extent as "
887  "either a LHS or a RHS dimension");
888 
889  AffineMap resMap = op.getIndexingMapsArray()[2];
890  auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
891  /*symCount=*/0, extents, ctx);
892  // Compose the resMap with the extentsMap, which is a constant map.
893  AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
894  assert(
895  llvm::all_of(expectedMap.getResults(),
896  [](AffineExpr e) { return isa<AffineConstantExpr>(e); }) &&
897  "expected constant extent along all dimensions.");
898  // Extract the expected shape and build the type.
899  auto expectedShape = llvm::to_vector<4>(
900  llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
901  return cast<AffineConstantExpr>(e).getValue();
902  }));
903  auto expected =
904  VectorType::get(expectedShape, resVectorType.getElementType(),
905  resVectorType.getScalableDims());
906  if (resVectorType != expected || accVectorType != expected)
907  return op.emitOpError(
908  "invalid accumulator/result vector shape, expected: ")
909  << expected;
910  }
911  return success();
912 }
913 
915  VectorType lhsType = getLhsType();
916  VectorType rhsType = getRhsType();
917  Type accType = getAccType();
918  Type resType = getResultType();
919 
920  if (llvm::isa<IntegerType>(lhsType.getElementType())) {
921  if (!lhsType.getElementType().isSignlessInteger())
922  return emitOpError("only supports signless integer types");
923  }
924 
925  // Verify that an indexing map was specified for each vector operand.
926  if (getIndexingMapsArray().size() != 3)
927  return emitOpError("expected an indexing map for each vector operand");
928 
929  // Verify that each index map has 'numIterators' inputs, no symbols, and
930  // that the number of map outputs equals the rank of its associated
931  // vector operand.
932  unsigned numIterators = getIteratorTypes().getValue().size();
933  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
934  auto index = it.index();
935  auto map = it.value();
936  if (map.getNumSymbols() != 0)
937  return emitOpError("expected indexing map ")
938  << index << " to have no symbols";
939  auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
940  unsigned rank = vectorType ? vectorType.getShape().size() : 0;
941  // Verify that the map has the right number of inputs, outputs, and indices.
942  // This also correctly accounts for (..) -> () for rank-0 results.
943  if (map.getNumDims() != numIterators)
944  return emitOpError("expected indexing map ")
945  << index << " to have " << numIterators << " number of inputs";
946  if (map.getNumResults() != rank)
947  return emitOpError("expected indexing map ")
948  << index << " to have " << rank << " number of outputs";
949  if (!map.isProjectedPermutation())
950  return emitOpError("expected indexing map ")
951  << index << " to be a projected permutation of its inputs";
952  }
953 
954  auto contractingDimMap = getContractingDimMap();
955  auto batchDimMap = getBatchDimMap();
956 
957  // Verify at least one contracting dimension pair was specified.
958  if (contractingDimMap.empty())
959  return emitOpError("expected at least one contracting dimension pair");
960 
961  // Verify contracting dimension map was properly constructed.
962  if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
963  return emitOpError("invalid contracting dimension map");
964 
965  // Verify batch dimension map was properly constructed.
966  if (!verifyDimMap(lhsType, rhsType, batchDimMap))
967  return emitOpError("invalid batch dimension map");
968 
969  // Verify 'accType' and 'resType' shape.
970  if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
971  contractingDimMap, batchDimMap)))
972  return failure();
973 
974  // Verify supported combining kind.
975  auto vectorType = llvm::dyn_cast<VectorType>(resType);
976  auto elementType = vectorType ? vectorType.getElementType() : resType;
977  if (!isSupportedCombiningKind(getKind(), elementType))
978  return emitOpError("unsupported contraction type");
979 
980  return success();
981 }
982 
983 // MaskableOpInterface methods.
984 
985 /// Returns the mask type expected by this operation. Mostly used for
986 /// verification purposes. It requires the operation to be vectorized."
987 Type ContractionOp::getExpectedMaskType() {
988  auto indexingMaps = this->getIndexingMapsArray();
989  AffineMap lhsIdxMap = indexingMaps[0];
990  AffineMap rhsIdxMap = indexingMaps[1];
991  VectorType lhsType = this->getLhsType();
992  VectorType rhsType = this->getRhsType();
993 
994  unsigned numVecDims = lhsIdxMap.getNumDims();
995  SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
996  SmallVector<bool> maskShapeScalableDims(numVecDims, false);
997 
998  // Using the information in the indexing maps, extract the size of each
999  // dimension in the vector.contract operation from the two input operands.
1000  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1001  maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1002  maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1003  lhsType.getScalableDims()[dimIdx];
1004  }
1005  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1006  maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1007  maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1008  rhsType.getScalableDims()[dimIdx];
1009  }
1010 
1011  assert(!ShapedType::isDynamicShape(maskShape) &&
1012  "Mask shape couldn't be computed");
1013 
1014  return VectorType::get(maskShape,
1015  IntegerType::get(lhsType.getContext(), /*width=*/1),
1016  maskShapeScalableDims);
1017 }
1018 
1019 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1020  return SmallVector<StringRef>{getIndexingMapsAttrName(),
1021  getIteratorTypesAttrName(), getKindAttrName()};
1022 }
1023 
1024 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
1025  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1026  if (targetExpr == map.getResult(i))
1027  return i;
1028  return -1;
1029 }
1030 
1031 static std::vector<std::pair<int64_t, int64_t>>
1032 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1033  IteratorType targetIteratorType, MLIRContext *context) {
1034  std::vector<std::pair<int64_t, int64_t>> dimMap;
1035  for (const auto &it : llvm::enumerate(iteratorTypes)) {
1036  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1037  if (iteratorType != targetIteratorType)
1038  continue;
1039  // Search lhs/rhs map results for 'targetExpr'.
1040  auto targetExpr = getAffineDimExpr(it.index(), context);
1041  int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1042  int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1043  if (lhsDim >= 0 && rhsDim >= 0)
1044  dimMap.emplace_back(lhsDim, rhsDim);
1045  }
1046  return dimMap;
1047 }
1048 
1049 void ContractionOp::getIterationBounds(
1050  SmallVectorImpl<int64_t> &iterationBounds) {
1051  auto lhsShape = getLhsType().getShape();
1052  auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1053  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1054  SmallVector<int64_t, 2> iterationShape;
1055  for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1056  // Search lhs/rhs map results for 'targetExpr'.
1057  auto targetExpr = getAffineDimExpr(it.index(), getContext());
1058  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1059  if (iteratorType == IteratorType::reduction) {
1060  // Get reduction dim size from lhs shape (same size in rhsShape).
1061  int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1062  assert(lhsDimIndex >= 0);
1063  iterationBounds.push_back(lhsShape[lhsDimIndex]);
1064  continue;
1065  }
1066  // Get parallel dimension size from result shape.
1067  int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1068  assert(resDimIndex >= 0);
1069  assert(resVectorType != nullptr);
1070  iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1071  }
1072 }
1073 
1074 void ContractionOp::getIterationIndexMap(
1075  std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1076  unsigned numMaps = getIndexingMapsArray().size();
1077  iterationIndexMap.resize(numMaps);
1078  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1079  auto index = it.index();
1080  auto map = it.value();
1081  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1082  auto dim = cast<AffineDimExpr>(map.getResult(i));
1083  iterationIndexMap[index][dim.getPosition()] = i;
1084  }
1085  }
1086 }
1087 
1088 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1089  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1090  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1091  getContext());
1092 }
1093 
1094 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1095  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1096  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1097  getContext());
1098 }
1099 
1100 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1102  getIterationBounds(shape);
1103  return shape;
1104 }
1105 
1106 /// Return a fused vector::ContractionOp which represents a patterns such as:
1107 ///
1108 /// ```mlir
1109 /// %c0 = vector.constant 0: ...
1110 /// %c = vector.contract %a, %b, %c0: ...
1111 /// %e = add %c, %d: ...
1112 /// ```
1113 ///
1114 /// by:
1115 ///
1116 /// ```mlir
1117 /// %e = vector.contract %a, %b, %d: ...
1118 /// ```
1119 ///
1120 /// Return null if the canonicalization does not apply.
1121 // TODO: This should be a folding of Add into Contract in core but while they
1122 // live in different dialects, it is not possible without unnatural
1123 // dependencies.
1124 template <typename AddOpType>
1125 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1127 
1129  PatternRewriter &rewriter) const override {
1130  auto canonicalize = [&](Value maybeContraction,
1131  Value otherOperand) -> vector::ContractionOp {
1132  vector::ContractionOp contractionOp =
1133  dyn_cast_or_null<vector::ContractionOp>(
1134  maybeContraction.getDefiningOp());
1135  if (!contractionOp)
1136  return vector::ContractionOp();
1137  if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1138  contractionOp.getAcc().getDefiningOp())) {
1139  if (maybeZero.getValue() ==
1140  rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1141  IRMapping bvm;
1142  bvm.map(contractionOp.getAcc(), otherOperand);
1143  auto newContraction =
1144  cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1145  rewriter.replaceOp(addOp, newContraction.getResult());
1146  return newContraction;
1147  }
1148  }
1149  return vector::ContractionOp();
1150  };
1151 
1152  Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1153  vector::ContractionOp contract = canonicalize(a, b);
1154  contract = contract ? contract : canonicalize(b, a);
1155  return contract ? success() : failure();
1156  }
1157 };
1158 
1159 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1160  MLIRContext *context) {
1163 }
1164 
1165 //===----------------------------------------------------------------------===//
1166 // ExtractElementOp
1167 //===----------------------------------------------------------------------===//
1168 
1169 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1170  Value source) {
1171  result.addOperands({source});
1172  result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
1173 }
1174 
1176  VectorType vectorType = getSourceVectorType();
1177  if (vectorType.getRank() == 0) {
1178  if (getPosition())
1179  return emitOpError("expected position to be empty with 0-D vector");
1180  return success();
1181  }
1182  if (vectorType.getRank() != 1)
1183  return emitOpError("unexpected >1 vector rank");
1184  if (!getPosition())
1185  return emitOpError("expected position for 1-D vector");
1186  return success();
1187 }
1188 
1189 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1190  // Skip the 0-D vector here now.
1191  if (!adaptor.getPosition())
1192  return {};
1193 
1194  // Fold extractelement (splat X) -> X.
1195  if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
1196  return splat.getInput();
1197 
1198  // Fold extractelement(broadcast(X)) -> X.
1199  if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1200  if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
1201  return broadcast.getSource();
1202 
1203  auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1204  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1205  if (!pos || !src)
1206  return {};
1207 
1208  auto srcElements = src.getValues<Attribute>();
1209 
1210  uint64_t posIdx = pos.getInt();
1211  if (posIdx >= srcElements.size())
1212  return {};
1213 
1214  return srcElements[posIdx];
1215 }
1216 
1217 //===----------------------------------------------------------------------===//
1218 // ExtractOp
1219 //===----------------------------------------------------------------------===//
1220 
1221 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1222  Value source, int64_t position) {
1223  build(builder, result, source, ArrayRef<int64_t>{position});
1224 }
1225 
1226 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1227  Value source, OpFoldResult position) {
1228  build(builder, result, source, ArrayRef<OpFoldResult>{position});
1229 }
1230 
1231 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1232  Value source, ArrayRef<int64_t> position) {
1233  build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1234  builder.getDenseI64ArrayAttr(position));
1235 }
1236 
1237 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1238  Value source, ArrayRef<OpFoldResult> position) {
1239  SmallVector<int64_t> staticPos;
1240  SmallVector<Value> dynamicPos;
1241  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1242  build(builder, result, source, dynamicPos,
1243  builder.getDenseI64ArrayAttr(staticPos));
1244 }
1245 
1247 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1248  ExtractOp::Adaptor adaptor,
1249  SmallVectorImpl<Type> &inferredReturnTypes) {
1250  auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1251  if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1252  vectorType.getRank()) {
1253  inferredReturnTypes.push_back(vectorType.getElementType());
1254  } else {
1255  auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1256  vectorType.getRank());
1257  inferredReturnTypes.push_back(VectorType::get(
1258  vectorType.getShape().drop_front(n), vectorType.getElementType(),
1259  vectorType.getScalableDims().drop_front(n)));
1260  }
1261  return success();
1262 }
1263 
1264 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1265  // Allow extracting 1-element vectors instead of scalars.
1266  auto isCompatible = [](TypeRange l, TypeRange r) {
1267  auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1268  return vectorType && vectorType.getShape().equals({1}) &&
1269  vectorType.getElementType() == r.front();
1270  };
1271  if (l.size() == 1 && r.size() == 1 &&
1272  (isCompatible(l, r) || isCompatible(r, l)))
1273  return true;
1274  return l == r;
1275 }
1276 
1278  // Note: This check must come before getMixedPosition() to prevent a crash.
1279  auto dynamicMarkersCount =
1280  llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1281  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1282  return emitOpError(
1283  "mismatch between dynamic and static positions (kDynamic marker but no "
1284  "corresponding dynamic position) -- this can only happen due to an "
1285  "incorrect fold/rewrite");
1286  auto position = getMixedPosition();
1287  if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1288  return emitOpError(
1289  "expected position attribute of rank no greater than vector rank");
1290  for (auto [idx, pos] : llvm::enumerate(position)) {
1291  if (pos.is<Attribute>()) {
1292  int64_t constIdx = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
1293  if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
1294  return emitOpError("expected position attribute #")
1295  << (idx + 1)
1296  << " to be a non-negative integer smaller than the "
1297  "corresponding vector dimension";
1298  }
1299  }
1300  }
1301  return success();
1302 }
1303 
1304 template <typename IntType>
1305 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1306  return llvm::to_vector<4>(llvm::map_range(
1307  arrayAttr.getAsRange<IntegerAttr>(),
1308  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1309 }
1310 
1311 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1312 /// positions.
1313 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1314  if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1315  return failure();
1316 
1317  // TODO: Canonicalization for dynamic position not implemented yet.
1318  if (extractOp.hasDynamicPosition())
1319  return failure();
1320 
1321  SmallVector<int64_t> globalPosition;
1322  ExtractOp currentOp = extractOp;
1323  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1324  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1325  while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1326  currentOp = nextOp;
1327  // TODO: Canonicalization for dynamic position not implemented yet.
1328  if (currentOp.hasDynamicPosition())
1329  return failure();
1330  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1331  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1332  }
1333  extractOp.setOperand(0, currentOp.getVector());
1334  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1335  OpBuilder b(extractOp.getContext());
1336  std::reverse(globalPosition.begin(), globalPosition.end());
1337  extractOp.setStaticPosition(globalPosition);
1338  return success();
1339 }
1340 
1341 namespace {
1342 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1343 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1344 /// Compose TransposeOp permutations as we walk back.
1345 /// This helper class keeps an updated extraction position `extractPosition`
1346 /// with extra trailing sentinels.
1347 /// The sentinels encode the internal transposition status of the result vector.
1348 /// As we iterate, extractPosition is permuted and updated.
1349 class ExtractFromInsertTransposeChainState {
1350 public:
1351  ExtractFromInsertTransposeChainState(ExtractOp e);
1352 
1353  /// Iterate over producing insert and transpose ops until we find a fold.
1354  Value fold();
1355 
1356 private:
1357  /// Return true if the vector at position `a` is contained within the vector
1358  /// at position `b`. Under insert/extract semantics, this is the same as `a`
1359  /// is a prefix of `b`.
1360  template <typename ContainerA, typename ContainerB>
1361  bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1362  return a.size() <= b.size() &&
1363  std::equal(a.begin(), a.begin() + a.size(), b.begin());
1364  }
1365 
1366  /// Return true if the vector at position `a` intersects the vector at
1367  /// position `b`. Under insert/extract semantics, this is the same as equality
1368  /// of all entries of `a` that are >=0 with the corresponding entries of b.
1369  /// Comparison is on the common prefix (i.e. zip).
1370  template <typename ContainerA, typename ContainerB>
1371  bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1372  for (auto [elemA, elemB] : llvm::zip(a, b)) {
1373  if (elemA < 0 || elemB < 0)
1374  continue;
1375  if (elemA != elemB)
1376  return false;
1377  }
1378  return true;
1379  }
1380 
1381  /// Folding is only possible in the absence of an internal permutation in the
1382  /// result vector.
1383  bool canFold() {
1384  return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1385  }
1386 
1387  // Helper to get the next defining op of interest.
1388  void updateStateForNextIteration(Value v) {
1389  nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1390  nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1391  };
1392 
1393  // Case 1. If we hit a transpose, just compose the map and iterate.
1394  // Invariant: insert + transpose do not change rank, we can always compose.
1395  LogicalResult handleTransposeOp();
1396 
1397  // Case 2: the insert position matches extractPosition exactly, early return.
1398  LogicalResult handleInsertOpWithMatchingPos(Value &res);
1399 
1400  /// Case 3: if the insert position is a prefix of extractPosition, extract a
1401  /// portion of the source of the insert.
1402  /// Example:
1403  /// ```
1404  /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1405  /// // extractPosition == [1, 2, 3]
1406  /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1407  /// // can fold to vector.extract %source[0, 3]
1408  /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1409  /// ```
1410  /// To traverse through %source, we need to set the leading dims to 0 and
1411  /// drop the extra leading dims.
1412  /// This method updates the internal state.
1413  LogicalResult handleInsertOpWithPrefixPos(Value &res);
1414 
1415  /// Try to fold in place to extract(source, extractPosition) and return the
1416  /// folded result. Return null if folding is not possible (e.g. due to an
1417  /// internal tranposition in the result).
1418  Value tryToFoldExtractOpInPlace(Value source);
1419 
1420  ExtractOp extractOp;
1421  int64_t vectorRank;
1422  int64_t extractedRank;
1423 
1424  InsertOp nextInsertOp;
1425  TransposeOp nextTransposeOp;
1426 
1427  /// Sentinel values that encode the internal permutation status of the result.
1428  /// They are set to (-1, ... , -k) at the beginning and appended to
1429  /// `extractPosition`.
1430  /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1431  /// ensure that there is no internal transposition.
1432  /// Internal transposition cannot be accounted for with a folding pattern.
1433  // TODO: We could relax the internal transposition with an extra transposition
1434  // operation in a future canonicalizer.
1435  SmallVector<int64_t> sentinels;
1437 };
1438 } // namespace
1439 
1440 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1441  ExtractOp e)
1442  : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1443  extractedRank(extractOp.getNumIndices()) {
1444  assert(vectorRank >= extractedRank && "Extracted position overflow");
1445  sentinels.reserve(vectorRank - extractedRank);
1446  for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1447  sentinels.push_back(-(i + 1));
1448  extractPosition.assign(extractOp.getStaticPosition().begin(),
1449  extractOp.getStaticPosition().end());
1450  llvm::append_range(extractPosition, sentinels);
1451 }
1452 
1453 // Case 1. If we hit a transpose, just compose the map and iterate.
1454 // Invariant: insert + transpose do not change rank, we can always compose.
1455 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1456  // TODO: Canonicalization for dynamic position not implemented yet.
1457  if (extractOp.hasDynamicPosition())
1458  return failure();
1459 
1460  if (!nextTransposeOp)
1461  return failure();
1462  AffineMap m = inversePermutation(AffineMap::getPermutationMap(
1463  nextTransposeOp.getPermutation(), extractOp.getContext()));
1465  return success();
1466 }
1467 
1468 // Case 2: the insert position matches extractPosition exactly, early return.
1470 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1471  Value &res) {
1472  // TODO: Canonicalization for dynamic position not implemented yet.
1473  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1474  return failure();
1475 
1476  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1477  if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1478  return failure();
1479  // Case 2.a. early-exit fold.
1480  res = nextInsertOp.getSource();
1481  // Case 2.b. if internal transposition is present, canFold will be false.
1482  return success(canFold());
1483 }
1484 
1485 /// Case 3: if inserted position is a prefix of extractPosition,
1486 /// extract a portion of the source of the insertion.
1487 /// This method updates the internal state.
1489 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1490  // TODO: Canonicalization for dynamic position not implemented yet.
1491  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1492  return failure();
1493 
1494  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1495  if (!isContainedWithin(insertedPos, extractPosition))
1496  return failure();
1497  // Set leading dims to zero.
1498  std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1499  // Drop extra leading dims.
1500  extractPosition.erase(extractPosition.begin(),
1501  extractPosition.begin() + insertedPos.size());
1502  extractedRank = extractPosition.size() - sentinels.size();
1503  // Case 3.a. early-exit fold (break and delegate to post-while path).
1504  res = nextInsertOp.getSource();
1505  // Case 3.b. if internal transposition is present, canFold will be false.
1506  return success();
1507 }
1508 
1509 /// Try to fold in place to extract(source, extractPosition) and return the
1510 /// folded result. Return null if folding is not possible (e.g. due to an
1511 /// internal tranposition in the result).
1512 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1513  Value source) {
1514  // TODO: Canonicalization for dynamic position not implemented yet.
1515  if (extractOp.hasDynamicPosition())
1516  return Value();
1517 
1518  // If we can't fold (either internal transposition, or nothing to fold), bail.
1519  bool nothingToFold = (source == extractOp.getVector());
1520  if (nothingToFold || !canFold())
1521  return Value();
1522 
1523  // Otherwise, fold by updating the op inplace and return its result.
1524  OpBuilder b(extractOp.getContext());
1525  extractOp.setStaticPosition(
1526  ArrayRef(extractPosition).take_front(extractedRank));
1527  extractOp.getVectorMutable().assign(source);
1528  return extractOp.getResult();
1529 }
1530 
1531 /// Iterate over producing insert and transpose ops until we find a fold.
1532 Value ExtractFromInsertTransposeChainState::fold() {
1533  // TODO: Canonicalization for dynamic position not implemented yet.
1534  if (extractOp.hasDynamicPosition())
1535  return Value();
1536 
1537  Value valueToExtractFrom = extractOp.getVector();
1538  updateStateForNextIteration(valueToExtractFrom);
1539  while (nextInsertOp || nextTransposeOp) {
1540  // Case 1. If we hit a transpose, just compose the map and iterate.
1541  // Invariant: insert + transpose do not change rank, we can always compose.
1542  if (succeeded(handleTransposeOp())) {
1543  valueToExtractFrom = nextTransposeOp.getVector();
1544  updateStateForNextIteration(valueToExtractFrom);
1545  continue;
1546  }
1547 
1548  Value result;
1549  // Case 2: the position match exactly.
1550  if (succeeded(handleInsertOpWithMatchingPos(result)))
1551  return result;
1552 
1553  // Case 3: if the inserted position is a prefix of extractPosition, we can
1554  // just extract a portion of the source of the insert.
1555  if (succeeded(handleInsertOpWithPrefixPos(result)))
1556  return tryToFoldExtractOpInPlace(result);
1557 
1558  // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1559  // values. This is a more difficult case and we bail.
1560  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1561  if (isContainedWithin(extractPosition, insertedPos) ||
1562  intersectsWhereNonNegative(extractPosition, insertedPos))
1563  return Value();
1564 
1565  // Case 5: No intersection, we forward the extract to insertOp.dest().
1566  valueToExtractFrom = nextInsertOp.getDest();
1567  updateStateForNextIteration(valueToExtractFrom);
1568  }
1569  // If after all this we can fold, go for it.
1570  return tryToFoldExtractOpInPlace(valueToExtractFrom);
1571 }
1572 
1573 /// Returns true if the operation has a 0-D vector type operand or result.
1574 static bool hasZeroDimVectors(Operation *op) {
1575  auto hasZeroDimVectorType = [](Type type) -> bool {
1576  auto vecType = dyn_cast<VectorType>(type);
1577  return vecType && vecType.getRank() == 0;
1578  };
1579 
1580  return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1581  llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1582 }
1583 
1584 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1585 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1586  // TODO: Canonicalization for dynamic position not implemented yet.
1587  if (extractOp.hasDynamicPosition())
1588  return Value();
1589 
1590  Operation *defOp = extractOp.getVector().getDefiningOp();
1591  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1592  return Value();
1593 
1594  // 0-D vectors not supported.
1595  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1596  if (hasZeroDimVectors(defOp))
1597  return Value();
1598 
1599  Value source = defOp->getOperand(0);
1600  if (extractOp.getType() == source.getType())
1601  return source;
1602  auto getRank = [](Type type) {
1603  return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1604  : 0;
1605  };
1606 
1607  // If splat or broadcast from a scalar, just return the source scalar.
1608  unsigned broadcastSrcRank = getRank(source.getType());
1609  if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
1610  return source;
1611 
1612  unsigned extractResultRank = getRank(extractOp.getType());
1613  if (extractResultRank >= broadcastSrcRank)
1614  return Value();
1615  // Check that the dimension of the result haven't been broadcasted.
1616  auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1617  auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
1618  if (extractVecType && broadcastVecType &&
1619  extractVecType.getShape() !=
1620  broadcastVecType.getShape().take_back(extractResultRank))
1621  return Value();
1622 
1623  auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1624  int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1625 
1626  // Detect all the positions that come from "dim-1" broadcasting.
1627  // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1628  // extract position to `0` when extracting from the source operand.
1629  llvm::SetVector<int64_t> broadcastedUnitDims =
1630  broadcastOp.computeBroadcastedUnitDims();
1631  SmallVector<int64_t> extractPos(extractOp.getStaticPosition());
1632  int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1633  for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1634  if (broadcastedUnitDims.contains(i))
1635  extractPos[i] = 0;
1636  // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1637  // matching extract position when extracting from the source operand.
1638  int64_t rankDiff = broadcastSrcRank - extractResultRank;
1639  extractPos.erase(extractPos.begin(),
1640  std::next(extractPos.begin(), extractPos.size() - rankDiff));
1641  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1642  OpBuilder b(extractOp.getContext());
1643  extractOp.setOperand(0, source);
1644  extractOp.setStaticPosition(extractPos);
1645  return extractOp.getResult();
1646 }
1647 
1648 // Fold extractOp with source coming from ShapeCast op.
1649 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1650  // TODO: Canonicalization for dynamic position not implemented yet.
1651  if (extractOp.hasDynamicPosition())
1652  return Value();
1653 
1654  auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1655  if (!shapeCastOp)
1656  return Value();
1657 
1658  // 0-D vectors not supported.
1659  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1660  if (hasZeroDimVectors(shapeCastOp))
1661  return Value();
1662 
1663  // Get the nth dimension size starting from lowest dimension.
1664  auto getDimReverse = [](VectorType type, int64_t n) {
1665  return type.getShape().take_back(n + 1).front();
1666  };
1667  int64_t destinationRank =
1668  llvm::isa<VectorType>(extractOp.getType())
1669  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1670  : 0;
1671  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1672  return Value();
1673  if (destinationRank > 0) {
1674  auto destinationType =
1675  llvm::cast<VectorType>(extractOp.getResult().getType());
1676  for (int64_t i = 0; i < destinationRank; i++) {
1677  // The lowest dimension of the destination must match the lowest
1678  // dimension of the shapecast op source.
1679  // TODO: This case could be support in a canonicalization pattern.
1680  if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1681  getDimReverse(destinationType, i))
1682  return Value();
1683  }
1684  }
1685  // Extract the strides associated with the extract op vector source. Then use
1686  // this to calculate a linearized position for the extract.
1687  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1688  std::reverse(extractedPos.begin(), extractedPos.end());
1689  SmallVector<int64_t, 4> strides;
1690  int64_t stride = 1;
1691  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1692  strides.push_back(stride);
1693  stride *=
1694  getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1695  }
1696 
1697  int64_t position = linearize(extractedPos, strides);
1698  // Then extract the strides associated to the shapeCast op vector source and
1699  // delinearize the position using those strides.
1700  SmallVector<int64_t, 4> newStrides;
1701  int64_t numDimension =
1702  shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1703  stride = 1;
1704  for (int64_t i = 0; i < numDimension; i++) {
1705  newStrides.push_back(stride);
1706  stride *=
1707  getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1708  }
1709  std::reverse(newStrides.begin(), newStrides.end());
1710  SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1711  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1712  OpBuilder b(extractOp.getContext());
1713  extractOp.setStaticPosition(newPosition);
1714  extractOp.setOperand(0, shapeCastOp.getSource());
1715  return extractOp.getResult();
1716 }
1717 
1718 /// Fold an ExtractOp from ExtractStridedSliceOp.
1719 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1720  // TODO: Canonicalization for dynamic position not implemented yet.
1721  if (extractOp.hasDynamicPosition())
1722  return Value();
1723 
1724  auto extractStridedSliceOp =
1725  extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1726  if (!extractStridedSliceOp)
1727  return Value();
1728 
1729  // 0-D vectors not supported.
1730  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1731  if (hasZeroDimVectors(extractStridedSliceOp))
1732  return Value();
1733 
1734  // Return if 'extractStridedSliceOp' has non-unit strides.
1735  if (extractStridedSliceOp.hasNonUnitStrides())
1736  return Value();
1737 
1738  // Trim offsets for dimensions fully extracted.
1739  auto sliceOffsets =
1740  extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1741  while (!sliceOffsets.empty()) {
1742  size_t lastOffset = sliceOffsets.size() - 1;
1743  if (sliceOffsets.back() != 0 ||
1744  extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1745  extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1746  break;
1747  sliceOffsets.pop_back();
1748  }
1749  unsigned destinationRank = 0;
1750  if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1751  destinationRank = vecType.getRank();
1752  // The dimensions of the result need to be untouched by the
1753  // extractStridedSlice op.
1754  if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1755  sliceOffsets.size())
1756  return Value();
1757 
1758  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1759  assert(extractedPos.size() >= sliceOffsets.size());
1760  for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1761  extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1762  extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1763 
1764  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1765  OpBuilder b(extractOp.getContext());
1766  extractOp.setStaticPosition(extractedPos);
1767  return extractOp.getResult();
1768 }
1769 
1770 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1771 static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1772  // TODO: Canonicalization for dynamic position not implemented yet.
1773  if (extractOp.hasDynamicPosition())
1774  return Value();
1775 
1776  int64_t destinationRank =
1777  llvm::isa<VectorType>(extractOp.getType())
1778  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1779  : 0;
1780  auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1781  if (!insertOp)
1782  return Value();
1783 
1784  // 0-D vectors not supported.
1785  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1786  if (hasZeroDimVectors(insertOp))
1787  return Value();
1788 
1789  while (insertOp) {
1790  int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1791  insertOp.getSourceVectorType().getRank();
1792  if (destinationRank > insertOp.getSourceVectorType().getRank())
1793  return Value();
1794  auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1795  ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1796 
1797  if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1798  return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1799  }))
1800  return Value();
1801  bool disjoint = false;
1802  SmallVector<int64_t, 4> offsetDiffs;
1803  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1804  int64_t start = insertOffsets[dim];
1805  int64_t size =
1806  (dim < insertRankDiff)
1807  ? 1
1808  : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1809  int64_t end = start + size;
1810  int64_t offset = extractOffsets[dim];
1811  // Check if the start of the extract offset is in the interval inserted.
1812  if (start <= offset && offset < end) {
1813  if (dim >= insertRankDiff)
1814  offsetDiffs.push_back(offset - start);
1815  continue;
1816  }
1817  disjoint = true;
1818  break;
1819  }
1820  // The extract element chunk overlap with the vector inserted.
1821  if (!disjoint) {
1822  // If any of the inner dimensions are only partially inserted we have a
1823  // partial overlap.
1824  int64_t srcRankDiff =
1825  insertOp.getSourceVectorType().getRank() - destinationRank;
1826  for (int64_t i = 0; i < destinationRank; i++) {
1827  if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1828  insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1829  insertRankDiff))
1830  return Value();
1831  }
1832  extractOp.getVectorMutable().assign(insertOp.getSource());
1833  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1834  OpBuilder b(extractOp.getContext());
1835  extractOp.setStaticPosition(offsetDiffs);
1836  return extractOp.getResult();
1837  }
1838  // If the chunk extracted is disjoint from the chunk inserted, keep
1839  // looking in the insert chain.
1840  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1841  }
1842  return Value();
1843 }
1844 
1845 OpFoldResult ExtractOp::fold(FoldAdaptor) {
1846  if (getNumIndices() == 0)
1847  return getVector();
1849  return getResult();
1850  if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
1851  return res;
1852  if (auto res = foldExtractFromBroadcast(*this))
1853  return res;
1854  if (auto res = foldExtractFromShapeCast(*this))
1855  return res;
1856  if (auto val = foldExtractFromExtractStrided(*this))
1857  return val;
1858  if (auto val = foldExtractStridedOpFromInsertChain(*this))
1859  return val;
1860  return OpFoldResult();
1861 }
1862 
1863 namespace {
1864 
1865 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
1866 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
1867 public:
1869 
1870  LogicalResult matchAndRewrite(ExtractOp extractOp,
1871  PatternRewriter &rewriter) const override {
1872  Operation *defOp = extractOp.getVector().getDefiningOp();
1873  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1874  return failure();
1875 
1876  Value source = defOp->getOperand(0);
1877  if (extractOp.getType() == source.getType())
1878  return failure();
1879  auto getRank = [](Type type) {
1880  return llvm::isa<VectorType>(type)
1881  ? llvm::cast<VectorType>(type).getRank()
1882  : 0;
1883  };
1884  unsigned broadcastSrcRank = getRank(source.getType());
1885  unsigned extractResultRank = getRank(extractOp.getType());
1886  // We only consider the case where the rank of the source is less than or
1887  // equal to the rank of the extract dst. The other cases are handled in the
1888  // folding patterns.
1889  if (extractResultRank < broadcastSrcRank)
1890  return failure();
1891 
1892  // Special case if broadcast src is a 0D vector.
1893  if (extractResultRank == 0) {
1894  assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
1895  rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
1896  return success();
1897  }
1898  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1899  extractOp, extractOp.getType(), source);
1900  return success();
1901  }
1902 };
1903 
1904 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
1905 class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
1906 public:
1908 
1909  LogicalResult matchAndRewrite(ExtractOp extractOp,
1910  PatternRewriter &rewriter) const override {
1911  // Return if 'ExtractOp' operand is not defined by a splat vector
1912  // ConstantOp.
1913  Value sourceVector = extractOp.getVector();
1914  Attribute vectorCst;
1915  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
1916  return failure();
1917  auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
1918  if (!splat)
1919  return failure();
1920  TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
1921  if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1922  newAttr = DenseElementsAttr::get(vecDstType, newAttr);
1923  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
1924  return success();
1925  }
1926 };
1927 
1928 // Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
1929 class ExtractOpNonSplatConstantFolder final
1930  : public OpRewritePattern<ExtractOp> {
1931 public:
1933 
1934  LogicalResult matchAndRewrite(ExtractOp extractOp,
1935  PatternRewriter &rewriter) const override {
1936  // TODO: Canonicalization for dynamic position not implemented yet.
1937  if (extractOp.hasDynamicPosition())
1938  return failure();
1939 
1940  // Return if 'ExtractOp' operand is not defined by a compatible vector
1941  // ConstantOp.
1942  Value sourceVector = extractOp.getVector();
1943  Attribute vectorCst;
1944  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
1945  return failure();
1946 
1947  auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
1948  if (vecTy.isScalable())
1949  return failure();
1950 
1951  // The splat case is handled by `ExtractOpSplatConstantFolder`.
1952  auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
1953  if (!dense || dense.isSplat())
1954  return failure();
1955 
1956  // Calculate the linearized position of the continuous chunk of elements to
1957  // extract.
1958  llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
1959  copy(extractOp.getStaticPosition(), completePositions.begin());
1960  int64_t elemBeginPosition =
1961  linearize(completePositions, computeStrides(vecTy.getShape()));
1962  auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
1963 
1964  TypedAttr newAttr;
1965  if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
1966  SmallVector<Attribute> elementValues(
1967  denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
1968  newAttr = DenseElementsAttr::get(resVecTy, elementValues);
1969  } else {
1970  newAttr = *denseValuesBegin;
1971  }
1972 
1973  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
1974  return success();
1975  }
1976 };
1977 
1978 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
1979 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
1980 public:
1982 
1983  LogicalResult matchAndRewrite(ExtractOp extractOp,
1984  PatternRewriter &rewriter) const override {
1985  auto createMaskOp =
1986  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
1987  if (!createMaskOp)
1988  return failure();
1989 
1990  VectorType extractedMaskType =
1991  llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
1992 
1993  if (!extractedMaskType)
1994  return failure();
1995 
1996  auto maskOperands = createMaskOp.getOperands();
1997  ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
1998  VectorType maskType = createMaskOp.getVectorType();
1999 
2000  bool containsUnknownDims = false;
2001  bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2002 
2003  for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2004  dimIdx++) {
2005  int64_t pos = extractOpPos[dimIdx];
2006  Value operand = maskOperands[dimIdx];
2007  auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2008  if (!constantOp) {
2009  // Bounds of this dim unknown.
2010  containsUnknownDims = true;
2011  continue;
2012  }
2013 
2014  int64_t createMaskBound =
2015  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2016 
2017  if (pos != ShapedType::kDynamic) {
2018  // If any position is outside the range from the `create_mask`, then the
2019  // extracted mask will be all-false.
2020  allFalse |= pos >= createMaskBound;
2021  } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2022  // This dim is not all-true and since this is a dynamic index we don't
2023  // know if the extraction is within the true or false region.
2024  // Note: Zero dims have already handled via getMaskFormat().
2025  containsUnknownDims = true;
2026  }
2027  }
2028 
2029  if (allFalse) {
2030  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2031  extractOp, DenseElementsAttr::get(extractedMaskType, false));
2032  } else if (!containsUnknownDims) {
2033  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2034  extractOp, extractedMaskType,
2035  maskOperands.drop_front(extractOpPos.size()));
2036  } else {
2037  return failure();
2038  }
2039  return success();
2040  }
2041 };
2042 
2043 // Patterns to rewrite ExtractOp(ConstantMaskOp)
2044 //
2045 // When the result of ExtractOp is a subvector of input, we can rewrite it as
2046 // a ConstantMaskOp with subvector ranks.
2047 //
2048 // ExtractOp(ConstantMaskOp) -> ConstantMaskOp
2049 //
2050 // When the result of ExtractOp is a scalar, we can get the scalar value
2051 // directly.
2052 //
2053 // ExtractOp(ConstantMaskOp) -> ConstantOp
2054 class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
2055 public:
2057 
2058  LogicalResult matchAndRewrite(ExtractOp extractOp,
2059  PatternRewriter &rewriter) const override {
2060  auto constantMaskOp =
2061  extractOp.getVector().getDefiningOp<vector::ConstantMaskOp>();
2062  if (!constantMaskOp)
2063  return failure();
2064 
2065  // All indices must be static.
2066  ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2067  unsigned dynamicPosCount =
2068  llvm::count_if(extractOpPos, ShapedType::isDynamic);
2069  // If there is any dynamic position in ExtractOp, we cannot determine the
2070  // scalar value.
2071  if (dynamicPosCount)
2072  return failure();
2073 
2074  ArrayRef<Attribute> maskDimSizes =
2075  constantMaskOp.getMaskDimSizes().getValue();
2076  Type resultTy = extractOp.getResult().getType();
2077  if (resultTy.isa<mlir::VectorType>()) {
2078  auto resultVectorTy = resultTy.cast<mlir::VectorType>();
2079  int64_t resultRank = resultVectorTy.getRank();
2080  int64_t n = maskDimSizes.size();
2081  std::vector<int64_t> indices;
2082  for (auto i = n - resultRank; i < n; ++i)
2083  indices.push_back(cast<IntegerAttr>(maskDimSizes[i]).getInt());
2084 
2085  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
2086  extractOp, resultVectorTy,
2087  vector::getVectorSubscriptAttr(rewriter, indices));
2088 
2089  return success();
2090  } else if (resultTy.isa<mlir::IntegerType>()) {
2091  // ConstantMaskOp creates and returns a vector mask where elements of the
2092  // result vector are set to ‘0’ or ‘1’, based on whether the element
2093  // indices are contained within a hyper-rectangular region.
2094  // We go through ExtractOp static positions to determine the position is
2095  // within the hyper-rectangular region or not.
2096  Type boolType = rewriter.getI1Type();
2097  IntegerAttr setAttr = IntegerAttr::get(boolType, 1);
2098  for (size_t i = 0, end = extractOpPos.size(); i < end; ++i) {
2099  if (cast<IntegerAttr>(maskDimSizes[i]).getInt() <= extractOpPos[i]) {
2100  setAttr = IntegerAttr::get(boolType, 0);
2101  break;
2102  }
2103  }
2104 
2105  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, boolType,
2106  setAttr);
2107  return success();
2108  }
2109 
2110  return failure();
2111  }
2112 };
2113 
2114 // Folds extract(shape_cast(..)) into shape_cast when the total element count
2115 // does not change.
2116 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2117  PatternRewriter &rewriter) {
2118  auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2119  if (!castOp)
2120  return failure();
2121 
2122  VectorType sourceType = castOp.getSourceVectorType();
2123  auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2124  if (!targetType)
2125  return failure();
2126 
2127  if (sourceType.getNumElements() != targetType.getNumElements())
2128  return failure();
2129 
2130  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2131  castOp.getSource());
2132  return success();
2133 }
2134 
2135 } // namespace
2136 
2137 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2138  MLIRContext *context) {
2139  results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2140  ExtractOpFromBroadcast, ExtractOpFromCreateMask,
2141  ExtractOpFromConstantMask>(context);
2142  results.add(foldExtractFromShapeCastToShapeCast);
2143 }
2144 
2145 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
2146  SmallVectorImpl<int64_t> &results) {
2147  for (auto attr : arrayAttr)
2148  results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2149 }
2150 
2151 //===----------------------------------------------------------------------===//
2152 // FmaOp
2153 //===----------------------------------------------------------------------===//
2154 
2155 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2156  return llvm::to_vector<4>(getVectorType().getShape());
2157 }
2158 
2159 //===----------------------------------------------------------------------===//
2160 // BroadcastOp
2161 //===----------------------------------------------------------------------===//
2162 
2163 /// Return the dimensions of the result vector that were formerly ones in the
2164 /// source tensor and thus correspond to "dim-1" broadcasting.
2167  ArrayRef<int64_t> dstShape) {
2168  int64_t rankDiff = dstShape.size() - srcShape.size();
2169  int64_t dstDim = rankDiff;
2171  for (auto [s1, s2] :
2172  llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2173  if (s1 != s2) {
2174  assert(s1 == 1 && "expected dim-1 broadcasting");
2175  res.insert(dstDim);
2176  }
2177  ++dstDim;
2178  }
2179  return res;
2180 }
2181 
2183  // Scalar broadcast is without any unit dim broadcast.
2184  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2185  if (!srcVectorType)
2186  return {};
2187  return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2188  getResultVectorType().getShape());
2189 }
2190 
2191 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2192 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
2193 /// This requires (and asserts) that the broadcast is free of dim-1
2194 /// broadcasting.
2195 /// Since vector.broadcast only allows expanding leading dimensions, an extra
2196 /// vector.transpose may be inserted to make the broadcast possible.
2197 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2198 /// the helper will assert. This means:
2199 /// 1. `dstShape` must not be empty.
2200 /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2201 /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2202 // must match the `value` shape.
2203 Value BroadcastOp::createOrFoldBroadcastOp(
2204  OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2205  const llvm::SetVector<int64_t> &broadcastedDims) {
2206  assert(!dstShape.empty() && "unexpected empty dst shape");
2207 
2208  // Well-formedness check.
2209  SmallVector<int64_t> checkShape;
2210  for (int i = 0, e = dstShape.size(); i < e; ++i) {
2211  if (broadcastedDims.contains(i))
2212  continue;
2213  checkShape.push_back(dstShape[i]);
2214  }
2215  assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2216  "ill-formed broadcastedDims contains values not confined to "
2217  "destVectorShape");
2218 
2219  Location loc = value.getLoc();
2220  Type elementType = getElementTypeOrSelf(value.getType());
2221  VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2222  VectorType dstVectorType = VectorType::get(dstShape, elementType);
2223 
2224  // Step 2. If scalar -> dstShape broadcast, just do it.
2225  if (!srcVectorType) {
2226  assert(checkShape.empty() &&
2227  "ill-formed createOrFoldBroadcastOp arguments");
2228  return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2229  }
2230 
2231  assert(srcVectorType.getShape().equals(checkShape) &&
2232  "ill-formed createOrFoldBroadcastOp arguments");
2233 
2234  // Step 3. Since vector.broadcast only allows creating leading dims,
2235  // vector -> dstShape broadcast may require a transpose.
2236  // Traverse the dims in order and construct:
2237  // 1. The leading entries of the broadcastShape that is guaranteed to be
2238  // achievable by a simple broadcast.
2239  // 2. The induced permutation for the subsequent vector.transpose that will
2240  // bring us from `broadcastShape` back to he desired `dstShape`.
2241  // If the induced permutation is not the identity, create a vector.transpose.
2242  SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2243  broadcastShape.reserve(dstShape.size());
2244  // Consider the example:
2245  // srcShape = 2x4
2246  // dstShape = 1x2x3x4x5
2247  // broadcastedDims = [0, 2, 4]
2248  //
2249  // We want to build:
2250  // broadcastShape = 1x3x5x2x4
2251  // permutation = [0, 2, 4, 1, 3]
2252  // ---V--- -----V-----
2253  // leading broadcast part src shape part
2254  //
2255  // Note that the trailing dims of broadcastShape are exactly the srcShape
2256  // by construction.
2257  // nextSrcShapeDim is used to keep track of where in the permutation the
2258  // "src shape part" occurs.
2259  int64_t nextSrcShapeDim = broadcastedDims.size();
2260  for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2261  if (broadcastedDims.contains(i)) {
2262  // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2263  // bring it to the head of the broadcastShape.
2264  // It will need to be permuted back from `broadcastShape.size() - 1` into
2265  // position `i`.
2266  broadcastShape.push_back(dstShape[i]);
2267  permutation[i] = broadcastShape.size() - 1;
2268  } else {
2269  // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2270  // shape and needs to be permuted into position `i`.
2271  // Don't touch `broadcastShape` here, the whole srcShape will be
2272  // appended after.
2273  permutation[i] = nextSrcShapeDim++;
2274  }
2275  }
2276  // 3.c. Append the srcShape.
2277  llvm::append_range(broadcastShape, srcVectorType.getShape());
2278 
2279  // Ensure there are no dim-1 broadcasts.
2280  assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2281  .empty() &&
2282  "unexpected dim-1 broadcast");
2283 
2284  VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2285  assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2286  vector::BroadcastableToResult::Success &&
2287  "must be broadcastable");
2288  Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2289  // Step 4. If we find any dimension that indeed needs to be permuted,
2290  // immediately return a new vector.transpose.
2291  for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2292  if (permutation[i] != i)
2293  return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2294  // Otherwise return res.
2295  return res;
2296 }
2297 
2299 mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
2300  std::pair<int, int> *mismatchingDims) {
2301  // Broadcast scalar to vector of the same element type.
2302  if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
2303  getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
2304  return BroadcastableToResult::Success;
2305  // From now on, only vectors broadcast.
2306  VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2307  if (!srcVectorType)
2308  return BroadcastableToResult::SourceTypeNotAVector;
2309 
2310  int64_t srcRank = srcVectorType.getRank();
2311  int64_t dstRank = dstVectorType.getRank();
2312  if (srcRank > dstRank)
2313  return BroadcastableToResult::SourceRankHigher;
2314  // Source has an exact match or singleton value for all trailing dimensions
2315  // (all leading dimensions are simply duplicated).
2316  int64_t lead = dstRank - srcRank;
2317  for (int64_t r = 0; r < srcRank; ++r) {
2318  int64_t srcDim = srcVectorType.getDimSize(r);
2319  int64_t dstDim = dstVectorType.getDimSize(lead + r);
2320  if (srcDim != 1 && srcDim != dstDim) {
2321  if (mismatchingDims) {
2322  mismatchingDims->first = srcDim;
2323  mismatchingDims->second = dstDim;
2324  }
2325  return BroadcastableToResult::DimensionMismatch;
2326  }
2327  }
2328 
2329  return BroadcastableToResult::Success;
2330 }
2331 
2333  std::pair<int, int> mismatchingDims;
2335  getSourceType(), getResultVectorType(), &mismatchingDims);
2336  if (res == BroadcastableToResult::Success)
2337  return success();
2338  if (res == BroadcastableToResult::SourceRankHigher)
2339  return emitOpError("source rank higher than destination rank");
2340  if (res == BroadcastableToResult::DimensionMismatch)
2341  return emitOpError("dimension mismatch (")
2342  << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
2343  if (res == BroadcastableToResult::SourceTypeNotAVector)
2344  return emitOpError("source type is not a vector");
2345  llvm_unreachable("unexpected vector.broadcast op error");
2346 }
2347 
2348 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
2349  if (getSourceType() == getResultVectorType())
2350  return getSource();
2351  if (!adaptor.getSource())
2352  return {};
2353  auto vectorType = getResultVectorType();
2354  if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
2355  return DenseElementsAttr::get(vectorType, adaptor.getSource());
2356  if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2357  return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
2358  return {};
2359 }
2360 
2361 namespace {
2362 
2363 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
2364 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2366 
2367  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2368  PatternRewriter &rewriter) const override {
2369  auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2370  if (!srcBroadcast)
2371  return failure();
2372  rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
2373  broadcastOp.getResultVectorType(),
2374  srcBroadcast.getSource());
2375  return success();
2376  }
2377 };
2378 } // namespace
2379 
2380 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2381  MLIRContext *context) {
2382  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2383  // calling `populateCastAwayVectorLeadingOneDimPatterns`
2384  results.add<BroadcastFolder>(context);
2385 }
2386 
2387 //===----------------------------------------------------------------------===//
2388 // ShuffleOp
2389 //===----------------------------------------------------------------------===//
2390 
2391 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
2392  Value v2, ArrayRef<int64_t> mask) {
2393  build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask));
2394 }
2395 
2397  VectorType resultType = getResultVectorType();
2398  VectorType v1Type = getV1VectorType();
2399  VectorType v2Type = getV2VectorType();
2400  // Verify ranks.
2401  int64_t resRank = resultType.getRank();
2402  int64_t v1Rank = v1Type.getRank();
2403  int64_t v2Rank = v2Type.getRank();
2404  bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2405  bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2406  if (!wellFormed0DCase && !wellFormedNDCase)
2407  return emitOpError("rank mismatch");
2408 
2409  // Verify all but leading dimension sizes.
2410  for (int64_t r = 1; r < v1Rank; ++r) {
2411  int64_t resDim = resultType.getDimSize(r);
2412  int64_t v1Dim = v1Type.getDimSize(r);
2413  int64_t v2Dim = v2Type.getDimSize(r);
2414  if (resDim != v1Dim || v1Dim != v2Dim)
2415  return emitOpError("dimension mismatch");
2416  }
2417  // Verify mask length.
2418  auto maskAttr = getMask().getValue();
2419  int64_t maskLength = maskAttr.size();
2420  if (maskLength <= 0)
2421  return emitOpError("invalid mask length");
2422  if (maskLength != resultType.getDimSize(0))
2423  return emitOpError("mask length mismatch");
2424  // Verify all indices.
2425  int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2426  (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2427  for (const auto &en : llvm::enumerate(maskAttr)) {
2428  auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
2429  if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
2430  return emitOpError("mask index #") << (en.index() + 1) << " out of range";
2431  }
2432  return success();
2433 }
2434 
2436 ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
2437  ShuffleOp::Adaptor adaptor,
2438  SmallVectorImpl<Type> &inferredReturnTypes) {
2439  auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2440  auto v1Rank = v1Type.getRank();
2441  // Construct resulting type: leading dimension matches mask
2442  // length, all trailing dimensions match the operands.
2444  shape.reserve(v1Rank);
2445  shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2446  // In the 0-D case there is no trailing shape to append.
2447  if (v1Rank > 0)
2448  llvm::append_range(shape, v1Type.getShape().drop_front());
2449  inferredReturnTypes.push_back(
2450  VectorType::get(shape, v1Type.getElementType()));
2451  return success();
2452 }
2453 
2454 static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
2455  uint64_t expected = begin;
2456  return idxArr.size() == width &&
2457  llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
2458  [&expected](auto attr) {
2459  return attr.getZExtValue() == expected++;
2460  });
2461 }
2462 
2463 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2464  VectorType v1Type = getV1VectorType();
2465  // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
2466  // but must be a canonicalization into a vector.broadcast.
2467  if (v1Type.getRank() == 0)
2468  return {};
2469 
2470  // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
2471  if (!v1Type.isScalable() &&
2472  isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
2473  return getV1();
2474  // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
2475  if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
2476  isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
2477  getV2VectorType().getDimSize(0)))
2478  return getV2();
2479 
2480  Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
2481  if (!lhs || !rhs)
2482  return {};
2483 
2484  auto lhsType =
2485  llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType());
2486  // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2487  // manipulation.
2488  if (lhsType.getRank() != 1)
2489  return {};
2490  int64_t lhsSize = lhsType.getDimSize(0);
2491 
2492  SmallVector<Attribute> results;
2493  auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
2494  auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
2495  for (const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
2496  int64_t i = index.getZExtValue();
2497  if (i >= lhsSize) {
2498  results.push_back(rhsElements[i - lhsSize]);
2499  } else {
2500  results.push_back(lhsElements[i]);
2501  }
2502  }
2503 
2504  return DenseElementsAttr::get(getResultVectorType(), results);
2505 }
2506 
2507 namespace {
2508 
2509 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
2510 // to a broadcast.
2511 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
2513 
2514  LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
2515  PatternRewriter &rewriter) const override {
2516  VectorType v1VectorType = shuffleOp.getV1VectorType();
2517  ArrayAttr mask = shuffleOp.getMask();
2518  if (v1VectorType.getRank() > 0)
2519  return failure();
2520  if (mask.size() != 1)
2521  return failure();
2522  VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
2523  if (llvm::cast<IntegerAttr>(mask[0]).getInt() == 0)
2524  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2525  shuffleOp.getV1());
2526  else
2527  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2528  shuffleOp.getV2());
2529  return success();
2530  }
2531 };
2532 
2533 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
2534 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
2535 public:
2537 
2538  LogicalResult matchAndRewrite(ShuffleOp op,
2539  PatternRewriter &rewriter) const override {
2540  auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2541  auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2542 
2543  if (!v1Splat || !v2Splat)
2544  return failure();
2545 
2546  if (v1Splat.getInput() != v2Splat.getInput())
2547  return failure();
2548 
2549  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
2550  return success();
2551  }
2552 };
2553 
2554 } // namespace
2555 
2556 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
2557  MLIRContext *context) {
2558  results.add<ShuffleSplat, Canonicalize0DShuffleOp>(context);
2559 }
2560 
2561 //===----------------------------------------------------------------------===//
2562 // InsertElementOp
2563 //===----------------------------------------------------------------------===//
2564 
2565 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
2566  Value source, Value dest) {
2567  build(builder, result, source, dest, {});
2568 }
2569 
2571  auto dstVectorType = getDestVectorType();
2572  if (dstVectorType.getRank() == 0) {
2573  if (getPosition())
2574  return emitOpError("expected position to be empty with 0-D vector");
2575  return success();
2576  }
2577  if (dstVectorType.getRank() != 1)
2578  return emitOpError("unexpected >1 vector rank");
2579  if (!getPosition())
2580  return emitOpError("expected position for 1-D vector");
2581  return success();
2582 }
2583 
2584 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2585  // Skip the 0-D vector here.
2586  if (!adaptor.getPosition())
2587  return {};
2588 
2589  auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2590  auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2591  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2592  if (!src || !dst || !pos)
2593  return {};
2594 
2595  if (src.getType() != getDestVectorType().getElementType())
2596  return {};
2597 
2598  auto dstElements = dst.getValues<Attribute>();
2599 
2600  SmallVector<Attribute> results(dstElements);
2601 
2602  uint64_t posIdx = pos.getInt();
2603  if (posIdx >= results.size())
2604  return {};
2605  results[posIdx] = src;
2606 
2607  return DenseElementsAttr::get(getDestVectorType(), results);
2608 }
2609 
2610 //===----------------------------------------------------------------------===//
2611 // InsertOp
2612 //===----------------------------------------------------------------------===//
2613 
2614 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2615  Value source, Value dest, int64_t position) {
2616  build(builder, result, source, dest, ArrayRef<int64_t>{position});
2617 }
2618 
2619 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2620  Value source, Value dest, OpFoldResult position) {
2621  build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
2622 }
2623 
2624 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2625  Value source, Value dest,
2626  ArrayRef<int64_t> position) {
2627  SmallVector<OpFoldResult> posVals;
2628  posVals.reserve(position.size());
2629  llvm::transform(position, std::back_inserter(posVals),
2630  [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
2631  build(builder, result, source, dest, posVals);
2632 }
2633 
2634 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2635  Value source, Value dest,
2636  ArrayRef<OpFoldResult> position) {
2637  SmallVector<int64_t> staticPos;
2638  SmallVector<Value> dynamicPos;
2639  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
2640  build(builder, result, source, dest, dynamicPos,
2641  builder.getDenseI64ArrayAttr(staticPos));
2642 }
2643 
2645  SmallVector<OpFoldResult> position = getMixedPosition();
2646  auto destVectorType = getDestVectorType();
2647  if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
2648  return emitOpError(
2649  "expected position attribute of rank no greater than dest vector rank");
2650  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2651  if (srcVectorType &&
2652  (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2653  static_cast<unsigned>(destVectorType.getRank())))
2654  return emitOpError("expected position attribute rank + source rank to "
2655  "match dest vector rank");
2656  if (!srcVectorType &&
2657  (position.size() != static_cast<unsigned>(destVectorType.getRank())))
2658  return emitOpError(
2659  "expected position attribute rank to match the dest vector rank");
2660  for (auto [idx, pos] : llvm::enumerate(position)) {
2661  if (auto attr = pos.dyn_cast<Attribute>()) {
2662  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2663  if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
2664  return emitOpError("expected position attribute #")
2665  << (idx + 1)
2666  << " to be a non-negative integer smaller than the "
2667  "corresponding "
2668  "dest vector dimension";
2669  }
2670  }
2671  }
2672  return success();
2673 }
2674 
2675 namespace {
2676 
2677 // If insertOp is only inserting unit dimensions it can be transformed to a
2678 // broadcast.
2679 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
2680 public:
2682 
2683  LogicalResult matchAndRewrite(InsertOp insertOp,
2684  PatternRewriter &rewriter) const override {
2685  auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
2686  if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2687  srcVecType.getNumElements())
2688  return failure();
2689  rewriter.replaceOpWithNewOp<BroadcastOp>(
2690  insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2691  return success();
2692  }
2693 };
2694 
2695 /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
2696 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
2697 public:
2699 
2700  LogicalResult matchAndRewrite(InsertOp op,
2701  PatternRewriter &rewriter) const override {
2702  auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2703  auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2704 
2705  if (!srcSplat || !dstSplat)
2706  return failure();
2707 
2708  if (srcSplat.getInput() != dstSplat.getInput())
2709  return failure();
2710 
2711  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
2712  return success();
2713  }
2714 };
2715 
2716 // Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
2717 class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
2718 public:
2720 
2721  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
2722  // unless the source vector constant has a single use.
2723  static constexpr int64_t vectorSizeFoldThreshold = 256;
2724 
2725  LogicalResult matchAndRewrite(InsertOp op,
2726  PatternRewriter &rewriter) const override {
2727  // TODO: Canonicalization for dynamic position not implemented yet.
2728  if (op.hasDynamicPosition())
2729  return failure();
2730 
2731  // Return if 'InsertOp' operand is not defined by a compatible vector
2732  // ConstantOp.
2733  TypedValue<VectorType> destVector = op.getDest();
2734  Attribute vectorDestCst;
2735  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
2736  return failure();
2737 
2738  VectorType destTy = destVector.getType();
2739  if (destTy.isScalable())
2740  return failure();
2741 
2742  // Make sure we do not create too many large constants.
2743  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
2744  !destVector.hasOneUse())
2745  return failure();
2746 
2747  auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
2748 
2749  Value sourceValue = op.getSource();
2750  Attribute sourceCst;
2751  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
2752  return failure();
2753 
2754  // Calculate the linearized position of the continuous chunk of elements to
2755  // insert.
2756  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
2757  copy(op.getStaticPosition(), completePositions.begin());
2758  int64_t insertBeginPosition =
2759  linearize(completePositions, computeStrides(destTy.getShape()));
2760 
2761  SmallVector<Attribute> insertedValues;
2762  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst))
2763  llvm::append_range(insertedValues, denseSource.getValues<Attribute>());
2764  else
2765  insertedValues.push_back(sourceCst);
2766 
2767  auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
2768  copy(insertedValues, allValues.begin() + insertBeginPosition);
2769  auto newAttr = DenseElementsAttr::get(destTy, allValues);
2770 
2771  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
2772  return success();
2773  }
2774 };
2775 
2776 } // namespace
2777 
2778 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
2779  MLIRContext *context) {
2780  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
2781  InsertOpConstantFolder>(context);
2782 }
2783 
2784 // Eliminates insert operations that produce values identical to their source
2785 // value. This happens when the source and destination vectors have identical
2786 // sizes.
2787 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
2788  if (getNumIndices() == 0)
2789  return getSource();
2790  return {};
2791 }
2792 
2793 //===----------------------------------------------------------------------===//
2794 // InsertStridedSliceOp
2795 //===----------------------------------------------------------------------===//
2796 
2797 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
2798  Value source, Value dest,
2799  ArrayRef<int64_t> offsets,
2800  ArrayRef<int64_t> strides) {
2801  result.addOperands({source, dest});
2802  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
2803  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
2804  result.addTypes(dest.getType());
2805  result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
2806  offsetsAttr);
2807  result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
2808  stridesAttr);
2809 }
2810 
2811 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
2812 template <typename OpType>
2814  ArrayAttr arrayAttr,
2815  ArrayRef<int64_t> shape,
2816  StringRef attrName) {
2817  if (arrayAttr.size() > shape.size())
2818  return op.emitOpError("expected ")
2819  << attrName << " attribute of rank no greater than vector rank";
2820  return success();
2821 }
2822 
2823 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2824 // interval. If `halfOpen` is true then the admissible interval is [min, max).
2825 // Otherwise, the admissible interval is [min, max].
2826 template <typename OpType>
2827 static LogicalResult
2828 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
2829  int64_t max, StringRef attrName,
2830  bool halfOpen = true) {
2831  for (auto attr : arrayAttr) {
2832  auto val = llvm::cast<IntegerAttr>(attr).getInt();
2833  auto upper = max;
2834  if (!halfOpen)
2835  upper += 1;
2836  if (val < min || val >= upper)
2837  return op.emitOpError("expected ") << attrName << " to be confined to ["
2838  << min << ", " << upper << ")";
2839  }
2840  return success();
2841 }
2842 
2843 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2844 // interval. If `halfOpen` is true then the admissible interval is [min, max).
2845 // Otherwise, the admissible interval is [min, max].
2846 template <typename OpType>
2847 static LogicalResult
2848 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
2849  ArrayRef<int64_t> shape, StringRef attrName,
2850  bool halfOpen = true, int64_t min = 0) {
2851  for (auto [index, attrDimPair] :
2852  llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
2853  int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
2854  int64_t max = std::get<1>(attrDimPair);
2855  if (!halfOpen)
2856  max += 1;
2857  if (val < min || val >= max)
2858  return op.emitOpError("expected ")
2859  << attrName << " dimension " << index << " to be confined to ["
2860  << min << ", " << max << ")";
2861  }
2862  return success();
2863 }
2864 
2865 // Returns true if, for all indices i = 0..shape.size()-1, val is in the
2866 // [min, max} interval:
2867 // val = `arrayAttr1[i]` + `arrayAttr2[i]`,
2868 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
2869 // the admissible interval is [min, max].
2870 template <typename OpType>
2872  OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
2873  ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
2874  bool halfOpen = true, int64_t min = 1) {
2875  assert(arrayAttr1.size() <= shape.size());
2876  assert(arrayAttr2.size() <= shape.size());
2877  for (auto [index, it] :
2878  llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
2879  auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
2880  auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
2881  int64_t max = std::get<2>(it);
2882  if (!halfOpen)
2883  max += 1;
2884  if (val1 + val2 < 0 || val1 + val2 >= max)
2885  return op.emitOpError("expected sum(")
2886  << attrName1 << ", " << attrName2 << ") dimension " << index
2887  << " to be confined to [" << min << ", " << max << ")";
2888  }
2889  return success();
2890 }
2891 
2892 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
2893  MLIRContext *context) {
2894  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
2895  return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
2896  });
2897  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
2898 }
2899 
2901  auto sourceVectorType = getSourceVectorType();
2902  auto destVectorType = getDestVectorType();
2903  auto offsets = getOffsetsAttr();
2904  auto strides = getStridesAttr();
2905  if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
2906  return emitOpError(
2907  "expected offsets of same size as destination vector rank");
2908  if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
2909  return emitOpError("expected strides of same size as source vector rank");
2910  if (sourceVectorType.getRank() > destVectorType.getRank())
2911  return emitOpError(
2912  "expected source rank to be no greater than destination rank");
2913 
2914  auto sourceShape = sourceVectorType.getShape();
2915  auto destShape = destVectorType.getShape();
2916  SmallVector<int64_t, 4> sourceShapeAsDestShape(
2917  destShape.size() - sourceShape.size(), 0);
2918  sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
2919  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
2920  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
2921  if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
2922  offName)) ||
2923  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
2924  /*max=*/1, stridesName,
2925  /*halfOpen=*/false)) ||
2927  *this, offsets,
2928  makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
2929  offName, "source vector shape",
2930  /*halfOpen=*/false, /*min=*/1)))
2931  return failure();
2932 
2933  unsigned rankDiff = destShape.size() - sourceShape.size();
2934  for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
2935  if (sourceVectorType.getScalableDims()[idx] !=
2936  destVectorType.getScalableDims()[idx + rankDiff]) {
2937  return emitOpError("mismatching scalable flags (at source vector idx=")
2938  << idx << ")";
2939  }
2940  if (sourceVectorType.getScalableDims()[idx]) {
2941  auto sourceSize = sourceShape[idx];
2942  auto destSize = destShape[idx + rankDiff];
2943  if (sourceSize != destSize) {
2944  return emitOpError("expected size at idx=")
2945  << idx
2946  << (" to match the corresponding base size from the input "
2947  "vector (")
2948  << sourceSize << (" vs ") << destSize << (")");
2949  }
2950  }
2951  }
2952 
2953  return success();
2954 }
2955 
2956 namespace {
2957 /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
2958 /// SplatOp(X):dst_type) to SplatOp(X):dst_type.
2959 class FoldInsertStridedSliceSplat final
2960  : public OpRewritePattern<InsertStridedSliceOp> {
2961 public:
2963 
2964  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
2965  PatternRewriter &rewriter) const override {
2966  auto srcSplatOp =
2967  insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
2968  auto destSplatOp =
2969  insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
2970 
2971  if (!srcSplatOp || !destSplatOp)
2972  return failure();
2973 
2974  if (srcSplatOp.getInput() != destSplatOp.getInput())
2975  return failure();
2976 
2977  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2978  return success();
2979  }
2980 };
2981 
2982 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
2983 /// to dst.
2984 class FoldInsertStridedSliceOfExtract final
2985  : public OpRewritePattern<InsertStridedSliceOp> {
2986 public:
2988 
2989  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
2990  PatternRewriter &rewriter) const override {
2991  auto extractStridedSliceOp =
2992  insertStridedSliceOp.getSource()
2993  .getDefiningOp<vector::ExtractStridedSliceOp>();
2994 
2995  if (!extractStridedSliceOp)
2996  return failure();
2997 
2998  if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
2999  return failure();
3000 
3001  // Check if have the same strides and offsets.
3002  if (extractStridedSliceOp.getStrides() !=
3003  insertStridedSliceOp.getStrides() ||
3004  extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3005  return failure();
3006 
3007  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3008  return success();
3009  }
3010 };
3011 
3012 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3013 // ConstantOp.
3014 class InsertStridedSliceConstantFolder final
3015  : public OpRewritePattern<InsertStridedSliceOp> {
3016 public:
3018 
3019  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3020  // unless the source vector constant has a single use.
3021  static constexpr int64_t vectorSizeFoldThreshold = 256;
3022 
3023  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3024  PatternRewriter &rewriter) const override {
3025  // Return if 'InsertOp' operand is not defined by a compatible vector
3026  // ConstantOp.
3027  TypedValue<VectorType> destVector = op.getDest();
3028  Attribute vectorDestCst;
3029  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3030  return failure();
3031 
3032  VectorType destTy = destVector.getType();
3033  if (destTy.isScalable())
3034  return failure();
3035 
3036  // Make sure we do not create too many large constants.
3037  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3038  !destVector.hasOneUse())
3039  return failure();
3040 
3041  auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3042 
3043  TypedValue<VectorType> sourceValue = op.getSource();
3044  Attribute sourceCst;
3045  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3046  return failure();
3047 
3048  // TODO: Handle non-unit strides when they become available.
3049  if (op.hasNonUnitStrides())
3050  return failure();
3051 
3052  VectorType sliceVecTy = sourceValue.getType();
3053  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3054  int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3055  SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3056  SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3057 
3058  // Calcualte the destination element indices by enumerating all slice
3059  // positions within the destination and linearizing them. The enumeration
3060  // order is lexicographic which yields a sequence of monotonically
3061  // increasing linearized position indices.
3062  // Because the destination may have higher dimensionality then the slice,
3063  // we keep track of two overlapping sets of positions and offsets.
3064  auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3065  auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3066  auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3067  SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3068  MutableArrayRef<int64_t> currSlicePosition(
3069  currDestPosition.begin() + rankDifference, currDestPosition.end());
3070  ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3071  offsets.end());
3072  do {
3073  int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3074  assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3075  assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3076  "Invalid slice element");
3077  newValues[linearizedPosition] = *sliceValuesIt;
3078  ++sliceValuesIt;
3079  } while (succeeded(
3080  incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3081 
3082  auto newAttr = DenseElementsAttr::get(destTy, newValues);
3083  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3084  return success();
3085  }
3086 };
3087 
3088 } // namespace
3089 
3090 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3091  RewritePatternSet &results, MLIRContext *context) {
3092  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3093  InsertStridedSliceConstantFolder>(context);
3094 }
3095 
3096 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3097  if (getSourceVectorType() == getDestVectorType())
3098  return getSource();
3099  return {};
3100 }
3101 
3102 //===----------------------------------------------------------------------===//
3103 // OuterProductOp
3104 //===----------------------------------------------------------------------===//
3105 
3106 /// Build an op without mask, use the type of `acc` as the return type.
3107 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
3108  Value lhs, Value rhs, Value acc) {
3109  result.addOperands({lhs, rhs, acc});
3110  result.addTypes(acc.getType());
3111 }
3112 
3114  p << " " << getLhs() << ", " << getRhs();
3115  if (getAcc()) {
3116  p << ", " << getAcc();
3117  p.printOptionalAttrDict((*this)->getAttrs());
3118  }
3119  p << " : " << getLhs().getType() << ", " << getRhs().getType();
3120 }
3121 
3124  Type tLHS, tRHS;
3125  if (parser.parseOperandList(operandsInfo) ||
3126  parser.parseOptionalAttrDict(result.attributes) ||
3127  parser.parseColonType(tLHS) || parser.parseComma() ||
3128  parser.parseType(tRHS))
3129  return failure();
3130  if (operandsInfo.size() < 2)
3131  return parser.emitError(parser.getNameLoc(),
3132  "expected at least 2 operands");
3133  VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3134  VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3135  if (!vLHS)
3136  return parser.emitError(parser.getNameLoc(),
3137  "expected vector type for operand #1");
3138 
3139  VectorType resType;
3140  if (vRHS) {
3141  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3142  vRHS.getScalableDims()[0]};
3143  resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
3144  vLHS.getElementType(), scalableDimsRes);
3145  } else {
3146  // Scalar RHS operand
3147  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3148  resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3149  scalableDimsRes);
3150  }
3151 
3152  if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
3153  result.attributes.append(
3154  OuterProductOp::getKindAttrName(result.name),
3156  OuterProductOp::getDefaultKind()));
3157  }
3158 
3159  return failure(
3160  parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
3161  parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
3162  (operandsInfo.size() > 2 &&
3163  parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
3164  parser.addTypeToList(resType, result.types));
3165 }
3166 
3168  Type tRHS = getOperandTypeRHS();
3169  VectorType vLHS = getOperandVectorTypeLHS(),
3170  vRHS = llvm::dyn_cast<VectorType>(tRHS),
3171  vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3172 
3173  if (vLHS.getRank() != 1)
3174  return emitOpError("expected 1-d vector for operand #1");
3175 
3176  if (vRHS) {
3177  // Proper OUTER operation.
3178  if (vRHS.getRank() != 1)
3179  return emitOpError("expected 1-d vector for operand #2");
3180  if (vRES.getRank() != 2)
3181  return emitOpError("expected 2-d vector result");
3182  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3183  return emitOpError("expected #1 operand dim to match result dim #1");
3184  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3185  return emitOpError("expected #2 operand dim to match result dim #2");
3186  if (vLHS.isScalable() && !vRHS.isScalable()) {
3187  // This restriction reflects what's currently supported in terms of
3188  // scalable vectors. However, we could relax this if there's a use case.
3189  return emitOpError(
3190  "expected either both or only #2 operand dim to be scalable");
3191  }
3192  } else {
3193  // An AXPY operation.
3194  if (vRES.getRank() != 1)
3195  return emitOpError("expected 1-d vector result");
3196  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3197  return emitOpError("expected #1 operand dim to match result dim #1");
3198  }
3199 
3200  if (vACC && vACC != vRES)
3201  return emitOpError("expected operand #3 of same type as result type");
3202 
3203  // Verify supported combining kind.
3204  if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
3205  return emitOpError("unsupported outerproduct type");
3206 
3207  return success();
3208 }
3209 
3210 // MaskableOpInterface methods.
3211 
3212 /// Returns the mask type expected by this operation. Mostly used for
3213 /// verification purposes. It requires the operation to be vectorized."
3214 Type OuterProductOp::getExpectedMaskType() {
3215  auto vecType = this->getResultVectorType();
3216  return VectorType::get(vecType.getShape(),
3217  IntegerType::get(vecType.getContext(), /*width=*/1),
3218  vecType.getScalableDims());
3219 }
3220 
3221 //===----------------------------------------------------------------------===//
3222 // ReshapeOp
3223 //===----------------------------------------------------------------------===//
3224 
3226  // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
3227  auto inputVectorType = getInputVectorType();
3228  auto outputVectorType = getOutputVectorType();
3229  int64_t inputShapeRank = getNumInputShapeSizes();
3230  int64_t outputShapeRank = getNumOutputShapeSizes();
3231  SmallVector<int64_t, 4> fixedVectorSizes;
3232  getFixedVectorSizes(fixedVectorSizes);
3233  int64_t numFixedVectorSizes = fixedVectorSizes.size();
3234 
3235  if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
3236  return emitError("invalid input shape for vector type ") << inputVectorType;
3237 
3238  if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
3239  return emitError("invalid output shape for vector type ")
3240  << outputVectorType;
3241 
3242  // Verify that the 'fixedVectorSizes' match an input/output vector shape
3243  // suffix.
3244  unsigned inputVectorRank = inputVectorType.getRank();
3245  for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
3246  unsigned index = inputVectorRank - numFixedVectorSizes - i;
3247  if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
3248  return emitError("fixed vector size must match input vector for dim ")
3249  << i;
3250  }
3251 
3252  unsigned outputVectorRank = outputVectorType.getRank();
3253  for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
3254  unsigned index = outputVectorRank - numFixedVectorSizes - i;
3255  if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
3256  return emitError("fixed vector size must match output vector for dim ")
3257  << i;
3258  }
3259 
3260  // If all shape operands are produced by constant ops, verify that product
3261  // of dimensions for input/output shape match.
3262  auto isDefByConstant = [](Value operand) {
3263  return getConstantIntValue(operand).has_value();
3264  };
3265  if (llvm::all_of(getInputShape(), isDefByConstant) &&
3266  llvm::all_of(getOutputShape(), isDefByConstant)) {
3267  int64_t numInputElements = 1;
3268  for (auto operand : getInputShape())
3269  numInputElements *= getConstantIntValue(operand).value();
3270  int64_t numOutputElements = 1;
3271  for (auto operand : getOutputShape())
3272  numOutputElements *= getConstantIntValue(operand).value();
3273  if (numInputElements != numOutputElements)
3274  return emitError("product of input and output shape sizes must match");
3275  }
3276  return success();
3277 }
3278 
3279 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
3280  populateFromInt64AttrArray(getFixedVectorSizes(), results);
3281 }
3282 
3283 //===----------------------------------------------------------------------===//
3284 // ExtractStridedSliceOp
3285 //===----------------------------------------------------------------------===//
3286 
3287 // Inference works as follows:
3288 // 1. Add 'sizes' from prefix of dims in 'offsets'.
3289 // 2. Add sizes from 'vectorType' for remaining dims.
3290 // Scalable flags are inherited from 'vectorType'.
3291 static Type inferStridedSliceOpResultType(VectorType vectorType,
3292  ArrayAttr offsets, ArrayAttr sizes,
3293  ArrayAttr strides) {
3294  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3296  shape.reserve(vectorType.getRank());
3297  unsigned idx = 0;
3298  for (unsigned e = offsets.size(); idx < e; ++idx)
3299  shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3300  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3301  shape.push_back(vectorType.getShape()[idx]);
3302 
3303  return VectorType::get(shape, vectorType.getElementType(),
3304  vectorType.getScalableDims());
3305 }
3306 
3307 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3308  Value source, ArrayRef<int64_t> offsets,
3309  ArrayRef<int64_t> sizes,
3310  ArrayRef<int64_t> strides) {
3311  result.addOperands(source);
3312  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3313  auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
3314  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3315  result.addTypes(
3316  inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
3317  offsetsAttr, sizesAttr, stridesAttr));
3318  result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
3319  offsetsAttr);
3320  result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
3321  sizesAttr);
3322  result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
3323  stridesAttr);
3324 }
3325 
3327  auto type = getSourceVectorType();
3328  auto offsets = getOffsetsAttr();
3329  auto sizes = getSizesAttr();
3330  auto strides = getStridesAttr();
3331  if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3332  return emitOpError(
3333  "expected offsets, sizes and strides attributes of same size");
3334 
3335  auto shape = type.getShape();
3336  auto offName = getOffsetsAttrName();
3337  auto sizesName = getSizesAttrName();
3338  auto stridesName = getStridesAttrName();
3339  if (failed(
3340  isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
3341  failed(
3342  isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
3343  failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
3344  stridesName)) ||
3345  failed(
3346  isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
3347  failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
3348  /*halfOpen=*/false,
3349  /*min=*/1)) ||
3350  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3351  /*max=*/1, stridesName,
3352  /*halfOpen=*/false)) ||
3353  failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
3354  shape, offName, sizesName,
3355  /*halfOpen=*/false)))
3356  return failure();
3357 
3358  auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
3359  offsets, sizes, strides);
3360  if (getResult().getType() != resultType)
3361  return emitOpError("expected result type to be ") << resultType;
3362 
3363  for (unsigned idx = 0; idx < sizes.size(); ++idx) {
3364  if (type.getScalableDims()[idx]) {
3365  auto inputDim = type.getShape()[idx];
3366  auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3367  if (inputDim != inputSize)
3368  return emitOpError("expected size at idx=")
3369  << idx
3370  << (" to match the corresponding base size from the input "
3371  "vector (")
3372  << inputSize << (" vs ") << inputDim << (")");
3373  }
3374  }
3375 
3376  return success();
3377 }
3378 
3379 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
3380 // to use the source of the InsertStrided ops if we can detect that the
3381 // extracted vector is a subset of one of the vector inserted.
3382 static LogicalResult
3383 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
3384  // Helper to extract integer out of ArrayAttr.
3385  auto getElement = [](ArrayAttr array, int idx) {
3386  return llvm::cast<IntegerAttr>(array[idx]).getInt();
3387  };
3388  ArrayAttr extractOffsets = op.getOffsets();
3389  ArrayAttr extractStrides = op.getStrides();
3390  ArrayAttr extractSizes = op.getSizes();
3391  auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3392  while (insertOp) {
3393  if (op.getSourceVectorType().getRank() !=
3394  insertOp.getSourceVectorType().getRank())
3395  return failure();
3396  ArrayAttr insertOffsets = insertOp.getOffsets();
3397  ArrayAttr insertStrides = insertOp.getStrides();
3398  // If the rank of extract is greater than the rank of insert, we are likely
3399  // extracting a partial chunk of the vector inserted.
3400  if (extractOffsets.size() > insertOffsets.size())
3401  return failure();
3402  bool patialoverlap = false;
3403  bool disjoint = false;
3404  SmallVector<int64_t, 4> offsetDiffs;
3405  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3406  if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
3407  return failure();
3408  int64_t start = getElement(insertOffsets, dim);
3409  int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3410  int64_t offset = getElement(extractOffsets, dim);
3411  int64_t size = getElement(extractSizes, dim);
3412  // Check if the start of the extract offset is in the interval inserted.
3413  if (start <= offset && offset < end) {
3414  // If the extract interval overlaps but is not fully included we may
3415  // have a partial overlap that will prevent any folding.
3416  if (offset + size > end)
3417  patialoverlap = true;
3418  offsetDiffs.push_back(offset - start);
3419  continue;
3420  }
3421  disjoint = true;
3422  break;
3423  }
3424  // The extract element chunk is a subset of the insert element.
3425  if (!disjoint && !patialoverlap) {
3426  op.setOperand(insertOp.getSource());
3427  // OpBuilder is only used as a helper to build an I64ArrayAttr.
3428  OpBuilder b(op.getContext());
3429  op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
3430  return success();
3431  }
3432  // If the chunk extracted is disjoint from the chunk inserted, keep looking
3433  // in the insert chain.
3434  if (disjoint)
3435  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3436  else {
3437  // The extracted vector partially overlap the inserted vector, we cannot
3438  // fold.
3439  return failure();
3440  }
3441  }
3442  return failure();
3443 }
3444 
3445 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3446  if (getSourceVectorType() == getResult().getType())
3447  return getVector();
3449  return getResult();
3450  return {};
3451 }
3452 
3453 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
3454  populateFromInt64AttrArray(getOffsets(), results);
3455 }
3456 
3457 namespace {
3458 
3459 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
3460 // ConstantMaskOp.
3461 class StridedSliceConstantMaskFolder final
3462  : public OpRewritePattern<ExtractStridedSliceOp> {
3463 public:
3465 
3466  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3467  PatternRewriter &rewriter) const override {
3468  // Return if 'extractStridedSliceOp' operand is not defined by a
3469  // ConstantMaskOp.
3470  auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3471  auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3472  if (!constantMaskOp)
3473  return failure();
3474  // Return if 'extractStridedSliceOp' has non-unit strides.
3475  if (extractStridedSliceOp.hasNonUnitStrides())
3476  return failure();
3477  // Gather constant mask dimension sizes.
3478  SmallVector<int64_t, 4> maskDimSizes;
3479  populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
3480  // Gather strided slice offsets and sizes.
3481  SmallVector<int64_t, 4> sliceOffsets;
3482  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
3483  sliceOffsets);
3484  SmallVector<int64_t, 4> sliceSizes;
3485  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
3486 
3487  // Compute slice of vector mask region.
3488  SmallVector<int64_t, 4> sliceMaskDimSizes;
3489  sliceMaskDimSizes.reserve(maskDimSizes.size());
3490  for (auto [maskDimSize, sliceOffset, sliceSize] :
3491  llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3492  int64_t sliceMaskDimSize = std::max(
3493  static_cast<int64_t>(0),
3494  std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3495  sliceMaskDimSizes.push_back(sliceMaskDimSize);
3496  }
3497  // Add unchanged dimensions.
3498  if (sliceMaskDimSizes.size() < maskDimSizes.size())
3499  for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3500  sliceMaskDimSizes.push_back(maskDimSizes[i]);
3501  // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
3502  // region is a conjunction of mask dim intervals).
3503  if (llvm::is_contained(sliceMaskDimSizes, 0))
3504  sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3505 
3506  // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
3507  // region.
3508  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3509  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3510  vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
3511  return success();
3512  }
3513 };
3514 
3515 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3516 class StridedSliceSplatConstantFolder final
3517  : public OpRewritePattern<ExtractStridedSliceOp> {
3518 public:
3520 
3521  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3522  PatternRewriter &rewriter) const override {
3523  // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3524  // ConstantOp.
3525  Value sourceVector = extractStridedSliceOp.getVector();
3526  Attribute vectorCst;
3527  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3528  return failure();
3529 
3530  auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3531  if (!splat)
3532  return failure();
3533 
3534  auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
3535  splat.getSplatValue<Attribute>());
3536  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3537  newAttr);
3538  return success();
3539  }
3540 };
3541 
3542 // Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3543 // ConstantOp.
3544 class StridedSliceNonSplatConstantFolder final
3545  : public OpRewritePattern<ExtractStridedSliceOp> {
3546 public:
3548 
3549  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3550  PatternRewriter &rewriter) const override {
3551  // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3552  // ConstantOp.
3553  Value sourceVector = extractStridedSliceOp.getVector();
3554  Attribute vectorCst;
3555  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3556  return failure();
3557 
3558  // The splat case is handled by `StridedSliceSplatConstantFolder`.
3559  auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3560  if (!dense || dense.isSplat())
3561  return failure();
3562 
3563  // TODO: Handle non-unit strides when they become available.
3564  if (extractStridedSliceOp.hasNonUnitStrides())
3565  return failure();
3566 
3567  auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
3568  ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3569  SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3570 
3571  VectorType sliceVecTy = extractStridedSliceOp.getType();
3572  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3573  int64_t sliceRank = sliceVecTy.getRank();
3574 
3575  // Expand offsets and sizes to match the vector rank.
3576  SmallVector<int64_t, 4> offsets(sliceRank, 0);
3577  copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
3578 
3579  SmallVector<int64_t, 4> sizes(sourceShape.begin(), sourceShape.end());
3580  copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
3581 
3582  // Calculate the slice elements by enumerating all slice positions and
3583  // linearizing them. The enumeration order is lexicographic which yields a
3584  // sequence of monotonically increasing linearized position indices.
3585  auto denseValuesBegin = dense.value_begin<Attribute>();
3586  SmallVector<Attribute> sliceValues;
3587  sliceValues.reserve(sliceVecTy.getNumElements());
3588  SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3589  do {
3590  int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
3591  assert(linearizedPosition < sourceVecTy.getNumElements() &&
3592  "Invalid index");
3593  sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3594  } while (
3595  succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
3596 
3597  assert(static_cast<int64_t>(sliceValues.size()) ==
3598  sliceVecTy.getNumElements() &&
3599  "Invalid number of slice elements");
3600  auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
3601  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3602  newAttr);
3603  return success();
3604  }
3605 };
3606 
3607 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
3608 // BroadcastOp(ExtractStrideSliceOp).
3609 class StridedSliceBroadcast final
3610  : public OpRewritePattern<ExtractStridedSliceOp> {
3611 public:
3613 
3614  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3615  PatternRewriter &rewriter) const override {
3616  auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3617  if (!broadcast)
3618  return failure();
3619  auto srcVecType =
3620  llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
3621  unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3622  auto dstVecType = llvm::cast<VectorType>(op.getType());
3623  unsigned dstRank = dstVecType.getRank();
3624  unsigned rankDiff = dstRank - srcRank;
3625  // Check if the most inner dimensions of the source of the broadcast are the
3626  // same as the destination of the extract. If this is the case we can just
3627  // use a broadcast as the original dimensions are untouched.
3628  bool lowerDimMatch = true;
3629  for (unsigned i = 0; i < srcRank; i++) {
3630  if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3631  lowerDimMatch = false;
3632  break;
3633  }
3634  }
3635  Value source = broadcast.getSource();
3636  // If the inner dimensions don't match, it means we need to extract from the
3637  // source of the orignal broadcast and then broadcast the extracted value.
3638  // We also need to handle degenerated cases where the source is effectively
3639  // just a single scalar.
3640  bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3641  if (!lowerDimMatch && !isScalarSrc) {
3642  source = rewriter.create<ExtractStridedSliceOp>(
3643  op->getLoc(), source,
3644  getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
3645  getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
3646  getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
3647  }
3648  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
3649  return success();
3650  }
3651 };
3652 
3653 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
3654 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
3655 public:
3657 
3658  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3659  PatternRewriter &rewriter) const override {
3660  auto splat = op.getVector().getDefiningOp<SplatOp>();
3661  if (!splat)
3662  return failure();
3663  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
3664  return success();
3665  }
3666 };
3667 
3668 } // namespace
3669 
3670 void ExtractStridedSliceOp::getCanonicalizationPatterns(
3671  RewritePatternSet &results, MLIRContext *context) {
3672  // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
3673  // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
3674  results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
3675  StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3676  StridedSliceSplat>(context);
3677 }
3678 
3679 //===----------------------------------------------------------------------===//
3680 // TransferReadOp
3681 //===----------------------------------------------------------------------===//
3682 
3683 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
3684 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3685  VectorType vectorType, Value source,
3686  ValueRange indices, AffineMapAttr permutationMapAttr,
3687  /*optional*/ ArrayAttr inBoundsAttr) {
3688  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3689  Value padding = builder.create<arith::ConstantOp>(
3690  result.location, elemType, builder.getZeroAttr(elemType));
3691  build(builder, result, vectorType, source, indices, permutationMapAttr,
3692  padding, /*mask=*/Value(), inBoundsAttr);
3693 }
3694 
3695 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
3696 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3697  VectorType vectorType, Value source,
3698  ValueRange indices, AffineMap permutationMap,
3699  std::optional<ArrayRef<bool>> inBounds) {
3700  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3701  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3702  ? builder.getBoolArrayAttr(inBounds.value())
3703  : ArrayAttr();
3704  build(builder, result, vectorType, source, indices, permutationMapAttr,
3705  inBoundsAttr);
3706 }
3707 
3708 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
3709 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3710  VectorType vectorType, Value source,
3711  ValueRange indices, Value padding,
3712  std::optional<ArrayRef<bool>> inBounds) {
3713  AffineMap permutationMap = getTransferMinorIdentityMap(
3714  llvm::cast<ShapedType>(source.getType()), vectorType);
3715  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3716  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3717  ? builder.getBoolArrayAttr(inBounds.value())
3718  : ArrayAttr();
3719  build(builder, result, vectorType, source, indices, permutationMapAttr,
3720  padding,
3721  /*mask=*/Value(), inBoundsAttr);
3722 }
3723 
3724 /// 4. Builder that sets padding to zero and permutation map to
3725 /// 'getMinorIdentityMap'.
3726 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3727  VectorType vectorType, Value source,
3728  ValueRange indices,
3729  std::optional<ArrayRef<bool>> inBounds) {
3730  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3731  Value padding = builder.create<arith::ConstantOp>(
3732  result.location, elemType, builder.getZeroAttr(elemType));
3733  build(builder, result, vectorType, source, indices, padding, inBounds);
3734 }
3735 
3736 template <typename EmitFun>
3738  EmitFun emitOpError) {
3739  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
3740  for (auto expr : permutationMap.getResults()) {
3741  auto dim = dyn_cast<AffineDimExpr>(expr);
3742  auto zero = dyn_cast<AffineConstantExpr>(expr);
3743  if (zero) {
3744  if (zero.getValue() != 0) {
3745  return emitOpError(
3746  "requires a projected permutation_map (at most one dim or the zero "
3747  "constant can appear in each result)");
3748  }
3749  continue;
3750  }
3751  if (!dim) {
3752  return emitOpError("requires a projected permutation_map (at most one "
3753  "dim or the zero constant can appear in each result)");
3754  }
3755  if (seen[dim.getPosition()]) {
3756  return emitOpError(
3757  "requires a permutation_map that is a permutation (found one dim "
3758  "used more than once)");
3759  }
3760  seen[dim.getPosition()] = true;
3761  }
3762  return success();
3763 }
3764 
3765 static LogicalResult
3766 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
3767  VectorType vectorType, VectorType maskType,
3768  VectorType inferredMaskType, AffineMap permutationMap,
3769  ArrayAttr inBounds) {
3770  if (op->hasAttr("masked")) {
3771  return op->emitOpError("masked attribute has been removed. "
3772  "Use in_bounds instead.");
3773  }
3774 
3775  if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
3776  return op->emitOpError(
3777  "requires source to be a memref or ranked tensor type");
3778 
3779  auto elementType = shapedType.getElementType();
3780  DataLayout dataLayout = DataLayout::closest(op);
3781  if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
3782  // Memref or tensor has vector element type.
3783  unsigned sourceVecSize =
3784  dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
3785  vectorElementType.getShape().back();
3786  unsigned resultVecSize =
3787  dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
3788  vectorType.getShape().back();
3789  if (resultVecSize % sourceVecSize != 0)
3790  return op->emitOpError(
3791  "requires the bitwidth of the minor 1-D vector to be an integral "
3792  "multiple of the bitwidth of the minor 1-D vector of the source");
3793 
3794  unsigned sourceVecEltRank = vectorElementType.getRank();
3795  unsigned resultVecRank = vectorType.getRank();
3796  if (sourceVecEltRank > resultVecRank)
3797  return op->emitOpError(
3798  "requires source vector element and vector result ranks to match.");
3799  unsigned rankOffset = resultVecRank - sourceVecEltRank;
3800  // Check that permutation map results match 'rankOffset' of vector type.
3801  if (permutationMap.getNumResults() != rankOffset)
3802  return op->emitOpError("requires a permutation_map with result dims of "
3803  "the same rank as the vector type");
3804 
3805  if (maskType)
3806  return op->emitOpError("does not support masks with vector element type");
3807  } else {
3808  // Memref or tensor has scalar element type.
3809  unsigned minorSize =
3810  vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
3811  unsigned resultVecSize =
3812  dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
3813  if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
3814  return op->emitOpError(
3815  "requires the bitwidth of the minor 1-D vector to be an integral "
3816  "multiple of the bitwidth of the source element type");
3817 
3818  // Check that permutation map results match rank of vector type.
3819  if (permutationMap.getNumResults() != vectorType.getRank())
3820  return op->emitOpError("requires a permutation_map with result dims of "
3821  "the same rank as the vector type");
3822  }
3823 
3824  if (permutationMap.getNumSymbols() != 0)
3825  return op->emitOpError("requires permutation_map without symbols");
3826 
3827  if (permutationMap.getNumInputs() != shapedType.getRank())
3828  return op->emitOpError("requires a permutation_map with input dims of the "
3829  "same rank as the source type");
3830 
3831  if (maskType && maskType != inferredMaskType)
3832  return op->emitOpError("inferred mask type (")
3833  << inferredMaskType << ") and mask operand type (" << maskType
3834  << ") don't match";
3835 
3836  if (inBounds) {
3837  if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
3838  return op->emitOpError("expects the optional in_bounds attr of same rank "
3839  "as permutation_map results: ")
3840  << AffineMapAttr::get(permutationMap)
3841  << " vs inBounds of size: " << inBounds.size();
3842  for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
3843  if (isa<AffineConstantExpr>(permutationMap.getResult(i)) &&
3844  !llvm::cast<BoolAttr>(inBounds.getValue()[i]).getValue())
3845  return op->emitOpError("requires broadcast dimensions to be in-bounds");
3846  }
3847 
3848  return success();
3849 }
3850 
3851 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
3852  SmallVector<StringRef, 3> elidedAttrs;
3853  elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
3854  if (op.getPermutationMap().isMinorIdentity())
3855  elidedAttrs.push_back(op.getPermutationMapAttrName());
3856  // Elide in_bounds attribute if all dims are out-of-bounds.
3857  if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
3858  elidedAttrs.push_back(op.getInBoundsAttrName());
3859  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
3860 }
3861 
3863  p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
3864  if (getMask())
3865  p << ", " << getMask();
3866  printTransferAttrs(p, *this);
3867  p << " : " << getShapedType() << ", " << getVectorType();
3868 }
3869 
3870 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
3871  AffineMap permMap) {
3872  auto i1Type = IntegerType::get(permMap.getContext(), 1);
3873  AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
3874  assert(invPermMap && "Inversed permutation map couldn't be computed");
3875  SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
3876 
3877  SmallVector<bool> scalableDims =
3878  applyPermutationMap(invPermMap, vecType.getScalableDims());
3879 
3880  return VectorType::get(maskShape, i1Type, scalableDims);
3881 }
3882 
3884  auto &builder = parser.getBuilder();
3885  SMLoc typesLoc;
3886  OpAsmParser::UnresolvedOperand sourceInfo;
3888  OpAsmParser::UnresolvedOperand paddingInfo;
3889  SmallVector<Type, 2> types;
3891  // Parsing with support for paddingValue.
3892  if (parser.parseOperand(sourceInfo) ||
3893  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
3894  parser.parseComma() || parser.parseOperand(paddingInfo))
3895  return failure();
3896  ParseResult hasMask = parser.parseOptionalComma();
3897  if (hasMask.succeeded()) {
3898  if (parser.parseOperand(maskInfo))
3899  return failure();
3900  }
3901  if (parser.parseOptionalAttrDict(result.attributes) ||
3902  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
3903  return failure();
3904  if (types.size() != 2)
3905  return parser.emitError(typesLoc, "requires two types");
3906  auto indexType = builder.getIndexType();
3907  auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
3908  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
3909  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
3910  VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
3911  if (!vectorType)
3912  return parser.emitError(typesLoc, "requires vector type");
3913  auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
3914  Attribute permMapAttr = result.attributes.get(permMapAttrName);
3915  AffineMap permMap;
3916  if (!permMapAttr) {
3917  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3918  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
3919  } else {
3920  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
3921  }
3922  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
3923  parser.resolveOperands(indexInfo, indexType, result.operands) ||
3924  parser.resolveOperand(paddingInfo, shapedType.getElementType(),
3925  result.operands))
3926  return failure();
3927  if (hasMask.succeeded()) {
3928  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
3929  return parser.emitError(
3930  maskInfo.location, "does not support masks with vector element type");
3931  if (vectorType.getRank() != permMap.getNumResults()) {
3932  return parser.emitError(typesLoc,
3933  "expected the same rank for the vector and the "
3934  "results of the permutation map");
3935  }
3936  // Instead of adding the mask type as an op type, compute it based on the
3937  // vector type and the permutation map (to keep the type signature small).
3938  auto maskType = inferTransferOpMaskType(vectorType, permMap);
3939  if (parser.resolveOperand(maskInfo, maskType, result.operands))
3940  return failure();
3941  }
3942  result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
3943  builder.getDenseI32ArrayAttr(
3944  {1, static_cast<int32_t>(indexInfo.size()), 1,
3945  static_cast<int32_t>(hasMask.succeeded())}));
3946  return parser.addTypeToList(vectorType, result.types);
3947 }
3948 
3950  // Consistency of elemental types in source and vector.
3951  ShapedType shapedType = getShapedType();
3952  VectorType vectorType = getVectorType();
3953  VectorType maskType = getMaskType();
3954  auto paddingType = getPadding().getType();
3955  auto permutationMap = getPermutationMap();
3956  VectorType inferredMaskType =
3957  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
3958  : VectorType();
3959  auto sourceElementType = shapedType.getElementType();
3960 
3961  if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
3962  return emitOpError("requires ") << shapedType.getRank() << " indices";
3963 
3964  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
3965  shapedType, vectorType, maskType,
3966  inferredMaskType, permutationMap,
3967  getInBounds() ? *getInBounds() : ArrayAttr())))
3968  return failure();
3969 
3970  if (auto sourceVectorElementType =
3971  llvm::dyn_cast<VectorType>(sourceElementType)) {
3972  // Source has vector element type.
3973  // Check that 'sourceVectorElementType' and 'paddingType' types match.
3974  if (sourceVectorElementType != paddingType)
3975  return emitOpError(
3976  "requires source element type and padding type to match.");
3977 
3978  } else {
3979  // Check that 'paddingType' is valid to store in a vector type.
3980  if (!VectorType::isValidElementType(paddingType))
3981  return emitOpError("requires valid padding vector elemental type");
3982 
3983  // Check that padding type and vector element types match.
3984  if (paddingType != sourceElementType)
3985  return emitOpError(
3986  "requires formal padding and source of the same elemental type");
3987  }
3988 
3989  return verifyPermutationMap(permutationMap,
3990  [&](Twine t) { return emitOpError(t); });
3991 }
3992 
3993 // MaskableOpInterface methods.
3994 
3995 /// Returns the mask type expected by this operation. Mostly used for
3996 /// verification purposes. It requires the operation to be vectorized."
3997 Type TransferReadOp::getExpectedMaskType() {
3998  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
3999 }
4000 
4001 template <typename TransferOp>
4002 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4003  // TODO: support more aggressive createOrFold on:
4004  // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4005  if (op.getShapedType().isDynamicDim(indicesIdx))
4006  return false;
4007  Value index = op.getIndices()[indicesIdx];
4008  std::optional<int64_t> cstOp = getConstantIntValue(index);
4009  if (!cstOp.has_value())
4010  return false;
4011 
4012  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4013  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4014 
4015  return cstOp.value() + vectorSize <= sourceSize;
4016 }
4017 
4018 template <typename TransferOp>
4020  // TODO: support 0-d corner case.
4021  // TODO: Be less conservative.
4022  if (op.getTransferRank() == 0)
4023  return failure();
4024  AffineMap permutationMap = op.getPermutationMap();
4025  bool changed = false;
4026  SmallVector<bool, 4> newInBounds;
4027  newInBounds.reserve(op.getTransferRank());
4028  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
4029  // Already marked as in-bounds, nothing to see here.
4030  if (op.isDimInBounds(i)) {
4031  newInBounds.push_back(true);
4032  continue;
4033  }
4034  // Currently out-of-bounds, check whether we can statically determine it is
4035  // inBounds.
4036  auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
4037  assert(dimExpr && "Broadcast dims must be in-bounds");
4038  auto inBounds =
4039  isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
4040  newInBounds.push_back(inBounds);
4041  // We commit the pattern if it is "more inbounds".
4042  changed |= inBounds;
4043  }
4044  if (!changed)
4045  return failure();
4046  // OpBuilder is only used as a helper to build an I64ArrayAttr.
4047  OpBuilder b(op.getContext());
4048  op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
4049  return success();
4050 }
4051 
4052 template <typename TransferOp>
4053 static LogicalResult foldTransferFullMask(TransferOp op) {
4054  auto mask = op.getMask();
4055  if (!mask)
4056  return failure();
4057 
4058  auto constantMask = mask.template getDefiningOp<vector::ConstantMaskOp>();
4059  if (!constantMask)
4060  return failure();
4061 
4062  if (!constantMask.isAllOnesMask())
4063  return failure();
4064 
4065  op.getMaskMutable().clear();
4066  return success();
4067 }
4068 
4069 /// ```
4070 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4071 /// : vector<1x4xf32>, tensor<4x4xf32>
4072 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
4073 /// : tensor<4x4xf32>, vector<1x4xf32>
4074 /// ```
4075 /// -> Folds into
4076 /// ```
4077 /// %v0
4078 /// ```
4079 static Value foldRAW(TransferReadOp readOp) {
4080  if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4081  return {};
4082  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4083  while (defWrite) {
4084  if (checkSameValueRAW(defWrite, readOp))
4085  return defWrite.getVector();
4087  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4088  cast<VectorTransferOpInterface>(readOp.getOperation())))
4089  break;
4090  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4091  }
4092  return {};
4093 }
4094 
4095 OpFoldResult TransferReadOp::fold(FoldAdaptor) {
4096  if (Value vec = foldRAW(*this))
4097  return vec;
4098  /// transfer_read(memrefcast) -> transfer_read
4100  return getResult();
4101  if (succeeded(foldTransferFullMask(*this)))
4102  return getResult();
4103  if (succeeded(memref::foldMemRefCast(*this)))
4104  return getResult();
4105  if (succeeded(tensor::foldTensorCast(*this)))
4106  return getResult();
4107  return OpFoldResult();
4108 }
4109 
4110 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4111  return llvm::to_vector<4>(getVectorType().getShape());
4112 }
4113 
4114 void TransferReadOp::getEffects(
4116  &effects) {
4117  if (llvm::isa<MemRefType>(getShapedType()))
4118  effects.emplace_back(MemoryEffects::Read::get(), getSource(),
4120 }
4121 
4122 namespace {
4123 /// Store to load forwarding for transfer operations with permuation maps.
4124 /// Even if the permutation maps are different we can still propagate the store
4125 /// into the load if the size of the dimensions read and written match. Then we
4126 /// can replace the transfer_read + transfer_write by vector.broadcast and
4127 /// vector.transpose.
4128 /// Example:
4129 /// ```
4130 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
4131 /// {in_bounds = [true, true],
4132 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
4133 /// vector<4x1xf32>, tensor<4x4x4xf32>
4134 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
4135 /// {in_bounds = [true, true, true, true],
4136 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
4137 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
4138 /// ```
4139 /// To:
4140 /// ```
4141 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
4142 /// %r = vector.transpose %0, [3, 0, 2, 1] :
4143 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
4144 /// ```
4145 struct TransferReadAfterWriteToBroadcast
4146  : public OpRewritePattern<TransferReadOp> {
4148 
4149  LogicalResult matchAndRewrite(TransferReadOp readOp,
4150  PatternRewriter &rewriter) const override {
4151  if (readOp.hasOutOfBoundsDim() ||
4152  !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4153  return failure();
4154  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4155  if (!defWrite)
4156  return failure();
4157  // TODO: If the written transfer chunk is a superset of the read transfer
4158  // chunk we could do an extract_strided_slice.
4159  if (readOp.getTransferChunkAccessed() !=
4160  defWrite.getTransferChunkAccessed())
4161  return failure();
4162  // TODO: Support cases where a dim is explicitly written but implicitly
4163  // read (i.e., a unit dim that is rank reduced).
4164  if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
4165  getUnusedDimsBitVector({defWrite.getPermutationMap()}))
4166  return failure();
4167  if (readOp.getIndices() != defWrite.getIndices() ||
4168  readOp.getMask() != defWrite.getMask())
4169  return failure();
4170  Value vec = defWrite.getVector();
4171  // TODO: loop through the chain of transfer_write if we can prove that they
4172  // don't overlap with the transfer_read. This requires improving
4173  // `isDisjointTransferIndices` helper.
4174  AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4175  AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
4176  AffineMap map = readMap.compose(writeMap);
4177  if (map.getNumResults() == 0)
4178  return failure();
4179  // Calculate the permutation to apply to go from the vector stored to the
4180  // vector read.
4181  SmallVector<unsigned> permutation;
4182  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
4183  return failure();
4184 
4185  Location loc = readOp.getLoc();
4186  // Calculate the broadcast shape by applying the reverse permutation to the
4187  // final shape we want.
4188  ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
4189  SmallVector<int64_t> broadcastShape(destShape.size());
4190  SmallVector<bool> broadcastScalableFlags(destShape.size());
4191  for (const auto &pos : llvm::enumerate(permutation)) {
4192  broadcastShape[pos.value()] = destShape[pos.index()];
4193  broadcastScalableFlags[pos.value()] =
4194  readOp.getVectorType().getScalableDims()[pos.index()];
4195  }
4196  VectorType broadcastedType = VectorType::get(
4197  broadcastShape, defWrite.getVectorType().getElementType(),
4198  broadcastScalableFlags);
4199  vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
4200  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
4201  rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
4202  transposePerm);
4203  return success();
4204  }
4205 };
4206 } // namespace
4207 
4208 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4209  MLIRContext *context) {
4210  results.add<TransferReadAfterWriteToBroadcast>(context);
4211 }
4212 
4213 //===----------------------------------------------------------------------===//
4214 // TransferWriteOp
4215 //===----------------------------------------------------------------------===//
4216 
4217 /// 1. Builder with type inference.
4218 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4219  Value vector, Value dest, ValueRange indices,
4220  AffineMapAttr permutationMapAttr,
4221  /*optional*/ Value mask,
4222  /*optional*/ ArrayAttr inBoundsAttr) {
4223  Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
4224  build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4225  mask, inBoundsAttr);
4226 }
4227 
4228 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
4229 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4230  Value vector, Value dest, ValueRange indices,
4231  AffineMapAttr permutationMapAttr,
4232  /*optional*/ ArrayAttr inBoundsAttr) {
4233  build(builder, result, vector, dest, indices, permutationMapAttr,
4234  /*mask=*/Value(), inBoundsAttr);
4235 }
4236 
4237 /// 3. Builder with type inference that sets an empty mask (variant without
4238 /// attrs)
4239 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4240  Value vector, Value dest, ValueRange indices,
4241  AffineMap permutationMap,
4242  std::optional<ArrayRef<bool>> inBounds) {
4243  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4244  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4245  ? builder.getBoolArrayAttr(inBounds.value())
4246  : ArrayAttr();
4247  build(builder, result, vector, dest, indices, permutationMapAttr,
4248  /*mask=*/Value(), inBoundsAttr);
4249 }
4250 
4251 /// 4. Builder with type inference that sets an empty mask and sets permutation
4252 /// map to 'getMinorIdentityMap'.
4253 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4254  Value vector, Value dest, ValueRange indices,
4255  std::optional<ArrayRef<bool>> inBounds) {
4256  auto vectorType = llvm::cast<VectorType>(vector.getType());
4257  AffineMap permutationMap = getTransferMinorIdentityMap(
4258  llvm::cast<ShapedType>(dest.getType()), vectorType);
4259  build(builder, result, vector, dest, indices, permutationMap, inBounds);
4260 }
4261 
4263  OperationState &result) {
4264  auto &builder = parser.getBuilder();
4265  SMLoc typesLoc;
4266  OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
4268  SmallVector<Type, 2> types;
4270  if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
4271  parser.parseOperand(sourceInfo) ||
4272  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
4273  return failure();
4274  ParseResult hasMask = parser.parseOptionalComma();
4275  if (hasMask.succeeded() && parser.parseOperand(maskInfo))
4276  return failure();
4277  if (parser.parseOptionalAttrDict(result.attributes) ||
4278  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4279  return failure();
4280  if (types.size() != 2)
4281  return parser.emitError(typesLoc, "requires two types");
4282  auto indexType = builder.getIndexType();
4283  VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4284  if (!vectorType)
4285  return parser.emitError(typesLoc, "requires vector type");
4286  ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4287  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4288  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4289  auto permMapAttrName =
4290  TransferWriteOp::getPermutationMapAttrName(result.name);
4291  auto permMapAttr = result.attributes.get(permMapAttrName);
4292  AffineMap permMap;
4293  if (!permMapAttr) {
4294  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4295  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4296  } else {
4297  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4298  }
4299  if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
4300  parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4301  parser.resolveOperands(indexInfo, indexType, result.operands))
4302  return failure();
4303  if (hasMask.succeeded()) {
4304  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4305  return parser.emitError(
4306  maskInfo.location, "does not support masks with vector element type");
4307  if (vectorType.getRank() != permMap.getNumResults()) {
4308  return parser.emitError(typesLoc,
4309  "expected the same rank for the vector and the "
4310  "results of the permutation map");
4311  }
4312  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4313  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4314  return failure();
4315  }
4316  result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4317  builder.getDenseI32ArrayAttr(
4318  {1, 1, static_cast<int32_t>(indexInfo.size()),
4319  static_cast<int32_t>(hasMask.succeeded())}));
4320  return failure(llvm::isa<RankedTensorType>(shapedType) &&
4321  parser.addTypeToList(shapedType, result.types));
4322 }
4323 
4325  p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
4326  if (getMask())
4327  p << ", " << getMask();
4328  printTransferAttrs(p, *this);
4329  p << " : " << getVectorType() << ", " << getShapedType();
4330 }
4331 
4333  // Consistency of elemental types in shape and vector.
4334  ShapedType shapedType = getShapedType();
4335  VectorType vectorType = getVectorType();
4336  VectorType maskType = getMaskType();
4337  auto permutationMap = getPermutationMap();
4338  VectorType inferredMaskType =
4339  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4340  : VectorType();
4341 
4342  if (llvm::size(getIndices()) != shapedType.getRank())
4343  return emitOpError("requires ") << shapedType.getRank() << " indices";
4344 
4345  // We do not allow broadcast dimensions on TransferWriteOps for the moment,
4346  // as the semantics is unclear. This can be revisited later if necessary.
4347  if (hasBroadcastDim())
4348  return emitOpError("should not have broadcast dimensions");
4349 
4350  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4351  shapedType, vectorType, maskType,
4352  inferredMaskType, permutationMap,
4353  getInBounds() ? *getInBounds() : ArrayAttr())))
4354  return failure();
4355 
4356  return verifyPermutationMap(permutationMap,
4357  [&](Twine t) { return emitOpError(t); });
4358 }
4359 
4360 // MaskableOpInterface methods.
4361 
4362 /// Returns the mask type expected by this operation. Mostly used for
4363 /// verification purposes.
4364 Type TransferWriteOp::getExpectedMaskType() {
4365  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4366 }
4367 
4368 /// Fold:
4369 /// ```
4370 /// %t1 = ...
4371 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
4372 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4373 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
4374 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4375 /// ```
4376 ///
4377 /// into:
4378 ///
4379 /// ```
4380 /// %t0
4381 /// ```
4382 ///
4383 /// The producer of t1 may or may not be DCE'd depending on whether it is a
4384 /// block argument or has side effects.
4385 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4387  SmallVectorImpl<OpFoldResult> &results) {
4388  // TODO: support 0-d corner case.
4389  if (write.getTransferRank() == 0)
4390  return failure();
4391  auto rankedTensorType =
4392  llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4393  // If not operating on tensors, bail.
4394  if (!rankedTensorType)
4395  return failure();
4396  // If no read, bail.
4397  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4398  if (!read)
4399  return failure();
4400  // TODO: support 0-d corner case.
4401  if (read.getTransferRank() == 0)
4402  return failure();
4403  // For now, only accept minor identity. Future: composition is minor identity.
4404  if (!read.getPermutationMap().isMinorIdentity() ||
4405  !write.getPermutationMap().isMinorIdentity())
4406  return failure();
4407  // Bail on mismatching ranks.
4408  if (read.getTransferRank() != write.getTransferRank())
4409  return failure();
4410  // Bail on potential out-of-bounds accesses.
4411  if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4412  return failure();
4413  // Tensor types must be the same.
4414  if (read.getSource().getType() != rankedTensorType)
4415  return failure();
4416  // Vector types must be the same.
4417  if (read.getVectorType() != write.getVectorType())
4418  return failure();
4419  // Vector and Tensor shapes must match.
4420  if (read.getVectorType().getShape() != rankedTensorType.getShape())
4421  return failure();
4422  // If any index is nonzero.
4423  auto isNotConstantZero = [](Value v) {
4424  auto cstOp = getConstantIntValue(v);
4425  return !cstOp.has_value() || cstOp.value() != 0;
4426  };
4427  if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4428  llvm::any_of(write.getIndices(), isNotConstantZero))
4429  return failure();
4430  // Success.
4431  results.push_back(read.getSource());
4432  return success();
4433 }
4434 
4435 static bool checkSameValueWAR(vector::TransferReadOp read,
4436  vector::TransferWriteOp write) {
4437  return read.getSource() == write.getSource() &&
4438  read.getIndices() == write.getIndices() &&
4439  read.getPermutationMap() == write.getPermutationMap() &&
4440  read.getVectorType() == write.getVectorType() && !read.getMask() &&
4441  !write.getMask();
4442 }
4443 /// Fold transfer_write write after read:
4444 /// ```
4445 /// %t0 = ...
4446 /// %v = vector.transfer_read %t0[%c0...] :
4447 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4448 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
4449 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4450 /// ```
4451 ///
4452 /// into:
4453 ///
4454 /// ```
4455 /// %t0
4456 /// ```
4457 static LogicalResult foldWAR(TransferWriteOp write,
4458  SmallVectorImpl<OpFoldResult> &results) {
4459  if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4460  return failure();
4461  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4462  if (!read)
4463  return failure();
4464 
4465  if (!checkSameValueWAR(read, write))
4466  return failure();
4467  results.push_back(read.getSource());
4468  return success();
4469 }
4470 
4471 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4472  SmallVectorImpl<OpFoldResult> &results) {
4473  if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
4474  return success();
4475  if (succeeded(foldWAR(*this, results)))
4476  return success();
4478  return success();
4479  if (succeeded(foldTransferFullMask(*this)))
4480  return success();
4481  return memref::foldMemRefCast(*this);
4482 }
4483 
4484 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4485  return llvm::to_vector<4>(getVectorType().getShape());
4486 }
4487 
4488 void TransferWriteOp::getEffects(
4490  &effects) {
4491  if (llvm::isa<MemRefType>(getShapedType()))
4492  effects.emplace_back(MemoryEffects::Write::get(), getSource(),
4494 }
4495 
4496 namespace {
4497 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
4498 /// DCE
4499 /// ```
4500 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4501 /// : vector<1x4xf32>, tensor<4x4xf32>
4502 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
4503 /// : vector<1x4xf32>, tensor<4x4xf32>
4504 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4505 /// : vector<1x4xf32>, tensor<4x4xf32>
4506 /// ```
4507 ///
4508 /// into:
4509 ///
4510 /// ```
4511 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4512 /// : vector<1x4xf32>, tensor<4x4xf32>
4513 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
4514 /// : vector<1x4xf32>, tensor<4x4xf32>
4515 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4516 /// : vector<1x4xf32>, tensor<4x4xf32>
4517 /// ```
4518 ///
4519 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
4520 /// any other uses.
4521 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
4522 public:
4524  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
4525  PatternRewriter &rewriter) const override {
4526  if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4527  return failure();
4528  vector::TransferWriteOp writeToModify = writeOp;
4529 
4530  auto defWrite =
4531  writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4532  while (defWrite) {
4533  if (checkSameValueWAW(writeOp, defWrite)) {
4534  rewriter.modifyOpInPlace(writeToModify, [&]() {
4535  writeToModify.getSourceMutable().assign(defWrite.getSource());
4536  });
4537  return success();
4538  }
4540  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4541  cast<VectorTransferOpInterface>(writeOp.getOperation())))
4542  break;
4543  // If the previous write op doesn't have any other use we an safely look
4544  // at the previous store to see if it can be removed.
4545  if (!defWrite->hasOneUse())
4546  break;
4547  writeToModify = defWrite;
4548  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4549  }
4550  return failure();
4551  }
4552 };
4553 
4554 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
4555 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
4556 /// overwritten and inserted into another tensor. After this rewrite, the
4557 /// operations bufferize in-place since all of them work on the same slice.
4558 ///
4559 /// For example:
4560 /// ```mlir
4561 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
4562 /// : vector<8x16xf32>, tensor<8x16xf32>
4563 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
4564 /// : tensor<8x16xf32> to tensor<?x?xf32>
4565 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4566 /// : tensor<?x?xf32> into tensor<27x37xf32>
4567 /// ```
4568 /// folds to
4569 /// ```mlir
4570 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4571 /// : tensor<27x37xf32> to tensor<?x?xf32>
4572 /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
4573 /// : vector<8x16xf32>, tensor<?x?xf32>
4574 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4575 /// : tensor<?x?xf32> into tensor<27x37xf32>
4576 /// ```
4577 struct SwapExtractSliceOfTransferWrite
4578  : public OpRewritePattern<tensor::InsertSliceOp> {
4579 public:
4581 
4582  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
4583  PatternRewriter &rewriter) const override {
4584  if (!insertOp.hasUnitStride())
4585  return failure();
4586  auto extractOp =
4587  insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4588  if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4589  return failure();
4590  auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4591  if (!transferOp || !transferOp->hasOneUse())
4592  return failure();
4593 
4594  // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
4595  // rank-reducing.
4596  if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4597  return rewriter.notifyMatchFailure(insertOp,
4598  "use-def chain is rank-reducing");
4599  }
4600 
4601  // Fail if tensor::ExtractSliceOp has non-zero offset.
4602  if (!extractOp.hasZeroOffset()) {
4603  return rewriter.notifyMatchFailure(insertOp,
4604  "ExtractSliceOp has non-zero offset");
4605  }
4606 
4607  // Fail if tensor::TransferWriteOp has non-zero offset.
4608  if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
4609  return getConstantIntValue(value) == static_cast<int64_t>(0);
4610  })) {
4611  return rewriter.notifyMatchFailure(insertOp,
4612  "TranferWriteOp has non-zero offset");
4613  }
4614 
4615  // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
4616  if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
4617  return rewriter.notifyMatchFailure(
4618  insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
4619  }
4620 
4621  for (auto [insertSize, extractSize] :
4622  llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
4623  if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
4624  return rewriter.notifyMatchFailure(
4625  insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
4626  }
4627  }
4628 
4629  // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
4630  assert(transferOp.getVectorType().hasStaticShape() &&
4631  "expected vector to have a static shape");
4632  ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
4634  transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
4635  if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
4636  return rewriter.notifyMatchFailure(
4637  insertOp, "TransferWriteOp may not write the full tensor.");
4638  }
4639 
4640  // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
4641  // Set all in_bounds to false and let the folder infer them.
4642  SmallVector<bool> newInBounds(vectorShape.size(), false);
4643  auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
4644  extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
4645  insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
4646  insertOp.getMixedStrides());
4647  auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
4648  transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
4649  transferOp.getIndices(), transferOp.getPermutationMapAttr(),
4650  rewriter.getBoolArrayAttr(newInBounds));
4651  rewriter.modifyOpInPlace(insertOp, [&]() {
4652  insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
4653  });
4654  return success();
4655  }
4656 };
4657 
4658 } // namespace
4659 
4660 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
4661  MLIRContext *context) {
4662  results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
4663 }
4664 
4665 //===----------------------------------------------------------------------===//
4666 // LoadOp
4667 //===----------------------------------------------------------------------===//
4668 
4669 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
4670  MemRefType memRefTy) {
4671  if (!isLastMemrefDimUnitStride(memRefTy))
4672  return op->emitOpError("most minor memref dim must have unit stride");
4673  return success();
4674 }
4675 
4677  VectorType resVecTy = getVectorType();
4678  MemRefType memRefTy = getMemRefType();
4679 
4680  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
4681  return failure();
4682 
4683  // Checks for vector memrefs.
4684  Type memElemTy = memRefTy.getElementType();
4685  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4686  if (memVecTy != resVecTy)
4687  return emitOpError("base memref and result vector types should match");
4688  memElemTy = memVecTy.getElementType();
4689  }
4690 
4691  if (resVecTy.getElementType() != memElemTy)
4692  return emitOpError("base and result element types should match");
4693  if (llvm::size(getIndices()) != memRefTy.getRank())
4694  return emitOpError("requires ") << memRefTy.getRank() << " indices";
4695  return success();
4696 }
4697 
4698 OpFoldResult LoadOp::fold(FoldAdaptor) {
4699  if (succeeded(memref::foldMemRefCast(*this)))
4700  return getResult();
4701  return OpFoldResult();
4702 }
4703 
4704 //===----------------------------------------------------------------------===//
4705 // StoreOp
4706 //===----------------------------------------------------------------------===//
4707 
4709  VectorType valueVecTy = getVectorType();
4710  MemRefType memRefTy = getMemRefType();
4711 
4712  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
4713  return failure();
4714 
4715  // Checks for vector memrefs.
4716  Type memElemTy = memRefTy.getElementType();
4717  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4718  if (memVecTy != valueVecTy)
4719  return emitOpError(
4720  "base memref and valueToStore vector types should match");
4721  memElemTy = memVecTy.getElementType();
4722  }
4723 
4724  if (valueVecTy.getElementType() != memElemTy)
4725  return emitOpError("base and valueToStore element type should match");
4726  if (llvm::size(getIndices()) != memRefTy.getRank())
4727  return emitOpError("requires ") << memRefTy.getRank() << " indices";
4728  return success();
4729 }
4730 
4731 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
4732  SmallVectorImpl<OpFoldResult> &results) {
4733  return memref::foldMemRefCast(*this);
4734 }
4735 
4736 //===----------------------------------------------------------------------===//
4737 // MaskedLoadOp
4738 //===----------------------------------------------------------------------===//
4739 
4741  VectorType maskVType = getMaskVectorType();
4742  VectorType passVType = getPassThruVectorType();
4743  VectorType resVType = getVectorType();
4744  MemRefType memType = getMemRefType();
4745 
4746  if (resVType.getElementType() != memType.getElementType())
4747  return emitOpError("base and result element type should match");
4748  if (llvm::size(getIndices()) != memType.getRank())
4749  return emitOpError("requires ") << memType.getRank() << " indices";
4750  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4751  return emitOpError("expected result dim to match mask dim");
4752  if (resVType != passVType)
4753  return emitOpError("expected pass_thru of same type as result type");
4754  return success();
4755 }
4756 
4757 namespace {
4758 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
4759 public:
4761  LogicalResult matchAndRewrite(MaskedLoadOp load,
4762  PatternRewriter &rewriter) const override {
4763  switch (getMaskFormat(load.getMask())) {
4764  case MaskFormat::AllTrue:
4765  rewriter.replaceOpWithNewOp<vector::LoadOp>(
4766  load, load.getType(), load.getBase(), load.getIndices());
4767  return success();
4768  case MaskFormat::AllFalse:
4769  rewriter.replaceOp(load, load.getPassThru());
4770  return success();
4771  case MaskFormat::Unknown:
4772  return failure();
4773  }
4774  llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
4775  }
4776 };
4777 } // namespace
4778 
4779 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4780  MLIRContext *context) {
4781  results.add<MaskedLoadFolder>(context);
4782 }
4783 
4784 OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
4785  if (succeeded(memref::foldMemRefCast(*this)))
4786  return getResult();
4787  return OpFoldResult();
4788 }
4789 
4790 //===----------------------------------------------------------------------===//
4791 // MaskedStoreOp
4792 //===----------------------------------------------------------------------===//
4793 
4795  VectorType maskVType = getMaskVectorType();
4796  VectorType valueVType = getVectorType();
4797  MemRefType memType = getMemRefType();
4798 
4799  if (valueVType.getElementType() != memType.getElementType())
4800  return emitOpError("base and valueToStore element type should match");
4801  if (llvm::size(getIndices()) != memType.getRank())
4802  return emitOpError("requires ") << memType.getRank() << " indices";
4803  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4804  return emitOpError("expected valueToStore dim to match mask dim");
4805  return success();
4806 }
4807 
4808 namespace {
4809 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
4810 public:
4812  LogicalResult matchAndRewrite(MaskedStoreOp store,
4813  PatternRewriter &rewriter) const override {
4814  switch (getMaskFormat(store.getMask())) {
4815  case MaskFormat::AllTrue:
4816  rewriter.replaceOpWithNewOp<vector::StoreOp>(
4817  store, store.getValueToStore(), store.getBase(), store.getIndices());
4818  return success();
4819  case MaskFormat::AllFalse:
4820  rewriter.eraseOp(store);
4821  return success();
4822  case MaskFormat::Unknown:
4823  return failure();
4824  }
4825  llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
4826  }
4827 };
4828 } // namespace
4829 
4830 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
4831  MLIRContext *context) {
4832  results.add<MaskedStoreFolder>(context);
4833 }
4834 
4835 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
4836  SmallVectorImpl<OpFoldResult> &results) {
4837  return memref::foldMemRefCast(*this);
4838 }
4839 
4840 //===----------------------------------------------------------------------===//
4841 // GatherOp
4842 //===----------------------------------------------------------------------===//
4843 
4845  VectorType indVType = getIndexVectorType();
4846  VectorType maskVType = getMaskVectorType();
4847  VectorType resVType = getVectorType();
4848  ShapedType baseType = getBaseType();
4849 
4850  if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
4851  return emitOpError("requires base to be a memref or ranked tensor type");
4852 
4853  if (resVType.getElementType() != baseType.getElementType())
4854  return emitOpError("base and result element type should match");
4855  if (llvm::size(getIndices()) != baseType.getRank())
4856  return emitOpError("requires ") << baseType.getRank() << " indices";
4857  if (resVType.getShape() != indVType.getShape())
4858  return emitOpError("expected result dim to match indices dim");
4859  if (resVType.getShape() != maskVType.getShape())
4860  return emitOpError("expected result dim to match mask dim");
4861  if (resVType != getPassThruVectorType())
4862  return emitOpError("expected pass_thru of same type as result type");
4863  return success();
4864 }
4865 
4866 // MaskableOpInterface methods.
4867 
4868 /// Returns the mask type expected by this operation. Mostly used for
4869 /// verification purposes. It requires the operation to be vectorized."
4870 Type GatherOp::getExpectedMaskType() {
4871  auto vecType = this->getIndexVectorType();
4872  return VectorType::get(vecType.getShape(),
4873  IntegerType::get(vecType.getContext(), /*width=*/1),
4874  vecType.getScalableDims());
4875 }
4876 
4877 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
4878  return llvm::to_vector<4>(getVectorType().getShape());
4879 }
4880 
4881 namespace {
4882 class GatherFolder final : public OpRewritePattern<GatherOp> {
4883 public:
4885  LogicalResult matchAndRewrite(GatherOp gather,
4886  PatternRewriter &rewriter) const override {
4887  switch (getMaskFormat(gather.getMask())) {
4888  case MaskFormat::AllTrue:
4889  return failure(); // no unmasked equivalent
4890  case MaskFormat::AllFalse:
4891  rewriter.replaceOp(gather, gather.getPassThru());
4892  return success();
4893  case MaskFormat::Unknown:
4894  return failure();
4895  }
4896  llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
4897  }
4898 };
4899 } // namespace
4900 
4901 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
4902  MLIRContext *context) {
4903  results.add<GatherFolder>(context);
4904 }
4905 
4906 //===----------------------------------------------------------------------===//
4907 // ScatterOp
4908 //===----------------------------------------------------------------------===//
4909 
4911  VectorType indVType = getIndexVectorType();
4912  VectorType maskVType = getMaskVectorType();
4913  VectorType valueVType = getVectorType();
4914  MemRefType memType = getMemRefType();
4915 
4916  if (valueVType.getElementType() != memType.getElementType())
4917  return emitOpError("base and valueToStore element type should match");
4918  if (llvm::size(getIndices()) != memType.getRank())
4919  return emitOpError("requires ") << memType.getRank() << " indices";
4920  if (valueVType.getDimSize(0) != indVType.getDimSize(0))
4921  return emitOpError("expected valueToStore dim to match indices dim");
4922  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4923  return emitOpError("expected valueToStore dim to match mask dim");
4924  return success();
4925 }
4926 
4927 namespace {
4928 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
4929 public:
4931  LogicalResult matchAndRewrite(ScatterOp scatter,
4932  PatternRewriter &rewriter) const override {
4933  switch (getMaskFormat(scatter.getMask())) {
4934  case MaskFormat::AllTrue:
4935  return failure(); // no unmasked equivalent
4936  case MaskFormat::AllFalse:
4937  rewriter.eraseOp(scatter);
4938  return success();
4939  case MaskFormat::Unknown:
4940  return failure();
4941  }
4942  llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
4943  }
4944 };
4945 } // namespace
4946 
4947 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
4948  MLIRContext *context) {
4949  results.add<ScatterFolder>(context);
4950 }
4951 
4952 //===----------------------------------------------------------------------===//
4953 // ExpandLoadOp
4954 //===----------------------------------------------------------------------===//
4955 
4957  VectorType maskVType = getMaskVectorType();
4958  VectorType passVType = getPassThruVectorType();
4959  VectorType resVType = getVectorType();
4960  MemRefType memType = getMemRefType();
4961 
4962  if (resVType.getElementType() != memType.getElementType())
4963  return emitOpError("base and result element type should match");
4964  if (llvm::size(getIndices()) != memType.getRank())
4965  return emitOpError("requires ") << memType.getRank() << " indices";
4966  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4967  return emitOpError("expected result dim to match mask dim");
4968  if (resVType != passVType)
4969  return emitOpError("expected pass_thru of same type as result type");
4970  return success();
4971 }
4972 
4973 namespace {
4974 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
4975 public:
4977  LogicalResult matchAndRewrite(ExpandLoadOp expand,
4978  PatternRewriter &rewriter) const override {
4979  switch (getMaskFormat(expand.getMask())) {
4980  case MaskFormat::AllTrue:
4981  rewriter.replaceOpWithNewOp<vector::LoadOp>(
4982  expand, expand.getType(), expand.getBase(), expand.getIndices());
4983  return success();
4984  case MaskFormat::AllFalse:
4985  rewriter.replaceOp(expand, expand.getPassThru());
4986  return success();
4987  case MaskFormat::Unknown:
4988  return failure();
4989  }
4990  llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
4991  }
4992 };
4993 } // namespace
4994 
4995 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4996  MLIRContext *context) {
4997  results.add<ExpandLoadFolder>(context);
4998 }
4999 
5000 //===----------------------------------------------------------------------===//
5001 // CompressStoreOp
5002 //===----------------------------------------------------------------------===//
5003 
5005  VectorType maskVType = getMaskVectorType();
5006  VectorType valueVType = getVectorType();
5007  MemRefType memType = getMemRefType();
5008 
5009  if (valueVType.getElementType() != memType.getElementType())
5010  return emitOpError("base and valueToStore element type should match");
5011  if (llvm::size(getIndices()) != memType.getRank())
5012  return emitOpError("requires ") << memType.getRank() << " indices";
5013  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5014  return emitOpError("expected valueToStore dim to match mask dim");
5015  return success();
5016 }
5017 
5018 namespace {
5019 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
5020 public:
5022  LogicalResult matchAndRewrite(CompressStoreOp compress,
5023  PatternRewriter &rewriter) const override {
5024  switch (getMaskFormat(compress.getMask())) {
5025  case MaskFormat::AllTrue:
5026  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5027  compress, compress.getValueToStore(), compress.getBase(),
5028  compress.getIndices());
5029  return success();
5030  case MaskFormat::AllFalse:
5031  rewriter.eraseOp(compress);
5032  return success();
5033  case MaskFormat::Unknown:
5034  return failure();
5035  }
5036  llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
5037  }
5038 };
5039 } // namespace
5040 
5041 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5042  MLIRContext *context) {
5043  results.add<CompressStoreFolder>(context);
5044 }
5045 
5046 //===----------------------------------------------------------------------===//
5047 // ShapeCastOp
5048 //===----------------------------------------------------------------------===//
5049 
5050 /// Returns true if each element of 'a' is equal to the product of a contiguous
5051 /// sequence of the elements of 'b'. Returns false otherwise.
5052 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
5053  unsigned rankA = a.size();
5054  unsigned rankB = b.size();
5055  assert(rankA < rankB);
5056 
5057  auto isOne = [](int64_t v) { return v == 1; };
5058 
5059  // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5060  // casted to a 0-d vector.
5061  if (rankA == 0 && llvm::all_of(b, isOne))
5062  return true;
5063 
5064  unsigned i = 0;
5065  unsigned j = 0;
5066  while (i < rankA && j < rankB) {
5067  int64_t dimA = a[i];
5068  int64_t dimB = 1;
5069  while (dimB < dimA && j < rankB)
5070  dimB *= b[j++];
5071  if (dimA != dimB)
5072  break;
5073  ++i;
5074 
5075  // Handle the case when trailing dimensions are of size 1.
5076  // Include them into the contiguous sequence.
5077  if (i < rankA && llvm::all_of(a.slice(i), isOne))
5078  i = rankA;
5079  if (j < rankB && llvm::all_of(b.slice(j), isOne))
5080  j = rankB;
5081  }
5082 
5083  return i == rankA && j == rankB;
5084 }
5085 
5086 static LogicalResult verifyVectorShapeCast(Operation *op,
5087  VectorType sourceVectorType,
5088  VectorType resultVectorType) {
5089  // Check that element type is the same.
5090  if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5091  return op->emitOpError("source/result vectors must have same element type");
5092  auto sourceShape = sourceVectorType.getShape();
5093  auto resultShape = resultVectorType.getShape();
5094 
5095  // Check that product of source dim sizes matches product of result dim sizes.
5096  int64_t sourceDimProduct = std::accumulate(
5097  sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5098  int64_t resultDimProduct = std::accumulate(
5099  resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5100  if (sourceDimProduct != resultDimProduct)
5101  return op->emitOpError("source/result number of elements must match");
5102 
5103  // Check that expanding/contracting rank cases.
5104  unsigned sourceRank = sourceVectorType.getRank();
5105  unsigned resultRank = resultVectorType.getRank();
5106  if (sourceRank < resultRank) {
5107  if (!isValidShapeCast(sourceShape, resultShape))
5108  return op->emitOpError("invalid shape cast");
5109  } else if (sourceRank > resultRank) {
5110  if (!isValidShapeCast(resultShape, sourceShape))
5111  return op->emitOpError("invalid shape cast");
5112  }
5113  return success();
5114 }
5115 
5117  auto sourceVectorType =
5118  llvm::dyn_cast_or_null<VectorType>(getSource().getType());
5119  auto resultVectorType =
5120  llvm::dyn_cast_or_null<VectorType>(getResult().getType());
5121 
5122  // Check if source/result are of vector type.
5123  if (sourceVectorType && resultVectorType)
5124  return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
5125 
5126  return success();
5127 }
5128 
5129 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5130  // No-op shape cast.
5131  if (getSource().getType() == getResult().getType())
5132  return getSource();
5133 
5134  // Canceling shape casts.
5135  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5136  if (getResult().getType() == otherOp.getSource().getType())
5137  return otherOp.getSource();
5138 
5139  // Only allows valid transitive folding.
5140  VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5141  VectorType resultType = llvm::cast<VectorType>(getResult().getType());
5142  if (srcType.getRank() < resultType.getRank()) {
5143  if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5144  return {};
5145  } else if (srcType.getRank() > resultType.getRank()) {
5146  if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5147  return {};
5148  } else {
5149  return {};
5150  }
5151 
5152  setOperand(otherOp.getSource());
5153  return getResult();
5154  }
5155 
5156  // Cancelling broadcast and shape cast ops.
5157  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5158  if (bcastOp.getSourceType() == getType())
5159  return bcastOp.getSource();
5160  }
5161 
5162  return {};
5163 }
5164 
5165 namespace {
5166 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
5167 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
5168 public:
5170 
5171  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5172  PatternRewriter &rewriter) const override {
5173  auto constantOp =
5174  shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5175  if (!constantOp)
5176  return failure();
5177  // Only handle splat for now.
5178  auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5179  if (!dense)
5180  return failure();
5181  auto newAttr =
5182  DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
5183  dense.getSplatValue<Attribute>());
5184  rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
5185  return success();
5186  }
5187 };
5188 
5189 /// Helper function that computes a new vector type based on the input vector
5190 /// type by removing the trailing one dims:
5191 ///
5192 /// vector<4x1x1xi1> --> vector<4x1>
5193 ///
5194 static VectorType trimTrailingOneDims(VectorType oldType) {
5195  ArrayRef<int64_t> oldShape = oldType.getShape();
5196  ArrayRef<int64_t> newShape = oldShape;
5197 
5198  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
5199  ArrayRef<bool> newScalableDims = oldScalableDims;
5200 
5201  while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5202  newShape = newShape.drop_back(1);
5203  newScalableDims = newScalableDims.drop_back(1);
5204  }
5205 
5206  // Make sure we have at least 1 dimension.
5207  // TODO: Add support for 0-D vectors.
5208  if (newShape.empty()) {
5209  newShape = oldShape.take_back();
5210  newScalableDims = oldScalableDims.take_back();
5211  }
5212 
5213  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5214 }
5215 
5216 /// Folds qualifying shape_cast(create_mask) into a new create_mask
5217 ///
5218 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
5219 /// dimension. If the input vector comes from `vector.create_mask` for which
5220 /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
5221 /// to fold shape_cast into create_mask.
5222 ///
5223 /// BEFORE:
5224 /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
5225 /// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
5226 /// AFTER:
5227 /// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
5228 class ShapeCastCreateMaskFolderTrailingOneDim final
5229  : public OpRewritePattern<ShapeCastOp> {
5230 public:
5232 
5233  LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
5234  PatternRewriter &rewriter) const override {
5235  Value shapeOpSrc = shapeOp->getOperand(0);
5236  auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
5237  auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
5238  if (!createMaskOp && !constantMaskOp)
5239  return failure();
5240 
5241  VectorType shapeOpResTy = shapeOp.getResultVectorType();
5242  VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5243 
5244  VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5245  if (newVecType != shapeOpResTy)
5246  return failure();
5247 
5248  auto numDimsToDrop =
5249  shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5250 
5251  // No unit dims to drop
5252  if (!numDimsToDrop)
5253  return failure();
5254 
5255  if (createMaskOp) {
5256  auto maskOperands = createMaskOp.getOperands();
5257  auto numMaskOperands = maskOperands.size();
5258 
5259  // Check every mask dim size to see whether it can be dropped
5260  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5261  --i) {
5262  auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5263  if (!constant || (constant.value() != 1))
5264  return failure();
5265  }
5266  SmallVector<Value> newMaskOperands =
5267  maskOperands.drop_back(numDimsToDrop);
5268 
5269  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
5270  newMaskOperands);
5271  return success();
5272  }
5273 
5274  if (constantMaskOp) {
5275  auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
5276  auto numMaskOperands = maskDimSizes.size();
5277 
5278  // Check every mask dim size to see whether it can be dropped
5279  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5280  --i) {
5281  if (cast<IntegerAttr>(maskDimSizes[i]).getValue() != 1)
5282  return failure();
5283  }
5284 
5285  auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5286  ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr(newMaskOperands);
5287 
5288  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
5289  newMaskOperandsAttr);
5290  return success();
5291  }
5292 
5293  return failure();
5294  }
5295 };
5296 
5297 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
5298 /// This only applies when the shape of the broadcast source
5299 /// 1. is a suffix of the shape of the result (i.e. when broadcast without
5300 /// reshape is expressive enough to capture the result in a single op), or
5301 /// 2. has the same element count as the shape cast result.
5302 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5303 public:
5305 
5306  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5307  PatternRewriter &rewriter) const override {
5308  auto broadcastOp =
5309  shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5310  if (!broadcastOp)
5311  return failure();
5312 
5313  ArrayRef<int64_t> broadcastSourceShape;
5314  if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5315  broadcastSourceShape = srcType.getShape();
5316  ArrayRef<int64_t> shapeCastTargetShape =
5317  shapeCastOp.getResultVectorType().getShape();
5318 
5319  // If `broadcastSourceShape` is a suffix of the result, we can just replace
5320  // with a broadcast to the final shape.
5321  if (broadcastSourceShape ==
5322  shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5323  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5324  shapeCastOp, shapeCastOp.getResultVectorType(),
5325  broadcastOp.getSource());
5326  return success();
5327  }
5328 
5329  // Otherwise, if the final result has the same element count, we can replace
5330  // with a shape cast.
5331  if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5332  if (srcType.getNumElements() ==
5333  shapeCastOp.getResultVectorType().getNumElements()) {
5334  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
5335  shapeCastOp, shapeCastOp.getResultVectorType(),
5336  broadcastOp.getSource());
5337  return success();
5338  }
5339  }
5340 
5341  return failure();
5342  }
5343 };
5344 
5345 } // namespace
5346 
5347 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
5348  MLIRContext *context) {
5349  results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5350  ShapeCastBroadcastFolder>(context);
5351 }
5352 
5353 //===----------------------------------------------------------------------===//
5354 // VectorBitCastOp
5355 //===----------------------------------------------------------------------===//
5356 
5358  auto sourceVectorType = getSourceVectorType();
5359  auto resultVectorType = getResultVectorType();
5360 
5361  for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5362  if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5363  return emitOpError("dimension size mismatch at: ") << i;
5364  }
5365 
5366  DataLayout dataLayout = DataLayout::closest(*this);
5367  auto sourceElementBits =
5368  dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
5369  auto resultElementBits =
5370  dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
5371 
5372  if (sourceVectorType.getRank() == 0) {
5373  if (sourceElementBits != resultElementBits)
5374  return emitOpError("source/result bitwidth of the 0-D vector element "
5375  "types must be equal");
5376  } else if (sourceElementBits * sourceVectorType.getShape().back() !=
5377  resultElementBits * resultVectorType.getShape().back()) {
5378  return emitOpError(
5379  "source/result bitwidth of the minor 1-D vectors must be equal");
5380  }
5381 
5382  return success();
5383 }
5384 
5385 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
5386  // Nop cast.
5387  if (getSource().getType() == getResult().getType())
5388  return getSource();
5389 
5390  // Canceling bitcasts.
5391  if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5392  if (getResult().getType() == otherOp.getSource().getType())
5393  return otherOp.getSource();
5394 
5395  setOperand(otherOp.getSource());
5396  return getResult();
5397  }
5398 
5399  Attribute sourceConstant = adaptor.getSource();
5400  if (!sourceConstant)
5401  return {};
5402 
5403  Type srcElemType = getSourceVectorType().getElementType();
5404  Type dstElemType = getResultVectorType().getElementType();
5405 
5406  if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5407  if (floatPack.isSplat()) {
5408  auto splat = floatPack.getSplatValue<FloatAttr>();
5409 
5410  // Casting fp16 into fp32.
5411  if (srcElemType.isF16() && dstElemType.isF32()) {
5412  uint32_t bits = static_cast<uint32_t>(
5413  splat.getValue().bitcastToAPInt().getZExtValue());
5414  // Duplicate the 16-bit pattern.
5415  bits = (bits << 16) | (bits & 0xffff);
5416  APInt intBits(32, bits);
5417  APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5418  return DenseElementsAttr::get(getResultVectorType(), floatBits);
5419  }
5420  }
5421  }
5422 
5423  if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5424  if (intPack.isSplat()) {
5425  auto splat = intPack.getSplatValue<IntegerAttr>();
5426 
5427  if (llvm::isa<IntegerType>(dstElemType)) {
5428  uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
5429  uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
5430 
5431  // Casting to a larger integer bit width.
5432  if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5433  APInt intBits = splat.getValue().zext(dstBitWidth);
5434 
5435  // Duplicate the lower width element.
5436  for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5437  intBits = (intBits << srcBitWidth) | intBits;
5438  return DenseElementsAttr::get(getResultVectorType(), intBits);
5439  }
5440  }
5441  }
5442  }
5443 
5444  return {};
5445 }
5446 
5447 //===----------------------------------------------------------------------===//
5448 // TypeCastOp
5449 //===----------------------------------------------------------------------===//
5450 
5451 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
5452  auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5453  SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
5454  memRefType.getShape().end());
5455  if (vectorType)
5456  res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5457  return res;
5458 }
5459 
5460 /// Build the canonical memRefType with a single vector.
5461 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
5462 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
5463  Value source) {
5464  result.addOperands(source);
5465  MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
5466  VectorType vectorType =
5467  VectorType::get(extractShape(memRefType),
5469  result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
5470  memRefType.getMemorySpace()));
5471 }
5472 
5474  MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType());
5475  if (!canonicalType.getLayout().isIdentity())
5476  return emitOpError("expects operand to be a memref with identity layout");
5477  if (!getResultMemRefType().getLayout().isIdentity())
5478  return emitOpError("expects result to be a memref with identity layout");
5479  if (getResultMemRefType().getMemorySpace() !=
5480  getMemRefType().getMemorySpace())
5481  return emitOpError("expects result in same memory space");
5482 
5483  auto sourceType = getMemRefType();
5484  auto resultType = getResultMemRefType();
5485  if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
5487  return emitOpError(
5488  "expects result and operand with same underlying scalar type: ")
5489  << resultType;
5490  if (extractShape(sourceType) != extractShape(resultType))
5491  return emitOpError(
5492  "expects concatenated result and operand shapes to be equal: ")
5493  << resultType;
5494  return success();
5495 }
5496 
5497 //===----------------------------------------------------------------------===//
5498 // TransposeOp
5499 //===----------------------------------------------------------------------===//
5500 
5501 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
5502  Value vector, ArrayRef<int64_t> permutation) {
5503  VectorType vt = llvm::cast<VectorType>(vector.getType());
5504  SmallVector<int64_t, 4> transposedShape(vt.getRank());
5505  SmallVector<bool, 4> transposedScalableDims(vt.getRank());
5506  for (unsigned i = 0; i < permutation.size(); ++i) {
5507  transposedShape[i] = vt.getShape()[permutation[i]];
5508  transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5509  }
5510 
5511  result.addOperands(vector);
5512  result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
5513  transposedScalableDims));
5514  result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
5515  builder.getDenseI64ArrayAttr(permutation));
5516 }
5517 
5518 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5519  // Eliminate splat constant transpose ops.
5520  if (auto attr =
5521  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
5522  if (attr.isSplat())
5523  return attr.reshape(getResultVectorType());
5524 
5525  // Eliminate identity transpose ops. This happens when the dimensions of the
5526  // input vector remain in their original order after the transpose operation.
5527  ArrayRef<int64_t> perm = getPermutation();
5528 
5529  // Check if the permutation of the dimensions contains sequential values:
5530  // {0, 1, 2, ...}.
5531  for (int64_t i = 0, e = perm.size(); i < e; i++) {
5532  if (perm[i] != i)
5533  return {};
5534  }
5535 
5536  return getVector();
5537 }
5538 
5540  VectorType vectorType = getSourceVectorType();
5541  VectorType resultType = getResultVectorType();
5542  int64_t rank = resultType.getRank();
5543  if (vectorType.getRank() != rank)
5544  return emitOpError("vector result rank mismatch: ") << rank;
5545  // Verify transposition array.
5546  ArrayRef<int64_t> perm = getPermutation();
5547  int64_t size = perm.size();
5548  if (rank != size)
5549  return emitOpError("transposition length mismatch: ") << size;
5550  SmallVector<bool, 8> seen(rank, false);
5551  for (const auto &ta : llvm::enumerate(perm)) {
5552  if (ta.value() < 0 || ta.value() >= rank)
5553  return emitOpError("transposition index out of range: ") << ta.value();
5554  if (seen[ta.value()])
5555  return emitOpError("duplicate position index: ") << ta.value();
5556  seen[ta.value()] = true;
5557  if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
5558  return emitOpError("dimension size mismatch at: ") << ta.value();
5559  }
5560  return success();
5561 }
5562 
5563 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
5564  return llvm::to_vector<4>(getResultVectorType().getShape());
5565 }
5566 
5567 namespace {
5568 
5569 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
5570 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
5571 public:
5573 
5574  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5575  PatternRewriter &rewriter) const override {
5576  // Composes two permutations: result[i] = permutation1[permutation2[i]].
5577  auto composePermutations = [](ArrayRef<int64_t> permutation1,
5578  ArrayRef<int64_t> permutation2) {
5579  SmallVector<int64_t, 4> result;
5580  for (auto index : permutation2)
5581  result.push_back(permutation1[index]);
5582  return result;
5583  };
5584 
5585  // Return if the input of 'transposeOp' is not defined by another transpose.
5586  vector::TransposeOp parentTransposeOp =
5587  transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
5588  if (!parentTransposeOp)
5589  return failure();
5590 
5591  SmallVector<int64_t, 4> permutation = composePermutations(
5592  parentTransposeOp.getPermutation(), transposeOp.getPermutation());
5593  // Replace 'transposeOp' with a new transpose operation.
5594  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
5595  transposeOp, transposeOp.getResult().getType(),
5596  parentTransposeOp.getVector(), permutation);
5597  return success();
5598  }
5599 };
5600 
5601 // Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
5602 struct FoldTransposedScalarBroadcast final
5603  : public OpRewritePattern<vector::TransposeOp> {
5605 
5606  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5607  PatternRewriter &rewriter) const override {
5608  auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
5609  if (!bcastOp)
5610  return failure();
5611 
5612  auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
5613  if (!srcVectorType || srcVectorType.getNumElements() == 1) {
5614  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5615  transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
5616  return success();
5617  }
5618 
5619  return failure();
5620  }
5621 };
5622 
5623 // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
5624 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
5625 public:
5627 
5628  LogicalResult matchAndRewrite(TransposeOp transposeOp,
5629  PatternRewriter &rewriter) const override {
5630  auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
5631  if (!splatOp)
5632  return failure();
5633 
5634  rewriter.replaceOpWithNewOp<vector::SplatOp>(
5635  transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
5636  return success();
5637  }
5638 };
5639 
5640 /// Folds transpose(create_mask) into a new transposed create_mask.
5641 class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
5642 public:
5644 
5645  LogicalResult matchAndRewrite(TransposeOp transpOp,
5646  PatternRewriter &rewriter) const override {
5647  Value transposeSrc = transpOp.getVector();
5648  auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
5649  auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
5650  if (!createMaskOp && !constantMaskOp)
5651  return failure();
5652 
5653  // Get the transpose permutation and apply it to the vector.create_mask or
5654  // vector.constant_mask operands.
5655  ArrayRef<int64_t> permutation = transpOp.getPermutation();
5656 
5657  if (createMaskOp) {
5658  auto maskOperands = createMaskOp.getOperands();
5659  SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
5660  applyPermutationToVector(newOperands, permutation);
5661 
5662  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
5663  transpOp, transpOp.getResultVectorType(), newOperands);
5664  return success();
5665  }
5666 
5667  // ConstantMaskOp case.
5668  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5669  SmallVector<Attribute> newMaskDimSizes(maskDimSizes.getValue());
5670  applyPermutationToVector(newMaskDimSizes, permutation);
5671 
5672  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
5673  transpOp, transpOp.getResultVectorType(),
5674  ArrayAttr::get(transpOp.getContext(), newMaskDimSizes));
5675  return success();
5676  }
5677 };
5678 
5679 } // namespace
5680 
5681 void vector::TransposeOp::getCanonicalizationPatterns(
5682  RewritePatternSet &results, MLIRContext *context) {
5683  results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
5684  TransposeFolder, FoldTransposeSplat>(context);
5685 }
5686 
5687 //===----------------------------------------------------------------------===//
5688 // ConstantMaskOp
5689 //===----------------------------------------------------------------------===//
5690 
5692  auto resultType = llvm::cast<VectorType>(getResult().getType());
5693  // Check the corner case of 0-D vectors first.
5694  if (resultType.getRank() == 0) {
5695  if (getMaskDimSizes().size() != 1)
5696  return emitError("array attr must have length 1 for 0-D vectors");
5697  auto dim = llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt();
5698  if (dim != 0 && dim != 1)
5699  return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
5700  return success();
5701  }
5702 
5703  // Verify that array attr size matches the rank of the vector result.
5704  if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
5705  return emitOpError(
5706  "must specify array attr of size equal vector result rank");
5707  // Verify that each array attr element is in bounds of corresponding vector
5708  // result dimension size.
5709  auto resultShape = resultType.getShape();
5710  auto resultScalableDims = resultType.getScalableDims();
5711  SmallVector<int64_t, 4> maskDimSizes;
5712  for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
5713  int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
5714  if (maskDimSize < 0 || maskDimSize > resultShape[index])
5715  return emitOpError(
5716  "array attr of size out of bounds of vector result dimension size");
5717  if (resultScalableDims[index] && maskDimSize != 0 &&
5718  maskDimSize != resultShape[index])
5719  return emitOpError(
5720  "only supports 'none set' or 'all set' scalable dimensions");
5721  maskDimSizes.push_back(maskDimSize);
5722  }
5723  // Verify that if one mask dim size is zero, they all should be zero (because
5724  // the mask region is a conjunction of each mask dimension interval).
5725  bool anyZeros = llvm::is_contained(maskDimSizes, 0);
5726  bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
5727  if (anyZeros && !allZeros)
5728  return emitOpError("expected all mask dim sizes to be zeros, "
5729  "as a result of conjunction with zero mask dim");
5730  return success();
5731 }
5732 
5733 bool ConstantMaskOp::isAllOnesMask() {
5734  auto resultType = getVectorType();
5735  // Check the corner case of 0-D vectors first.
5736  if (resultType.getRank() == 0) {
5737  assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
5738  return llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() == 1;
5739  }
5740  for (const auto [resultSize, intAttr] :
5741  llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
5742  int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
5743  if (maskDimSize < resultSize)
5744  return false;
5745  }
5746  return true;
5747 }
5748 
5749 //===----------------------------------------------------------------------===//
5750 // CreateMaskOp
5751 //===----------------------------------------------------------------------===//
5752 
5753 void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
5754  VectorType type,
5755  ArrayRef<OpFoldResult> mixedOperands) {
5756  SmallVector<Value> operands =
5757  getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
5758  build(builder, result, type, operands);
5759 }
5760 
5762  auto vectorType = llvm::cast<VectorType>(getResult().getType());
5763  // Verify that an operand was specified for each result vector each dimension.
5764  if (vectorType.getRank() == 0) {
5765  if (getNumOperands() != 1)
5766  return emitOpError(
5767  "must specify exactly one operand for 0-D create_mask");
5768  } else if (getNumOperands() !=
5769  llvm::cast<VectorType>(getResult().getType()).getRank()) {
5770  return emitOpError(
5771  "must specify an operand for each result vector dimension");
5772  }
5773  return success();
5774 }
5775 
5776 namespace {
5777 
5778 /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
5779 ///
5780 /// Ex 1:
5781 /// %c2 = arith.constant 2 : index
5782 /// %c3 = arith.constant 3 : index
5783 /// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
5784 /// Becomes:
5785 /// vector.constant_mask [3, 2] : vector<4x3xi1>
5786 ///
5787 /// Ex 2:
5788 /// %c_neg_1 = arith.constant -1 : index
5789 /// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
5790 /// becomes:
5791 /// vector.constant_mask [0] : vector<[8]xi1>
5792 ///
5793 /// Ex 3:
5794 /// %c8 = arith.constant 8 : index
5795 /// %c16 = arith.constant 16 : index
5796 /// %0 = vector.vscale
5797 /// %1 = arith.muli %0, %c16 : index
5798 /// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
5799 /// becomes:
5800 /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
5801 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
5802 public:
5804 
5805  LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
5806  PatternRewriter &rewriter) const override {
5807  VectorType retTy = createMaskOp.getResult().getType();
5808  bool isScalable = retTy.isScalable();
5809 
5810  // Check every mask operand
5811  for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
5812  if (auto cst = getConstantIntValue(operand)) {
5813  // Most basic case - this operand is a constant value. Note that for
5814  // scalable dimensions, CreateMaskOp can be folded only if the
5815  // corresponding operand is negative or zero.
5816  if (retTy.getScalableDims()[opIdx] && *cst > 0)
5817  return failure();
5818 
5819  continue;
5820  }
5821 
5822  // Non-constant operands are not allowed for non-scalable vectors.
5823  if (!isScalable)
5824  return failure();
5825 
5826  // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
5827  // true" mask, so can also be treated as constant.
5828  auto mul = operand.getDefiningOp<arith::MulIOp>();
5829  if (!mul)
5830  return failure();
5831  auto mulLHS = mul.getRhs();
5832  auto mulRHS = mul.getLhs();
5833  bool isOneOpVscale =
5834  (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
5835  isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));
5836 
5837  auto isConstantValMatchingDim =
5838  [=, dim = retTy.getShape()[opIdx]](Value operand) {
5839  auto constantVal = getConstantIntValue(operand);
5840  return (constantVal.has_value() && constantVal.value() == dim);
5841  };
5842 
5843  bool isOneOpConstantMatchingDim =
5844  isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);
5845 
5846  if (!isOneOpVscale || !isOneOpConstantMatchingDim)
5847  return failure();
5848  }
5849 
5850  // Gather constant mask dimension sizes.
5851  SmallVector<int64_t, 4> maskDimSizes;
5852  maskDimSizes.reserve(createMaskOp->getNumOperands());
5853  for (auto [operand, maxDimSize] : llvm::zip_equal(
5854  createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
5855  std::optional dimSize = getConstantIntValue(operand);
5856  if (!dimSize) {
5857  // Although not a constant, it is safe to assume that `operand` is
5858  // "vscale * maxDimSize".
5859  maskDimSizes.push_back(maxDimSize);
5860  continue;
5861  }
5862  int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize);
5863  // If one of dim sizes is zero, set all dims to zero.
5864  if (dimSize <= 0) {
5865  maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
5866  break;
5867  }
5868  maskDimSizes.push_back(dimSizeVal);
5869  }
5870 
5871  // Replace 'createMaskOp' with ConstantMaskOp.
5872  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
5873  createMaskOp, retTy,
5874  vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
5875  return success();
5876  }
5877 };
5878 
5879 } // namespace
5880 
5881 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
5882  MLIRContext *context) {
5883  results.add<CreateMaskFolder>(context);
5884 }
5885 
5886 //===----------------------------------------------------------------------===//
5887 // MaskOp
5888 //===----------------------------------------------------------------------===//
5889 
5890 void MaskOp::build(
5891  OpBuilder &builder, OperationState &result, Value mask,
5892  Operation *maskableOp,
5893  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
5894  assert(maskRegionBuilder &&
5895  "builder callback for 'maskRegion' must be present");
5896 
5897  result.addOperands(mask);
5898  OpBuilder::InsertionGuard guard(builder);
5899  Region *maskRegion = result.addRegion();
5900  builder.createBlock(maskRegion);
5901  maskRegionBuilder(builder, maskableOp);
5902 }
5903 
5904 void MaskOp::build(
5905  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
5906  Value mask, Operation *maskableOp,
5907  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
5908  build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
5909  maskRegionBuilder);
5910 }
5911 
5912 void MaskOp::build(
5913  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
5914  Value mask, Value passthru, Operation *maskableOp,
5915  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
5916  build(builder, result, mask, maskableOp, maskRegionBuilder);
5917  if (passthru)
5918  result.addOperands(passthru);
5919  result.addTypes(resultTypes);
5920 }
5921 
5923  // Create the op region.
5924  result.regions.reserve(1);
5925  Region &maskRegion = *result.addRegion();
5926 
5927  auto &builder = parser.getBuilder();
5928 
5929  // Parse all the operands.
5931  if (parser.parseOperand(mask))
5932  return failure();
5933 
5934  // Optional passthru operand.
5936  ParseResult parsePassthru = parser.parseOptionalComma();
5937  if (parsePassthru.succeeded() && parser.parseOperand(passthru))
5938  return failure();
5939 
5940