MLIR  22.0.0git
VectorOps.cpp
Go to the documentation of this file.
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/AffineMap.h"
28 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/IRMapping.h"
34 #include "mlir/IR/PatternMatch.h"
35 #include "mlir/IR/TypeUtilities.h"
36 #include "mlir/IR/ValueRange.h"
39 #include "mlir/Support/LLVM.h"
41 #include "llvm/ADT/ArrayRef.h"
42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/ADT/SmallVector.h"
44 #include "llvm/ADT/StringSet.h"
45 #include "llvm/ADT/TypeSwitch.h"
46 #include "llvm/Support/Casting.h"
47 
48 #include <cassert>
49 #include <cstdint>
50 
51 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
52 // Pull in all enum type and utility function definitions.
53 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
54 
55 using namespace mlir;
56 using namespace mlir::vector;
57 
58 /// Helper enum to classify mask value.
59 enum class MaskFormat {
60  AllTrue = 0,
61  AllFalse = 1,
62  Unknown = 2,
63 };
64 
65 /// Helper method to classify a mask value. Currently, the method
66 /// looks "under the hood" of a constant value with dense attributes
67 /// and a constant mask operation (since the client may be called at
68 /// various stages during progressive lowering).
70  if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
71  // Inspect constant dense values. We count up for bits that
72  // are set, count down for bits that are cleared, and bail
73  // when a mix is detected.
74  if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
75  int64_t val = 0;
76  for (bool b : denseElts.getValues<bool>())
77  if (b && val >= 0)
78  val++;
79  else if (!b && val <= 0)
80  val--;
81  else
82  return MaskFormat::Unknown;
83  if (val > 0)
84  return MaskFormat::AllTrue;
85  if (val < 0)
86  return MaskFormat::AllFalse;
87  }
88  } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
89  // Inspect constant mask index. If the index exceeds the
90  // dimension size, all bits are set. If the index is zero
91  // or less, no bits are set.
92  ArrayRef<int64_t> masks = m.getMaskDimSizes();
93  auto shape = m.getType().getShape();
94  bool allTrue = true;
95  bool allFalse = true;
96  for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
97  if (maskIdx < dimSize)
98  allTrue = false;
99  if (maskIdx > 0)
100  allFalse = false;
101  }
102  if (allTrue)
103  return MaskFormat::AllTrue;
104  if (allFalse)
105  return MaskFormat::AllFalse;
106  } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
107  // Finds all-false create_masks. An all-true create_mask requires all
108  // dims to be constants, so that'll be folded to a constant_mask, then
109  // detected in the constant_mask case.
110  auto maskOperands = m.getOperands();
111  for (Value operand : maskOperands) {
112  if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
113  int64_t dimSize =
114  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
115  if (dimSize <= 0)
116  return MaskFormat::AllFalse;
117  }
118  }
119  return MaskFormat::Unknown;
120  }
121  return MaskFormat::Unknown;
122 }
123 
124 /// Default callback to build a region with a 'vector.yield' terminator with no
125 /// arguments.
127  vector::YieldOp::create(builder, loc);
128 }
129 
130 // Helper for verifying combining kinds in contractions and reductions.
131 static bool isSupportedCombiningKind(CombiningKind combiningKind,
132  Type elementType) {
133  switch (combiningKind) {
134  case CombiningKind::ADD:
135  case CombiningKind::MUL:
136  return elementType.isIntOrIndexOrFloat();
138  case CombiningKind::MINSI:
139  case CombiningKind::MAXUI:
140  case CombiningKind::MAXSI:
141  case CombiningKind::AND:
142  case CombiningKind::OR:
143  case CombiningKind::XOR:
144  return elementType.isIntOrIndex();
145  case CombiningKind::MINNUMF:
146  case CombiningKind::MAXNUMF:
147  case CombiningKind::MINIMUMF:
148  case CombiningKind::MAXIMUMF:
149  return llvm::isa<FloatType>(elementType);
150  }
151  return false;
152 }
153 
154 /// Returns the effective rank of the vector to read/write for Xfer Ops
155 ///
156 /// When the element type of the shaped type is _a scalar_, this will simply
157 /// return the rank of the vector ( the result for xfer_read or the value to
158 /// store for xfer_write).
159 ///
160 /// When the element type of the base shaped type is _a vector_, returns the
161 /// difference between the original vector type and the element type of the
162 /// shaped type.
163 ///
164 /// EXAMPLE 1 (element type is _a scalar_):
165 /// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
166 /// - shapedType.getElementType() = f32 (rank 0)
167 /// - vectorType.getRank() = 2
168 /// - Result = 2 - 0 = 2
169 ///
170 /// EXAMPLE 2 (element type is _a vector_):
171 /// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
172 /// - shapedType.getElementType() = vector<20xf32> (rank 1)
173 /// - vectorType.getRank() = 1
174 /// - Result = 1 - 1 = 0
175 ///
176 /// This is used to determine the number of minor dimensions for identity maps
177 /// in vector transfer Ops.
178 static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
179  VectorType vectorType) {
180  unsigned elementVectorRank = 0;
181  VectorType elementVectorType =
182  llvm::dyn_cast<VectorType>(shapedType.getElementType());
183  if (elementVectorType)
184  elementVectorRank += elementVectorType.getRank();
185  return vectorType.getRank() - elementVectorRank;
186 }
187 
189  VectorType vectorType) {
190  // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
191  // TODO: replace once we have 0-d vectors.
192  if (shapedType.getRank() == 0 &&
193  vectorType.getShape() == ArrayRef<int64_t>{1})
194  return AffineMap::get(
195  /*numDims=*/0, /*numSymbols=*/0,
196  getAffineConstantExpr(0, shapedType.getContext()));
198  shapedType.getRank(),
199  getEffectiveVectorRankForXferOp(shapedType, vectorType),
200  shapedType.getContext());
201 }
202 
203 /// Check if `write` is of a constant splat and the masked `read` is padded with
204 /// the same splat value -- meaning it could be the same value as the initial
205 /// constant splat.
206 static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
207  vector::TransferReadOp read) {
208  auto readMask = read.getMask();
209  auto writeMask = write.getMask();
210  // Check if the masks are consistent. The splat value could be the same if the
211  // read is masked (and padded with the splat value), and the write is unmasked
212  // or has the same mask. Note this does not allow the case where the write is
213  // masked and the read is unmasked, as then the read could be of more elements
214  // than the write (which may not be the same value).
215  bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
216  if (!couldBeSameSplat)
217  return false;
218  // Check for constant splat (as the source of the write).
219  DenseElementsAttr splatAttr;
220  if (!matchPattern(write.getVector(),
221  m_Constant<DenseElementsAttr>(&splatAttr)) ||
222  !splatAttr.isSplat()) {
223  return false;
224  }
225  // The padding of the read and the constant splat value must be the same.
226  Attribute padAttr;
227  if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
228  return false;
229  return padAttr == splatAttr.getSplatValue<Attribute>();
230 }
231 
232 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
233  vector::TransferReadOp read) {
234  return !defWrite.hasOutOfBoundsDim() &&
235  defWrite.getIndices() == read.getIndices() &&
236  defWrite.getVectorType() == read.getVectorType() &&
237  defWrite.getPermutationMap() == read.getPermutationMap() &&
238  ((!defWrite.getMask() && !read.getMask()) ||
239  isSplatWriteConsistentWithMaskedRead(defWrite, read));
240 }
241 
242 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
243  vector::TransferWriteOp priorWrite) {
244  return priorWrite.getIndices() == write.getIndices() &&
245  priorWrite.getMask() == write.getMask() &&
246  priorWrite.getVectorType() == write.getVectorType() &&
247  priorWrite.getPermutationMap() == write.getPermutationMap();
248 }
249 
251  VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
252  bool testDynamicValueUsingBounds) {
253  // For simplicity only look at transfer of same type.
254  if (transferA.getVectorType() != transferB.getVectorType())
255  return false;
256  unsigned rankOffset = transferA.getLeadingShapedRank();
257  for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
258  Value indexA = transferA.getIndices()[i];
259  Value indexB = transferB.getIndices()[i];
260  std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
261  std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
262 
263  if (i < rankOffset) {
264  // For leading dimensions, if we can prove that index are different we
265  // know we are accessing disjoint slices.
266  if (cstIndexA.has_value() && cstIndexB.has_value()) {
267  if (*cstIndexA != *cstIndexB)
268  return true;
269  continue;
270  }
271  if (testDynamicValueUsingBounds) {
272  // First try to see if we can fully compose and simplify the affine
273  // expression as a fast track.
274  FailureOr<uint64_t> delta =
276  if (succeeded(delta) && *delta != 0)
277  return true;
278 
279  FailureOr<bool> testEqual =
280  ValueBoundsConstraintSet::areEqual(indexA, indexB);
281  if (succeeded(testEqual) && !testEqual.value())
282  return true;
283  }
284  } else {
285  // For this dimension, we slice a part of the memref we need to make sure
286  // the intervals accessed don't overlap.
287  int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
288  if (cstIndexA.has_value() && cstIndexB.has_value()) {
289  int64_t distance = std::abs(*cstIndexA - *cstIndexB);
290  if (distance >= vectorDim)
291  return true;
292  continue;
293  }
294  if (testDynamicValueUsingBounds) {
295  // First try to see if we can fully compose and simplify the affine
296  // expression as a fast track.
297  FailureOr<int64_t> delta =
299  if (succeeded(delta) && std::abs(*delta) >= vectorDim)
300  return true;
301 
302  FailureOr<int64_t> computeDelta =
304  if (succeeded(computeDelta)) {
305  if (std::abs(computeDelta.value()) >= vectorDim)
306  return true;
307  }
308  }
309  }
310  }
311  return false;
312 }
313 
314 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
315  VectorTransferOpInterface transferB,
316  bool testDynamicValueUsingBounds) {
317  if (transferA.getBase() != transferB.getBase())
318  return false;
319  return isDisjointTransferIndices(transferA, transferB,
320  testDynamicValueUsingBounds);
321 }
322 
323 // Helper to iterate over n-D vector slice elements. Calculate the next
324 // `position` in the n-D vector of size `shape`, applying an offset `offsets`.
325 // Modifies the `position` in place. Returns a failure when `position` becomes
326 // the end position.
327 static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
328  ArrayRef<int64_t> shape,
329  ArrayRef<int64_t> offsets) {
330  for (auto [posInDim, dimSize, offsetInDim] :
331  llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
332  ++posInDim;
333  if (posInDim < dimSize + offsetInDim)
334  return success();
335 
336  // Carry the overflow to the next loop iteration.
337  posInDim = offsetInDim;
338  }
339 
340  return failure();
341 }
342 
343 /// Returns the integer numbers in `values`. `values` are expected to be
344 /// constant operations.
347  llvm::transform(values, std::back_inserter(ints), [](Value value) {
348  auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
349  assert(constOp && "Unexpected non-constant index");
350  return constOp.value();
351  });
352  return ints;
353 }
354 
355 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
356 /// be constant operations.
359  llvm::transform(
360  foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
361  assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
362  return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
363  });
364  return ints;
365 }
366 
367 /// Convert `foldResults` into Values. Integer attributes are converted to
368 /// constant op.
370  ArrayRef<OpFoldResult> foldResults) {
371  SmallVector<Value> values;
372  llvm::transform(foldResults, std::back_inserter(values),
373  [&](OpFoldResult foldResult) {
374  if (auto attr = dyn_cast<Attribute>(foldResult))
376  builder, loc, cast<IntegerAttr>(attr).getInt())
377  .getResult();
378 
379  return cast<Value>(foldResult);
380  });
381  return values;
382 }
383 
384 std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
385  if (value.getDefiningOp<vector::VectorScaleOp>())
386  return 1;
387  auto mul = value.getDefiningOp<arith::MulIOp>();
388  if (!mul)
389  return {};
390  auto lhs = mul.getLhs();
391  auto rhs = mul.getRhs();
392  if (lhs.getDefiningOp<vector::VectorScaleOp>())
393  return getConstantIntValue(rhs);
394  if (rhs.getDefiningOp<vector::VectorScaleOp>())
395  return getConstantIntValue(lhs);
396  return {};
397 }
398 
399 /// Converts an IntegerAttr to have the specified type if needed.
400 /// This handles cases where constant attributes have a different type than the
401 /// target element type. If the input attribute is not an IntegerAttr or already
402 /// has the correct type, returns it unchanged.
403 static Attribute convertIntegerAttr(Attribute attr, Type expectedType) {
404  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
405  if (intAttr.getType() != expectedType)
406  return IntegerAttr::get(expectedType, intAttr.getInt());
407  }
408  return attr;
409 }
410 
411 //===----------------------------------------------------------------------===//
412 // CombiningKindAttr
413 //===----------------------------------------------------------------------===//
414 
415 namespace mlir {
416 namespace vector {
417 namespace detail {
419  using KeyTy = uint64_t;
420 
421  BitmaskEnumStorage(KeyTy val) : value(val) {}
422 
423  bool operator==(const KeyTy &key) const { return value == key; }
424 
426  const KeyTy &key) {
427  return new (allocator.allocate<BitmaskEnumStorage>())
428  BitmaskEnumStorage(key);
429  }
430 
431  KeyTy value = 0;
432 };
433 } // namespace detail
434 } // namespace vector
435 } // namespace mlir
436 
437 //===----------------------------------------------------------------------===//
438 // VectorDialect
439 //===----------------------------------------------------------------------===//
440 
441 namespace {
442 /// This class defines the interface for handling inlining with vector dialect
443 /// operations.
444 struct VectorInlinerInterface : public DialectInlinerInterface {
446 
447  /// All vector dialect ops can be inlined.
448  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
449  return true;
450  }
451 };
452 } // namespace
453 
454 void VectorDialect::initialize() {
455  addAttributes<
456 #define GET_ATTRDEF_LIST
457 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
458  >();
459 
460  addOperations<
461 #define GET_OP_LIST
462 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
463  >();
464 
465  addInterfaces<VectorInlinerInterface>();
466 
467  declarePromisedInterfaces<bufferization::BufferizableOpInterface,
468  TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
469  YieldOp>();
470  declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
471  TransferWriteOp>();
472  declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
473  declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
474  declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
475 }
476 
477 /// Materialize a single constant operation from a given attribute value with
478 /// the desired resultant type.
480  Attribute value, Type type,
481  Location loc) {
482  if (isa<ub::PoisonAttrInterface>(value))
483  return value.getDialect().materializeConstant(builder, value, type, loc);
484 
485  return arith::ConstantOp::materialize(builder, value, type, loc);
486 }
487 
489  return builder.getIntegerType(64);
490 }
491 
493  ArrayRef<int64_t> values) {
494  return builder.getI64ArrayAttr(values);
495 }
496 
497 //===----------------------------------------------------------------------===//
498 // MultiDimReductionOp
499 //===----------------------------------------------------------------------===//
500 
501 void vector::MultiDimReductionOp::build(OpBuilder &builder,
502  OperationState &result, Value source,
503  Value acc, ArrayRef<bool> reductionMask,
504  CombiningKind kind) {
505  SmallVector<int64_t> reductionDims;
506  for (const auto &en : llvm::enumerate(reductionMask))
507  if (en.value())
508  reductionDims.push_back(en.index());
509  build(builder, result, kind, source, acc, reductionDims);
510 }
511 
512 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
513  // Single parallel dim, this is a noop.
514  if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
515  return getSource();
516  return {};
517 }
518 
519 std::optional<SmallVector<int64_t, 4>>
520 MultiDimReductionOp::getShapeForUnroll() {
521  return llvm::to_vector<4>(getSourceVectorType().getShape());
522 }
523 
524 LogicalResult MultiDimReductionOp::verify() {
525  SmallVector<int64_t> targetShape;
526  SmallVector<bool> scalableDims;
527  Type inferredReturnType;
528  auto sourceScalableDims = getSourceVectorType().getScalableDims();
529  for (auto [dimIdx, dimSize] :
530  llvm::enumerate(getSourceVectorType().getShape()))
531  if (!llvm::any_of(getReductionDims(),
532  [dimIdx = dimIdx](int64_t reductionDimIdx) {
533  return reductionDimIdx == static_cast<int64_t>(dimIdx);
534  })) {
535  targetShape.push_back(dimSize);
536  scalableDims.push_back(sourceScalableDims[dimIdx]);
537  }
538  // TODO: update to also allow 0-d vectors when available.
539  if (targetShape.empty())
540  inferredReturnType = getSourceVectorType().getElementType();
541  else
542  inferredReturnType = VectorType::get(
543  targetShape, getSourceVectorType().getElementType(), scalableDims);
544  if (getType() != inferredReturnType)
545  return emitOpError() << "destination type " << getType()
546  << " is incompatible with source type "
547  << getSourceVectorType();
548 
549  return success();
550 }
551 
552 /// Returns the mask type expected by this operation.
553 Type MultiDimReductionOp::getExpectedMaskType() {
554  auto vecType = getSourceVectorType();
555  return VectorType::get(vecType.getShape(),
556  IntegerType::get(vecType.getContext(), /*width=*/1),
557  vecType.getScalableDims());
558 }
559 
560 namespace {
561 // Only unit dimensions that are being reduced are folded. If the dimension is
562 // unit, but not reduced, it is not folded, thereby keeping the output type the
563 // same. If not all dimensions which are reduced are of unit dimension, this
564 // transformation does nothing. This is just a generalization of
565 // ElideSingleElementReduction for ReduceOp.
566 struct ElideUnitDimsInMultiDimReduction
567  : public OpRewritePattern<MultiDimReductionOp> {
569 
570  LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
571  PatternRewriter &rewriter) const override {
572  ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
573  for (const auto &dim : enumerate(shape)) {
574  if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
575  return failure();
576  }
577 
578  // Vector mask setup.
579  OpBuilder::InsertionGuard guard(rewriter);
580  Operation *rootOp;
581  Value mask;
582  if (reductionOp.isMasked()) {
583  rewriter.setInsertionPoint(reductionOp.getMaskingOp());
584  rootOp = reductionOp.getMaskingOp();
585  mask = reductionOp.getMaskingOp().getMask();
586  } else {
587  rootOp = reductionOp;
588  }
589 
590  Location loc = reductionOp.getLoc();
591  Value acc = reductionOp.getAcc();
592  Value cast;
593  if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
594  if (mask) {
595  VectorType newMaskType =
596  VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
597  dstVecType.getScalableDims());
598  mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
599  }
600  cast = vector::ShapeCastOp::create(
601  rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
602  } else {
603  // This means we are reducing all the dimensions, and all reduction
604  // dimensions are of size 1. So a simple extraction would do.
605  if (mask)
606  mask = vector::ExtractOp::create(rewriter, loc, mask);
607  cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
608  }
609 
610  Value result =
611  vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
612  cast, /*fastmath=*/nullptr, mask);
613  rewriter.replaceOp(rootOp, result);
614  return success();
615  }
616 };
617 } // namespace
618 
619 void MultiDimReductionOp::getCanonicalizationPatterns(
620  RewritePatternSet &results, MLIRContext *context) {
621  results.add<ElideUnitDimsInMultiDimReduction>(context);
622 }
623 
624 //===----------------------------------------------------------------------===//
625 // ReductionOp
626 //===----------------------------------------------------------------------===//
627 
628 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
629  CombiningKind kind, Value vector,
630  arith::FastMathFlags fastMathFlags) {
631  build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
632 }
633 
634 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
635  CombiningKind kind, Value vector, Value acc,
636  arith::FastMathFlags fastMathFlags) {
637  build(builder, result,
638  llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
639  acc, fastMathFlags);
640 }
641 
642 LogicalResult ReductionOp::verify() {
643  // Verify for 0-D and 1-D vector.
644  int64_t rank = getSourceVectorType().getRank();
645  if (rank > 1)
646  return emitOpError("unsupported reduction rank: ") << rank;
647 
648  // Verify supported reduction kind.
649  Type eltType = getDest().getType();
650  if (!isSupportedCombiningKind(getKind(), eltType))
651  return emitOpError("unsupported reduction type '")
652  << eltType << "' for kind '" << stringifyCombiningKind(getKind())
653  << "'";
654 
655  return success();
656 }
657 
658 // MaskableOpInterface methods.
659 
660 /// Returns the mask type expected by this operation.
661 Type ReductionOp::getExpectedMaskType() {
662  auto vecType = getSourceVectorType();
663  return VectorType::get(vecType.getShape(),
664  IntegerType::get(vecType.getContext(), /*width=*/1),
665  vecType.getScalableDims());
666 }
667 
669  OpBuilder &builder, Location loc,
670  Value vector) {
671  switch (op) {
672  case arith::AtomicRMWKind::addf:
673  case arith::AtomicRMWKind::addi:
674  return vector::ReductionOp::create(builder, vector.getLoc(),
675  CombiningKind::ADD, vector);
676  case arith::AtomicRMWKind::mulf:
677  case arith::AtomicRMWKind::muli:
678  return vector::ReductionOp::create(builder, vector.getLoc(),
679  CombiningKind::MUL, vector);
680  case arith::AtomicRMWKind::minimumf:
681  return vector::ReductionOp::create(builder, vector.getLoc(),
682  CombiningKind::MINIMUMF, vector);
683  case arith::AtomicRMWKind::mins:
684  return vector::ReductionOp::create(builder, vector.getLoc(),
685  CombiningKind::MINSI, vector);
686  case arith::AtomicRMWKind::minu:
687  return vector::ReductionOp::create(builder, vector.getLoc(),
688  CombiningKind::MINUI, vector);
689  case arith::AtomicRMWKind::maximumf:
690  return vector::ReductionOp::create(builder, vector.getLoc(),
691  CombiningKind::MAXIMUMF, vector);
692  case arith::AtomicRMWKind::maxs:
693  return vector::ReductionOp::create(builder, vector.getLoc(),
694  CombiningKind::MAXSI, vector);
695  case arith::AtomicRMWKind::maxu:
696  return vector::ReductionOp::create(builder, vector.getLoc(),
697  CombiningKind::MAXUI, vector);
698  case arith::AtomicRMWKind::andi:
699  return vector::ReductionOp::create(builder, vector.getLoc(),
700  CombiningKind::AND, vector);
701  case arith::AtomicRMWKind::ori:
702  return vector::ReductionOp::create(builder, vector.getLoc(),
703  CombiningKind::OR, vector);
704  // TODO: Add remaining reduction operations.
705  default:
706  (void)emitOptionalError(loc, "Reduction operation type not supported");
707  break;
708  }
709  return nullptr;
710 }
711 
712 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
713  return llvm::to_vector<4>(getSourceVectorType().getShape());
714 }
715 
716 namespace {
717 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
719 
720  LogicalResult matchAndRewrite(ReductionOp reductionOp,
721  PatternRewriter &rewriter) const override {
722  // Vector mask setup.
723  OpBuilder::InsertionGuard guard(rewriter);
724  auto maskableOp =
725  cast<vector::MaskableOpInterface>(reductionOp.getOperation());
726  Operation *rootOp;
727  Value mask;
728  if (maskableOp.isMasked()) {
729  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
730  rootOp = maskableOp.getMaskingOp();
731  mask = maskableOp.getMaskingOp().getMask();
732  } else {
733  rootOp = reductionOp;
734  }
735 
736  auto vectorType = reductionOp.getSourceVectorType();
737  if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
738  return failure();
739 
740  Location loc = reductionOp.getLoc();
741  if (mask)
742  mask = ExtractOp::create(rewriter, loc, mask);
743  Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
744 
745  if (Value acc = reductionOp.getAcc())
746  result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
747  result, acc,
748  reductionOp.getFastmathAttr(), mask);
749 
750  rewriter.replaceOp(rootOp, result);
751  return success();
752  }
753 };
754 } // namespace
755 
756 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
757  MLIRContext *context) {
758  results.add<ElideSingleElementReduction>(context);
759 }
760 
761 //===----------------------------------------------------------------------===//
762 // ContractionOp
763 //===----------------------------------------------------------------------===//
764 
765 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
766  Value lhs, Value rhs, Value acc,
767  ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
768  ArrayRef<IteratorType> iteratorTypes) {
769  result.addOperands({lhs, rhs, acc});
770  result.addTypes(acc.getType());
771  result.addAttribute(
772  getIndexingMapsAttrName(result.name),
773  builder.getAffineMapArrayAttr(
774  AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
775  result.addAttribute(
776  getIteratorTypesAttrName(result.name),
777  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
778  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
779  return IteratorTypeAttr::get(builder.getContext(), t);
780  }))));
781 }
782 
783 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
784  Value lhs, Value rhs, Value acc,
785  ArrayAttr indexingMaps,
786  ArrayAttr iteratorTypes) {
787  build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
788  ContractionOp::getDefaultKind());
789 }
790 
791 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
792  Value lhs, Value rhs, Value acc,
793  ArrayAttr indexingMaps,
794  ArrayAttr iteratorTypes, CombiningKind kind) {
795  result.addOperands({lhs, rhs, acc});
796  result.addTypes(acc.getType());
797  result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
798  result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
799  result.addAttribute(getKindAttrName(result.name),
801 }
802 
803 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
808  SmallVector<Type, 2> types;
809  Type resultType;
810  auto loc = parser.getCurrentLocation();
811  DictionaryAttr dictAttr;
812  // TODO: Unify linalg op attribute parsing.
813  if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
814  parser.parseComma() || parser.parseOperand(rhsInfo) ||
815  parser.parseComma() || parser.parseOperand(accInfo) ||
816  parser.parseTrailingOperandList(masksInfo) ||
817  parser.parseOptionalAttrDict(result.attributes) ||
818  parser.parseColonTypeList(types) ||
819  parser.parseKeywordType("into", resultType) ||
820  parser.resolveOperand(lhsInfo, types[0], result.operands) ||
821  parser.resolveOperand(rhsInfo, types[1], result.operands) ||
822  parser.resolveOperand(accInfo, resultType, result.operands) ||
823  parser.addTypeToList(resultType, result.types))
824  return failure();
825  result.attributes.append(dictAttr.getValue().begin(),
826  dictAttr.getValue().end());
827 
828  // Convert array of string into an array of IteratyType enums. This is needed,
829  // because tests still use the old format when 'iterator_types' attribute is
830  // represented as an array of strings.
831  // TODO: Remove this conversion once tests are fixed.
832  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
833  result.attributes.get(getIteratorTypesAttrName(result.name)));
834  if (!iteratorTypes) {
835  return parser.emitError(loc)
836  << "expected " << getIteratorTypesAttrName(result.name)
837  << " array attribute";
838  }
839 
840  SmallVector<Attribute> iteratorTypeAttrs;
841 
842  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
843  auto maybeIteratorType = symbolizeIteratorType(s);
844  if (!maybeIteratorType.has_value())
845  return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
846 
847  iteratorTypeAttrs.push_back(
848  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
849  }
850  result.attributes.set(getIteratorTypesAttrName(result.name),
851  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
852 
853  if (!result.attributes.get(getKindAttrName(result.name))) {
854  result.addAttribute(
855  getKindAttrName(result.name),
857  ContractionOp::getDefaultKind()));
858  }
859  if (masksInfo.empty())
860  return success();
861  if (masksInfo.size() != 2)
862  return parser.emitError(parser.getNameLoc(),
863  "expected zero or exactly 2 vector mask operands");
864  auto lhsType = llvm::cast<VectorType>(types[0]);
865  auto rhsType = llvm::cast<VectorType>(types[1]);
866  auto maskElementType = parser.getBuilder().getI1Type();
867  std::array<VectorType, 2> maskTypes = {
868  VectorType::Builder(lhsType).setElementType(maskElementType),
869  VectorType::Builder(rhsType).setElementType(maskElementType)};
870  if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
871  return failure();
872  return success();
873 }
874 
876  // TODO: Unify printing code with linalg ops.
877  auto attrNames = getTraitAttrNames();
878  llvm::StringSet<> traitAttrsSet;
879  traitAttrsSet.insert_range(attrNames);
881  for (auto attr : (*this)->getAttrs()) {
882  if (attr.getName() == getIteratorTypesAttrName()) {
883  auto iteratorTypes =
884  llvm::cast<ArrayAttr>(attr.getValue())
885  .getAsValueRange<IteratorTypeAttr, IteratorType>();
886  // Convert IteratorType enums into the string representation. This is
887  // needed, because tests still use the old format when 'iterator_types'
888  // attribute is represented as an array of strings.
889  // TODO: Remove this conversion once tests are fixed.
890  SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
891  llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
892  return StringAttr::get(getContext(), stringifyIteratorType(t));
893  }));
894 
895  attrs.emplace_back(getIteratorTypesAttrName(),
896  ArrayAttr::get(getContext(), iteratorTypeNames));
897  } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
898  attrs.push_back(attr);
899  }
900 
901  auto dictAttr = DictionaryAttr::get(getContext(), attrs);
902  p << " " << dictAttr << " " << getLhs() << ", ";
903  p << getRhs() << ", " << getAcc();
904 
905  p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
906  p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
907  << getResultType();
908 }
909 
910 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
911  const std::vector<std::pair<int64_t, int64_t>> &map) {
912  for (auto &dimPair : map) {
913  if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
914  dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
915  lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
916  return false;
917  }
918  return true;
919 }
920 
921 static LogicalResult verifyOutputShape(
922  ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
923  Type resType,
924  const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
925  const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
926  DenseSet<int64_t> lhsContractingDimSet;
927  DenseSet<int64_t> rhsContractingDimSet;
928  for (auto &dimPair : contractingDimMap) {
929  lhsContractingDimSet.insert(dimPair.first);
930  rhsContractingDimSet.insert(dimPair.second);
931  }
932  DenseSet<int64_t> rhsBatchDimSet(llvm::from_range,
933  llvm::make_second_range(batchDimMap));
934 
935  // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
936  SmallVector<int64_t, 4> expectedResultDims;
937  for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
938  if (lhsContractingDimSet.count(i) > 0)
939  continue;
940  expectedResultDims.push_back(lhsType.getDimSize(i));
941  }
942 
943  // Add free dimensions from 'rhsType' to 'expectedResultDims'.
944  for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
945  if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
946  continue;
947  expectedResultDims.push_back(rhsType.getDimSize(i));
948  }
949 
950  // Verify 'expectedResultDims'.
951  if (expectedResultDims.empty()) {
952  // No batch or free dimension implies a scalar result.
953  if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
954  return op.emitOpError("invalid accumulator/result vector shape");
955  } else {
956  // At least one batch or free dimension implies a vector result.
957  auto resVectorType = llvm::dyn_cast<VectorType>(resType);
958  auto accVectorType = llvm::dyn_cast<VectorType>(accType);
959  if (!resVectorType || !accVectorType)
960  return op.emitOpError("invalid accumulator/result vector shape");
961 
962  // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
963  // types fully define the result vector type. This assumes the affine maps
964  // are well-formed, which must have been verified already.
965  MLIRContext *ctx = op.getContext();
966  AffineMap lhsMap = op.getIndexingMapsArray()[0];
967  AffineMap rhsMap = op.getIndexingMapsArray()[1];
968  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
969  return op.emitOpError(
970  "expected all dimensions to be either a LHS or a RHS dimension");
971  SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
972  for (auto pair :
973  {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
974  VectorType v = pair.first;
975  auto map = pair.second;
976  for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
977  unsigned pos = map.getDimPosition(idx);
978  if (!extents[pos])
979  extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
980  }
981  }
982  if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
983  return op.emitOpError("expected all dimensions to get an extent as "
984  "either a LHS or a RHS dimension");
985 
986  AffineMap resMap = op.getIndexingMapsArray()[2];
987  auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
988  /*symbolCount=*/0, extents, ctx);
989  // Compose the resMap with the extentsMap, which is a constant map.
990  AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
991  assert(llvm::all_of(expectedMap.getResults(),
992  llvm::IsaPred<AffineConstantExpr>) &&
993  "expected constant extent along all dimensions.");
994  // Extract the expected shape and build the type.
995  auto expectedShape = llvm::to_vector<4>(
996  llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
997  return cast<AffineConstantExpr>(e).getValue();
998  }));
999  auto expected =
1000  VectorType::get(expectedShape, resVectorType.getElementType(),
1001  resVectorType.getScalableDims());
1002  if (resVectorType != expected || accVectorType != expected)
1003  return op.emitOpError(
1004  "invalid accumulator/result vector shape, expected: ")
1005  << expected;
1006  }
1007  return success();
1008 }
1009 
1010 LogicalResult ContractionOp::verify() {
1011  VectorType lhsType = getLhsType();
1012  VectorType rhsType = getRhsType();
1013  Type accType = getAccType();
1014  Type resType = getResultType();
1015 
1016  if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1017  if (!lhsType.getElementType().isSignlessInteger())
1018  return emitOpError("only supports signless integer types");
1019  }
1020 
1021  // Verify that an indexing map was specified for each vector operand.
1022  if (getIndexingMapsArray().size() != 3)
1023  return emitOpError("expected an indexing map for each vector operand");
1024 
1025  // Verify that each index map has 'numIterators' inputs, no symbols, and
1026  // that the number of map outputs equals the rank of its associated
1027  // vector operand.
1028  unsigned numIterators = getIteratorTypes().getValue().size();
1029  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1030  auto index = it.index();
1031  auto map = it.value();
1032  if (map.getNumSymbols() != 0)
1033  return emitOpError("expected indexing map ")
1034  << index << " to have no symbols";
1035  auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
1036  unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1037  // Verify that the map has the right number of inputs, outputs, and indices.
1038  // This also correctly accounts for (..) -> () for rank-0 results.
1039  if (map.getNumDims() != numIterators)
1040  return emitOpError("expected indexing map ")
1041  << index << " to have " << numIterators << " number of inputs";
1042  if (map.getNumResults() != rank)
1043  return emitOpError("expected indexing map ")
1044  << index << " to have " << rank << " number of outputs";
1045  if (!map.isProjectedPermutation())
1046  return emitOpError("expected indexing map ")
1047  << index << " to be a projected permutation of its inputs";
1048  }
1049 
1050  auto contractingDimMap = getContractingDimMap();
1051  auto batchDimMap = getBatchDimMap();
1052 
1053  // Verify at least one contracting dimension pair was specified.
1054  if (contractingDimMap.empty())
1055  return emitOpError("expected at least one contracting dimension pair");
1056 
1057  // Verify contracting dimension map was properly constructed.
1058  if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1059  return emitOpError("invalid contracting dimension map");
1060 
1061  // Verify batch dimension map was properly constructed.
1062  if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1063  return emitOpError("invalid batch dimension map");
1064 
1065  // Verify 'accType' and 'resType' shape.
1066  if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1067  contractingDimMap, batchDimMap)))
1068  return failure();
1069 
1070  // Verify supported combining kind.
1071  auto vectorType = llvm::dyn_cast<VectorType>(resType);
1072  auto elementType = vectorType ? vectorType.getElementType() : resType;
1073  if (!isSupportedCombiningKind(getKind(), elementType))
1074  return emitOpError("unsupported contraction type");
1075 
1076  // Delayed calling of IndexingMapOpInterface::verifyImpl.
1077  return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1078 }
1079 
1080 // MaskableOpInterface methods.
1081 
1082 /// Returns the mask type expected by this operation. Mostly used for
1083 /// verification purposes. It requires the operation to be vectorized."
1084 Type ContractionOp::getExpectedMaskType() {
1085  auto indexingMaps = this->getIndexingMapsArray();
1086  AffineMap lhsIdxMap = indexingMaps[0];
1087  AffineMap rhsIdxMap = indexingMaps[1];
1088  VectorType lhsType = this->getLhsType();
1089  VectorType rhsType = this->getRhsType();
1090 
1091  unsigned numVecDims = lhsIdxMap.getNumDims();
1092  SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1093  SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1094 
1095  // Using the information in the indexing maps, extract the size of each
1096  // dimension in the vector.contract operation from the two input operands.
1097  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1098  maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1099  maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1100  lhsType.getScalableDims()[dimIdx];
1101  }
1102  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1103  maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1104  maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1105  rhsType.getScalableDims()[dimIdx];
1106  }
1107 
1108  assert(ShapedType::isStaticShape(maskShape) &&
1109  "Mask shape couldn't be computed");
1110 
1111  return VectorType::get(maskShape,
1112  IntegerType::get(lhsType.getContext(), /*width=*/1),
1113  maskShapeScalableDims);
1114 }
1115 
1116 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1117  return SmallVector<StringRef>{getIndexingMapsAttrName(),
1118  getIteratorTypesAttrName(), getKindAttrName()};
1119 }
1120 
1121 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
1122  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1123  if (targetExpr == map.getResult(i))
1124  return i;
1125  return -1;
1126 }
1127 
1128 static std::vector<std::pair<int64_t, int64_t>>
1129 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1130  IteratorType targetIteratorType, MLIRContext *context) {
1131  std::vector<std::pair<int64_t, int64_t>> dimMap;
1132  for (const auto &it : llvm::enumerate(iteratorTypes)) {
1133  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1134  if (iteratorType != targetIteratorType)
1135  continue;
1136  // Search lhs/rhs map results for 'targetExpr'.
1137  auto targetExpr = getAffineDimExpr(it.index(), context);
1138  int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1139  int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1140  if (lhsDim >= 0 && rhsDim >= 0)
1141  dimMap.emplace_back(lhsDim, rhsDim);
1142  }
1143  return dimMap;
1144 }
1145 
1146 void ContractionOp::getIterationBounds(
1147  SmallVectorImpl<int64_t> &iterationBounds) {
1148  auto lhsShape = getLhsType().getShape();
1149  auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1150  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1151  for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1152  // Search lhs/rhs map results for 'targetExpr'.
1153  auto targetExpr = getAffineDimExpr(it.index(), getContext());
1154  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1155  if (iteratorType == IteratorType::reduction) {
1156  // Get reduction dim size from lhs shape (same size in rhsShape).
1157  int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1158  assert(lhsDimIndex >= 0);
1159  iterationBounds.push_back(lhsShape[lhsDimIndex]);
1160  continue;
1161  }
1162  // Get parallel dimension size from result shape.
1163  int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1164  assert(resDimIndex >= 0);
1165  assert(resVectorType != nullptr);
1166  iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1167  }
1168 }
1169 
1170 void ContractionOp::getIterationIndexMap(
1171  std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1172  unsigned numMaps = getIndexingMapsArray().size();
1173  iterationIndexMap.resize(numMaps);
1174  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1175  auto index = it.index();
1176  auto map = it.value();
1177  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1178  auto dim = cast<AffineDimExpr>(map.getResult(i));
1179  iterationIndexMap[index][dim.getPosition()] = i;
1180  }
1181  }
1182 }
1183 
1184 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1185  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1186  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1187  getContext());
1188 }
1189 
1190 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1191  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1192  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1193  getContext());
1194 }
1195 
1196 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1198  getIterationBounds(shape);
1199  return shape;
1200 }
1201 
1202 /// Return a fused vector::ContractionOp which represents a patterns such as:
1203 ///
1204 /// ```mlir
1205 /// %c0 = vector.constant 0: ...
1206 /// %c = vector.contract %a, %b, %c0: ...
1207 /// %e = add %c, %d: ...
1208 /// ```
1209 ///
1210 /// by:
1211 ///
1212 /// ```mlir
1213 /// %e = vector.contract %a, %b, %d: ...
1214 /// ```
1215 ///
1216 /// Return null if the canonicalization does not apply.
1217 // TODO: This should be a folding of Add into Contract in core but while they
1218 // live in different dialects, it is not possible without unnatural
1219 // dependencies.
1220 template <typename AddOpType>
1221 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1223 
1224  LogicalResult matchAndRewrite(AddOpType addOp,
1225  PatternRewriter &rewriter) const override {
1226  auto canonicalize = [&](Value maybeContraction,
1227  Value otherOperand) -> vector::ContractionOp {
1228  vector::ContractionOp contractionOp =
1229  dyn_cast_or_null<vector::ContractionOp>(
1230  maybeContraction.getDefiningOp());
1231  if (!contractionOp)
1232  return vector::ContractionOp();
1233  if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1234  contractionOp.getAcc().getDefiningOp())) {
1235  if (maybeZero.getValue() ==
1236  rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1237  IRMapping bvm;
1238  bvm.map(contractionOp.getAcc(), otherOperand);
1239  auto newContraction =
1240  cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1241  rewriter.replaceOp(addOp, newContraction.getResult());
1242  return newContraction;
1243  }
1244  }
1245  return vector::ContractionOp();
1246  };
1247 
1248  Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1249  vector::ContractionOp contract = canonicalize(a, b);
1250  contract = contract ? contract : canonicalize(b, a);
1251  return contract ? success() : failure();
1252  }
1253 };
1254 
1255 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1256  MLIRContext *context) {
1259 }
1260 
1261 // Returns `true` if `index` is either within [0, maxIndex) or equal to
1262 // `poisonValue`.
1263 static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
1264  int64_t maxIndex) {
1265  return index == poisonValue || (index >= 0 && index < maxIndex);
1266 }
1267 
1268 //===----------------------------------------------------------------------===//
1269 // ExtractOp
1270 //===----------------------------------------------------------------------===//
1271 
1272 void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1273  SetIntRangeFn setResultRanges) {
1274  setResultRanges(getResult(), argRanges.front());
1275 }
1276 
1277 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1278  Value source) {
1279  auto vectorTy = cast<VectorType>(source.getType());
1280  build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
1281 }
1282 
1283 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1284  Value source, int64_t position) {
1285  build(builder, result, source, ArrayRef<int64_t>{position});
1286 }
1287 
1288 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1289  Value source, OpFoldResult position) {
1290  build(builder, result, source, ArrayRef<OpFoldResult>{position});
1291 }
1292 
1293 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1294  Value source, ArrayRef<int64_t> position) {
1295  build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1296  builder.getDenseI64ArrayAttr(position));
1297 }
1298 
1299 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1300  Value source, ArrayRef<OpFoldResult> position) {
1301  SmallVector<int64_t> staticPos;
1302  SmallVector<Value> dynamicPos;
1303  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1304  build(builder, result, source, dynamicPos,
1305  builder.getDenseI64ArrayAttr(staticPos));
1306 }
1307 
1308 LogicalResult
1309 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1310  ExtractOp::Adaptor adaptor,
1311  SmallVectorImpl<Type> &inferredReturnTypes) {
1312  auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1313  if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1314  vectorType.getRank()) {
1315  inferredReturnTypes.push_back(vectorType.getElementType());
1316  } else {
1317  auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1318  vectorType.getRank());
1319  inferredReturnTypes.push_back(VectorType::get(
1320  vectorType.getShape().drop_front(n), vectorType.getElementType(),
1321  vectorType.getScalableDims().drop_front(n)));
1322  }
1323  return success();
1324 }
1325 
1326 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1327  // Allow extracting 1-element vectors instead of scalars.
1328  auto isCompatible = [](TypeRange l, TypeRange r) {
1329  auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1330  return vectorType && vectorType.getShape().equals({1}) &&
1331  vectorType.getElementType() == r.front();
1332  };
1333  if (l.size() == 1 && r.size() == 1 &&
1334  (isCompatible(l, r) || isCompatible(r, l)))
1335  return true;
1336  return l == r;
1337 }
1338 
1339 LogicalResult vector::ExtractOp::verify() {
1340  if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
1341  if (resTy.getRank() == 0)
1342  return emitError(
1343  "expected a scalar instead of a 0-d vector as the result type");
1344 
1345  // Note: This check must come before getMixedPosition() to prevent a crash.
1346  auto dynamicMarkersCount =
1347  llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1348  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1349  return emitOpError(
1350  "mismatch between dynamic and static positions (kDynamic marker but no "
1351  "corresponding dynamic position) -- this can only happen due to an "
1352  "incorrect fold/rewrite");
1353  auto position = getMixedPosition();
1354  if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1355  return emitOpError(
1356  "expected position attribute of rank no greater than vector rank");
1357  for (auto [idx, pos] : llvm::enumerate(position)) {
1358  if (auto attr = dyn_cast<Attribute>(pos)) {
1359  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1361  constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1362  return emitOpError("expected position attribute #")
1363  << (idx + 1)
1364  << " to be a non-negative integer smaller than the "
1365  "corresponding vector dimension or poison (-1)";
1366  }
1367  }
1368  }
1369  return success();
1370 }
1371 
1372 template <typename IntType>
1373 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1374  return llvm::to_vector<4>(llvm::map_range(
1375  arrayAttr.getAsRange<IntegerAttr>(),
1376  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1377 }
1378 
1379 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1380 /// positions.
1381 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1382  if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1383  return failure();
1384 
1385  // TODO: Canonicalization for dynamic position not implemented yet.
1386  if (extractOp.hasDynamicPosition())
1387  return failure();
1388 
1389  SmallVector<int64_t> globalPosition;
1390  ExtractOp currentOp = extractOp;
1391  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1392  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1393  while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1394  currentOp = nextOp;
1395  // TODO: Canonicalization for dynamic position not implemented yet.
1396  if (currentOp.hasDynamicPosition())
1397  return failure();
1398  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1399  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1400  }
1401  extractOp.setOperand(0, currentOp.getVector());
1402  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1403  OpBuilder b(extractOp.getContext());
1404  std::reverse(globalPosition.begin(), globalPosition.end());
1405  extractOp.setStaticPosition(globalPosition);
1406  return success();
1407 }
1408 
1409 namespace {
1410 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1411 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1412 /// Compose TransposeOp permutations as we walk back.
1413 /// This helper class keeps an updated extraction position `extractPosition`
1414 /// with extra trailing sentinels.
1415 /// The sentinels encode the internal transposition status of the result vector.
1416 /// As we iterate, extractPosition is permuted and updated.
1417 class ExtractFromInsertTransposeChainState {
1418 public:
1419  ExtractFromInsertTransposeChainState(ExtractOp e);
1420 
1421  /// Iterate over producing insert and transpose ops until we find a fold.
1422  Value fold();
1423 
1424 private:
1425  /// Return true if the vector at position `a` is contained within the vector
1426  /// at position `b`. Under insert/extract semantics, this is the same as `a`
1427  /// is a prefix of `b`.
1428  template <typename ContainerA, typename ContainerB>
1429  bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1430  return a.size() <= b.size() &&
1431  std::equal(a.begin(), a.begin() + a.size(), b.begin());
1432  }
1433 
1434  /// Return true if the vector at position `a` intersects the vector at
1435  /// position `b`. Under insert/extract semantics, this is the same as equality
1436  /// of all entries of `a` that are >=0 with the corresponding entries of b.
1437  /// Comparison is on the common prefix (i.e. zip).
1438  template <typename ContainerA, typename ContainerB>
1439  bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1440  for (auto [elemA, elemB] : llvm::zip(a, b)) {
1441  if (elemA < 0 || elemB < 0)
1442  continue;
1443  if (elemA != elemB)
1444  return false;
1445  }
1446  return true;
1447  }
1448 
1449  /// Folding is only possible in the absence of an internal permutation in the
1450  /// result vector.
1451  bool canFold() {
1452  return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1453  }
1454 
1455  // Helper to get the next defining op of interest.
1456  void updateStateForNextIteration(Value v) {
1457  nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1458  nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1459  };
1460 
1461  // Case 1. If we hit a transpose, just compose the map and iterate.
1462  // Invariant: insert + transpose do not change rank, we can always compose.
1463  LogicalResult handleTransposeOp();
1464 
1465  // Case 2: the insert position matches extractPosition exactly, early return.
1466  LogicalResult handleInsertOpWithMatchingPos(Value &res);
1467 
1468  /// Case 3: if the insert position is a prefix of extractPosition, extract a
1469  /// portion of the source of the insert.
1470  /// Example:
1471  /// ```
1472  /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1473  /// // extractPosition == [1, 2, 3]
1474  /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1475  /// // can fold to vector.extract %source[0, 3]
1476  /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1477  /// ```
1478  /// To traverse through %source, we need to set the leading dims to 0 and
1479  /// drop the extra leading dims.
1480  /// This method updates the internal state.
1481  LogicalResult handleInsertOpWithPrefixPos(Value &res);
1482 
1483  /// Try to fold in place to extract(source, extractPosition) and return the
1484  /// folded result. Return null if folding is not possible (e.g. due to an
1485  /// internal transposition in the result).
1486  Value tryToFoldExtractOpInPlace(Value source);
1487 
1488  ExtractOp extractOp;
1489  int64_t vectorRank;
1490  int64_t extractedRank;
1491 
1492  InsertOp nextInsertOp;
1493  TransposeOp nextTransposeOp;
1494 
1495  /// Sentinel values that encode the internal permutation status of the result.
1496  /// They are set to (-1, ... , -k) at the beginning and appended to
1497  /// `extractPosition`.
1498  /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1499  /// ensure that there is no internal transposition.
1500  /// Internal transposition cannot be accounted for with a folding pattern.
1501  // TODO: We could relax the internal transposition with an extra transposition
1502  // operation in a future canonicalizer.
1503  SmallVector<int64_t> sentinels;
1505 };
1506 } // namespace
1507 
1508 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1509  ExtractOp e)
1510  : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1511  extractedRank(extractOp.getNumIndices()) {
1512  assert(vectorRank >= extractedRank && "Extracted position overflow");
1513  sentinels.reserve(vectorRank - extractedRank);
1514  for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1515  sentinels.push_back(-(i + 1));
1516  extractPosition.assign(extractOp.getStaticPosition().begin(),
1517  extractOp.getStaticPosition().end());
1518  llvm::append_range(extractPosition, sentinels);
1519 }
1520 
1521 // Case 1. If we hit a transpose, just compose the map and iterate.
1522 // Invariant: insert + transpose do not change rank, we can always compose.
1523 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1524  // TODO: Canonicalization for dynamic position not implemented yet.
1525  if (extractOp.hasDynamicPosition())
1526  return failure();
1527 
1528  if (!nextTransposeOp)
1529  return failure();
1530  AffineMap m = inversePermutation(AffineMap::getPermutationMap(
1531  nextTransposeOp.getPermutation(), extractOp.getContext()));
1533  return success();
1534 }
1535 
1536 // Case 2: the insert position matches extractPosition exactly, early return.
1537 LogicalResult
1538 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1539  Value &res) {
1540  // TODO: Canonicalization for dynamic position not implemented yet.
1541  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1542  return failure();
1543 
1544  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1545  if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1546  return failure();
1547  // Case 2.a. early-exit fold.
1548  res = nextInsertOp.getValueToStore();
1549  // Case 2.b. if internal transposition is present, canFold will be false.
1550  return success(canFold());
1551 }
1552 
1553 /// Case 3: if inserted position is a prefix of extractPosition,
1554 /// extract a portion of the source of the insertion.
1555 /// This method updates the internal state.
1556 LogicalResult
1557 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1558  // TODO: Canonicalization for dynamic position not implemented yet.
1559  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1560  return failure();
1561 
1562  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1563  if (!isContainedWithin(insertedPos, extractPosition))
1564  return failure();
1565  // Set leading dims to zero.
1566  std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1567  // Drop extra leading dims.
1568  extractPosition.erase(extractPosition.begin(),
1569  extractPosition.begin() + insertedPos.size());
1570  extractedRank = extractPosition.size() - sentinels.size();
1571  // Case 3.a. early-exit fold (break and delegate to post-while path).
1572  res = nextInsertOp.getValueToStore();
1573  // Case 3.b. if internal transposition is present, canFold will be false.
1574  return success();
1575 }
1576 
1577 /// Try to fold in place to extract(source, extractPosition) and return the
1578 /// folded result. Return null if folding is not possible (e.g. due to an
1579 /// internal transposition in the result).
1580 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1581  Value source) {
1582  // TODO: Canonicalization for dynamic position not implemented yet.
1583  if (extractOp.hasDynamicPosition())
1584  return Value();
1585 
1586  // If we can't fold (either internal transposition, or nothing to fold), bail.
1587  bool nothingToFold = (source == extractOp.getVector());
1588  if (nothingToFold || !canFold())
1589  return Value();
1590 
1591  // Otherwise, fold by updating the op inplace and return its result.
1592  OpBuilder b(extractOp.getContext());
1593  extractOp.setStaticPosition(
1594  ArrayRef(extractPosition).take_front(extractedRank));
1595  extractOp.getVectorMutable().assign(source);
1596  return extractOp.getResult();
1597 }
1598 
1599 /// Iterate over producing insert and transpose ops until we find a fold.
1600 Value ExtractFromInsertTransposeChainState::fold() {
1601  // TODO: Canonicalization for dynamic position not implemented yet.
1602  if (extractOp.hasDynamicPosition())
1603  return Value();
1604 
1605  Value valueToExtractFrom = extractOp.getVector();
1606  updateStateForNextIteration(valueToExtractFrom);
1607  while (nextInsertOp || nextTransposeOp) {
1608  // Case 1. If we hit a transpose, just compose the map and iterate.
1609  // Invariant: insert + transpose do not change rank, we can always compose.
1610  if (succeeded(handleTransposeOp())) {
1611  valueToExtractFrom = nextTransposeOp.getVector();
1612  updateStateForNextIteration(valueToExtractFrom);
1613  continue;
1614  }
1615 
1616  Value result;
1617  // Case 2: the position match exactly.
1618  if (succeeded(handleInsertOpWithMatchingPos(result)))
1619  return result;
1620 
1621  // Case 3: if the inserted position is a prefix of extractPosition, we can
1622  // just extract a portion of the source of the insert.
1623  if (succeeded(handleInsertOpWithPrefixPos(result)))
1624  return tryToFoldExtractOpInPlace(result);
1625 
1626  // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1627  // values. This is a more difficult case and we bail.
1628  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1629  if (isContainedWithin(extractPosition, insertedPos) ||
1630  intersectsWhereNonNegative(extractPosition, insertedPos))
1631  return Value();
1632 
1633  // Case 5: No intersection, we forward the extract to insertOp.dest().
1634  valueToExtractFrom = nextInsertOp.getDest();
1635  updateStateForNextIteration(valueToExtractFrom);
1636  }
1637  // If after all this we can fold, go for it.
1638  return tryToFoldExtractOpInPlace(valueToExtractFrom);
1639 }
1640 
1641 /// Returns true if the operation has a 0-D vector type operand or result.
1642 static bool hasZeroDimVectors(Operation *op) {
1643  auto hasZeroDimVectorType = [](Type type) -> bool {
1644  auto vecType = dyn_cast<VectorType>(type);
1645  return vecType && vecType.getRank() == 0;
1646  };
1647 
1648  return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1649  llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1650 }
1651 
1652 /// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
1653 /// 1s, are considered to be 'broadcastlike'.
1654 static bool isBroadcastLike(Operation *op) {
1655  if (isa<BroadcastOp, SplatOp>(op))
1656  return true;
1657 
1658  auto shapeCast = dyn_cast<ShapeCastOp>(op);
1659  if (!shapeCast)
1660  return false;
1661 
1662  // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1663  // Checking that the destination shape has a prefix of 1s is not sufficient,
1664  // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
1665  // is that the source shape is a suffix of the destination shape.
1666  VectorType srcType = shapeCast.getSourceVectorType();
1667  ArrayRef<int64_t> srcShape = srcType.getShape();
1668  uint64_t srcRank = srcType.getRank();
1669  ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
1670  return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1671 }
1672 
1673 /// Fold extract(broadcast(X)) to either extract(X) or just X.
1674 ///
1675 /// Example:
1676 ///
1677 /// broadcast extract [1][2]
1678 /// (3, 4) --------> (2, 3, 4) ----------------> (4)
1679 ///
1680 /// becomes
1681 /// extract [1]
1682 /// (3,4) -------------------------------------> (4)
1683 ///
1684 ///
1685 /// The variable names used in this implementation correspond to the above
1686 /// shapes as,
1687 ///
1688 /// - (3, 4) is `input` shape.
1689 /// - (2, 3, 4) is `broadcast` shape.
1690 /// - (4) is `extract` shape.
1691 ///
1692 /// This folding is possible when the suffix of `input` shape is the same as
1693 /// `extract` shape.
1694 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1695 
1696  Operation *defOp = extractOp.getVector().getDefiningOp();
1697  if (!defOp || !isBroadcastLike(defOp))
1698  return Value();
1699 
1700  Value input = defOp->getOperand(0);
1701 
1702  // Replace extract(broadcast(X)) with X
1703  if (extractOp.getType() == input.getType())
1704  return input;
1705 
1706  // Get required types and ranks in the chain
1707  // input -> broadcast -> extract
1708  // (scalars are treated as rank-0).
1709  auto inputType = llvm::dyn_cast<VectorType>(input.getType());
1710  auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1711  unsigned inputRank = inputType ? inputType.getRank() : 0;
1712  unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1713  unsigned extractRank = extractType ? extractType.getRank() : 0;
1714 
1715  // Cannot do without the broadcast if overall the rank increases.
1716  if (extractRank > inputRank)
1717  return Value();
1718 
1719  // The above condition guarantees that input is a vector.
1720  assert(inputType && "input must be a vector type because of previous checks");
1721  ArrayRef<int64_t> inputShape = inputType.getShape();
1722 
1723  // In the case where there is a broadcast dimension in the suffix, it is not
1724  // possible to replace extract(broadcast(X)) with extract(X). Example:
1725  //
1726  // broadcast extract
1727  // (1) --------> (3,4) ------> (4)
1728  if (extractType &&
1729  extractType.getShape() != inputShape.take_back(extractRank))
1730  return Value();
1731 
1732  // Replace extract(broadcast(X)) with extract(X).
1733  // First, determine the new extraction position.
1734  unsigned deltaOverall = inputRank - extractRank;
1735  unsigned deltaBroadcast = broadcastRank - inputRank;
1736  SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
1737  SmallVector<OpFoldResult> newPositions(deltaOverall);
1738  IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
1739  for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1740  newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1741  }
1742  auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
1743  extractOp->setOperands(
1744  llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
1745  extractOp.setStaticPosition(staticPos);
1746  return extractOp.getResult();
1747 }
1748 
1749 /// Fold extractOp coming from ShuffleOp.
1750 ///
1751 /// Example:
1752 ///
1753 /// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1754 /// : vector<8xf32>, vector<8xf32>
1755 /// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1756 /// ->
1757 /// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1758 ///
1759 static Value foldExtractFromShuffle(ExtractOp extractOp) {
1760  // Dynamic positions are not folded as the resulting code would be more
1761  // complex than the input code.
1762  if (extractOp.hasDynamicPosition())
1763  return Value();
1764 
1765  auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1766  if (!shuffleOp)
1767  return Value();
1768 
1769  // TODO: 0-D or multi-dimensional vectors not supported yet.
1770  if (shuffleOp.getResultVectorType().getRank() != 1)
1771  return Value();
1772 
1773  int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1774  auto shuffleMask = shuffleOp.getMask();
1775  int64_t extractIdx = extractOp.getStaticPosition()[0];
1776  int64_t shuffleIdx = shuffleMask[extractIdx];
1777 
1778  // Find the shuffled vector to extract from based on the shuffle index.
1779  if (shuffleIdx < inputVecSize) {
1780  extractOp.setOperand(0, shuffleOp.getV1());
1781  extractOp.setStaticPosition({shuffleIdx});
1782  } else {
1783  extractOp.setOperand(0, shuffleOp.getV2());
1784  extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1785  }
1786 
1787  return extractOp.getResult();
1788 }
1789 
1790 // Fold extractOp with source coming from ShapeCast op.
1791 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1792  // TODO: Canonicalization for dynamic position not implemented yet.
1793  if (extractOp.hasDynamicPosition())
1794  return Value();
1795 
1796  auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1797  if (!shapeCastOp)
1798  return Value();
1799 
1800  // Get the nth dimension size starting from lowest dimension.
1801  auto getDimReverse = [](VectorType type, int64_t n) {
1802  return type.getShape().take_back(n + 1).front();
1803  };
1804  int64_t destinationRank =
1805  llvm::isa<VectorType>(extractOp.getType())
1806  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1807  : 0;
1808  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1809  return Value();
1810  if (destinationRank > 0) {
1811  auto destinationType =
1812  llvm::cast<VectorType>(extractOp.getResult().getType());
1813  for (int64_t i = 0; i < destinationRank; i++) {
1814  // The lowest dimension of the destination must match the lowest
1815  // dimension of the shapecast op source.
1816  // TODO: This case could be support in a canonicalization pattern.
1817  if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1818  getDimReverse(destinationType, i))
1819  return Value();
1820  }
1821  }
1822  // Extract the strides associated with the extract op vector source. Then use
1823  // this to calculate a linearized position for the extract.
1824  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1825  std::reverse(extractedPos.begin(), extractedPos.end());
1826  SmallVector<int64_t, 4> strides;
1827  int64_t stride = 1;
1828  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1829  strides.push_back(stride);
1830  stride *=
1831  getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1832  }
1833 
1834  int64_t position = linearize(extractedPos, strides);
1835  // Then extract the strides associated to the shapeCast op vector source and
1836  // delinearize the position using those strides.
1837  SmallVector<int64_t, 4> newStrides;
1838  int64_t numDimension =
1839  shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1840  stride = 1;
1841  for (int64_t i = 0; i < numDimension; i++) {
1842  newStrides.push_back(stride);
1843  stride *=
1844  getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1845  }
1846  std::reverse(newStrides.begin(), newStrides.end());
1847  SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1848  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1849  OpBuilder b(extractOp.getContext());
1850  extractOp.setStaticPosition(newPosition);
1851  extractOp.setOperand(0, shapeCastOp.getSource());
1852  return extractOp.getResult();
1853 }
1854 
1855 /// Fold an ExtractOp from ExtractStridedSliceOp.
1856 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1857  // TODO: Canonicalization for dynamic position not implemented yet.
1858  if (extractOp.hasDynamicPosition())
1859  return Value();
1860 
1861  auto extractStridedSliceOp =
1862  extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1863  if (!extractStridedSliceOp)
1864  return Value();
1865 
1866  // 0-D vectors not supported.
1867  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1868  if (hasZeroDimVectors(extractStridedSliceOp))
1869  return Value();
1870 
1871  // Return if 'extractStridedSliceOp' has non-unit strides.
1872  if (extractStridedSliceOp.hasNonUnitStrides())
1873  return Value();
1874 
1875  // Trim offsets for dimensions fully extracted.
1876  auto sliceOffsets =
1877  extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1878  while (!sliceOffsets.empty()) {
1879  size_t lastOffset = sliceOffsets.size() - 1;
1880  if (sliceOffsets.back() != 0 ||
1881  extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1882  extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1883  break;
1884  sliceOffsets.pop_back();
1885  }
1886  unsigned destinationRank = 0;
1887  if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1888  destinationRank = vecType.getRank();
1889  // The dimensions of the result need to be untouched by the
1890  // extractStridedSlice op.
1891  if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1892  sliceOffsets.size())
1893  return Value();
1894 
1895  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1896  assert(extractedPos.size() >= sliceOffsets.size());
1897  for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1898  extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1899  extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1900 
1901  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1902  OpBuilder b(extractOp.getContext());
1903  extractOp.setStaticPosition(extractedPos);
1904  return extractOp.getResult();
1905 }
1906 
1907 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1908 static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1909  // TODO: Canonicalization for dynamic position not implemented yet.
1910  if (extractOp.hasDynamicPosition())
1911  return Value();
1912 
1913  int64_t destinationRank =
1914  llvm::isa<VectorType>(extractOp.getType())
1915  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1916  : 0;
1917  auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1918  if (!insertOp)
1919  return Value();
1920 
1921  // 0-D vectors not supported.
1922  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1923  if (hasZeroDimVectors(insertOp))
1924  return Value();
1925 
1926  while (insertOp) {
1927  int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1928  insertOp.getSourceVectorType().getRank();
1929  if (destinationRank > insertOp.getSourceVectorType().getRank())
1930  return Value();
1931  auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1932  ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1933 
1934  if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1935  return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1936  }))
1937  return Value();
1938  bool disjoint = false;
1939  SmallVector<int64_t, 4> offsetDiffs;
1940  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1941  int64_t start = insertOffsets[dim];
1942  int64_t size =
1943  (dim < insertRankDiff)
1944  ? 1
1945  : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1946  int64_t end = start + size;
1947  int64_t offset = extractOffsets[dim];
1948  // Check if the start of the extract offset is in the interval inserted.
1949  if (start <= offset && offset < end) {
1950  if (dim >= insertRankDiff)
1951  offsetDiffs.push_back(offset - start);
1952  continue;
1953  }
1954  disjoint = true;
1955  break;
1956  }
1957  // The extract element chunk overlap with the vector inserted.
1958  if (!disjoint) {
1959  // If any of the inner dimensions are only partially inserted we have a
1960  // partial overlap.
1961  int64_t srcRankDiff =
1962  insertOp.getSourceVectorType().getRank() - destinationRank;
1963  for (int64_t i = 0; i < destinationRank; i++) {
1964  if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1965  insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1966  insertRankDiff))
1967  return Value();
1968  }
1969  extractOp.getVectorMutable().assign(insertOp.getValueToStore());
1970  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1971  OpBuilder b(extractOp.getContext());
1972  extractOp.setStaticPosition(offsetDiffs);
1973  return extractOp.getResult();
1974  }
1975  // If the chunk extracted is disjoint from the chunk inserted, keep
1976  // looking in the insert chain.
1977  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1978  }
1979  return Value();
1980 }
1981 
1982 /// Try to fold the extraction of a scalar from a vector defined by
1983 /// vector.from_elements. E.g.:
1984 ///
1985 /// %0 = vector.from_elements %a, %b : vector<2xf32>
1986 /// %1 = vector.extract %0[0] : f32 from vector<2xf32>
1987 /// ==> fold to %a
1988 static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
1989  // Dynamic extractions cannot be folded.
1990  if (extractOp.hasDynamicPosition())
1991  return {};
1992 
1993  // Look for extract(from_elements).
1994  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
1995  if (!fromElementsOp)
1996  return {};
1997 
1998  // Scalable vectors are not supported.
1999  auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2000  if (vecType.isScalable())
2001  return {};
2002 
2003  // Only extractions of scalars are supported.
2004  int64_t rank = vecType.getRank();
2005  ArrayRef<int64_t> indices = extractOp.getStaticPosition();
2006  if (extractOp.getType() != vecType.getElementType())
2007  return {};
2008  assert(static_cast<int64_t>(indices.size()) == rank &&
2009  "unexpected number of indices");
2010 
2011  // Compute flattened/linearized index and fold to operand.
2012  int flatIndex = 0;
2013  int stride = 1;
2014  for (int i = rank - 1; i >= 0; --i) {
2015  flatIndex += indices[i] * stride;
2016  stride *= vecType.getDimSize(i);
2017  }
2018  return fromElementsOp.getElements()[flatIndex];
2019 }
2020 
2021 /// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
2022 /// then fold it.
2023 template <typename OpType, typename AdaptorType>
2024 static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
2025  SmallVectorImpl<Value> &operands) {
2026  std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2027  OperandRange dynamicPosition = op.getDynamicPosition();
2028  ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2030  if constexpr (std::is_same_v<OpType, ExtractOp>)
2031  vectorShape = op.getSourceVectorType().getShape();
2032  else
2033  vectorShape = op.getDestVectorType().getShape();
2034 
2035  // If the dynamic operands is empty, it is returned directly.
2036  if (!dynamicPosition.size())
2037  return {};
2038 
2039  // `index` is used to iterate over the `dynamicPosition`.
2040  unsigned index = 0;
2041 
2042  // `opChange` is a flag. If it is true, it means to update `op` in place.
2043  bool opChange = false;
2044  for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2045  if (ShapedType::isStatic(staticPosition[i]))
2046  continue;
2047  Attribute positionAttr = dynamicPositionAttr[index];
2048  Value position = dynamicPosition[index++];
2049  if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2050  int64_t value = attr.getInt();
2051  // Do not fold if the value is out of bounds (-1 signifies a poison
2052  // value rather than OOB index).
2053  if (value >= -1 && value < vectorShape[i]) {
2054  staticPosition[i] = attr.getInt();
2055  opChange = true;
2056  continue;
2057  }
2058  }
2059  operands.push_back(position);
2060  }
2061 
2062  if (opChange) {
2063  op.setStaticPosition(staticPosition);
2064  op.getOperation()->setOperands(operands);
2065  // Return the original result to indicate an in-place folding happened.
2066  return op.getResult();
2067  }
2068  return {};
2069 }
2070 
2071 /// Fold an insert or extract operation into an poison value when a poison index
2072 /// is found at any dimension of the static position.
2074  ArrayRef<int64_t> staticPos,
2075  int64_t poisonVal) {
2076  if (!is_contained(staticPos, poisonVal))
2077  return {};
2078 
2079  return ub::PoisonAttr::get(context);
2080 }
2081 
2082 /// Fold a vector extract from is a poison source.
2084  if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2085  return srcAttr;
2086 
2087  return {};
2088 }
2089 
2090 /// Fold a vector extract extracting from a DenseElementsAttr.
2092  Attribute srcAttr) {
2093  auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2094  if (!denseAttr) {
2095  return {};
2096  }
2097 
2098  if (denseAttr.isSplat()) {
2099  Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2100  if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2101  newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2102  return newAttr;
2103  }
2104 
2105  auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2106  if (vecTy.isScalable())
2107  return {};
2108 
2109  if (extractOp.hasDynamicPosition()) {
2110  return {};
2111  }
2112 
2113  // Materializing subsets of a large constant array can generally lead to
2114  // explosion in IR size because of different combination of subsets that
2115  // can exist. However, vector.extract is a restricted form of subset
2116  // extract where you can only extract non-overlapping (or the same) subset for
2117  // a given rank of the subset. Because of this property, the IR size can only
2118  // increase at most by `rank * size(array)` from a single constant array being
2119  // extracted by multiple extracts.
2120 
2121  // Calculate the linearized position of the continuous chunk of elements to
2122  // extract.
2123  SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2124  copy(extractOp.getStaticPosition(), completePositions.begin());
2125  int64_t startPos =
2126  linearize(completePositions, computeStrides(vecTy.getShape()));
2127  auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2128 
2129  TypedAttr newAttr;
2130  if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2131  SmallVector<Attribute> elementValues(
2132  denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2133  newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2134  } else {
2135  newAttr = *denseValuesBegin;
2136  }
2137 
2138  return newAttr;
2139 }
2140 
2141 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2142  // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2143  // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2144  // mismatch).
2145  if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
2146  return getVector();
2147  if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
2148  return res;
2149  // Fold `arith.constant` indices into the `vector.extract` operation.
2150  // Do not stop here as this fold may enable subsequent folds that require
2151  // constant indices.
2152  SmallVector<Value> operands = {getVector()};
2153  auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
2154 
2155  if (auto res = foldPoisonIndexInsertExtractOp(
2156  getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2157  return res;
2158  if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
2159  return res;
2160  if (succeeded(foldExtractOpFromExtractChain(*this)))
2161  return getResult();
2162  if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2163  return res;
2164  if (auto res = foldExtractFromBroadcast(*this))
2165  return res;
2166  if (auto res = foldExtractFromShuffle(*this))
2167  return res;
2168  if (auto res = foldExtractFromShapeCast(*this))
2169  return res;
2170  if (auto val = foldExtractFromExtractStrided(*this))
2171  return val;
2172  if (auto val = foldExtractStridedOpFromInsertChain(*this))
2173  return val;
2174  if (auto val = foldScalarExtractFromFromElements(*this))
2175  return val;
2176 
2177  return inplaceFolded;
2178 }
2179 
2180 namespace {
2181 
2182 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2183 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2184 public:
2186 
2187  LogicalResult matchAndRewrite(ExtractOp extractOp,
2188  PatternRewriter &rewriter) const override {
2189 
2190  Operation *defOp = extractOp.getVector().getDefiningOp();
2191  VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2192  if (!defOp || !isBroadcastLike(defOp) || !outType)
2193  return failure();
2194 
2195  Value source = defOp->getOperand(0);
2196  if (isBroadcastableTo(source.getType(), outType) !=
2197  BroadcastableToResult::Success)
2198  return failure();
2199 
2200  rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
2201  return success();
2202  }
2203 };
2204 
2205 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2206 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2207 public:
2209 
2210  LogicalResult matchAndRewrite(ExtractOp extractOp,
2211  PatternRewriter &rewriter) const override {
2212  auto createMaskOp =
2213  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2214  if (!createMaskOp)
2215  return failure();
2216 
2217  VectorType extractedMaskType =
2218  llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2219 
2220  if (!extractedMaskType)
2221  return failure();
2222 
2223  auto maskOperands = createMaskOp.getOperands();
2224  ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2225  VectorType maskType = createMaskOp.getVectorType();
2226 
2227  bool containsUnknownDims = false;
2228  bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2229 
2230  for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2231  dimIdx++) {
2232  int64_t pos = extractOpPos[dimIdx];
2233  Value operand = maskOperands[dimIdx];
2234  auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2235  if (!constantOp) {
2236  // Bounds of this dim unknown.
2237  containsUnknownDims = true;
2238  continue;
2239  }
2240 
2241  int64_t createMaskBound =
2242  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2243 
2244  if (pos != ShapedType::kDynamic) {
2245  // If any position is outside the range from the `create_mask`, then the
2246  // extracted mask will be all-false.
2247  allFalse |= pos >= createMaskBound;
2248  } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2249  // This dim is not all-true and since this is a dynamic index we don't
2250  // know if the extraction is within the true or false region.
2251  // Note: Zero dims have already handled via getMaskFormat().
2252  containsUnknownDims = true;
2253  }
2254  }
2255 
2256  if (allFalse) {
2257  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2258  extractOp, DenseElementsAttr::get(extractedMaskType, false));
2259  } else if (!containsUnknownDims) {
2260  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2261  extractOp, extractedMaskType,
2262  maskOperands.drop_front(extractOpPos.size()));
2263  } else {
2264  return failure();
2265  }
2266  return success();
2267  }
2268 };
2269 
2270 // Folds extract(shape_cast(..)) into shape_cast when the total element count
2271 // does not change.
2272 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2273  PatternRewriter &rewriter) {
2274  auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2275  if (!castOp)
2276  return failure();
2277 
2278  VectorType sourceType = castOp.getSourceVectorType();
2279  auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2280  if (!targetType)
2281  return failure();
2282 
2283  if (sourceType.getNumElements() != targetType.getNumElements())
2284  return failure();
2285 
2286  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2287  castOp.getSource());
2288  return success();
2289 }
2290 
2291 /// Try to canonicalize the extraction of a subvector from a vector defined by
2292 /// vector.from_elements. E.g.:
2293 ///
2294 /// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2295 /// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2296 /// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2297 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2298  PatternRewriter &rewriter) {
2299  // Dynamic positions are not supported.
2300  if (extractOp.hasDynamicPosition())
2301  return failure();
2302 
2303  // Scalar extracts are handled by the folder.
2304  auto resultType = dyn_cast<VectorType>(extractOp.getType());
2305  if (!resultType)
2306  return failure();
2307 
2308  // Look for extracts from a from_elements op.
2309  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2310  if (!fromElementsOp)
2311  return failure();
2312  VectorType inputType = fromElementsOp.getType();
2313 
2314  // Scalable vectors are not supported.
2315  if (resultType.isScalable() || inputType.isScalable())
2316  return failure();
2317 
2318  // Compute the position of first extracted element and flatten/linearize the
2319  // position.
2320  SmallVector<int64_t> firstElementPos =
2321  llvm::to_vector(extractOp.getStaticPosition());
2322  firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2323  int flatIndex = 0;
2324  int stride = 1;
2325  for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2326  flatIndex += firstElementPos[i] * stride;
2327  stride *= inputType.getDimSize(i);
2328  }
2329 
2330  // Replace the op with a smaller from_elements op.
2331  rewriter.replaceOpWithNewOp<FromElementsOp>(
2332  extractOp, resultType,
2333  fromElementsOp.getElements().slice(flatIndex,
2334  resultType.getNumElements()));
2335  return success();
2336 }
2337 
2338 } // namespace
2339 
2340 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2341  MLIRContext *context) {
2342  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2343  results.add(foldExtractFromShapeCastToShapeCast);
2344  results.add(foldExtractFromFromElements);
2345 }
2346 
2347 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
2348  SmallVectorImpl<int64_t> &results) {
2349  for (auto attr : arrayAttr)
2350  results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2351 }
2352 
2353 //===----------------------------------------------------------------------===//
2354 // FmaOp
2355 //===----------------------------------------------------------------------===//
2356 
2357 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2358  return llvm::to_vector<4>(getVectorType().getShape());
2359 }
2360 
2361 //===----------------------------------------------------------------------===//
2362 // ToElementsOp
2363 //===----------------------------------------------------------------------===//
2364 
2365 /// Returns true if all the `operands` are defined by `defOp`.
2366 /// Otherwise, returns false.
2367 static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
2368  if (operands.empty())
2369  return false;
2370 
2371  return llvm::all_of(operands, [&](Value operand) {
2372  Operation *currentDef = operand.getDefiningOp();
2373  return currentDef == defOp;
2374  });
2375 }
2376 
2377 /// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
2378 /// (%e0, %e1, ...). For example:
2379 ///
2380 /// %0 = vector.from_elements %a, %b, %c : vector<3xf32>
2381 /// %1:3 = vector.to_elements %0 : vector<3xf32>
2382 /// user_op %1#0, %1#1, %1#2
2383 ///
2384 /// becomes:
2385 ///
2386 /// user_op %a, %b, %c
2387 ///
2388 static LogicalResult
2389 foldToElementsFromElements(ToElementsOp toElementsOp,
2390  SmallVectorImpl<OpFoldResult> &results) {
2391  auto fromElementsOp =
2392  toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2393  if (!fromElementsOp)
2394  return failure();
2395 
2396  llvm::append_range(results, fromElementsOp.getElements());
2397  return success();
2398 }
2399 
2400 LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2401  SmallVectorImpl<OpFoldResult> &results) {
2402  return foldToElementsFromElements(*this, results);
2403 }
2404 
2405 LogicalResult
2406 ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2407  ToElementsOp::Adaptor adaptor,
2408  SmallVectorImpl<Type> &inferredReturnTypes) {
2409  auto vecType = cast<VectorType>(adaptor.getSource().getType());
2410  Type elType = vecType.getElementType();
2411  inferredReturnTypes.append(vecType.getNumElements(), elType);
2412  return success();
2413 }
2414 
2415 //===----------------------------------------------------------------------===//
2416 // FromElementsOp
2417 //===----------------------------------------------------------------------===//
2418 
2419 /// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
2420 ///
2421 /// Case #1: Input and output vectors are the same.
2422 ///
2423 /// %0:3 = vector.to_elements %a : vector<3xf32>
2424 /// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
2425 /// user_op %1
2426 ///
2427 /// becomes:
2428 ///
2429 /// user_op %a
2430 ///
2431 static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
2432  OperandRange fromElemsOperands = fromElementsOp.getElements();
2433  if (fromElemsOperands.empty())
2434  return {};
2435 
2436  auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2437  if (!toElementsOp)
2438  return {};
2439 
2440  if (!haveSameDefiningOp(fromElemsOperands, toElementsOp))
2441  return {};
2442 
2443  // Case #1: Input and output vectors are the same. Forward the input vector.
2444  Value toElementsInput = toElementsOp.getSource();
2445  if (fromElementsOp.getType() == toElementsInput.getType() &&
2446  llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2447  return toElementsInput;
2448  }
2449 
2450  // TODO: Support cases with different input and output shapes and different
2451  // number of elements.
2452 
2453  return {};
2454 }
2455 
2456 /// Fold vector.from_elements to a constant when all operands are constants.
2457 /// Example:
2458 /// %c1 = arith.constant 1 : i32
2459 /// %c2 = arith.constant 2 : i32
2460 /// %v = vector.from_elements %c1, %c2 : vector<2xi32>
2461 /// =>
2462 /// %v = arith.constant dense<[1, 2]> : vector<2xi32>
2463 ///
2464 static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
2465  ArrayRef<Attribute> elements) {
2466  if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
2467  return {};
2468 
2469  // DenseElementsAttr only supports int/index/float/complex types.
2470  auto destVecType = fromElementsOp.getDest().getType();
2471  auto destEltType = destVecType.getElementType();
2472  if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2473  return {};
2474 
2475  // Constant attributes might have a different type than the return type.
2476  // Convert them before creating the dense elements attribute.
2477  auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
2478  return convertIntegerAttr(attr, destEltType);
2479  });
2480 
2481  return DenseElementsAttr::get(destVecType, convertedElements);
2482 }
2483 
2484 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2485  if (auto res = foldFromElementsToElements(*this))
2486  return res;
2487  if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
2488  return res;
2489 
2490  return {};
2491 }
2492 
2493 /// Rewrite vector.from_elements as vector.broadcast if the elements are the
2494 /// same. Example:
2495 /// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2496 /// =>
2497 /// %0 = vector.broadcast %a : f32 to vector<3xf32>
2498 static LogicalResult
2499 rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
2500  PatternRewriter &rewriter) {
2501  if (!llvm::all_equal(fromElementsOp.getElements()))
2502  return failure();
2503  rewriter.replaceOpWithNewOp<BroadcastOp>(
2504  fromElementsOp, fromElementsOp.getType(),
2505  fromElementsOp.getElements().front());
2506  return success();
2507 }
2508 
2509 /// Rewrite from_elements on multiple scalar extracts as a shape_cast
2510 /// on a single extract. Example:
2511 /// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
2512 /// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
2513 /// %2 = vector.from_elements %0, %1 : vector<2xi8>
2514 ///
2515 /// becomes
2516 /// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
2517 /// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
2518 ///
2519 /// The requirements for this to be valid are
2520 ///
2521 /// i) The elements are extracted from the same vector (%source).
2522 ///
2523 /// ii) The elements form a suffix of %source. Specifically, the number
2524 /// of elements is the same as the product of the last N dimension sizes
2525 /// of %source, for some N.
2526 ///
2527 /// iii) The elements are extracted contiguously in ascending order.
2528 
2529 class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
2530 
2531  using OpRewritePattern::OpRewritePattern;
2532 
2533  LogicalResult matchAndRewrite(FromElementsOp fromElements,
2534  PatternRewriter &rewriter) const override {
2535 
2536  // Handled by `rewriteFromElementsAsBroadcast`.
2537  if (fromElements.getType().getNumElements() == 1)
2538  return failure();
2539 
2540  // The common source that all elements are extracted from, if one exists.
2541  TypedValue<VectorType> source;
2542  // The position of the combined extract operation, if one is created.
2543  ArrayRef<int64_t> combinedPosition;
2544  // The expected index of extraction of the current element in the loop, if
2545  // elements are extracted contiguously in ascending order.
2546  SmallVector<int64_t> expectedPosition;
2547 
2548  for (auto [insertIndex, element] :
2549  llvm::enumerate(fromElements.getElements())) {
2550 
2551  // Check that the element is from a vector.extract operation.
2552  auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2553  if (!extractOp) {
2554  return rewriter.notifyMatchFailure(fromElements,
2555  "element not from vector.extract");
2556  }
2557 
2558  // Check condition (i) by checking that all elements have the same source
2559  // as the first element.
2560  if (insertIndex == 0) {
2561  source = extractOp.getVector();
2562  } else if (extractOp.getVector() != source) {
2563  return rewriter.notifyMatchFailure(fromElements,
2564  "element from different vector");
2565  }
2566 
2567  ArrayRef<int64_t> position = extractOp.getStaticPosition();
2568  int64_t rank = position.size();
2569  assert(rank == source.getType().getRank() &&
2570  "scalar extract must have full rank position");
2571 
2572  // Check condition (ii) by checking that the position that the first
2573  // element is extracted from has sufficient trailing 0s. For example, in
2574  //
2575  // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
2576  // [...]
2577  // %elms = vector.from_elements %elm0, [...] : vector<12xi8>
2578  //
2579  // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
2580  // elements, which is the number of elements of %n, so this is valid.
2581  if (insertIndex == 0) {
2582  const int64_t numElms = fromElements.getType().getNumElements();
2583  int64_t numSuffixElms = 1;
2584  int64_t index = rank;
2585  while (index > 0 && position[index - 1] == 0 &&
2586  numSuffixElms < numElms) {
2587  numSuffixElms *= source.getType().getDimSize(index - 1);
2588  --index;
2589  }
2590  if (numSuffixElms != numElms) {
2591  return rewriter.notifyMatchFailure(
2592  fromElements, "elements do not form a suffix of source");
2593  }
2594  expectedPosition = llvm::to_vector(position);
2595  combinedPosition = position.drop_back(rank - index);
2596  }
2597 
2598  // Check condition (iii).
2599  else if (expectedPosition != position) {
2600  return rewriter.notifyMatchFailure(
2601  fromElements, "elements not in ascending order (static order)");
2602  }
2603  increment(expectedPosition, source.getType().getShape());
2604  }
2605 
2606  auto extracted = rewriter.createOrFold<vector::ExtractOp>(
2607  fromElements.getLoc(), source, combinedPosition);
2608 
2609  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2610  fromElements, fromElements.getType(), extracted);
2611 
2612  return success();
2613  }
2614 
2615  /// Increments n-D `indices` by 1 starting from the innermost dimension.
2616  static void increment(MutableArrayRef<int64_t> indices,
2617  ArrayRef<int64_t> shape) {
2618  for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2619  indices[dim] += 1;
2620  if (indices[dim] < shape[dim])
2621  break;
2622  indices[dim] = 0;
2623  }
2624  }
2625 };
2626 
2627 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2628  MLIRContext *context) {
2630  results.add<FromElementsToShapeCast>(context);
2631 }
2632 
2633 //===----------------------------------------------------------------------===//
2634 // BroadcastOp
2635 //===----------------------------------------------------------------------===//
2636 
2637 void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2638  SetIntRangeFn setResultRanges) {
2639  setResultRanges(getResult(), argRanges.front());
2640 }
2641 
2642 std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2643  return llvm::to_vector<4>(getResultVectorType().getShape());
2644 }
2645 
2646 /// Return the dimensions of the result vector that were formerly ones in the
2647 /// source tensor and thus correspond to "dim-1" broadcasting.
2650  ArrayRef<int64_t> dstShape) {
2651  int64_t rankDiff = dstShape.size() - srcShape.size();
2652  int64_t dstDim = rankDiff;
2654  for (auto [s1, s2] :
2655  llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2656  if (s1 != s2) {
2657  assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2658  res.insert(dstDim);
2659  }
2660  ++dstDim;
2661  }
2662  return res;
2663 }
2664 
2666  // Scalar broadcast is without any unit dim broadcast.
2667  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2668  if (!srcVectorType)
2669  return {};
2670  return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2671  getResultVectorType().getShape());
2672 }
2673 
2674 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2675 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
2676 /// This requires (and asserts) that the broadcast is free of "dim-1"
2677 /// broadcasting.
2678 /// Since vector.broadcast only allows expanding leading dimensions, an extra
2679 /// vector.transpose may be inserted to make the broadcast possible.
2680 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2681 /// the helper will assert. This means:
2682 /// 1. `dstShape` must not be empty.
2683 /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2684 /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2685 // must match the `value` shape.
2686 Value BroadcastOp::createOrFoldBroadcastOp(
2687  OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2688  const llvm::SetVector<int64_t> &broadcastedDims) {
2689  assert(!dstShape.empty() && "unexpected empty dst shape");
2690 
2691  // Well-formedness check.
2692  SmallVector<int64_t> checkShape;
2693  for (int i = 0, e = dstShape.size(); i < e; ++i) {
2694  if (broadcastedDims.contains(i))
2695  continue;
2696  checkShape.push_back(dstShape[i]);
2697  }
2698  assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2699  "ill-formed broadcastedDims contains values not confined to "
2700  "destVectorShape");
2701 
2702  Location loc = value.getLoc();
2703  Type elementType = getElementTypeOrSelf(value.getType());
2704  VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2705  VectorType dstVectorType = VectorType::get(dstShape, elementType);
2706 
2707  // Step 2. If scalar -> dstShape broadcast, just do it.
2708  if (!srcVectorType) {
2709  assert(checkShape.empty() &&
2710  "ill-formed createOrFoldBroadcastOp arguments");
2711  return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2712  }
2713 
2714  assert(srcVectorType.getShape().equals(checkShape) &&
2715  "ill-formed createOrFoldBroadcastOp arguments");
2716 
2717  // Step 3. Since vector.broadcast only allows creating leading dims,
2718  // vector -> dstShape broadcast may require a transpose.
2719  // Traverse the dims in order and construct:
2720  // 1. The leading entries of the broadcastShape that is guaranteed to be
2721  // achievable by a simple broadcast.
2722  // 2. The induced permutation for the subsequent vector.transpose that will
2723  // bring us from `broadcastShape` back to he desired `dstShape`.
2724  // If the induced permutation is not the identity, create a vector.transpose.
2725  SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2726  broadcastShape.reserve(dstShape.size());
2727  // Consider the example:
2728  // srcShape = 2x4
2729  // dstShape = 1x2x3x4x5
2730  // broadcastedDims = [0, 2, 4]
2731  //
2732  // We want to build:
2733  // broadcastShape = 1x3x5x2x4
2734  // permutation = [0, 2, 4, 1, 3]
2735  // ---V--- -----V-----
2736  // leading broadcast part src shape part
2737  //
2738  // Note that the trailing dims of broadcastShape are exactly the srcShape
2739  // by construction.
2740  // nextSrcShapeDim is used to keep track of where in the permutation the
2741  // "src shape part" occurs.
2742  int64_t nextSrcShapeDim = broadcastedDims.size();
2743  for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2744  if (broadcastedDims.contains(i)) {
2745  // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2746  // bring it to the head of the broadcastShape.
2747  // It will need to be permuted back from `broadcastShape.size() - 1` into
2748  // position `i`.
2749  broadcastShape.push_back(dstShape[i]);
2750  permutation[i] = broadcastShape.size() - 1;
2751  } else {
2752  // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2753  // shape and needs to be permuted into position `i`.
2754  // Don't touch `broadcastShape` here, the whole srcShape will be
2755  // appended after.
2756  permutation[i] = nextSrcShapeDim++;
2757  }
2758  }
2759  // 3.c. Append the srcShape.
2760  llvm::append_range(broadcastShape, srcVectorType.getShape());
2761 
2762  // Ensure there are no "dim-1" broadcasts.
2763  assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2764  .empty() &&
2765  "unexpected \"dim-1\" broadcast");
2766 
2767  VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2768  assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2769  vector::BroadcastableToResult::Success &&
2770  "must be broadcastable");
2771  Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2772  // Step 4. If we find any dimension that indeed needs to be permuted,
2773  // immediately return a new vector.transpose.
2774  for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2775  if (permutation[i] != i)
2776  return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2777  // Otherwise return res.
2778  return res;
2779 }
2780 
2782  Type srcType, VectorType dstVectorType,
2783  std::pair<VectorDim, VectorDim> *mismatchingDims) {
2784  // Broadcast scalar to vector of the same element type.
2785  if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
2786  srcType == getElementTypeOrSelf(dstVectorType))
2787  return BroadcastableToResult::Success;
2788  // From now on, only vectors broadcast.
2789  VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2790  if (!srcVectorType)
2791  return BroadcastableToResult::SourceTypeNotAVector;
2792 
2793  int64_t srcRank = srcVectorType.getRank();
2794  int64_t dstRank = dstVectorType.getRank();
2795  if (srcRank > dstRank)
2796  return BroadcastableToResult::SourceRankHigher;
2797  // Source has an exact match or singleton value for all trailing dimensions
2798  // (all leading dimensions are simply duplicated).
2799  int64_t lead = dstRank - srcRank;
2800  for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2801  // Have mismatching dims (in the sense of vector.broadcast semantics) been
2802  // encountered?
2803  bool foundMismatchingDims = false;
2804 
2805  // Check fixed-width dims.
2806  int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2807  int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2808  if (srcDim != 1 && srcDim != dstDim)
2809  foundMismatchingDims = true;
2810 
2811  // Check scalable flags.
2812  bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2813  bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2814  if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2815  // 1 -> [N] is fine, everything else should be rejected when mixing
2816  // fixed-width and scalable dims
2817  (srcDimScalableFlag != dstDimScalableFlag &&
2818  (srcDim != 1 || srcDimScalableFlag)))
2819  foundMismatchingDims = true;
2820 
2821  if (foundMismatchingDims) {
2822  if (mismatchingDims != nullptr) {
2823  mismatchingDims->first.dim = srcDim;
2824  mismatchingDims->first.isScalable = srcDimScalableFlag;
2825 
2826  mismatchingDims->second.dim = dstDim;
2827  mismatchingDims->second.isScalable = dstDimScalableFlag;
2828  }
2829  return BroadcastableToResult::DimensionMismatch;
2830  }
2831  }
2832 
2833  return BroadcastableToResult::Success;
2834 }
2835 
2836 LogicalResult BroadcastOp::verify() {
2837  std::pair<VectorDim, VectorDim> mismatchingDims;
2839  getSourceType(), getResultVectorType(), &mismatchingDims);
2840  if (res == BroadcastableToResult::Success)
2841  return success();
2842  if (res == BroadcastableToResult::SourceRankHigher)
2843  return emitOpError("source rank higher than destination rank");
2844  if (res == BroadcastableToResult::DimensionMismatch) {
2845  return emitOpError("dimension mismatch (")
2846  << (mismatchingDims.first.isScalable ? "[" : "")
2847  << mismatchingDims.first.dim
2848  << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
2849  << (mismatchingDims.second.isScalable ? "[" : "")
2850  << mismatchingDims.second.dim
2851  << (mismatchingDims.second.isScalable ? "]" : "") << ")";
2852  }
2853  if (res == BroadcastableToResult::SourceTypeNotAVector)
2854  return emitOpError("source type is not a vector");
2855  llvm_unreachable("unexpected vector.broadcast op error");
2856 }
2857 
2858 // Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
2859 // with broadcast's result type and shape_cast only adds or removes ones in the
2860 // leading dimensions.
2861 static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
2862  auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
2863  if (!srcShapeCast)
2864  return failure();
2865 
2866  VectorType srcType = srcShapeCast.getSourceVectorType();
2867  VectorType destType = broadcastOp.getResultVectorType();
2868  // Check type compatibility.
2869  if (vector::isBroadcastableTo(srcType, destType) !=
2870  BroadcastableToResult::Success)
2871  return failure();
2872 
2873  ArrayRef<int64_t> srcShape = srcType.getShape();
2874  ArrayRef<int64_t> shapecastShape =
2875  srcShapeCast.getResultVectorType().getShape();
2876  // Trailing dimensions should be the same if shape_cast only alters the
2877  // leading dimensions.
2878  unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
2879  if (!llvm::equal(srcShape.take_back(numTrailingDims),
2880  shapecastShape.take_back(numTrailingDims)))
2881  return failure();
2882 
2883  assert(all_of(srcShape.drop_back(numTrailingDims),
2884  [](int64_t E) { return E == 1; }) &&
2885  all_of(shapecastShape.drop_back(numTrailingDims),
2886  [](int64_t E) { return E == 1; }) &&
2887  "ill-formed shape_cast");
2888 
2889  broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
2890  return success();
2891 }
2892 
2893 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
2894  if (getSourceType() == getResultVectorType())
2895  return getSource();
2896  if (succeeded(foldBroadcastOfShapeCast(*this)))
2897  return getResult();
2898 
2899  if (!adaptor.getSource())
2900  return {};
2901  auto vectorType = getResultVectorType();
2902  if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
2903  if (vectorType.getElementType() != attr.getType())
2904  return {};
2905  return DenseElementsAttr::get(vectorType, attr);
2906  }
2907  if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
2908  if (vectorType.getElementType() != attr.getType())
2909  return {};
2910  return DenseElementsAttr::get(vectorType, attr);
2911  }
2912  if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2913  return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
2914  if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
2915  return ub::PoisonAttr::get(getContext());
2916  return {};
2917 }
2918 
2919 namespace {
2920 
2921 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
2922 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2924 
2925  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2926  PatternRewriter &rewriter) const override {
2927  auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2928  if (!srcBroadcast)
2929  return failure();
2930  rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
2931  broadcastOp.getResultVectorType(),
2932  srcBroadcast.getSource());
2933  return success();
2934  }
2935 };
2936 } // namespace
2937 
2938 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2939  MLIRContext *context) {
2940  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2941  // calling `populateCastAwayVectorLeadingOneDimPatterns`
2942  results.add<BroadcastFolder>(context);
2943 }
2944 
2945 //===----------------------------------------------------------------------===//
2946 // ShuffleOp
2947 //===----------------------------------------------------------------------===//
2948 
2949 LogicalResult ShuffleOp::verify() {
2950  VectorType resultType = getResultVectorType();
2951  VectorType v1Type = getV1VectorType();
2952  VectorType v2Type = getV2VectorType();
2953  // Verify ranks.
2954  int64_t resRank = resultType.getRank();
2955  int64_t v1Rank = v1Type.getRank();
2956  int64_t v2Rank = v2Type.getRank();
2957  bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2958  bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2959  if (!wellFormed0DCase && !wellFormedNDCase)
2960  return emitOpError("rank mismatch");
2961 
2962  // Verify all but leading dimension sizes.
2963  for (int64_t r = 1; r < v1Rank; ++r) {
2964  int64_t resDim = resultType.getDimSize(r);
2965  int64_t v1Dim = v1Type.getDimSize(r);
2966  int64_t v2Dim = v2Type.getDimSize(r);
2967  if (resDim != v1Dim || v1Dim != v2Dim)
2968  return emitOpError("dimension mismatch");
2969  }
2970  // Verify mask length.
2971  ArrayRef<int64_t> mask = getMask();
2972  int64_t maskLength = mask.size();
2973  if (maskLength <= 0)
2974  return emitOpError("invalid mask length");
2975  if (maskLength != resultType.getDimSize(0))
2976  return emitOpError("mask length mismatch");
2977  // Verify all indices.
2978  int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2979  (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2980  for (auto [idx, maskPos] : llvm::enumerate(mask)) {
2981  if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
2982  return emitOpError("mask index #") << (idx + 1) << " out of range";
2983  }
2984  return success();
2985 }
2986 
2987 LogicalResult
2988 ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
2989  ShuffleOp::Adaptor adaptor,
2990  SmallVectorImpl<Type> &inferredReturnTypes) {
2991  auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2992  auto v1Rank = v1Type.getRank();
2993  // Construct resulting type: leading dimension matches mask
2994  // length, all trailing dimensions match the operands.
2996  shape.reserve(v1Rank);
2997  shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2998  // In the 0-D case there is no trailing shape to append.
2999  if (v1Rank > 0)
3000  llvm::append_range(shape, v1Type.getShape().drop_front());
3001  inferredReturnTypes.push_back(
3002  VectorType::get(shape, v1Type.getElementType()));
3003  return success();
3004 }
3005 
3006 template <typename T>
3007 static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
3008  T expected = begin;
3009  return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3010  return value == expected++;
3011  });
3012 }
3013 
3014 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3015  auto v1Type = getV1VectorType();
3016  auto v2Type = getV2VectorType();
3017 
3018  assert(!v1Type.isScalable() && !v2Type.isScalable() &&
3019  "Vector shuffle does not support scalable vectors");
3020 
3021  // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
3022  // but must be a canonicalization into a vector.broadcast.
3023  if (v1Type.getRank() == 0)
3024  return {};
3025 
3026  // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
3027  auto mask = getMask();
3028  if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
3029  return getV1();
3030  // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
3031  if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
3032  return getV2();
3033 
3034  Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3035  if (!v1Attr || !v2Attr)
3036  return {};
3037 
3038  // Fold shuffle poison, poison -> poison.
3039  bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
3040  bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
3041  if (isV1Poison && isV2Poison)
3042  return ub::PoisonAttr::get(getContext());
3043 
3044  // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
3045  // manipulation.
3046  if (v1Type.getRank() != 1)
3047  return {};
3048 
3049  // Poison input attributes need special handling as they are not
3050  // DenseElementsAttr. If an index is poison, we select the first element of
3051  // the first non-poison input.
3052  SmallVector<Attribute> v1Elements, v2Elements;
3053  Attribute poisonElement;
3054  if (!isV2Poison) {
3055  auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3056  if (!v2DenseAttr)
3057  return {};
3058  v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3059  poisonElement = v2Elements[0];
3060  }
3061  if (!isV1Poison) {
3062  auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3063  if (!v1DenseAttr)
3064  return {};
3065  v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3066  poisonElement = v1Elements[0];
3067  }
3068 
3069  SmallVector<Attribute> results;
3070  int64_t v1Size = v1Type.getDimSize(0);
3071  for (int64_t maskIdx : mask) {
3072  Attribute indexedElm;
3073  // TODO: Return a partial poison vector when supported by the UB dialect.
3074  if (maskIdx == ShuffleOp::kPoisonIndex) {
3075  indexedElm = poisonElement;
3076  } else {
3077  if (maskIdx < v1Size)
3078  indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3079  else
3080  indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3081  }
3082 
3083  results.push_back(indexedElm);
3084  }
3085 
3086  return DenseElementsAttr::get(getResultVectorType(), results);
3087 }
3088 
3089 namespace {
3090 
3091 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
3092 // to a broadcast.
3093 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
3095 
3096  LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3097  PatternRewriter &rewriter) const override {
3098  VectorType v1VectorType = shuffleOp.getV1VectorType();
3099  ArrayRef<int64_t> mask = shuffleOp.getMask();
3100  if (v1VectorType.getRank() > 0)
3101  return failure();
3102  if (mask.size() != 1)
3103  return failure();
3104  VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3105  if (mask[0] == 0)
3106  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3107  shuffleOp.getV1());
3108  else
3109  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3110  shuffleOp.getV2());
3111  return success();
3112  }
3113 };
3114 
3115 /// Consider the defining operation `defOp` of `value`. If `defOp` is a
3116 /// vector.splat or a vector.broadcast with a scalar operand, return the scalar
3117 /// value that is splatted. Otherwise return null.
3118 ///
3119 /// Examples:
3120 ///
3121 /// scalar_source --> vector.splat --> value - return scalar_source
3122 /// scalar_source --> vector.broadcast --> value - return scalar_source
3123 static Value getScalarSplatSource(Value value) {
3124  // Block argument:
3125  Operation *defOp = value.getDefiningOp();
3126  if (!defOp)
3127  return {};
3128 
3129  // Splat:
3130  if (auto splat = dyn_cast<vector::SplatOp>(defOp))
3131  return splat.getInput();
3132 
3133  auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3134 
3135  // Not broadcast (and not splat):
3136  if (!broadcast)
3137  return {};
3138 
3139  // Broadcast of a vector:
3140  if (isa<VectorType>(broadcast.getSourceType()))
3141  return {};
3142 
3143  // Broadcast of a scalar:
3144  return broadcast.getSource();
3145 }
3146 
3147 /// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
3148 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
3149 public:
3151 
3152  LogicalResult matchAndRewrite(ShuffleOp op,
3153  PatternRewriter &rewriter) const override {
3154  Value splat = getScalarSplatSource(op.getV1());
3155  if (!splat || getScalarSplatSource(op.getV2()) != splat)
3156  return failure();
3157 
3158  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3159  return success();
3160  }
3161 };
3162 
3163 /// Pattern to rewrite a fixed-size interleave via vector.shuffle to
3164 /// vector.interleave.
3165 class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
3166 public:
3168 
3169  LogicalResult matchAndRewrite(ShuffleOp op,
3170  PatternRewriter &rewriter) const override {
3171  VectorType resultType = op.getResultVectorType();
3172  if (resultType.isScalable())
3173  return rewriter.notifyMatchFailure(
3174  op, "ShuffleOp can't represent a scalable interleave");
3175 
3176  if (resultType.getRank() != 1)
3177  return rewriter.notifyMatchFailure(
3178  op, "ShuffleOp can't represent an n-D interleave");
3179 
3180  VectorType sourceType = op.getV1VectorType();
3181  if (sourceType != op.getV2VectorType() ||
3182  sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3183  return rewriter.notifyMatchFailure(
3184  op, "ShuffleOp types don't match an interleave");
3185  }
3186 
3187  ArrayRef<int64_t> shuffleMask = op.getMask();
3188  int64_t resultVectorSize = resultType.getNumElements();
3189  for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3190  int64_t maskValueA = shuffleMask[i * 2];
3191  int64_t maskValueB = shuffleMask[(i * 2) + 1];
3192  if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3193  return rewriter.notifyMatchFailure(op,
3194  "ShuffleOp mask not interleaving");
3195  }
3196 
3197  rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
3198  return success();
3199  }
3200 };
3201 
3202 } // namespace
3203 
3204 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3205  MLIRContext *context) {
3206  results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3207  context);
3208 }
3209 
3210 //===----------------------------------------------------------------------===//
3211 // InsertOp
3212 //===----------------------------------------------------------------------===//
3213 
3214 void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3215  SetIntRangeFn setResultRanges) {
3216  setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3217 }
3218 
3219 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3220  Value source, Value dest) {
3221  auto vectorTy = cast<VectorType>(dest.getType());
3222  build(builder, result, source, dest,
3223  SmallVector<int64_t>(vectorTy.getRank(), 0));
3224 }
3225 
3226 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3227  Value source, Value dest, int64_t position) {
3228  build(builder, result, source, dest, ArrayRef<int64_t>{position});
3229 }
3230 
3231 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3232  Value source, Value dest, OpFoldResult position) {
3233  build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
3234 }
3235 
3236 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3237  Value source, Value dest,
3238  ArrayRef<int64_t> position) {
3239  SmallVector<OpFoldResult> posVals;
3240  posVals.reserve(position.size());
3241  llvm::transform(position, std::back_inserter(posVals),
3242  [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
3243  build(builder, result, source, dest, posVals);
3244 }
3245 
3246 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3247  Value source, Value dest,
3248  ArrayRef<OpFoldResult> position) {
3249  SmallVector<int64_t> staticPos;
3250  SmallVector<Value> dynamicPos;
3251  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
3252  build(builder, result, source, dest, dynamicPos,
3253  builder.getDenseI64ArrayAttr(staticPos));
3254 }
3255 
3256 LogicalResult InsertOp::verify() {
3257  if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3258  if (srcTy.getRank() == 0)
3259  return emitError(
3260  "expected a scalar instead of a 0-d vector as the source operand");
3261 
3262  SmallVector<OpFoldResult> position = getMixedPosition();
3263  auto destVectorType = getDestVectorType();
3264  if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
3265  return emitOpError(
3266  "expected position attribute of rank no greater than dest vector rank");
3267  auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3268  if (srcVectorType &&
3269  (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3270  static_cast<unsigned>(destVectorType.getRank())))
3271  return emitOpError("expected position attribute rank + source rank to "
3272  "match dest vector rank");
3273  if (!srcVectorType &&
3274  (position.size() != static_cast<unsigned>(destVectorType.getRank())))
3275  return emitOpError(
3276  "expected position attribute rank to match the dest vector rank");
3277  for (auto [idx, pos] : llvm::enumerate(position)) {
3278  if (auto attr = dyn_cast<Attribute>(pos)) {
3279  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3280  if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
3281  destVectorType.getDimSize(idx))) {
3282  return emitOpError("expected position attribute #")
3283  << (idx + 1)
3284  << " to be a non-negative integer smaller than the "
3285  "corresponding "
3286  "dest vector dimension";
3287  }
3288  }
3289  }
3290  return success();
3291 }
3292 
3293 // Calculate the linearized position of the continuous chunk of elements to
3294 // insert, based on the shape of the value to insert and the positions to insert
3295 // at.
3296 static int64_t calculateInsertPosition(VectorType destTy,
3297  ArrayRef<int64_t> positions) {
3298  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3299  assert(positions.size() <= completePositions.size() &&
3300  "positions size must be less than or equal to destTy rank");
3301  copy(positions, completePositions.begin());
3302  return linearize(completePositions, computeStrides(destTy.getShape()));
3303 }
3304 
3305 namespace {
3306 
3307 // If insertOp is only inserting unit dimensions it can be transformed to a
3308 // broadcast.
3309 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3310 public:
3312 
3313  LogicalResult matchAndRewrite(InsertOp insertOp,
3314  PatternRewriter &rewriter) const override {
3315  auto srcVecType =
3316  llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3317  if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3318  srcVecType.getNumElements())
3319  return failure();
3320  rewriter.replaceOpWithNewOp<BroadcastOp>(
3321  insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3322  return success();
3323  }
3324 };
3325 
3326 /// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
3327 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3328 public:
3330 
3331  LogicalResult matchAndRewrite(InsertOp op,
3332  PatternRewriter &rewriter) const override {
3333 
3334  Value splat = getScalarSplatSource(op.getValueToStore());
3335  if (!splat || getScalarSplatSource(op.getDest()) != splat)
3336  return failure();
3337 
3338  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3339  return success();
3340  }
3341 };
3342 
3343 /// Pattern to optimize a chain of insertions.
3344 ///
3345 /// This pattern identifies chains of vector.insert operations that:
3346 /// 1. Only insert values at static positions.
3347 /// 2. Completely initialize all elements in the resulting vector.
3348 /// 3. All intermediate insert operations have only one use.
3349 ///
3350 /// When these conditions are met, the entire chain can be replaced with a
3351 /// single vector.from_elements operation.
3352 ///
3353 /// To keep this pattern simple, and avoid spending too much time on matching
3354 /// fragmented insert chains, this pattern only considers the last insert op in
3355 /// the chain.
3356 ///
3357 /// Example transformation:
3358 /// %poison = ub.poison : vector<2xi32>
3359 /// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
3360 /// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
3361 /// ->
3362 /// %result = vector.from_elements %c1, %c2 : vector<2xi32>
3363 class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
3364 public:
3366  LogicalResult matchAndRewrite(InsertOp op,
3367  PatternRewriter &rewriter) const override {
3368 
3369  VectorType destTy = op.getDestVectorType();
3370  if (destTy.isScalable())
3371  return failure();
3372  // Ensure this is the trailing vector.insert op in a chain of inserts.
3373  for (Operation *user : op.getResult().getUsers())
3374  if (auto insertOp = dyn_cast<InsertOp>(user))
3375  if (insertOp.getDest() == op.getResult())
3376  return failure();
3377 
3378  InsertOp currentOp = op;
3379  SmallVector<InsertOp> chainInsertOps;
3380  while (currentOp) {
3381  // Check cond 1: Dynamic position is not supported.
3382  if (currentOp.hasDynamicPosition())
3383  return failure();
3384 
3385  chainInsertOps.push_back(currentOp);
3386  currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3387  // Check cond 3: Intermediate inserts have only one use to avoid an
3388  // explosion of vectors.
3389  if (currentOp && !currentOp->hasOneUse())
3390  return failure();
3391  }
3392 
3393  int64_t vectorSize = destTy.getNumElements();
3394  int64_t initializedCount = 0;
3395  SmallVector<bool> initializedDestIdxs(vectorSize, false);
3396  SmallVector<int64_t> pendingInsertPos;
3397  SmallVector<int64_t> pendingInsertSize;
3398  SmallVector<Value> pendingInsertValues;
3399 
3400  for (auto insertOp : chainInsertOps) {
3401  // This pattern can do nothing with poison index.
3402  if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3403  return failure();
3404 
3405  // Calculate the linearized position for inserting elements.
3406  int64_t insertBeginPosition =
3407  calculateInsertPosition(destTy, insertOp.getStaticPosition());
3408 
3409  // The valueToStore operand may be a vector or a scalar. Need to handle
3410  // both cases.
3411  int64_t insertSize = 1;
3412  if (auto srcVectorType =
3413  llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3414  insertSize = srcVectorType.getNumElements();
3415 
3416  assert(insertBeginPosition + insertSize <= vectorSize &&
3417  "insert would overflow the vector");
3418 
3419  for (auto index : llvm::seq<int64_t>(insertBeginPosition,
3420  insertBeginPosition + insertSize)) {
3421  if (initializedDestIdxs[index])
3422  continue;
3423  initializedDestIdxs[index] = true;
3424  ++initializedCount;
3425  }
3426 
3427  // Defer the creation of ops before we can make sure the pattern can
3428  // succeed.
3429  pendingInsertPos.push_back(insertBeginPosition);
3430  pendingInsertSize.push_back(insertSize);
3431  pendingInsertValues.push_back(insertOp.getValueToStore());
3432 
3433  if (initializedCount == vectorSize)
3434  break;
3435  }
3436 
3437  // Check cond 2: all positions must be initialized.
3438  if (initializedCount != vectorSize)
3439  return failure();
3440 
3441  SmallVector<Value> elements(vectorSize);
3442  for (auto [insertBeginPosition, insertSize, valueToStore] :
3443  llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3444  pendingInsertValues))) {
3445  auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3446 
3447  if (!srcVectorType) {
3448  elements[insertBeginPosition] = valueToStore;
3449  continue;
3450  }
3451 
3452  SmallVector<Type> elementToInsertTypes(insertSize,
3453  srcVectorType.getElementType());
3454  // Get all elements from the vector in row-major order.
3455  auto elementsToInsert = rewriter.create<vector::ToElementsOp>(
3456  op.getLoc(), elementToInsertTypes, valueToStore);
3457  for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3458  elements[insertBeginPosition + linearIdx] =
3459  elementsToInsert.getResult(linearIdx);
3460  }
3461  }
3462 
3463  rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
3464  return success();
3465  }
3466 };
3467 
3468 } // namespace
3469 
3470 static Attribute
3472  Attribute dstAttr,
3473  int64_t maxVectorSizeFoldThreshold) {
3474  if (insertOp.hasDynamicPosition())
3475  return {};
3476 
3477  auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3478  if (!denseDst)
3479  return {};
3480 
3481  if (!srcAttr) {
3482  return {};
3483  }
3484 
3485  VectorType destTy = insertOp.getDestVectorType();
3486  if (destTy.isScalable())
3487  return {};
3488 
3489  // Make sure we do not create too many large constants.
3490  if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3491  !insertOp->hasOneUse())
3492  return {};
3493 
3494  // Calculate the linearized position for inserting elements.
3495  int64_t insertBeginPosition =
3496  calculateInsertPosition(destTy, insertOp.getStaticPosition());
3497  SmallVector<Attribute> insertedValues;
3498  Type destEltType = destTy.getElementType();
3499 
3500  /// Converts the expected type to an IntegerAttr if there's
3501  /// a mismatch.
3502  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3503  for (auto value : denseSource.getValues<Attribute>())
3504  insertedValues.push_back(convertIntegerAttr(value, destEltType));
3505  } else {
3506  insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
3507  }
3508 
3509  auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
3510  copy(insertedValues, allValues.begin() + insertBeginPosition);
3511  auto newAttr = DenseElementsAttr::get(destTy, allValues);
3512 
3513  return newAttr;
3514 }
3515 
3516 /// Folder to replace the `dest` operand of the insert op with the root dest of
3517 /// the insert op use chain.
3518 static Value foldInsertUseChain(InsertOp insertOp) {
3519  auto destInsert = insertOp.getDest().getDefiningOp<InsertOp>();
3520  if (!destInsert)
3521  return {};
3522 
3523  if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3524  return {};
3525 
3526  insertOp.setOperand(1, destInsert.getDest());
3527  return insertOp.getResult();
3528 }
3529 
3530 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3531  MLIRContext *context) {
3532  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3533  InsertChainFullyInitialized>(context);
3534 }
3535 
3536 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
3537  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3538  // unless the source vector constant has a single use.
3539  constexpr int64_t vectorSizeFoldThreshold = 256;
3540  // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3541  // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3542  // (type mismatch).
3543  if (getNumIndices() == 0 && getValueToStoreType() == getType())
3544  return getValueToStore();
3545  // Fold `arith.constant` indices into the `vector.insert` operation.
3546  // Do not stop here as this fold may enable subsequent folds that require
3547  // constant indices.
3548  SmallVector<Value> operands = {getValueToStore(), getDest()};
3549  auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
3550 
3551  if (auto res = foldInsertUseChain(*this))
3552  return res;
3553  if (auto res = foldPoisonIndexInsertExtractOp(
3554  getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3555  return res;
3556  if (auto res = foldDenseElementsAttrDestInsertOp(
3557  *this, adaptor.getValueToStore(), adaptor.getDest(),
3558  vectorSizeFoldThreshold)) {
3559  return res;
3560  }
3561 
3562  return inplaceFolded;
3563 }
3564 
3565 //===----------------------------------------------------------------------===//
3566 // InsertStridedSliceOp
3567 //===----------------------------------------------------------------------===//
3568 
3569 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3570  Value source, Value dest,
3571  ArrayRef<int64_t> offsets,
3572  ArrayRef<int64_t> strides) {
3573  result.addOperands({source, dest});
3574  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3575  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3576  result.addTypes(dest.getType());
3577  result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
3578  offsetsAttr);
3579  result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
3580  stridesAttr);
3581 }
3582 
3583 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
3584 template <typename OpType>
3585 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3586  ArrayAttr arrayAttr,
3587  ArrayRef<int64_t> shape,
3588  StringRef attrName) {
3589  if (arrayAttr.size() > shape.size())
3590  return op.emitOpError("expected ")
3591  << attrName << " attribute of rank no greater than vector rank";
3592  return success();
3593 }
3594 
3595 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3596 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3597 // Otherwise, the admissible interval is [min, max].
3598 template <typename OpType>
3599 static LogicalResult
3600 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
3601  int64_t max, StringRef attrName,
3602  bool halfOpen = true) {
3603  for (auto attr : arrayAttr) {
3604  auto val = llvm::cast<IntegerAttr>(attr).getInt();
3605  auto upper = max;
3606  if (!halfOpen)
3607  upper += 1;
3608  if (val < min || val >= upper)
3609  return op.emitOpError("expected ") << attrName << " to be confined to ["
3610  << min << ", " << upper << ")";
3611  }
3612  return success();
3613 }
3614 
3615 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3616 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3617 // Otherwise, the admissible interval is [min, max].
3618 template <typename OpType>
3619 static LogicalResult
3620 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
3621  ArrayRef<int64_t> shape, StringRef attrName,
3622  bool halfOpen = true, int64_t min = 0) {
3623  for (auto [index, attrDimPair] :
3624  llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3625  int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3626  int64_t max = std::get<1>(attrDimPair);
3627  if (!halfOpen)
3628  max += 1;
3629  if (val < min || val >= max)
3630  return op.emitOpError("expected ")
3631  << attrName << " dimension " << index << " to be confined to ["
3632  << min << ", " << max << ")";
3633  }
3634  return success();
3635 }
3636 
3637 // Returns true if, for all indices i = 0..shape.size()-1, val is in the
3638 // [min, max} interval:
3639 // val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3640 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3641 // the admissible interval is [min, max].
3642 template <typename OpType>
3644  OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3645  ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3646  bool halfOpen = true, int64_t min = 1) {
3647  assert(arrayAttr1.size() <= shape.size());
3648  assert(arrayAttr2.size() <= shape.size());
3649  for (auto [index, it] :
3650  llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3651  auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3652  auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3653  int64_t max = std::get<2>(it);
3654  if (!halfOpen)
3655  max += 1;
3656  if (val1 + val2 < 0 || val1 + val2 >= max)
3657  return op.emitOpError("expected sum(")
3658  << attrName1 << ", " << attrName2 << ") dimension " << index
3659  << " to be confined to [" << min << ", " << max << ")";
3660  }
3661  return success();
3662 }
3663 
3664 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
3665  MLIRContext *context) {
3666  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3667  return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3668  });
3669  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3670 }
3671 
3672 LogicalResult InsertStridedSliceOp::verify() {
3673  auto sourceVectorType = getSourceVectorType();
3674  auto destVectorType = getDestVectorType();
3675  auto offsets = getOffsetsAttr();
3676  auto strides = getStridesAttr();
3677  if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3678  return emitOpError(
3679  "expected offsets of same size as destination vector rank");
3680  if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3681  return emitOpError("expected strides of same size as source vector rank");
3682  if (sourceVectorType.getRank() > destVectorType.getRank())
3683  return emitOpError(
3684  "expected source rank to be no greater than destination rank");
3685 
3686  auto sourceShape = sourceVectorType.getShape();
3687  auto destShape = destVectorType.getShape();
3688  SmallVector<int64_t, 4> sourceShapeAsDestShape(
3689  destShape.size() - sourceShape.size(), 0);
3690  sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3691  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3692  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3693  if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3694  offName)) ||
3695  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3696  /*max=*/1, stridesName,
3697  /*halfOpen=*/false)) ||
3699  *this, offsets,
3700  makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3701  offName, "source vector shape",
3702  /*halfOpen=*/false, /*min=*/1)))
3703  return failure();
3704 
3705  unsigned rankDiff = destShape.size() - sourceShape.size();
3706  for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3707  if (sourceVectorType.getScalableDims()[idx] !=
3708  destVectorType.getScalableDims()[idx + rankDiff]) {
3709  return emitOpError("mismatching scalable flags (at source vector idx=")
3710  << idx << ")";
3711  }
3712  if (sourceVectorType.getScalableDims()[idx]) {
3713  auto sourceSize = sourceShape[idx];
3714  auto destSize = destShape[idx + rankDiff];
3715  if (sourceSize != destSize) {
3716  return emitOpError("expected size at idx=")
3717  << idx
3718  << (" to match the corresponding base size from the input "
3719  "vector (")
3720  << sourceSize << (" vs ") << destSize << (")");
3721  }
3722  }
3723  }
3724 
3725  return success();
3726 }
3727 
3728 namespace {
3729 /// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
3730 class FoldInsertStridedSliceSplat final
3731  : public OpRewritePattern<InsertStridedSliceOp> {
3732 public:
3734 
3735  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3736  PatternRewriter &rewriter) const override {
3737 
3738  auto dst = insertStridedSliceOp.getDest();
3739  auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
3740  if (!splat || getScalarSplatSource(dst) != splat)
3741  return failure();
3742 
3743  rewriter.replaceOp(insertStridedSliceOp, dst);
3744  return success();
3745  }
3746 };
3747 
3748 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3749 /// to dst.
3750 class FoldInsertStridedSliceOfExtract final
3751  : public OpRewritePattern<InsertStridedSliceOp> {
3752 public:
3754 
3755  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3756  PatternRewriter &rewriter) const override {
3757  auto extractStridedSliceOp =
3758  insertStridedSliceOp.getValueToStore()
3759  .getDefiningOp<vector::ExtractStridedSliceOp>();
3760 
3761  if (!extractStridedSliceOp)
3762  return failure();
3763 
3764  if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3765  return failure();
3766 
3767  // Check if have the same strides and offsets.
3768  if (extractStridedSliceOp.getStrides() !=
3769  insertStridedSliceOp.getStrides() ||
3770  extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3771  return failure();
3772 
3773  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3774  return success();
3775  }
3776 };
3777 
3778 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3779 // ConstantOp.
3780 class InsertStridedSliceConstantFolder final
3781  : public OpRewritePattern<InsertStridedSliceOp> {
3782 public:
3784 
3785  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3786  // unless the source vector constant has a single use.
3787  static constexpr int64_t vectorSizeFoldThreshold = 256;
3788 
3789  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3790  PatternRewriter &rewriter) const override {
3791  // Return if 'InsertOp' operand is not defined by a compatible vector
3792  // ConstantOp.
3793  TypedValue<VectorType> destVector = op.getDest();
3794  Attribute vectorDestCst;
3795  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3796  return failure();
3797 
3798  VectorType destTy = destVector.getType();
3799  if (destTy.isScalable())
3800  return failure();
3801 
3802  // Make sure we do not create too many large constants.
3803  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3804  !destVector.hasOneUse())
3805  return failure();
3806 
3807  TypedValue<VectorType> sourceValue = op.getValueToStore();
3808  Attribute sourceCst;
3809  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3810  return failure();
3811 
3812  // TODO: Support poison.
3813  if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3814  return failure();
3815 
3816  // TODO: Handle non-unit strides when they become available.
3817  if (op.hasNonUnitStrides())
3818  return failure();
3819 
3820  VectorType sliceVecTy = sourceValue.getType();
3821  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3822  int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3823  SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3824  SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3825 
3826  // Calcualte the destination element indices by enumerating all slice
3827  // positions within the destination and linearizing them. The enumeration
3828  // order is lexicographic which yields a sequence of monotonically
3829  // increasing linearized position indices.
3830  // Because the destination may have higher dimensionality then the slice,
3831  // we keep track of two overlapping sets of positions and offsets.
3832  auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3833  auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3834  auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3835  auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3836  SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3837  MutableArrayRef<int64_t> currSlicePosition(
3838  currDestPosition.begin() + rankDifference, currDestPosition.end());
3839  ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3840  offsets.end());
3841  do {
3842  int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3843  assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3844  assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3845  "Invalid slice element");
3846  newValues[linearizedPosition] = *sliceValuesIt;
3847  ++sliceValuesIt;
3848  } while (succeeded(
3849  incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3850 
3851  auto newAttr = DenseElementsAttr::get(destTy, newValues);
3852  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3853  return success();
3854  }
3855 };
3856 
3857 } // namespace
3858 
3859 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3860  RewritePatternSet &results, MLIRContext *context) {
3861  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3862  InsertStridedSliceConstantFolder>(context);
3863 }
3864 
3865 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3866  if (getSourceVectorType() == getDestVectorType())
3867  return getValueToStore();
3868  return {};
3869 }
3870 
3871 //===----------------------------------------------------------------------===//
3872 // OuterProductOp
3873 //===----------------------------------------------------------------------===//
3874 
3875 /// Build an op without mask, use the type of `acc` as the return type.
3876 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
3877  Value lhs, Value rhs, Value acc) {
3878  result.addOperands({lhs, rhs, acc});
3879  result.addTypes(acc.getType());
3880 }
3881 
3883  p << " " << getLhs() << ", " << getRhs();
3884  if (getAcc()) {
3885  p << ", " << getAcc();
3886  p.printOptionalAttrDict((*this)->getAttrs());
3887  }
3888  p << " : " << getLhs().getType() << ", " << getRhs().getType();
3889 }
3890 
3891 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
3893  Type tLHS, tRHS;
3894  if (parser.parseOperandList(operandsInfo) ||
3895  parser.parseOptionalAttrDict(result.attributes) ||
3896  parser.parseColonType(tLHS) || parser.parseComma() ||
3897  parser.parseType(tRHS))
3898  return failure();
3899  if (operandsInfo.size() < 2)
3900  return parser.emitError(parser.getNameLoc(),
3901  "expected at least 2 operands");
3902  VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3903  VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3904  if (!vLHS)
3905  return parser.emitError(parser.getNameLoc(),
3906  "expected vector type for operand #1");
3907 
3908  VectorType resType;
3909  if (vRHS) {
3910  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3911  vRHS.getScalableDims()[0]};
3912  resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
3913  vLHS.getElementType(), scalableDimsRes);
3914  } else {
3915  // Scalar RHS operand
3916  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3917  resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3918  scalableDimsRes);
3919  }
3920 
3921  if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
3922  result.attributes.append(
3923  OuterProductOp::getKindAttrName(result.name),
3925  OuterProductOp::getDefaultKind()));
3926  }
3927 
3928  return failure(
3929  parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
3930  parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
3931  (operandsInfo.size() > 2 &&
3932  parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
3933  parser.addTypeToList(resType, result.types));
3934 }
3935 
3936 LogicalResult OuterProductOp::verify() {
3937  Type tRHS = getOperandTypeRHS();
3938  VectorType vLHS = getOperandVectorTypeLHS(),
3939  vRHS = llvm::dyn_cast<VectorType>(tRHS),
3940  vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3941 
3942  if (vLHS.getRank() != 1)
3943  return emitOpError("expected 1-d vector for operand #1");
3944 
3945  if (vRHS) {
3946  // Proper OUTER operation.
3947  if (vRHS.getRank() != 1)
3948  return emitOpError("expected 1-d vector for operand #2");
3949  if (vRES.getRank() != 2)
3950  return emitOpError("expected 2-d vector result");
3951  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3952  return emitOpError("expected #1 operand dim to match result dim #1");
3953  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3954  return emitOpError("expected #2 operand dim to match result dim #2");
3955  if (vLHS.isScalable() && !vRHS.isScalable()) {
3956  // This restriction reflects what's currently supported in terms of
3957  // scalable vectors. However, we could relax this if there's a use case.
3958  return emitOpError(
3959  "expected either both or only #2 operand dim to be scalable");
3960  }
3961  } else {
3962  // An AXPY operation.
3963  if (vRES.getRank() != 1)
3964  return emitOpError("expected 1-d vector result");
3965  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3966  return emitOpError("expected #1 operand dim to match result dim #1");
3967  }
3968 
3969  if (vACC && vACC != vRES)
3970  return emitOpError("expected operand #3 of same type as result type");
3971 
3972  // Verify supported combining kind.
3973  if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
3974  return emitOpError("unsupported outerproduct type");
3975 
3976  return success();
3977 }
3978 
3979 // MaskableOpInterface methods.
3980 
3981 /// Returns the mask type expected by this operation. Mostly used for
3982 /// verification purposes. It requires the operation to be vectorized."
3983 Type OuterProductOp::getExpectedMaskType() {
3984  auto vecType = this->getResultVectorType();
3985  return VectorType::get(vecType.getShape(),
3986  IntegerType::get(vecType.getContext(), /*width=*/1),
3987  vecType.getScalableDims());
3988 }
3989 
3990 //===----------------------------------------------------------------------===//
3991 // ExtractStridedSliceOp
3992 //===----------------------------------------------------------------------===//
3993 
3994 // Inference works as follows:
3995 // 1. Add 'sizes' from prefix of dims in 'offsets'.
3996 // 2. Add sizes from 'vectorType' for remaining dims.
3997 // Scalable flags are inherited from 'vectorType'.
3998 static Type inferStridedSliceOpResultType(VectorType vectorType,
3999  ArrayAttr offsets, ArrayAttr sizes,
4000  ArrayAttr strides) {
4001  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4003  shape.reserve(vectorType.getRank());
4004  unsigned idx = 0;
4005  for (unsigned e = offsets.size(); idx < e; ++idx)
4006  shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4007  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4008  shape.push_back(vectorType.getShape()[idx]);
4009 
4010  return VectorType::get(shape, vectorType.getElementType(),
4011  vectorType.getScalableDims());
4012 }
4013 
4014 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
4015  Value source, ArrayRef<int64_t> offsets,
4016  ArrayRef<int64_t> sizes,
4017  ArrayRef<int64_t> strides) {
4018  result.addOperands(source);
4019  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
4020  auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
4021  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
4022  result.addTypes(
4023  inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
4024  offsetsAttr, sizesAttr, stridesAttr));
4025  result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
4026  offsetsAttr);
4027  result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
4028  sizesAttr);
4029  result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
4030  stridesAttr);
4031 }
4032 
4033 LogicalResult ExtractStridedSliceOp::verify() {
4034  auto type = getSourceVectorType();
4035  auto offsets = getOffsetsAttr();
4036  auto sizes = getSizesAttr();
4037  auto strides = getStridesAttr();
4038  if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4039  return emitOpError(
4040  "expected offsets, sizes and strides attributes of same size");
4041 
4042  auto shape = type.getShape();
4043  auto offName = getOffsetsAttrName();
4044  auto sizesName = getSizesAttrName();
4045  auto stridesName = getStridesAttrName();
4046  if (failed(
4047  isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
4048  failed(
4049  isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
4050  failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
4051  stridesName)) ||
4052  failed(
4053  isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
4054  failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
4055  /*halfOpen=*/false,
4056  /*min=*/1)) ||
4057  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
4058  /*max=*/1, stridesName,
4059  /*halfOpen=*/false)) ||
4060  failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
4061  shape, offName, sizesName,
4062  /*halfOpen=*/false)))
4063  return failure();
4064 
4065  auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
4066  offsets, sizes, strides);
4067  if (getResult().getType() != resultType)
4068  return emitOpError("expected result type to be ") << resultType;
4069 
4070  for (unsigned idx = 0; idx < sizes.size(); ++idx) {
4071  if (type.getScalableDims()[idx]) {
4072  auto inputDim = type.getShape()[idx];
4073  auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4074  if (inputDim != inputSize)
4075  return emitOpError("expected size at idx=")
4076  << idx
4077  << (" to match the corresponding base size from the input "
4078  "vector (")
4079  << inputSize << (" vs ") << inputDim << (")");
4080  }
4081  }
4082 
4083  return success();
4084 }
4085 
4086 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
4087 // to use the source of the InsertStrided ops if we can detect that the
4088 // extracted vector is a subset of one of the vector inserted.
4089 static LogicalResult
4090 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
4091  // Helper to extract integer out of ArrayAttr.
4092  auto getElement = [](ArrayAttr array, int idx) {
4093  return llvm::cast<IntegerAttr>(array[idx]).getInt();
4094  };
4095  ArrayAttr extractOffsets = op.getOffsets();
4096  ArrayAttr extractStrides = op.getStrides();
4097  ArrayAttr extractSizes = op.getSizes();
4098  auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
4099  while (insertOp) {
4100  if (op.getSourceVectorType().getRank() !=
4101  insertOp.getSourceVectorType().getRank())
4102  return failure();
4103  ArrayAttr insertOffsets = insertOp.getOffsets();
4104  ArrayAttr insertStrides = insertOp.getStrides();
4105  // If the rank of extract is greater than the rank of insert, we are likely
4106  // extracting a partial chunk of the vector inserted.
4107  if (extractOffsets.size() > insertOffsets.size())
4108  return failure();
4109  bool patialoverlap = false;
4110  bool disjoint = false;
4111  SmallVector<int64_t, 4> offsetDiffs;
4112  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4113  if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
4114  return failure();
4115  int64_t start = getElement(insertOffsets, dim);
4116  int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4117  int64_t offset = getElement(extractOffsets, dim);
4118  int64_t size = getElement(extractSizes, dim);
4119  // Check if the start of the extract offset is in the interval inserted.
4120  if (start <= offset && offset < end) {
4121  // If the extract interval overlaps but is not fully included we may
4122  // have a partial overlap that will prevent any folding.
4123  if (offset + size > end)
4124  patialoverlap = true;
4125  offsetDiffs.push_back(offset - start);
4126  continue;
4127  }
4128  disjoint = true;
4129  break;
4130  }
4131  // The extract element chunk is a subset of the insert element.
4132  if (!disjoint && !patialoverlap) {
4133  op.setOperand(insertOp.getValueToStore());
4134  // OpBuilder is only used as a helper to build an I64ArrayAttr.
4135  OpBuilder b(op.getContext());
4136  op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
4137  return success();
4138  }
4139  // If the chunk extracted is disjoint from the chunk inserted, keep looking
4140  // in the insert chain.
4141  if (disjoint)
4142  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4143  else {
4144  // The extracted vector partially overlap the inserted vector, we cannot
4145  // fold.
4146  return failure();
4147  }
4148  }
4149  return failure();
4150 }
4151 
4152 // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4153 static OpFoldResult
4155  Attribute foldInput) {
4156 
4157  auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4158  if (!dense)
4159  return {};
4160 
4161  // TODO: Handle non-unit strides when they become available.
4162  if (op.hasNonUnitStrides())
4163  return {};
4164 
4165  VectorType sourceVecTy = op.getSourceVectorType();
4166  ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
4167  SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
4168 
4169  VectorType sliceVecTy = op.getType();
4170  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4171  int64_t rank = sliceVecTy.getRank();
4172 
4173  // Expand offsets and sizes to match the vector rank.
4174  SmallVector<int64_t, 4> offsets(rank, 0);
4175  copy(getI64SubArray(op.getOffsets()), offsets.begin());
4176 
4177  SmallVector<int64_t, 4> sizes(sourceShape);
4178  copy(getI64SubArray(op.getSizes()), sizes.begin());
4179 
4180  // Calculate the slice elements by enumerating all slice positions and
4181  // linearizing them. The enumeration order is lexicographic which yields a
4182  // sequence of monotonically increasing linearized position indices.
4183  const auto denseValuesBegin = dense.value_begin<Attribute>();
4184  SmallVector<Attribute> sliceValues;
4185  sliceValues.reserve(sliceVecTy.getNumElements());
4186  SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
4187  do {
4188  int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
4189  assert(linearizedPosition < sourceVecTy.getNumElements() &&
4190  "Invalid index");
4191  sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4192  } while (succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
4193 
4194  assert(static_cast<int64_t>(sliceValues.size()) ==
4195  sliceVecTy.getNumElements() &&
4196  "Invalid number of slice elements");
4197  return DenseElementsAttr::get(sliceVecTy, sliceValues);
4198 }
4199 
4200 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4201  if (getSourceVectorType() == getResult().getType())
4202  return getVector();
4203  if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
4204  return getResult();
4205 
4206  // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
4207  if (auto splat =
4208  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
4209  DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
4210 
4211  // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4212  return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getVector());
4213 }
4214 
4215 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4216  populateFromInt64AttrArray(getOffsets(), results);
4217 }
4218 
4219 namespace {
4220 
4221 // Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
4222 // CreateMaskOp.
4223 //
4224 // Example:
4225 //
4226 // %mask = vector.create_mask %ub : vector<16xi1>
4227 // %slice = vector.extract_strided_slice [%offset] [8] [1]
4228 //
4229 // to
4230 //
4231 // %new_ub = arith.subi %ub, %offset
4232 // %mask = vector.create_mask %new_ub : vector<8xi1>
4233 class StridedSliceCreateMaskFolder final
4234  : public OpRewritePattern<ExtractStridedSliceOp> {
4236 
4237 public:
4238  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4239  PatternRewriter &rewriter) const override {
4240  Location loc = extractStridedSliceOp.getLoc();
4241  // Return if 'extractStridedSliceOp' operand is not defined by a
4242  // CreateMaskOp.
4243  auto createMaskOp =
4244  extractStridedSliceOp.getVector().getDefiningOp<CreateMaskOp>();
4245  if (!createMaskOp)
4246  return failure();
4247  // Return if 'extractStridedSliceOp' has non-unit strides.
4248  if (extractStridedSliceOp.hasNonUnitStrides())
4249  return failure();
4250  // Gather constant mask dimension sizes.
4251  SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4252  // Gather strided slice offsets and sizes.
4253  SmallVector<int64_t> sliceOffsets;
4254  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4255  sliceOffsets);
4256  SmallVector<int64_t> sliceSizes;
4257  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4258 
4259  // Compute slice of vector mask region.
4260  SmallVector<Value> sliceMaskDimSizes;
4261  sliceMaskDimSizes.reserve(maskDimSizes.size());
4262  // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4263  // only iterate on the leading dim sizes. The tail accounts for the
4264  // remaining dim sizes.
4265  for (auto [maskDimSize, sliceOffset, sliceSize] :
4266  llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4267  // No need to clamp on min/max values, because create_mask has clamping
4268  // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
4269  // greater than the vector dim size.
4270  IntegerAttr offsetAttr =
4271  rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
4272  Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4273  Value sliceMaskDimSize =
4274  arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4275  sliceMaskDimSizes.push_back(sliceMaskDimSize);
4276  }
4277  // Add unchanged dimensions.
4278  llvm::append_range(
4279  sliceMaskDimSizes,
4280  llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4281  // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
4282  // region.
4283  rewriter.replaceOpWithNewOp<CreateMaskOp>(
4284  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4285  sliceMaskDimSizes);
4286  return success();
4287  }
4288 };
4289 
4290 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
4291 // ConstantMaskOp.
4292 class StridedSliceConstantMaskFolder final
4293  : public OpRewritePattern<ExtractStridedSliceOp> {
4294 public:
4296 
4297  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4298  PatternRewriter &rewriter) const override {
4299  // Return if 'extractStridedSliceOp' operand is not defined by a
4300  // ConstantMaskOp.
4301  auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
4302  auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4303  if (!constantMaskOp)
4304  return failure();
4305  // Return if 'extractStridedSliceOp' has non-unit strides.
4306  if (extractStridedSliceOp.hasNonUnitStrides())
4307  return failure();
4308  // Gather constant mask dimension sizes.
4309  ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4310  // Gather strided slice offsets and sizes.
4311  SmallVector<int64_t> sliceOffsets;
4312  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4313  sliceOffsets);
4314  SmallVector<int64_t> sliceSizes;
4315  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4316 
4317  // Compute slice of vector mask region.
4318  SmallVector<int64_t> sliceMaskDimSizes;
4319  sliceMaskDimSizes.reserve(maskDimSizes.size());
4320  for (auto [maskDimSize, sliceOffset, sliceSize] :
4321  llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4322  int64_t sliceMaskDimSize = std::max(
4323  static_cast<int64_t>(0),
4324  std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4325  sliceMaskDimSizes.push_back(sliceMaskDimSize);
4326  }
4327  // Add unchanged dimensions.
4328  if (sliceMaskDimSizes.size() < maskDimSizes.size())
4329  for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4330  sliceMaskDimSizes.push_back(maskDimSizes[i]);
4331  // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
4332  // region is a conjunction of mask dim intervals).
4333  if (llvm::is_contained(sliceMaskDimSizes, 0))
4334  sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4335 
4336  // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
4337  // region.
4338  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4339  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4340  sliceMaskDimSizes);
4341  return success();
4342  }
4343 };
4344 
4345 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
4346 // BroadcastOp(ExtractStrideSliceOp).
4347 class StridedSliceBroadcast final
4348  : public OpRewritePattern<ExtractStridedSliceOp> {
4349 public:
4351 
4352  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4353  PatternRewriter &rewriter) const override {
4354  auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
4355  if (!broadcast)
4356  return failure();
4357  auto srcVecType =
4358  llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
4359  unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4360  auto dstVecType = llvm::cast<VectorType>(op.getType());
4361  unsigned dstRank = dstVecType.getRank();
4362  unsigned rankDiff = dstRank - srcRank;
4363  // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4364  // (n -> m with n > m). If they are originally both broadcasted *and*
4365  // sliced, this can be simplified to just broadcasting.
4366  bool needsSlice = false;
4367  for (unsigned i = 0; i < srcRank; i++) {
4368  if (srcVecType.getDimSize(i) != 1 &&
4369  srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4370  needsSlice = true;
4371  break;
4372  }
4373  }
4374  Value source = broadcast.getSource();
4375  if (needsSlice) {
4376  SmallVector<int64_t> offsets =
4377  getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
4378  SmallVector<int64_t> sizes =
4379  getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
4380  for (unsigned i = 0; i < srcRank; i++) {
4381  if (srcVecType.getDimSize(i) == 1) {
4382  // In case this dimension was broadcasted *and* sliced, the offset
4383  // and size need to be updated now that there is no broadcast before
4384  // the slice.
4385  offsets[i] = 0;
4386  sizes[i] = 1;
4387  }
4388  }
4389  source = ExtractStridedSliceOp::create(
4390  rewriter, op->getLoc(), source, offsets, sizes,
4391  getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
4392  }
4393  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
4394  return success();
4395  }
4396 };
4397 
4398 /// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
4399 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
4400 public:
4402 
4403  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4404  PatternRewriter &rewriter) const override {
4405 
4406  Value splat = getScalarSplatSource(op.getVector());
4407  if (!splat)
4408  return failure();
4409  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
4410  return success();
4411  }
4412 };
4413 
4414 /// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
4415 /// slice is contiguous, into extract and shape_cast.
4416 ///
4417 /// Example:
4418 /// Before:
4419 /// %1 = vector.extract_strided_slice %arg0 {
4420 /// offsets = [0, 0, 0, 0, 0],
4421 /// sizes = [1, 1, 1, 1, 8],
4422 /// strides = [1, 1, 1, 1, 1]
4423 /// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
4424 /// After:
4425 /// %0 = vector.extract %arg0[0, 0, 0, 0]
4426 /// : vector<8xi8> from vector<8x1x1x2x8xi8>
4427 /// %1 = vector.shape_cast %0
4428 /// : vector<8xi8> to vector<1x1x1x1x8xi8>
4429 ///
4430 class ContiguousExtractStridedSliceToExtract final
4431  : public OpRewritePattern<ExtractStridedSliceOp> {
4432 public:
4434 
4435  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4436  PatternRewriter &rewriter) const override {
4437  if (op.hasNonUnitStrides())
4438  return failure();
4439  Value source = op.getOperand();
4440  auto sourceType = cast<VectorType>(source.getType());
4441  if (sourceType.isScalable() || sourceType.getRank() == 0)
4442  return failure();
4443 
4444  // Compute the number of offsets to pass to ExtractOp::build. That is the
4445  // difference between the source rank and the desired slice rank. We walk
4446  // the dimensions from innermost out, and stop when the next slice dimension
4447  // is not full-size.
4448  SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
4449  int numOffsets;
4450  for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4451  if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4452  break;
4453  }
4454 
4455  // If the created extract op would have no offsets, then this whole
4456  // extract_strided_slice is the identity and should have been handled by
4457  // other canonicalizations.
4458  if (numOffsets == 0)
4459  return failure();
4460 
4461  // If not even the inner-most dimension is full-size, this op can't be
4462  // rewritten as an ExtractOp.
4463  if (numOffsets == sourceType.getRank() &&
4464  static_cast<int>(sizes.size()) == sourceType.getRank())
4465  return failure();
4466 
4467  // The outer dimensions must have unit size.
4468  for (int i = 0; i < numOffsets; ++i) {
4469  if (sizes[i] != 1)
4470  return failure();
4471  }
4472 
4473  // Avoid generating slices that have leading unit dimensions. The shape_cast
4474  // op that we create below would take bad generic fallback patterns
4475  // (ShapeCastOpRewritePattern).
4476  while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
4477  sizes[numOffsets] == 1) {
4478  ++numOffsets;
4479  }
4480 
4481  SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
4482  auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4483  Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4484  extractOffsets);
4485  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
4486  return success();
4487  }
4488 };
4489 
4490 } // namespace
4491 
4492 void ExtractStridedSliceOp::getCanonicalizationPatterns(
4493  RewritePatternSet &results, MLIRContext *context) {
4494  // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
4495  // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4496  results.add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
4497  StridedSliceBroadcast, StridedSliceSplat,
4498  ContiguousExtractStridedSliceToExtract>(context);
4499 }
4500 
4501 //===----------------------------------------------------------------------===//
4502 // TransferReadOp
4503 //===----------------------------------------------------------------------===//
4504 
4505 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
4506 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4507  VectorType vectorType, Value source,
4508  ValueRange indices, std::optional<Value> padding,
4509  AffineMapAttr permutationMapAttr,
4510  /*optional*/ ArrayAttr inBoundsAttr) {
4511 
4512  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4513  if (!padding)
4514  padding = ub::PoisonOp::create(builder, result.location, elemType);
4515  build(builder, result, vectorType, source, indices, permutationMapAttr,
4516  *padding, /*mask=*/Value(), inBoundsAttr);
4517 }
4518 
4519 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
4520 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4521  VectorType vectorType, Value source,
4522  ValueRange indices, std::optional<Value> padding,
4523  AffineMap permutationMap,
4524  std::optional<ArrayRef<bool>> inBounds) {
4525  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4526  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4527  ? builder.getBoolArrayAttr(inBounds.value())
4528  : builder.getBoolArrayAttr(
4529  SmallVector<bool>(vectorType.getRank(), false));
4530  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4531  if (!padding)
4532  padding = ub::PoisonOp::create(builder, result.location, elemType);
4533  build(builder, result, vectorType, source, indices, *padding,
4534  permutationMapAttr, inBoundsAttr);
4535 }
4536 
4537 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
4538 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4539  VectorType vectorType, Value source,
4540  ValueRange indices, std::optional<Value> padding,
4541  std::optional<ArrayRef<bool>> inBounds) {
4542  AffineMap permutationMap = getTransferMinorIdentityMap(
4543  llvm::cast<ShapedType>(source.getType()), vectorType);
4544  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4545  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4546  ? builder.getBoolArrayAttr(inBounds.value())
4547  : builder.getBoolArrayAttr(
4548  SmallVector<bool>(vectorType.getRank(), false));
4549  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4550  if (!padding)
4551  padding = ub::PoisonOp::create(builder, result.location, elemType);
4552  build(builder, result, vectorType, source, indices, permutationMapAttr,
4553  *padding,
4554  /*mask=*/Value(), inBoundsAttr);
4555 }
4556 
4557 template <typename EmitFun>
4558 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
4559  EmitFun emitOpError) {
4560  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
4561  for (auto expr : permutationMap.getResults()) {
4562  auto dim = dyn_cast<AffineDimExpr>(expr);
4563  auto zero = dyn_cast<AffineConstantExpr>(expr);
4564  if (zero) {
4565  if (zero.getValue() != 0) {
4566  return emitOpError(
4567  "requires a projected permutation_map (at most one dim or the zero "
4568  "constant can appear in each result)");
4569  }
4570  continue;
4571  }
4572  if (!dim) {
4573  return emitOpError("requires a projected permutation_map (at most one "
4574  "dim or the zero constant can appear in each result)");
4575  }
4576  if (seen[dim.getPosition()]) {
4577  return emitOpError(
4578  "requires a permutation_map that is a permutation (found one dim "
4579  "used more than once)");
4580  }
4581  seen[dim.getPosition()] = true;
4582  }
4583  return success();
4584 }
4585 
4586 static LogicalResult
4587 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4588  VectorType vectorType, VectorType maskType,
4589  VectorType inferredMaskType, AffineMap permutationMap,
4590  ArrayAttr inBounds) {
4591  if (op->hasAttr("masked")) {
4592  return op->emitOpError("masked attribute has been removed. "
4593  "Use in_bounds instead.");
4594  }
4595 
4596  if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4597  return op->emitOpError(
4598  "requires source to be a memref or ranked tensor type");
4599 
4600  auto elementType = shapedType.getElementType();
4601  DataLayout dataLayout = DataLayout::closest(op);
4602  if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4603  // Memref or tensor has vector element type.
4604  unsigned sourceVecSize =
4605  dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
4606  vectorElementType.getShape().back();
4607  unsigned resultVecSize =
4608  dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
4609  vectorType.getShape().back();
4610  if (resultVecSize % sourceVecSize != 0)
4611  return op->emitOpError(
4612  "requires the bitwidth of the minor 1-D vector to be an integral "
4613  "multiple of the bitwidth of the minor 1-D vector of the source");
4614 
4615  unsigned sourceVecEltRank = vectorElementType.getRank();
4616  unsigned resultVecRank = vectorType.getRank();
4617  if (sourceVecEltRank > resultVecRank)
4618  return op->emitOpError(
4619  "requires source vector element and vector result ranks to match.");
4620  unsigned rankOffset = resultVecRank - sourceVecEltRank;
4621  // Check that permutation map results match 'rankOffset' of vector type.
4622  if (permutationMap.getNumResults() != rankOffset)
4623  return op->emitOpError("requires a permutation_map with result dims of "
4624  "the same rank as the vector type");
4625 
4626  if (maskType)
4627  return op->emitOpError("does not support masks with vector element type");
4628  } else {
4629  // Memref or tensor has scalar element type.
4630  unsigned minorSize =
4631  vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4632  unsigned resultVecSize =
4633  dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
4634  if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
4635  return op->emitOpError(
4636  "requires the bitwidth of the minor 1-D vector to be an integral "
4637  "multiple of the bitwidth of the source element type");
4638 
4639  // Check that permutation map results match rank of vector type.
4640  if (permutationMap.getNumResults() != vectorType.getRank())
4641  return op->emitOpError("requires a permutation_map with result dims of "
4642  "the same rank as the vector type");
4643  }
4644 
4645  if (permutationMap.getNumSymbols() != 0)
4646  return op->emitOpError("requires permutation_map without symbols");
4647 
4648  if (permutationMap.getNumInputs() != shapedType.getRank())
4649  return op->emitOpError("requires a permutation_map with input dims of the "
4650  "same rank as the source type");
4651 
4652  if (maskType && maskType != inferredMaskType)
4653  return op->emitOpError("inferred mask type (")
4654  << inferredMaskType << ") and mask operand type (" << maskType
4655  << ") don't match";
4656 
4657  if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
4658  return op->emitOpError("expects the in_bounds attr of same rank "
4659  "as permutation_map results: ")
4660  << AffineMapAttr::get(permutationMap)
4661  << " vs inBounds of size: " << inBounds.size();
4662 
4663  return success();
4664 }
4665 
4666 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
4667  SmallVector<StringRef, 3> elidedAttrs;
4668  elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4669  if (op.getPermutationMap().isMinorIdentity())
4670  elidedAttrs.push_back(op.getPermutationMapAttrName());
4671  // Elide in_bounds attribute if all dims are out-of-bounds.
4672  if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
4673  elidedAttrs.push_back(op.getInBoundsAttrName());
4674  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4675 }
4676 
4678  p << " " << getBase() << "[" << getIndices() << "], " << getPadding();
4679  if (getMask())
4680  p << ", " << getMask();
4681  printTransferAttrs(p, *this);
4682  p << " : " << getShapedType() << ", " << getVectorType();
4683 }
4684 
4685 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
4686  AffineMap permMap) {
4687  auto i1Type = IntegerType::get(permMap.getContext(), 1);
4688  AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
4689  assert(invPermMap && "Inversed permutation map couldn't be computed");
4690  SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
4691 
4692  // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
4693  // 0-D mask into a single-element 1-D mask.
4694  if (maskShape.empty())
4695  maskShape.push_back(1);
4696 
4697  SmallVector<bool> scalableDims =
4698  applyPermutationMap(invPermMap, vecType.getScalableDims());
4699 
4700  return VectorType::get(maskShape, i1Type, scalableDims);
4701 }
4702 
4703 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4704  auto &builder = parser.getBuilder();
4705  SMLoc typesLoc;
4706  OpAsmParser::UnresolvedOperand sourceInfo;
4708  OpAsmParser::UnresolvedOperand paddingInfo;
4709  SmallVector<Type, 2> types;
4711  // Parsing with support for paddingValue.
4712  if (parser.parseOperand(sourceInfo) ||
4713  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
4714  parser.parseComma() || parser.parseOperand(paddingInfo))
4715  return failure();
4716  ParseResult hasMask = parser.parseOptionalComma();
4717  if (hasMask.succeeded()) {
4718  if (parser.parseOperand(maskInfo))
4719  return failure();
4720  }
4721  if (parser.parseOptionalAttrDict(result.attributes) ||
4722  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4723  return failure();
4724  if (types.size() != 2)
4725  return parser.emitError(typesLoc, "requires two types");
4726  auto indexType = builder.getIndexType();
4727  auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4728  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4729  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4730  VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4731  if (!vectorType)
4732  return parser.emitError(typesLoc, "requires vector type");
4733  auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
4734  Attribute permMapAttr = result.attributes.get(permMapAttrName);
4735  AffineMap permMap;
4736  if (!permMapAttr) {
4737  if (shapedType.getRank() <
4738  getEffectiveVectorRankForXferOp(shapedType, vectorType))
4739  return parser.emitError(typesLoc,
4740  "expected a custom permutation_map when "
4741  "rank(source) != rank(destination)");
4742  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4743  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4744  } else {
4745  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4746  }
4747  auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
4748  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4749  if (!inBoundsAttr) {
4750  result.addAttribute(inBoundsAttrName,
4751  builder.getBoolArrayAttr(
4752  SmallVector<bool>(permMap.getNumResults(), false)));
4753  }
4754  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4755  parser.resolveOperands(indexInfo, indexType, result.operands) ||
4756  parser.resolveOperand(paddingInfo, shapedType.getElementType(),
4757  result.operands))
4758  return failure();
4759  if (hasMask.succeeded()) {
4760  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4761  return parser.emitError(
4762  maskInfo.location, "does not support masks with vector element type");
4763  if (vectorType.getRank() != permMap.getNumResults()) {
4764  return parser.emitError(typesLoc,
4765  "expected the same rank for the vector and the "
4766  "results of the permutation map");
4767  }
4768  // Instead of adding the mask type as an op type, compute it based on the
4769  // vector type and the permutation map (to keep the type signature small).
4770  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4771  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4772  return failure();
4773  }
4774  result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4775  builder.getDenseI32ArrayAttr(
4776  {1, static_cast<int32_t>(indexInfo.size()), 1,
4777  static_cast<int32_t>(hasMask.succeeded())}));
4778  return parser.addTypeToList(vectorType, result.types);
4779 }
4780 
4781 LogicalResult TransferReadOp::verify() {
4782  // Consistency of elemental types in source and vector.
4783  ShapedType shapedType = getShapedType();
4784  VectorType vectorType = getVectorType();
4785  VectorType maskType = getMaskType();
4786  auto paddingType = getPadding().getType();
4787  auto permutationMap = getPermutationMap();
4788  VectorType inferredMaskType =
4789  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4790  : VectorType();
4791  auto sourceElementType = shapedType.getElementType();
4792 
4793  if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
4794  return emitOpError("requires ") << shapedType.getRank() << " indices";
4795 
4796  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4797  shapedType, vectorType, maskType,
4798  inferredMaskType, permutationMap, getInBounds())))
4799  return failure();
4800 
4801  if (auto sourceVectorElementType =
4802  llvm::dyn_cast<VectorType>(sourceElementType)) {
4803  // Source has vector element type.
4804  // Check that 'sourceVectorElementType' and 'paddingType' types match.
4805  if (sourceVectorElementType != paddingType)
4806  return emitOpError(
4807  "requires source element type and padding type to match.");
4808 
4809  } else {
4810  // Check that 'paddingType' is valid to store in a vector type.
4811  if (!VectorType::isValidElementType(paddingType))
4812  return emitOpError("requires valid padding vector elemental type");
4813 
4814  // Check that padding type and vector element types match.
4815  if (paddingType != sourceElementType)
4816  return emitOpError(
4817  "requires formal padding and source of the same elemental type");
4818  }
4819 
4820  return verifyPermutationMap(permutationMap,
4821  [&](Twine t) { return emitOpError(t); });
4822 }
4823 
4824 // MaskableOpInterface methods.
4825 
4826 /// Returns the mask type expected by this operation. Mostly used for
4827 /// verification purposes. It requires the operation to be vectorized."
4828 Type TransferReadOp::getExpectedMaskType() {
4829  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4830 }
4831 
4832 //===----------------------------------------------------------------------===//
4833 // TransferReadOp: VectorTransferOpInterface methods.
4834 //===----------------------------------------------------------------------===//
4835 VectorType TransferReadOp::getVectorType() {
4836  return cast<VectorType>(getVector().getType());
4837 }
4838 
4839 template <typename TransferOp>
4840 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4841  // TODO: support more aggressive createOrFold on:
4842  // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4843  if (op.getShapedType().isDynamicDim(indicesIdx))
4844  return false;
4845  Value index = op.getIndices()[indicesIdx];
4846  std::optional<int64_t> cstOp = getConstantIntValue(index);
4847  if (!cstOp.has_value())
4848  return false;
4849 
4850  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4851  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4852 
4853  return cstOp.value() + vectorSize <= sourceSize;
4854 }
4855 
4856 template <typename TransferOp>
4857 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4858  // TODO: support 0-d corner case.
4859  // TODO: Be less conservative.
4860  if (op.getTransferRank() == 0)
4861  return failure();
4862  AffineMap permutationMap = op.getPermutationMap();
4863  bool changed = false;
4864  SmallVector<bool, 4> newInBounds;
4865  newInBounds.reserve(op.getTransferRank());
4866  // Idxs of non-bcast dims - used when analysing bcast dims.
4867  SmallVector<unsigned> nonBcastDims;
4868 
4869  // 1. Process non-broadcast dims
4870  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
4871  // 1.1. Already marked as in-bounds, nothing to see here.
4872  if (op.isDimInBounds(i)) {
4873  newInBounds.push_back(true);
4874  continue;
4875  }
4876  // 1.2. Currently out-of-bounds, check whether we can statically determine
4877  // it is inBounds.
4878  bool inBounds = false;
4879  auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
4880  if (dimExpr) {
4881  inBounds = isInBounds(op, /*resultIdx=*/i,
4882  /*indicesIdx=*/dimExpr.getPosition());
4883  nonBcastDims.push_back(i);
4884  }
4885 
4886  newInBounds.push_back(inBounds);
4887  // We commit the pattern if it is "more inbounds".
4888  changed |= inBounds;
4889  }
4890 
4891  // 2. Handle broadcast dims
4892  // If all non-broadcast dims are "in bounds", then all bcast dims should be
4893  // "in bounds" as well.
4894  bool allNonBcastDimsInBounds = llvm::all_of(
4895  nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
4896  if (allNonBcastDimsInBounds) {
4897  for (size_t idx : permutationMap.getBroadcastDims()) {
4898  changed |= !newInBounds[idx];
4899  newInBounds[idx] = true;
4900  }
4901  }
4902 
4903  if (!changed)
4904  return failure();
4905  // OpBuilder is only used as a helper to build an I64ArrayAttr.
4906  OpBuilder b(op.getContext());
4907  op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
4908  return success();
4909 }
4910 
4911 template <typename TransferOp>
4912 static LogicalResult foldTransferFullMask(TransferOp op) {
4913  auto mask = op.getMask();
4914  if (!mask)
4915  return failure();
4916 
4917  if (getMaskFormat(mask) != MaskFormat::AllTrue)
4918  return failure();
4919 
4920  op.getMaskMutable().clear();
4921  return success();
4922 }
4923 
4924 /// ```
4925 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4926 /// : vector<1x4xf32>, tensor<4x4xf32>
4927 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
4928 /// : tensor<4x4xf32>, vector<1x4xf32>
4929 /// ```
4930 /// -> Folds into
4931 /// ```
4932 /// %v0
4933 /// ```
4934 static Value foldRAW(TransferReadOp readOp) {
4935  if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4936  return {};
4937  auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
4938  while (defWrite) {
4939  if (checkSameValueRAW(defWrite, readOp))
4940  return defWrite.getVector();
4942  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4943  cast<VectorTransferOpInterface>(readOp.getOperation())))
4944  break;
4945  defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
4946  }
4947  return {};
4948 }
4949 
4950 OpFoldResult TransferReadOp::fold(FoldAdaptor) {
4951  if (Value vec = foldRAW(*this))
4952  return vec;
4953  /// transfer_read(memrefcast) -> transfer_read
4954  if (succeeded(foldTransferInBoundsAttribute(*this)))
4955  return getResult();
4956  if (succeeded(foldTransferFullMask(*this)))
4957  return getResult();
4958  if (succeeded(memref::foldMemRefCast(*this)))
4959  return getResult();
4960  if (succeeded(tensor::foldTensorCast(*this)))
4961  return getResult();
4962  return OpFoldResult();
4963 }
4964 
4965 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4966  return llvm::to_vector<4>(getVectorType().getShape());
4967 }
4968 
4969 void TransferReadOp::getEffects(
4971  &effects) {
4972  if (llvm::isa<MemRefType>(getShapedType()))
4973  effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
4975 }
4976 
4977 Speculation::Speculatability TransferReadOp::getSpeculatability() {
4978  if (hasPureTensorSemantics())
4981 }
4982 
4983 namespace {
4984 /// Store to load forwarding for transfer operations with permuation maps.
4985 /// Even if the permutation maps are different we can still propagate the store
4986 /// into the load if the size of the dimensions read and written match. Then we
4987 /// can replace the transfer_read + transfer_write by vector.broadcast and
4988 /// vector.transpose.
4989 /// Example:
4990 /// ```
4991 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
4992 /// {in_bounds = [true, true],
4993 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
4994 /// vector<4x1xf32>, tensor<4x4x4xf32>
4995 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
4996 /// {in_bounds = [true, true, true, true],
4997 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
4998 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
4999 /// ```
5000 /// To:
5001 /// ```
5002 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
5003 /// %r = vector.transpose %0, [3, 0, 2, 1] :
5004 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
5005 /// ```
5006 struct TransferReadAfterWriteToBroadcast
5007  : public OpRewritePattern<TransferReadOp> {
5009 
5010  LogicalResult matchAndRewrite(TransferReadOp readOp,
5011  PatternRewriter &rewriter) const override {
5012  auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5013  if (!defWrite)
5014  return failure();
5015  // Bail if we need an alias analysis.
5016  if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5017  return failure();
5018  // Bail if we need a bounds analysis.
5019  if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5020  return failure();
5021  // TODO: If the written transfer chunk is a superset of the read transfer
5022  // chunk we could do an extract_strided_slice.
5023  if (readOp.getTransferChunkAccessed() !=
5024  defWrite.getTransferChunkAccessed())
5025  return failure();
5026  // TODO: Support cases where a dim is explicitly written but implicitly
5027  // read (i.e., a unit dim that is rank reduced).
5028  if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
5029  getUnusedDimsBitVector({defWrite.getPermutationMap()}))
5030  return failure();
5031  // This pattern should only catch the broadcast case, the non-broadcast case
5032  // should be done separately to keep application conditions clean and
5033  // separate.
5034  AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
5035  AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
5036  bool bcast = !readMap.getBroadcastDims().empty() ||
5037  !writeMap.getBroadcastDims().empty();
5038  if (!bcast)
5039  return failure();
5040  // At this point, we know we have a bcast.
5041  // Bail in the masked case (too complex atm and needed to properly account
5042  // for padding).
5043  if (readOp.getMask() || defWrite.getMask())
5044  return failure();
5045  // If indices are not the same a shift may be required, bail.
5046  if (readOp.getIndices() != defWrite.getIndices())
5047  return failure();
5048 
5049  Value vec = defWrite.getVector();
5050  // TODO: loop through the chain of transfer_write if we can prove that they
5051  // don't overlap with the transfer_read. This requires improving
5052  // `isDisjointTransferIndices` helper.
5053  AffineMap map = readMap.compose(writeMap);
5054  if (map.getNumResults() == 0)
5055  return failure();
5056  // Calculate the permutation to apply to go from the vector stored to the
5057  // vector read.
5058  SmallVector<unsigned> permutation;
5059  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
5060  return failure();
5061 
5062  Location loc = readOp.getLoc();
5063  // Calculate the broadcast shape by applying the reverse permutation to the
5064  // final shape we want.
5065  ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
5066  SmallVector<int64_t> broadcastShape(destShape.size());
5067  SmallVector<bool> broadcastScalableFlags(destShape.size());
5068  for (const auto &pos : llvm::enumerate(permutation)) {
5069  broadcastShape[pos.value()] = destShape[pos.index()];
5070  broadcastScalableFlags[pos.value()] =
5071  readOp.getVectorType().getScalableDims()[pos.index()];
5072  }
5073  VectorType broadcastedType = VectorType::get(
5074  broadcastShape, defWrite.getVectorType().getElementType(),
5075  broadcastScalableFlags);
5076  vec = vector::BroadcastOp::create(rewriter, loc, broadcastedType, vec);
5077  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
5078  rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
5079  transposePerm);
5080  return success();
5081  }
5082 };
5083 } // namespace
5084 
5085 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5086  MLIRContext *context) {
5087  results.add<TransferReadAfterWriteToBroadcast>(context);
5088 }
5089 
5090 //===----------------------------------------------------------------------===//
5091 // TransferWriteOp
5092 //===----------------------------------------------------------------------===//
5093 
5094 /// 1. Builder with type inference.
5095 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5096  Value vector, Value dest, ValueRange indices,
5097  AffineMapAttr permutationMapAttr,
5098  /*optional*/ Value mask,
5099  /*optional*/ ArrayAttr inBoundsAttr) {
5100  Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
5101  build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
5102  mask, inBoundsAttr);
5103 }
5104 
5105 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
5106 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5107  Value vector, Value dest, ValueRange indices,
5108  AffineMapAttr permutationMapAttr,
5109  /*optional*/ ArrayAttr inBoundsAttr) {
5110  build(builder, result, vector, dest, indices, permutationMapAttr,
5111  /*mask=*/Value(), inBoundsAttr);
5112 }
5113 
5114 /// 3. Builder with type inference that sets an empty mask (variant without
5115 /// attrs)
5116 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5117  Value vector, Value dest, ValueRange indices,
5118  AffineMap permutationMap,
5119  std::optional<ArrayRef<bool>> inBounds) {
5120  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5121  auto inBoundsAttr =
5122  (inBounds && !inBounds.value().empty())
5123  ? builder.getBoolArrayAttr(inBounds.value())
5125  llvm::cast<VectorType>(vector.getType()).getRank(), false));
5126  build(builder, result, vector, dest, indices, permutationMapAttr,
5127  /*mask=*/Value(), inBoundsAttr);
5128 }
5129 
5130 /// 4. Builder with type inference that sets an empty mask and sets permutation
5131 /// map to 'getMinorIdentityMap'.
5132 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5133  Value vector, Value dest, ValueRange indices,
5134  std::optional<ArrayRef<bool>> inBounds) {
5135  auto vectorType = llvm::cast<VectorType>(vector.getType());
5136  AffineMap permutationMap = getTransferMinorIdentityMap(
5137  llvm::cast<ShapedType>(dest.getType()), vectorType);
5138  build(builder, result, vector, dest, indices, permutationMap, inBounds);
5139 }
5140 
5141 ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5142  OperationState &result) {
5143  auto &builder = parser.getBuilder();
5144  SMLoc typesLoc;
5145  OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5147  SmallVector<Type, 2> types;
5149  if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
5150  parser.parseOperand(sourceInfo) ||
5151  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
5152  return failure();
5153  ParseResult hasMask = parser.parseOptionalComma();
5154  if (hasMask.succeeded() && parser.parseOperand(maskInfo))
5155  return failure();
5156  if (parser.parseOptionalAttrDict(result.attributes) ||
5157  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5158  return failure();
5159  if (types.size() != 2)
5160  return parser.emitError(typesLoc, "requires two types");
5161  auto indexType = builder.getIndexType();
5162  VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5163  if (!vectorType)
5164  return parser.emitError(typesLoc, "requires vector type");
5165  ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5166  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5167  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5168  auto permMapAttrName =
5169  TransferWriteOp::getPermutationMapAttrName(result.name);
5170  auto permMapAttr = result.attributes.get(permMapAttrName);
5171  AffineMap permMap;
5172  if (!permMapAttr) {
5173  if (shapedType.getRank() <
5174  getEffectiveVectorRankForXferOp(shapedType, vectorType))
5175  return parser.emitError(typesLoc,
5176  "expected a custom permutation_map when "
5177  "rank(source) != rank(destination)");
5178  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5179  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5180  } else {
5181  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5182  }
5183  auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
5184  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5185  if (!inBoundsAttr) {
5186  result.addAttribute(inBoundsAttrName,
5187  builder.getBoolArrayAttr(
5188  SmallVector<bool>(permMap.getNumResults(), false)));
5189  }
5190  if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
5191  parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5192  parser.resolveOperands(indexInfo, indexType, result.operands))
5193  return failure();
5194  if (hasMask.succeeded()) {
5195  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5196  return parser.emitError(
5197  maskInfo.location, "does not support masks with vector element type");
5198  if (vectorType.getRank() != permMap.getNumResults()) {
5199  return parser.emitError(typesLoc,
5200  "expected the same rank for the vector and the "
5201  "results of the permutation map");
5202  }
5203  auto maskType = inferTransferOpMaskType(vectorType, permMap);
5204  if (parser.resolveOperand(maskInfo, maskType, result.operands))
5205  return failure();
5206  }
5207  result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5208  builder.getDenseI32ArrayAttr(
5209  {1, 1, static_cast<int32_t>(indexInfo.size()),
5210  static_cast<int32_t>(hasMask.succeeded())}));
5211  return failure(llvm::isa<RankedTensorType>(shapedType) &&
5212  parser.addTypeToList(shapedType, result.types));
5213 }
5214 
5216  p << " " << getVector() << ", " << getBase() << "[" << getIndices() << "]";
5217  if (getMask())
5218  p << ", " << getMask();
5219  printTransferAttrs(p, *this);
5220  p << " : " << getVectorType() << ", " << getShapedType();
5221 }
5222 
5223 LogicalResult TransferWriteOp::verify() {
5224  // Consistency of elemental types in shape and vector.
5225  ShapedType shapedType = getShapedType();
5226  VectorType vectorType = getVectorType();
5227  VectorType maskType = getMaskType();
5228  auto permutationMap = getPermutationMap();
5229  VectorType inferredMaskType =
5230  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5231  : VectorType();
5232 
5233  if (llvm::size(getIndices()) != shapedType.getRank())
5234  return emitOpError("requires ") << shapedType.getRank() << " indices";
5235 
5236  // We do not allow broadcast dimensions on TransferWriteOps for the moment,
5237  // as the semantics is unclear. This can be revisited later if necessary.
5238  if (hasBroadcastDim())
5239  return emitOpError("should not have broadcast dimensions");
5240 
5241  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5242  shapedType, vectorType, maskType,
5243  inferredMaskType, permutationMap, getInBounds())))
5244  return failure();
5245 
5246  return verifyPermutationMap(permutationMap,
5247  [&](Twine t) { return emitOpError(t); });
5248 }
5249 
5250 //===----------------------------------------------------------------------===//
5251 // TransferWriteOp: MaskableOpInterface methods.
5252 //===----------------------------------------------------------------------===//
5253 
5254 /// Returns the mask type expected by this operation. Mostly used for
5255 /// verification purposes.
5256 Type TransferWriteOp::getExpectedMaskType() {
5257  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5258 }
5259 
5260 //===----------------------------------------------------------------------===//
5261 // TransferWriteOp: VectorTransferOpInterface methods.
5262 //===----------------------------------------------------------------------===//
5263 Value TransferWriteOp::getVector() { return getOperand(0); }
5264 VectorType TransferWriteOp::getVectorType() {
5265  return cast<VectorType>(getValueToStore().getType());
5266 }
5267 
5268 //===----------------------------------------------------------------------===//
5269 // TransferWriteOp: fold methods.
5270 //===----------------------------------------------------------------------===//
5271 /// Fold:
5272 /// ```
5273 /// %t1 = ...
5274 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
5275 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
5276 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
5277 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
5278 /// ```
5279 ///
5280 /// into:
5281 ///
5282 /// ```
5283 /// %t0
5284 /// ```
5285 ///
5286 /// The producer of t1 may or may not be DCE'd depending on whether it is a
5287 /// block argument or has side effects.
5288 static LogicalResult foldReadInitWrite(TransferWriteOp write,
5290  SmallVectorImpl<OpFoldResult> &results) {
5291  // TODO: support 0-d corner case.
5292  if (write.getTransferRank() == 0)
5293  return failure();
5294  auto rankedTensorType =
5295  llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5296  // If not operating on tensors, bail.
5297  if (!rankedTensorType)
5298  return failure();
5299  // If no read, bail.
5300  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5301  if (!read)
5302  return failure();
5303  // TODO: support 0-d corner case.
5304  if (read.getTransferRank() == 0)
5305  return failure();
5306  // For now, only accept minor identity. Future: composition is minor identity.
5307  if (!read.getPermutationMap().isMinorIdentity() ||
5308  !write.getPermutationMap().isMinorIdentity())
5309  return failure();
5310  // Bail on mismatching ranks.
5311  if (read.getTransferRank() != write.getTransferRank())
5312  return failure();
5313  // Bail on potential out-of-bounds accesses.
5314  if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5315  return failure();
5316  // Tensor types must be the same.
5317  if (read.getBase().getType() != rankedTensorType)
5318  return failure();
5319  // Vector types must be the same.
5320  if (read.getVectorType() != write.getVectorType())
5321  return failure();
5322  // Vector and Tensor shapes must match.
5323  if (read.getVectorType().getShape() != rankedTensorType.getShape())
5324  return failure();
5325  // If any index is nonzero.
5326  auto isNotConstantZero = [](Value v) {
5327  auto cstOp = getConstantIntValue(v);
5328  return !cstOp.has_value() || cstOp.value() != 0;
5329  };
5330  if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5331  llvm::any_of(write.getIndices(), isNotConstantZero))
5332  return failure();
5333  // Success.
5334  results.push_back(read.getBase());
5335  return success();
5336 }
5337 
5338 static bool checkSameValueWAR(vector::TransferReadOp read,
5339  vector::TransferWriteOp write) {
5340  return read.getBase() == write.getBase() &&
5341  read.getIndices() == write.getIndices() &&
5342  read.getPermutationMap() == write.getPermutationMap() &&
5343  read.getVectorType() == write.getVectorType() && !read.getMask() &&
5344  !write.getMask();
5345 }
5346 /// Fold transfer_write write after read:
5347 /// ```
5348 /// %t0 = ...
5349 /// %v = vector.transfer_read %t0[%c0...] :
5350 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
5351 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
5352 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
5353 /// ```
5354 ///
5355 /// into:
5356 ///
5357 /// ```
5358 /// %t0
5359 /// ```
5360 static LogicalResult foldWAR(TransferWriteOp write,
5361  SmallVectorImpl<OpFoldResult> &results) {
5362  if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5363  return failure();
5364  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5365  if (!read)
5366  return failure();
5367 
5368  if (!checkSameValueWAR(read, write))
5369  return failure();
5370  results.push_back(read.getBase());
5371  return success();
5372 }
5373 
5374 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5375  SmallVectorImpl<OpFoldResult> &results) {
5376  if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
5377  return success();
5378  if (succeeded(foldWAR(*this, results)))
5379  return success();
5380  if (succeeded(foldTransferInBoundsAttribute(*this)))
5381  return success();
5382  if (succeeded(foldTransferFullMask(*this)))
5383  return success();
5384  return memref::foldMemRefCast(*this);
5385 }
5386 
5387 //===----------------------------------------------------------------------===//
5388 // TransferWriteOp: other methods.
5389 //===----------------------------------------------------------------------===//
5390 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5391  return llvm::to_vector<4>(getVectorType().getShape());
5392 }
5393 
5394 void TransferWriteOp::getEffects(
5396  &effects) {
5397  if (llvm::isa<MemRefType>(getShapedType()))
5398  effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5400 }
5401 
5402 Speculation::Speculatability TransferWriteOp::getSpeculatability() {
5403  if (hasPureTensorSemantics())
5406 }
5407 
5408 namespace {
5409 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
5410 /// DCE
5411 /// ```
5412 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5413 /// : vector<1x4xf32>, tensor<4x4xf32>
5414 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
5415 /// : vector<1x4xf32>, tensor<4x4xf32>
5416 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5417 /// : vector<1x4xf32>, tensor<4x4xf32>
5418 /// ```
5419 ///
5420 /// into:
5421 ///
5422 /// ```
5423 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5424 /// : vector<1x4xf32>, tensor<4x4xf32>
5425 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
5426 /// : vector<1x4xf32>, tensor<4x4xf32>
5427 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5428 /// : vector<1x4xf32>, tensor<4x4xf32>
5429 /// ```
5430 ///
5431 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
5432 /// any other uses.
5433 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
5434 public:
5436  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5437  PatternRewriter &rewriter) const override {
5438  if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5439  return failure();
5440  vector::TransferWriteOp writeToModify = writeOp;
5441 
5442  auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5443  while (defWrite) {
5444  if (checkSameValueWAW(writeOp, defWrite)) {
5445  rewriter.modifyOpInPlace(writeToModify, [&]() {
5446  writeToModify.getBaseMutable().assign(defWrite.getBase());
5447  });
5448  return success();
5449  }
5451  cast<VectorTransferOpInterface>(defWrite.getOperation()),
5452  cast<VectorTransferOpInterface>(writeOp.getOperation())))
5453  break;
5454  // If the previous write op doesn't have any other use we an safely look
5455  // at the previous store to see if it can be removed.
5456  if (!defWrite->hasOneUse())
5457  break;
5458  writeToModify = defWrite;
5459  defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5460  }
5461  return failure();
5462  }
5463 };
5464 
5465 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
5466 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
5467 /// overwritten and inserted into another tensor. After this rewrite, the
5468 /// operations bufferize in-place since all of them work on the same slice.
5469 ///
5470 /// For example:
5471 /// ```mlir
5472 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
5473 /// : vector<8x16xf32>, tensor<8x16xf32>
5474 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
5475 /// : tensor<8x16xf32> to tensor<?x?xf32>
5476 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5477 /// : tensor<?x?xf32> into tensor<27x37xf32>
5478 /// ```
5479 /// folds to
5480 /// ```mlir
5481 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5482 /// : tensor<27x37xf32> to tensor<?x?xf32>
5483 /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
5484 /// : vector<8x16xf32>, tensor<?x?xf32>
5485 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5486 /// : tensor<?x?xf32> into tensor<27x37xf32>
5487 /// ```
5488 struct SwapExtractSliceOfTransferWrite
5489  : public OpRewritePattern<tensor::InsertSliceOp> {
5490 public:
5492 
5493  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5494  PatternRewriter &rewriter) const override {
5495  if (!insertOp.hasUnitStride())
5496  return failure();
5497  auto extractOp =
5498  insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5499  if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5500  return failure();
5501  auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5502  if (!transferOp || !transferOp->hasOneUse())
5503  return failure();
5504 
5505  // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
5506  // rank-reducing.
5507  if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5508  return rewriter.notifyMatchFailure(insertOp,
5509  "use-def chain is rank-reducing");
5510  }
5511 
5512  // Fail if tensor::ExtractSliceOp has non-zero offset.
5513  if (!extractOp.hasZeroOffset()) {
5514  return rewriter.notifyMatchFailure(insertOp,
5515  "ExtractSliceOp has non-zero offset");
5516  }
5517 
5518  // Fail if tensor::TransferWriteOp has non-zero offset.
5519  if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5520  return getConstantIntValue(value) == static_cast<int64_t>(0);
5521  })) {
5522  return rewriter.notifyMatchFailure(insertOp,
5523  "TranferWriteOp has non-zero offset");
5524  }
5525 
5526  // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
5527  if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5528  return rewriter.notifyMatchFailure(
5529  insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
5530  }
5531 
5532  for (auto [insertSize, extractSize] :
5533  llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5534  if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
5535  return rewriter.notifyMatchFailure(
5536  insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
5537  }
5538  }
5539 
5540  // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
5541  assert(transferOp.getVectorType().hasStaticShape() &&
5542  "expected vector to have a static shape");
5543  ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
5545  transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5546  if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
5547  return rewriter.notifyMatchFailure(
5548  insertOp, "TransferWriteOp may not write the full tensor.");
5549  }
5550 
5551  // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
5552  // Set all in_bounds to false and let the folder infer them.
5553  SmallVector<bool> newInBounds(vectorShape.size(), false);
5554  auto newExtractOp = tensor::ExtractSliceOp::create(
5555  rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5556  insertOp.getDest(), insertOp.getMixedOffsets(),
5557  insertOp.getMixedSizes(), insertOp.getMixedStrides());
5558  auto newTransferWriteOp = TransferWriteOp::create(
5559  rewriter, transferOp.getLoc(), transferOp.getVector(),
5560  newExtractOp.getResult(), transferOp.getIndices(),
5561  transferOp.getPermutationMapAttr(),
5562  rewriter.getBoolArrayAttr(newInBounds));
5563  rewriter.modifyOpInPlace(insertOp, [&]() {
5564  insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5565  });
5566  return success();
5567  }
5568 };
5569 
5570 } // namespace
5571 
5572 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5573  MLIRContext *context) {
5574  results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5575 }
5576 
5577 //===----------------------------------------------------------------------===//
5578 // LoadOp
5579 //===----------------------------------------------------------------------===//
5580 
5581 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5582  VectorType vecTy,
5583  MemRefType memRefTy) {
5584  // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
5585  // need any strides limitations.
5586  if (!vecTy.isScalable() &&
5587  (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5588  return success();
5589 
5590  if (!memRefTy.isLastDimUnitStride())
5591  return op->emitOpError("most minor memref dim must have unit stride");
5592  return success();
5593 }
5594 
5595 LogicalResult vector::LoadOp::verify() {
5596  VectorType resVecTy = getVectorType();
5597  MemRefType memRefTy = getMemRefType();
5598 
5599  if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
5600  return failure();
5601 
5602  if (memRefTy.getRank() < resVecTy.getRank())
5603  return emitOpError(
5604  "destination memref has lower rank than the result vector");
5605 
5606  // Checks for vector memrefs.
5607  Type memElemTy = memRefTy.getElementType();
5608  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5609  if (memVecTy != resVecTy)
5610  return emitOpError("base memref and result vector types should match");
5611  memElemTy = memVecTy.getElementType();
5612  }
5613 
5614  if (resVecTy.getElementType() != memElemTy)
5615  return emitOpError("base and result element types should match");
5616  if (llvm::size(getIndices()) != memRefTy.getRank())
5617  return emitOpError("requires ") << memRefTy.getRank() << " indices";
5618  return success();
5619 }
5620 
5621 OpFoldResult LoadOp::fold(FoldAdaptor) {
5622  if (succeeded(memref::foldMemRefCast(*this)))
5623  return getResult();
5624  return OpFoldResult();
5625 }
5626 
5627 std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5628  return llvm::to_vector<4>(getVectorType().getShape());
5629 }
5630 
5631 //===----------------------------------------------------------------------===//
5632 // StoreOp
5633 //===----------------------------------------------------------------------===//
5634 
5635 LogicalResult vector::StoreOp::verify() {
5636  VectorType valueVecTy = getVectorType();
5637  MemRefType memRefTy = getMemRefType();
5638 
5639  if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
5640  return failure();
5641 
5642  if (memRefTy.getRank() < valueVecTy.getRank())
5643  return emitOpError("source memref has lower rank than the vector to store");
5644 
5645  // Checks for vector memrefs.
5646  Type memElemTy = memRefTy.getElementType();
5647  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5648  if (memVecTy != valueVecTy)
5649  return emitOpError(
5650  "base memref and valueToStore vector types should match");
5651  memElemTy = memVecTy.getElementType();
5652  }
5653 
5654  if (valueVecTy.getElementType() != memElemTy)
5655  return emitOpError("base and valueToStore element type should match");
5656  if (llvm::size(getIndices()) != memRefTy.getRank())
5657  return emitOpError("requires ") << memRefTy.getRank() << " indices";
5658  return success();
5659 }
5660 
5661 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5662  SmallVectorImpl<OpFoldResult> &results) {
5663  return memref::foldMemRefCast(*this);
5664 }
5665 
5666 std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5667  return llvm::to_vector<4>(getVectorType().getShape());
5668 }
5669 
5670 //===----------------------------------------------------------------------===//
5671 // MaskedLoadOp
5672 //===----------------------------------------------------------------------===//
5673 
5674 LogicalResult MaskedLoadOp::verify() {
5675  VectorType maskVType = getMaskVectorType();
5676  VectorType passVType = getPassThruVectorType();
5677  VectorType resVType = getVectorType();
5678  MemRefType memType = getMemRefType();
5679 
5680  if (resVType.getElementType() != memType.getElementType())
5681  return emitOpError("base and result element type should match");
5682  if (llvm::size(getIndices()) != memType.getRank())
5683  return emitOpError("requires ") << memType.getRank() << " indices";
5684  if (resVType.getShape() != maskVType.getShape())
5685  return emitOpError("expected result shape to match mask shape");
5686  if (resVType != passVType)
5687  return emitOpError("expected pass_thru of same type as result type");
5688  return success();
5689 }
5690 
5691 namespace {
5692 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
5693 public:
5695  LogicalResult matchAndRewrite(MaskedLoadOp load,
5696  PatternRewriter &rewriter) const override {
5697  switch (getMaskFormat(load.getMask())) {
5698  case MaskFormat::AllTrue:
5699  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5700  load, load.getType(), load.getBase(), load.getIndices());
5701  return success();
5702  case MaskFormat::AllFalse:
5703  rewriter.replaceOp(load, load.getPassThru());
5704  return success();
5705  case MaskFormat::Unknown:
5706  return failure();
5707  }
5708  llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
5709  }
5710 };
5711 } // namespace
5712 
5713 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5714  MLIRContext *context) {
5715  results.add<MaskedLoadFolder>(context);
5716 }
5717 
5718 OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5719  if (succeeded(memref::foldMemRefCast(*this)))
5720  return getResult();
5721  return OpFoldResult();
5722 }
5723 
5724 //===----------------------------------------------------------------------===//
5725 // MaskedStoreOp
5726 //===----------------------------------------------------------------------===//
5727 
5728 LogicalResult MaskedStoreOp::verify() {
5729  VectorType maskVType = getMaskVectorType();
5730  VectorType valueVType = getVectorType();
5731  MemRefType memType = getMemRefType();
5732 
5733  if (valueVType.getElementType() != memType.getElementType())
5734  return emitOpError("base and valueToStore element type should match");
5735  if (llvm::size(getIndices()) != memType.getRank())
5736  return emitOpError("requires ") << memType.getRank() << " indices";
5737  if (valueVType.getShape() != maskVType.getShape())
5738  return emitOpError("expected valueToStore shape to match mask shape");
5739  return success();
5740 }
5741 
5742 namespace {
5743 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
5744 public:
5746  LogicalResult matchAndRewrite(MaskedStoreOp store,
5747  PatternRewriter &rewriter) const override {
5748  switch (getMaskFormat(store.getMask())) {
5749  case MaskFormat::AllTrue:
5750  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5751  store, store.getValueToStore(), store.getBase(), store.getIndices());
5752  return success();
5753  case MaskFormat::AllFalse:
5754  rewriter.eraseOp(store);
5755  return success();
5756  case MaskFormat::Unknown:
5757  return failure();
5758  }
5759  llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
5760  }
5761 };
5762 } // namespace
5763 
5764 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5765  MLIRContext *context) {
5766  results.add<MaskedStoreFolder>(context);
5767 }
5768 
5769 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5770  SmallVectorImpl<OpFoldResult> &results) {
5771  return memref::foldMemRefCast(*this);
5772 }
5773 
5774 //===----------------------------------------------------------------------===//
5775 // GatherOp
5776 //===----------------------------------------------------------------------===//
5777 
5778 LogicalResult GatherOp::verify() {
5779  VectorType indVType = getIndexVectorType();
5780  VectorType maskVType = getMaskVectorType();
5781  VectorType resVType = getVectorType();
5782  ShapedType baseType = getBaseType();
5783 
5784  if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5785  return emitOpError("requires base to be a memref or ranked tensor type");
5786 
5787  if (resVType.getElementType() != baseType.getElementType())
5788  return emitOpError("base and result element type should match");
5789  if (llvm::size(getOffsets()) != baseType.getRank())
5790  return emitOpError("requires ") << baseType.getRank() << " indices";
5791  if (resVType.getShape() != indVType.getShape())
5792  return emitOpError("expected result dim to match indices dim");
5793  if (resVType.getShape() != maskVType.getShape())
5794  return emitOpError("expected result dim to match mask dim");
5795  if (resVType != getPassThruVectorType())
5796  return emitOpError("expected pass_thru of same type as result type");
5797  return success();
5798 }
5799 
5800 // MaskableOpInterface methods.
5801 
5802 /// Returns the mask type expected by this operation. Mostly used for
5803 /// verification purposes. It requires the operation to be vectorized."
5804 Type GatherOp::getExpectedMaskType() {
5805  auto vecType = this->getIndexVectorType();
5806  return VectorType::get(vecType.getShape(),
5807  IntegerType::get(vecType.getContext(), /*width=*/1),
5808  vecType.getScalableDims());
5809 }
5810 
5811 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5812  return llvm::to_vector<4>(getVectorType().getShape());
5813 }
5814 
5815 /// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
5816 static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
5817  auto vecType = dyn_cast<VectorType>(indexVec.getType());
5818  if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5819  return failure();
5820 
5821  if (indexVec.getDefiningOp<StepOp>())
5822  return success();
5823 
5824  DenseIntElementsAttr elements;
5825  if (!matchPattern(indexVec, m_Constant(&elements)))
5826  return failure();
5827 
5828  return success(
5829  llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
5830 }
5831 
5832 namespace {
5833 class GatherFolder final : public OpRewritePattern<GatherOp> {
5834 public:
5836  LogicalResult matchAndRewrite(GatherOp gather,
5837  PatternRewriter &rewriter) const override {
5838  switch (getMaskFormat(gather.getMask())) {
5839  case MaskFormat::AllTrue:
5840  return failure(); // no unmasked equivalent
5841  case MaskFormat::AllFalse:
5842  rewriter.replaceOp(gather, gather.getPassThru());
5843  return success();
5844  case MaskFormat::Unknown:
5845  return failure();
5846  }
5847  llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
5848  }
5849 };
5850 
5851 /// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
5852 /// maskedload. Only 1D fixed vectors are supported for now.
5853 class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
5854 public:
5856  LogicalResult matchAndRewrite(GatherOp op,
5857  PatternRewriter &rewriter) const override {
5858  if (!isa<MemRefType>(op.getBase().getType()))
5859  return rewriter.notifyMatchFailure(op, "base must be of memref type");
5860 
5861  if (failed(isZeroBasedContiguousSeq(op.getIndices())))
5862  return failure();
5863 
5864  rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
5865  op.getOffsets(), op.getMask(),
5866  op.getPassThru());
5867  return success();
5868  }
5869 };
5870 } // namespace
5871 
5872 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
5873  MLIRContext *context) {
5874  results.add<GatherFolder, FoldContiguousGather>(context);
5875 }
5876 
5877 //===----------------------------------------------------------------------===//
5878 // ScatterOp
5879 //===----------------------------------------------------------------------===//
5880 
5881 LogicalResult ScatterOp::verify() {
5882  VectorType indVType = getIndexVectorType();
5883  VectorType maskVType = getMaskVectorType();
5884  VectorType valueVType = getVectorType();
5885  MemRefType memType = getMemRefType();
5886 
5887  if (valueVType.getElementType() != memType.getElementType())
5888  return emitOpError("base and valueToStore element type should match");
5889  if (llvm::size(getOffsets()) != memType.getRank())
5890  return emitOpError("requires ") << memType.getRank() << " indices";
5891  if (valueVType.getShape() != indVType.getShape())
5892  return emitOpError("expected valueToStore dim to match indices dim");
5893  if (valueVType.getShape() != maskVType.getShape())
5894  return emitOpError("expected valueToStore dim to match mask dim");
5895  return success();
5896 }
5897 
5898 namespace {
5899 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
5900 public:
5902  LogicalResult matchAndRewrite(ScatterOp scatter,
5903  PatternRewriter &rewriter) const override {
5904  switch (getMaskFormat(scatter.getMask())) {
5905  case MaskFormat::AllTrue:
5906  return failure(); // no unmasked equivalent
5907  case MaskFormat::AllFalse:
5908  rewriter.eraseOp(scatter);
5909  return success();
5910  case MaskFormat::Unknown:
5911  return failure();
5912  }
5913  llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
5914  }
5915 };
5916 
5917 /// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
5918 /// maskedstore. Only 1D fixed vectors are supported for now.
5919 class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
5920 public:
5922  LogicalResult matchAndRewrite(ScatterOp op,
5923  PatternRewriter &rewriter) const override {
5924  if (failed(isZeroBasedContiguousSeq(op.getIndices())))
5925  return failure();
5926 
5927  rewriter.replaceOpWithNewOp<MaskedStoreOp>(
5928  op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
5929  return success();
5930  }
5931 };
5932 } // namespace
5933 
5934 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
5935  MLIRContext *context) {
5936  results.add<ScatterFolder, FoldContiguousScatter>(context);
5937 }
5938 
5939 //===----------------------------------------------------------------------===//
5940 // ExpandLoadOp
5941 //===----------------------------------------------------------------------===//
5942 
5943 LogicalResult ExpandLoadOp::verify() {
5944  VectorType maskVType = getMaskVectorType();
5945  VectorType passVType = getPassThruVectorType();
5946  VectorType resVType = getVectorType();
5947  MemRefType memType = getMemRefType();
5948 
5949  if (resVType.getElementType() != memType.getElementType())
5950  return emitOpError("base and result element type should match");
5951  if (llvm::size(getIndices()) != memType.getRank())
5952  return emitOpError("requires ") << memType.getRank() << " indices";
5953  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5954  return emitOpError("expected result dim to match mask dim");
5955  if (resVType != passVType)
5956  return emitOpError("expected pass_thru of same type as result type");
5957  return success();
5958 }
5959 
5960 namespace {
5961 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
5962 public:
5964  LogicalResult matchAndRewrite(ExpandLoadOp expand,
5965  PatternRewriter &rewriter) const override {
5966  switch (getMaskFormat(expand.getMask())) {
5967  case MaskFormat::AllTrue:
5968  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5969  expand, expand.getType(), expand.getBase(), expand.getIndices());
5970  return success();
5971  case MaskFormat::AllFalse:
5972  rewriter.replaceOp(expand, expand.getPassThru());
5973  return success();
5974  case MaskFormat::Unknown:
5975  return failure();
5976  }
5977  llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
5978  }
5979 };
5980 } // namespace
5981 
5982 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5983  MLIRContext *context) {
5984  results.add<ExpandLoadFolder>(context);
5985 }
5986 
5987 //===----------------------------------------------------------------------===//
5988 // CompressStoreOp
5989 //===----------------------------------------------------------------------===//
5990 
5991 LogicalResult CompressStoreOp::verify() {
5992  VectorType maskVType = getMaskVectorType();
5993  VectorType valueVType = getVectorType();
5994  MemRefType memType = getMemRefType();
5995 
5996  if (valueVType.getElementType() != memType.getElementType())
5997  return emitOpError("base and valueToStore element type should match");
5998  if (llvm::size(getIndices()) != memType.getRank())
5999  return emitOpError("requires ") << memType.getRank() << " indices";
6000  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6001  return emitOpError("expected valueToStore dim to match mask dim");
6002  return success();
6003 }
6004 
6005 namespace {
6006 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
6007 public:
6009  LogicalResult matchAndRewrite(CompressStoreOp compress,
6010  PatternRewriter &rewriter) const override {
6011  switch (getMaskFormat(compress.getMask())) {
6012  case MaskFormat::AllTrue:
6013  rewriter.replaceOpWithNewOp<vector::StoreOp>(
6014  compress, compress.getValueToStore(), compress.getBase(),
6015  compress.getIndices());
6016  return success();
6017  case MaskFormat::AllFalse:
6018  rewriter.eraseOp(compress);
6019  return success();
6020  case MaskFormat::Unknown:
6021  return failure();
6022  }
6023  llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
6024  }
6025 };
6026 } // namespace
6027 
6028 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6029  MLIRContext *context) {
6030  results.add<CompressStoreFolder>(context);
6031 }
6032 
6033 //===----------------------------------------------------------------------===//
6034 // ShapeCastOp
6035 //===----------------------------------------------------------------------===//
6036 
6037 void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6038  SetIntRangeFn setResultRanges) {
6039  setResultRanges(getResult(), argRanges.front());
6040 }
6041 
6042 LogicalResult ShapeCastOp::verify() {
6043 
6044  VectorType sourceType = getSourceVectorType();
6045  VectorType resultType = getResultVectorType();
6046 
6047  // Check that element type is preserved
6048  if (sourceType.getElementType() != resultType.getElementType())
6049  return emitOpError("has different source and result element types");
6050 
6051  // Check that number of elements is preserved
6052  int64_t sourceNElms = sourceType.getNumElements();
6053  int64_t resultNElms = resultType.getNumElements();
6054  if (sourceNElms != resultNElms) {
6055  return emitOpError() << "has different number of elements at source ("
6056  << sourceNElms << ") and result (" << resultNElms
6057  << ")";
6058  }
6059 
6060  // Check that (non-)scalability is preserved
6061  int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6062  int64_t resultNScalableDims = resultType.getNumScalableDims();
6063  if (sourceNScalableDims != resultNScalableDims)
6064  return emitOpError() << "has different number of scalable dims at source ("
6065  << sourceNScalableDims << ") and result ("
6066  << resultNScalableDims << ")";
6067 
6068  return success();
6069 }
6070 
6071 /// Return true if `transpose` does not permute a pair of non-unit dims.
6072 /// By `order preserving` we mean that the flattened versions of the input and
6073 /// output vectors are (numerically) identical. In other words `transpose` is
6074 /// effectively a shape cast.
6075 static bool isOrderPreserving(TransposeOp transpose) {
6076  ArrayRef<int64_t> permutation = transpose.getPermutation();
6077  VectorType sourceType = transpose.getSourceVectorType();
6078  ArrayRef<int64_t> inShape = sourceType.getShape();
6079  ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6080  auto isNonScalableUnitDim = [&](int64_t dim) {
6081  return inShape[dim] == 1 && !inDimIsScalable[dim];
6082  };
6083  int64_t current = 0;
6084  for (auto p : permutation) {
6085  if (!isNonScalableUnitDim(p)) {
6086  if (p < current) {
6087  return false;
6088  }
6089  current = p;
6090  }
6091  }
6092  return true;
6093 }
6094 
6095 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6096 
6097  VectorType resultType = getType();
6098 
6099  // No-op shape cast.
6100  if (getSource().getType() == resultType)
6101  return getSource();
6102 
6103  // shape_cast(shape_cast(x)) -> shape_cast(x)
6104  if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6105  setOperand(precedingShapeCast.getSource());
6106  return getResult();
6107  }
6108 
6109  // shape_cast(transpose(x)) -> shape_cast(x)
6110  if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6111  if (isOrderPreserving(transpose)) {
6112  setOperand(transpose.getVector());
6113  return getResult();
6114  }
6115  return {};
6116  }
6117 
6118  // Y = shape_cast(broadcast(X))
6119  // -> X, if X and Y have same type
6120  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6121  if (bcastOp.getSourceType() == resultType)
6122  return bcastOp.getSource();
6123  }
6124 
6125  // shape_cast(constant) -> constant
6126  if (auto denseAttr =
6127  dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6128  return denseAttr.reshape(getType());
6129 
6130  // shape_cast(poison) -> poison
6131  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
6132  return ub::PoisonAttr::get(getContext());
6133 
6134  return {};
6135 }
6136 
6137 namespace {
6138 
6139 /// Helper function that computes a new vector type based on the input vector
6140 /// type by removing the trailing one dims:
6141 ///
6142 /// vector<4x1x1xi1> --> vector<4x1xi1>
6143 ///
6144 static VectorType trimTrailingOneDims(VectorType oldType) {
6145  ArrayRef<int64_t> oldShape = oldType.getShape();
6146  ArrayRef<int64_t> newShape = oldShape;
6147 
6148  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6149  ArrayRef<bool> newScalableDims = oldScalableDims;
6150 
6151  while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6152  newShape = newShape.drop_back(1);
6153  newScalableDims = newScalableDims.drop_back(1);
6154  }
6155 
6156  // Make sure we have at least 1 dimension.
6157  // TODO: Add support for 0-D vectors.
6158  if (newShape.empty()) {
6159  newShape = oldShape.take_back();
6160  newScalableDims = oldScalableDims.take_back();
6161  }
6162 
6163  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6164 }
6165 
6166 /// Folds qualifying shape_cast(create_mask) into a new create_mask
6167 ///
6168 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
6169 /// dimension. If the input vector comes from `vector.create_mask` for which
6170 /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
6171 /// to fold shape_cast into create_mask.
6172 ///
6173 /// BEFORE:
6174 /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
6175 /// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
6176 /// AFTER:
6177 /// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
6178 class ShapeCastCreateMaskFolderTrailingOneDim final
6179  : public OpRewritePattern<ShapeCastOp> {
6180 public:
6182 
6183  LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6184  PatternRewriter &rewriter) const override {
6185  Value shapeOpSrc = shapeOp->getOperand(0);
6186  auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
6187  auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
6188  if (!createMaskOp && !constantMaskOp)
6189  return failure();
6190 
6191  VectorType shapeOpResTy = shapeOp.getResultVectorType();
6192  VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6193 
6194  VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6195  if (newVecType != shapeOpResTy)
6196  return failure();
6197 
6198  auto numDimsToDrop =
6199  shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6200 
6201  // No unit dims to drop
6202  if (!numDimsToDrop)
6203  return failure();
6204 
6205  if (createMaskOp) {
6206  auto maskOperands = createMaskOp.getOperands();
6207  auto numMaskOperands = maskOperands.size();
6208 
6209  // Check every mask dim size to see whether it can be dropped
6210  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6211  --i) {
6212  auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6213  if (!constant || (constant.value() != 1))
6214  return failure();
6215  }
6216  SmallVector<Value> newMaskOperands =
6217  maskOperands.drop_back(numDimsToDrop);
6218 
6219  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
6220  newMaskOperands);
6221  return success();
6222  }
6223 
6224  if (constantMaskOp) {
6225  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6226  auto numMaskOperands = maskDimSizes.size();
6227 
6228  // Check every mask dim size to see whether it can be dropped
6229  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6230  --i) {
6231  if (maskDimSizes[i] != 1)
6232  return failure();
6233  }
6234 
6235  auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6236  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
6237  newMaskOperands);
6238  return success();
6239  }
6240 
6241  return failure();
6242  }
6243 };
6244 
6245 /// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
6246 /// i) Y = ShapeCast(X), or
6247 /// ii) Y = Broadcast(X)
6248 /// If both (i) and (ii) are possible, (i) is chosen.
6249 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
6250 public:
6252 
6253  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6254  PatternRewriter &rewriter) const override {
6255  auto broadcastOp =
6256  shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6257  if (!broadcastOp)
6258  return failure();
6259 
6260  auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6261  bool srcIsScalar = !srcVectorType;
6262 
6263  // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
6264  // Example:
6265  // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
6266  // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
6267  // to
6268  // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
6269  if (srcVectorType) {
6270  if (srcVectorType.getNumElements() ==
6271  shapeCastOp.getResultVectorType().getNumElements()) {
6272  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
6273  shapeCastOp, shapeCastOp.getResultVectorType(),
6274  broadcastOp.getSource());
6275  return success();
6276  }
6277  }
6278 
6279  // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
6280  // Example
6281  // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
6282  // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
6283  // to
6284  // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
6285  VectorType dstVectorType = shapeCastOp.getResultVectorType();
6286  if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
6287  BroadcastableToResult::Success) {
6288  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6289  shapeCastOp, dstVectorType, broadcastOp.getSource());
6290  return success();
6291  }
6292  return failure();
6293  }
6294 };
6295 
6296 } // namespace
6297 
6298 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6299  MLIRContext *context) {
6300  results
6301  .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
6302  context);
6303 }
6304 
6305 //===----------------------------------------------------------------------===//
6306 // VectorBitCastOp
6307 //===----------------------------------------------------------------------===//
6308 
6309 LogicalResult BitCastOp::verify() {
6310  auto sourceVectorType = getSourceVectorType();
6311  auto resultVectorType = getResultVectorType();
6312 
6313  for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6314  if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6315  return emitOpError("dimension size mismatch at: ") << i;
6316  }
6317 
6318  DataLayout dataLayout = DataLayout::closest(*this);
6319  auto sourceElementBits =
6320  dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
6321  auto resultElementBits =
6322  dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
6323 
6324  if (sourceVectorType.getRank() == 0) {
6325  if (sourceElementBits != resultElementBits)
6326  return emitOpError("source/result bitwidth of the 0-D vector element "
6327  "types must be equal");
6328  } else if (sourceElementBits * sourceVectorType.getShape().back() !=
6329  resultElementBits * resultVectorType.getShape().back()) {
6330  return emitOpError(
6331  "source/result bitwidth of the minor 1-D vectors must be equal");
6332  }
6333 
6334  return success();
6335 }
6336 
6337 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
6338  // Nop cast.
6339  if (getSource().getType() == getResult().getType())
6340  return getSource();
6341 
6342  // Canceling bitcasts.
6343  if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6344  if (getResult().getType() == otherOp.getSource().getType())
6345  return otherOp.getSource();
6346 
6347  setOperand(otherOp.getSource());
6348  return getResult();
6349  }
6350 
6351  Attribute sourceConstant = adaptor.getSource();
6352  if (!sourceConstant)
6353  return {};
6354 
6355  Type srcElemType = getSourceVectorType().getElementType();
6356  Type dstElemType = getResultVectorType().getElementType();
6357 
6358  if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6359  if (floatPack.isSplat()) {
6360  auto splat = floatPack.getSplatValue<FloatAttr>();
6361 
6362  // Casting fp16 into fp32.
6363  if (srcElemType.isF16() && dstElemType.isF32()) {
6364  uint32_t bits = static_cast<uint32_t>(
6365  splat.getValue().bitcastToAPInt().getZExtValue());
6366  // Duplicate the 16-bit pattern.
6367  bits = (bits << 16) | (bits & 0xffff);
6368  APInt intBits(32, bits);
6369  APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6370  return DenseElementsAttr::get(getResultVectorType(), floatBits);
6371  }
6372  }
6373  }
6374 
6375  if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6376  if (intPack.isSplat()) {
6377  auto splat = intPack.getSplatValue<IntegerAttr>();
6378 
6379  if (llvm::isa<IntegerType>(dstElemType)) {
6380  uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
6381  uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
6382 
6383  // Casting to a larger integer bit width.
6384  if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6385  APInt intBits = splat.getValue().zext(dstBitWidth);
6386 
6387  // Duplicate the lower width element.
6388  for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6389  intBits = (intBits << srcBitWidth) | intBits;
6390  return DenseElementsAttr::get(getResultVectorType(), intBits);
6391  }
6392  }
6393  }
6394  }
6395 
6396  return {};
6397 }
6398 
6399 //===----------------------------------------------------------------------===//
6400 // TypeCastOp
6401 //===----------------------------------------------------------------------===//
6402 
6403 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
6404  auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6405  SmallVector<int64_t, 8> res(memRefType.getShape());
6406  if (vectorType)
6407  res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6408  return res;
6409 }
6410 
6411 /// Build the canonical memRefType with a single vector.
6412 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
6413 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
6414  Value source) {
6415  result.addOperands(source);
6416  MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
6417  VectorType vectorType =
6418  VectorType::get(extractShape(memRefType),
6420  result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
6421  memRefType.getMemorySpace()));
6422 }
6423 
6424 LogicalResult TypeCastOp::verify() {
6425  MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
6426  if (!canonicalType.getLayout().isIdentity())
6427  return emitOpError("expects operand to be a memref with identity layout");
6428  if (!getResultMemRefType().getLayout().isIdentity())
6429  return emitOpError("expects result to be a memref with identity layout");
6430  if (getResultMemRefType().getMemorySpace() !=
6431  getMemRefType().getMemorySpace())
6432  return emitOpError("expects result in same memory space");
6433 
6434  auto sourceType = getMemRefType();
6435  auto resultType = getResultMemRefType();
6436  if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
6438  return emitOpError(
6439  "expects result and operand with same underlying scalar type: ")
6440  << resultType;
6441  if (extractShape(sourceType) != extractShape(resultType))
6442  return emitOpError(
6443  "expects concatenated result and operand shapes to be equal: ")
6444  << resultType;
6445  return success();
6446 }
6447 
6448 //===----------------------------------------------------------------------===//
6449 // TransposeOp
6450 //===----------------------------------------------------------------------===//
6451 
6452 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
6453  Value vector, ArrayRef<int64_t> permutation) {
6454  VectorType vt = llvm::cast<VectorType>(vector.getType());
6455  SmallVector<int64_t, 4> transposedShape(vt.getRank());
6456  SmallVector<bool, 4> transposedScalableDims(vt.getRank());
6457  for (unsigned i = 0; i < permutation.size(); ++i) {
6458  transposedShape[i] = vt.getShape()[permutation[i]];
6459  transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6460  }
6461 
6462  result.addOperands(vector);
6463  result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
6464  transposedScalableDims));
6465  result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
6466  builder.getDenseI64ArrayAttr(permutation));
6467 }
6468 
6469 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6470  // Eliminate splat constant transpose ops.
6471  if (auto splat =
6472  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6473  return splat.reshape(getResultVectorType());
6474 
6475  // Eliminate poison transpose ops.
6476  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
6477  return ub::PoisonAttr::get(getContext());
6478 
6479  // Eliminate identity transposes, and more generally any transposes that
6480  // preserves the shape without permuting elements.
6481  //
6482  // Examples of what to fold:
6483  // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6484  // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6485  // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6486  //
6487  // Example of what NOT to fold:
6488  // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6489  //
6490  if (getSourceVectorType() == getResultVectorType() &&
6491  isOrderPreserving(*this))
6492  return getVector();
6493 
6494  return {};
6495 }
6496 
6497 LogicalResult vector::TransposeOp::verify() {
6498  VectorType vectorType = getSourceVectorType();
6499  VectorType resultType = getResultVectorType();
6500  int64_t rank = resultType.getRank();
6501  if (vectorType.getRank() != rank)
6502  return emitOpError("vector result rank mismatch: ") << rank;
6503  // Verify transposition array.
6504  ArrayRef<int64_t> perm = getPermutation();
6505  int64_t size = perm.size();
6506  if (rank != size)
6507  return emitOpError("transposition length mismatch: ") << size;
6508  SmallVector<bool, 8> seen(rank, false);
6509  for (const auto &ta : llvm::enumerate(perm)) {
6510  if (ta.value() < 0 || ta.value() >= rank)
6511  return emitOpError("transposition index out of range: ") << ta.value();
6512  if (seen[ta.value()])
6513  return emitOpError("duplicate position index: ") << ta.value();
6514  seen[ta.value()] = true;
6515  if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6516  return emitOpError("dimension size mismatch at: ") << ta.value();
6517  }
6518  return success();
6519 }
6520 
6521 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6522  return llvm::to_vector<4>(getResultVectorType().getShape());
6523 }
6524 
6525 void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6526  SetIntRangeFn setResultRanges) {
6527  setResultRanges(getResult(), argRanges.front());
6528 }
6529 
6530 namespace {
6531 
6532 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
6533 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
6534 public:
6536 
6537  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6538  PatternRewriter &rewriter) const override {
6539  // Composes two permutations: result[i] = permutation1[permutation2[i]].
6540  auto composePermutations = [](ArrayRef<int64_t> permutation1,
6541  ArrayRef<int64_t> permutation2) {
6542  SmallVector<int64_t, 4> result;
6543  for (auto index : permutation2)
6544  result.push_back(permutation1[index]);
6545  return result;
6546  };
6547 
6548  // Return if the input of 'transposeOp' is not defined by another transpose.
6549  vector::TransposeOp parentTransposeOp =
6550  transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6551  if (!parentTransposeOp)
6552  return failure();
6553 
6554  SmallVector<int64_t, 4> permutation = composePermutations(
6555  parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6556  // Replace 'transposeOp' with a new transpose operation.
6557  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
6558  transposeOp, transposeOp.getResult().getType(),
6559  parentTransposeOp.getVector(), permutation);
6560  return success();
6561  }
6562 };
6563 
6564 /// Replace transpose(splat-like(v)) with broadcast(v)
6565 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
6566 public:
6568 
6569  LogicalResult matchAndRewrite(TransposeOp transposeOp,
6570  PatternRewriter &rewriter) const override {
6571  Value splat = getScalarSplatSource(transposeOp.getVector());
6572  if (!splat)
6573  return failure();
6574 
6575  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6576  transposeOp, transposeOp.getResultVectorType(), splat);
6577  return success();
6578  }
6579 };
6580 
6581 /// Folds transpose(create_mask) into a new transposed create_mask.
6582 class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6583 public:
6585 
6586  LogicalResult matchAndRewrite(TransposeOp transpOp,
6587  PatternRewriter &rewriter) const override {
6588  Value transposeSrc = transpOp.getVector();
6589  auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
6590  auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
6591  if (!createMaskOp && !constantMaskOp)
6592  return failure();
6593 
6594  // Get the transpose permutation and apply it to the vector.create_mask or
6595  // vector.constant_mask operands.
6596  ArrayRef<int64_t> permutation = transpOp.getPermutation();
6597 
6598  if (createMaskOp) {
6599  auto maskOperands = createMaskOp.getOperands();
6600  SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
6601  applyPermutationToVector(newOperands, permutation);
6602 
6603  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
6604  transpOp, transpOp.getResultVectorType(), newOperands);
6605  return success();
6606  }
6607 
6608  // ConstantMaskOp case.
6609  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6610  auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
6611 
6612  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
6613  transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6614  return success();
6615  }
6616 };
6617 
6618 /// Folds transpose(shape_cast) into a new shape_cast.
6619 class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6620 public:
6622 
6623  LogicalResult matchAndRewrite(TransposeOp transposeOp,
6624  PatternRewriter &rewriter) const override {
6625  auto shapeCastOp =
6626  transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6627  if (!shapeCastOp)
6628  return failure();
6629  if (!isOrderPreserving(transposeOp))
6630  return failure();
6631 
6632  VectorType resultType = transposeOp.getType();
6633 
6634  // We don't need to check isValidShapeCast at this point, because it is
6635  // guaranteed that merging the transpose into the the shape_cast is a valid
6636  // shape_cast, because the transpose just inserts/removes ones.
6637 
6638  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
6639  shapeCastOp.getSource());
6640  return success();
6641  }
6642 };
6643 
6644 /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
6645 /// 'order preserving', where 'order preserving' means the flattened
6646 /// inputs and outputs of the transpose have identical (numerical) values.
6647 ///
6648 /// Example:
6649 /// ```
6650 /// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
6651 /// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
6652 /// to vector<8x1xi32>
6653 /// ```
6654 /// can be rewritten as the equivalent
6655 /// ```
6656 /// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
6657 /// ```
6658 /// The algorithm works by partitioning dimensions into groups that can be
6659 /// locally permuted while preserving order, and checks that the transpose
6660 /// only permutes within these groups.
6661 ///
6662 /// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
6663 /// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
6664 /// broadcasting from 1x1x4x1x1x7.
6665 /// ^^^ ^ ^^^ ^
6666 /// groups: 0 1 2 3
6667 /// Order preserving permutations for this example are ones that only permute
6668 /// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
6669 class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6670 public:
6672  FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
6673  : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
6674 
6675  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
6676  PatternRewriter &rewriter) const override {
6677 
6678  vector::BroadcastOp broadcast =
6679  transpose.getVector().getDefiningOp<vector::BroadcastOp>();
6680  if (!broadcast) {
6681  return rewriter.notifyMatchFailure(transpose,
6682  "not preceded by a broadcast");
6683  }
6684 
6685  auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
6686  VectorType outputType = transpose.getResultVectorType();
6687 
6688  // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
6689  bool inputIsScalar = !inputType;
6690  if (inputIsScalar) {
6691  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
6692  broadcast.getSource());
6693  return success();
6694  }
6695 
6696  ArrayRef<int64_t> permutation = transpose.getPermutation();
6697  ArrayRef<int64_t> inputShape = inputType.getShape();
6698  int64_t inputRank = inputType.getRank();
6699  int64_t outputRank = transpose.getType().getRank();
6700  int64_t deltaRank = outputRank - inputRank;
6701 
6702  int low = 0;
6703  for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
6704  bool notOne = inputShape[inputIndex] != 1;
6705  bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
6706  bool groupEndFound = notOne || prevNotOne;
6707  if (groupEndFound) {
6708  int high = inputIndex + deltaRank;
6709  // Return failure if not all permutation destinations for indices in
6710  // [low, high) are in [low, high), i.e. the permutation is not local to
6711  // the group.
6712  for (int i = low; i < high; ++i) {
6713  if (permutation[i] < low || permutation[i] >= high) {
6714  return rewriter.notifyMatchFailure(
6715  transpose, "permutation not local to group");
6716  }
6717  }
6718  low = high;
6719  }
6720  }
6721 
6722  // We don't need to check the final group [low, outputRank) because if it is
6723  // not locally bound, there must be a preceding group that already failed
6724  // the check (impossible to have just 1 non-locally bound group).
6725 
6726  // The preceding logic also ensures that at this point, the output of the
6727  // transpose is definitely broadcastable from the input shape, assert so:
6728  assert(vector::isBroadcastableTo(inputType, outputType) ==
6729  vector::BroadcastableToResult::Success &&
6730  "not broadcastable directly to transpose output");
6731 
6732  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
6733  broadcast.getSource());
6734 
6735  return success();
6736  }
6737 };
6738 
6739 } // namespace
6740 
6741 void vector::TransposeOp::getCanonicalizationPatterns(
6742  RewritePatternSet &results, MLIRContext *context) {
6743  results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6744  FoldTransposeSplat, FoldTransposeBroadcast>(context);
6745 }
6746 
6747 //===----------------------------------------------------------------------===//
6748 // ConstantMaskOp
6749 //===----------------------------------------------------------------------===//
6750 
6751 void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
6752  VectorType type, ConstantMaskKind kind) {
6753  assert(kind == ConstantMaskKind::AllTrue ||
6754  kind == ConstantMaskKind::AllFalse);
6755  build(builder, result, type,
6756  kind == ConstantMaskKind::AllTrue
6757  ? type.getShape()
6758  : SmallVector<int64_t>(type.getRank(), 0));
6759 }
6760 
6761 LogicalResult ConstantMaskOp::verify() {
6762  auto resultType = llvm::cast<VectorType>(getResult().getType());
6763  // Check the corner case of 0-D vectors first.
6764  if (resultType.getRank() == 0) {
6765  if (getMaskDimSizes().size() != 1)
6766  return emitError("array attr must have length 1 for 0-D vectors");
6767  auto dim = getMaskDimSizes()[0];
6768  if (dim != 0 && dim != 1)
6769  return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
6770  return success();
6771  }
6772 
6773  // Verify that array attr size matches the rank of the vector result.
6774  if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
6775  return emitOpError(
6776  "must specify array attr of size equal vector result rank");
6777  // Verify that each array attr element is in bounds of corresponding vector
6778  // result dimension size.
6779  auto resultShape = resultType.getShape();
6780  auto resultScalableDims = resultType.getScalableDims();
6781  ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
6782  for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
6783  if (maskDimSize < 0 || maskDimSize > resultShape[index])
6784  return emitOpError(
6785  "array attr of size out of bounds of vector result dimension size");
6786  if (resultScalableDims[index] && maskDimSize != 0 &&
6787  maskDimSize != resultShape[index])
6788  return emitOpError(
6789  "only supports 'none set' or 'all set' scalable dimensions");
6790  }
6791  // Verify that if one mask dim size is zero, they all should be zero (because
6792  // the mask region is a conjunction of each mask dimension interval).
6793  bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6794  bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
6795  if (anyZeros && !allZeros)
6796  return emitOpError("expected all mask dim sizes to be zeros, "
6797  "as a result of conjunction with zero mask dim");
6798  return success();
6799 }
6800 
6801 bool ConstantMaskOp::isAllOnesMask() {
6802  auto resultType = getVectorType();
6803  // Check the corner case of 0-D vectors first.
6804  if (resultType.getRank() == 0) {
6805  assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
6806  return getMaskDimSizes()[0] == 1;
6807  }
6808  for (const auto [resultSize, maskDimSize] :
6809  llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
6810  if (maskDimSize < resultSize)
6811  return false;
6812  }
6813  return true;
6814 }
6815 
6816 OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
6817  ArrayRef<int64_t> bounds = getMaskDimSizes();
6818  ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
6819 
6820  auto createBoolSplat = [&](bool x) {
6822  BoolAttr::get(getContext(), x));
6823  };
6824 
6825  // Check the corner case of 0-D vectors first.
6826  if (vectorSizes.empty()) {
6827  assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
6828  return createBoolSplat(bounds[0] == 1);
6829  }
6830  // Fold vector.constant_mask to splat if possible.
6831  if (bounds == vectorSizes)
6832  return createBoolSplat(true);
6833  if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
6834  return createBoolSplat(false);
6835  return OpFoldResult();
6836 }
6837 
6838 //===----------------------------------------------------------------------===//
6839 // CreateMaskOp
6840 //===----------------------------------------------------------------------===//
6841 
6842 void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
6843  VectorType type,
6844  ArrayRef<OpFoldResult> mixedOperands) {
6845  SmallVector<Value> operands =
6846  getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
6847  build(builder, result, type, operands);
6848 }
6849 
6850 LogicalResult CreateMaskOp::verify() {
6851  auto vectorType = llvm::cast<VectorType>(getResult().getType());
6852  // Verify that an operand was specified for each result vector each dimension.
6853  if (vectorType.getRank() == 0) {
6854  if (getNumOperands() != 1)
6855  return emitOpError(
6856  "must specify exactly one operand for 0-D create_mask");
6857  } else if (getNumOperands() !=
6858  llvm::cast<VectorType>(getResult().getType()).getRank()) {
6859  return emitOpError(
6860  "must specify an operand for each result vector dimension");
6861  }
6862  return success();
6863 }
6864 
6865 namespace {
6866 
6867 /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
6868 ///
6869 /// Ex 1:
6870 /// %c2 = arith.constant 2 : index
6871 /// %c3 = arith.constant 3 : index
6872 /// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
6873 /// Becomes:
6874 /// vector.constant_mask [3, 2] : vector<4x3xi1>
6875 ///
6876 /// Ex 2:
6877 /// %c_neg_1 = arith.constant -1 : index
6878 /// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
6879 /// becomes:
6880 /// vector.constant_mask [0] : vector<[8]xi1>
6881 ///
6882 /// Ex 3:
6883 /// %c8 = arith.constant 8 : index
6884 /// %c16 = arith.constant 16 : index
6885 /// %0 = vector.vscale
6886 /// %1 = arith.muli %0, %c16 : index
6887 /// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
6888 /// becomes:
6889 /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
6890 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
6891 public:
6893 
6894  LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
6895  PatternRewriter &rewriter) const override {
6896  VectorType maskType = createMaskOp.getVectorType();
6897  ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
6898  ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6899 
6900  // Special case: Rank zero shape.
6901  constexpr std::array<int64_t, 1> rankZeroShape{1};
6902  constexpr std::array<bool, 1> rankZeroScalableDims{false};
6903  if (maskType.getRank() == 0) {
6904  maskTypeDimSizes = rankZeroShape;
6905  maskTypeDimScalableFlags = rankZeroScalableDims;
6906  }
6907 
6908  // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
6909  // collect the `constantDims` (for the ConstantMaskOp).
6910  SmallVector<int64_t, 4> constantDims;
6911  for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
6912  if (auto intSize = getConstantIntValue(dimSize)) {
6913  // Constant value.
6914  // If the mask dim is non-scalable this can be any value.
6915  // If the mask dim is scalable only zero (all-false) is supported.
6916  if (maskTypeDimScalableFlags[i] && intSize >= 0)
6917  return failure();
6918  constantDims.push_back(*intSize);
6919  } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
6920  // Constant vscale multiple (e.g. 4 x vscale).
6921  // Must be all-true to fold to a ConstantMask.
6922  if (vscaleMultiplier < maskTypeDimSizes[i])
6923  return failure();
6924  constantDims.push_back(*vscaleMultiplier);
6925  } else {
6926  return failure();
6927  }
6928  }
6929 
6930  // Clamp values to constant_mask bounds.
6931  for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
6932  value = std::clamp<int64_t>(value, 0, maskDimSize);
6933 
6934  // If one of dim sizes is zero, set all dims to zero.
6935  if (llvm::is_contained(constantDims, 0))
6936  constantDims.assign(constantDims.size(), 0);
6937 
6938  // Replace 'createMaskOp' with ConstantMaskOp.
6939  rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
6940  constantDims);
6941  return success();
6942  }
6943 };
6944 
6945 } // namespace
6946 
6947 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6948  MLIRContext *context) {
6949  results.add<CreateMaskFolder>(context);
6950 }
6951 
6952 //===----------------------------------------------------------------------===//
6953 // MaskOp
6954 //===----------------------------------------------------------------------===//
6955 
6956 void MaskOp::build(
6957  OpBuilder &builder, OperationState &result, Value mask,
6958  Operation *maskableOp,
6959  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6960  assert(maskRegionBuilder &&
6961  "builder callback for 'maskRegion' must be present");
6962 
6963  result.addOperands(mask);
6964  OpBuilder::InsertionGuard guard(builder);
6965  Region *maskRegion = result.addRegion();
6966  builder.createBlock(maskRegion);
6967  maskRegionBuilder(builder, maskableOp);
6968 }
6969 
6970 void MaskOp::build(
6971  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6972  Value mask, Operation *maskableOp,
6973  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6974  build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
6975  maskRegionBuilder);
6976 }
6977 
6978 void MaskOp::build(
6979  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6980  Value mask, Value passthru, Operation *maskableOp,
6981  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6982  build(builder, result, mask, maskableOp, maskRegionBuilder);
6983  if (passthru)
6984  result.addOperands(passthru);
6985  result.addTypes(resultTypes);
6986 }
6987 
6988 ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
6989  // Create the op region.
6990  result.regions.reserve(1);
6991  Region &maskRegion = *result.addRegion();
6992 
6993  auto &builder = parser.getBuilder();
6994 
6995  // Parse all the operands.
6997  if (parser.parseOperand(mask))
6998  return failure();
6999 
7000  // Optional passthru operand.
7002  ParseResult parsePassthru = parser.parseOptionalComma();
7003  if (parsePassthru.succeeded() && parser.parseOperand(passthru))
7004  return failure();
7005 
7006  // Parse op region.
7007  if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
7008  return failure();
7009 
7010  MaskOp::ensureTerminator(maskRegion, builder, result.location);
7011 
7012  // Parse the optional attribute list.
7013  if (parser.parseOptionalAttrDict(result.attributes))
7014  return failure();
7015 
7016  // Parse all the types.
7017  Type maskType;
7018  if (parser.parseColonType(maskType))
7019  return failure();
7020 
7021  SmallVector<Type> resultTypes;
7022  if (parser.parseOptionalArrowTypeList(resultTypes))
7023  return failure();
7024  result.types.append(resultTypes);
7025 
7026  // Resolve operands.
7027  if (parser.resolveOperand(mask, maskType, result.operands))
7028  return failure();
7029 
7030  if (parsePassthru.succeeded()) {
7031  if (resultTypes.empty())
7032  return parser.emitError(
7033  parser.getNameLoc(),
7034  "expects a result if passthru operand is provided");
7035 
7036  if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
7037  return failure();
7038  }
7039 
7040  return success();
7041 }
7042 
7044  p << " " << getMask();
7045  if (getPassthru())
7046  p << ", " << getPassthru();
7047 
7048  // Print single masked operation and skip terminator.
7049  p << " { ";
7050  Block *singleBlock = &getMaskRegion().getBlocks().front();
7051  if (singleBlock && !singleBlock->getOperations().empty())
7052  p.printCustomOrGenericOp(&singleBlock->front());
7053  p << " }";
7054 
7055  p.printOptionalAttrDict(getOperation()->getAttrs());
7056 
7057  p << " : " << getMask().getType();
7058  if (getNumResults() > 0)
7059  p << " -> " << getResultTypes();
7060 }
7061 
7062 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
7063  // 1. For an empty `vector.mask`, create a default terminator.
7064  if (region.empty() || region.front().empty()) {
7066  MaskOp>::ensureTerminator(region, builder, loc);
7067  return;
7068  }
7069 
7070  // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
7071  Block &block = region.front();
7072  if (isa<vector::YieldOp>(block.back()))
7073  return;
7074 
7075  // 3. For a non-empty `vector.mask` without an explicit terminator:
7076 
7077  // Create default terminator if the number of masked operations is not
7078  // one. This case will trigger a verification failure.
7079  if (block.getOperations().size() != 1) {
7081  MaskOp>::ensureTerminator(region, builder, loc);
7082  return;
7083  }
7084 
7085  // Create a terminator that yields the results from the masked operation.
7086  OpBuilder opBuilder(builder.getContext());
7087  Operation *maskedOp = &block.front();
7088  opBuilder.setInsertionPointToEnd(&block);
7089  vector::YieldOp::create(opBuilder, loc, maskedOp->getResults());
7090 }
7091 
7092 LogicalResult MaskOp::verify() {
7093  // Structural checks.
7094  Block &block = getMaskRegion().getBlocks().front();
7095  if (block.getOperations().empty())
7096  return emitOpError("expects a terminator within the mask region");
7097 
7098  unsigned numMaskRegionOps = block.getOperations().size();
7099  if (numMaskRegionOps > 2)
7100  return emitOpError("expects only one operation to mask");
7101 
7102  // Terminator checks.
7103  auto terminator = dyn_cast<vector::YieldOp>(block.back());
7104  if (!terminator)
7105  return emitOpError("expects a terminator within the mask region");
7106 
7107  if (terminator->getNumOperands() != getNumResults())
7108  return emitOpError(
7109  "expects number of results to match mask region yielded values");
7110 
7111  // Empty vector.mask. Nothing else to check.
7112  if (numMaskRegionOps == 1)
7113  return success();
7114 
7115  auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
7116  if (!maskableOp)
7117  return emitOpError("expects a MaskableOpInterface within the mask region");
7118 
7119  // Result checks.
7120  if (maskableOp->getNumResults() != getNumResults())
7121  return emitOpError("expects number of results to match maskable operation "
7122  "number of results");
7123 
7124  if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
7125  return emitOpError("expects all the results from the MaskableOpInterface "
7126  "to match all the values returned by the terminator");
7127 
7128  if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
7129  return emitOpError(
7130  "expects result type to match maskable operation result type");
7131 
7132  if (llvm::count_if(maskableOp->getResultTypes(),
7133  [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7134  return emitOpError("multiple vector results not supported");
7135 
7136  // Mask checks.
7137  Type expectedMaskType = maskableOp.getExpectedMaskType();
7138  if (getMask().getType() != expectedMaskType)
7139  return emitOpError("expects a ")
7140  << expectedMaskType << " mask for the maskable operation";
7141 
7142  // Passthru checks.
7143  Value passthru = getPassthru();
7144  if (passthru) {
7145  if (!maskableOp.supportsPassthru())
7146  return emitOpError(
7147  "doesn't expect a passthru argument for this maskable operation");
7148 
7149  if (maskableOp->getNumResults() != 1)
7150  return emitOpError("expects result when passthru argument is provided");
7151 
7152  if (passthru.getType() != maskableOp->getResultTypes()[0])
7153  return emitOpError("expects passthru type to match result type");
7154  }
7155 
7156  return success();
7157 }
7158 
7159 /// Folds empty `vector.mask` with no passthru operand and with or without
7160 /// return values. For example:
7161 ///
7162 /// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
7163 /// vector<8xi1> -> vector<8xf32>
7164 /// %1 = user_op %0 : vector<8xf32>
7165 ///
7166 /// becomes:
7167 ///
7168 /// %0 = user_op %a : vector<8xf32>
7169 ///
7170 /// Empty `vector.mask` with passthru operand are handled by the canonicalizer
7171 /// as it requires creating new operations.
7172 
7173 static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7174  SmallVectorImpl<OpFoldResult> &results) {
7175  if (!maskOp.isEmpty() || maskOp.hasPassthru())
7176  return failure();
7177 
7178  Block *block = maskOp.getMaskBlock();
7179  auto terminator = cast<vector::YieldOp>(block->front());
7180  if (terminator.getNumOperands() == 0) {
7181  // `vector.mask` has no results, just remove the `vector.mask`.
7182  return success();
7183  }
7184 
7185  // `vector.mask` has results, propagate the results.
7186  llvm::append_range(results, terminator.getOperands());
7187  return success();
7188 }
7189 
7190 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7191  SmallVectorImpl<OpFoldResult> &results) {
7192  if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
7193  return success();
7194 
7195  MaskFormat maskFormat = getMaskFormat(getMask());
7196  if (maskFormat != MaskFormat::AllTrue)
7197  return failure();
7198 
7199  // Move maskable operation outside of the `vector.mask` region.
7200  Operation *maskableOp = getMaskableOp();
7201  maskableOp->dropAllUses();
7202  maskableOp->moveBefore(getOperation());
7203 
7204  llvm::append_range(results, maskableOp->getResults());
7205  return success();
7206 }
7207 
7208 /// Canonialize empty `vector.mask` operations that can't be handled in
7209 /// `VectorMask::fold` as they require creating new operations.
7210 ///
7211 /// Example 1: Empty `vector.mask` with passthru operand.
7212 ///
7213 /// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
7214 /// vector<8xi1> -> vector<8xf32>
7215 ///
7216 /// becomes:
7217 ///
7218 /// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
7219 ///
7220 class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
7222 
7223  LogicalResult matchAndRewrite(MaskOp maskOp,
7224  PatternRewriter &rewriter) const override {
7225  if (!maskOp.isEmpty())
7226  return failure();
7227 
7228  if (!maskOp.hasPassthru())
7229  return failure();
7230 
7231  Block *block = maskOp.getMaskBlock();
7232  auto terminator = cast<vector::YieldOp>(block->front());
7233  assert(terminator.getNumOperands() == 1 &&
7234  "expected one result when passthru is provided");
7235 
7236  rewriter.replaceOpWithNewOp<arith::SelectOp>(
7237  maskOp, maskOp.getResultTypes(), maskOp.getMask(),
7238  terminator.getOperand(0), maskOp.getPassthru());
7239 
7240  return success();
7241  }
7242 };
7243 
7244 void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7245  MLIRContext *context) {
7246  results.add<CanonializeEmptyMaskOp>(context);
7247 }
7248 
7249 // MaskingOpInterface definitions.
7250 
7251 /// Returns the operation masked by this 'vector.mask'.
7252 Operation *MaskOp::getMaskableOp() {
7253  Block *block = getMaskBlock();
7254  if (block->getOperations().size() < 2)
7255  return nullptr;
7256 
7257  return &block->front();
7258 }
7259 
7260 /// Returns true if 'vector.mask' has a passthru value.
7261 bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
7262 
7263 //===----------------------------------------------------------------------===//
7264 // ScanOp
7265 //===----------------------------------------------------------------------===//
7266 
7267 LogicalResult ScanOp::verify() {
7268  VectorType srcType = getSourceType();
7269  VectorType initialType = getInitialValueType();
7270  // Check reduction dimension < rank.
7271  int64_t srcRank = srcType.getRank();
7272  int64_t reductionDim = getReductionDim();
7273  if (reductionDim >= srcRank)
7274  return emitOpError("reduction dimension ")
7275  << reductionDim << " has to be less than " << srcRank;
7276 
7277  // Check that rank(initial_value) = rank(src) - 1.
7278  int64_t initialValueRank = initialType.getRank();
7279  if (initialValueRank != srcRank - 1)
7280  return emitOpError("initial value rank ")
7281  << initialValueRank << " has to be equal to " << srcRank - 1;
7282 
7283  // Check shapes of initial value and src.
7284  ArrayRef<int64_t> srcShape = srcType.getShape();
7285  ArrayRef<int64_t> initialValueShapes = initialType.getShape();
7286  SmallVector<int64_t> expectedShape;
7287  for (int i = 0; i < srcRank; i++) {
7288  if (i != reductionDim)
7289  expectedShape.push_back(srcShape[i]);
7290  }
7291  if (!llvm::equal(initialValueShapes, expectedShape)) {
7292  return emitOpError("incompatible input/initial value shapes");
7293  }
7294 
7295  // Verify supported reduction kind.
7296  Type eltType = getDestType().getElementType();
7297  if (!isSupportedCombiningKind(getKind(), eltType))
7298  return emitOpError("unsupported reduction type ")
7299  << eltType << " for kind '" << stringifyCombiningKind(getKind())
7300  << "'";
7301 
7302  return success();
7303 }
7304 
7307  patterns
7308  .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7309  ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7310  StridedSliceConstantMaskFolder, TransposeFolder>(
7311  patterns.getContext(), benefit);
7312 }
7313 
7314 //===----------------------------------------------------------------------===//
7315 // SplatOp
7316 //===----------------------------------------------------------------------===//
7317 
7318 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
7319  auto constOperand = adaptor.getInput();
7320  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
7321  return {};
7322 
7323  // SplatElementsAttr::get treats single value for second arg as being a splat.
7324  return SplatElementsAttr::get(getType(), {constOperand});
7325 }
7326 
7327 // Canonicalizer for vector.splat. It always gets canonicalized to a
7328 // vector.broadcast.
7329 class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
7330 public:
7332  LogicalResult matchAndRewrite(SplatOp splatOp,
7333  PatternRewriter &rewriter) const override {
7334  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
7335  splatOp.getOperand());
7336  return success();
7337  }
7338 };
7339 void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
7340  MLIRContext *context) {
7341  results.add<SplatToBroadcastPattern>(context);
7342 }
7343 
7344 void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7345  SetIntRangeFn setResultRanges) {
7346  setResultRanges(getResult(), argRanges.front());
7347 }
7348 
7350  CombiningKind kind, Value v1, Value acc,
7351  arith::FastMathFlagsAttr fastmath,
7352  Value mask) {
7353  Type t1 = getElementTypeOrSelf(v1.getType());
7354  Type tAcc = getElementTypeOrSelf(acc.getType());
7355  Value result;
7356 
7357  switch (kind) {
7358  case CombiningKind::ADD:
7359  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7360  result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
7361  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7362  result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7363  else
7364  llvm_unreachable("invalid value types for ADD reduction");
7365  break;
7366  case CombiningKind::AND:
7367  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7368  result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
7369  break;
7370  case CombiningKind::MAXNUMF:
7371  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7372  "expected float values");
7373  result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7374  break;
7375  case CombiningKind::MAXIMUMF:
7376  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7377  "expected float values");
7378  result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7379  break;
7380  case CombiningKind::MINNUMF:
7381  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7382  "expected float values");
7383  result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7384  break;
7385  case CombiningKind::MINIMUMF:
7386  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7387  "expected float values");
7388  result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7389  break;
7390  case CombiningKind::MAXSI:
7391  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7392  result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
7393  break;
7394  case CombiningKind::MINSI:
7395  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7396  result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
7397  break;
7398  case CombiningKind::MAXUI:
7399  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7400  result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
7401  break;
7402  case CombiningKind::MINUI:
7403  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7404  result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
7405  break;
7406  case CombiningKind::MUL:
7407  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7408  result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
7409  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7410  result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7411  else
7412  llvm_unreachable("invalid value types for MUL reduction");
7413  break;
7414  case CombiningKind::OR:
7415  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7416  result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
7417  break;
7418  case CombiningKind::XOR:
7419  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7420  result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
7421  break;
7422  };
7423 
7424  assert(result && "unknown CombiningKind");
7425  return selectPassthru(b, mask, result, acc);
7426 }
7427 
7428 //===----------------------------------------------------------------------===//
7429 // StepOp
7430 //===----------------------------------------------------------------------===//
7431 
7432 void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7433  SetIntRangeFn setResultRanges) {
7434  auto resultType = cast<VectorType>(getType());
7435  if (resultType.isScalable()) {
7436  return;
7437  }
7438  unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
7439  APInt zero(bitwidth, 0);
7440  APInt high(bitwidth, resultType.getDimSize(0) - 1);
7441  ConstantIntRanges result = {zero, high, zero, high};
7442  setResultRanges(getResult(), result);
7443 }
7444 
7445 //===----------------------------------------------------------------------===//
7446 // Vector Masking Utilities
7447 //===----------------------------------------------------------------------===//
7448 
7449 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
7450 /// as masked operation.
7452  Operation *maskableOp) {
7453  assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
7454  Block *insBlock = builder.getInsertionBlock();
7455  // Create a block and move the op to that block.
7456  insBlock->getOperations().splice(
7457  insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
7458  YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults());
7459 }
7460 
7461 /// Creates a vector.mask operation around a maskable operation. Returns the
7462 /// vector.mask operation if the mask provided is valid. Otherwise, returns
7463 /// the maskable operation itself.
7465  Operation *maskableOp, Value mask,
7466  Value passthru) {
7467  if (!mask)
7468  return maskableOp;
7469  if (passthru)
7470  return MaskOp::create(builder, maskableOp->getLoc(),
7471  maskableOp->getResultTypes(), mask, passthru,
7472  maskableOp, createMaskOpRegion);
7473  return MaskOp::create(builder, maskableOp->getLoc(),
7474  maskableOp->getResultTypes(), mask, maskableOp,
7476 }
7477 
7478 /// Creates a vector select operation that picks values from `newValue` or
7479 /// `passthru` for each result vector lane based on `mask`. This utility is used
7480 /// to propagate the pass-thru value of vector.mask or for cases where only the
7481 /// pass-thru value propagation is needed. VP intrinsics do not support
7482 /// pass-thru values and every mask-out lane is set to poison. LLVM backends are
7483 /// usually able to match op + select patterns and fold them into a native
7484 /// target instructions.
7486  Value newValue, Value passthru) {
7487  if (!mask)
7488  return newValue;
7489 
7490  return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(),
7491  mask, newValue, passthru);
7492 }
7493 
7494 //===----------------------------------------------------------------------===//
7495 // TableGen'd op method definitions
7496 //===----------------------------------------------------------------------===//
7497 
7498 #define GET_ATTRDEF_CLASSES
7499 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
7500 
7501 #define GET_OP_CLASSES
7502 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static SmallVector< Value > computeStrides(Location loc, RewriterBase &rewriter, ValueRange dynamicBasis, ArrayRef< int64_t > staticBasis, bool knownNonNegative)
Given a basis (in static and dynamic components), return the sequence of suffix products of the basis...
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MINUI(lhs, rhs)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition: VectorOps.cpp:69
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
Definition: VectorOps.cpp:1129
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
Definition: VectorOps.cpp:1381
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
Definition: VectorOps.cpp:2431
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
Definition: VectorOps.cpp:1642
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
Definition: VectorOps.cpp:4666
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
Definition: VectorOps.cpp:1988
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
Definition: VectorOps.cpp:1856
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
Definition: VectorOps.cpp:1694
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
Definition: VectorOps.cpp:2389
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
Definition: VectorOps.cpp:2083
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Definition: VectorOps.cpp:2347
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
Definition: VectorOps.cpp:2861
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
Definition: VectorOps.cpp:131
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
Definition: VectorOps.cpp:2073
MaskFormat
Helper enum to classify mask value.
Definition: VectorOps.cpp:59
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
Definition: VectorOps.cpp:3664
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
Definition: VectorOps.cpp:178
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
Definition: VectorOps.cpp:2464
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
Definition: VectorOps.cpp:327
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
Definition: VectorOps.cpp:910
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
Definition: VectorOps.cpp:2649
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
Definition: VectorOps.cpp:2024
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
Definition: VectorOps.cpp:3007
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
Definition: VectorOps.cpp:3600
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
Definition: VectorOps.cpp:2367
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
Definition: VectorOps.cpp:1121
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
Definition: VectorOps.cpp:3620
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
Definition: VectorOps.cpp:206
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
Definition: VectorOps.cpp:3643
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
Definition: VectorOps.cpp:3471
static LogicalResult foldTransferFullMask(TransferOp op)
Definition: VectorOps.cpp:4912
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Definition: VectorOps.cpp:1373
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
Definition: VectorOps.cpp:1263
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
Definition: VectorOps.cpp:4154
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
Definition: VectorOps.cpp:4558
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
Definition: VectorOps.cpp:2499
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
Definition: VectorOps.cpp:3518
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
Definition: VectorOps.cpp:921
static bool isBroadcastLike(Operation *op)
All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'br...
Definition: VectorOps.cpp:1654
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
Definition: VectorOps.cpp:3585
static Value foldExtractFromShapeCast(ExtractOp extractOp)
Definition: VectorOps.cpp:1791
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
Definition: VectorOps.cpp:4587
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
Definition: VectorOps.cpp:4840
static Attribute convertIntegerAttr(Attribute attr, Type expectedType)
Converts an IntegerAttr to have the specified type if needed.
Definition: VectorOps.cpp:403
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
Definition: VectorOps.cpp:3998
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
Definition: VectorOps.cpp:1759
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
Definition: VectorOps.cpp:4857
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
Definition: VectorOps.cpp:1908
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
Definition: VectorOps.cpp:3296
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
Definition: VectorOps.cpp:2091
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Definition: VectorOps.cpp:2529
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:131
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:411
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:394
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:212
unsigned getNumResults() const
Definition: AffineMap.cpp:398
unsigned getNumInputs() const
Definition: AffineMap.cpp:399
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:157
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:58
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
Operation & back()
Definition: Block.h:152
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:158
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:162
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:52
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:276
IndexType getIndexType()
Definition: Builders.cpp:50
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:265
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:313
A set of arbitrary-precision integers representing bounds on a given integer value.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:834
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:554
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:845
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
bool isF32() const
Definition: Types.cpp:40
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:112
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:286
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Definition: BuiltinTypes.h:298
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:305
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:391
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Definition: VectorOps.cpp:492
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
Definition: VectorOps.cpp:126
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
Definition: VectorOps.cpp:384
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:188
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition: VectorOps.h:62
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2781
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Definition: VectorOps.cpp:4685
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
Definition: VectorOps.cpp:250
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:314
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:345
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:369
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
Definition: VectorOps.cpp:668
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition: VectorOps.h:70
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Definition: VectorOps.cpp:488
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:766
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:715
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:923
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Return a fused vector::ContractionOp which represents a patterns such as:
Definition: VectorOps.cpp:1221
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Definition: VectorOps.cpp:1224
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Definition: PatternMatch.h:297
bool operator==(const KeyTy &key) const
Definition: VectorOps.cpp:423
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Definition: VectorOps.cpp:425