MLIR  22.0.0git
VectorOps.cpp
Go to the documentation of this file.
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/AffineMap.h"
28 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/IRMapping.h"
34 #include "mlir/IR/PatternMatch.h"
35 #include "mlir/IR/TypeUtilities.h"
36 #include "mlir/IR/ValueRange.h"
39 #include "mlir/Support/LLVM.h"
41 #include "llvm/ADT/ArrayRef.h"
42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/ADT/SmallVector.h"
44 #include "llvm/ADT/StringSet.h"
45 #include "llvm/ADT/TypeSwitch.h"
46 #include "llvm/Support/Casting.h"
47 
48 #include <cassert>
49 #include <cstdint>
50 #include <numeric>
51 
52 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
53 // Pull in all enum type and utility function definitions.
54 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
55 
56 using namespace mlir;
57 using namespace mlir::vector;
58 
59 /// Helper enum to classify mask value.
60 enum class MaskFormat {
61  AllTrue = 0,
62  AllFalse = 1,
63  Unknown = 2,
64 };
65 
66 /// Helper method to classify a mask value. Currently, the method
67 /// looks "under the hood" of a constant value with dense attributes
68 /// and a constant mask operation (since the client may be called at
69 /// various stages during progressive lowering).
71  if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
72  // Inspect constant dense values. We count up for bits that
73  // are set, count down for bits that are cleared, and bail
74  // when a mix is detected.
75  if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
76  int64_t val = 0;
77  for (bool b : denseElts.getValues<bool>())
78  if (b && val >= 0)
79  val++;
80  else if (!b && val <= 0)
81  val--;
82  else
83  return MaskFormat::Unknown;
84  if (val > 0)
85  return MaskFormat::AllTrue;
86  if (val < 0)
87  return MaskFormat::AllFalse;
88  }
89  } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
90  // Inspect constant mask index. If the index exceeds the
91  // dimension size, all bits are set. If the index is zero
92  // or less, no bits are set.
93  ArrayRef<int64_t> masks = m.getMaskDimSizes();
94  auto shape = m.getType().getShape();
95  bool allTrue = true;
96  bool allFalse = true;
97  for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
98  if (maskIdx < dimSize)
99  allTrue = false;
100  if (maskIdx > 0)
101  allFalse = false;
102  }
103  if (allTrue)
104  return MaskFormat::AllTrue;
105  if (allFalse)
106  return MaskFormat::AllFalse;
107  } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
108  // Finds all-false create_masks. An all-true create_mask requires all
109  // dims to be constants, so that'll be folded to a constant_mask, then
110  // detected in the constant_mask case.
111  auto maskOperands = m.getOperands();
112  for (Value operand : maskOperands) {
113  if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
114  int64_t dimSize =
115  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
116  if (dimSize <= 0)
117  return MaskFormat::AllFalse;
118  }
119  }
120  return MaskFormat::Unknown;
121  }
122  return MaskFormat::Unknown;
123 }
124 
125 /// Default callback to build a region with a 'vector.yield' terminator with no
126 /// arguments.
128  vector::YieldOp::create(builder, loc);
129 }
130 
131 // Helper for verifying combining kinds in contractions and reductions.
132 static bool isSupportedCombiningKind(CombiningKind combiningKind,
133  Type elementType) {
134  switch (combiningKind) {
135  case CombiningKind::ADD:
136  case CombiningKind::MUL:
137  return elementType.isIntOrIndexOrFloat();
139  case CombiningKind::MINSI:
140  case CombiningKind::MAXUI:
141  case CombiningKind::MAXSI:
142  case CombiningKind::AND:
143  case CombiningKind::OR:
144  case CombiningKind::XOR:
145  return elementType.isIntOrIndex();
146  case CombiningKind::MINNUMF:
147  case CombiningKind::MAXNUMF:
148  case CombiningKind::MINIMUMF:
149  case CombiningKind::MAXIMUMF:
150  return llvm::isa<FloatType>(elementType);
151  }
152  return false;
153 }
154 
155 /// Returns the effective rank of the vector to read/write for Xfer Ops
156 ///
157 /// When the element type of the shaped type is _a scalar_, this will simply
158 /// return the rank of the vector ( the result for xfer_read or the value to
159 /// store for xfer_write).
160 ///
161 /// When the element type of the base shaped type is _a vector_, returns the
162 /// difference between the original vector type and the element type of the
163 /// shaped type.
164 ///
165 /// EXAMPLE 1 (element type is _a scalar_):
166 /// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
167 /// - shapedType.getElementType() = f32 (rank 0)
168 /// - vectorType.getRank() = 2
169 /// - Result = 2 - 0 = 2
170 ///
171 /// EXAMPLE 2 (element type is _a vector_):
172 /// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
173 /// - shapedType.getElementType() = vector<20xf32> (rank 1)
174 /// - vectorType.getRank() = 1
175 /// - Result = 1 - 1 = 0
176 ///
177 /// This is used to determine the number of minor dimensions for identity maps
178 /// in vector transfer Ops.
179 static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
180  VectorType vectorType) {
181  unsigned elementVectorRank = 0;
182  VectorType elementVectorType =
183  llvm::dyn_cast<VectorType>(shapedType.getElementType());
184  if (elementVectorType)
185  elementVectorRank += elementVectorType.getRank();
186  return vectorType.getRank() - elementVectorRank;
187 }
188 
190  VectorType vectorType) {
191  // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
192  // TODO: replace once we have 0-d vectors.
193  if (shapedType.getRank() == 0 &&
194  vectorType.getShape() == ArrayRef<int64_t>{1})
195  return AffineMap::get(
196  /*numDims=*/0, /*numSymbols=*/0,
197  getAffineConstantExpr(0, shapedType.getContext()));
199  shapedType.getRank(),
200  getEffectiveVectorRankForXferOp(shapedType, vectorType),
201  shapedType.getContext());
202 }
203 
204 /// Check if `write` is of a constant splat and the masked `read` is padded with
205 /// the same splat value -- meaning it could be the same value as the initial
206 /// constant splat.
207 static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
208  vector::TransferReadOp read) {
209  auto readMask = read.getMask();
210  auto writeMask = write.getMask();
211  // Check if the masks are consistent. The splat value could be the same if the
212  // read is masked (and padded with the splat value), and the write is unmasked
213  // or has the same mask. Note this does not allow the case where the write is
214  // masked and the read is unmasked, as then the read could be of more elements
215  // than the write (which may not be the same value).
216  bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
217  if (!couldBeSameSplat)
218  return false;
219  // Check for constant splat (as the source of the write).
220  DenseElementsAttr splatAttr;
221  if (!matchPattern(write.getVector(),
222  m_Constant<DenseElementsAttr>(&splatAttr)) ||
223  !splatAttr.isSplat()) {
224  return false;
225  }
226  // The padding of the read and the constant splat value must be the same.
227  Attribute padAttr;
228  if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
229  return false;
230  return padAttr == splatAttr.getSplatValue<Attribute>();
231 }
232 
233 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
234  vector::TransferReadOp read) {
235  return !defWrite.hasOutOfBoundsDim() &&
236  defWrite.getIndices() == read.getIndices() &&
237  defWrite.getVectorType() == read.getVectorType() &&
238  defWrite.getPermutationMap() == read.getPermutationMap() &&
239  ((!defWrite.getMask() && !read.getMask()) ||
240  isSplatWriteConsistentWithMaskedRead(defWrite, read));
241 }
242 
243 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
244  vector::TransferWriteOp priorWrite) {
245  return priorWrite.getIndices() == write.getIndices() &&
246  priorWrite.getMask() == write.getMask() &&
247  priorWrite.getVectorType() == write.getVectorType() &&
248  priorWrite.getPermutationMap() == write.getPermutationMap();
249 }
250 
252  VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
253  bool testDynamicValueUsingBounds) {
254  // For simplicity only look at transfer of same type.
255  if (transferA.getVectorType() != transferB.getVectorType())
256  return false;
257  unsigned rankOffset = transferA.getLeadingShapedRank();
258  for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
259  Value indexA = transferA.getIndices()[i];
260  Value indexB = transferB.getIndices()[i];
261  std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
262  std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
263 
264  if (i < rankOffset) {
265  // For leading dimensions, if we can prove that index are different we
266  // know we are accessing disjoint slices.
267  if (cstIndexA.has_value() && cstIndexB.has_value()) {
268  if (*cstIndexA != *cstIndexB)
269  return true;
270  continue;
271  }
272  if (testDynamicValueUsingBounds) {
273  // First try to see if we can fully compose and simplify the affine
274  // expression as a fast track.
275  FailureOr<uint64_t> delta =
277  if (succeeded(delta) && *delta != 0)
278  return true;
279 
280  FailureOr<bool> testEqual =
281  ValueBoundsConstraintSet::areEqual(indexA, indexB);
282  if (succeeded(testEqual) && !testEqual.value())
283  return true;
284  }
285  } else {
286  // For this dimension, we slice a part of the memref we need to make sure
287  // the intervals accessed don't overlap.
288  int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
289  if (cstIndexA.has_value() && cstIndexB.has_value()) {
290  int64_t distance = std::abs(*cstIndexA - *cstIndexB);
291  if (distance >= vectorDim)
292  return true;
293  continue;
294  }
295  if (testDynamicValueUsingBounds) {
296  // First try to see if we can fully compose and simplify the affine
297  // expression as a fast track.
298  FailureOr<int64_t> delta =
300  if (succeeded(delta) && std::abs(*delta) >= vectorDim)
301  return true;
302 
303  FailureOr<int64_t> computeDelta =
305  if (succeeded(computeDelta)) {
306  if (std::abs(computeDelta.value()) >= vectorDim)
307  return true;
308  }
309  }
310  }
311  }
312  return false;
313 }
314 
315 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
316  VectorTransferOpInterface transferB,
317  bool testDynamicValueUsingBounds) {
318  if (transferA.getBase() != transferB.getBase())
319  return false;
320  return isDisjointTransferIndices(transferA, transferB,
321  testDynamicValueUsingBounds);
322 }
323 
324 // Helper to iterate over n-D vector slice elements. Calculate the next
325 // `position` in the n-D vector of size `shape`, applying an offset `offsets`.
326 // Modifies the `position` in place. Returns a failure when `position` becomes
327 // the end position.
328 static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
329  ArrayRef<int64_t> shape,
330  ArrayRef<int64_t> offsets) {
331  for (auto [posInDim, dimSize, offsetInDim] :
332  llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
333  ++posInDim;
334  if (posInDim < dimSize + offsetInDim)
335  return success();
336 
337  // Carry the overflow to the next loop iteration.
338  posInDim = offsetInDim;
339  }
340 
341  return failure();
342 }
343 
344 /// Returns the integer numbers in `values`. `values` are expected to be
345 /// constant operations.
348  llvm::transform(values, std::back_inserter(ints), [](Value value) {
349  auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
350  assert(constOp && "Unexpected non-constant index");
351  return constOp.value();
352  });
353  return ints;
354 }
355 
356 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
357 /// be constant operations.
360  llvm::transform(
361  foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
362  assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
363  return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
364  });
365  return ints;
366 }
367 
368 /// Convert `foldResults` into Values. Integer attributes are converted to
369 /// constant op.
371  ArrayRef<OpFoldResult> foldResults) {
372  SmallVector<Value> values;
373  llvm::transform(foldResults, std::back_inserter(values),
374  [&](OpFoldResult foldResult) {
375  if (auto attr = dyn_cast<Attribute>(foldResult))
377  builder, loc, cast<IntegerAttr>(attr).getInt())
378  .getResult();
379 
380  return cast<Value>(foldResult);
381  });
382  return values;
383 }
384 
385 std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
386  if (value.getDefiningOp<vector::VectorScaleOp>())
387  return 1;
388  auto mul = value.getDefiningOp<arith::MulIOp>();
389  if (!mul)
390  return {};
391  auto lhs = mul.getLhs();
392  auto rhs = mul.getRhs();
393  if (lhs.getDefiningOp<vector::VectorScaleOp>())
394  return getConstantIntValue(rhs);
395  if (rhs.getDefiningOp<vector::VectorScaleOp>())
396  return getConstantIntValue(lhs);
397  return {};
398 }
399 
400 /// Converts numeric attributes to the expected type. Supports
401 /// integer-to-integer and float-to-integer conversions. Returns the original
402 /// attribute if no conversion is needed or supported.
403 static Attribute convertNumericAttr(Attribute attr, Type expectedType) {
404  // Integer-to-integer conversion
405  if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
406  if (auto intType = dyn_cast<IntegerType>(expectedType)) {
407  if (intAttr.getType() != expectedType)
408  return IntegerAttr::get(expectedType, intAttr.getInt());
409  }
410  return attr;
411  }
412 
413  // Float-to-integer bitcast (preserves bit representation)
414  if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
415  auto intType = dyn_cast<IntegerType>(expectedType);
416  if (!intType)
417  return attr;
418 
419  APFloat floatVal = floatAttr.getValue();
420  APInt intVal = floatVal.bitcastToAPInt();
421  return IntegerAttr::get(expectedType, intVal);
422  }
423 
424  return attr;
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // CombiningKindAttr
429 //===----------------------------------------------------------------------===//
430 
431 namespace mlir {
432 namespace vector {
433 namespace detail {
435  using KeyTy = uint64_t;
436 
437  BitmaskEnumStorage(KeyTy val) : value(val) {}
438 
439  bool operator==(const KeyTy &key) const { return value == key; }
440 
442  const KeyTy &key) {
443  return new (allocator.allocate<BitmaskEnumStorage>())
444  BitmaskEnumStorage(key);
445  }
446 
447  KeyTy value = 0;
448 };
449 } // namespace detail
450 } // namespace vector
451 } // namespace mlir
452 
453 //===----------------------------------------------------------------------===//
454 // VectorDialect
455 //===----------------------------------------------------------------------===//
456 
457 namespace {
458 /// This class defines the interface for handling inlining with vector dialect
459 /// operations.
460 struct VectorInlinerInterface : public DialectInlinerInterface {
462 
463  /// All vector dialect ops can be inlined.
464  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
465  return true;
466  }
467 };
468 } // namespace
469 
470 void VectorDialect::initialize() {
471  addAttributes<
472 #define GET_ATTRDEF_LIST
473 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
474  >();
475 
476  addOperations<
477 #define GET_OP_LIST
478 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
479  >();
480 
481  addInterfaces<VectorInlinerInterface>();
482 
483  declarePromisedInterfaces<bufferization::BufferizableOpInterface,
484  TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
485  YieldOp>();
486  declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
487  TransferWriteOp>();
488  declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
489  declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
490  declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
491 }
492 
493 /// Materialize a single constant operation from a given attribute value with
494 /// the desired resultant type.
496  Attribute value, Type type,
497  Location loc) {
498  if (isa<ub::PoisonAttrInterface>(value))
499  return value.getDialect().materializeConstant(builder, value, type, loc);
500 
501  return arith::ConstantOp::materialize(builder, value, type, loc);
502 }
503 
505  return builder.getIntegerType(64);
506 }
507 
509  ArrayRef<int64_t> values) {
510  return builder.getI64ArrayAttr(values);
511 }
512 
513 //===----------------------------------------------------------------------===//
514 // MultiDimReductionOp
515 //===----------------------------------------------------------------------===//
516 
517 void vector::MultiDimReductionOp::build(OpBuilder &builder,
518  OperationState &result, Value source,
519  Value acc, ArrayRef<bool> reductionMask,
520  CombiningKind kind) {
521  SmallVector<int64_t> reductionDims;
522  for (const auto &en : llvm::enumerate(reductionMask))
523  if (en.value())
524  reductionDims.push_back(en.index());
525  build(builder, result, kind, source, acc, reductionDims);
526 }
527 
528 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
529  // Single parallel dim, this is a noop.
530  if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
531  return getSource();
532  return {};
533 }
534 
535 std::optional<SmallVector<int64_t, 4>>
536 MultiDimReductionOp::getShapeForUnroll() {
537  return llvm::to_vector<4>(getSourceVectorType().getShape());
538 }
539 
540 LogicalResult MultiDimReductionOp::verify() {
541  SmallVector<int64_t> targetShape;
542  SmallVector<bool> scalableDims;
543  Type inferredReturnType;
544  auto sourceScalableDims = getSourceVectorType().getScalableDims();
545  for (auto [dimIdx, dimSize] :
546  llvm::enumerate(getSourceVectorType().getShape()))
547  if (!llvm::any_of(getReductionDims(),
548  [dimIdx = dimIdx](int64_t reductionDimIdx) {
549  return reductionDimIdx == static_cast<int64_t>(dimIdx);
550  })) {
551  targetShape.push_back(dimSize);
552  scalableDims.push_back(sourceScalableDims[dimIdx]);
553  }
554  // TODO: update to also allow 0-d vectors when available.
555  if (targetShape.empty())
556  inferredReturnType = getSourceVectorType().getElementType();
557  else
558  inferredReturnType = VectorType::get(
559  targetShape, getSourceVectorType().getElementType(), scalableDims);
560  if (getType() != inferredReturnType)
561  return emitOpError() << "destination type " << getType()
562  << " is incompatible with source type "
563  << getSourceVectorType();
564 
565  return success();
566 }
567 
568 /// Returns the mask type expected by this operation.
569 Type MultiDimReductionOp::getExpectedMaskType() {
570  auto vecType = getSourceVectorType();
571  return VectorType::get(vecType.getShape(),
572  IntegerType::get(vecType.getContext(), /*width=*/1),
573  vecType.getScalableDims());
574 }
575 
576 namespace {
577 // Only unit dimensions that are being reduced are folded. If the dimension is
578 // unit, but not reduced, it is not folded, thereby keeping the output type the
579 // same. If not all dimensions which are reduced are of unit dimension, this
580 // transformation does nothing. This is just a generalization of
581 // ElideSingleElementReduction for ReduceOp.
582 struct ElideUnitDimsInMultiDimReduction
583  : public OpRewritePattern<MultiDimReductionOp> {
584  using Base::Base;
585 
586  LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
587  PatternRewriter &rewriter) const override {
588  ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
589  for (const auto &dim : enumerate(shape)) {
590  if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
591  return failure();
592  }
593 
594  // Vector mask setup.
595  OpBuilder::InsertionGuard guard(rewriter);
596  Operation *rootOp;
597  Value mask;
598  if (reductionOp.isMasked()) {
599  rewriter.setInsertionPoint(reductionOp.getMaskingOp());
600  rootOp = reductionOp.getMaskingOp();
601  mask = reductionOp.getMaskingOp().getMask();
602  } else {
603  rootOp = reductionOp;
604  }
605 
606  Location loc = reductionOp.getLoc();
607  Value acc = reductionOp.getAcc();
608  Value cast;
609  if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
610  if (mask) {
611  VectorType newMaskType =
612  VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
613  dstVecType.getScalableDims());
614  mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
615  }
616  cast = vector::ShapeCastOp::create(
617  rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
618  } else {
619  // This means we are reducing all the dimensions, and all reduction
620  // dimensions are of size 1. So a simple extraction would do.
621  if (mask)
622  mask = vector::ExtractOp::create(rewriter, loc, mask);
623  cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
624  }
625 
626  Value result =
627  vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
628  cast, /*fastmath=*/nullptr, mask);
629  rewriter.replaceOp(rootOp, result);
630  return success();
631  }
632 };
633 } // namespace
634 
635 void MultiDimReductionOp::getCanonicalizationPatterns(
636  RewritePatternSet &results, MLIRContext *context) {
637  results.add<ElideUnitDimsInMultiDimReduction>(context);
638 }
639 
640 //===----------------------------------------------------------------------===//
641 // ReductionOp
642 //===----------------------------------------------------------------------===//
643 
644 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
645  CombiningKind kind, Value vector,
646  arith::FastMathFlags fastMathFlags) {
647  build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
648 }
649 
650 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
651  CombiningKind kind, Value vector, Value acc,
652  arith::FastMathFlags fastMathFlags) {
653  build(builder, result,
654  llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
655  acc, fastMathFlags);
656 }
657 
658 LogicalResult ReductionOp::verify() {
659  // Verify for 0-D and 1-D vector.
660  int64_t rank = getSourceVectorType().getRank();
661  if (rank > 1)
662  return emitOpError("unsupported reduction rank: ") << rank;
663 
664  // Verify supported reduction kind.
665  Type eltType = getDest().getType();
666  if (!isSupportedCombiningKind(getKind(), eltType))
667  return emitOpError("unsupported reduction type '")
668  << eltType << "' for kind '" << stringifyCombiningKind(getKind())
669  << "'";
670 
671  return success();
672 }
673 
674 // MaskableOpInterface methods.
675 
676 /// Returns the mask type expected by this operation.
677 Type ReductionOp::getExpectedMaskType() {
678  auto vecType = getSourceVectorType();
679  return VectorType::get(vecType.getShape(),
680  IntegerType::get(vecType.getContext(), /*width=*/1),
681  vecType.getScalableDims());
682 }
683 
685  OpBuilder &builder, Location loc,
686  Value vector) {
687  switch (op) {
688  case arith::AtomicRMWKind::addf:
689  case arith::AtomicRMWKind::addi:
690  return vector::ReductionOp::create(builder, vector.getLoc(),
691  CombiningKind::ADD, vector);
692  case arith::AtomicRMWKind::mulf:
693  case arith::AtomicRMWKind::muli:
694  return vector::ReductionOp::create(builder, vector.getLoc(),
695  CombiningKind::MUL, vector);
696  case arith::AtomicRMWKind::minimumf:
697  return vector::ReductionOp::create(builder, vector.getLoc(),
698  CombiningKind::MINIMUMF, vector);
699  case arith::AtomicRMWKind::mins:
700  return vector::ReductionOp::create(builder, vector.getLoc(),
701  CombiningKind::MINSI, vector);
702  case arith::AtomicRMWKind::minu:
703  return vector::ReductionOp::create(builder, vector.getLoc(),
704  CombiningKind::MINUI, vector);
705  case arith::AtomicRMWKind::maximumf:
706  return vector::ReductionOp::create(builder, vector.getLoc(),
707  CombiningKind::MAXIMUMF, vector);
708  case arith::AtomicRMWKind::maxs:
709  return vector::ReductionOp::create(builder, vector.getLoc(),
710  CombiningKind::MAXSI, vector);
711  case arith::AtomicRMWKind::maxu:
712  return vector::ReductionOp::create(builder, vector.getLoc(),
713  CombiningKind::MAXUI, vector);
714  case arith::AtomicRMWKind::andi:
715  return vector::ReductionOp::create(builder, vector.getLoc(),
716  CombiningKind::AND, vector);
717  case arith::AtomicRMWKind::ori:
718  return vector::ReductionOp::create(builder, vector.getLoc(),
719  CombiningKind::OR, vector);
720  // TODO: Add remaining reduction operations.
721  default:
722  (void)emitOptionalError(loc, "Reduction operation type not supported");
723  break;
724  }
725  return nullptr;
726 }
727 
728 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
729  return llvm::to_vector<4>(getSourceVectorType().getShape());
730 }
731 
732 namespace {
733 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
734  using Base::Base;
735 
736  LogicalResult matchAndRewrite(ReductionOp reductionOp,
737  PatternRewriter &rewriter) const override {
738  // Vector mask setup.
739  OpBuilder::InsertionGuard guard(rewriter);
740  auto maskableOp =
741  cast<vector::MaskableOpInterface>(reductionOp.getOperation());
742  Operation *rootOp;
743  Value mask;
744  if (maskableOp.isMasked()) {
745  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
746  rootOp = maskableOp.getMaskingOp();
747  mask = maskableOp.getMaskingOp().getMask();
748  } else {
749  rootOp = reductionOp;
750  }
751 
752  auto vectorType = reductionOp.getSourceVectorType();
753  if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
754  return failure();
755 
756  Location loc = reductionOp.getLoc();
757  if (mask)
758  mask = ExtractOp::create(rewriter, loc, mask);
759  Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
760 
761  if (Value acc = reductionOp.getAcc())
762  result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
763  result, acc,
764  reductionOp.getFastmathAttr(), mask);
765 
766  rewriter.replaceOp(rootOp, result);
767  return success();
768  }
769 };
770 } // namespace
771 
772 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
773  MLIRContext *context) {
774  results.add<ElideSingleElementReduction>(context);
775 }
776 
777 //===----------------------------------------------------------------------===//
778 // ContractionOp
779 //===----------------------------------------------------------------------===//
780 
781 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
782  Value lhs, Value rhs, Value acc,
783  ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
784  ArrayRef<IteratorType> iteratorTypes) {
785  result.addOperands({lhs, rhs, acc});
786  result.addTypes(acc.getType());
787  result.addAttribute(
788  getIndexingMapsAttrName(result.name),
789  builder.getAffineMapArrayAttr(
790  AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
791  result.addAttribute(
792  getIteratorTypesAttrName(result.name),
793  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
794  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
795  return IteratorTypeAttr::get(builder.getContext(), t);
796  }))));
797 }
798 
799 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
800  Value lhs, Value rhs, Value acc,
801  ArrayAttr indexingMaps,
802  ArrayAttr iteratorTypes) {
803  build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
804  ContractionOp::getDefaultKind());
805 }
806 
807 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
808  Value lhs, Value rhs, Value acc,
809  ArrayAttr indexingMaps,
810  ArrayAttr iteratorTypes, CombiningKind kind) {
811  result.addOperands({lhs, rhs, acc});
812  result.addTypes(acc.getType());
813  result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
814  result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
815  result.addAttribute(getKindAttrName(result.name),
817 }
818 
819 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
824  SmallVector<Type, 2> types;
825  Type resultType;
826  auto loc = parser.getCurrentLocation();
827  DictionaryAttr dictAttr;
828  // TODO: Unify linalg op attribute parsing.
829  if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
830  parser.parseComma() || parser.parseOperand(rhsInfo) ||
831  parser.parseComma() || parser.parseOperand(accInfo) ||
832  parser.parseTrailingOperandList(masksInfo) ||
833  parser.parseOptionalAttrDict(result.attributes) ||
834  parser.parseColonTypeList(types) ||
835  parser.parseKeywordType("into", resultType) ||
836  parser.resolveOperand(lhsInfo, types[0], result.operands) ||
837  parser.resolveOperand(rhsInfo, types[1], result.operands) ||
838  parser.resolveOperand(accInfo, resultType, result.operands) ||
839  parser.addTypeToList(resultType, result.types))
840  return failure();
841  result.attributes.append(dictAttr.getValue().begin(),
842  dictAttr.getValue().end());
843 
844  // Convert array of string into an array of IteratyType enums. This is needed,
845  // because tests still use the old format when 'iterator_types' attribute is
846  // represented as an array of strings.
847  // TODO: Remove this conversion once tests are fixed.
848  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
849  result.attributes.get(getIteratorTypesAttrName(result.name)));
850  if (!iteratorTypes) {
851  return parser.emitError(loc)
852  << "expected " << getIteratorTypesAttrName(result.name)
853  << " array attribute";
854  }
855 
856  SmallVector<Attribute> iteratorTypeAttrs;
857 
858  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
859  auto maybeIteratorType = symbolizeIteratorType(s);
860  if (!maybeIteratorType.has_value())
861  return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
862 
863  iteratorTypeAttrs.push_back(
864  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
865  }
866  result.attributes.set(getIteratorTypesAttrName(result.name),
867  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
868 
869  if (!result.attributes.get(getKindAttrName(result.name))) {
870  result.addAttribute(
871  getKindAttrName(result.name),
873  ContractionOp::getDefaultKind()));
874  }
875  if (masksInfo.empty())
876  return success();
877  if (masksInfo.size() != 2)
878  return parser.emitError(parser.getNameLoc(),
879  "expected zero or exactly 2 vector mask operands");
880  auto lhsType = llvm::cast<VectorType>(types[0]);
881  auto rhsType = llvm::cast<VectorType>(types[1]);
882  auto maskElementType = parser.getBuilder().getI1Type();
883  std::array<VectorType, 2> maskTypes = {
884  VectorType::Builder(lhsType).setElementType(maskElementType),
885  VectorType::Builder(rhsType).setElementType(maskElementType)};
886  if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
887  return failure();
888  return success();
889 }
890 
892  // TODO: Unify printing code with linalg ops.
893  auto attrNames = getTraitAttrNames();
894  llvm::StringSet<> traitAttrsSet;
895  traitAttrsSet.insert_range(attrNames);
897  for (auto attr : (*this)->getAttrs()) {
898  if (attr.getName() == getIteratorTypesAttrName()) {
899  auto iteratorTypes =
900  llvm::cast<ArrayAttr>(attr.getValue())
901  .getAsValueRange<IteratorTypeAttr, IteratorType>();
902  // Convert IteratorType enums into the string representation. This is
903  // needed, because tests still use the old format when 'iterator_types'
904  // attribute is represented as an array of strings.
905  // TODO: Remove this conversion once tests are fixed.
906  SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
907  llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
908  return StringAttr::get(getContext(), stringifyIteratorType(t));
909  }));
910 
911  attrs.emplace_back(getIteratorTypesAttrName(),
912  ArrayAttr::get(getContext(), iteratorTypeNames));
913  } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
914  attrs.push_back(attr);
915  }
916 
917  auto dictAttr = DictionaryAttr::get(getContext(), attrs);
918  p << " " << dictAttr << " " << getLhs() << ", ";
919  p << getRhs() << ", " << getAcc();
920 
921  p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
922  p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
923  << getResultType();
924 }
925 
926 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
927  const std::vector<std::pair<int64_t, int64_t>> &map) {
928  for (auto &dimPair : map) {
929  if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
930  dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
931  lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
932  return false;
933  }
934  return true;
935 }
936 
937 static LogicalResult verifyOutputShape(
938  ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
939  Type resType,
940  const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
941  const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
942  DenseSet<int64_t> lhsContractingDimSet;
943  DenseSet<int64_t> rhsContractingDimSet;
944  for (auto &dimPair : contractingDimMap) {
945  lhsContractingDimSet.insert(dimPair.first);
946  rhsContractingDimSet.insert(dimPair.second);
947  }
948  DenseSet<int64_t> rhsBatchDimSet(llvm::from_range,
949  llvm::make_second_range(batchDimMap));
950 
951  // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
952  SmallVector<int64_t, 4> expectedResultDims;
953  for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
954  if (lhsContractingDimSet.count(i) > 0)
955  continue;
956  expectedResultDims.push_back(lhsType.getDimSize(i));
957  }
958 
959  // Add free dimensions from 'rhsType' to 'expectedResultDims'.
960  for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
961  if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
962  continue;
963  expectedResultDims.push_back(rhsType.getDimSize(i));
964  }
965 
966  // Verify 'expectedResultDims'.
967  if (expectedResultDims.empty()) {
968  // No batch or free dimension implies a scalar result.
969  if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
970  return op.emitOpError("invalid accumulator/result vector shape");
971  } else {
972  // At least one batch or free dimension implies a vector result.
973  auto resVectorType = llvm::dyn_cast<VectorType>(resType);
974  auto accVectorType = llvm::dyn_cast<VectorType>(accType);
975  if (!resVectorType || !accVectorType)
976  return op.emitOpError("invalid accumulator/result vector shape");
977 
978  // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
979  // types fully define the result vector type. This assumes the affine maps
980  // are well-formed, which must have been verified already.
981  MLIRContext *ctx = op.getContext();
982  AffineMap lhsMap = op.getIndexingMapsArray()[0];
983  AffineMap rhsMap = op.getIndexingMapsArray()[1];
984  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
985  return op.emitOpError(
986  "expected all dimensions to be either a LHS or a RHS dimension");
987  SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
988  for (auto pair :
989  {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
990  VectorType v = pair.first;
991  auto map = pair.second;
992  for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
993  unsigned pos = map.getDimPosition(idx);
994  if (!extents[pos])
995  extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
996  }
997  }
998  if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
999  return op.emitOpError("expected all dimensions to get an extent as "
1000  "either a LHS or a RHS dimension");
1001 
1002  AffineMap resMap = op.getIndexingMapsArray()[2];
1003  auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
1004  /*symbolCount=*/0, extents, ctx);
1005  // Compose the resMap with the extentsMap, which is a constant map.
1006  AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
1007  assert(llvm::all_of(expectedMap.getResults(),
1008  llvm::IsaPred<AffineConstantExpr>) &&
1009  "expected constant extent along all dimensions.");
1010  // Extract the expected shape and build the type.
1011  auto expectedShape = llvm::to_vector<4>(
1012  llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
1013  return cast<AffineConstantExpr>(e).getValue();
1014  }));
1015  auto expected =
1016  VectorType::get(expectedShape, resVectorType.getElementType(),
1017  resVectorType.getScalableDims());
1018  if (resVectorType != expected || accVectorType != expected)
1019  return op.emitOpError(
1020  "invalid accumulator/result vector shape, expected: ")
1021  << expected;
1022  }
1023  return success();
1024 }
1025 
1026 LogicalResult ContractionOp::verify() {
1027  VectorType lhsType = getLhsType();
1028  VectorType rhsType = getRhsType();
1029  Type accType = getAccType();
1030  Type resType = getResultType();
1031 
1032  if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1033  if (!lhsType.getElementType().isSignlessInteger())
1034  return emitOpError("only supports signless integer types");
1035  }
1036 
1037  // Verify that an indexing map was specified for each vector operand.
1038  if (getIndexingMapsArray().size() != 3)
1039  return emitOpError("expected an indexing map for each vector operand");
1040 
1041  // Verify that each index map has 'numIterators' inputs, no symbols, and
1042  // that the number of map outputs equals the rank of its associated
1043  // vector operand.
1044  unsigned numIterators = getIteratorTypes().getValue().size();
1045  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1046  auto index = it.index();
1047  auto map = it.value();
1048  if (map.getNumSymbols() != 0)
1049  return emitOpError("expected indexing map ")
1050  << index << " to have no symbols";
1051  auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
1052  unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1053  // Verify that the map has the right number of inputs, outputs, and indices.
1054  // This also correctly accounts for (..) -> () for rank-0 results.
1055  if (map.getNumDims() != numIterators)
1056  return emitOpError("expected indexing map ")
1057  << index << " to have " << numIterators << " number of inputs";
1058  if (map.getNumResults() != rank)
1059  return emitOpError("expected indexing map ")
1060  << index << " to have " << rank << " number of outputs";
1061  if (!map.isProjectedPermutation())
1062  return emitOpError("expected indexing map ")
1063  << index << " to be a projected permutation of its inputs";
1064  }
1065 
1066  auto contractingDimMap = getContractingDimMap();
1067  auto batchDimMap = getBatchDimMap();
1068 
1069  // Verify at least one contracting dimension pair was specified.
1070  if (contractingDimMap.empty())
1071  return emitOpError("expected at least one contracting dimension pair");
1072 
1073  // Verify contracting dimension map was properly constructed.
1074  if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1075  return emitOpError("invalid contracting dimension map");
1076 
1077  // Verify batch dimension map was properly constructed.
1078  if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1079  return emitOpError("invalid batch dimension map");
1080 
1081  // Verify 'accType' and 'resType' shape.
1082  if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1083  contractingDimMap, batchDimMap)))
1084  return failure();
1085 
1086  // Verify supported combining kind.
1087  auto vectorType = llvm::dyn_cast<VectorType>(resType);
1088  auto elementType = vectorType ? vectorType.getElementType() : resType;
1089  if (!isSupportedCombiningKind(getKind(), elementType))
1090  return emitOpError("unsupported contraction type");
1091 
1092  // Delayed calling of IndexingMapOpInterface::verifyImpl.
1093  return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1094 }
1095 
1096 // MaskableOpInterface methods.
1097 
1098 /// Returns the mask type expected by this operation. Mostly used for
1099 /// verification purposes. It requires the operation to be vectorized."
1100 Type ContractionOp::getExpectedMaskType() {
1101  auto indexingMaps = this->getIndexingMapsArray();
1102  AffineMap lhsIdxMap = indexingMaps[0];
1103  AffineMap rhsIdxMap = indexingMaps[1];
1104  VectorType lhsType = this->getLhsType();
1105  VectorType rhsType = this->getRhsType();
1106 
1107  unsigned numVecDims = lhsIdxMap.getNumDims();
1108  SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1109  SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1110 
1111  // Using the information in the indexing maps, extract the size of each
1112  // dimension in the vector.contract operation from the two input operands.
1113  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1114  maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1115  maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1116  lhsType.getScalableDims()[dimIdx];
1117  }
1118  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1119  maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1120  maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1121  rhsType.getScalableDims()[dimIdx];
1122  }
1123 
1124  assert(ShapedType::isStaticShape(maskShape) &&
1125  "Mask shape couldn't be computed");
1126 
1127  return VectorType::get(maskShape,
1128  IntegerType::get(lhsType.getContext(), /*width=*/1),
1129  maskShapeScalableDims);
1130 }
1131 
1132 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1133  return SmallVector<StringRef>{getIndexingMapsAttrName(),
1134  getIteratorTypesAttrName(), getKindAttrName()};
1135 }
1136 
1137 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
1138  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1139  if (targetExpr == map.getResult(i))
1140  return i;
1141  return -1;
1142 }
1143 
1144 static std::vector<std::pair<int64_t, int64_t>>
1145 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1146  IteratorType targetIteratorType, MLIRContext *context) {
1147  std::vector<std::pair<int64_t, int64_t>> dimMap;
1148  for (const auto &it : llvm::enumerate(iteratorTypes)) {
1149  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1150  if (iteratorType != targetIteratorType)
1151  continue;
1152  // Search lhs/rhs map results for 'targetExpr'.
1153  auto targetExpr = getAffineDimExpr(it.index(), context);
1154  int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1155  int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1156  if (lhsDim >= 0 && rhsDim >= 0)
1157  dimMap.emplace_back(lhsDim, rhsDim);
1158  }
1159  return dimMap;
1160 }
1161 
1162 void ContractionOp::getIterationBounds(
1163  SmallVectorImpl<int64_t> &iterationBounds) {
1164  auto lhsShape = getLhsType().getShape();
1165  auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1166  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1167  for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1168  // Search lhs/rhs map results for 'targetExpr'.
1169  auto targetExpr = getAffineDimExpr(it.index(), getContext());
1170  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1171  if (iteratorType == IteratorType::reduction) {
1172  // Get reduction dim size from lhs shape (same size in rhsShape).
1173  int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1174  assert(lhsDimIndex >= 0);
1175  iterationBounds.push_back(lhsShape[lhsDimIndex]);
1176  continue;
1177  }
1178  // Get parallel dimension size from result shape.
1179  int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1180  assert(resDimIndex >= 0);
1181  assert(resVectorType != nullptr);
1182  iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1183  }
1184 }
1185 
1186 void ContractionOp::getIterationIndexMap(
1187  std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1188  unsigned numMaps = getIndexingMapsArray().size();
1189  iterationIndexMap.resize(numMaps);
1190  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1191  auto index = it.index();
1192  auto map = it.value();
1193  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1194  auto dim = cast<AffineDimExpr>(map.getResult(i));
1195  iterationIndexMap[index][dim.getPosition()] = i;
1196  }
1197  }
1198 }
1199 
1200 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1201  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1202  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1203  getContext());
1204 }
1205 
1206 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1207  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1208  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1209  getContext());
1210 }
1211 
1212 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1214  getIterationBounds(shape);
1215  return shape;
1216 }
1217 
1218 /// Return a fused vector::ContractionOp which represents a patterns such as:
1219 ///
1220 /// ```mlir
1221 /// %c0 = vector.constant 0: ...
1222 /// %c = vector.contract %a, %b, %c0: ...
1223 /// %e = add %c, %d: ...
1224 /// ```
1225 ///
1226 /// by:
1227 ///
1228 /// ```mlir
1229 /// %e = vector.contract %a, %b, %d: ...
1230 /// ```
1231 ///
1232 /// Return null if the canonicalization does not apply.
1233 // TODO: This should be a folding of Add into Contract in core but while they
1234 // live in different dialects, it is not possible without unnatural
1235 // dependencies.
1236 template <typename AddOpType>
1237 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1239 
1240  LogicalResult matchAndRewrite(AddOpType addOp,
1241  PatternRewriter &rewriter) const override {
1242  auto canonicalize = [&](Value maybeContraction,
1243  Value otherOperand) -> vector::ContractionOp {
1244  vector::ContractionOp contractionOp =
1245  dyn_cast_or_null<vector::ContractionOp>(
1246  maybeContraction.getDefiningOp());
1247  if (!contractionOp)
1248  return vector::ContractionOp();
1249  if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1250  contractionOp.getAcc().getDefiningOp())) {
1251  if (maybeZero.getValue() ==
1252  rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1253  IRMapping bvm;
1254  bvm.map(contractionOp.getAcc(), otherOperand);
1255  auto newContraction =
1256  cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1257  rewriter.replaceOp(addOp, newContraction.getResult());
1258  return newContraction;
1259  }
1260  }
1261  return vector::ContractionOp();
1262  };
1263 
1264  Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1265  vector::ContractionOp contract = canonicalize(a, b);
1266  contract = contract ? contract : canonicalize(b, a);
1267  return contract ? success() : failure();
1268  }
1269 };
1270 
1271 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1272  MLIRContext *context) {
1275 }
1276 
1277 // Returns `true` if `index` is either within [0, maxIndex) or equal to
1278 // `poisonValue`.
1279 static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
1280  int64_t maxIndex) {
1281  return index == poisonValue || (index >= 0 && index < maxIndex);
1282 }
1283 
1284 //===----------------------------------------------------------------------===//
1285 // ExtractOp
1286 //===----------------------------------------------------------------------===//
1287 
1288 void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1289  SetIntRangeFn setResultRanges) {
1290  setResultRanges(getResult(), argRanges.front());
1291 }
1292 
1293 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1294  Value source) {
1295  auto vectorTy = cast<VectorType>(source.getType());
1296  build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
1297 }
1298 
1299 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1300  Value source, int64_t position) {
1301  build(builder, result, source, ArrayRef<int64_t>{position});
1302 }
1303 
1304 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1305  Value source, OpFoldResult position) {
1306  build(builder, result, source, ArrayRef<OpFoldResult>{position});
1307 }
1308 
1309 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1310  Value source, ArrayRef<int64_t> position) {
1311  build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1312  builder.getDenseI64ArrayAttr(position));
1313 }
1314 
1315 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1316  Value source, ArrayRef<OpFoldResult> position) {
1317  SmallVector<int64_t> staticPos;
1318  SmallVector<Value> dynamicPos;
1319  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1320  build(builder, result, source, dynamicPos,
1321  builder.getDenseI64ArrayAttr(staticPos));
1322 }
1323 
1324 LogicalResult
1325 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1326  ExtractOp::Adaptor adaptor,
1327  SmallVectorImpl<Type> &inferredReturnTypes) {
1328  auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1329  if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1330  vectorType.getRank()) {
1331  inferredReturnTypes.push_back(vectorType.getElementType());
1332  } else {
1333  auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1334  vectorType.getRank());
1335  inferredReturnTypes.push_back(VectorType::get(
1336  vectorType.getShape().drop_front(n), vectorType.getElementType(),
1337  vectorType.getScalableDims().drop_front(n)));
1338  }
1339  return success();
1340 }
1341 
1342 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1343  // Allow extracting 1-element vectors instead of scalars.
1344  auto isCompatible = [](TypeRange l, TypeRange r) {
1345  auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1346  return vectorType && vectorType.getShape().equals({1}) &&
1347  vectorType.getElementType() == r.front();
1348  };
1349  if (l.size() == 1 && r.size() == 1 &&
1350  (isCompatible(l, r) || isCompatible(r, l)))
1351  return true;
1352  return l == r;
1353 }
1354 
1355 LogicalResult vector::ExtractOp::verify() {
1356  if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
1357  if (resTy.getRank() == 0)
1358  return emitError(
1359  "expected a scalar instead of a 0-d vector as the result type");
1360 
1361  // Note: This check must come before getMixedPosition() to prevent a crash.
1362  auto dynamicMarkersCount =
1363  llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1364  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1365  return emitOpError(
1366  "mismatch between dynamic and static positions (kDynamic marker but no "
1367  "corresponding dynamic position) -- this can only happen due to an "
1368  "incorrect fold/rewrite");
1369  auto position = getMixedPosition();
1370  if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1371  return emitOpError(
1372  "expected position attribute of rank no greater than vector rank");
1373  for (auto [idx, pos] : llvm::enumerate(position)) {
1374  if (auto attr = dyn_cast<Attribute>(pos)) {
1375  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1377  constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1378  return emitOpError("expected position attribute #")
1379  << (idx + 1)
1380  << " to be a non-negative integer smaller than the "
1381  "corresponding vector dimension or poison (-1)";
1382  }
1383  }
1384  }
1385  return success();
1386 }
1387 
1388 template <typename IntType>
1389 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1390  return llvm::to_vector<4>(llvm::map_range(
1391  arrayAttr.getAsRange<IntegerAttr>(),
1392  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1393 }
1394 
1395 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1396 /// positions.
1397 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1398  if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1399  return failure();
1400 
1401  // TODO: Canonicalization for dynamic position not implemented yet.
1402  if (extractOp.hasDynamicPosition())
1403  return failure();
1404 
1405  SmallVector<int64_t> globalPosition;
1406  ExtractOp currentOp = extractOp;
1407  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1408  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1409  while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1410  currentOp = nextOp;
1411  // TODO: Canonicalization for dynamic position not implemented yet.
1412  if (currentOp.hasDynamicPosition())
1413  return failure();
1414  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1415  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1416  }
1417  extractOp.setOperand(0, currentOp.getSource());
1418  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1419  OpBuilder b(extractOp.getContext());
1420  std::reverse(globalPosition.begin(), globalPosition.end());
1421  extractOp.setStaticPosition(globalPosition);
1422  return success();
1423 }
1424 
1425 namespace {
1426 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1427 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1428 /// Compose TransposeOp permutations as we walk back.
1429 /// This helper class keeps an updated extraction position `extractPosition`
1430 /// with extra trailing sentinels.
1431 /// The sentinels encode the internal transposition status of the result vector.
1432 /// As we iterate, extractPosition is permuted and updated.
1433 class ExtractFromInsertTransposeChainState {
1434 public:
1435  ExtractFromInsertTransposeChainState(ExtractOp e);
1436 
1437  /// Iterate over producing insert and transpose ops until we find a fold.
1438  Value fold();
1439 
1440 private:
1441  /// Return true if the vector at position `a` is contained within the vector
1442  /// at position `b`. Under insert/extract semantics, this is the same as `a`
1443  /// is a prefix of `b`.
1444  template <typename ContainerA, typename ContainerB>
1445  bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1446  return a.size() <= b.size() &&
1447  std::equal(a.begin(), a.begin() + a.size(), b.begin());
1448  }
1449 
1450  /// Return true if the vector at position `a` intersects the vector at
1451  /// position `b`. Under insert/extract semantics, this is the same as equality
1452  /// of all entries of `a` that are >=0 with the corresponding entries of b.
1453  /// Comparison is on the common prefix (i.e. zip).
1454  template <typename ContainerA, typename ContainerB>
1455  bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1456  for (auto [elemA, elemB] : llvm::zip(a, b)) {
1457  if (elemA < 0 || elemB < 0)
1458  continue;
1459  if (elemA != elemB)
1460  return false;
1461  }
1462  return true;
1463  }
1464 
1465  /// Folding is only possible in the absence of an internal permutation in the
1466  /// result vector.
1467  bool canFold() {
1468  return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1469  }
1470 
1471  // Helper to get the next defining op of interest.
1472  void updateStateForNextIteration(Value v) {
1473  nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1474  nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1475  };
1476 
1477  // Case 1. If we hit a transpose, just compose the map and iterate.
1478  // Invariant: insert + transpose do not change rank, we can always compose.
1479  LogicalResult handleTransposeOp();
1480 
1481  // Case 2: the insert position matches extractPosition exactly, early return.
1482  LogicalResult handleInsertOpWithMatchingPos(Value &res);
1483 
1484  /// Case 3: if the insert position is a prefix of extractPosition, extract a
1485  /// portion of the source of the insert.
1486  /// Example:
1487  /// ```
1488  /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1489  /// // extractPosition == [1, 2, 3]
1490  /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1491  /// // can fold to vector.extract %source[0, 3]
1492  /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1493  /// ```
1494  /// To traverse through %source, we need to set the leading dims to 0 and
1495  /// drop the extra leading dims.
1496  /// This method updates the internal state.
1497  LogicalResult handleInsertOpWithPrefixPos(Value &res);
1498 
1499  /// Try to fold in place to extract(source, extractPosition) and return the
1500  /// folded result. Return null if folding is not possible (e.g. due to an
1501  /// internal transposition in the result).
1502  Value tryToFoldExtractOpInPlace(Value source);
1503 
1504  ExtractOp extractOp;
1505  int64_t vectorRank;
1506  int64_t extractedRank;
1507 
1508  InsertOp nextInsertOp;
1509  TransposeOp nextTransposeOp;
1510 
1511  /// Sentinel values that encode the internal permutation status of the result.
1512  /// They are set to (-1, ... , -k) at the beginning and appended to
1513  /// `extractPosition`.
1514  /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1515  /// ensure that there is no internal transposition.
1516  /// Internal transposition cannot be accounted for with a folding pattern.
1517  // TODO: We could relax the internal transposition with an extra transposition
1518  // operation in a future canonicalizer.
1519  SmallVector<int64_t> sentinels;
1521 };
1522 } // namespace
1523 
1524 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1525  ExtractOp e)
1526  : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1527  extractedRank(extractOp.getNumIndices()) {
1528  assert(vectorRank >= extractedRank && "Extracted position overflow");
1529  sentinels.reserve(vectorRank - extractedRank);
1530  for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1531  sentinels.push_back(-(i + 1));
1532  extractPosition.assign(extractOp.getStaticPosition().begin(),
1533  extractOp.getStaticPosition().end());
1534  llvm::append_range(extractPosition, sentinels);
1535 }
1536 
1537 // Case 1. If we hit a transpose, just compose the map and iterate.
1538 // Invariant: insert + transpose do not change rank, we can always compose.
1539 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1540  // TODO: Canonicalization for dynamic position not implemented yet.
1541  if (extractOp.hasDynamicPosition())
1542  return failure();
1543 
1544  if (!nextTransposeOp)
1545  return failure();
1546  AffineMap m = inversePermutation(AffineMap::getPermutationMap(
1547  nextTransposeOp.getPermutation(), extractOp.getContext()));
1549  return success();
1550 }
1551 
1552 // Case 2: the insert position matches extractPosition exactly, early return.
1553 LogicalResult
1554 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1555  Value &res) {
1556  // TODO: Canonicalization for dynamic position not implemented yet.
1557  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1558  return failure();
1559 
1560  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1561  if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1562  return failure();
1563  // Case 2.a. early-exit fold.
1564  res = nextInsertOp.getValueToStore();
1565  // Case 2.b. if internal transposition is present, canFold will be false.
1566  return success(canFold());
1567 }
1568 
1569 /// Case 3: if inserted position is a prefix of extractPosition,
1570 /// extract a portion of the source of the insertion.
1571 /// This method updates the internal state.
1572 LogicalResult
1573 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1574  // TODO: Canonicalization for dynamic position not implemented yet.
1575  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1576  return failure();
1577 
1578  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1579  if (!isContainedWithin(insertedPos, extractPosition))
1580  return failure();
1581  // Set leading dims to zero.
1582  std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1583  // Drop extra leading dims.
1584  extractPosition.erase(extractPosition.begin(),
1585  extractPosition.begin() + insertedPos.size());
1586  extractedRank = extractPosition.size() - sentinels.size();
1587  // Case 3.a. early-exit fold (break and delegate to post-while path).
1588  res = nextInsertOp.getValueToStore();
1589  // Case 3.b. if internal transposition is present, canFold will be false.
1590  return success();
1591 }
1592 
1593 /// Try to fold in place to extract(source, extractPosition) and return the
1594 /// folded result. Return null if folding is not possible (e.g. due to an
1595 /// internal transposition in the result).
1596 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1597  Value source) {
1598  // TODO: Canonicalization for dynamic position not implemented yet.
1599  if (extractOp.hasDynamicPosition())
1600  return Value();
1601 
1602  // If we can't fold (either internal transposition, or nothing to fold), bail.
1603  bool nothingToFold = (source == extractOp.getSource());
1604  if (nothingToFold || !canFold())
1605  return Value();
1606 
1607  // Otherwise, fold by updating the op inplace and return its result.
1608  OpBuilder b(extractOp.getContext());
1609  extractOp.setStaticPosition(
1610  ArrayRef(extractPosition).take_front(extractedRank));
1611  extractOp.getSourceMutable().assign(source);
1612  return extractOp.getResult();
1613 }
1614 
1615 /// Iterate over producing insert and transpose ops until we find a fold.
1616 Value ExtractFromInsertTransposeChainState::fold() {
1617  // TODO: Canonicalization for dynamic position not implemented yet.
1618  if (extractOp.hasDynamicPosition())
1619  return Value();
1620 
1621  Value valueToExtractFrom = extractOp.getSource();
1622  updateStateForNextIteration(valueToExtractFrom);
1623  while (nextInsertOp || nextTransposeOp) {
1624  // Case 1. If we hit a transpose, just compose the map and iterate.
1625  // Invariant: insert + transpose do not change rank, we can always compose.
1626  if (succeeded(handleTransposeOp())) {
1627  valueToExtractFrom = nextTransposeOp.getVector();
1628  updateStateForNextIteration(valueToExtractFrom);
1629  continue;
1630  }
1631 
1632  Value result;
1633  // Case 2: the position match exactly.
1634  if (succeeded(handleInsertOpWithMatchingPos(result)))
1635  return result;
1636 
1637  // Case 3: if the inserted position is a prefix of extractPosition, we can
1638  // just extract a portion of the source of the insert.
1639  if (succeeded(handleInsertOpWithPrefixPos(result)))
1640  return tryToFoldExtractOpInPlace(result);
1641 
1642  // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1643  // values. This is a more difficult case and we bail.
1644  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1645  if (isContainedWithin(extractPosition, insertedPos) ||
1646  intersectsWhereNonNegative(extractPosition, insertedPos))
1647  return Value();
1648 
1649  // Case 5: No intersection, we forward the extract to insertOp.dest().
1650  valueToExtractFrom = nextInsertOp.getDest();
1651  updateStateForNextIteration(valueToExtractFrom);
1652  }
1653  // If after all this we can fold, go for it.
1654  return tryToFoldExtractOpInPlace(valueToExtractFrom);
1655 }
1656 
1657 /// Returns true if the operation has a 0-D vector type operand or result.
1658 static bool hasZeroDimVectors(Operation *op) {
1659  auto hasZeroDimVectorType = [](Type type) -> bool {
1660  auto vecType = dyn_cast<VectorType>(type);
1661  return vecType && vecType.getRank() == 0;
1662  };
1663 
1664  return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1665  llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1666 }
1667 
1668 /// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
1669 /// considered to be 'broadcastlike'.
1670 static bool isBroadcastLike(Operation *op) {
1671  if (isa<BroadcastOp>(op))
1672  return true;
1673 
1674  auto shapeCast = dyn_cast<ShapeCastOp>(op);
1675  if (!shapeCast)
1676  return false;
1677 
1678  // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1679  // Checking that the destination shape has a prefix of 1s is not sufficient,
1680  // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
1681  // is that the source shape is a suffix of the destination shape.
1682  VectorType srcType = shapeCast.getSourceVectorType();
1683  ArrayRef<int64_t> srcShape = srcType.getShape();
1684  uint64_t srcRank = srcType.getRank();
1685  ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
1686  return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1687 }
1688 
1689 /// Fold extract(broadcast(X)) to either extract(X) or just X.
1690 ///
1691 /// Example:
1692 ///
1693 /// broadcast extract [1][2]
1694 /// (3, 4) --------> (2, 3, 4) ----------------> (4)
1695 ///
1696 /// becomes
1697 /// extract [1]
1698 /// (3,4) -------------------------------------> (4)
1699 ///
1700 ///
1701 /// The variable names used in this implementation correspond to the above
1702 /// shapes as,
1703 ///
1704 /// - (3, 4) is `input` shape.
1705 /// - (2, 3, 4) is `broadcast` shape.
1706 /// - (4) is `extract` shape.
1707 ///
1708 /// This folding is possible when the suffix of `input` shape is the same as
1709 /// `extract` shape.
1710 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1711 
1712  Operation *defOp = extractOp.getSource().getDefiningOp();
1713  if (!defOp || !isBroadcastLike(defOp))
1714  return Value();
1715 
1716  Value input = defOp->getOperand(0);
1717 
1718  // Replace extract(broadcast(X)) with X
1719  if (extractOp.getType() == input.getType())
1720  return input;
1721 
1722  // Get required types and ranks in the chain
1723  // input -> broadcast -> extract
1724  // (scalars are treated as rank-0).
1725  auto inputType = llvm::dyn_cast<VectorType>(input.getType());
1726  auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1727  unsigned inputRank = inputType ? inputType.getRank() : 0;
1728  unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1729  unsigned extractRank = extractType ? extractType.getRank() : 0;
1730 
1731  // Cannot do without the broadcast if overall the rank increases.
1732  if (extractRank > inputRank)
1733  return Value();
1734 
1735  // The above condition guarantees that input is a vector.
1736  assert(inputType && "input must be a vector type because of previous checks");
1737  ArrayRef<int64_t> inputShape = inputType.getShape();
1738 
1739  // In the case where there is a broadcast dimension in the suffix, it is not
1740  // possible to replace extract(broadcast(X)) with extract(X). Example:
1741  //
1742  // broadcast extract
1743  // (1) --------> (3,4) ------> (4)
1744  if (extractType &&
1745  extractType.getShape() != inputShape.take_back(extractRank))
1746  return Value();
1747 
1748  // Replace extract(broadcast(X)) with extract(X).
1749  // First, determine the new extraction position.
1750  unsigned deltaOverall = inputRank - extractRank;
1751  unsigned deltaBroadcast = broadcastRank - inputRank;
1752  SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
1753  SmallVector<OpFoldResult> newPositions(deltaOverall);
1754  IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
1755  for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1756  newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1757  }
1758  auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
1759  extractOp->setOperands(
1760  llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
1761  extractOp.setStaticPosition(staticPos);
1762  return extractOp.getResult();
1763 }
1764 
1765 /// Fold extractOp coming from ShuffleOp.
1766 ///
1767 /// Example:
1768 ///
1769 /// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1770 /// : vector<8xf32>, vector<8xf32>
1771 /// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1772 /// ->
1773 /// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1774 ///
1775 static Value foldExtractFromShuffle(ExtractOp extractOp) {
1776  // Dynamic positions are not folded as the resulting code would be more
1777  // complex than the input code.
1778  if (extractOp.hasDynamicPosition())
1779  return Value();
1780 
1781  auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1782  if (!shuffleOp)
1783  return Value();
1784 
1785  // TODO: 0-D or multi-dimensional vectors not supported yet.
1786  if (shuffleOp.getResultVectorType().getRank() != 1)
1787  return Value();
1788 
1789  int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1790  auto shuffleMask = shuffleOp.getMask();
1791  int64_t extractIdx = extractOp.getStaticPosition()[0];
1792  int64_t shuffleIdx = shuffleMask[extractIdx];
1793 
1794  // Find the shuffled vector to extract from based on the shuffle index.
1795  if (shuffleIdx < inputVecSize) {
1796  extractOp.setOperand(0, shuffleOp.getV1());
1797  extractOp.setStaticPosition({shuffleIdx});
1798  } else {
1799  extractOp.setOperand(0, shuffleOp.getV2());
1800  extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1801  }
1802 
1803  return extractOp.getResult();
1804 }
1805 
1806 // Fold extractOp with source coming from ShapeCast op.
1807 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1808  // TODO: Canonicalization for dynamic position not implemented yet.
1809  if (extractOp.hasDynamicPosition())
1810  return Value();
1811 
1812  auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1813  if (!shapeCastOp)
1814  return Value();
1815 
1816  // Get the nth dimension size starting from lowest dimension.
1817  auto getDimReverse = [](VectorType type, int64_t n) {
1818  return type.getShape().take_back(n + 1).front();
1819  };
1820  int64_t destinationRank =
1821  llvm::isa<VectorType>(extractOp.getType())
1822  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1823  : 0;
1824  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1825  return Value();
1826  if (destinationRank > 0) {
1827  auto destinationType =
1828  llvm::cast<VectorType>(extractOp.getResult().getType());
1829  for (int64_t i = 0; i < destinationRank; i++) {
1830  // The lowest dimension of the destination must match the lowest
1831  // dimension of the shapecast op source.
1832  // TODO: This case could be support in a canonicalization pattern.
1833  if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1834  getDimReverse(destinationType, i))
1835  return Value();
1836  }
1837  }
1838  // Extract the strides associated with the extract op vector source. Then use
1839  // this to calculate a linearized position for the extract.
1840  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1841  std::reverse(extractedPos.begin(), extractedPos.end());
1842  SmallVector<int64_t, 4> strides;
1843  int64_t stride = 1;
1844  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1845  strides.push_back(stride);
1846  stride *=
1847  getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1848  }
1849 
1850  int64_t position = linearize(extractedPos, strides);
1851  // Then extract the strides associated to the shapeCast op vector source and
1852  // delinearize the position using those strides.
1853  SmallVector<int64_t, 4> newStrides;
1854  int64_t numDimension =
1855  shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1856  stride = 1;
1857  for (int64_t i = 0; i < numDimension; i++) {
1858  newStrides.push_back(stride);
1859  stride *=
1860  getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1861  }
1862  std::reverse(newStrides.begin(), newStrides.end());
1863  SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1864  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1865  OpBuilder b(extractOp.getContext());
1866  extractOp.setStaticPosition(newPosition);
1867  extractOp.setOperand(0, shapeCastOp.getSource());
1868  return extractOp.getResult();
1869 }
1870 
1871 /// Fold an ExtractOp from ExtractStridedSliceOp.
1872 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1873  // TODO: Canonicalization for dynamic position not implemented yet.
1874  if (extractOp.hasDynamicPosition())
1875  return Value();
1876 
1877  auto extractStridedSliceOp =
1878  extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1879  if (!extractStridedSliceOp)
1880  return Value();
1881 
1882  // 0-D vectors not supported.
1883  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1884  if (hasZeroDimVectors(extractStridedSliceOp))
1885  return Value();
1886 
1887  // Return if 'extractStridedSliceOp' has non-unit strides.
1888  if (extractStridedSliceOp.hasNonUnitStrides())
1889  return Value();
1890 
1891  // Trim offsets for dimensions fully extracted.
1892  auto sliceOffsets =
1893  extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1894  while (!sliceOffsets.empty()) {
1895  size_t lastOffset = sliceOffsets.size() - 1;
1896  if (sliceOffsets.back() != 0 ||
1897  extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1898  extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1899  break;
1900  sliceOffsets.pop_back();
1901  }
1902  unsigned destinationRank = 0;
1903  if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1904  destinationRank = vecType.getRank();
1905  // The dimensions of the result need to be untouched by the
1906  // extractStridedSlice op.
1907  if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1908  sliceOffsets.size())
1909  return Value();
1910 
1911  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1912  assert(extractedPos.size() >= sliceOffsets.size());
1913  for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1914  extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1915  extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1916 
1917  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1918  OpBuilder b(extractOp.getContext());
1919  extractOp.setStaticPosition(extractedPos);
1920  return extractOp.getResult();
1921 }
1922 
1923 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1924 static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1925  // TODO: Canonicalization for dynamic position not implemented yet.
1926  if (extractOp.hasDynamicPosition())
1927  return Value();
1928 
1929  int64_t destinationRank =
1930  llvm::isa<VectorType>(extractOp.getType())
1931  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1932  : 0;
1933  auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1934  if (!insertOp)
1935  return Value();
1936 
1937  // 0-D vectors not supported.
1938  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1939  if (hasZeroDimVectors(insertOp))
1940  return Value();
1941 
1942  while (insertOp) {
1943  int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1944  insertOp.getSourceVectorType().getRank();
1945  if (destinationRank > insertOp.getSourceVectorType().getRank())
1946  return Value();
1947  auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1948  ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1949 
1950  if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1951  return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1952  }))
1953  return Value();
1954  bool disjoint = false;
1955  SmallVector<int64_t, 4> offsetDiffs;
1956  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1957  int64_t start = insertOffsets[dim];
1958  int64_t size =
1959  (dim < insertRankDiff)
1960  ? 1
1961  : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1962  int64_t end = start + size;
1963  int64_t offset = extractOffsets[dim];
1964  // Check if the start of the extract offset is in the interval inserted.
1965  if (start <= offset && offset < end) {
1966  if (dim >= insertRankDiff)
1967  offsetDiffs.push_back(offset - start);
1968  continue;
1969  }
1970  disjoint = true;
1971  break;
1972  }
1973  // The extract element chunk overlap with the vector inserted.
1974  if (!disjoint) {
1975  // If any of the inner dimensions are only partially inserted we have a
1976  // partial overlap.
1977  int64_t srcRankDiff =
1978  insertOp.getSourceVectorType().getRank() - destinationRank;
1979  for (int64_t i = 0; i < destinationRank; i++) {
1980  if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1981  insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1982  insertRankDiff))
1983  return Value();
1984  }
1985  extractOp.getSourceMutable().assign(insertOp.getValueToStore());
1986  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1987  OpBuilder b(extractOp.getContext());
1988  extractOp.setStaticPosition(offsetDiffs);
1989  return extractOp.getResult();
1990  }
1991  // If the chunk extracted is disjoint from the chunk inserted, keep
1992  // looking in the insert chain.
1993  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1994  }
1995  return Value();
1996 }
1997 
1998 /// Try to fold the extraction of a scalar from a vector defined by
1999 /// vector.from_elements. E.g.:
2000 ///
2001 /// %0 = vector.from_elements %a, %b : vector<2xf32>
2002 /// %1 = vector.extract %0[0] : f32 from vector<2xf32>
2003 /// ==> fold to %a
2004 static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
2005  // Dynamic extractions cannot be folded.
2006  if (extractOp.hasDynamicPosition())
2007  return {};
2008 
2009  // Look for extract(from_elements).
2010  auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2011  if (!fromElementsOp)
2012  return {};
2013 
2014  // Scalable vectors are not supported.
2015  auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2016  if (vecType.isScalable())
2017  return {};
2018 
2019  // Only extractions of scalars are supported.
2020  int64_t rank = vecType.getRank();
2021  ArrayRef<int64_t> indices = extractOp.getStaticPosition();
2022  if (extractOp.getType() != vecType.getElementType())
2023  return {};
2024  assert(static_cast<int64_t>(indices.size()) == rank &&
2025  "unexpected number of indices");
2026 
2027  // Compute flattened/linearized index and fold to operand.
2028  int flatIndex = 0;
2029  int stride = 1;
2030  for (int i = rank - 1; i >= 0; --i) {
2031  flatIndex += indices[i] * stride;
2032  stride *= vecType.getDimSize(i);
2033  }
2034  return fromElementsOp.getElements()[flatIndex];
2035 }
2036 
2037 /// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
2038 /// then fold it.
2039 template <typename OpType, typename AdaptorType>
2040 static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
2041  SmallVectorImpl<Value> &operands) {
2042  std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2043  OperandRange dynamicPosition = op.getDynamicPosition();
2044  ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2046  if constexpr (std::is_same_v<OpType, ExtractOp>)
2047  vectorShape = op.getSourceVectorType().getShape();
2048  else
2049  vectorShape = op.getDestVectorType().getShape();
2050 
2051  // If the dynamic operands is empty, it is returned directly.
2052  if (!dynamicPosition.size())
2053  return {};
2054 
2055  // `index` is used to iterate over the `dynamicPosition`.
2056  unsigned index = 0;
2057 
2058  // `opChange` is a flag. If it is true, it means to update `op` in place.
2059  bool opChange = false;
2060  for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2061  if (ShapedType::isStatic(staticPosition[i]))
2062  continue;
2063  Attribute positionAttr = dynamicPositionAttr[index];
2064  Value position = dynamicPosition[index++];
2065  if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2066  int64_t value = attr.getInt();
2067  // Do not fold if the value is out of bounds (-1 signifies a poison
2068  // value rather than OOB index).
2069  if (value >= -1 && value < vectorShape[i]) {
2070  staticPosition[i] = attr.getInt();
2071  opChange = true;
2072  continue;
2073  }
2074  }
2075  operands.push_back(position);
2076  }
2077 
2078  if (opChange) {
2079  op.setStaticPosition(staticPosition);
2080  op.getOperation()->setOperands(operands);
2081  // Return the original result to indicate an in-place folding happened.
2082  return op.getResult();
2083  }
2084  return {};
2085 }
2086 
2087 /// Fold an insert or extract operation into an poison value when a poison index
2088 /// is found at any dimension of the static position.
2090  ArrayRef<int64_t> staticPos,
2091  int64_t poisonVal) {
2092  if (!is_contained(staticPos, poisonVal))
2093  return {};
2094 
2095  return ub::PoisonAttr::get(context);
2096 }
2097 
2098 /// Fold a vector extract from is a poison source.
2100  if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2101  return srcAttr;
2102 
2103  return {};
2104 }
2105 
2106 /// Fold a vector extract extracting from a DenseElementsAttr.
2108  Attribute srcAttr) {
2109  auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2110  if (!denseAttr) {
2111  return {};
2112  }
2113 
2114  if (denseAttr.isSplat()) {
2115  Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2116  if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2117  newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2118  return newAttr;
2119  }
2120 
2121  auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2122  if (vecTy.isScalable())
2123  return {};
2124 
2125  if (extractOp.hasDynamicPosition()) {
2126  return {};
2127  }
2128 
2129  // Materializing subsets of a large constant array can generally lead to
2130  // explosion in IR size because of different combination of subsets that
2131  // can exist. However, vector.extract is a restricted form of subset
2132  // extract where you can only extract non-overlapping (or the same) subset for
2133  // a given rank of the subset. Because of this property, the IR size can only
2134  // increase at most by `rank * size(array)` from a single constant array being
2135  // extracted by multiple extracts.
2136 
2137  // Calculate the linearized position of the continuous chunk of elements to
2138  // extract.
2139  SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2140  copy(extractOp.getStaticPosition(), completePositions.begin());
2141  int64_t startPos =
2142  linearize(completePositions, computeStrides(vecTy.getShape()));
2143  auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2144 
2145  TypedAttr newAttr;
2146  if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2147  SmallVector<Attribute> elementValues(
2148  denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2149  newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2150  } else {
2151  newAttr = *denseValuesBegin;
2152  }
2153 
2154  return newAttr;
2155 }
2156 
2157 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2158  // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2159  // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2160  // mismatch).
2161  if (getNumIndices() == 0 && getSource().getType() == getResult().getType())
2162  return getSource();
2163  if (auto res = foldPoisonSrcExtractOp(adaptor.getSource()))
2164  return res;
2165  // Fold `arith.constant` indices into the `vector.extract` operation.
2166  // Do not stop here as this fold may enable subsequent folds that require
2167  // constant indices.
2168  SmallVector<Value> operands = {getSource()};
2169  auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
2170 
2171  if (auto res = foldPoisonIndexInsertExtractOp(
2172  getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2173  return res;
2174  if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getSource()))
2175  return res;
2176  if (succeeded(foldExtractOpFromExtractChain(*this)))
2177  return getResult();
2178  if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2179  return res;
2180  if (auto res = foldExtractFromBroadcast(*this))
2181  return res;
2182  if (auto res = foldExtractFromShuffle(*this))
2183  return res;
2184  if (auto res = foldExtractFromShapeCast(*this))
2185  return res;
2186  if (auto val = foldExtractFromExtractStrided(*this))
2187  return val;
2188  if (auto val = foldExtractStridedOpFromInsertChain(*this))
2189  return val;
2190  if (auto val = foldScalarExtractFromFromElements(*this))
2191  return val;
2192 
2193  return inplaceFolded;
2194 }
2195 
2196 namespace {
2197 
2198 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2199 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2200 public:
2201  using Base::Base;
2202 
2203  LogicalResult matchAndRewrite(ExtractOp extractOp,
2204  PatternRewriter &rewriter) const override {
2205 
2206  Operation *defOp = extractOp.getSource().getDefiningOp();
2207  VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2208  if (!defOp || !isBroadcastLike(defOp) || !outType)
2209  return failure();
2210 
2211  Value source = defOp->getOperand(0);
2212  if (isBroadcastableTo(source.getType(), outType) !=
2213  BroadcastableToResult::Success)
2214  return failure();
2215 
2216  rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
2217  return success();
2218  }
2219 };
2220 
2221 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2222 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2223 public:
2224  using Base::Base;
2225 
2226  LogicalResult matchAndRewrite(ExtractOp extractOp,
2227  PatternRewriter &rewriter) const override {
2228  auto createMaskOp =
2229  extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2230  if (!createMaskOp)
2231  return failure();
2232 
2233  VectorType extractedMaskType =
2234  llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2235 
2236  if (!extractedMaskType)
2237  return failure();
2238 
2239  auto maskOperands = createMaskOp.getOperands();
2240  ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2241  VectorType maskType = createMaskOp.getVectorType();
2242 
2243  bool containsUnknownDims = false;
2244  bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2245 
2246  for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2247  dimIdx++) {
2248  int64_t pos = extractOpPos[dimIdx];
2249  Value operand = maskOperands[dimIdx];
2250  auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2251  if (!constantOp) {
2252  // Bounds of this dim unknown.
2253  containsUnknownDims = true;
2254  continue;
2255  }
2256 
2257  int64_t createMaskBound =
2258  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2259 
2260  if (pos != ShapedType::kDynamic) {
2261  // If any position is outside the range from the `create_mask`, then the
2262  // extracted mask will be all-false.
2263  allFalse |= pos >= createMaskBound;
2264  } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2265  // This dim is not all-true and since this is a dynamic index we don't
2266  // know if the extraction is within the true or false region.
2267  // Note: Zero dims have already handled via getMaskFormat().
2268  containsUnknownDims = true;
2269  }
2270  }
2271 
2272  if (allFalse) {
2273  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2274  extractOp, DenseElementsAttr::get(extractedMaskType, false));
2275  } else if (!containsUnknownDims) {
2276  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2277  extractOp, extractedMaskType,
2278  maskOperands.drop_front(extractOpPos.size()));
2279  } else {
2280  return failure();
2281  }
2282  return success();
2283  }
2284 };
2285 
2286 // Folds extract(shape_cast(..)) into shape_cast when the total element count
2287 // does not change.
2288 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2289  PatternRewriter &rewriter) {
2290  auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2291  if (!castOp)
2292  return failure();
2293 
2294  VectorType sourceType = castOp.getSourceVectorType();
2295  auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2296  if (!targetType)
2297  return failure();
2298 
2299  if (sourceType.getNumElements() != targetType.getNumElements())
2300  return failure();
2301 
2302  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2303  castOp.getSource());
2304  return success();
2305 }
2306 
2307 /// Try to canonicalize the extraction of a subvector from a vector defined by
2308 /// vector.from_elements. E.g.:
2309 ///
2310 /// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2311 /// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2312 /// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2313 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2314  PatternRewriter &rewriter) {
2315  // Dynamic positions are not supported.
2316  if (extractOp.hasDynamicPosition())
2317  return failure();
2318 
2319  // Scalar extracts are handled by the folder.
2320  auto resultType = dyn_cast<VectorType>(extractOp.getType());
2321  if (!resultType)
2322  return failure();
2323 
2324  // Look for extracts from a from_elements op.
2325  auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2326  if (!fromElementsOp)
2327  return failure();
2328  VectorType inputType = fromElementsOp.getType();
2329 
2330  // Scalable vectors are not supported.
2331  if (resultType.isScalable() || inputType.isScalable())
2332  return failure();
2333 
2334  // Compute the position of first extracted element and flatten/linearize the
2335  // position.
2336  SmallVector<int64_t> firstElementPos =
2337  llvm::to_vector(extractOp.getStaticPosition());
2338  firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2339  int flatIndex = 0;
2340  int stride = 1;
2341  for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2342  flatIndex += firstElementPos[i] * stride;
2343  stride *= inputType.getDimSize(i);
2344  }
2345 
2346  // Replace the op with a smaller from_elements op.
2347  rewriter.replaceOpWithNewOp<FromElementsOp>(
2348  extractOp, resultType,
2349  fromElementsOp.getElements().slice(flatIndex,
2350  resultType.getNumElements()));
2351  return success();
2352 }
2353 
2354 } // namespace
2355 
2356 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2357  MLIRContext *context) {
2358  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2359  results.add(foldExtractFromShapeCastToShapeCast);
2360  results.add(foldExtractFromFromElements);
2361 }
2362 
2363 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
2364  SmallVectorImpl<int64_t> &results) {
2365  for (auto attr : arrayAttr)
2366  results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2367 }
2368 
2369 //===----------------------------------------------------------------------===//
2370 // FmaOp
2371 //===----------------------------------------------------------------------===//
2372 
2373 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2374  return llvm::to_vector<4>(getVectorType().getShape());
2375 }
2376 
2377 //===----------------------------------------------------------------------===//
2378 // ToElementsOp
2379 //===----------------------------------------------------------------------===//
2380 
2381 /// Returns true if all the `operands` are defined by `defOp`.
2382 /// Otherwise, returns false.
2383 static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
2384  if (operands.empty())
2385  return false;
2386 
2387  return llvm::all_of(operands, [&](Value operand) {
2388  Operation *currentDef = operand.getDefiningOp();
2389  return currentDef == defOp;
2390  });
2391 }
2392 
2393 /// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
2394 /// (%e0, %e1, ...). For example:
2395 ///
2396 /// %0 = vector.from_elements %a, %b, %c : vector<3xf32>
2397 /// %1:3 = vector.to_elements %0 : vector<3xf32>
2398 /// user_op %1#0, %1#1, %1#2
2399 ///
2400 /// becomes:
2401 ///
2402 /// user_op %a, %b, %c
2403 ///
2404 static LogicalResult
2405 foldToElementsFromElements(ToElementsOp toElementsOp,
2406  SmallVectorImpl<OpFoldResult> &results) {
2407  auto fromElementsOp =
2408  toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2409  if (!fromElementsOp)
2410  return failure();
2411 
2412  llvm::append_range(results, fromElementsOp.getElements());
2413  return success();
2414 }
2415 
2416 /// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
2417 ///
2418 /// Example:
2419 /// %b = vector.broadcast %x : i32 to vector<3xf32>
2420 /// %e:3 = vector.to_elements %b : vector<3xf32>
2421 /// user_op %e#0, %e#1, %e#2
2422 /// becomes:
2423 /// user_op %x, %x, %x
2424 ///
2425 /// The vector source case is handled by a canonicalization pattern.
2426 static LogicalResult
2427 foldToElementsOfBroadcast(ToElementsOp toElementsOp,
2428  SmallVectorImpl<OpFoldResult> &results) {
2429  auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2430  if (!bcastOp)
2431  return failure();
2432  // Vectors are handled in the ToElementsOfBroadcast RewritePattern.
2433  if (isa<VectorType>(bcastOp.getSource().getType()))
2434  return failure();
2435 
2436  auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2437 
2438  Value scalar = bcastOp.getSource();
2439  results.assign(resultVecType.getNumElements(), scalar);
2440  return success();
2441 }
2442 
2443 LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2444  SmallVectorImpl<OpFoldResult> &results) {
2445  if (succeeded(foldToElementsFromElements(*this, results)))
2446  return success();
2447  return foldToElementsOfBroadcast(*this, results);
2448 }
2449 
2450 LogicalResult
2451 ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2452  ToElementsOp::Adaptor adaptor,
2453  SmallVectorImpl<Type> &inferredReturnTypes) {
2454  auto vecType = cast<VectorType>(adaptor.getSource().getType());
2455  Type elType = vecType.getElementType();
2456  inferredReturnTypes.append(vecType.getNumElements(), elType);
2457  return success();
2458 }
2459 
2460 /// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
2461 /// vector.
2462 /// - Build `vector.to_elements %v` and remap each destination element to the
2463 /// corresponding source element using broadcast rules (match or 1 →
2464 /// replicate).
2465 ///
2466 /// Example:
2467 /// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
2468 /// %e:6 = vector.to_elements %v : vector<3x2xf32>
2469 /// becomes:
2470 /// %src_elems:2 = vector.to_elements %src : vector<2xf32>
2471 /// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2472 /// // %src_elems#1, %src_elems#0, %src_elems#1
2473 struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
2474  using Base::Base;
2475 
2476  LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
2477  PatternRewriter &rewriter) const override {
2478  auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2479  if (!bcastOp)
2480  return failure();
2481 
2482  // Only handle broadcasts from a vector source here.
2483  auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2484  if (!srcType)
2485  return failure();
2486 
2487  auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2488 
2489  ArrayRef<int64_t> dstShape = dstType.getShape();
2490  ArrayRef<int64_t> srcShape = srcType.getShape();
2491 
2492  int64_t dstRank = dstShape.size();
2493  int64_t srcRank = srcShape.size();
2494 
2495  // Create elements for the broadcast source vector.
2496  auto srcElems = vector::ToElementsOp::create(
2497  rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2498 
2499  int64_t dstCount = llvm::product_of(dstShape);
2500 
2501  SmallVector<Value> replacements;
2502  replacements.reserve(dstCount);
2503 
2504  // For each element of the destination, determine which element of the
2505  // source should be used. We walk all destination positions using a single
2506  // counter, decode it into per-dimension indices, then build the matching
2507  // source position: use the same index where sizes match, and use 0 where
2508  // the source size is 1 (replication). This mapping is needed so we can
2509  // replace each result of to_elements with the corresponding element from
2510  // the broadcast source.
2511  // Inner-dimension stretch example:
2512  // %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
2513  // %e:12 = vector.to_elements %v : vector<2x3x2xf32>
2514  // becomes:
2515  // %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
2516  // // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2517  // // %src_elems#1, %src_elems#0, %src_elems#1,
2518  // // %src_elems#2, %src_elems#3, %src_elems#2,
2519  // // %src_elems#3, %src_elems#2, %src_elems#3
2520 
2521  // Row-major strides for the destination shape.
2522  SmallVector<int64_t> dstStrides = computeStrides(dstShape);
2523  // Row-major strides for the source shape.
2524  SmallVector<int64_t> srcStrides = computeStrides(srcShape);
2525  SmallVector<int64_t> dstIdx(dstRank);
2526  SmallVector<int64_t> srcIdx(srcRank);
2527  for (int64_t lin = 0; lin < dstCount; ++lin) {
2528  // Convert linear destination index to per-dimension indices.
2529  dstIdx = delinearize(lin, dstStrides);
2530  for (int64_t k = 0; k < srcRank; ++k)
2531  srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2532  // Convert per-dimension source indices back to a linear index.
2533  int64_t srcLin = linearize(srcIdx, srcStrides);
2534  replacements.push_back(srcElems.getResult(srcLin));
2535  }
2536 
2537  rewriter.replaceOp(toElementsOp, replacements);
2538  return success();
2539  }
2540 };
2541 
2542 void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2543  MLIRContext *context) {
2544  results.add<ToElementsOfBroadcast>(context);
2545 }
2546 
2547 //===----------------------------------------------------------------------===//
2548 // FromElementsOp
2549 //===----------------------------------------------------------------------===//
2550 
2551 /// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
2552 ///
2553 /// Case #1: Input and output vectors are the same.
2554 ///
2555 /// %0:3 = vector.to_elements %a : vector<3xf32>
2556 /// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
2557 /// user_op %1
2558 ///
2559 /// becomes:
2560 ///
2561 /// user_op %a
2562 ///
2563 static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
2564  OperandRange fromElemsOperands = fromElementsOp.getElements();
2565  if (fromElemsOperands.empty())
2566  return {};
2567 
2568  auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2569  if (!toElementsOp)
2570  return {};
2571 
2572  if (!haveSameDefiningOp(fromElemsOperands, toElementsOp))
2573  return {};
2574 
2575  // Case #1: Input and output vectors are the same. Forward the input vector.
2576  Value toElementsInput = toElementsOp.getSource();
2577  if (fromElementsOp.getType() == toElementsInput.getType() &&
2578  llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2579  return toElementsInput;
2580  }
2581 
2582  // TODO: Support cases with different input and output shapes and different
2583  // number of elements.
2584 
2585  return {};
2586 }
2587 
2588 /// Fold vector.from_elements to a constant when all operands are constants.
2589 /// Example:
2590 /// %c1 = arith.constant 1 : i32
2591 /// %c2 = arith.constant 2 : i32
2592 /// %v = vector.from_elements %c1, %c2 : vector<2xi32>
2593 /// =>
2594 /// %v = arith.constant dense<[1, 2]> : vector<2xi32>
2595 ///
2596 static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
2597  ArrayRef<Attribute> elements) {
2598  // Check for null or poison attributes before any processing.
2599  if (llvm::any_of(elements, [](Attribute attr) {
2600  return !attr || isa<ub::PoisonAttrInterface>(attr);
2601  }))
2602  return {};
2603 
2604  // DenseElementsAttr only supports int/index/float/complex types.
2605  auto destVecType = fromElementsOp.getDest().getType();
2606  auto destEltType = destVecType.getElementType();
2607  if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2608  return {};
2609 
2610  // Constant attributes might have a different type than the return type.
2611  // Convert them before creating the dense elements attribute.
2612  auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
2613  return convertNumericAttr(attr, destEltType);
2614  });
2615 
2616  return DenseElementsAttr::get(destVecType, convertedElements);
2617 }
2618 
2619 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2620  if (auto res = foldFromElementsToElements(*this))
2621  return res;
2622  if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
2623  return res;
2624 
2625  return {};
2626 }
2627 
2628 /// Rewrite vector.from_elements as vector.broadcast if the elements are the
2629 /// same. Example:
2630 /// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2631 /// =>
2632 /// %0 = vector.broadcast %a : f32 to vector<3xf32>
2633 static LogicalResult
2634 rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
2635  PatternRewriter &rewriter) {
2636  if (!llvm::all_equal(fromElementsOp.getElements()))
2637  return failure();
2638  rewriter.replaceOpWithNewOp<BroadcastOp>(
2639  fromElementsOp, fromElementsOp.getType(),
2640  fromElementsOp.getElements().front());
2641  return success();
2642 }
2643 
2644 /// Rewrite from_elements on multiple scalar extracts as a shape_cast
2645 /// on a single extract. Example:
2646 /// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
2647 /// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
2648 /// %2 = vector.from_elements %0, %1 : vector<2xi8>
2649 ///
2650 /// becomes
2651 /// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
2652 /// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
2653 ///
2654 /// The requirements for this to be valid are
2655 ///
2656 /// i) The elements are extracted from the same vector (%source).
2657 ///
2658 /// ii) The elements form a suffix of %source. Specifically, the number
2659 /// of elements is the same as the product of the last N dimension sizes
2660 /// of %source, for some N.
2661 ///
2662 /// iii) The elements are extracted contiguously in ascending order.
2663 
2664 class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
2665 
2666  using Base::Base;
2667 
2668  LogicalResult matchAndRewrite(FromElementsOp fromElements,
2669  PatternRewriter &rewriter) const override {
2670 
2671  // Handled by `rewriteFromElementsAsBroadcast`.
2672  if (fromElements.getType().getNumElements() == 1)
2673  return failure();
2674 
2675  // The common source that all elements are extracted from, if one exists.
2676  TypedValue<VectorType> source;
2677  // The position of the combined extract operation, if one is created.
2678  ArrayRef<int64_t> combinedPosition;
2679  // The expected index of extraction of the current element in the loop, if
2680  // elements are extracted contiguously in ascending order.
2681  SmallVector<int64_t> expectedPosition;
2682 
2683  for (auto [insertIndex, element] :
2684  llvm::enumerate(fromElements.getElements())) {
2685 
2686  // Check that the element is from a vector.extract operation.
2687  auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2688  if (!extractOp) {
2689  return rewriter.notifyMatchFailure(fromElements,
2690  "element not from vector.extract");
2691  }
2692 
2693  // Check condition (i) by checking that all elements have the same source
2694  // as the first element.
2695  if (insertIndex == 0) {
2696  source = extractOp.getSource();
2697  } else if (extractOp.getSource() != source) {
2698  return rewriter.notifyMatchFailure(fromElements,
2699  "element from different vector");
2700  }
2701 
2702  ArrayRef<int64_t> position = extractOp.getStaticPosition();
2703  int64_t rank = position.size();
2704  assert(rank == source.getType().getRank() &&
2705  "scalar extract must have full rank position");
2706 
2707  // Check condition (ii) by checking that the position that the first
2708  // element is extracted from has sufficient trailing 0s. For example, in
2709  //
2710  // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
2711  // [...]
2712  // %elms = vector.from_elements %elm0, [...] : vector<12xi8>
2713  //
2714  // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
2715  // elements, which is the number of elements of %n, so this is valid.
2716  if (insertIndex == 0) {
2717  const int64_t numElms = fromElements.getType().getNumElements();
2718  int64_t numSuffixElms = 1;
2719  int64_t index = rank;
2720  while (index > 0 && position[index - 1] == 0 &&
2721  numSuffixElms < numElms) {
2722  numSuffixElms *= source.getType().getDimSize(index - 1);
2723  --index;
2724  }
2725  if (numSuffixElms != numElms) {
2726  return rewriter.notifyMatchFailure(
2727  fromElements, "elements do not form a suffix of source");
2728  }
2729  expectedPosition = llvm::to_vector(position);
2730  combinedPosition = position.drop_back(rank - index);
2731  }
2732 
2733  // Check condition (iii).
2734  else if (expectedPosition != position) {
2735  return rewriter.notifyMatchFailure(
2736  fromElements, "elements not in ascending order (static order)");
2737  }
2738  increment(expectedPosition, source.getType().getShape());
2739  }
2740 
2741  auto extracted = rewriter.createOrFold<vector::ExtractOp>(
2742  fromElements.getLoc(), source, combinedPosition);
2743 
2744  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2745  fromElements, fromElements.getType(), extracted);
2746 
2747  return success();
2748  }
2749 
2750  /// Increments n-D `indices` by 1 starting from the innermost dimension.
2751  static void increment(MutableArrayRef<int64_t> indices,
2752  ArrayRef<int64_t> shape) {
2753  for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2754  indices[dim] += 1;
2755  if (indices[dim] < shape[dim])
2756  break;
2757  indices[dim] = 0;
2758  }
2759  }
2760 };
2761 
2762 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2763  MLIRContext *context) {
2765  results.add<FromElementsToShapeCast>(context);
2766 }
2767 
2768 //===----------------------------------------------------------------------===//
2769 // BroadcastOp
2770 //===----------------------------------------------------------------------===//
2771 
2772 void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2773  SetIntRangeFn setResultRanges) {
2774  setResultRanges(getResult(), argRanges.front());
2775 }
2776 
2777 std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2778  return llvm::to_vector<4>(getResultVectorType().getShape());
2779 }
2780 
2781 /// Return the dimensions of the result vector that were formerly ones in the
2782 /// source tensor and thus correspond to "dim-1" broadcasting.
2785  ArrayRef<int64_t> dstShape) {
2786  int64_t rankDiff = dstShape.size() - srcShape.size();
2787  int64_t dstDim = rankDiff;
2789  for (auto [s1, s2] :
2790  llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2791  if (s1 != s2) {
2792  assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2793  res.insert(dstDim);
2794  }
2795  ++dstDim;
2796  }
2797  return res;
2798 }
2799 
2801  // Scalar broadcast is without any unit dim broadcast.
2802  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2803  if (!srcVectorType)
2804  return {};
2805  return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2806  getResultVectorType().getShape());
2807 }
2808 
2809 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2810 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
2811 /// This requires (and asserts) that the broadcast is free of "dim-1"
2812 /// broadcasting.
2813 /// Since vector.broadcast only allows expanding leading dimensions, an extra
2814 /// vector.transpose may be inserted to make the broadcast possible.
2815 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2816 /// the helper will assert. This means:
2817 /// 1. `dstShape` must not be empty.
2818 /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2819 /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2820 // must match the `value` shape.
2821 Value BroadcastOp::createOrFoldBroadcastOp(
2822  OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2823  const llvm::SetVector<int64_t> &broadcastedDims) {
2824  assert(!dstShape.empty() && "unexpected empty dst shape");
2825 
2826  // Well-formedness check.
2827  SmallVector<int64_t> checkShape;
2828  for (int i = 0, e = dstShape.size(); i < e; ++i) {
2829  if (broadcastedDims.contains(i))
2830  continue;
2831  checkShape.push_back(dstShape[i]);
2832  }
2833  assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2834  "ill-formed broadcastedDims contains values not confined to "
2835  "destVectorShape");
2836 
2837  Location loc = value.getLoc();
2838  Type elementType = getElementTypeOrSelf(value.getType());
2839  VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2840  VectorType dstVectorType = VectorType::get(dstShape, elementType);
2841 
2842  // Step 2. If scalar -> dstShape broadcast, just do it.
2843  if (!srcVectorType) {
2844  assert(checkShape.empty() &&
2845  "ill-formed createOrFoldBroadcastOp arguments");
2846  return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2847  }
2848 
2849  assert(srcVectorType.getShape().equals(checkShape) &&
2850  "ill-formed createOrFoldBroadcastOp arguments");
2851 
2852  // Step 3. Since vector.broadcast only allows creating leading dims,
2853  // vector -> dstShape broadcast may require a transpose.
2854  // Traverse the dims in order and construct:
2855  // 1. The leading entries of the broadcastShape that is guaranteed to be
2856  // achievable by a simple broadcast.
2857  // 2. The induced permutation for the subsequent vector.transpose that will
2858  // bring us from `broadcastShape` back to he desired `dstShape`.
2859  // If the induced permutation is not the identity, create a vector.transpose.
2860  SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2861  broadcastShape.reserve(dstShape.size());
2862  // Consider the example:
2863  // srcShape = 2x4
2864  // dstShape = 1x2x3x4x5
2865  // broadcastedDims = [0, 2, 4]
2866  //
2867  // We want to build:
2868  // broadcastShape = 1x3x5x2x4
2869  // permutation = [0, 2, 4, 1, 3]
2870  // ---V--- -----V-----
2871  // leading broadcast part src shape part
2872  //
2873  // Note that the trailing dims of broadcastShape are exactly the srcShape
2874  // by construction.
2875  // nextSrcShapeDim is used to keep track of where in the permutation the
2876  // "src shape part" occurs.
2877  int64_t nextSrcShapeDim = broadcastedDims.size();
2878  for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2879  if (broadcastedDims.contains(i)) {
2880  // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2881  // bring it to the head of the broadcastShape.
2882  // It will need to be permuted back from `broadcastShape.size() - 1` into
2883  // position `i`.
2884  broadcastShape.push_back(dstShape[i]);
2885  permutation[i] = broadcastShape.size() - 1;
2886  } else {
2887  // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2888  // shape and needs to be permuted into position `i`.
2889  // Don't touch `broadcastShape` here, the whole srcShape will be
2890  // appended after.
2891  permutation[i] = nextSrcShapeDim++;
2892  }
2893  }
2894  // 3.c. Append the srcShape.
2895  llvm::append_range(broadcastShape, srcVectorType.getShape());
2896 
2897  // Ensure there are no "dim-1" broadcasts.
2898  assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2899  .empty() &&
2900  "unexpected \"dim-1\" broadcast");
2901 
2902  VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2903  assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2904  vector::BroadcastableToResult::Success &&
2905  "must be broadcastable");
2906  Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2907  // Step 4. If we find any dimension that indeed needs to be permuted,
2908  // immediately return a new vector.transpose.
2909  for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2910  if (permutation[i] != i)
2911  return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2912  // Otherwise return res.
2913  return res;
2914 }
2915 
2917  Type srcType, VectorType dstVectorType,
2918  std::pair<VectorDim, VectorDim> *mismatchingDims) {
2919  // Broadcast scalar to vector of the same element type.
2920  if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
2921  srcType == getElementTypeOrSelf(dstVectorType))
2922  return BroadcastableToResult::Success;
2923  // From now on, only vectors broadcast.
2924  VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2925  if (!srcVectorType)
2926  return BroadcastableToResult::SourceTypeNotAVector;
2927 
2928  int64_t srcRank = srcVectorType.getRank();
2929  int64_t dstRank = dstVectorType.getRank();
2930  if (srcRank > dstRank)
2931  return BroadcastableToResult::SourceRankHigher;
2932  // Source has an exact match or singleton value for all trailing dimensions
2933  // (all leading dimensions are simply duplicated).
2934  int64_t lead = dstRank - srcRank;
2935  for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2936  // Have mismatching dims (in the sense of vector.broadcast semantics) been
2937  // encountered?
2938  bool foundMismatchingDims = false;
2939 
2940  // Check fixed-width dims.
2941  int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2942  int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2943  if (srcDim != 1 && srcDim != dstDim)
2944  foundMismatchingDims = true;
2945 
2946  // Check scalable flags.
2947  bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2948  bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2949  if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2950  // 1 -> [N] is fine, everything else should be rejected when mixing
2951  // fixed-width and scalable dims
2952  (srcDimScalableFlag != dstDimScalableFlag &&
2953  (srcDim != 1 || srcDimScalableFlag)))
2954  foundMismatchingDims = true;
2955 
2956  if (foundMismatchingDims) {
2957  if (mismatchingDims != nullptr) {
2958  mismatchingDims->first.dim = srcDim;
2959  mismatchingDims->first.isScalable = srcDimScalableFlag;
2960 
2961  mismatchingDims->second.dim = dstDim;
2962  mismatchingDims->second.isScalable = dstDimScalableFlag;
2963  }
2964  return BroadcastableToResult::DimensionMismatch;
2965  }
2966  }
2967 
2968  return BroadcastableToResult::Success;
2969 }
2970 
2971 LogicalResult BroadcastOp::verify() {
2972  std::pair<VectorDim, VectorDim> mismatchingDims;
2974  getSourceType(), getResultVectorType(), &mismatchingDims);
2975  if (res == BroadcastableToResult::Success)
2976  return success();
2977  if (res == BroadcastableToResult::SourceRankHigher)
2978  return emitOpError("source rank higher than destination rank");
2979  if (res == BroadcastableToResult::DimensionMismatch) {
2980  return emitOpError("dimension mismatch (")
2981  << (mismatchingDims.first.isScalable ? "[" : "")
2982  << mismatchingDims.first.dim
2983  << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
2984  << (mismatchingDims.second.isScalable ? "[" : "")
2985  << mismatchingDims.second.dim
2986  << (mismatchingDims.second.isScalable ? "]" : "") << ")";
2987  }
2988  if (res == BroadcastableToResult::SourceTypeNotAVector)
2989  return emitOpError("source type is not a vector");
2990  llvm_unreachable("unexpected vector.broadcast op error");
2991 }
2992 
2993 // Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
2994 // with broadcast's result type and shape_cast only adds or removes ones in the
2995 // leading dimensions.
2996 static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
2997  auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
2998  if (!srcShapeCast)
2999  return failure();
3000 
3001  VectorType srcType = srcShapeCast.getSourceVectorType();
3002  VectorType destType = broadcastOp.getResultVectorType();
3003  // Check type compatibility.
3004  if (vector::isBroadcastableTo(srcType, destType) !=
3005  BroadcastableToResult::Success)
3006  return failure();
3007 
3008  ArrayRef<int64_t> srcShape = srcType.getShape();
3009  ArrayRef<int64_t> shapecastShape =
3010  srcShapeCast.getResultVectorType().getShape();
3011  // Trailing dimensions should be the same if shape_cast only alters the
3012  // leading dimensions.
3013  unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3014  if (!llvm::equal(srcShape.take_back(numTrailingDims),
3015  shapecastShape.take_back(numTrailingDims)))
3016  return failure();
3017 
3018  assert(all_of(srcShape.drop_back(numTrailingDims),
3019  [](int64_t E) { return E == 1; }) &&
3020  all_of(shapecastShape.drop_back(numTrailingDims),
3021  [](int64_t E) { return E == 1; }) &&
3022  "ill-formed shape_cast");
3023 
3024  broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3025  return success();
3026 }
3027 
3028 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3029  if (getSourceType() == getResultVectorType())
3030  return getSource();
3031  if (succeeded(foldBroadcastOfShapeCast(*this)))
3032  return getResult();
3033 
3034  if (!adaptor.getSource())
3035  return {};
3036  auto vectorType = getResultVectorType();
3037  if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3038  if (vectorType.getElementType() != attr.getType())
3039  return {};
3040  return DenseElementsAttr::get(vectorType, attr);
3041  }
3042  if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3043  if (vectorType.getElementType() != attr.getType())
3044  return {};
3045  return DenseElementsAttr::get(vectorType, attr);
3046  }
3047  if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3048  return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
3049  if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
3050  return ub::PoisonAttr::get(getContext());
3051  return {};
3052 }
3053 
3054 namespace {
3055 
3056 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
3057 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
3058  using Base::Base;
3059 
3060  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3061  PatternRewriter &rewriter) const override {
3062  auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3063  if (!srcBroadcast)
3064  return failure();
3065  rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
3066  broadcastOp.getResultVectorType(),
3067  srcBroadcast.getSource());
3068  return success();
3069  }
3070 };
3071 } // namespace
3072 
3073 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3074  MLIRContext *context) {
3075  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
3076  // calling `populateCastAwayVectorLeadingOneDimPatterns`
3077  results.add<BroadcastFolder>(context);
3078 }
3079 
3080 //===----------------------------------------------------------------------===//
3081 // ShuffleOp
3082 //===----------------------------------------------------------------------===//
3083 
3084 LogicalResult ShuffleOp::verify() {
3085  VectorType resultType = getResultVectorType();
3086  VectorType v1Type = getV1VectorType();
3087  VectorType v2Type = getV2VectorType();
3088  // Verify ranks.
3089  int64_t resRank = resultType.getRank();
3090  int64_t v1Rank = v1Type.getRank();
3091  int64_t v2Rank = v2Type.getRank();
3092  bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3093  bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3094  if (!wellFormed0DCase && !wellFormedNDCase)
3095  return emitOpError("rank mismatch");
3096 
3097  // Verify all but leading dimension sizes.
3098  for (int64_t r = 1; r < v1Rank; ++r) {
3099  int64_t resDim = resultType.getDimSize(r);
3100  int64_t v1Dim = v1Type.getDimSize(r);
3101  int64_t v2Dim = v2Type.getDimSize(r);
3102  if (resDim != v1Dim || v1Dim != v2Dim)
3103  return emitOpError("dimension mismatch");
3104  }
3105  // Verify mask length.
3106  ArrayRef<int64_t> mask = getMask();
3107  int64_t maskLength = mask.size();
3108  if (maskLength <= 0)
3109  return emitOpError("invalid mask length");
3110  if (maskLength != resultType.getDimSize(0))
3111  return emitOpError("mask length mismatch");
3112  // Verify all indices.
3113  int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3114  (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3115  for (auto [idx, maskPos] : llvm::enumerate(mask)) {
3116  if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
3117  return emitOpError("mask index #") << (idx + 1) << " out of range";
3118  }
3119  return success();
3120 }
3121 
3122 LogicalResult
3123 ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
3124  ShuffleOp::Adaptor adaptor,
3125  SmallVectorImpl<Type> &inferredReturnTypes) {
3126  auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
3127  auto v1Rank = v1Type.getRank();
3128  // Construct resulting type: leading dimension matches mask
3129  // length, all trailing dimensions match the operands.
3131  shape.reserve(v1Rank);
3132  shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3133  // In the 0-D case there is no trailing shape to append.
3134  if (v1Rank > 0)
3135  llvm::append_range(shape, v1Type.getShape().drop_front());
3136  inferredReturnTypes.push_back(
3137  VectorType::get(shape, v1Type.getElementType()));
3138  return success();
3139 }
3140 
3141 template <typename T>
3142 static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
3143  T expected = begin;
3144  return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3145  return value == expected++;
3146  });
3147 }
3148 
3149 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3150  auto v1Type = getV1VectorType();
3151  auto v2Type = getV2VectorType();
3152 
3153  assert(!v1Type.isScalable() && !v2Type.isScalable() &&
3154  "Vector shuffle does not support scalable vectors");
3155 
3156  // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
3157  // but must be a canonicalization into a vector.broadcast.
3158  if (v1Type.getRank() == 0)
3159  return {};
3160 
3161  // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
3162  auto mask = getMask();
3163  if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
3164  return getV1();
3165  // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
3166  if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
3167  return getV2();
3168 
3169  Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3170  if (!v1Attr || !v2Attr)
3171  return {};
3172 
3173  // Fold shuffle poison, poison -> poison.
3174  bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
3175  bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
3176  if (isV1Poison && isV2Poison)
3177  return ub::PoisonAttr::get(getContext());
3178 
3179  // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
3180  // manipulation.
3181  if (v1Type.getRank() != 1)
3182  return {};
3183 
3184  // Poison input attributes need special handling as they are not
3185  // DenseElementsAttr. If an index is poison, we select the first element of
3186  // the first non-poison input.
3187  SmallVector<Attribute> v1Elements, v2Elements;
3188  Attribute poisonElement;
3189  if (!isV2Poison) {
3190  auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3191  if (!v2DenseAttr)
3192  return {};
3193  v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3194  poisonElement = v2Elements[0];
3195  }
3196  if (!isV1Poison) {
3197  auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3198  if (!v1DenseAttr)
3199  return {};
3200  v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3201  poisonElement = v1Elements[0];
3202  }
3203 
3204  SmallVector<Attribute> results;
3205  int64_t v1Size = v1Type.getDimSize(0);
3206  for (int64_t maskIdx : mask) {
3207  Attribute indexedElm;
3208  // TODO: Return a partial poison vector when supported by the UB dialect.
3209  if (maskIdx == ShuffleOp::kPoisonIndex) {
3210  indexedElm = poisonElement;
3211  } else {
3212  if (maskIdx < v1Size)
3213  indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3214  else
3215  indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3216  }
3217 
3218  results.push_back(indexedElm);
3219  }
3220 
3221  return DenseElementsAttr::get(getResultVectorType(), results);
3222 }
3223 
3224 namespace {
3225 
3226 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
3227 // to a broadcast.
3228 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
3229  using Base::Base;
3230 
3231  LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3232  PatternRewriter &rewriter) const override {
3233  VectorType v1VectorType = shuffleOp.getV1VectorType();
3234  ArrayRef<int64_t> mask = shuffleOp.getMask();
3235  if (v1VectorType.getRank() > 0)
3236  return failure();
3237  if (mask.size() != 1)
3238  return failure();
3239  VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3240  if (mask[0] == 0)
3241  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3242  shuffleOp.getV1());
3243  else
3244  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3245  shuffleOp.getV2());
3246  return success();
3247  }
3248 };
3249 
3250 /// Consider the defining operation `defOp` of `value`. If `defOp` is a
3251 /// vector.broadcast with a scalar operand, return the scalar value that is
3252 /// splatted. Otherwise return null.
3253 ///
3254 /// Example:
3255 ///
3256 /// scalar_source --> vector.broadcast --> value - return scalar_source
3257 static Value getScalarSplatSource(Value value) {
3258  // Block argument:
3259  Operation *defOp = value.getDefiningOp();
3260  if (!defOp)
3261  return {};
3262 
3263  auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3264 
3265  // Not broadcast (and not splat):
3266  if (!broadcast)
3267  return {};
3268 
3269  // Broadcast of a vector:
3270  if (isa<VectorType>(broadcast.getSourceType()))
3271  return {};
3272 
3273  // Broadcast of a scalar:
3274  return broadcast.getSource();
3275 }
3276 
3277 /// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
3278 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
3279 public:
3280  using Base::Base;
3281 
3282  LogicalResult matchAndRewrite(ShuffleOp op,
3283  PatternRewriter &rewriter) const override {
3284  Value splat = getScalarSplatSource(op.getV1());
3285  if (!splat || getScalarSplatSource(op.getV2()) != splat)
3286  return failure();
3287 
3288  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3289  return success();
3290  }
3291 };
3292 
3293 /// Pattern to rewrite a fixed-size interleave via vector.shuffle to
3294 /// vector.interleave.
3295 class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
3296 public:
3297  using Base::Base;
3298 
3299  LogicalResult matchAndRewrite(ShuffleOp op,
3300  PatternRewriter &rewriter) const override {
3301  VectorType resultType = op.getResultVectorType();
3302  if (resultType.isScalable())
3303  return rewriter.notifyMatchFailure(
3304  op, "ShuffleOp can't represent a scalable interleave");
3305 
3306  if (resultType.getRank() != 1)
3307  return rewriter.notifyMatchFailure(
3308  op, "ShuffleOp can't represent an n-D interleave");
3309 
3310  VectorType sourceType = op.getV1VectorType();
3311  if (sourceType != op.getV2VectorType() ||
3312  sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3313  return rewriter.notifyMatchFailure(
3314  op, "ShuffleOp types don't match an interleave");
3315  }
3316 
3317  ArrayRef<int64_t> shuffleMask = op.getMask();
3318  int64_t resultVectorSize = resultType.getNumElements();
3319  for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3320  int64_t maskValueA = shuffleMask[i * 2];
3321  int64_t maskValueB = shuffleMask[(i * 2) + 1];
3322  if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3323  return rewriter.notifyMatchFailure(op,
3324  "ShuffleOp mask not interleaving");
3325  }
3326 
3327  rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
3328  return success();
3329  }
3330 };
3331 
3332 } // namespace
3333 
3334 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3335  MLIRContext *context) {
3336  results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3337  context);
3338 }
3339 
3340 //===----------------------------------------------------------------------===//
3341 // InsertOp
3342 //===----------------------------------------------------------------------===//
3343 
3344 void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3345  SetIntRangeFn setResultRanges) {
3346  setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3347 }
3348 
3349 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3350  Value source, Value dest) {
3351  auto vectorTy = cast<VectorType>(dest.getType());
3352  build(builder, result, source, dest,
3353  SmallVector<int64_t>(vectorTy.getRank(), 0));
3354 }
3355 
3356 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3357  Value source, Value dest, int64_t position) {
3358  build(builder, result, source, dest, ArrayRef<int64_t>{position});
3359 }
3360 
3361 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3362  Value source, Value dest, OpFoldResult position) {
3363  build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
3364 }
3365 
3366 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3367  Value source, Value dest,
3368  ArrayRef<int64_t> position) {
3369  SmallVector<OpFoldResult> posVals;
3370  posVals.reserve(position.size());
3371  llvm::transform(position, std::back_inserter(posVals),
3372  [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
3373  build(builder, result, source, dest, posVals);
3374 }
3375 
3376 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3377  Value source, Value dest,
3378  ArrayRef<OpFoldResult> position) {
3379  SmallVector<int64_t> staticPos;
3380  SmallVector<Value> dynamicPos;
3381  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
3382  build(builder, result, source, dest, dynamicPos,
3383  builder.getDenseI64ArrayAttr(staticPos));
3384 }
3385 
3386 LogicalResult InsertOp::verify() {
3387  if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3388  if (srcTy.getRank() == 0)
3389  return emitError(
3390  "expected a scalar instead of a 0-d vector as the source operand");
3391 
3392  SmallVector<OpFoldResult> position = getMixedPosition();
3393  auto destVectorType = getDestVectorType();
3394  if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
3395  return emitOpError(
3396  "expected position attribute of rank no greater than dest vector rank");
3397  auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3398  if (srcVectorType &&
3399  (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3400  static_cast<unsigned>(destVectorType.getRank())))
3401  return emitOpError("expected position attribute rank + source rank to "
3402  "match dest vector rank");
3403  if (!srcVectorType &&
3404  (position.size() != static_cast<unsigned>(destVectorType.getRank())))
3405  return emitOpError(
3406  "expected position attribute rank to match the dest vector rank");
3407  for (auto [idx, pos] : llvm::enumerate(position)) {
3408  if (auto attr = dyn_cast<Attribute>(pos)) {
3409  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3410  if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
3411  destVectorType.getDimSize(idx))) {
3412  return emitOpError("expected position attribute #")
3413  << (idx + 1)
3414  << " to be a non-negative integer smaller than the "
3415  "corresponding "
3416  "dest vector dimension";
3417  }
3418  }
3419  }
3420  return success();
3421 }
3422 
3423 // Calculate the linearized position of the continuous chunk of elements to
3424 // insert, based on the shape of the value to insert and the positions to insert
3425 // at.
3426 static int64_t calculateInsertPosition(VectorType destTy,
3427  ArrayRef<int64_t> positions) {
3428  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3429  assert(positions.size() <= completePositions.size() &&
3430  "positions size must be less than or equal to destTy rank");
3431  copy(positions, completePositions.begin());
3432  return linearize(completePositions, computeStrides(destTy.getShape()));
3433 }
3434 
3435 namespace {
3436 
3437 // If insertOp is only inserting unit dimensions it can be transformed to a
3438 // broadcast.
3439 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3440 public:
3441  using Base::Base;
3442 
3443  LogicalResult matchAndRewrite(InsertOp insertOp,
3444  PatternRewriter &rewriter) const override {
3445  auto srcVecType =
3446  llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3447  if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3448  srcVecType.getNumElements())
3449  return failure();
3450  rewriter.replaceOpWithNewOp<BroadcastOp>(
3451  insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3452  return success();
3453  }
3454 };
3455 
3456 /// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
3457 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3458 public:
3459  using Base::Base;
3460 
3461  LogicalResult matchAndRewrite(InsertOp op,
3462  PatternRewriter &rewriter) const override {
3463 
3464  Value splat = getScalarSplatSource(op.getValueToStore());
3465  if (!splat || getScalarSplatSource(op.getDest()) != splat)
3466  return failure();
3467 
3468  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3469  return success();
3470  }
3471 };
3472 
3473 /// Pattern to optimize a chain of insertions.
3474 ///
3475 /// This pattern identifies chains of vector.insert operations that:
3476 /// 1. Only insert values at static positions.
3477 /// 2. Completely initialize all elements in the resulting vector.
3478 /// 3. All intermediate insert operations have only one use.
3479 ///
3480 /// When these conditions are met, the entire chain can be replaced with a
3481 /// single vector.from_elements operation.
3482 ///
3483 /// To keep this pattern simple, and avoid spending too much time on matching
3484 /// fragmented insert chains, this pattern only considers the last insert op in
3485 /// the chain.
3486 ///
3487 /// Example transformation:
3488 /// %poison = ub.poison : vector<2xi32>
3489 /// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
3490 /// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
3491 /// ->
3492 /// %result = vector.from_elements %c1, %c2 : vector<2xi32>
3493 class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
3494 public:
3495  using Base::Base;
3496  LogicalResult matchAndRewrite(InsertOp op,
3497  PatternRewriter &rewriter) const override {
3498 
3499  VectorType destTy = op.getDestVectorType();
3500  if (destTy.isScalable())
3501  return failure();
3502  // Ensure this is the trailing vector.insert op in a chain of inserts.
3503  for (Operation *user : op.getResult().getUsers())
3504  if (auto insertOp = dyn_cast<InsertOp>(user))
3505  if (insertOp.getDest() == op.getResult())
3506  return failure();
3507 
3508  InsertOp currentOp = op;
3509  SmallVector<InsertOp> chainInsertOps;
3510  while (currentOp) {
3511  // Check cond 1: Dynamic position is not supported.
3512  if (currentOp.hasDynamicPosition())
3513  return failure();
3514 
3515  chainInsertOps.push_back(currentOp);
3516  currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3517  // Check cond 3: Intermediate inserts have only one use to avoid an
3518  // explosion of vectors.
3519  if (currentOp && !currentOp->hasOneUse())
3520  return failure();
3521  }
3522 
3523  int64_t vectorSize = destTy.getNumElements();
3524  int64_t initializedCount = 0;
3525  SmallVector<bool> initializedDestIdxs(vectorSize, false);
3526  SmallVector<int64_t> pendingInsertPos;
3527  SmallVector<int64_t> pendingInsertSize;
3528  SmallVector<Value> pendingInsertValues;
3529 
3530  for (auto insertOp : chainInsertOps) {
3531  // This pattern can do nothing with poison index.
3532  if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3533  return failure();
3534 
3535  // Calculate the linearized position for inserting elements.
3536  int64_t insertBeginPosition =
3537  calculateInsertPosition(destTy, insertOp.getStaticPosition());
3538 
3539  // The valueToStore operand may be a vector or a scalar. Need to handle
3540  // both cases.
3541  int64_t insertSize = 1;
3542  if (auto srcVectorType =
3543  llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3544  insertSize = srcVectorType.getNumElements();
3545 
3546  assert(insertBeginPosition + insertSize <= vectorSize &&
3547  "insert would overflow the vector");
3548 
3549  for (auto index : llvm::seq<int64_t>(insertBeginPosition,
3550  insertBeginPosition + insertSize)) {
3551  if (initializedDestIdxs[index])
3552  continue;
3553  initializedDestIdxs[index] = true;
3554  ++initializedCount;
3555  }
3556 
3557  // Defer the creation of ops before we can make sure the pattern can
3558  // succeed.
3559  pendingInsertPos.push_back(insertBeginPosition);
3560  pendingInsertSize.push_back(insertSize);
3561  pendingInsertValues.push_back(insertOp.getValueToStore());
3562 
3563  if (initializedCount == vectorSize)
3564  break;
3565  }
3566 
3567  // Check cond 2: all positions must be initialized.
3568  if (initializedCount != vectorSize)
3569  return failure();
3570 
3571  SmallVector<Value> elements(vectorSize);
3572  for (auto [insertBeginPosition, insertSize, valueToStore] :
3573  llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3574  pendingInsertValues))) {
3575  auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3576 
3577  if (!srcVectorType) {
3578  elements[insertBeginPosition] = valueToStore;
3579  continue;
3580  }
3581 
3582  SmallVector<Type> elementToInsertTypes(insertSize,
3583  srcVectorType.getElementType());
3584  // Get all elements from the vector in row-major order.
3585  auto elementsToInsert = vector::ToElementsOp::create(
3586  rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3587  for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3588  elements[insertBeginPosition + linearIdx] =
3589  elementsToInsert.getResult(linearIdx);
3590  }
3591  }
3592 
3593  rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
3594  return success();
3595  }
3596 };
3597 
3598 } // namespace
3599 
3600 static Attribute
3602  Attribute dstAttr,
3603  int64_t maxVectorSizeFoldThreshold) {
3604  if (insertOp.hasDynamicPosition())
3605  return {};
3606 
3607  auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3608  if (!denseDst)
3609  return {};
3610 
3611  if (!srcAttr) {
3612  return {};
3613  }
3614 
3615  VectorType destTy = insertOp.getDestVectorType();
3616  if (destTy.isScalable())
3617  return {};
3618 
3619  // Make sure we do not create too many large constants.
3620  if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3621  !insertOp->hasOneUse())
3622  return {};
3623 
3624  // Calculate the linearized position for inserting elements.
3625  int64_t insertBeginPosition =
3626  calculateInsertPosition(destTy, insertOp.getStaticPosition());
3627  SmallVector<Attribute> insertedValues;
3628  Type destEltType = destTy.getElementType();
3629 
3630  /// Converts attribute to the expected type if there's
3631  /// a mismatch.
3632  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3633  for (auto value : denseSource.getValues<Attribute>())
3634  insertedValues.push_back(convertNumericAttr(value, destEltType));
3635  } else {
3636  insertedValues.push_back(convertNumericAttr(srcAttr, destEltType));
3637  }
3638 
3639  auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
3640  copy(insertedValues, allValues.begin() + insertBeginPosition);
3641  auto newAttr = DenseElementsAttr::get(destTy, allValues);
3642 
3643  return newAttr;
3644 }
3645 
3646 /// Folder to replace the `dest` operand of the insert op with the root dest of
3647 /// the insert op use chain.
3648 static Value foldInsertUseChain(InsertOp insertOp) {
3649  auto destInsert = insertOp.getDest().getDefiningOp<InsertOp>();
3650  if (!destInsert)
3651  return {};
3652 
3653  if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3654  return {};
3655 
3656  insertOp.setOperand(1, destInsert.getDest());
3657  return insertOp.getResult();
3658 }
3659 
3660 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3661  MLIRContext *context) {
3662  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3663  InsertChainFullyInitialized>(context);
3664 }
3665 
3666 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
3667  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3668  // unless the source vector constant has a single use.
3669  constexpr int64_t vectorSizeFoldThreshold = 256;
3670  // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3671  // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3672  // (type mismatch).
3673  if (getNumIndices() == 0 && getValueToStoreType() == getType())
3674  return getValueToStore();
3675  // Fold `arith.constant` indices into the `vector.insert` operation.
3676  // Do not stop here as this fold may enable subsequent folds that require
3677  // constant indices.
3678  SmallVector<Value> operands = {getValueToStore(), getDest()};
3679  auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
3680 
3681  if (auto res = foldInsertUseChain(*this))
3682  return res;
3683  if (auto res = foldPoisonIndexInsertExtractOp(
3684  getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3685  return res;
3686  if (auto res = foldDenseElementsAttrDestInsertOp(
3687  *this, adaptor.getValueToStore(), adaptor.getDest(),
3688  vectorSizeFoldThreshold)) {
3689  return res;
3690  }
3691 
3692  return inplaceFolded;
3693 }
3694 
3695 //===----------------------------------------------------------------------===//
3696 // InsertStridedSliceOp
3697 //===----------------------------------------------------------------------===//
3698 
3699 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3700  Value source, Value dest,
3701  ArrayRef<int64_t> offsets,
3702  ArrayRef<int64_t> strides) {
3703  result.addOperands({source, dest});
3704  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3705  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3706  result.addTypes(dest.getType());
3707  result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
3708  offsetsAttr);
3709  result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
3710  stridesAttr);
3711 }
3712 
3713 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
3714 template <typename OpType>
3715 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3716  ArrayAttr arrayAttr,
3717  ArrayRef<int64_t> shape,
3718  StringRef attrName) {
3719  if (arrayAttr.size() > shape.size())
3720  return op.emitOpError("expected ")
3721  << attrName << " attribute of rank no greater than vector rank";
3722  return success();
3723 }
3724 
3725 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3726 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3727 // Otherwise, the admissible interval is [min, max].
3728 template <typename OpType>
3729 static LogicalResult
3730 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
3731  int64_t max, StringRef attrName,
3732  bool halfOpen = true) {
3733  for (auto attr : arrayAttr) {
3734  auto val = llvm::cast<IntegerAttr>(attr).getInt();
3735  auto upper = max;
3736  if (!halfOpen)
3737  upper += 1;
3738  if (val < min || val >= upper)
3739  return op.emitOpError("expected ") << attrName << " to be confined to ["
3740  << min << ", " << upper << ")";
3741  }
3742  return success();
3743 }
3744 
3745 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3746 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3747 // Otherwise, the admissible interval is [min, max].
3748 template <typename OpType>
3749 static LogicalResult
3750 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
3751  ArrayRef<int64_t> shape, StringRef attrName,
3752  bool halfOpen = true, int64_t min = 0) {
3753  for (auto [index, attrDimPair] :
3754  llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3755  int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3756  int64_t max = std::get<1>(attrDimPair);
3757  if (!halfOpen)
3758  max += 1;
3759  if (val < min || val >= max)
3760  return op.emitOpError("expected ")
3761  << attrName << " dimension " << index << " to be confined to ["
3762  << min << ", " << max << ")";
3763  }
3764  return success();
3765 }
3766 
3767 // Returns true if, for all indices i = 0..shape.size()-1, val is in the
3768 // [min, max} interval:
3769 // val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3770 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3771 // the admissible interval is [min, max].
3772 template <typename OpType>
3774  OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3775  ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3776  bool halfOpen = true, int64_t min = 1) {
3777  assert(arrayAttr1.size() <= shape.size());
3778  assert(arrayAttr2.size() <= shape.size());
3779  for (auto [index, it] :
3780  llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3781  auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3782  auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3783  int64_t max = std::get<2>(it);
3784  if (!halfOpen)
3785  max += 1;
3786  if (val1 + val2 < 0 || val1 + val2 >= max)
3787  return op.emitOpError("expected sum(")
3788  << attrName1 << ", " << attrName2 << ") dimension " << index
3789  << " to be confined to [" << min << ", " << max << ")";
3790  }
3791  return success();
3792 }
3793 
3794 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
3795  MLIRContext *context) {
3796  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3797  return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3798  });
3799  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3800 }
3801 
3802 LogicalResult InsertStridedSliceOp::verify() {
3803  auto sourceVectorType = getSourceVectorType();
3804  auto destVectorType = getDestVectorType();
3805  auto offsets = getOffsetsAttr();
3806  auto strides = getStridesAttr();
3807  if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3808  return emitOpError(
3809  "expected offsets of same size as destination vector rank");
3810  if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3811  return emitOpError("expected strides of same size as source vector rank");
3812  if (sourceVectorType.getRank() > destVectorType.getRank())
3813  return emitOpError(
3814  "expected source rank to be no greater than destination rank");
3815 
3816  auto sourceShape = sourceVectorType.getShape();
3817  auto destShape = destVectorType.getShape();
3818  SmallVector<int64_t, 4> sourceShapeAsDestShape(
3819  destShape.size() - sourceShape.size(), 0);
3820  sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3821  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3822  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3823  if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3824  offName)) ||
3825  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3826  /*max=*/1, stridesName,
3827  /*halfOpen=*/false)) ||
3829  *this, offsets,
3830  makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3831  offName, "source vector shape",
3832  /*halfOpen=*/false, /*min=*/1)))
3833  return failure();
3834 
3835  unsigned rankDiff = destShape.size() - sourceShape.size();
3836  for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3837  if (sourceVectorType.getScalableDims()[idx] !=
3838  destVectorType.getScalableDims()[idx + rankDiff]) {
3839  return emitOpError("mismatching scalable flags (at source vector idx=")
3840  << idx << ")";
3841  }
3842  if (sourceVectorType.getScalableDims()[idx]) {
3843  auto sourceSize = sourceShape[idx];
3844  auto destSize = destShape[idx + rankDiff];
3845  if (sourceSize != destSize) {
3846  return emitOpError("expected size at idx=")
3847  << idx
3848  << (" to match the corresponding base size from the input "
3849  "vector (")
3850  << sourceSize << (" vs ") << destSize << (")");
3851  }
3852  }
3853  }
3854 
3855  return success();
3856 }
3857 
3858 namespace {
3859 /// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
3860 class FoldInsertStridedSliceSplat final
3861  : public OpRewritePattern<InsertStridedSliceOp> {
3862 public:
3863  using Base::Base;
3864 
3865  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3866  PatternRewriter &rewriter) const override {
3867 
3868  auto dst = insertStridedSliceOp.getDest();
3869  auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
3870  if (!splat || getScalarSplatSource(dst) != splat)
3871  return failure();
3872 
3873  rewriter.replaceOp(insertStridedSliceOp, dst);
3874  return success();
3875  }
3876 };
3877 
3878 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3879 /// to dst.
3880 class FoldInsertStridedSliceOfExtract final
3881  : public OpRewritePattern<InsertStridedSliceOp> {
3882 public:
3883  using Base::Base;
3884 
3885  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3886  PatternRewriter &rewriter) const override {
3887  auto extractStridedSliceOp =
3888  insertStridedSliceOp.getValueToStore()
3889  .getDefiningOp<vector::ExtractStridedSliceOp>();
3890 
3891  if (!extractStridedSliceOp)
3892  return failure();
3893 
3894  if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3895  return failure();
3896 
3897  // Check if have the same strides and offsets.
3898  if (extractStridedSliceOp.getStrides() !=
3899  insertStridedSliceOp.getStrides() ||
3900  extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3901  return failure();
3902 
3903  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3904  return success();
3905  }
3906 };
3907 
3908 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3909 // ConstantOp.
3910 class InsertStridedSliceConstantFolder final
3911  : public OpRewritePattern<InsertStridedSliceOp> {
3912 public:
3913  using Base::Base;
3914 
3915  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3916  // unless the source vector constant has a single use.
3917  static constexpr int64_t vectorSizeFoldThreshold = 256;
3918 
3919  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3920  PatternRewriter &rewriter) const override {
3921  // Return if 'InsertOp' operand is not defined by a compatible vector
3922  // ConstantOp.
3923  TypedValue<VectorType> destVector = op.getDest();
3924  Attribute vectorDestCst;
3925  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3926  return failure();
3927 
3928  VectorType destTy = destVector.getType();
3929  if (destTy.isScalable())
3930  return failure();
3931 
3932  // Make sure we do not create too many large constants.
3933  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3934  !destVector.hasOneUse())
3935  return failure();
3936 
3937  TypedValue<VectorType> sourceValue = op.getValueToStore();
3938  Attribute sourceCst;
3939  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3940  return failure();
3941 
3942  // TODO: Support poison.
3943  if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3944  return failure();
3945 
3946  // TODO: Handle non-unit strides when they become available.
3947  if (op.hasNonUnitStrides())
3948  return failure();
3949 
3950  VectorType sliceVecTy = sourceValue.getType();
3951  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3952  int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3953  SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3954  SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3955 
3956  // Calcualte the destination element indices by enumerating all slice
3957  // positions within the destination and linearizing them. The enumeration
3958  // order is lexicographic which yields a sequence of monotonically
3959  // increasing linearized position indices.
3960  // Because the destination may have higher dimensionality then the slice,
3961  // we keep track of two overlapping sets of positions and offsets.
3962  auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3963  auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3964  auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3965  auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3966  SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3967  MutableArrayRef<int64_t> currSlicePosition(
3968  currDestPosition.begin() + rankDifference, currDestPosition.end());
3969  ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3970  offsets.end());
3971  do {
3972  int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3973  assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3974  assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3975  "Invalid slice element");
3976  newValues[linearizedPosition] = *sliceValuesIt;
3977  ++sliceValuesIt;
3978  } while (succeeded(
3979  incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3980 
3981  auto newAttr = DenseElementsAttr::get(destTy, newValues);
3982  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3983  return success();
3984  }
3985 };
3986 
3987 } // namespace
3988 
3989 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3990  RewritePatternSet &results, MLIRContext *context) {
3991  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3992  InsertStridedSliceConstantFolder>(context);
3993 }
3994 
3995 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3996  if (getSourceVectorType() == getDestVectorType())
3997  return getValueToStore();
3998  return {};
3999 }
4000 
4001 //===----------------------------------------------------------------------===//
4002 // OuterProductOp
4003 //===----------------------------------------------------------------------===//
4004 
4005 /// Build an op without mask, use the type of `acc` as the return type.
4006 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
4007  Value lhs, Value rhs, Value acc) {
4008  result.addOperands({lhs, rhs, acc});
4009  result.addTypes(acc.getType());
4010 }
4011 
4013  p << " " << getLhs() << ", " << getRhs();
4014  if (getAcc()) {
4015  p << ", " << getAcc();
4016  p.printOptionalAttrDict((*this)->getAttrs());
4017  }
4018  p << " : " << getLhs().getType() << ", " << getRhs().getType();
4019 }
4020 
4021 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
4023  Type tLHS, tRHS;
4024  if (parser.parseOperandList(operandsInfo) ||
4025  parser.parseOptionalAttrDict(result.attributes) ||
4026  parser.parseColonType(tLHS) || parser.parseComma() ||
4027  parser.parseType(tRHS))
4028  return failure();
4029  if (operandsInfo.size() < 2)
4030  return parser.emitError(parser.getNameLoc(),
4031  "expected at least 2 operands");
4032  VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4033  VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4034  if (!vLHS)
4035  return parser.emitError(parser.getNameLoc(),
4036  "expected vector type for operand #1");
4037 
4038  VectorType resType;
4039  if (vRHS) {
4040  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4041  vRHS.getScalableDims()[0]};
4042  resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4043  vLHS.getElementType(), scalableDimsRes);
4044  } else {
4045  // Scalar RHS operand
4046  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4047  resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4048  scalableDimsRes);
4049  }
4050 
4051  if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
4052  result.attributes.append(
4053  OuterProductOp::getKindAttrName(result.name),
4055  OuterProductOp::getDefaultKind()));
4056  }
4057 
4058  return failure(
4059  parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
4060  parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
4061  (operandsInfo.size() > 2 &&
4062  parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
4063  parser.addTypeToList(resType, result.types));
4064 }
4065 
4066 LogicalResult OuterProductOp::verify() {
4067  Type tRHS = getOperandTypeRHS();
4068  VectorType vLHS = getOperandVectorTypeLHS(),
4069  vRHS = llvm::dyn_cast<VectorType>(tRHS),
4070  vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4071 
4072  if (vLHS.getRank() != 1)
4073  return emitOpError("expected 1-d vector for operand #1");
4074 
4075  if (vRHS) {
4076  // Proper OUTER operation.
4077  if (vRHS.getRank() != 1)
4078  return emitOpError("expected 1-d vector for operand #2");
4079  if (vRES.getRank() != 2)
4080  return emitOpError("expected 2-d vector result");
4081  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4082  return emitOpError("expected #1 operand dim to match result dim #1");
4083  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4084  return emitOpError("expected #2 operand dim to match result dim #2");
4085  if (vLHS.isScalable() && !vRHS.isScalable()) {
4086  // This restriction reflects what's currently supported in terms of
4087  // scalable vectors. However, we could relax this if there's a use case.
4088  return emitOpError(
4089  "expected either both or only #2 operand dim to be scalable");
4090  }
4091  } else {
4092  // An AXPY operation.
4093  if (vRES.getRank() != 1)
4094  return emitOpError("expected 1-d vector result");
4095  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4096  return emitOpError("expected #1 operand dim to match result dim #1");
4097  }
4098 
4099  if (vACC && vACC != vRES)
4100  return emitOpError("expected operand #3 of same type as result type");
4101 
4102  // Verify supported combining kind.
4103  if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
4104  return emitOpError("unsupported outerproduct type");
4105 
4106  return success();
4107 }
4108 
4109 // MaskableOpInterface methods.
4110 
4111 /// Returns the mask type expected by this operation. Mostly used for
4112 /// verification purposes. It requires the operation to be vectorized."
4113 Type OuterProductOp::getExpectedMaskType() {
4114  auto vecType = this->getResultVectorType();
4115  return VectorType::get(vecType.getShape(),
4116  IntegerType::get(vecType.getContext(), /*width=*/1),
4117  vecType.getScalableDims());
4118 }
4119 
4120 //===----------------------------------------------------------------------===//
4121 // ExtractStridedSliceOp
4122 //===----------------------------------------------------------------------===//
4123 
4124 // Inference works as follows:
4125 // 1. Add 'sizes' from prefix of dims in 'offsets'.
4126 // 2. Add sizes from 'vectorType' for remaining dims.
4127 // Scalable flags are inherited from 'vectorType'.
4128 static Type inferStridedSliceOpResultType(VectorType vectorType,
4129  ArrayAttr offsets, ArrayAttr sizes,
4130  ArrayAttr strides) {
4131  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4133  shape.reserve(vectorType.getRank());
4134  unsigned idx = 0;
4135  for (unsigned e = offsets.size(); idx < e; ++idx)
4136  shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4137  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4138  shape.push_back(vectorType.getShape()[idx]);
4139 
4140  return VectorType::get(shape, vectorType.getElementType(),
4141  vectorType.getScalableDims());
4142 }
4143 
4144 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
4145  Value source, ArrayRef<int64_t> offsets,
4146  ArrayRef<int64_t> sizes,
4147  ArrayRef<int64_t> strides) {
4148  result.addOperands(source);
4149  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
4150  auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
4151  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
4152  result.addTypes(
4153  inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
4154  offsetsAttr, sizesAttr, stridesAttr));
4155  result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
4156  offsetsAttr);
4157  result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
4158  sizesAttr);
4159  result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
4160  stridesAttr);
4161 }
4162 
4163 LogicalResult ExtractStridedSliceOp::verify() {
4164  auto type = getSourceVectorType();
4165  auto offsets = getOffsetsAttr();
4166  auto sizes = getSizesAttr();
4167  auto strides = getStridesAttr();
4168  if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4169  return emitOpError(
4170  "expected offsets, sizes and strides attributes of same size");
4171 
4172  auto shape = type.getShape();
4173  auto offName = getOffsetsAttrName();
4174  auto sizesName = getSizesAttrName();
4175  auto stridesName = getStridesAttrName();
4176  if (failed(
4177  isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
4178  failed(
4179  isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
4180  failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
4181  stridesName)) ||
4182  failed(
4183  isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
4184  failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
4185  /*halfOpen=*/false,
4186  /*min=*/1)) ||
4187  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
4188  /*max=*/1, stridesName,
4189  /*halfOpen=*/false)) ||
4190  failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
4191  shape, offName, sizesName,
4192  /*halfOpen=*/false)))
4193  return failure();
4194 
4195  auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
4196  offsets, sizes, strides);
4197  if (getResult().getType() != resultType)
4198  return emitOpError("expected result type to be ") << resultType;
4199 
4200  for (unsigned idx = 0; idx < sizes.size(); ++idx) {
4201  if (type.getScalableDims()[idx]) {
4202  auto inputDim = type.getShape()[idx];
4203  auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4204  if (inputDim != inputSize)
4205  return emitOpError("expected size at idx=")
4206  << idx
4207  << (" to match the corresponding base size from the input "
4208  "vector (")
4209  << inputSize << (" vs ") << inputDim << (")");
4210  }
4211  }
4212 
4213  return success();
4214 }
4215 
4216 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
4217 // to use the source of the InsertStrided ops if we can detect that the
4218 // extracted vector is a subset of one of the vector inserted.
4219 static LogicalResult
4220 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
4221  // Helper to extract integer out of ArrayAttr.
4222  auto getElement = [](ArrayAttr array, int idx) {
4223  return llvm::cast<IntegerAttr>(array[idx]).getInt();
4224  };
4225  ArrayAttr extractOffsets = op.getOffsets();
4226  ArrayAttr extractStrides = op.getStrides();
4227  ArrayAttr extractSizes = op.getSizes();
4228  auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4229  while (insertOp) {
4230  if (op.getSourceVectorType().getRank() !=
4231  insertOp.getSourceVectorType().getRank())
4232  return failure();
4233  ArrayAttr insertOffsets = insertOp.getOffsets();
4234  ArrayAttr insertStrides = insertOp.getStrides();
4235  // If the rank of extract is greater than the rank of insert, we are likely
4236  // extracting a partial chunk of the vector inserted.
4237  if (extractOffsets.size() > insertOffsets.size())
4238  return failure();
4239  bool patialoverlap = false;
4240  bool disjoint = false;
4241  SmallVector<int64_t, 4> offsetDiffs;
4242  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4243  if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
4244  return failure();
4245  int64_t start = getElement(insertOffsets, dim);
4246  int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4247  int64_t offset = getElement(extractOffsets, dim);
4248  int64_t size = getElement(extractSizes, dim);
4249  // Check if the start of the extract offset is in the interval inserted.
4250  if (start <= offset && offset < end) {
4251  // If the extract interval overlaps but is not fully included we may
4252  // have a partial overlap that will prevent any folding.
4253  if (offset + size > end)
4254  patialoverlap = true;
4255  offsetDiffs.push_back(offset - start);
4256  continue;
4257  }
4258  disjoint = true;
4259  break;
4260  }
4261  // The extract element chunk is a subset of the insert element.
4262  if (!disjoint && !patialoverlap) {
4263  op.setOperand(insertOp.getValueToStore());
4264  // OpBuilder is only used as a helper to build an I64ArrayAttr.
4265  OpBuilder b(op.getContext());
4266  op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
4267  return success();
4268  }
4269  // If the chunk extracted is disjoint from the chunk inserted, keep looking
4270  // in the insert chain.
4271  if (disjoint)
4272  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4273  else {
4274  // The extracted vector partially overlap the inserted vector, we cannot
4275  // fold.
4276  return failure();
4277  }
4278  }
4279  return failure();
4280 }
4281 
4282 // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4283 static OpFoldResult
4285  Attribute foldInput) {
4286 
4287  auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4288  if (!dense)
4289  return {};
4290 
4291  // TODO: Handle non-unit strides when they become available.
4292  if (op.hasNonUnitStrides())
4293  return {};
4294 
4295  VectorType sourceVecTy = op.getSourceVectorType();
4296  ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
4297  SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
4298 
4299  VectorType sliceVecTy = op.getType();
4300  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4301  int64_t rank = sliceVecTy.getRank();
4302 
4303  // Expand offsets and sizes to match the vector rank.
4304  SmallVector<int64_t, 4> offsets(rank, 0);
4305  copy(getI64SubArray(op.getOffsets()), offsets.begin());
4306 
4307  SmallVector<int64_t, 4> sizes(sourceShape);
4308  copy(getI64SubArray(op.getSizes()), sizes.begin());
4309 
4310  // Calculate the slice elements by enumerating all slice positions and
4311  // linearizing them. The enumeration order is lexicographic which yields a
4312  // sequence of monotonically increasing linearized position indices.
4313  const auto denseValuesBegin = dense.value_begin<Attribute>();
4314  SmallVector<Attribute> sliceValues;
4315  sliceValues.reserve(sliceVecTy.getNumElements());
4316  SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
4317  do {
4318  int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
4319  assert(linearizedPosition < sourceVecTy.getNumElements() &&
4320  "Invalid index");
4321  sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4322  } while (succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
4323 
4324  assert(static_cast<int64_t>(sliceValues.size()) ==
4325  sliceVecTy.getNumElements() &&
4326  "Invalid number of slice elements");
4327  return DenseElementsAttr::get(sliceVecTy, sliceValues);
4328 }
4329 
4330 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4331  if (getSourceVectorType() == getResult().getType())
4332  return getSource();
4333  if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
4334  return getResult();
4335 
4336  // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
4337  if (auto splat =
4338  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4339  DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
4340 
4341  // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4342  return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getSource());
4343 }
4344 
4345 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4346  populateFromInt64AttrArray(getOffsets(), results);
4347 }
4348 
4349 namespace {
4350 
4351 // Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
4352 // CreateMaskOp.
4353 //
4354 // Example:
4355 //
4356 // %mask = vector.create_mask %ub : vector<16xi1>
4357 // %slice = vector.extract_strided_slice [%offset] [8] [1]
4358 //
4359 // to
4360 //
4361 // %new_ub = arith.subi %ub, %offset
4362 // %mask = vector.create_mask %new_ub : vector<8xi1>
4363 class StridedSliceCreateMaskFolder final
4364  : public OpRewritePattern<ExtractStridedSliceOp> {
4365  using Base::Base;
4366 
4367 public:
4368  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4369  PatternRewriter &rewriter) const override {
4370  Location loc = extractStridedSliceOp.getLoc();
4371  // Return if 'extractStridedSliceOp' operand is not defined by a
4372  // CreateMaskOp.
4373  auto createMaskOp =
4374  extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4375  if (!createMaskOp)
4376  return failure();
4377  // Return if 'extractStridedSliceOp' has non-unit strides.
4378  if (extractStridedSliceOp.hasNonUnitStrides())
4379  return failure();
4380  // Gather constant mask dimension sizes.
4381  SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4382  // Gather strided slice offsets and sizes.
4383  SmallVector<int64_t> sliceOffsets;
4384  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4385  sliceOffsets);
4386  SmallVector<int64_t> sliceSizes;
4387  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4388 
4389  // Compute slice of vector mask region.
4390  SmallVector<Value> sliceMaskDimSizes;
4391  sliceMaskDimSizes.reserve(maskDimSizes.size());
4392  // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4393  // only iterate on the leading dim sizes. The tail accounts for the
4394  // remaining dim sizes.
4395  for (auto [maskDimSize, sliceOffset, sliceSize] :
4396  llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4397  // No need to clamp on min/max values, because create_mask has clamping
4398  // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
4399  // greater than the vector dim size.
4400  IntegerAttr offsetAttr =
4401  rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
4402  Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4403  Value sliceMaskDimSize =
4404  arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4405  sliceMaskDimSizes.push_back(sliceMaskDimSize);
4406  }
4407  // Add unchanged dimensions.
4408  llvm::append_range(
4409  sliceMaskDimSizes,
4410  llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4411  // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
4412  // region.
4413  rewriter.replaceOpWithNewOp<CreateMaskOp>(
4414  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4415  sliceMaskDimSizes);
4416  return success();
4417  }
4418 };
4419 
4420 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
4421 // ConstantMaskOp.
4422 class StridedSliceConstantMaskFolder final
4423  : public OpRewritePattern<ExtractStridedSliceOp> {
4424 public:
4425  using Base::Base;
4426 
4427  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4428  PatternRewriter &rewriter) const override {
4429  // Return if 'extractStridedSliceOp' operand is not defined by a
4430  // ConstantMaskOp.
4431  auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4432  auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4433  if (!constantMaskOp)
4434  return failure();
4435  // Return if 'extractStridedSliceOp' has non-unit strides.
4436  if (extractStridedSliceOp.hasNonUnitStrides())
4437  return failure();
4438  // Gather constant mask dimension sizes.
4439  ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4440  // Gather strided slice offsets and sizes.
4441  SmallVector<int64_t> sliceOffsets;
4442  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4443  sliceOffsets);
4444  SmallVector<int64_t> sliceSizes;
4445  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4446 
4447  // Compute slice of vector mask region.
4448  SmallVector<int64_t> sliceMaskDimSizes;
4449  sliceMaskDimSizes.reserve(maskDimSizes.size());
4450  for (auto [maskDimSize, sliceOffset, sliceSize] :
4451  llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4452  int64_t sliceMaskDimSize = std::max(
4453  static_cast<int64_t>(0),
4454  std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4455  sliceMaskDimSizes.push_back(sliceMaskDimSize);
4456  }
4457  // Add unchanged dimensions.
4458  if (sliceMaskDimSizes.size() < maskDimSizes.size())
4459  for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4460  sliceMaskDimSizes.push_back(maskDimSizes[i]);
4461  // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
4462  // region is a conjunction of mask dim intervals).
4463  if (llvm::is_contained(sliceMaskDimSizes, 0))
4464  sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4465 
4466  // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
4467  // region.
4468  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4469  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4470  sliceMaskDimSizes);
4471  return success();
4472  }
4473 };
4474 
4475 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
4476 // BroadcastOp(ExtractStrideSliceOp).
4477 class StridedSliceBroadcast final
4478  : public OpRewritePattern<ExtractStridedSliceOp> {
4479 public:
4480  using Base::Base;
4481 
4482  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4483  PatternRewriter &rewriter) const override {
4484  auto broadcast = op.getSource().getDefiningOp<BroadcastOp>();
4485  if (!broadcast)
4486  return failure();
4487  auto srcVecType =
4488  llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
4489  unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4490  auto dstVecType = llvm::cast<VectorType>(op.getType());
4491  unsigned dstRank = dstVecType.getRank();
4492  unsigned rankDiff = dstRank - srcRank;
4493  // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4494  // (n -> m with n > m). If they are originally both broadcasted *and*
4495  // sliced, this can be simplified to just broadcasting.
4496  bool needsSlice = false;
4497  for (unsigned i = 0; i < srcRank; i++) {
4498  if (srcVecType.getDimSize(i) != 1 &&
4499  srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4500  needsSlice = true;
4501  break;
4502  }
4503  }
4504  Value source = broadcast.getSource();
4505  if (needsSlice) {
4506  SmallVector<int64_t> offsets =
4507  getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
4508  SmallVector<int64_t> sizes =
4509  getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
4510  for (unsigned i = 0; i < srcRank; i++) {
4511  if (srcVecType.getDimSize(i) == 1) {
4512  // In case this dimension was broadcasted *and* sliced, the offset
4513  // and size need to be updated now that there is no broadcast before
4514  // the slice.
4515  offsets[i] = 0;
4516  sizes[i] = 1;
4517  }
4518  }
4519  source = ExtractStridedSliceOp::create(
4520  rewriter, op->getLoc(), source, offsets, sizes,
4521  getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
4522  }
4523  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
4524  return success();
4525  }
4526 };
4527 
4528 /// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
4529 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
4530 public:
4531  using Base::Base;
4532 
4533  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4534  PatternRewriter &rewriter) const override {
4535 
4536  Value splat = getScalarSplatSource(op.getSource());
4537  if (!splat)
4538  return failure();
4539  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
4540  return success();
4541  }
4542 };
4543 
4544 /// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
4545 /// slice is contiguous, into extract and shape_cast.
4546 ///
4547 /// Example:
4548 /// Before:
4549 /// %1 = vector.extract_strided_slice %arg0 {
4550 /// offsets = [0, 0, 0, 0, 0],
4551 /// sizes = [1, 1, 1, 1, 8],
4552 /// strides = [1, 1, 1, 1, 1]
4553 /// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
4554 /// After:
4555 /// %0 = vector.extract %arg0[0, 0, 0, 0]
4556 /// : vector<8xi8> from vector<8x1x1x2x8xi8>
4557 /// %1 = vector.shape_cast %0
4558 /// : vector<8xi8> to vector<1x1x1x1x8xi8>
4559 ///
4560 class ContiguousExtractStridedSliceToExtract final
4561  : public OpRewritePattern<ExtractStridedSliceOp> {
4562 public:
4563  using Base::Base;
4564 
4565  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4566  PatternRewriter &rewriter) const override {
4567  if (op.hasNonUnitStrides())
4568  return failure();
4569  Value source = op.getOperand();
4570  auto sourceType = cast<VectorType>(source.getType());
4571  if (sourceType.isScalable() || sourceType.getRank() == 0)
4572  return failure();
4573 
4574  // Compute the number of offsets to pass to ExtractOp::build. That is the
4575  // difference between the source rank and the desired slice rank. We walk
4576  // the dimensions from innermost out, and stop when the next slice dimension
4577  // is not full-size.
4578  SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
4579  int numOffsets;
4580  for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4581  if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4582  break;
4583  }
4584 
4585  // If the created extract op would have no offsets, then this whole
4586  // extract_strided_slice is the identity and should have been handled by
4587  // other canonicalizations.
4588  if (numOffsets == 0)
4589  return failure();
4590 
4591  // If not even the inner-most dimension is full-size, this op can't be
4592  // rewritten as an ExtractOp.
4593  if (numOffsets == sourceType.getRank() &&
4594  static_cast<int>(sizes.size()) == sourceType.getRank())
4595  return failure();
4596 
4597  // The outer dimensions must have unit size.
4598  for (int i = 0; i < numOffsets; ++i) {
4599  if (sizes[i] != 1)
4600  return failure();
4601  }
4602 
4603  // Avoid generating slices that have leading unit dimensions. The shape_cast
4604  // op that we create below would take bad generic fallback patterns
4605  // (ShapeCastOpRewritePattern).
4606  while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
4607  sizes[numOffsets] == 1) {
4608  ++numOffsets;
4609  }
4610 
4611  SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
4612  auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4613  Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4614  extractOffsets);
4615  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
4616  return success();
4617  }
4618 };
4619 
4620 } // namespace
4621 
4622 void ExtractStridedSliceOp::getCanonicalizationPatterns(
4623  RewritePatternSet &results, MLIRContext *context) {
4624  // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
4625  // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4626  results.add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
4627  StridedSliceBroadcast, StridedSliceSplat,
4628  ContiguousExtractStridedSliceToExtract>(context);
4629 }
4630 
4631 //===----------------------------------------------------------------------===//
4632 // TransferReadOp
4633 //===----------------------------------------------------------------------===//
4634 
4635 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
4636 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4637  VectorType vectorType, Value source,
4638  ValueRange indices, std::optional<Value> padding,
4639  AffineMapAttr permutationMapAttr,
4640  /*optional*/ ArrayAttr inBoundsAttr) {
4641 
4642  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4643  if (!padding)
4644  padding = ub::PoisonOp::create(builder, result.location, elemType);
4645  build(builder, result, vectorType, source, indices, permutationMapAttr,
4646  *padding, /*mask=*/Value(), inBoundsAttr);
4647 }
4648 
4649 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
4650 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4651  VectorType vectorType, Value source,
4652  ValueRange indices, std::optional<Value> padding,
4653  AffineMap permutationMap,
4654  std::optional<ArrayRef<bool>> inBounds) {
4655  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4656  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4657  ? builder.getBoolArrayAttr(inBounds.value())
4658  : builder.getBoolArrayAttr(
4659  SmallVector<bool>(vectorType.getRank(), false));
4660  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4661  if (!padding)
4662  padding = ub::PoisonOp::create(builder, result.location, elemType);
4663  build(builder, result, vectorType, source, indices, *padding,
4664  permutationMapAttr, inBoundsAttr);
4665 }
4666 
4667 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
4668 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4669  VectorType vectorType, Value source,
4670  ValueRange indices, std::optional<Value> padding,
4671  std::optional<ArrayRef<bool>> inBounds) {
4672  AffineMap permutationMap = getTransferMinorIdentityMap(
4673  llvm::cast<ShapedType>(source.getType()), vectorType);
4674  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4675  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4676  ? builder.getBoolArrayAttr(inBounds.value())
4677  : builder.getBoolArrayAttr(
4678  SmallVector<bool>(vectorType.getRank(), false));
4679  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4680  if (!padding)
4681  padding = ub::PoisonOp::create(builder, result.location, elemType);
4682  build(builder, result, vectorType, source, indices, permutationMapAttr,
4683  *padding,
4684  /*mask=*/Value(), inBoundsAttr);
4685 }
4686 
4687 template <typename EmitFun>
4688 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
4689  EmitFun emitOpError) {
4690  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
4691  for (auto expr : permutationMap.getResults()) {
4692  auto dim = dyn_cast<AffineDimExpr>(expr);
4693  auto zero = dyn_cast<AffineConstantExpr>(expr);
4694  if (zero) {
4695  if (zero.getValue() != 0) {
4696  return emitOpError(
4697  "requires a projected permutation_map (at most one dim or the zero "
4698  "constant can appear in each result)");
4699  }
4700  continue;
4701  }
4702  if (!dim) {
4703  return emitOpError("requires a projected permutation_map (at most one "
4704  "dim or the zero constant can appear in each result)");
4705  }
4706  if (seen[dim.getPosition()]) {
4707  return emitOpError(
4708  "requires a permutation_map that is a permutation (found one dim "
4709  "used more than once)");
4710  }
4711  seen[dim.getPosition()] = true;
4712  }
4713  return success();
4714 }
4715 
4716 static LogicalResult
4717 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4718  VectorType vectorType, VectorType maskType,
4719  VectorType inferredMaskType, AffineMap permutationMap,
4720  ArrayAttr inBounds) {
4721  if (op->hasAttr("masked")) {
4722  return op->emitOpError("masked attribute has been removed. "
4723  "Use in_bounds instead.");
4724  }
4725 
4726  if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4727  return op->emitOpError(
4728  "requires source to be a memref or ranked tensor type");
4729 
4730  auto elementType = shapedType.getElementType();
4731  DataLayout dataLayout = DataLayout::closest(op);
4732  if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4733  // Memref or tensor has vector element type.
4734  unsigned sourceVecSize =
4735  dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
4736  vectorElementType.getShape().back();
4737  unsigned resultVecSize =
4738  dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
4739  vectorType.getShape().back();
4740  if (resultVecSize % sourceVecSize != 0)
4741  return op->emitOpError(
4742  "requires the bitwidth of the minor 1-D vector to be an integral "
4743  "multiple of the bitwidth of the minor 1-D vector of the source");
4744 
4745  unsigned sourceVecEltRank = vectorElementType.getRank();
4746  unsigned resultVecRank = vectorType.getRank();
4747  if (sourceVecEltRank > resultVecRank)
4748  return op->emitOpError(
4749  "requires source vector element and vector result ranks to match.");
4750  unsigned rankOffset = resultVecRank - sourceVecEltRank;
4751  // Check that permutation map results match 'rankOffset' of vector type.
4752  if (permutationMap.getNumResults() != rankOffset)
4753  return op->emitOpError("requires a permutation_map with result dims of "
4754  "the same rank as the vector type");
4755 
4756  if (maskType)
4757  return op->emitOpError("does not support masks with vector element type");
4758  } else {
4759  // Memref or tensor has scalar element type.
4760  unsigned minorSize =
4761  vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4762  unsigned resultVecSize =
4763  dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
4764  if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
4765  return op->emitOpError(
4766  "requires the bitwidth of the minor 1-D vector to be an integral "
4767  "multiple of the bitwidth of the source element type");
4768 
4769  // Check that permutation map results match rank of vector type.
4770  if (permutationMap.getNumResults() != vectorType.getRank())
4771  return op->emitOpError("requires a permutation_map with result dims of "
4772  "the same rank as the vector type");
4773  }
4774 
4775  if (permutationMap.getNumSymbols() != 0)
4776  return op->emitOpError("requires permutation_map without symbols");
4777 
4778  if (permutationMap.getNumInputs() != shapedType.getRank())
4779  return op->emitOpError("requires a permutation_map with input dims of the "
4780  "same rank as the source type");
4781 
4782  if (maskType && maskType != inferredMaskType)
4783  return op->emitOpError("inferred mask type (")
4784  << inferredMaskType << ") and mask operand type (" << maskType
4785  << ") don't match";
4786 
4787  if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
4788  return op->emitOpError("expects the in_bounds attr of same rank "
4789  "as permutation_map results: ")
4790  << AffineMapAttr::get(permutationMap)
4791  << " vs inBounds of size: " << inBounds.size();
4792 
4793  return success();
4794 }
4795 
4796 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
4797  SmallVector<StringRef, 3> elidedAttrs;
4798  elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4799  if (op.getPermutationMap().isMinorIdentity())
4800  elidedAttrs.push_back(op.getPermutationMapAttrName());
4801  // Elide in_bounds attribute if all dims are out-of-bounds.
4802  if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
4803  elidedAttrs.push_back(op.getInBoundsAttrName());
4804  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4805 }
4806 
4808  p << " " << getBase() << "[" << getIndices() << "], " << getPadding();
4809  if (getMask())
4810  p << ", " << getMask();
4811  printTransferAttrs(p, *this);
4812  p << " : " << getShapedType() << ", " << getVectorType();
4813 }
4814 
4815 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
4816  AffineMap permMap) {
4817  auto i1Type = IntegerType::get(permMap.getContext(), 1);
4818  AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
4819  assert(invPermMap && "Inversed permutation map couldn't be computed");
4820  SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
4821 
4822  // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
4823  // 0-D mask into a single-element 1-D mask.
4824  if (maskShape.empty())
4825  maskShape.push_back(1);
4826 
4827  SmallVector<bool> scalableDims =
4828  applyPermutationMap(invPermMap, vecType.getScalableDims());
4829 
4830  return VectorType::get(maskShape, i1Type, scalableDims);
4831 }
4832 
4833 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4834  auto &builder = parser.getBuilder();
4835  SMLoc typesLoc;
4836  OpAsmParser::UnresolvedOperand sourceInfo;
4838  OpAsmParser::UnresolvedOperand paddingInfo;
4839  SmallVector<Type, 2> types;
4841  // Parsing with support for paddingValue.
4842  if (parser.parseOperand(sourceInfo) ||
4843  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
4844  parser.parseComma() || parser.parseOperand(paddingInfo))
4845  return failure();
4846  ParseResult hasMask = parser.parseOptionalComma();
4847  if (hasMask.succeeded()) {
4848  if (parser.parseOperand(maskInfo))
4849  return failure();
4850  }
4851  if (parser.parseOptionalAttrDict(result.attributes) ||
4852  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4853  return failure();
4854  if (types.size() != 2)
4855  return parser.emitError(typesLoc, "requires two types");
4856  auto indexType = builder.getIndexType();
4857  auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4858  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4859  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4860  VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4861  if (!vectorType)
4862  return parser.emitError(typesLoc, "requires vector type");
4863  auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
4864  Attribute permMapAttr = result.attributes.get(permMapAttrName);
4865  AffineMap permMap;
4866  if (!permMapAttr) {
4867  if (shapedType.getRank() <
4868  getEffectiveVectorRankForXferOp(shapedType, vectorType))
4869  return parser.emitError(typesLoc,
4870  "expected a custom permutation_map when "
4871  "rank(source) != rank(destination)");
4872  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4873  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4874  } else {
4875  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4876  }
4877  auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
4878  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4879  if (!inBoundsAttr) {
4880  result.addAttribute(inBoundsAttrName,
4881  builder.getBoolArrayAttr(
4882  SmallVector<bool>(permMap.getNumResults(), false)));
4883  }
4884  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4885  parser.resolveOperands(indexInfo, indexType, result.operands) ||
4886  parser.resolveOperand(paddingInfo, shapedType.getElementType(),
4887  result.operands))
4888  return failure();
4889  if (hasMask.succeeded()) {
4890  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4891  return parser.emitError(
4892  maskInfo.location, "does not support masks with vector element type");
4893  if (vectorType.getRank() != permMap.getNumResults()) {
4894  return parser.emitError(typesLoc,
4895  "expected the same rank for the vector and the "
4896  "results of the permutation map");
4897  }
4898  // Instead of adding the mask type as an op type, compute it based on the
4899  // vector type and the permutation map (to keep the type signature small).
4900  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4901  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4902  return failure();
4903  }
4904  result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4905  builder.getDenseI32ArrayAttr(
4906  {1, static_cast<int32_t>(indexInfo.size()), 1,
4907  static_cast<int32_t>(hasMask.succeeded())}));
4908  return parser.addTypeToList(vectorType, result.types);
4909 }
4910 
4911 LogicalResult TransferReadOp::verify() {
4912  // Consistency of elemental types in source and vector.
4913  ShapedType shapedType = getShapedType();
4914  VectorType vectorType = getVectorType();
4915  VectorType maskType = getMaskType();
4916  auto paddingType = getPadding().getType();
4917  auto permutationMap = getPermutationMap();
4918  VectorType inferredMaskType =
4919  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4920  : VectorType();
4921  auto sourceElementType = shapedType.getElementType();
4922 
4923  if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
4924  return emitOpError("requires ") << shapedType.getRank() << " indices";
4925 
4926  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4927  shapedType, vectorType, maskType,
4928  inferredMaskType, permutationMap, getInBounds())))
4929  return failure();
4930 
4931  if (auto sourceVectorElementType =
4932  llvm::dyn_cast<VectorType>(sourceElementType)) {
4933  // Source has vector element type.
4934  // Check that 'sourceVectorElementType' and 'paddingType' types match.
4935  if (sourceVectorElementType != paddingType)
4936  return emitOpError(
4937  "requires source element type and padding type to match.");
4938 
4939  } else {
4940  // Check that 'paddingType' is valid to store in a vector type.
4941  if (!VectorType::isValidElementType(paddingType))
4942  return emitOpError("requires valid padding vector elemental type");
4943 
4944  // Check that padding type and vector element types match.
4945  if (paddingType != sourceElementType)
4946  return emitOpError(
4947  "requires formal padding and source of the same elemental type");
4948  }
4949 
4950  return verifyPermutationMap(permutationMap,
4951  [&](Twine t) { return emitOpError(t); });
4952 }
4953 
4954 // MaskableOpInterface methods.
4955 
4956 /// Returns the mask type expected by this operation. Mostly used for
4957 /// verification purposes. It requires the operation to be vectorized."
4958 Type TransferReadOp::getExpectedMaskType() {
4959  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4960 }
4961 
4962 //===----------------------------------------------------------------------===//
4963 // TransferReadOp: VectorTransferOpInterface methods.
4964 //===----------------------------------------------------------------------===//
4965 VectorType TransferReadOp::getVectorType() {
4966  return cast<VectorType>(getVector().getType());
4967 }
4968 
4969 template <typename TransferOp>
4970 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4971  // TODO: support more aggressive createOrFold on:
4972  // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4973  if (op.getShapedType().isDynamicDim(indicesIdx))
4974  return false;
4975  Value index = op.getIndices()[indicesIdx];
4976  std::optional<int64_t> cstOp = getConstantIntValue(index);
4977  if (!cstOp.has_value())
4978  return false;
4979 
4980  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4981  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4982 
4983  return cstOp.value() + vectorSize <= sourceSize;
4984 }
4985 
4986 template <typename TransferOp>
4987 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4988  // TODO: support 0-d corner case.
4989  // TODO: Be less conservative.
4990  if (op.getTransferRank() == 0)
4991  return failure();
4992  AffineMap permutationMap = op.getPermutationMap();
4993  bool changed = false;
4994  SmallVector<bool, 4> newInBounds;
4995  newInBounds.reserve(op.getTransferRank());
4996  // Idxs of non-bcast dims - used when analysing bcast dims.
4997  SmallVector<unsigned> nonBcastDims;
4998 
4999  // 1. Process non-broadcast dims
5000  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
5001  // 1.1. Already marked as in-bounds, nothing to see here.
5002  if (op.isDimInBounds(i)) {
5003  newInBounds.push_back(true);
5004  continue;
5005  }
5006  // 1.2. Currently out-of-bounds, check whether we can statically determine
5007  // it is inBounds.
5008  bool inBounds = false;
5009  auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
5010  if (dimExpr) {
5011  inBounds = isInBounds(op, /*resultIdx=*/i,
5012  /*indicesIdx=*/dimExpr.getPosition());
5013  nonBcastDims.push_back(i);
5014  }
5015 
5016  newInBounds.push_back(inBounds);
5017  // We commit the pattern if it is "more inbounds".
5018  changed |= inBounds;
5019  }
5020 
5021  // 2. Handle broadcast dims
5022  // If all non-broadcast dims are "in bounds", then all bcast dims should be
5023  // "in bounds" as well.
5024  bool allNonBcastDimsInBounds = llvm::all_of(
5025  nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
5026  if (allNonBcastDimsInBounds) {
5027  for (size_t idx : permutationMap.getBroadcastDims()) {
5028  changed |= !newInBounds[idx];
5029  newInBounds[idx] = true;
5030  }
5031  }
5032 
5033  if (!changed)
5034  return failure();
5035  // OpBuilder is only used as a helper to build an I64ArrayAttr.
5036  OpBuilder b(op.getContext());
5037  op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
5038  return success();
5039 }
5040 
5041 template <typename TransferOp>
5042 static LogicalResult foldTransferFullMask(TransferOp op) {
5043  auto mask = op.getMask();
5044  if (!mask)
5045  return failure();
5046 
5047  if (getMaskFormat(mask) != MaskFormat::AllTrue)
5048  return failure();
5049 
5050  op.getMaskMutable().clear();
5051  return success();
5052 }
5053 
5054 /// ```
5055 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5056 /// : vector<1x4xf32>, tensor<4x4xf32>
5057 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
5058 /// : tensor<4x4xf32>, vector<1x4xf32>
5059 /// ```
5060 /// -> Folds into
5061 /// ```
5062 /// %v0
5063 /// ```
5064 static Value foldRAW(TransferReadOp readOp) {
5065  if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5066  return {};
5067  auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5068  while (defWrite) {
5069  if (checkSameValueRAW(defWrite, readOp))
5070  return defWrite.getVector();
5072  cast<VectorTransferOpInterface>(defWrite.getOperation()),
5073  cast<VectorTransferOpInterface>(readOp.getOperation())))
5074  break;
5075  defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5076  }
5077  return {};
5078 }
5079 
5080 OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5081  if (Value vec = foldRAW(*this))
5082  return vec;
5083  /// transfer_read(memrefcast) -> transfer_read
5084  if (succeeded(foldTransferInBoundsAttribute(*this)))
5085  return getResult();
5086  if (succeeded(foldTransferFullMask(*this)))
5087  return getResult();
5088  if (succeeded(memref::foldMemRefCast(*this)))
5089  return getResult();
5090  if (succeeded(tensor::foldTensorCast(*this)))
5091  return getResult();
5092  return OpFoldResult();
5093 }
5094 
5095 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5096  return llvm::to_vector<4>(getVectorType().getShape());
5097 }
5098 
5099 void TransferReadOp::getEffects(
5101  &effects) {
5102  if (llvm::isa<MemRefType>(getShapedType()))
5103  effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5105 }
5106 
5107 Speculation::Speculatability TransferReadOp::getSpeculatability() {
5108  if (hasPureTensorSemantics())
5111 }
5112 
5113 namespace {
5114 /// Store to load forwarding for transfer operations with permuation maps.
5115 /// Even if the permutation maps are different we can still propagate the store
5116 /// into the load if the size of the dimensions read and written match. Then we
5117 /// can replace the transfer_read + transfer_write by vector.broadcast and
5118 /// vector.transpose.
5119 /// Example:
5120 /// ```
5121 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
5122 /// {in_bounds = [true, true],
5123 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
5124 /// vector<4x1xf32>, tensor<4x4x4xf32>
5125 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
5126 /// {in_bounds = [true, true, true, true],
5127 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
5128 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
5129 /// ```
5130 /// To:
5131 /// ```
5132 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
5133 /// %r = vector.transpose %0, [3, 0, 2, 1] :
5134 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
5135 /// ```
5136 struct TransferReadAfterWriteToBroadcast
5137  : public OpRewritePattern<TransferReadOp> {
5138  using Base::Base;
5139 
5140  LogicalResult matchAndRewrite(TransferReadOp readOp,
5141  PatternRewriter &rewriter) const override {
5142  auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5143  if (!defWrite)
5144  return failure();
5145  // Bail if we need an alias analysis.
5146  if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5147  return failure();
5148  // Bail if we need a bounds analysis.
5149  if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5150  return failure();
5151  // TODO: If the written transfer chunk is a superset of the read transfer
5152  // chunk we could do an extract_strided_slice.
5153  if (readOp.getTransferChunkAccessed() !=
5154  defWrite.getTransferChunkAccessed())
5155  return failure();
5156  // TODO: Support cases where a dim is explicitly written but implicitly
5157  // read (i.e., a unit dim that is rank reduced).
5158  if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
5159  getUnusedDimsBitVector({defWrite.getPermutationMap()}))
5160  return failure();
5161  // This pattern should only catch the broadcast case, the non-broadcast case
5162  // should be done separately to keep application conditions clean and
5163  // separate.
5164  AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
5165  AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
5166  bool bcast = !readMap.getBroadcastDims().empty() ||
5167  !writeMap.getBroadcastDims().empty();
5168  if (!bcast)
5169  return failure();
5170  // At this point, we know we have a bcast.
5171  // Bail in the masked case (too complex atm and needed to properly account
5172  // for padding).
5173  if (readOp.getMask() || defWrite.getMask())
5174  return failure();
5175  // If indices are not the same a shift may be required, bail.
5176  if (readOp.getIndices() != defWrite.getIndices())
5177  return failure();
5178 
5179  Value vec = defWrite.getVector();
5180  // TODO: loop through the chain of transfer_write if we can prove that they
5181  // don't overlap with the transfer_read. This requires improving
5182  // `isDisjointTransferIndices` helper.
5183  AffineMap map = readMap.compose(writeMap);
5184  if (map.getNumResults() == 0)
5185  return failure();
5186  // Calculate the permutation to apply to go from the vector stored to the
5187  // vector read.
5188  SmallVector<unsigned> permutation;
5189  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
5190  return failure();
5191 
5192  Location loc = readOp.getLoc();
5193  // Calculate the broadcast shape by applying the reverse permutation to the
5194  // final shape we want.
5195  ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
5196  SmallVector<int64_t> broadcastShape(destShape.size());
5197  SmallVector<bool> broadcastScalableFlags(destShape.size());
5198  for (const auto &pos : llvm::enumerate(permutation)) {
5199  broadcastShape[pos.value()] = destShape[pos.index()];
5200  broadcastScalableFlags[pos.value()] =
5201  readOp.getVectorType().getScalableDims()[pos.index()];
5202  }
5203  VectorType broadcastedType = VectorType::get(
5204  broadcastShape, defWrite.getVectorType().getElementType(),
5205  broadcastScalableFlags);
5206  vec = vector::BroadcastOp::create(rewriter, loc, broadcastedType, vec);
5207  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
5208  rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
5209  transposePerm);
5210  return success();
5211  }
5212 };
5213 } // namespace
5214 
5215 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5216  MLIRContext *context) {
5217  results.add<TransferReadAfterWriteToBroadcast>(context);
5218 }
5219 
5220 FailureOr<std::optional<SmallVector<Value>>>
5221 TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5222  if (!hasPureBufferSemantics())
5223  return failure();
5225  getResult());
5226 }
5227 
5228 //===----------------------------------------------------------------------===//
5229 // TransferWriteOp
5230 //===----------------------------------------------------------------------===//
5231 
5232 /// 1. Builder with type inference.
5233 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5234  Value vector, Value dest, ValueRange indices,
5235  AffineMapAttr permutationMapAttr,
5236  /*optional*/ Value mask,
5237  /*optional*/ ArrayAttr inBoundsAttr) {
5238  Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
5239  build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
5240  mask, inBoundsAttr);
5241 }
5242 
5243 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
5244 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5245  Value vector, Value dest, ValueRange indices,
5246  AffineMapAttr permutationMapAttr,
5247  /*optional*/ ArrayAttr inBoundsAttr) {
5248  build(builder, result, vector, dest, indices, permutationMapAttr,
5249  /*mask=*/Value(), inBoundsAttr);
5250 }
5251 
5252 /// 3. Builder with type inference that sets an empty mask (variant without
5253 /// attrs)
5254 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5255  Value vector, Value dest, ValueRange indices,
5256  AffineMap permutationMap,
5257  std::optional<ArrayRef<bool>> inBounds) {
5258  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5259  auto inBoundsAttr =
5260  (inBounds && !inBounds.value().empty())
5261  ? builder.getBoolArrayAttr(inBounds.value())
5263  llvm::cast<VectorType>(vector.getType()).getRank(), false));
5264  build(builder, result, vector, dest, indices, permutationMapAttr,
5265  /*mask=*/Value(), inBoundsAttr);
5266 }
5267 
5268 /// 4. Builder with type inference that sets an empty mask and sets permutation
5269 /// map to 'getMinorIdentityMap'.
5270 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5271  Value vector, Value dest, ValueRange indices,
5272  std::optional<ArrayRef<bool>> inBounds) {
5273  auto vectorType = llvm::cast<VectorType>(vector.getType());
5274  AffineMap permutationMap = getTransferMinorIdentityMap(
5275  llvm::cast<ShapedType>(dest.getType()), vectorType);
5276  build(builder, result, vector, dest, indices, permutationMap, inBounds);
5277 }
5278 
5279 ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5280  OperationState &result) {
5281  auto &builder = parser.getBuilder();
5282  SMLoc typesLoc;
5283  OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5285  SmallVector<Type, 2> types;
5287  if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
5288  parser.parseOperand(sourceInfo) ||
5289  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
5290  return failure();
5291  ParseResult hasMask = parser.parseOptionalComma();
5292  if (hasMask.succeeded() && parser.parseOperand(maskInfo))
5293  return failure();
5294  if (parser.parseOptionalAttrDict(result.attributes) ||
5295  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5296  return failure();
5297  if (types.size() != 2)
5298  return parser.emitError(typesLoc, "requires two types");
5299  auto indexType = builder.getIndexType();
5300  VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5301  if (!vectorType)
5302  return parser.emitError(typesLoc, "requires vector type");
5303  ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5304  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5305  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5306  auto permMapAttrName =
5307  TransferWriteOp::getPermutationMapAttrName(result.name);
5308  auto permMapAttr = result.attributes.get(permMapAttrName);
5309  AffineMap permMap;
5310  if (!permMapAttr) {
5311  if (shapedType.getRank() <
5312  getEffectiveVectorRankForXferOp(shapedType, vectorType))
5313  return parser.emitError(typesLoc,
5314  "expected a custom permutation_map when "
5315  "rank(source) != rank(destination)");
5316  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5317  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5318  } else {
5319  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5320  }
5321  auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
5322  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5323  if (!inBoundsAttr) {
5324  result.addAttribute(inBoundsAttrName,
5325  builder.getBoolArrayAttr(
5326  SmallVector<bool>(permMap.getNumResults(), false)));
5327  }
5328  if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
5329  parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5330  parser.resolveOperands(indexInfo, indexType, result.operands))
5331  return failure();
5332  if (hasMask.succeeded()) {
5333  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5334  return parser.emitError(
5335  maskInfo.location, "does not support masks with vector element type");
5336  if (vectorType.getRank() != permMap.getNumResults()) {
5337  return parser.emitError(typesLoc,
5338  "expected the same rank for the vector and the "
5339  "results of the permutation map");
5340  }
5341  auto maskType = inferTransferOpMaskType(vectorType, permMap);
5342  if (parser.resolveOperand(maskInfo, maskType, result.operands))
5343  return failure();
5344  }
5345  result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5346  builder.getDenseI32ArrayAttr(
5347  {1, 1, static_cast<int32_t>(indexInfo.size()),
5348  static_cast<int32_t>(hasMask.succeeded())}));
5349  return failure(llvm::isa<RankedTensorType>(shapedType) &&
5350  parser.addTypeToList(shapedType, result.types));
5351 }
5352 
5354  p << " " << getVector() << ", " << getBase() << "[" << getIndices() << "]";
5355  if (getMask())
5356  p << ", " << getMask();
5357  printTransferAttrs(p, *this);
5358  p << " : " << getVectorType() << ", " << getShapedType();
5359 }
5360 
5361 LogicalResult TransferWriteOp::verify() {
5362  // Consistency of elemental types in shape and vector.
5363  ShapedType shapedType = getShapedType();
5364  VectorType vectorType = getVectorType();
5365  VectorType maskType = getMaskType();
5366  auto permutationMap = getPermutationMap();
5367  VectorType inferredMaskType =
5368  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5369  : VectorType();
5370 
5371  if (llvm::size(getIndices()) != shapedType.getRank())
5372  return emitOpError("requires ") << shapedType.getRank() << " indices";
5373 
5374  // We do not allow broadcast dimensions on TransferWriteOps for the moment,
5375  // as the semantics is unclear. This can be revisited later if necessary.
5376  if (hasBroadcastDim())
5377  return emitOpError("should not have broadcast dimensions");
5378 
5379  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5380  shapedType, vectorType, maskType,
5381  inferredMaskType, permutationMap, getInBounds())))
5382  return failure();
5383 
5384  return verifyPermutationMap(permutationMap,
5385  [&](Twine t) { return emitOpError(t); });
5386 }
5387 
5388 //===----------------------------------------------------------------------===//
5389 // TransferWriteOp: MaskableOpInterface methods.
5390 //===----------------------------------------------------------------------===//
5391 
5392 /// Returns the mask type expected by this operation. Mostly used for
5393 /// verification purposes.
5394 Type TransferWriteOp::getExpectedMaskType() {
5395  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5396 }
5397 
5398 //===----------------------------------------------------------------------===//
5399 // TransferWriteOp: VectorTransferOpInterface methods.
5400 //===----------------------------------------------------------------------===//
5401 Value TransferWriteOp::getVector() { return getOperand(0); }
5402 VectorType TransferWriteOp::getVectorType() {
5403  return cast<VectorType>(getValueToStore().getType());
5404 }
5405 
5406 //===----------------------------------------------------------------------===//
5407 // TransferWriteOp: fold methods.
5408 //===----------------------------------------------------------------------===//
5409 /// Fold:
5410 /// ```
5411 /// %t1 = ...
5412 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
5413 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
5414 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
5415 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
5416 /// ```
5417 ///
5418 /// into:
5419 ///
5420 /// ```
5421 /// %t0
5422 /// ```
5423 ///
5424 /// The producer of t1 may or may not be DCE'd depending on whether it is a
5425 /// block argument or has side effects.
5426 static LogicalResult foldReadInitWrite(TransferWriteOp write,
5428  SmallVectorImpl<OpFoldResult> &results) {
5429  // TODO: support 0-d corner case.
5430  if (write.getTransferRank() == 0)
5431  return failure();
5432  auto rankedTensorType =
5433  llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5434  // If not operating on tensors, bail.
5435  if (!rankedTensorType)
5436  return failure();
5437  // If no read, bail.
5438  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5439  if (!read)
5440  return failure();
5441  // TODO: support 0-d corner case.
5442  if (read.getTransferRank() == 0)
5443  return failure();
5444  // For now, only accept minor identity. Future: composition is minor identity.
5445  if (!read.getPermutationMap().isMinorIdentity() ||
5446  !write.getPermutationMap().isMinorIdentity())
5447  return failure();
5448  // Bail on mismatching ranks.
5449  if (read.getTransferRank() != write.getTransferRank())
5450  return failure();
5451  // Bail on potential out-of-bounds accesses.
5452  if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5453  return failure();
5454  // Tensor types must be the same.
5455  if (read.getBase().getType() != rankedTensorType)
5456  return failure();
5457  // Vector types must be the same.
5458  if (read.getVectorType() != write.getVectorType())
5459  return failure();
5460  // Vector and Tensor shapes must match.
5461  if (read.getVectorType().getShape() != rankedTensorType.getShape())
5462  return failure();
5463  // If any index is nonzero.
5464  auto isNotConstantZero = [](Value v) {
5465  auto cstOp = getConstantIntValue(v);
5466  return !cstOp.has_value() || cstOp.value() != 0;
5467  };
5468  if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5469  llvm::any_of(write.getIndices(), isNotConstantZero))
5470  return failure();
5471  // Success.
5472  results.push_back(read.getBase());
5473  return success();
5474 }
5475 
5476 static bool checkSameValueWAR(vector::TransferReadOp read,
5477  vector::TransferWriteOp write) {
5478  return read.getBase() == write.getBase() &&
5479  read.getIndices() == write.getIndices() &&
5480  read.getPermutationMap() == write.getPermutationMap() &&
5481  read.getVectorType() == write.getVectorType() && !read.getMask() &&
5482  !write.getMask();
5483 }
5484 /// Fold transfer_write write after read:
5485 /// ```
5486 /// %t0 = ...
5487 /// %v = vector.transfer_read %t0[%c0...] :
5488 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
5489 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
5490 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
5491 /// ```
5492 ///
5493 /// into:
5494 ///
5495 /// ```
5496 /// %t0
5497 /// ```
5498 static LogicalResult foldWAR(TransferWriteOp write,
5499  SmallVectorImpl<OpFoldResult> &results) {
5500  if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5501  return failure();
5502  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5503  if (!read)
5504  return failure();
5505 
5506  if (!checkSameValueWAR(read, write))
5507  return failure();
5508  results.push_back(read.getBase());
5509  return success();
5510 }
5511 
5512 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5513  SmallVectorImpl<OpFoldResult> &results) {
5514  if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
5515  return success();
5516  if (succeeded(foldWAR(*this, results)))
5517  return success();
5518  if (succeeded(foldTransferInBoundsAttribute(*this)))
5519  return success();
5520  if (succeeded(foldTransferFullMask(*this)))
5521  return success();
5522  return memref::foldMemRefCast(*this);
5523 }
5524 
5525 //===----------------------------------------------------------------------===//
5526 // TransferWriteOp: other methods.
5527 //===----------------------------------------------------------------------===//
5528 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5529  return llvm::to_vector<4>(getVectorType().getShape());
5530 }
5531 
5532 void TransferWriteOp::getEffects(
5534  &effects) {
5535  if (llvm::isa<MemRefType>(getShapedType()))
5536  effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5538 }
5539 
5540 Speculation::Speculatability TransferWriteOp::getSpeculatability() {
5541  if (hasPureTensorSemantics())
5544 }
5545 
5546 namespace {
5547 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
5548 /// DCE
5549 /// ```
5550 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5551 /// : vector<1x4xf32>, tensor<4x4xf32>
5552 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
5553 /// : vector<1x4xf32>, tensor<4x4xf32>
5554 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5555 /// : vector<1x4xf32>, tensor<4x4xf32>
5556 /// ```
5557 ///
5558 /// into:
5559 ///
5560 /// ```
5561 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5562 /// : vector<1x4xf32>, tensor<4x4xf32>
5563 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
5564 /// : vector<1x4xf32>, tensor<4x4xf32>
5565 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5566 /// : vector<1x4xf32>, tensor<4x4xf32>
5567 /// ```
5568 ///
5569 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
5570 /// any other uses.
5571 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
5572 public:
5573  using Base::Base;
5574  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5575  PatternRewriter &rewriter) const override {
5576  if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5577  return failure();
5578  vector::TransferWriteOp writeToModify = writeOp;
5579 
5580  auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5581  while (defWrite) {
5582  if (checkSameValueWAW(writeOp, defWrite)) {
5583  rewriter.modifyOpInPlace(writeToModify, [&]() {
5584  writeToModify.getBaseMutable().assign(defWrite.getBase());
5585  });
5586  return success();
5587  }
5589  cast<VectorTransferOpInterface>(defWrite.getOperation()),
5590  cast<VectorTransferOpInterface>(writeOp.getOperation())))
5591  break;
5592  // If the previous write op doesn't have any other use we an safely look
5593  // at the previous store to see if it can be removed.
5594  if (!defWrite->hasOneUse())
5595  break;
5596  writeToModify = defWrite;
5597  defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5598  }
5599  return failure();
5600  }
5601 };
5602 
5603 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
5604 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
5605 /// overwritten and inserted into another tensor. After this rewrite, the
5606 /// operations bufferize in-place since all of them work on the same slice.
5607 ///
5608 /// For example:
5609 /// ```mlir
5610 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
5611 /// : vector<8x16xf32>, tensor<8x16xf32>
5612 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
5613 /// : tensor<8x16xf32> to tensor<?x?xf32>
5614 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5615 /// : tensor<?x?xf32> into tensor<27x37xf32>
5616 /// ```
5617 /// folds to
5618 /// ```mlir
5619 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5620 /// : tensor<27x37xf32> to tensor<?x?xf32>
5621 /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
5622 /// : vector<8x16xf32>, tensor<?x?xf32>
5623 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5624 /// : tensor<?x?xf32> into tensor<27x37xf32>
5625 /// ```
5626 struct SwapExtractSliceOfTransferWrite
5627  : public OpRewritePattern<tensor::InsertSliceOp> {
5628 public:
5629  using Base::Base;
5630 
5631  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5632  PatternRewriter &rewriter) const override {
5633  if (!insertOp.hasUnitStride())
5634  return failure();
5635  auto extractOp =
5636  insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5637  if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5638  return failure();
5639  auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5640  if (!transferOp || !transferOp->hasOneUse())
5641  return failure();
5642 
5643  // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
5644  // rank-reducing.
5645  if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5646  return rewriter.notifyMatchFailure(insertOp,
5647  "use-def chain is rank-reducing");
5648  }
5649 
5650  // Fail if tensor::ExtractSliceOp has non-zero offset.
5651  if (!extractOp.hasZeroOffset()) {
5652  return rewriter.notifyMatchFailure(insertOp,
5653  "ExtractSliceOp has non-zero offset");
5654  }
5655 
5656  // Fail if tensor::TransferWriteOp has non-zero offset.
5657  if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5658  return getConstantIntValue(value) == static_cast<int64_t>(0);
5659  })) {
5660  return rewriter.notifyMatchFailure(insertOp,
5661  "TranferWriteOp has non-zero offset");
5662  }
5663 
5664  // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
5665  if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5666  return rewriter.notifyMatchFailure(
5667  insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
5668  }
5669 
5670  for (auto [insertSize, extractSize] :
5671  llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5672  if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
5673  return rewriter.notifyMatchFailure(
5674  insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
5675  }
5676  }
5677 
5678  // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
5679  assert(transferOp.getVectorType().hasStaticShape() &&
5680  "expected vector to have a static shape");
5681  ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
5683  transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5684  if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
5685  return rewriter.notifyMatchFailure(
5686  insertOp, "TransferWriteOp may not write the full tensor.");
5687  }
5688 
5689  // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
5690  // Set all in_bounds to false and let the folder infer them.
5691  SmallVector<bool> newInBounds(vectorShape.size(), false);
5692  auto newExtractOp = tensor::ExtractSliceOp::create(
5693  rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5694  insertOp.getDest(), insertOp.getMixedOffsets(),
5695  insertOp.getMixedSizes(), insertOp.getMixedStrides());
5696  auto newTransferWriteOp = TransferWriteOp::create(
5697  rewriter, transferOp.getLoc(), transferOp.getVector(),
5698  newExtractOp.getResult(), transferOp.getIndices(),
5699  transferOp.getPermutationMapAttr(),
5700  rewriter.getBoolArrayAttr(newInBounds));
5701  rewriter.modifyOpInPlace(insertOp, [&]() {
5702  insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5703  });
5704  return success();
5705  }
5706 };
5707 
5708 } // namespace
5709 
5710 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5711  MLIRContext *context) {
5712  results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5713 }
5714 
5715 FailureOr<std::optional<SmallVector<Value>>>
5716 TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
5717  if (!hasPureBufferSemantics())
5718  return failure();
5720  ValueRange());
5721 }
5722 
5723 //===----------------------------------------------------------------------===//
5724 // LoadOp
5725 //===----------------------------------------------------------------------===//
5726 
5727 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5728  VectorType vecTy,
5729  MemRefType memRefTy) {
5730  // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
5731  // need any strides limitations.
5732  if (!vecTy.isScalable() &&
5733  (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5734  return success();
5735 
5736  if (!memRefTy.isLastDimUnitStride())
5737  return op->emitOpError("most minor memref dim must have unit stride");
5738  return success();
5739 }
5740 
5741 LogicalResult vector::LoadOp::verify() {
5742  VectorType resVecTy = getVectorType();
5743  MemRefType memRefTy = getMemRefType();
5744 
5745  if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
5746  return failure();
5747 
5748  if (memRefTy.getRank() < resVecTy.getRank())
5749  return emitOpError(
5750  "destination memref has lower rank than the result vector");
5751 
5752  // Checks for vector memrefs.
5753  Type memElemTy = memRefTy.getElementType();
5754  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5755  if (memVecTy != resVecTy)
5756  return emitOpError("base memref and result vector types should match");
5757  memElemTy = memVecTy.getElementType();
5758  }
5759 
5760  if (resVecTy.getElementType() != memElemTy)
5761  return emitOpError("base and result element types should match");
5762  if (llvm::size(getIndices()) != memRefTy.getRank())
5763  return emitOpError("requires ") << memRefTy.getRank() << " indices";
5764  return success();
5765 }
5766 
5767 OpFoldResult LoadOp::fold(FoldAdaptor) {
5768  if (succeeded(memref::foldMemRefCast(*this)))
5769  return getResult();
5770  return OpFoldResult();
5771 }
5772 
5773 std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5774  return llvm::to_vector<4>(getVectorType().getShape());
5775 }
5776 
5777 FailureOr<std::optional<SmallVector<Value>>>
5778 LoadOp::bubbleDownCasts(OpBuilder &builder) {
5780  getResult());
5781 }
5782 
5783 //===----------------------------------------------------------------------===//
5784 // StoreOp
5785 //===----------------------------------------------------------------------===//
5786 
5787 LogicalResult vector::StoreOp::verify() {
5788  VectorType valueVecTy = getVectorType();
5789  MemRefType memRefTy = getMemRefType();
5790 
5791  if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
5792  return failure();
5793 
5794  if (memRefTy.getRank() < valueVecTy.getRank())
5795  return emitOpError("source memref has lower rank than the vector to store");
5796 
5797  // Checks for vector memrefs.
5798  Type memElemTy = memRefTy.getElementType();
5799  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5800  if (memVecTy != valueVecTy)
5801  return emitOpError(
5802  "base memref and valueToStore vector types should match");
5803  memElemTy = memVecTy.getElementType();
5804  }
5805 
5806  if (valueVecTy.getElementType() != memElemTy)
5807  return emitOpError("base and valueToStore element type should match");
5808  if (llvm::size(getIndices()) != memRefTy.getRank())
5809  return emitOpError("requires ") << memRefTy.getRank() << " indices";
5810  return success();
5811 }
5812 
5813 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5814  SmallVectorImpl<OpFoldResult> &results) {
5815  return memref::foldMemRefCast(*this);
5816 }
5817 
5818 std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5819  return llvm::to_vector<4>(getVectorType().getShape());
5820 }
5821 
5822 FailureOr<std::optional<SmallVector<Value>>>
5823 StoreOp::bubbleDownCasts(OpBuilder &builder) {
5825  ValueRange());
5826 }
5827 
5828 //===----------------------------------------------------------------------===//
5829 // MaskedLoadOp
5830 //===----------------------------------------------------------------------===//
5831 
5832 LogicalResult MaskedLoadOp::verify() {
5833  VectorType maskVType = getMaskVectorType();
5834  VectorType passVType = getPassThruVectorType();
5835  VectorType resVType = getVectorType();
5836  MemRefType memType = getMemRefType();
5837 
5838  if (resVType.getElementType() != memType.getElementType())
5839  return emitOpError("base and result element type should match");
5840  if (llvm::size(getIndices()) != memType.getRank())
5841  return emitOpError("requires ") << memType.getRank() << " indices";
5842  if (resVType.getShape() != maskVType.getShape())
5843  return emitOpError("expected result shape to match mask shape");
5844  if (resVType != passVType)
5845  return emitOpError("expected pass_thru of same type as result type");
5846  return success();
5847 }
5848 
5849 namespace {
5850 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
5851 public:
5852  using Base::Base;
5853  LogicalResult matchAndRewrite(MaskedLoadOp load,
5854  PatternRewriter &rewriter) const override {
5855  switch (getMaskFormat(load.getMask())) {
5856  case MaskFormat::AllTrue:
5857  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5858  load, load.getType(), load.getBase(), load.getIndices());
5859  return success();
5860  case MaskFormat::AllFalse:
5861  rewriter.replaceOp(load, load.getPassThru());
5862  return success();
5863  case MaskFormat::Unknown:
5864  return failure();
5865  }
5866  llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
5867  }
5868 };
5869 } // namespace
5870 
5871 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5872  MLIRContext *context) {
5873  results.add<MaskedLoadFolder>(context);
5874 }
5875 
5876 OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5877  if (succeeded(memref::foldMemRefCast(*this)))
5878  return getResult();
5879  return OpFoldResult();
5880 }
5881 
5882 FailureOr<std::optional<SmallVector<Value>>>
5883 MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
5885  getResult());
5886 }
5887 
5888 //===----------------------------------------------------------------------===//
5889 // MaskedStoreOp
5890 //===----------------------------------------------------------------------===//
5891 
5892 LogicalResult MaskedStoreOp::verify() {
5893  VectorType maskVType = getMaskVectorType();
5894  VectorType valueVType = getVectorType();
5895  MemRefType memType = getMemRefType();
5896 
5897  if (valueVType.getElementType() != memType.getElementType())
5898  return emitOpError("base and valueToStore element type should match");
5899  if (llvm::size(getIndices()) != memType.getRank())
5900  return emitOpError("requires ") << memType.getRank() << " indices";
5901  if (valueVType.getShape() != maskVType.getShape())
5902  return emitOpError("expected valueToStore shape to match mask shape");
5903  return success();
5904 }
5905 
5906 namespace {
5907 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
5908 public:
5909  using Base::Base;
5910  LogicalResult matchAndRewrite(MaskedStoreOp store,
5911  PatternRewriter &rewriter) const override {
5912  switch (getMaskFormat(store.getMask())) {
5913  case MaskFormat::AllTrue:
5914  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5915  store, store.getValueToStore(), store.getBase(), store.getIndices());
5916  return success();
5917  case MaskFormat::AllFalse:
5918  rewriter.eraseOp(store);
5919  return success();
5920  case MaskFormat::Unknown:
5921  return failure();
5922  }
5923  llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
5924  }
5925 };
5926 } // namespace
5927 
5928 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5929  MLIRContext *context) {
5930  results.add<MaskedStoreFolder>(context);
5931 }
5932 
5933 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5934  SmallVectorImpl<OpFoldResult> &results) {
5935  return memref::foldMemRefCast(*this);
5936 }
5937 
5938 FailureOr<std::optional<SmallVector<Value>>>
5939 MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
5941  ValueRange());
5942 }
5943 
5944 //===----------------------------------------------------------------------===//
5945 // GatherOp
5946 //===----------------------------------------------------------------------===//
5947 
5948 LogicalResult GatherOp::verify() {
5949  VectorType indVType = getIndexVectorType();
5950  VectorType maskVType = getMaskVectorType();
5951  VectorType resVType = getVectorType();
5952  ShapedType baseType = getBaseType();
5953 
5954  if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5955  return emitOpError("requires base to be a memref or ranked tensor type");
5956 
5957  if (resVType.getElementType() != baseType.getElementType())
5958  return emitOpError("base and result element type should match");
5959  if (llvm::size(getOffsets()) != baseType.getRank())
5960  return emitOpError("requires ") << baseType.getRank() << " indices";
5961  if (resVType.getShape() != indVType.getShape())
5962  return emitOpError("expected result dim to match indices dim");
5963  if (resVType.getShape() != maskVType.getShape())
5964  return emitOpError("expected result dim to match mask dim");
5965  if (resVType != getPassThruVectorType())
5966  return emitOpError("expected pass_thru of same type as result type");
5967  return success();
5968 }
5969 
5970 // MaskableOpInterface methods.
5971 
5972 /// Returns the mask type expected by this operation. Mostly used for
5973 /// verification purposes. It requires the operation to be vectorized."
5974 Type GatherOp::getExpectedMaskType() {
5975  auto vecType = this->getIndexVectorType();
5976  return VectorType::get(vecType.getShape(),
5977  IntegerType::get(vecType.getContext(), /*width=*/1),
5978  vecType.getScalableDims());
5979 }
5980 
5981 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5982  return llvm::to_vector<4>(getVectorType().getShape());
5983 }
5984 
5985 /// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
5986 static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
5987  auto vecType = dyn_cast<VectorType>(indexVec.getType());
5988  if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5989  return failure();
5990 
5991  if (indexVec.getDefiningOp<StepOp>())
5992  return success();
5993 
5994  DenseIntElementsAttr elements;
5995  if (!matchPattern(indexVec, m_Constant(&elements)))
5996  return failure();
5997 
5998  return success(
5999  llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6000 }
6001 
6002 namespace {
6003 class GatherFolder final : public OpRewritePattern<GatherOp> {
6004 public:
6005  using Base::Base;
6006  LogicalResult matchAndRewrite(GatherOp gather,
6007  PatternRewriter &rewriter) const override {
6008  switch (getMaskFormat(gather.getMask())) {
6009  case MaskFormat::AllTrue:
6010  return failure(); // no unmasked equivalent
6011  case MaskFormat::AllFalse:
6012  rewriter.replaceOp(gather, gather.getPassThru());
6013  return success();
6014  case MaskFormat::Unknown:
6015  return failure();
6016  }
6017  llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
6018  }
6019 };
6020 
6021 /// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
6022 /// maskedload. Only 1D fixed vectors are supported for now.
6023 class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
6024 public:
6025  using Base::Base;
6026  LogicalResult matchAndRewrite(GatherOp op,
6027  PatternRewriter &rewriter) const override {
6028  if (!isa<MemRefType>(op.getBase().getType()))
6029  return rewriter.notifyMatchFailure(op, "base must be of memref type");
6030 
6031  if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6032  return failure();
6033 
6034  rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
6035  op.getOffsets(), op.getMask(),
6036  op.getPassThru());
6037  return success();
6038  }
6039 };
6040 } // namespace
6041 
6042 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6043  MLIRContext *context) {
6044  results.add<GatherFolder, FoldContiguousGather>(context);
6045 }
6046 
6047 FailureOr<std::optional<SmallVector<Value>>>
6048 GatherOp::bubbleDownCasts(OpBuilder &builder) {
6050  getResult());
6051 }
6052 
6053 //===----------------------------------------------------------------------===//
6054 // ScatterOp
6055 //===----------------------------------------------------------------------===//
6056 
6057 LogicalResult ScatterOp::verify() {
6058  VectorType indVType = getIndexVectorType();
6059  VectorType maskVType = getMaskVectorType();
6060  VectorType valueVType = getVectorType();
6061  MemRefType memType = getMemRefType();
6062 
6063  if (valueVType.getElementType() != memType.getElementType())
6064  return emitOpError("base and valueToStore element type should match");
6065  if (llvm::size(getOffsets()) != memType.getRank())
6066  return emitOpError("requires ") << memType.getRank() << " indices";
6067  if (valueVType.getShape() != indVType.getShape())
6068  return emitOpError("expected valueToStore dim to match indices dim");
6069  if (valueVType.getShape() != maskVType.getShape())
6070  return emitOpError("expected valueToStore dim to match mask dim");
6071  return success();
6072 }
6073 
6074 namespace {
6075 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
6076 public:
6077  using Base::Base;
6078  LogicalResult matchAndRewrite(ScatterOp scatter,
6079  PatternRewriter &rewriter) const override {
6080  switch (getMaskFormat(scatter.getMask())) {
6081  case MaskFormat::AllTrue:
6082  return failure(); // no unmasked equivalent
6083  case MaskFormat::AllFalse:
6084  rewriter.eraseOp(scatter);
6085  return success();
6086  case MaskFormat::Unknown:
6087  return failure();
6088  }
6089  llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
6090  }
6091 };
6092 
6093 /// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
6094 /// maskedstore. Only 1D fixed vectors are supported for now.
6095 class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
6096 public:
6097  using Base::Base;
6098  LogicalResult matchAndRewrite(ScatterOp op,
6099  PatternRewriter &rewriter) const override {
6100  if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6101  return failure();
6102 
6103  rewriter.replaceOpWithNewOp<MaskedStoreOp>(
6104  op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6105  return success();
6106  }
6107 };
6108 } // namespace
6109 
6110 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6111  MLIRContext *context) {
6112  results.add<ScatterFolder, FoldContiguousScatter>(context);
6113 }
6114 
6115 FailureOr<std::optional<SmallVector<Value>>>
6116 ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6118  ValueRange());
6119 }
6120 
6121 //===----------------------------------------------------------------------===//
6122 // ExpandLoadOp
6123 //===----------------------------------------------------------------------===//
6124 
6125 LogicalResult ExpandLoadOp::verify() {
6126  VectorType maskVType = getMaskVectorType();
6127  VectorType passVType = getPassThruVectorType();
6128  VectorType resVType = getVectorType();
6129  MemRefType memType = getMemRefType();
6130 
6131  if (resVType.getElementType() != memType.getElementType())
6132  return emitOpError("base and result element type should match");
6133  if (llvm::size(getIndices()) != memType.getRank())
6134  return emitOpError("requires ") << memType.getRank() << " indices";
6135  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
6136  return emitOpError("expected result dim to match mask dim");
6137  if (resVType != passVType)
6138  return emitOpError("expected pass_thru of same type as result type");
6139  return success();
6140 }
6141 
6142 namespace {
6143 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
6144 public:
6145  using Base::Base;
6146  LogicalResult matchAndRewrite(ExpandLoadOp expand,
6147  PatternRewriter &rewriter) const override {
6148  switch (getMaskFormat(expand.getMask())) {
6149  case MaskFormat::AllTrue:
6150  rewriter.replaceOpWithNewOp<vector::LoadOp>(
6151  expand, expand.getType(), expand.getBase(), expand.getIndices());
6152  return success();
6153  case MaskFormat::AllFalse:
6154  rewriter.replaceOp(expand, expand.getPassThru());
6155  return success();
6156  case MaskFormat::Unknown:
6157  return failure();
6158  }
6159  llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
6160  }
6161 };
6162 } // namespace
6163 
6164 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6165  MLIRContext *context) {
6166  results.add<ExpandLoadFolder>(context);
6167 }
6168 
6169 FailureOr<std::optional<SmallVector<Value>>>
6170 ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6172  getResult());
6173 }
6174 
6175 //===----------------------------------------------------------------------===//
6176 // CompressStoreOp
6177 //===----------------------------------------------------------------------===//
6178 
6179 LogicalResult CompressStoreOp::verify() {
6180  VectorType maskVType = getMaskVectorType();
6181  VectorType valueVType = getVectorType();
6182  MemRefType memType = getMemRefType();
6183 
6184  if (valueVType.getElementType() != memType.getElementType())
6185  return emitOpError("base and valueToStore element type should match");
6186  if (llvm::size(getIndices()) != memType.getRank())
6187  return emitOpError("requires ") << memType.getRank() << " indices";
6188  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6189  return emitOpError("expected valueToStore dim to match mask dim");
6190  return success();
6191 }
6192 
6193 namespace {
6194 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
6195 public:
6196  using Base::Base;
6197  LogicalResult matchAndRewrite(CompressStoreOp compress,
6198  PatternRewriter &rewriter) const override {
6199  switch (getMaskFormat(compress.getMask())) {
6200  case MaskFormat::AllTrue:
6201  rewriter.replaceOpWithNewOp<vector::StoreOp>(
6202  compress, compress.getValueToStore(), compress.getBase(),
6203  compress.getIndices());
6204  return success();
6205  case MaskFormat::AllFalse:
6206  rewriter.eraseOp(compress);
6207  return success();
6208  case MaskFormat::Unknown:
6209  return failure();
6210  }
6211  llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
6212  }
6213 };
6214 } // namespace
6215 
6216 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6217  MLIRContext *context) {
6218  results.add<CompressStoreFolder>(context);
6219 }
6220 
6221 FailureOr<std::optional<SmallVector<Value>>>
6222 CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6224  ValueRange());
6225 }
6226 
6227 //===----------------------------------------------------------------------===//
6228 // ShapeCastOp
6229 //===----------------------------------------------------------------------===//
6230 
6231 void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6232  SetIntRangeFn setResultRanges) {
6233  setResultRanges(getResult(), argRanges.front());
6234 }
6235 
6236 LogicalResult ShapeCastOp::verify() {
6237 
6238  VectorType sourceType = getSourceVectorType();
6239  VectorType resultType = getResultVectorType();
6240 
6241  // Check that element type is preserved
6242  if (sourceType.getElementType() != resultType.getElementType())
6243  return emitOpError("has different source and result element types");
6244 
6245  // Check that number of elements is preserved
6246  int64_t sourceNElms = sourceType.getNumElements();
6247  int64_t resultNElms = resultType.getNumElements();
6248  if (sourceNElms != resultNElms) {
6249  return emitOpError() << "has different number of elements at source ("
6250  << sourceNElms << ") and result (" << resultNElms
6251  << ")";
6252  }
6253 
6254  // Check that (non-)scalability is preserved
6255  int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6256  int64_t resultNScalableDims = resultType.getNumScalableDims();
6257  if (sourceNScalableDims != resultNScalableDims)
6258  return emitOpError() << "has different number of scalable dims at source ("
6259  << sourceNScalableDims << ") and result ("
6260  << resultNScalableDims << ")";
6261 
6262  return success();
6263 }
6264 
6265 /// Return true if `transpose` does not permute a pair of non-unit dims.
6266 /// By `order preserving` we mean that the flattened versions of the input and
6267 /// output vectors are (numerically) identical. In other words `transpose` is
6268 /// effectively a shape cast.
6269 static bool isOrderPreserving(TransposeOp transpose) {
6270  ArrayRef<int64_t> permutation = transpose.getPermutation();
6271  VectorType sourceType = transpose.getSourceVectorType();
6272  ArrayRef<int64_t> inShape = sourceType.getShape();
6273  ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6274  auto isNonScalableUnitDim = [&](int64_t dim) {
6275  return inShape[dim] == 1 && !inDimIsScalable[dim];
6276  };
6277  int64_t current = 0;
6278  for (auto p : permutation) {
6279  if (!isNonScalableUnitDim(p)) {
6280  if (p < current) {
6281  return false;
6282  }
6283  current = p;
6284  }
6285  }
6286  return true;
6287 }
6288 
6289 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6290 
6291  VectorType resultType = getType();
6292 
6293  // No-op shape cast.
6294  if (getSource().getType() == resultType)
6295  return getSource();
6296 
6297  // shape_cast(shape_cast(x)) -> shape_cast(x)
6298  if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6299  setOperand(precedingShapeCast.getSource());
6300  return getResult();
6301  }
6302 
6303  // shape_cast(transpose(x)) -> shape_cast(x)
6304  if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6305  if (isOrderPreserving(transpose)) {
6306  setOperand(transpose.getVector());
6307  return getResult();
6308  }
6309  return {};
6310  }
6311 
6312  // Y = shape_cast(broadcast(X))
6313  // -> X, if X and Y have same type
6314  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6315  if (bcastOp.getSourceType() == resultType)
6316  return bcastOp.getSource();
6317  }
6318 
6319  // shape_cast(constant) -> constant
6320  if (auto denseAttr =
6321  dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6322  return denseAttr.reshape(getType());
6323 
6324  // shape_cast(poison) -> poison
6325  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
6326  return ub::PoisonAttr::get(getContext());
6327 
6328  return {};
6329 }
6330 
6331 namespace {
6332 
6333 /// Helper function that computes a new vector type based on the input vector
6334 /// type by removing the trailing one dims:
6335 ///
6336 /// vector<4x1x1xi1> --> vector<4x1xi1>
6337 ///
6338 static VectorType trimTrailingOneDims(VectorType oldType) {
6339  ArrayRef<int64_t> oldShape = oldType.getShape();
6340  ArrayRef<int64_t> newShape = oldShape;
6341 
6342  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6343  ArrayRef<bool> newScalableDims = oldScalableDims;
6344 
6345  while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6346  newShape = newShape.drop_back(1);
6347  newScalableDims = newScalableDims.drop_back(1);
6348  }
6349 
6350  // Make sure we have at least 1 dimension.
6351  // TODO: Add support for 0-D vectors.
6352  if (newShape.empty()) {
6353  newShape = oldShape.take_back();
6354  newScalableDims = oldScalableDims.take_back();
6355  }
6356 
6357  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6358 }
6359 
6360 /// Folds qualifying shape_cast(create_mask) into a new create_mask
6361 ///
6362 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
6363 /// dimension. If the input vector comes from `vector.create_mask` for which
6364 /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
6365 /// to fold shape_cast into create_mask.
6366 ///
6367 /// BEFORE:
6368 /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
6369 /// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
6370 /// AFTER:
6371 /// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
6372 class ShapeCastCreateMaskFolderTrailingOneDim final
6373  : public OpRewritePattern<ShapeCastOp> {
6374 public:
6375  using Base::Base;
6376 
6377  LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6378  PatternRewriter &rewriter) const override {
6379  Value shapeOpSrc = shapeOp->getOperand(0);
6380  auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
6381  auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
6382  if (!createMaskOp && !constantMaskOp)
6383  return failure();
6384 
6385  VectorType shapeOpResTy = shapeOp.getResultVectorType();
6386  VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6387 
6388  VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6389  if (newVecType != shapeOpResTy)
6390  return failure();
6391 
6392  auto numDimsToDrop =
6393  shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6394 
6395  // No unit dims to drop
6396  if (!numDimsToDrop)
6397  return failure();
6398 
6399  if (createMaskOp) {
6400  auto maskOperands = createMaskOp.getOperands();
6401  auto numMaskOperands = maskOperands.size();
6402 
6403  // Check every mask dim size to see whether it can be dropped
6404  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6405  --i) {
6406  auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6407  if (!constant || (constant.value() != 1))
6408  return failure();
6409  }
6410  SmallVector<Value> newMaskOperands =
6411  maskOperands.drop_back(numDimsToDrop);
6412 
6413  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
6414  newMaskOperands);
6415  return success();
6416  }
6417 
6418  if (constantMaskOp) {
6419  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6420  auto numMaskOperands = maskDimSizes.size();
6421 
6422  // Check every mask dim size to see whether it can be dropped
6423  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6424  --i) {
6425  if (maskDimSizes[i] != 1)
6426  return failure();
6427  }
6428 
6429  auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6430  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
6431  newMaskOperands);
6432  return success();
6433  }
6434 
6435  return failure();
6436  }
6437 };
6438 
6439 /// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
6440 /// i) Y = ShapeCast(X), or
6441 /// ii) Y = Broadcast(X)
6442 /// If both (i) and (ii) are possible, (i) is chosen.
6443 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
6444 public:
6445  using Base::Base;
6446 
6447  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6448  PatternRewriter &rewriter) const override {
6449  auto broadcastOp =
6450  shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6451  if (!broadcastOp)
6452  return failure();
6453 
6454  auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6455  bool srcIsScalar = !srcVectorType;
6456 
6457  // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
6458  // Example:
6459  // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
6460  // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
6461  // to
6462  // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
6463  if (srcVectorType) {
6464  if (srcVectorType.getNumElements() ==
6465  shapeCastOp.getResultVectorType().getNumElements()) {
6466  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
6467  shapeCastOp, shapeCastOp.getResultVectorType(),
6468  broadcastOp.getSource());
6469  return success();
6470  }
6471  }
6472 
6473  // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
6474  // Example
6475  // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
6476  // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
6477  // to
6478  // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
6479  VectorType dstVectorType = shapeCastOp.getResultVectorType();
6480  if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
6481  BroadcastableToResult::Success) {
6482  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6483  shapeCastOp, dstVectorType, broadcastOp.getSource());
6484  return success();
6485  }
6486  return failure();
6487  }
6488 };
6489 
6490 } // namespace
6491 
6492 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6493  MLIRContext *context) {
6494  results
6495  .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
6496  context);
6497 }
6498 
6499 //===----------------------------------------------------------------------===//
6500 // VectorBitCastOp
6501 //===----------------------------------------------------------------------===//
6502 
6503 LogicalResult BitCastOp::verify() {
6504  auto sourceVectorType = getSourceVectorType();
6505  auto resultVectorType = getResultVectorType();
6506 
6507  for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6508  if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6509  return emitOpError("dimension size mismatch at: ") << i;
6510  }
6511 
6512  DataLayout dataLayout = DataLayout::closest(*this);
6513  auto sourceElementBits =
6514  dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
6515  auto resultElementBits =
6516  dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
6517 
6518  if (sourceVectorType.getRank() == 0) {
6519  if (sourceElementBits != resultElementBits)
6520  return emitOpError("source/result bitwidth of the 0-D vector element "
6521  "types must be equal");
6522  } else if (sourceElementBits * sourceVectorType.getShape().back() !=
6523  resultElementBits * resultVectorType.getShape().back()) {
6524  return emitOpError(
6525  "source/result bitwidth of the minor 1-D vectors must be equal");
6526  }
6527 
6528  return success();
6529 }
6530 
6531 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
6532  // Nop cast.
6533  if (getSource().getType() == getResult().getType())
6534  return getSource();
6535 
6536  // Canceling bitcasts.
6537  if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6538  if (getResult().getType() == otherOp.getSource().getType())
6539  return otherOp.getSource();
6540 
6541  setOperand(otherOp.getSource());
6542  return getResult();
6543  }
6544 
6545  Attribute sourceConstant = adaptor.getSource();
6546  if (!sourceConstant)
6547  return {};
6548 
6549  Type srcElemType = getSourceVectorType().getElementType();
6550  Type dstElemType = getResultVectorType().getElementType();
6551 
6552  if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6553  if (floatPack.isSplat()) {
6554  auto splat = floatPack.getSplatValue<FloatAttr>();
6555 
6556  // Casting fp16 into fp32.
6557  if (srcElemType.isF16() && dstElemType.isF32()) {
6558  uint32_t bits = static_cast<uint32_t>(
6559  splat.getValue().bitcastToAPInt().getZExtValue());
6560  // Duplicate the 16-bit pattern.
6561  bits = (bits << 16) | (bits & 0xffff);
6562  APInt intBits(32, bits);
6563  APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6564  return DenseElementsAttr::get(getResultVectorType(), floatBits);
6565  }
6566  }
6567  }
6568 
6569  if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6570  if (intPack.isSplat()) {
6571  auto splat = intPack.getSplatValue<IntegerAttr>();
6572 
6573  if (llvm::isa<IntegerType>(dstElemType)) {
6574  uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
6575  uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
6576 
6577  // Casting to a larger integer bit width.
6578  if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6579  APInt intBits = splat.getValue().zext(dstBitWidth);
6580 
6581  // Duplicate the lower width element.
6582  for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6583  intBits = (intBits << srcBitWidth) | intBits;
6584  return DenseElementsAttr::get(getResultVectorType(), intBits);
6585  }
6586  }
6587  }
6588  }
6589 
6590  return {};
6591 }
6592 
6593 //===----------------------------------------------------------------------===//
6594 // TypeCastOp
6595 //===----------------------------------------------------------------------===//
6596 
6597 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
6598  auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6599  SmallVector<int64_t, 8> res(memRefType.getShape());
6600  if (vectorType)
6601  res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6602  return res;
6603 }
6604 
6605 /// Build the canonical memRefType with a single vector.
6606 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
6607 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
6608  Value source) {
6609  result.addOperands(source);
6610  MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
6611  VectorType vectorType =
6612  VectorType::get(extractShape(memRefType),
6614  result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
6615  memRefType.getMemorySpace()));
6616 }
6617 
6618 LogicalResult TypeCastOp::verify() {
6619  MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
6620  if (!canonicalType.getLayout().isIdentity())
6621  return emitOpError("expects operand to be a memref with identity layout");
6622  if (!getResultMemRefType().getLayout().isIdentity())
6623  return emitOpError("expects result to be a memref with identity layout");
6624  if (getResultMemRefType().getMemorySpace() !=
6625  getMemRefType().getMemorySpace())
6626  return emitOpError("expects result in same memory space");
6627 
6628  auto sourceType = getMemRefType();
6629  auto resultType = getResultMemRefType();
6630  if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
6632  return emitOpError(
6633  "expects result and operand with same underlying scalar type: ")
6634  << resultType;
6635  if (extractShape(sourceType) != extractShape(resultType))
6636  return emitOpError(
6637  "expects concatenated result and operand shapes to be equal: ")
6638  << resultType;
6639  return success();
6640 }
6641 
6642 //===----------------------------------------------------------------------===//
6643 // TransposeOp
6644 //===----------------------------------------------------------------------===//
6645 
6646 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
6647  Value vector, ArrayRef<int64_t> permutation) {
6648  VectorType vt = llvm::cast<VectorType>(vector.getType());
6649  SmallVector<int64_t, 4> transposedShape(vt.getRank());
6650  SmallVector<bool, 4> transposedScalableDims(vt.getRank());
6651  for (unsigned i = 0; i < permutation.size(); ++i) {
6652  transposedShape[i] = vt.getShape()[permutation[i]];
6653  transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6654  }
6655 
6656  result.addOperands(vector);
6657  result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
6658  transposedScalableDims));
6659  result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
6660  builder.getDenseI64ArrayAttr(permutation));
6661 }
6662 
6663 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6664  // Eliminate splat constant transpose ops.
6665  if (auto splat =
6666  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6667  return splat.reshape(getResultVectorType());
6668 
6669  // Eliminate poison transpose ops.
6670  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
6671  return ub::PoisonAttr::get(getContext());
6672 
6673  // Eliminate identity transposes, and more generally any transposes that
6674  // preserves the shape without permuting elements.
6675  //
6676  // Examples of what to fold:
6677  // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6678  // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6679  // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6680  //
6681  // Example of what NOT to fold:
6682  // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6683  //
6684  if (getSourceVectorType() == getResultVectorType() &&
6685  isOrderPreserving(*this))
6686  return getVector();
6687 
6688  return {};
6689 }
6690 
6691 LogicalResult vector::TransposeOp::verify() {
6692  VectorType vectorType = getSourceVectorType();
6693  VectorType resultType = getResultVectorType();
6694  int64_t rank = resultType.getRank();
6695  if (vectorType.getRank() != rank)
6696  return emitOpError("vector result rank mismatch: ") << rank;
6697  // Verify transposition array.
6698  ArrayRef<int64_t> perm = getPermutation();
6699  int64_t size = perm.size();
6700  if (rank != size)
6701  return emitOpError("transposition length mismatch: ") << size;
6702  SmallVector<bool, 8> seen(rank, false);
6703  for (const auto &ta : llvm::enumerate(perm)) {
6704  if (ta.value() < 0 || ta.value() >= rank)
6705  return emitOpError("transposition index out of range: ") << ta.value();
6706  if (seen[ta.value()])
6707  return emitOpError("duplicate position index: ") << ta.value();
6708  seen[ta.value()] = true;
6709  if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6710  return emitOpError("dimension size mismatch at: ") << ta.value();
6711  }
6712  return success();
6713 }
6714 
6715 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6716  return llvm::to_vector<4>(getResultVectorType().getShape());
6717 }
6718 
6719 void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6720  SetIntRangeFn setResultRanges) {
6721  setResultRanges(getResult(), argRanges.front());
6722 }
6723 
6724 namespace {
6725 
6726 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
6727 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
6728 public:
6729  using Base::Base;
6730 
6731  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6732  PatternRewriter &rewriter) const override {
6733  // Composes two permutations: result[i] = permutation1[permutation2[i]].
6734  auto composePermutations = [](ArrayRef<int64_t> permutation1,
6735  ArrayRef<int64_t> permutation2) {
6736  SmallVector<int64_t, 4> result;
6737  for (auto index : permutation2)
6738  result.push_back(permutation1[index]);
6739  return result;
6740  };
6741 
6742  // Return if the input of 'transposeOp' is not defined by another transpose.
6743  vector::TransposeOp parentTransposeOp =
6744  transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6745  if (!parentTransposeOp)
6746  return failure();
6747 
6748  SmallVector<int64_t, 4> permutation = composePermutations(
6749  parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6750  // Replace 'transposeOp' with a new transpose operation.
6751  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
6752  transposeOp, transposeOp.getResult().getType(),
6753  parentTransposeOp.getVector(), permutation);
6754  return success();
6755  }
6756 };
6757 
6758 /// Replace transpose(splat-like(v)) with broadcast(v)
6759 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
6760 public:
6761  using Base::Base;
6762 
6763  LogicalResult matchAndRewrite(TransposeOp transposeOp,
6764  PatternRewriter &rewriter) const override {
6765  Value splat = getScalarSplatSource(transposeOp.getVector());
6766  if (!splat)
6767  return failure();
6768 
6769  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6770  transposeOp, transposeOp.getResultVectorType(), splat);
6771  return success();
6772  }
6773 };
6774 
6775 /// Folds transpose(create_mask) into a new transposed create_mask.
6776 class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6777 public:
6778  using Base::Base;
6779 
6780  LogicalResult matchAndRewrite(TransposeOp transpOp,
6781  PatternRewriter &rewriter) const override {
6782  Value transposeSrc = transpOp.getVector();
6783  auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
6784  auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
6785  if (!createMaskOp && !constantMaskOp)
6786  return failure();
6787 
6788  // Get the transpose permutation and apply it to the vector.create_mask or
6789  // vector.constant_mask operands.
6790  ArrayRef<int64_t> permutation = transpOp.getPermutation();
6791 
6792  if (createMaskOp) {
6793  auto maskOperands = createMaskOp.getOperands();
6794  SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
6795  applyPermutationToVector(newOperands, permutation);
6796 
6797  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
6798  transpOp, transpOp.getResultVectorType(), newOperands);
6799  return success();
6800  }
6801 
6802  // ConstantMaskOp case.
6803  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6804  auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
6805 
6806  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
6807  transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6808  return success();
6809  }
6810 };
6811 
6812 /// Folds transpose(shape_cast) into a new shape_cast.
6813 class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6814 public:
6815  using Base::Base;
6816 
6817  LogicalResult matchAndRewrite(TransposeOp transposeOp,
6818  PatternRewriter &rewriter) const override {
6819  auto shapeCastOp =
6820  transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6821  if (!shapeCastOp)
6822  return failure();
6823  if (!isOrderPreserving(transposeOp))
6824  return failure();
6825 
6826  VectorType resultType = transposeOp.getType();
6827 
6828  // We don't need to check isValidShapeCast at this point, because it is
6829  // guaranteed that merging the transpose into the the shape_cast is a valid
6830  // shape_cast, because the transpose just inserts/removes ones.
6831 
6832  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
6833  shapeCastOp.getSource());
6834  return success();
6835  }
6836 };
6837 
6838 /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
6839 /// 'order preserving', where 'order preserving' means the flattened
6840 /// inputs and outputs of the transpose have identical (numerical) values.
6841 ///
6842 /// Example:
6843 /// ```
6844 /// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
6845 /// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
6846 /// to vector<8x1xi32>
6847 /// ```
6848 /// can be rewritten as the equivalent
6849 /// ```
6850 /// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
6851 /// ```
6852 /// The algorithm works by partitioning dimensions into groups that can be
6853 /// locally permuted while preserving order, and checks that the transpose
6854 /// only permutes within these groups.
6855 ///
6856 /// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
6857 /// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
6858 /// broadcasting from 1x1x4x1x1x7.
6859 /// ^^^ ^ ^^^ ^
6860 /// groups: 0 1 2 3
6861 /// Order preserving permutations for this example are ones that only permute
6862 /// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
6863 class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6864 public:
6865  using Base::Base;
6866  FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
6867  : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
6868 
6869  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
6870  PatternRewriter &rewriter) const override {
6871 
6872  vector::BroadcastOp broadcast =
6873  transpose.getVector().getDefiningOp<vector::BroadcastOp>();
6874  if (!broadcast) {
6875  return rewriter.notifyMatchFailure(transpose,
6876  "not preceded by a broadcast");
6877  }
6878 
6879  auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
6880  VectorType outputType = transpose.getResultVectorType();
6881 
6882  // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
6883  bool inputIsScalar = !inputType;
6884  if (inputIsScalar) {
6885  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
6886  broadcast.getSource());
6887  return success();
6888  }
6889 
6890  ArrayRef<int64_t> permutation = transpose.getPermutation();
6891  ArrayRef<int64_t> inputShape = inputType.getShape();
6892  int64_t inputRank = inputType.getRank();
6893  int64_t outputRank = transpose.getType().getRank();
6894  int64_t deltaRank = outputRank - inputRank;
6895 
6896  int low = 0;
6897  for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
6898  bool notOne = inputShape[inputIndex] != 1;
6899  bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
6900  bool groupEndFound = notOne || prevNotOne;
6901  if (groupEndFound) {
6902  int high = inputIndex + deltaRank;
6903  // Return failure if not all permutation destinations for indices in
6904  // [low, high) are in [low, high), i.e. the permutation is not local to
6905  // the group.
6906  for (int i = low; i < high; ++i) {
6907  if (permutation[i] < low || permutation[i] >= high) {
6908  return rewriter.notifyMatchFailure(
6909  transpose, "permutation not local to group");
6910  }
6911  }
6912  low = high;
6913  }
6914  }
6915 
6916  // We don't need to check the final group [low, outputRank) because if it is
6917  // not locally bound, there must be a preceding group that already failed
6918  // the check (impossible to have just 1 non-locally bound group).
6919 
6920  // The preceding logic also ensures that at this point, the output of the
6921  // transpose is definitely broadcastable from the input shape, assert so:
6922  assert(vector::isBroadcastableTo(inputType, outputType) ==
6923  vector::BroadcastableToResult::Success &&
6924  "not broadcastable directly to transpose output");
6925 
6926  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
6927  broadcast.getSource());
6928 
6929  return success();
6930  }
6931 };
6932 
6933 } // namespace
6934 
6935 void vector::TransposeOp::getCanonicalizationPatterns(
6936  RewritePatternSet &results, MLIRContext *context) {
6937  results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6938  FoldTransposeSplat, FoldTransposeBroadcast>(context);
6939 }
6940 
6941 //===----------------------------------------------------------------------===//
6942 // ConstantMaskOp
6943 //===----------------------------------------------------------------------===//
6944 
6945 void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
6946  VectorType type, ConstantMaskKind kind) {
6947  assert(kind == ConstantMaskKind::AllTrue ||
6948  kind == ConstantMaskKind::AllFalse);
6949  build(builder, result, type,
6950  kind == ConstantMaskKind::AllTrue
6951  ? type.getShape()
6952  : SmallVector<int64_t>(type.getRank(), 0));
6953 }
6954 
6955 LogicalResult ConstantMaskOp::verify() {
6956  auto resultType = llvm::cast<VectorType>(getResult().getType());
6957  // Check the corner case of 0-D vectors first.
6958  if (resultType.getRank() == 0) {
6959  if (getMaskDimSizes().size() != 1)
6960  return emitError("array attr must have length 1 for 0-D vectors");
6961  auto dim = getMaskDimSizes()[0];
6962  if (dim != 0 && dim != 1)
6963  return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
6964  return success();
6965  }
6966 
6967  // Verify that array attr size matches the rank of the vector result.
6968  if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
6969  return emitOpError(
6970  "must specify array attr of size equal vector result rank");
6971  // Verify that each array attr element is in bounds of corresponding vector
6972  // result dimension size.
6973  auto resultShape = resultType.getShape();
6974  auto resultScalableDims = resultType.getScalableDims();
6975  ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
6976  for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
6977  if (maskDimSize < 0 || maskDimSize > resultShape[index])
6978  return emitOpError(
6979  "array attr of size out of bounds of vector result dimension size");
6980  if (resultScalableDims[index] && maskDimSize != 0 &&
6981  maskDimSize != resultShape[index])
6982  return emitOpError(
6983  "only supports 'none set' or 'all set' scalable dimensions");
6984  }
6985  // Verify that if one mask dim size is zero, they all should be zero (because
6986  // the mask region is a conjunction of each mask dimension interval).
6987  bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6988  bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
6989  if (anyZeros && !allZeros)
6990  return emitOpError("expected all mask dim sizes to be zeros, "
6991  "as a result of conjunction with zero mask dim");
6992  return success();
6993 }
6994 
6995 bool ConstantMaskOp::isAllOnesMask() {
6996  auto resultType = getVectorType();
6997  // Check the corner case of 0-D vectors first.
6998  if (resultType.getRank() == 0) {
6999  assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
7000  return getMaskDimSizes()[0] == 1;
7001  }
7002  for (const auto [resultSize, maskDimSize] :
7003  llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7004  if (maskDimSize < resultSize)
7005  return false;
7006  }
7007  return true;
7008 }
7009 
7010 OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7011  ArrayRef<int64_t> bounds = getMaskDimSizes();
7012  ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
7013 
7014  auto createBoolSplat = [&](bool x) {
7016  BoolAttr::get(getContext(), x));
7017  };
7018 
7019  // Check the corner case of 0-D vectors first.
7020  if (vectorSizes.empty()) {
7021  assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
7022  return createBoolSplat(bounds[0] == 1);
7023  }
7024  // Fold vector.constant_mask to splat if possible.
7025  if (bounds == vectorSizes)
7026  return createBoolSplat(true);
7027  if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
7028  return createBoolSplat(false);
7029  return OpFoldResult();
7030 }
7031 
7032 //===----------------------------------------------------------------------===//
7033 // CreateMaskOp
7034 //===----------------------------------------------------------------------===//
7035 
7036 void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
7037  VectorType type,
7038  ArrayRef<OpFoldResult> mixedOperands) {
7039  SmallVector<Value> operands =
7040  getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
7041  build(builder, result, type, operands);
7042 }
7043 
7044 LogicalResult CreateMaskOp::verify() {
7045  auto vectorType = llvm::cast<VectorType>(getResult().getType());
7046  // Verify that an operand was specified for each result vector each dimension.
7047  if (vectorType.getRank() == 0) {
7048  if (getNumOperands() != 1)
7049  return emitOpError(
7050  "must specify exactly one operand for 0-D create_mask");
7051  } else if (getNumOperands() !=
7052  llvm::cast<VectorType>(getResult().getType()).getRank()) {
7053  return emitOpError(
7054  "must specify an operand for each result vector dimension");
7055  }
7056  return success();
7057 }
7058 
7059 namespace {
7060 
7061 /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
7062 ///
7063 /// Ex 1:
7064 /// %c2 = arith.constant 2 : index
7065 /// %c3 = arith.constant 3 : index
7066 /// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
7067 /// Becomes:
7068 /// vector.constant_mask [3, 2] : vector<4x3xi1>
7069 ///
7070 /// Ex 2:
7071 /// %c_neg_1 = arith.constant -1 : index
7072 /// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
7073 /// becomes:
7074 /// vector.constant_mask [0] : vector<[8]xi1>
7075 ///
7076 /// Ex 3:
7077 /// %c8 = arith.constant 8 : index
7078 /// %c16 = arith.constant 16 : index
7079 /// %0 = vector.vscale
7080 /// %1 = arith.muli %0, %c16 : index
7081 /// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
7082 /// becomes:
7083 /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
7084 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
7085 public:
7086  using Base::Base;
7087 
7088  LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7089  PatternRewriter &rewriter) const override {
7090  VectorType maskType = createMaskOp.getVectorType();
7091  ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7092  ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7093 
7094  // Special case: Rank zero shape.
7095  constexpr std::array<int64_t, 1> rankZeroShape{1};
7096  constexpr std::array<bool, 1> rankZeroScalableDims{false};
7097  if (maskType.getRank() == 0) {
7098  maskTypeDimSizes = rankZeroShape;
7099  maskTypeDimScalableFlags = rankZeroScalableDims;
7100  }
7101 
7102  // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
7103  // collect the `constantDims` (for the ConstantMaskOp).
7104  SmallVector<int64_t, 4> constantDims;
7105  for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7106  if (auto intSize = getConstantIntValue(dimSize)) {
7107  // Constant value.
7108  // If the mask dim is non-scalable this can be any value.
7109  // If the mask dim is scalable only zero (all-false) is supported.
7110  if (maskTypeDimScalableFlags[i] && intSize >= 0)
7111  return failure();
7112  constantDims.push_back(*intSize);
7113  } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
7114  // Constant vscale multiple (e.g. 4 x vscale).
7115  // Must be all-true to fold to a ConstantMask.
7116  if (vscaleMultiplier < maskTypeDimSizes[i])
7117  return failure();
7118  constantDims.push_back(*vscaleMultiplier);
7119  } else {
7120  return failure();
7121  }
7122  }
7123 
7124  // Clamp values to constant_mask bounds.
7125  for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7126  value = std::clamp<int64_t>(value, 0, maskDimSize);
7127 
7128  // If one of dim sizes is zero, set all dims to zero.
7129  if (llvm::is_contained(constantDims, 0))
7130  constantDims.assign(constantDims.size(), 0);
7131 
7132  // Replace 'createMaskOp' with ConstantMaskOp.
7133  rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
7134  constantDims);
7135  return success();
7136  }
7137 };
7138 
7139 } // namespace
7140 
7141 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7142  MLIRContext *context) {
7143  results.add<CreateMaskFolder>(context);
7144 }
7145 
7146 //===----------------------------------------------------------------------===//
7147 // MaskOp
7148 //===----------------------------------------------------------------------===//
7149 
7150 void MaskOp::build(
7151  OpBuilder &builder, OperationState &result, Value mask,
7152  Operation *maskableOp,
7153  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7154  assert(maskRegionBuilder &&
7155  "builder callback for 'maskRegion' must be present");
7156 
7157  result.addOperands(mask);
7158  OpBuilder::InsertionGuard guard(builder);
7159  Region *maskRegion = result.addRegion();
7160  builder.createBlock(maskRegion);
7161  maskRegionBuilder(builder, maskableOp);
7162 }
7163 
7164 void MaskOp::build(
7165  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7166  Value mask, Operation *maskableOp,
7167  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7168  build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
7169  maskRegionBuilder);
7170 }
7171 
7172 void MaskOp::build(
7173  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7174  Value mask, Value passthru, Operation *maskableOp,
7175  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7176  build(builder, result, mask, maskableOp, maskRegionBuilder);
7177  if (passthru)
7178  result.addOperands(passthru);
7179  result.addTypes(resultTypes);
7180 }
7181 
7182 ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
7183  // Create the op region.
7184  result.regions.reserve(1);
7185  Region &maskRegion = *result.addRegion();
7186 
7187  auto &builder = parser.getBuilder();
7188 
7189  // Parse all the operands.
7191  if (parser.parseOperand(mask))
7192  return failure();
7193 
7194  // Optional passthru operand.
7196  ParseResult parsePassthru = parser.parseOptionalComma();
7197  if (parsePassthru.succeeded() && parser.parseOperand(passthru))
7198  return failure();
7199 
7200  // Parse op region.
7201  if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
7202  return failure();
7203 
7204  MaskOp::ensureTerminator(maskRegion, builder, result.location);
7205 
7206  // Parse the optional attribute list.
7207  if (parser.parseOptionalAttrDict(result.attributes))
7208  return failure();
7209 
7210  // Parse all the types.
7211  Type maskType;
7212  if (parser.parseColonType(maskType))
7213  return failure();
7214 
7215  SmallVector<Type> resultTypes;
7216  if (parser.parseOptionalArrowTypeList(resultTypes))
7217  return failure();
7218  result.types.append(resultTypes);
7219 
7220  // Resolve operands.
7221  if (parser.resolveOperand(mask, maskType, result.operands))
7222  return failure();
7223 
7224  if (parsePassthru.succeeded()) {
7225  if (resultTypes.empty())
7226  return parser.emitError(
7227  parser.getNameLoc(),
7228  "expects a result if passthru operand is provided");
7229 
7230  if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
7231  return failure();
7232  }
7233 
7234  return success();
7235 }
7236 
7238  p << " " << getMask();
7239  if (getPassthru())
7240  p << ", " << getPassthru();
7241 
7242  // Print single masked operation and skip terminator.
7243  p << " { ";
7244  Block *singleBlock = &getMaskRegion().getBlocks().front();
7245  if (singleBlock && !singleBlock->getOperations().empty())
7246  p.printCustomOrGenericOp(&singleBlock->front());
7247  p << " }";
7248 
7249  p.printOptionalAttrDict(getOperation()->getAttrs());
7250 
7251  p << " : " << getMask().getType();
7252  if (getNumResults() > 0)
7253  p << " -> " << getResultTypes();
7254 }
7255 
7256 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
7257  // 1. For an empty `vector.mask`, create a default terminator.
7258  if (region.empty() || region.front().empty()) {
7260  MaskOp>::ensureTerminator(region, builder, loc);
7261  return;
7262  }
7263 
7264  // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
7265  Block &block = region.front();
7266  if (isa<vector::YieldOp>(block.back()))
7267  return;
7268 
7269  // 3. For a non-empty `vector.mask` without an explicit terminator:
7270 
7271  // Create default terminator if the number of masked operations is not
7272  // one. This case will trigger a verification failure.
7273  if (block.getOperations().size() != 1) {
7275  MaskOp>::ensureTerminator(region, builder, loc);
7276  return;
7277  }
7278 
7279  // Create a terminator that yields the results from the masked operation.
7280  OpBuilder opBuilder(builder.getContext());
7281  Operation *maskedOp = &block.front();
7282  opBuilder.setInsertionPointToEnd(&block);
7283  vector::YieldOp::create(opBuilder, loc, maskedOp->getResults());
7284 }
7285 
7286 LogicalResult MaskOp::verify() {
7287  // Structural checks.
7288  Block &block = getMaskRegion().getBlocks().front();
7289  if (block.getOperations().empty())
7290  return emitOpError("expects a terminator within the mask region");
7291 
7292  unsigned numMaskRegionOps = block.getOperations().size();
7293  if (numMaskRegionOps > 2)
7294  return emitOpError("expects only one operation to mask");
7295 
7296  // Terminator checks.
7297  auto terminator = dyn_cast<vector::YieldOp>(block.back());
7298  if (!terminator)
7299  return emitOpError("expects a terminator within the mask region");
7300 
7301  if (terminator->getNumOperands() != getNumResults())
7302  return emitOpError(
7303  "expects number of results to match mask region yielded values");
7304 
7305  // Empty vector.mask. Nothing else to check.
7306  if (numMaskRegionOps == 1)
7307  return success();
7308 
7309  auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
7310  if (!maskableOp)
7311  return emitOpError("expects a MaskableOpInterface within the mask region");
7312 
7313  // Result checks.
7314  if (maskableOp->getNumResults() != getNumResults())
7315  return emitOpError("expects number of results to match maskable operation "
7316  "number of results");
7317 
7318  if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
7319  return emitOpError("expects all the results from the MaskableOpInterface "
7320  "to match all the values returned by the terminator");
7321 
7322  if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
7323  return emitOpError(
7324  "expects result type to match maskable operation result type");
7325 
7326  if (llvm::count_if(maskableOp->getResultTypes(),
7327  [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7328  return emitOpError("multiple vector results not supported");
7329 
7330  // Mask checks.
7331  Type expectedMaskType = maskableOp.getExpectedMaskType();
7332  if (getMask().getType() != expectedMaskType)
7333  return emitOpError("expects a ")
7334  << expectedMaskType << " mask for the maskable operation";
7335 
7336  // Passthru checks.
7337  Value passthru = getPassthru();
7338  if (passthru) {
7339  if (!maskableOp.supportsPassthru())
7340  return emitOpError(
7341  "doesn't expect a passthru argument for this maskable operation");
7342 
7343  if (maskableOp->getNumResults() != 1)
7344  return emitOpError("expects result when passthru argument is provided");
7345 
7346  if (passthru.getType() != maskableOp->getResultTypes()[0])
7347  return emitOpError("expects passthru type to match result type");
7348  }
7349 
7350  return success();
7351 }
7352 
7353 /// Folds empty `vector.mask` with no passthru operand and with or without
7354 /// return values. For example:
7355 ///
7356 /// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
7357 /// vector<8xi1> -> vector<8xf32>
7358 /// %1 = user_op %0 : vector<8xf32>
7359 ///
7360 /// becomes:
7361 ///
7362 /// %0 = user_op %a : vector<8xf32>
7363 ///
7364 /// Empty `vector.mask` with passthru operand are handled by the canonicalizer
7365 /// as it requires creating new operations.
7366 
7367 static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7368  SmallVectorImpl<OpFoldResult> &results) {
7369  if (!maskOp.isEmpty() || maskOp.hasPassthru())
7370  return failure();
7371 
7372  Block *block = maskOp.getMaskBlock();
7373  auto terminator = cast<vector::YieldOp>(block->front());
7374  if (terminator.getNumOperands() == 0) {
7375  // `vector.mask` has no results, just remove the `vector.mask`.
7376  return success();
7377  }
7378 
7379  // `vector.mask` has results, propagate the results.
7380  llvm::append_range(results, terminator.getOperands());
7381  return success();
7382 }
7383 
7384 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7385  SmallVectorImpl<OpFoldResult> &results) {
7386  if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
7387  return success();
7388 
7389  MaskFormat maskFormat = getMaskFormat(getMask());
7390  if (maskFormat != MaskFormat::AllTrue)
7391  return failure();
7392 
7393  // Move maskable operation outside of the `vector.mask` region.
7394  Operation *maskableOp = getMaskableOp();
7395  maskableOp->dropAllUses();
7396  maskableOp->moveBefore(getOperation());
7397 
7398  llvm::append_range(results, maskableOp->getResults());
7399  return success();
7400 }
7401 
7402 /// Canonialize empty `vector.mask` operations that can't be handled in
7403 /// `VectorMask::fold` as they require creating new operations.
7404 ///
7405 /// Example 1: Empty `vector.mask` with passthru operand.
7406 ///
7407 /// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
7408 /// vector<8xi1> -> vector<8xf32>
7409 ///
7410 /// becomes:
7411 ///
7412 /// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
7413 ///
7414 class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
7415  using Base::Base;
7416 
7417  LogicalResult matchAndRewrite(MaskOp maskOp,
7418  PatternRewriter &rewriter) const override {
7419  if (!maskOp.isEmpty())
7420  return failure();
7421 
7422  if (!maskOp.hasPassthru())
7423  return failure();
7424 
7425  Block *block = maskOp.getMaskBlock();
7426  auto terminator = cast<vector::YieldOp>(block->front());
7427  assert(terminator.getNumOperands() == 1 &&
7428  "expected one result when passthru is provided");
7429 
7430  rewriter.replaceOpWithNewOp<arith::SelectOp>(
7431  maskOp, maskOp.getResultTypes(), maskOp.getMask(),
7432  terminator.getOperand(0), maskOp.getPassthru());
7433 
7434  return success();
7435  }
7436 };
7437 
7438 void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7439  MLIRContext *context) {
7440  results.add<CanonializeEmptyMaskOp>(context);
7441 }
7442 
7443 // MaskingOpInterface definitions.
7444 
7445 /// Returns the operation masked by this 'vector.mask'.
7446 Operation *MaskOp::getMaskableOp() {
7447  Block *block = getMaskBlock();
7448  if (block->getOperations().size() < 2)
7449  return nullptr;
7450 
7451  return &block->front();
7452 }
7453 
7454 /// Returns true if 'vector.mask' has a passthru value.
7455 bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
7456 
7457 //===----------------------------------------------------------------------===//
7458 // ScanOp
7459 //===----------------------------------------------------------------------===//
7460 
7461 LogicalResult ScanOp::verify() {
7462  VectorType srcType = getSourceType();
7463  VectorType initialType = getInitialValueType();
7464  // Check reduction dimension < rank.
7465  int64_t srcRank = srcType.getRank();
7466  int64_t reductionDim = getReductionDim();
7467  if (reductionDim >= srcRank)
7468  return emitOpError("reduction dimension ")
7469  << reductionDim << " has to be less than " << srcRank;
7470 
7471  // Check that rank(initial_value) = rank(src) - 1.
7472  int64_t initialValueRank = initialType.getRank();
7473  if (initialValueRank != srcRank - 1)
7474  return emitOpError("initial value rank ")
7475  << initialValueRank << " has to be equal to " << srcRank - 1;
7476 
7477  // Check shapes of initial value and src.
7478  ArrayRef<int64_t> srcShape = srcType.getShape();
7479  ArrayRef<int64_t> initialValueShapes = initialType.getShape();
7480  SmallVector<int64_t> expectedShape;
7481  for (int i = 0; i < srcRank; i++) {
7482  if (i != reductionDim)
7483  expectedShape.push_back(srcShape[i]);
7484  }
7485  if (!llvm::equal(initialValueShapes, expectedShape)) {
7486  return emitOpError("incompatible input/initial value shapes");
7487  }
7488 
7489  // Verify supported reduction kind.
7490  Type eltType = getDestType().getElementType();
7491  if (!isSupportedCombiningKind(getKind(), eltType))
7492  return emitOpError("unsupported reduction type ")
7493  << eltType << " for kind '" << stringifyCombiningKind(getKind())
7494  << "'";
7495 
7496  return success();
7497 }
7498 
7501  patterns
7502  .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7503  ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7504  StridedSliceConstantMaskFolder, TransposeFolder>(
7505  patterns.getContext(), benefit);
7506 }
7507 
7509  CombiningKind kind, Value v1, Value acc,
7510  arith::FastMathFlagsAttr fastmath,
7511  Value mask) {
7512  Type t1 = getElementTypeOrSelf(v1.getType());
7513  Type tAcc = getElementTypeOrSelf(acc.getType());
7514  Value result;
7515 
7516  switch (kind) {
7517  case CombiningKind::ADD:
7518  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7519  result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
7520  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7521  result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7522  else
7523  llvm_unreachable("invalid value types for ADD reduction");
7524  break;
7525  case CombiningKind::AND:
7526  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7527  result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
7528  break;
7529  case CombiningKind::MAXNUMF:
7530  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7531  "expected float values");
7532  result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7533  break;
7534  case CombiningKind::MAXIMUMF:
7535  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7536  "expected float values");
7537  result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7538  break;
7539  case CombiningKind::MINNUMF:
7540  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7541  "expected float values");
7542  result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7543  break;
7544  case CombiningKind::MINIMUMF:
7545  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7546  "expected float values");
7547  result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7548  break;
7549  case CombiningKind::MAXSI:
7550  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7551  result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
7552  break;
7553  case CombiningKind::MINSI:
7554  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7555  result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
7556  break;
7557  case CombiningKind::MAXUI:
7558  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7559  result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
7560  break;
7561  case CombiningKind::MINUI:
7562  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7563  result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
7564  break;
7565  case CombiningKind::MUL:
7566  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7567  result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
7568  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7569  result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7570  else
7571  llvm_unreachable("invalid value types for MUL reduction");
7572  break;
7573  case CombiningKind::OR:
7574  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7575  result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
7576  break;
7577  case CombiningKind::XOR:
7578  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7579  result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
7580  break;
7581  };
7582 
7583  assert(result && "unknown CombiningKind");
7584  return selectPassthru(b, mask, result, acc);
7585 }
7586 
7587 //===----------------------------------------------------------------------===//
7588 // StepOp
7589 //===----------------------------------------------------------------------===//
7590 
7591 void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7592  SetIntRangeFn setResultRanges) {
7593  auto resultType = cast<VectorType>(getType());
7594  if (resultType.isScalable()) {
7595  return;
7596  }
7597  unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
7598  APInt zero(bitwidth, 0);
7599  APInt high(bitwidth, resultType.getDimSize(0) - 1);
7600  ConstantIntRanges result = {zero, high, zero, high};
7601  setResultRanges(getResult(), result);
7602 }
7603 
7604 //===----------------------------------------------------------------------===//
7605 // Vector Masking Utilities
7606 //===----------------------------------------------------------------------===//
7607 
7608 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
7609 /// as masked operation.
7611  Operation *maskableOp) {
7612  assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
7613  Block *insBlock = builder.getInsertionBlock();
7614  // Create a block and move the op to that block.
7615  insBlock->getOperations().splice(
7616  insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
7617  YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults());
7618 }
7619 
7620 /// Creates a vector.mask operation around a maskable operation. Returns the
7621 /// vector.mask operation if the mask provided is valid. Otherwise, returns
7622 /// the maskable operation itself.
7624  Operation *maskableOp, Value mask,
7625  Value passthru) {
7626  if (!mask)
7627  return maskableOp;
7628  if (passthru)
7629  return MaskOp::create(builder, maskableOp->getLoc(),
7630  maskableOp->getResultTypes(), mask, passthru,
7631  maskableOp, createMaskOpRegion);
7632  return MaskOp::create(builder, maskableOp->getLoc(),
7633  maskableOp->getResultTypes(), mask, maskableOp,
7635 }
7636 
7637 /// Creates a vector select operation that picks values from `newValue` or
7638 /// `passthru` for each result vector lane based on `mask`. This utility is used
7639 /// to propagate the pass-thru value of vector.mask or for cases where only the
7640 /// pass-thru value propagation is needed. VP intrinsics do not support
7641 /// pass-thru values and every mask-out lane is set to poison. LLVM backends are
7642 /// usually able to match op + select patterns and fold them into a native
7643 /// target instructions.
7645  Value newValue, Value passthru) {
7646  if (!mask)
7647  return newValue;
7648 
7649  return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(),
7650  mask, newValue, passthru);
7651 }
7652 
7653 //===----------------------------------------------------------------------===//
7654 // TableGen'd op method definitions
7655 //===----------------------------------------------------------------------===//
7656 
7657 #define GET_ATTRDEF_CLASSES
7658 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
7659 
7660 #define GET_OP_CLASSES
7661 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static SmallVector< Value > computeStrides(Location loc, RewriterBase &rewriter, ValueRange dynamicBasis, ArrayRef< int64_t > staticBasis, bool knownNonNegative)
Given a basis (in static and dynamic components), return the sequence of suffix products of the basis...
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:51
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
union mlir::linalg::@1247::ArityGroupAndKind::Kind kind
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MINUI(lhs, rhs)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition: VectorOps.cpp:70
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
Definition: VectorOps.cpp:1145
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
Definition: VectorOps.cpp:1397
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
Definition: VectorOps.cpp:2563
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
Definition: VectorOps.cpp:1658
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
Definition: VectorOps.cpp:4796
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
Definition: VectorOps.cpp:2004
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
Definition: VectorOps.cpp:403
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
Definition: VectorOps.cpp:1872
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
Definition: VectorOps.cpp:1710
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
Definition: VectorOps.cpp:2405
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
Definition: VectorOps.cpp:2099
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Definition: VectorOps.cpp:2363
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
Definition: VectorOps.cpp:2996
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
Definition: VectorOps.cpp:132
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
Definition: VectorOps.cpp:2089
MaskFormat
Helper enum to classify mask value.
Definition: VectorOps.cpp:60
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
Definition: VectorOps.cpp:3794
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
Definition: VectorOps.cpp:179
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
Definition: VectorOps.cpp:2596
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
Definition: VectorOps.cpp:328
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
Definition: VectorOps.cpp:926
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
Definition: VectorOps.cpp:2784
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
Definition: VectorOps.cpp:2040
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
Definition: VectorOps.cpp:2427
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
Definition: VectorOps.cpp:3142
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
Definition: VectorOps.cpp:3730
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
Definition: VectorOps.cpp:2383
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
Definition: VectorOps.cpp:1137
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
Definition: VectorOps.cpp:3750
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
Definition: VectorOps.cpp:207
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
Definition: VectorOps.cpp:3773
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
Definition: VectorOps.cpp:3601
static LogicalResult foldTransferFullMask(TransferOp op)
Definition: VectorOps.cpp:5042
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Definition: VectorOps.cpp:1389
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
Definition: VectorOps.cpp:1279
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
Definition: VectorOps.cpp:4284
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
Definition: VectorOps.cpp:4688
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
Definition: VectorOps.cpp:2634
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
Definition: VectorOps.cpp:3648
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
Definition: VectorOps.cpp:937
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
Definition: VectorOps.cpp:1670
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
Definition: VectorOps.cpp:3715
static Value foldExtractFromShapeCast(ExtractOp extractOp)
Definition: VectorOps.cpp:1807
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
Definition: VectorOps.cpp:4717
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
Definition: VectorOps.cpp:4970
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
Definition: VectorOps.cpp:4128
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
Definition: VectorOps.cpp:1775
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
Definition: VectorOps.cpp:4987
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
Definition: VectorOps.cpp:1924
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
Definition: VectorOps.cpp:3426
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
Definition: VectorOps.cpp:2107
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Definition: VectorOps.cpp:2664
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:131
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:411
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:394
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:212
unsigned getNumResults() const
Definition: AffineMap.cpp:398
unsigned getNumInputs() const
Definition: AffineMap.cpp:399
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:157
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:58
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
Operation & back()
Definition: Block.h:152
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:108
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:167
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:112
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:324
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:266
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:281
IndexType getIndexType()
Definition: Builders.cpp:51
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:270
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:318
A set of arbitrary-precision integers representing bounds on a given integer value.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:442
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:834
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
bool isF32() const
Definition: Types.cpp:40
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:112
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:286
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Definition: BuiltinTypes.h:298
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:305
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:391
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Definition: VectorOps.cpp:508
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
Definition: VectorOps.cpp:127
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
Definition: VectorOps.cpp:385
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:189
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition: VectorOps.h:63
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2916
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Definition: VectorOps.cpp:4815
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
Definition: VectorOps.cpp:251
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:315
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:346
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:370
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
Definition: VectorOps.cpp:684
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition: VectorOps.h:71
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Definition: VectorOps.cpp:504
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:766
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:715
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:923
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Return a fused vector::ContractionOp which represents a patterns such as:
Definition: VectorOps.cpp:1237
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Definition: VectorOps.cpp:1240
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
Definition: VectorOps.cpp:2473
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
Definition: VectorOps.cpp:2476
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
Definition: PatternMatch.h:317
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Definition: PatternMatch.h:297
bool operator==(const KeyTy &key) const
Definition: VectorOps.cpp:439
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Definition: VectorOps.cpp:441