MLIR  22.0.0git
VectorOps.cpp
Go to the documentation of this file.
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/AffineMap.h"
28 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/IRMapping.h"
34 #include "mlir/IR/PatternMatch.h"
35 #include "mlir/IR/TypeUtilities.h"
36 #include "mlir/IR/ValueRange.h"
39 #include "mlir/Support/LLVM.h"
41 #include "llvm/ADT/ArrayRef.h"
42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/ADT/SmallVector.h"
44 #include "llvm/ADT/StringSet.h"
45 #include "llvm/ADT/TypeSwitch.h"
46 #include "llvm/Support/Casting.h"
47 
48 #include <cassert>
49 #include <cstdint>
50 
51 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
52 // Pull in all enum type and utility function definitions.
53 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
54 
55 using namespace mlir;
56 using namespace mlir::vector;
57 
58 /// Helper enum to classify mask value.
59 enum class MaskFormat {
60  AllTrue = 0,
61  AllFalse = 1,
62  Unknown = 2,
63 };
64 
65 /// Helper method to classify a mask value. Currently, the method
66 /// looks "under the hood" of a constant value with dense attributes
67 /// and a constant mask operation (since the client may be called at
68 /// various stages during progressive lowering).
70  if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
71  // Inspect constant dense values. We count up for bits that
72  // are set, count down for bits that are cleared, and bail
73  // when a mix is detected.
74  if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
75  int64_t val = 0;
76  for (bool b : denseElts.getValues<bool>())
77  if (b && val >= 0)
78  val++;
79  else if (!b && val <= 0)
80  val--;
81  else
82  return MaskFormat::Unknown;
83  if (val > 0)
84  return MaskFormat::AllTrue;
85  if (val < 0)
86  return MaskFormat::AllFalse;
87  }
88  } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
89  // Inspect constant mask index. If the index exceeds the
90  // dimension size, all bits are set. If the index is zero
91  // or less, no bits are set.
92  ArrayRef<int64_t> masks = m.getMaskDimSizes();
93  auto shape = m.getType().getShape();
94  bool allTrue = true;
95  bool allFalse = true;
96  for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
97  if (maskIdx < dimSize)
98  allTrue = false;
99  if (maskIdx > 0)
100  allFalse = false;
101  }
102  if (allTrue)
103  return MaskFormat::AllTrue;
104  if (allFalse)
105  return MaskFormat::AllFalse;
106  } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
107  // Finds all-false create_masks. An all-true create_mask requires all
108  // dims to be constants, so that'll be folded to a constant_mask, then
109  // detected in the constant_mask case.
110  auto maskOperands = m.getOperands();
111  for (Value operand : maskOperands) {
112  if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
113  int64_t dimSize =
114  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
115  if (dimSize <= 0)
116  return MaskFormat::AllFalse;
117  }
118  }
119  return MaskFormat::Unknown;
120  }
121  return MaskFormat::Unknown;
122 }
123 
124 /// Default callback to build a region with a 'vector.yield' terminator with no
125 /// arguments.
127  vector::YieldOp::create(builder, loc);
128 }
129 
130 // Helper for verifying combining kinds in contractions and reductions.
131 static bool isSupportedCombiningKind(CombiningKind combiningKind,
132  Type elementType) {
133  switch (combiningKind) {
134  case CombiningKind::ADD:
135  case CombiningKind::MUL:
136  return elementType.isIntOrIndexOrFloat();
138  case CombiningKind::MINSI:
139  case CombiningKind::MAXUI:
140  case CombiningKind::MAXSI:
141  case CombiningKind::AND:
142  case CombiningKind::OR:
143  case CombiningKind::XOR:
144  return elementType.isIntOrIndex();
145  case CombiningKind::MINNUMF:
146  case CombiningKind::MAXNUMF:
147  case CombiningKind::MINIMUMF:
148  case CombiningKind::MAXIMUMF:
149  return llvm::isa<FloatType>(elementType);
150  }
151  return false;
152 }
153 
154 /// Returns the effective rank of the vector to read/write for Xfer Ops
155 ///
156 /// When the element type of the shaped type is _a scalar_, this will simply
157 /// return the rank of the vector ( the result for xfer_read or the value to
158 /// store for xfer_write).
159 ///
160 /// When the element type of the base shaped type is _a vector_, returns the
161 /// difference between the original vector type and the element type of the
162 /// shaped type.
163 ///
164 /// EXAMPLE 1 (element type is _a scalar_):
165 /// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
166 /// - shapedType.getElementType() = f32 (rank 0)
167 /// - vectorType.getRank() = 2
168 /// - Result = 2 - 0 = 2
169 ///
170 /// EXAMPLE 2 (element type is _a vector_):
171 /// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
172 /// - shapedType.getElementType() = vector<20xf32> (rank 1)
173 /// - vectorType.getRank() = 1
174 /// - Result = 1 - 1 = 0
175 ///
176 /// This is used to determine the number of minor dimensions for identity maps
177 /// in vector transfer Ops.
178 static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
179  VectorType vectorType) {
180  unsigned elementVectorRank = 0;
181  VectorType elementVectorType =
182  llvm::dyn_cast<VectorType>(shapedType.getElementType());
183  if (elementVectorType)
184  elementVectorRank += elementVectorType.getRank();
185  return vectorType.getRank() - elementVectorRank;
186 }
187 
189  VectorType vectorType) {
190  // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
191  // TODO: replace once we have 0-d vectors.
192  if (shapedType.getRank() == 0 &&
193  vectorType.getShape() == ArrayRef<int64_t>{1})
194  return AffineMap::get(
195  /*numDims=*/0, /*numSymbols=*/0,
196  getAffineConstantExpr(0, shapedType.getContext()));
198  shapedType.getRank(),
199  getEffectiveVectorRankForXferOp(shapedType, vectorType),
200  shapedType.getContext());
201 }
202 
203 /// Check if `write` is of a constant splat and the masked `read` is padded with
204 /// the same splat value -- meaning it could be the same value as the initial
205 /// constant splat.
206 static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
207  vector::TransferReadOp read) {
208  auto readMask = read.getMask();
209  auto writeMask = write.getMask();
210  // Check if the masks are consistent. The splat value could be the same if the
211  // read is masked (and padded with the splat value), and the write is unmasked
212  // or has the same mask. Note this does not allow the case where the write is
213  // masked and the read is unmasked, as then the read could be of more elements
214  // than the write (which may not be the same value).
215  bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
216  if (!couldBeSameSplat)
217  return false;
218  // Check for constant splat (as the source of the write).
219  DenseElementsAttr splatAttr;
220  if (!matchPattern(write.getVector(),
221  m_Constant<DenseElementsAttr>(&splatAttr)) ||
222  !splatAttr.isSplat()) {
223  return false;
224  }
225  // The padding of the read and the constant splat value must be the same.
226  Attribute padAttr;
227  if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
228  return false;
229  return padAttr == splatAttr.getSplatValue<Attribute>();
230 }
231 
232 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
233  vector::TransferReadOp read) {
234  return !defWrite.hasOutOfBoundsDim() &&
235  defWrite.getIndices() == read.getIndices() &&
236  defWrite.getVectorType() == read.getVectorType() &&
237  defWrite.getPermutationMap() == read.getPermutationMap() &&
238  ((!defWrite.getMask() && !read.getMask()) ||
239  isSplatWriteConsistentWithMaskedRead(defWrite, read));
240 }
241 
242 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
243  vector::TransferWriteOp priorWrite) {
244  return priorWrite.getIndices() == write.getIndices() &&
245  priorWrite.getMask() == write.getMask() &&
246  priorWrite.getVectorType() == write.getVectorType() &&
247  priorWrite.getPermutationMap() == write.getPermutationMap();
248 }
249 
251  VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
252  bool testDynamicValueUsingBounds) {
253  // For simplicity only look at transfer of same type.
254  if (transferA.getVectorType() != transferB.getVectorType())
255  return false;
256  unsigned rankOffset = transferA.getLeadingShapedRank();
257  for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
258  Value indexA = transferA.getIndices()[i];
259  Value indexB = transferB.getIndices()[i];
260  std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
261  std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
262 
263  if (i < rankOffset) {
264  // For leading dimensions, if we can prove that index are different we
265  // know we are accessing disjoint slices.
266  if (cstIndexA.has_value() && cstIndexB.has_value()) {
267  if (*cstIndexA != *cstIndexB)
268  return true;
269  continue;
270  }
271  if (testDynamicValueUsingBounds) {
272  // First try to see if we can fully compose and simplify the affine
273  // expression as a fast track.
274  FailureOr<uint64_t> delta =
276  if (succeeded(delta) && *delta != 0)
277  return true;
278 
279  FailureOr<bool> testEqual =
280  ValueBoundsConstraintSet::areEqual(indexA, indexB);
281  if (succeeded(testEqual) && !testEqual.value())
282  return true;
283  }
284  } else {
285  // For this dimension, we slice a part of the memref we need to make sure
286  // the intervals accessed don't overlap.
287  int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
288  if (cstIndexA.has_value() && cstIndexB.has_value()) {
289  int64_t distance = std::abs(*cstIndexA - *cstIndexB);
290  if (distance >= vectorDim)
291  return true;
292  continue;
293  }
294  if (testDynamicValueUsingBounds) {
295  // First try to see if we can fully compose and simplify the affine
296  // expression as a fast track.
297  FailureOr<int64_t> delta =
299  if (succeeded(delta) && std::abs(*delta) >= vectorDim)
300  return true;
301 
302  FailureOr<int64_t> computeDelta =
304  if (succeeded(computeDelta)) {
305  if (std::abs(computeDelta.value()) >= vectorDim)
306  return true;
307  }
308  }
309  }
310  }
311  return false;
312 }
313 
314 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
315  VectorTransferOpInterface transferB,
316  bool testDynamicValueUsingBounds) {
317  if (transferA.getBase() != transferB.getBase())
318  return false;
319  return isDisjointTransferIndices(transferA, transferB,
320  testDynamicValueUsingBounds);
321 }
322 
323 // Helper to iterate over n-D vector slice elements. Calculate the next
324 // `position` in the n-D vector of size `shape`, applying an offset `offsets`.
325 // Modifies the `position` in place. Returns a failure when `position` becomes
326 // the end position.
327 static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
328  ArrayRef<int64_t> shape,
329  ArrayRef<int64_t> offsets) {
330  for (auto [posInDim, dimSize, offsetInDim] :
331  llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
332  ++posInDim;
333  if (posInDim < dimSize + offsetInDim)
334  return success();
335 
336  // Carry the overflow to the next loop iteration.
337  posInDim = offsetInDim;
338  }
339 
340  return failure();
341 }
342 
343 /// Returns the integer numbers in `values`. `values` are expected to be
344 /// constant operations.
347  llvm::transform(values, std::back_inserter(ints), [](Value value) {
348  auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
349  assert(constOp && "Unexpected non-constant index");
350  return constOp.value();
351  });
352  return ints;
353 }
354 
355 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
356 /// be constant operations.
359  llvm::transform(
360  foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
361  assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
362  return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
363  });
364  return ints;
365 }
366 
367 /// Convert `foldResults` into Values. Integer attributes are converted to
368 /// constant op.
370  ArrayRef<OpFoldResult> foldResults) {
371  SmallVector<Value> values;
372  llvm::transform(foldResults, std::back_inserter(values),
373  [&](OpFoldResult foldResult) {
374  if (auto attr = dyn_cast<Attribute>(foldResult))
376  builder, loc, cast<IntegerAttr>(attr).getInt())
377  .getResult();
378 
379  return cast<Value>(foldResult);
380  });
381  return values;
382 }
383 
384 std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
385  if (value.getDefiningOp<vector::VectorScaleOp>())
386  return 1;
387  auto mul = value.getDefiningOp<arith::MulIOp>();
388  if (!mul)
389  return {};
390  auto lhs = mul.getLhs();
391  auto rhs = mul.getRhs();
392  if (lhs.getDefiningOp<vector::VectorScaleOp>())
393  return getConstantIntValue(rhs);
394  if (rhs.getDefiningOp<vector::VectorScaleOp>())
395  return getConstantIntValue(lhs);
396  return {};
397 }
398 
399 /// Converts an IntegerAttr to have the specified type if needed.
400 /// This handles cases where integer constant attributes have a different type
401 /// than the target element type.
402 static IntegerAttr convertIntegerAttr(IntegerAttr intAttr, Type expectedType) {
403  if (intAttr.getType() == expectedType)
404  return intAttr; // Already correct type
405 
406  return IntegerAttr::get(expectedType, intAttr.getInt());
407 }
408 
409 //===----------------------------------------------------------------------===//
410 // CombiningKindAttr
411 //===----------------------------------------------------------------------===//
412 
413 namespace mlir {
414 namespace vector {
415 namespace detail {
417  using KeyTy = uint64_t;
418 
419  BitmaskEnumStorage(KeyTy val) : value(val) {}
420 
421  bool operator==(const KeyTy &key) const { return value == key; }
422 
424  const KeyTy &key) {
425  return new (allocator.allocate<BitmaskEnumStorage>())
426  BitmaskEnumStorage(key);
427  }
428 
429  KeyTy value = 0;
430 };
431 } // namespace detail
432 } // namespace vector
433 } // namespace mlir
434 
435 //===----------------------------------------------------------------------===//
436 // VectorDialect
437 //===----------------------------------------------------------------------===//
438 
439 namespace {
440 /// This class defines the interface for handling inlining with vector dialect
441 /// operations.
442 struct VectorInlinerInterface : public DialectInlinerInterface {
444 
445  /// All vector dialect ops can be inlined.
446  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
447  return true;
448  }
449 };
450 } // namespace
451 
452 void VectorDialect::initialize() {
453  addAttributes<
454 #define GET_ATTRDEF_LIST
455 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
456  >();
457 
458  addOperations<
459 #define GET_OP_LIST
460 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
461  >();
462 
463  addInterfaces<VectorInlinerInterface>();
464 
465  declarePromisedInterfaces<bufferization::BufferizableOpInterface,
466  TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
467  YieldOp>();
468  declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
469  TransferWriteOp>();
470  declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
471  declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
472  declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
473 }
474 
475 /// Materialize a single constant operation from a given attribute value with
476 /// the desired resultant type.
478  Attribute value, Type type,
479  Location loc) {
480  if (isa<ub::PoisonAttrInterface>(value))
481  return value.getDialect().materializeConstant(builder, value, type, loc);
482 
483  return arith::ConstantOp::materialize(builder, value, type, loc);
484 }
485 
487  return builder.getIntegerType(64);
488 }
489 
491  ArrayRef<int64_t> values) {
492  return builder.getI64ArrayAttr(values);
493 }
494 
495 //===----------------------------------------------------------------------===//
496 // MultiDimReductionOp
497 //===----------------------------------------------------------------------===//
498 
499 void vector::MultiDimReductionOp::build(OpBuilder &builder,
500  OperationState &result, Value source,
501  Value acc, ArrayRef<bool> reductionMask,
502  CombiningKind kind) {
503  SmallVector<int64_t> reductionDims;
504  for (const auto &en : llvm::enumerate(reductionMask))
505  if (en.value())
506  reductionDims.push_back(en.index());
507  build(builder, result, kind, source, acc, reductionDims);
508 }
509 
510 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
511  // Single parallel dim, this is a noop.
512  if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
513  return getSource();
514  return {};
515 }
516 
517 std::optional<SmallVector<int64_t, 4>>
518 MultiDimReductionOp::getShapeForUnroll() {
519  return llvm::to_vector<4>(getSourceVectorType().getShape());
520 }
521 
522 LogicalResult MultiDimReductionOp::verify() {
523  SmallVector<int64_t> targetShape;
524  SmallVector<bool> scalableDims;
525  Type inferredReturnType;
526  auto sourceScalableDims = getSourceVectorType().getScalableDims();
527  for (auto [dimIdx, dimSize] :
528  llvm::enumerate(getSourceVectorType().getShape()))
529  if (!llvm::any_of(getReductionDims(),
530  [dimIdx = dimIdx](int64_t reductionDimIdx) {
531  return reductionDimIdx == static_cast<int64_t>(dimIdx);
532  })) {
533  targetShape.push_back(dimSize);
534  scalableDims.push_back(sourceScalableDims[dimIdx]);
535  }
536  // TODO: update to also allow 0-d vectors when available.
537  if (targetShape.empty())
538  inferredReturnType = getSourceVectorType().getElementType();
539  else
540  inferredReturnType = VectorType::get(
541  targetShape, getSourceVectorType().getElementType(), scalableDims);
542  if (getType() != inferredReturnType)
543  return emitOpError() << "destination type " << getType()
544  << " is incompatible with source type "
545  << getSourceVectorType();
546 
547  return success();
548 }
549 
550 /// Returns the mask type expected by this operation.
551 Type MultiDimReductionOp::getExpectedMaskType() {
552  auto vecType = getSourceVectorType();
553  return VectorType::get(vecType.getShape(),
554  IntegerType::get(vecType.getContext(), /*width=*/1),
555  vecType.getScalableDims());
556 }
557 
558 namespace {
559 // Only unit dimensions that are being reduced are folded. If the dimension is
560 // unit, but not reduced, it is not folded, thereby keeping the output type the
561 // same. If not all dimensions which are reduced are of unit dimension, this
562 // transformation does nothing. This is just a generalization of
563 // ElideSingleElementReduction for ReduceOp.
564 struct ElideUnitDimsInMultiDimReduction
565  : public OpRewritePattern<MultiDimReductionOp> {
567 
568  LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
569  PatternRewriter &rewriter) const override {
570  ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
571  for (const auto &dim : enumerate(shape)) {
572  if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
573  return failure();
574  }
575 
576  // Vector mask setup.
577  OpBuilder::InsertionGuard guard(rewriter);
578  Operation *rootOp;
579  Value mask;
580  if (reductionOp.isMasked()) {
581  rewriter.setInsertionPoint(reductionOp.getMaskingOp());
582  rootOp = reductionOp.getMaskingOp();
583  mask = reductionOp.getMaskingOp().getMask();
584  } else {
585  rootOp = reductionOp;
586  }
587 
588  Location loc = reductionOp.getLoc();
589  Value acc = reductionOp.getAcc();
590  Value cast;
591  if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
592  if (mask) {
593  VectorType newMaskType =
594  VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
595  dstVecType.getScalableDims());
596  mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
597  }
598  cast = vector::ShapeCastOp::create(
599  rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
600  } else {
601  // This means we are reducing all the dimensions, and all reduction
602  // dimensions are of size 1. So a simple extraction would do.
603  if (mask)
604  mask = vector::ExtractOp::create(rewriter, loc, mask);
605  cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
606  }
607 
608  Value result =
609  vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
610  cast, /*fastmath=*/nullptr, mask);
611  rewriter.replaceOp(rootOp, result);
612  return success();
613  }
614 };
615 } // namespace
616 
617 void MultiDimReductionOp::getCanonicalizationPatterns(
618  RewritePatternSet &results, MLIRContext *context) {
619  results.add<ElideUnitDimsInMultiDimReduction>(context);
620 }
621 
622 //===----------------------------------------------------------------------===//
623 // ReductionOp
624 //===----------------------------------------------------------------------===//
625 
626 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
627  CombiningKind kind, Value vector,
628  arith::FastMathFlags fastMathFlags) {
629  build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
630 }
631 
632 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
633  CombiningKind kind, Value vector, Value acc,
634  arith::FastMathFlags fastMathFlags) {
635  build(builder, result,
636  llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
637  acc, fastMathFlags);
638 }
639 
640 LogicalResult ReductionOp::verify() {
641  // Verify for 0-D and 1-D vector.
642  int64_t rank = getSourceVectorType().getRank();
643  if (rank > 1)
644  return emitOpError("unsupported reduction rank: ") << rank;
645 
646  // Verify supported reduction kind.
647  Type eltType = getDest().getType();
648  if (!isSupportedCombiningKind(getKind(), eltType))
649  return emitOpError("unsupported reduction type '")
650  << eltType << "' for kind '" << stringifyCombiningKind(getKind())
651  << "'";
652 
653  return success();
654 }
655 
656 // MaskableOpInterface methods.
657 
658 /// Returns the mask type expected by this operation.
659 Type ReductionOp::getExpectedMaskType() {
660  auto vecType = getSourceVectorType();
661  return VectorType::get(vecType.getShape(),
662  IntegerType::get(vecType.getContext(), /*width=*/1),
663  vecType.getScalableDims());
664 }
665 
667  OpBuilder &builder, Location loc,
668  Value vector) {
669  switch (op) {
670  case arith::AtomicRMWKind::addf:
671  case arith::AtomicRMWKind::addi:
672  return vector::ReductionOp::create(builder, vector.getLoc(),
673  CombiningKind::ADD, vector);
674  case arith::AtomicRMWKind::mulf:
675  case arith::AtomicRMWKind::muli:
676  return vector::ReductionOp::create(builder, vector.getLoc(),
677  CombiningKind::MUL, vector);
678  case arith::AtomicRMWKind::minimumf:
679  return vector::ReductionOp::create(builder, vector.getLoc(),
680  CombiningKind::MINIMUMF, vector);
681  case arith::AtomicRMWKind::mins:
682  return vector::ReductionOp::create(builder, vector.getLoc(),
683  CombiningKind::MINSI, vector);
684  case arith::AtomicRMWKind::minu:
685  return vector::ReductionOp::create(builder, vector.getLoc(),
686  CombiningKind::MINUI, vector);
687  case arith::AtomicRMWKind::maximumf:
688  return vector::ReductionOp::create(builder, vector.getLoc(),
689  CombiningKind::MAXIMUMF, vector);
690  case arith::AtomicRMWKind::maxs:
691  return vector::ReductionOp::create(builder, vector.getLoc(),
692  CombiningKind::MAXSI, vector);
693  case arith::AtomicRMWKind::maxu:
694  return vector::ReductionOp::create(builder, vector.getLoc(),
695  CombiningKind::MAXUI, vector);
696  case arith::AtomicRMWKind::andi:
697  return vector::ReductionOp::create(builder, vector.getLoc(),
698  CombiningKind::AND, vector);
699  case arith::AtomicRMWKind::ori:
700  return vector::ReductionOp::create(builder, vector.getLoc(),
701  CombiningKind::OR, vector);
702  // TODO: Add remaining reduction operations.
703  default:
704  (void)emitOptionalError(loc, "Reduction operation type not supported");
705  break;
706  }
707  return nullptr;
708 }
709 
710 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
711  return llvm::to_vector<4>(getSourceVectorType().getShape());
712 }
713 
714 namespace {
715 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
717 
718  LogicalResult matchAndRewrite(ReductionOp reductionOp,
719  PatternRewriter &rewriter) const override {
720  // Vector mask setup.
721  OpBuilder::InsertionGuard guard(rewriter);
722  auto maskableOp =
723  cast<vector::MaskableOpInterface>(reductionOp.getOperation());
724  Operation *rootOp;
725  Value mask;
726  if (maskableOp.isMasked()) {
727  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
728  rootOp = maskableOp.getMaskingOp();
729  mask = maskableOp.getMaskingOp().getMask();
730  } else {
731  rootOp = reductionOp;
732  }
733 
734  auto vectorType = reductionOp.getSourceVectorType();
735  if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
736  return failure();
737 
738  Location loc = reductionOp.getLoc();
739  if (mask)
740  mask = ExtractOp::create(rewriter, loc, mask);
741  Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
742 
743  if (Value acc = reductionOp.getAcc())
744  result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
745  result, acc,
746  reductionOp.getFastmathAttr(), mask);
747 
748  rewriter.replaceOp(rootOp, result);
749  return success();
750  }
751 };
752 } // namespace
753 
754 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
755  MLIRContext *context) {
756  results.add<ElideSingleElementReduction>(context);
757 }
758 
759 //===----------------------------------------------------------------------===//
760 // ContractionOp
761 //===----------------------------------------------------------------------===//
762 
763 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
764  Value lhs, Value rhs, Value acc,
765  ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
766  ArrayRef<IteratorType> iteratorTypes) {
767  result.addOperands({lhs, rhs, acc});
768  result.addTypes(acc.getType());
769  result.addAttribute(
770  getIndexingMapsAttrName(result.name),
771  builder.getAffineMapArrayAttr(
772  AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
773  result.addAttribute(
774  getIteratorTypesAttrName(result.name),
775  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
776  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
777  return IteratorTypeAttr::get(builder.getContext(), t);
778  }))));
779 }
780 
781 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
782  Value lhs, Value rhs, Value acc,
783  ArrayAttr indexingMaps,
784  ArrayAttr iteratorTypes) {
785  build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
786  ContractionOp::getDefaultKind());
787 }
788 
789 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
790  Value lhs, Value rhs, Value acc,
791  ArrayAttr indexingMaps,
792  ArrayAttr iteratorTypes, CombiningKind kind) {
793  result.addOperands({lhs, rhs, acc});
794  result.addTypes(acc.getType());
795  result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
796  result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
797  result.addAttribute(getKindAttrName(result.name),
799 }
800 
801 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
806  SmallVector<Type, 2> types;
807  Type resultType;
808  auto loc = parser.getCurrentLocation();
809  DictionaryAttr dictAttr;
810  // TODO: Unify linalg op attribute parsing.
811  if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
812  parser.parseComma() || parser.parseOperand(rhsInfo) ||
813  parser.parseComma() || parser.parseOperand(accInfo) ||
814  parser.parseTrailingOperandList(masksInfo) ||
815  parser.parseOptionalAttrDict(result.attributes) ||
816  parser.parseColonTypeList(types) ||
817  parser.parseKeywordType("into", resultType) ||
818  parser.resolveOperand(lhsInfo, types[0], result.operands) ||
819  parser.resolveOperand(rhsInfo, types[1], result.operands) ||
820  parser.resolveOperand(accInfo, resultType, result.operands) ||
821  parser.addTypeToList(resultType, result.types))
822  return failure();
823  result.attributes.append(dictAttr.getValue().begin(),
824  dictAttr.getValue().end());
825 
826  // Convert array of string into an array of IteratyType enums. This is needed,
827  // because tests still use the old format when 'iterator_types' attribute is
828  // represented as an array of strings.
829  // TODO: Remove this conversion once tests are fixed.
830  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
831  result.attributes.get(getIteratorTypesAttrName(result.name)));
832  if (!iteratorTypes) {
833  return parser.emitError(loc)
834  << "expected " << getIteratorTypesAttrName(result.name)
835  << " array attribute";
836  }
837 
838  SmallVector<Attribute> iteratorTypeAttrs;
839 
840  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
841  auto maybeIteratorType = symbolizeIteratorType(s);
842  if (!maybeIteratorType.has_value())
843  return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
844 
845  iteratorTypeAttrs.push_back(
846  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
847  }
848  result.attributes.set(getIteratorTypesAttrName(result.name),
849  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
850 
851  if (!result.attributes.get(getKindAttrName(result.name))) {
852  result.addAttribute(
853  getKindAttrName(result.name),
855  ContractionOp::getDefaultKind()));
856  }
857  if (masksInfo.empty())
858  return success();
859  if (masksInfo.size() != 2)
860  return parser.emitError(parser.getNameLoc(),
861  "expected zero or exactly 2 vector mask operands");
862  auto lhsType = llvm::cast<VectorType>(types[0]);
863  auto rhsType = llvm::cast<VectorType>(types[1]);
864  auto maskElementType = parser.getBuilder().getI1Type();
865  std::array<VectorType, 2> maskTypes = {
866  VectorType::Builder(lhsType).setElementType(maskElementType),
867  VectorType::Builder(rhsType).setElementType(maskElementType)};
868  if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
869  return failure();
870  return success();
871 }
872 
874  // TODO: Unify printing code with linalg ops.
875  auto attrNames = getTraitAttrNames();
876  llvm::StringSet<> traitAttrsSet;
877  traitAttrsSet.insert_range(attrNames);
879  for (auto attr : (*this)->getAttrs()) {
880  if (attr.getName() == getIteratorTypesAttrName()) {
881  auto iteratorTypes =
882  llvm::cast<ArrayAttr>(attr.getValue())
883  .getAsValueRange<IteratorTypeAttr, IteratorType>();
884  // Convert IteratorType enums into the string representation. This is
885  // needed, because tests still use the old format when 'iterator_types'
886  // attribute is represented as an array of strings.
887  // TODO: Remove this conversion once tests are fixed.
888  SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
889  llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
890  return StringAttr::get(getContext(), stringifyIteratorType(t));
891  }));
892 
893  attrs.emplace_back(getIteratorTypesAttrName(),
894  ArrayAttr::get(getContext(), iteratorTypeNames));
895  } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
896  attrs.push_back(attr);
897  }
898 
899  auto dictAttr = DictionaryAttr::get(getContext(), attrs);
900  p << " " << dictAttr << " " << getLhs() << ", ";
901  p << getRhs() << ", " << getAcc();
902 
903  p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
904  p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
905  << getResultType();
906 }
907 
908 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
909  const std::vector<std::pair<int64_t, int64_t>> &map) {
910  for (auto &dimPair : map) {
911  if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
912  dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
913  lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
914  return false;
915  }
916  return true;
917 }
918 
919 static LogicalResult verifyOutputShape(
920  ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
921  Type resType,
922  const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
923  const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
924  DenseSet<int64_t> lhsContractingDimSet;
925  DenseSet<int64_t> rhsContractingDimSet;
926  for (auto &dimPair : contractingDimMap) {
927  lhsContractingDimSet.insert(dimPair.first);
928  rhsContractingDimSet.insert(dimPair.second);
929  }
930  DenseSet<int64_t> rhsBatchDimSet(llvm::from_range,
931  llvm::make_second_range(batchDimMap));
932 
933  // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
934  SmallVector<int64_t, 4> expectedResultDims;
935  for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
936  if (lhsContractingDimSet.count(i) > 0)
937  continue;
938  expectedResultDims.push_back(lhsType.getDimSize(i));
939  }
940 
941  // Add free dimensions from 'rhsType' to 'expectedResultDims'.
942  for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
943  if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
944  continue;
945  expectedResultDims.push_back(rhsType.getDimSize(i));
946  }
947 
948  // Verify 'expectedResultDims'.
949  if (expectedResultDims.empty()) {
950  // No batch or free dimension implies a scalar result.
951  if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
952  return op.emitOpError("invalid accumulator/result vector shape");
953  } else {
954  // At least one batch or free dimension implies a vector result.
955  auto resVectorType = llvm::dyn_cast<VectorType>(resType);
956  auto accVectorType = llvm::dyn_cast<VectorType>(accType);
957  if (!resVectorType || !accVectorType)
958  return op.emitOpError("invalid accumulator/result vector shape");
959 
960  // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
961  // types fully define the result vector type. This assumes the affine maps
962  // are well-formed, which must have been verified already.
963  MLIRContext *ctx = op.getContext();
964  AffineMap lhsMap = op.getIndexingMapsArray()[0];
965  AffineMap rhsMap = op.getIndexingMapsArray()[1];
966  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
967  return op.emitOpError(
968  "expected all dimensions to be either a LHS or a RHS dimension");
969  SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
970  for (auto pair :
971  {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
972  VectorType v = pair.first;
973  auto map = pair.second;
974  for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
975  unsigned pos = map.getDimPosition(idx);
976  if (!extents[pos])
977  extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
978  }
979  }
980  if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
981  return op.emitOpError("expected all dimensions to get an extent as "
982  "either a LHS or a RHS dimension");
983 
984  AffineMap resMap = op.getIndexingMapsArray()[2];
985  auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
986  /*symbolCount=*/0, extents, ctx);
987  // Compose the resMap with the extentsMap, which is a constant map.
988  AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
989  assert(llvm::all_of(expectedMap.getResults(),
990  llvm::IsaPred<AffineConstantExpr>) &&
991  "expected constant extent along all dimensions.");
992  // Extract the expected shape and build the type.
993  auto expectedShape = llvm::to_vector<4>(
994  llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
995  return cast<AffineConstantExpr>(e).getValue();
996  }));
997  auto expected =
998  VectorType::get(expectedShape, resVectorType.getElementType(),
999  resVectorType.getScalableDims());
1000  if (resVectorType != expected || accVectorType != expected)
1001  return op.emitOpError(
1002  "invalid accumulator/result vector shape, expected: ")
1003  << expected;
1004  }
1005  return success();
1006 }
1007 
1008 LogicalResult ContractionOp::verify() {
1009  VectorType lhsType = getLhsType();
1010  VectorType rhsType = getRhsType();
1011  Type accType = getAccType();
1012  Type resType = getResultType();
1013 
1014  if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1015  if (!lhsType.getElementType().isSignlessInteger())
1016  return emitOpError("only supports signless integer types");
1017  }
1018 
1019  // Verify that an indexing map was specified for each vector operand.
1020  if (getIndexingMapsArray().size() != 3)
1021  return emitOpError("expected an indexing map for each vector operand");
1022 
1023  // Verify that each index map has 'numIterators' inputs, no symbols, and
1024  // that the number of map outputs equals the rank of its associated
1025  // vector operand.
1026  unsigned numIterators = getIteratorTypes().getValue().size();
1027  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1028  auto index = it.index();
1029  auto map = it.value();
1030  if (map.getNumSymbols() != 0)
1031  return emitOpError("expected indexing map ")
1032  << index << " to have no symbols";
1033  auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
1034  unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1035  // Verify that the map has the right number of inputs, outputs, and indices.
1036  // This also correctly accounts for (..) -> () for rank-0 results.
1037  if (map.getNumDims() != numIterators)
1038  return emitOpError("expected indexing map ")
1039  << index << " to have " << numIterators << " number of inputs";
1040  if (map.getNumResults() != rank)
1041  return emitOpError("expected indexing map ")
1042  << index << " to have " << rank << " number of outputs";
1043  if (!map.isProjectedPermutation())
1044  return emitOpError("expected indexing map ")
1045  << index << " to be a projected permutation of its inputs";
1046  }
1047 
1048  auto contractingDimMap = getContractingDimMap();
1049  auto batchDimMap = getBatchDimMap();
1050 
1051  // Verify at least one contracting dimension pair was specified.
1052  if (contractingDimMap.empty())
1053  return emitOpError("expected at least one contracting dimension pair");
1054 
1055  // Verify contracting dimension map was properly constructed.
1056  if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1057  return emitOpError("invalid contracting dimension map");
1058 
1059  // Verify batch dimension map was properly constructed.
1060  if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1061  return emitOpError("invalid batch dimension map");
1062 
1063  // Verify 'accType' and 'resType' shape.
1064  if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1065  contractingDimMap, batchDimMap)))
1066  return failure();
1067 
1068  // Verify supported combining kind.
1069  auto vectorType = llvm::dyn_cast<VectorType>(resType);
1070  auto elementType = vectorType ? vectorType.getElementType() : resType;
1071  if (!isSupportedCombiningKind(getKind(), elementType))
1072  return emitOpError("unsupported contraction type");
1073 
1074  // Delayed calling of IndexingMapOpInterface::verifyImpl.
1075  return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1076 }
1077 
1078 // MaskableOpInterface methods.
1079 
1080 /// Returns the mask type expected by this operation. Mostly used for
1081 /// verification purposes. It requires the operation to be vectorized."
1082 Type ContractionOp::getExpectedMaskType() {
1083  auto indexingMaps = this->getIndexingMapsArray();
1084  AffineMap lhsIdxMap = indexingMaps[0];
1085  AffineMap rhsIdxMap = indexingMaps[1];
1086  VectorType lhsType = this->getLhsType();
1087  VectorType rhsType = this->getRhsType();
1088 
1089  unsigned numVecDims = lhsIdxMap.getNumDims();
1090  SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1091  SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1092 
1093  // Using the information in the indexing maps, extract the size of each
1094  // dimension in the vector.contract operation from the two input operands.
1095  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1096  maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1097  maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1098  lhsType.getScalableDims()[dimIdx];
1099  }
1100  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1101  maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1102  maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1103  rhsType.getScalableDims()[dimIdx];
1104  }
1105 
1106  assert(ShapedType::isStaticShape(maskShape) &&
1107  "Mask shape couldn't be computed");
1108 
1109  return VectorType::get(maskShape,
1110  IntegerType::get(lhsType.getContext(), /*width=*/1),
1111  maskShapeScalableDims);
1112 }
1113 
1114 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1115  return SmallVector<StringRef>{getIndexingMapsAttrName(),
1116  getIteratorTypesAttrName(), getKindAttrName()};
1117 }
1118 
1119 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
1120  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1121  if (targetExpr == map.getResult(i))
1122  return i;
1123  return -1;
1124 }
1125 
1126 static std::vector<std::pair<int64_t, int64_t>>
1127 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1128  IteratorType targetIteratorType, MLIRContext *context) {
1129  std::vector<std::pair<int64_t, int64_t>> dimMap;
1130  for (const auto &it : llvm::enumerate(iteratorTypes)) {
1131  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1132  if (iteratorType != targetIteratorType)
1133  continue;
1134  // Search lhs/rhs map results for 'targetExpr'.
1135  auto targetExpr = getAffineDimExpr(it.index(), context);
1136  int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1137  int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1138  if (lhsDim >= 0 && rhsDim >= 0)
1139  dimMap.emplace_back(lhsDim, rhsDim);
1140  }
1141  return dimMap;
1142 }
1143 
1144 void ContractionOp::getIterationBounds(
1145  SmallVectorImpl<int64_t> &iterationBounds) {
1146  auto lhsShape = getLhsType().getShape();
1147  auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1148  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1149  for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1150  // Search lhs/rhs map results for 'targetExpr'.
1151  auto targetExpr = getAffineDimExpr(it.index(), getContext());
1152  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1153  if (iteratorType == IteratorType::reduction) {
1154  // Get reduction dim size from lhs shape (same size in rhsShape).
1155  int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1156  assert(lhsDimIndex >= 0);
1157  iterationBounds.push_back(lhsShape[lhsDimIndex]);
1158  continue;
1159  }
1160  // Get parallel dimension size from result shape.
1161  int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1162  assert(resDimIndex >= 0);
1163  assert(resVectorType != nullptr);
1164  iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1165  }
1166 }
1167 
1168 void ContractionOp::getIterationIndexMap(
1169  std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1170  unsigned numMaps = getIndexingMapsArray().size();
1171  iterationIndexMap.resize(numMaps);
1172  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1173  auto index = it.index();
1174  auto map = it.value();
1175  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1176  auto dim = cast<AffineDimExpr>(map.getResult(i));
1177  iterationIndexMap[index][dim.getPosition()] = i;
1178  }
1179  }
1180 }
1181 
1182 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1183  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1184  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1185  getContext());
1186 }
1187 
1188 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1189  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1190  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1191  getContext());
1192 }
1193 
1194 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1196  getIterationBounds(shape);
1197  return shape;
1198 }
1199 
1200 /// Return a fused vector::ContractionOp which represents a patterns such as:
1201 ///
1202 /// ```mlir
1203 /// %c0 = vector.constant 0: ...
1204 /// %c = vector.contract %a, %b, %c0: ...
1205 /// %e = add %c, %d: ...
1206 /// ```
1207 ///
1208 /// by:
1209 ///
1210 /// ```mlir
1211 /// %e = vector.contract %a, %b, %d: ...
1212 /// ```
1213 ///
1214 /// Return null if the canonicalization does not apply.
1215 // TODO: This should be a folding of Add into Contract in core but while they
1216 // live in different dialects, it is not possible without unnatural
1217 // dependencies.
1218 template <typename AddOpType>
1219 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1221 
1222  LogicalResult matchAndRewrite(AddOpType addOp,
1223  PatternRewriter &rewriter) const override {
1224  auto canonicalize = [&](Value maybeContraction,
1225  Value otherOperand) -> vector::ContractionOp {
1226  vector::ContractionOp contractionOp =
1227  dyn_cast_or_null<vector::ContractionOp>(
1228  maybeContraction.getDefiningOp());
1229  if (!contractionOp)
1230  return vector::ContractionOp();
1231  if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1232  contractionOp.getAcc().getDefiningOp())) {
1233  if (maybeZero.getValue() ==
1234  rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1235  IRMapping bvm;
1236  bvm.map(contractionOp.getAcc(), otherOperand);
1237  auto newContraction =
1238  cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1239  rewriter.replaceOp(addOp, newContraction.getResult());
1240  return newContraction;
1241  }
1242  }
1243  return vector::ContractionOp();
1244  };
1245 
1246  Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1247  vector::ContractionOp contract = canonicalize(a, b);
1248  contract = contract ? contract : canonicalize(b, a);
1249  return contract ? success() : failure();
1250  }
1251 };
1252 
1253 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1254  MLIRContext *context) {
1257 }
1258 
1259 // Returns `true` if `index` is either within [0, maxIndex) or equal to
1260 // `poisonValue`.
1261 static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
1262  int64_t maxIndex) {
1263  return index == poisonValue || (index >= 0 && index < maxIndex);
1264 }
1265 
1266 //===----------------------------------------------------------------------===//
1267 // ExtractOp
1268 //===----------------------------------------------------------------------===//
1269 
1270 void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1271  SetIntRangeFn setResultRanges) {
1272  setResultRanges(getResult(), argRanges.front());
1273 }
1274 
1275 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1276  Value source) {
1277  auto vectorTy = cast<VectorType>(source.getType());
1278  build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
1279 }
1280 
1281 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1282  Value source, int64_t position) {
1283  build(builder, result, source, ArrayRef<int64_t>{position});
1284 }
1285 
1286 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1287  Value source, OpFoldResult position) {
1288  build(builder, result, source, ArrayRef<OpFoldResult>{position});
1289 }
1290 
1291 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1292  Value source, ArrayRef<int64_t> position) {
1293  build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1294  builder.getDenseI64ArrayAttr(position));
1295 }
1296 
1297 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1298  Value source, ArrayRef<OpFoldResult> position) {
1299  SmallVector<int64_t> staticPos;
1300  SmallVector<Value> dynamicPos;
1301  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1302  build(builder, result, source, dynamicPos,
1303  builder.getDenseI64ArrayAttr(staticPos));
1304 }
1305 
1306 LogicalResult
1307 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1308  ExtractOp::Adaptor adaptor,
1309  SmallVectorImpl<Type> &inferredReturnTypes) {
1310  auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1311  if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1312  vectorType.getRank()) {
1313  inferredReturnTypes.push_back(vectorType.getElementType());
1314  } else {
1315  auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1316  vectorType.getRank());
1317  inferredReturnTypes.push_back(VectorType::get(
1318  vectorType.getShape().drop_front(n), vectorType.getElementType(),
1319  vectorType.getScalableDims().drop_front(n)));
1320  }
1321  return success();
1322 }
1323 
1324 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1325  // Allow extracting 1-element vectors instead of scalars.
1326  auto isCompatible = [](TypeRange l, TypeRange r) {
1327  auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1328  return vectorType && vectorType.getShape().equals({1}) &&
1329  vectorType.getElementType() == r.front();
1330  };
1331  if (l.size() == 1 && r.size() == 1 &&
1332  (isCompatible(l, r) || isCompatible(r, l)))
1333  return true;
1334  return l == r;
1335 }
1336 
1337 LogicalResult vector::ExtractOp::verify() {
1338  if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
1339  if (resTy.getRank() == 0)
1340  return emitError(
1341  "expected a scalar instead of a 0-d vector as the result type");
1342 
1343  // Note: This check must come before getMixedPosition() to prevent a crash.
1344  auto dynamicMarkersCount =
1345  llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1346  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1347  return emitOpError(
1348  "mismatch between dynamic and static positions (kDynamic marker but no "
1349  "corresponding dynamic position) -- this can only happen due to an "
1350  "incorrect fold/rewrite");
1351  auto position = getMixedPosition();
1352  if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1353  return emitOpError(
1354  "expected position attribute of rank no greater than vector rank");
1355  for (auto [idx, pos] : llvm::enumerate(position)) {
1356  if (auto attr = dyn_cast<Attribute>(pos)) {
1357  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1359  constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1360  return emitOpError("expected position attribute #")
1361  << (idx + 1)
1362  << " to be a non-negative integer smaller than the "
1363  "corresponding vector dimension or poison (-1)";
1364  }
1365  }
1366  }
1367  return success();
1368 }
1369 
1370 template <typename IntType>
1371 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1372  return llvm::to_vector<4>(llvm::map_range(
1373  arrayAttr.getAsRange<IntegerAttr>(),
1374  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1375 }
1376 
1377 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1378 /// positions.
1379 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1380  if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1381  return failure();
1382 
1383  // TODO: Canonicalization for dynamic position not implemented yet.
1384  if (extractOp.hasDynamicPosition())
1385  return failure();
1386 
1387  SmallVector<int64_t> globalPosition;
1388  ExtractOp currentOp = extractOp;
1389  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1390  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1391  while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1392  currentOp = nextOp;
1393  // TODO: Canonicalization for dynamic position not implemented yet.
1394  if (currentOp.hasDynamicPosition())
1395  return failure();
1396  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1397  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1398  }
1399  extractOp.setOperand(0, currentOp.getSource());
1400  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1401  OpBuilder b(extractOp.getContext());
1402  std::reverse(globalPosition.begin(), globalPosition.end());
1403  extractOp.setStaticPosition(globalPosition);
1404  return success();
1405 }
1406 
1407 namespace {
1408 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1409 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1410 /// Compose TransposeOp permutations as we walk back.
1411 /// This helper class keeps an updated extraction position `extractPosition`
1412 /// with extra trailing sentinels.
1413 /// The sentinels encode the internal transposition status of the result vector.
1414 /// As we iterate, extractPosition is permuted and updated.
1415 class ExtractFromInsertTransposeChainState {
1416 public:
1417  ExtractFromInsertTransposeChainState(ExtractOp e);
1418 
1419  /// Iterate over producing insert and transpose ops until we find a fold.
1420  Value fold();
1421 
1422 private:
1423  /// Return true if the vector at position `a` is contained within the vector
1424  /// at position `b`. Under insert/extract semantics, this is the same as `a`
1425  /// is a prefix of `b`.
1426  template <typename ContainerA, typename ContainerB>
1427  bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1428  return a.size() <= b.size() &&
1429  std::equal(a.begin(), a.begin() + a.size(), b.begin());
1430  }
1431 
1432  /// Return true if the vector at position `a` intersects the vector at
1433  /// position `b`. Under insert/extract semantics, this is the same as equality
1434  /// of all entries of `a` that are >=0 with the corresponding entries of b.
1435  /// Comparison is on the common prefix (i.e. zip).
1436  template <typename ContainerA, typename ContainerB>
1437  bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1438  for (auto [elemA, elemB] : llvm::zip(a, b)) {
1439  if (elemA < 0 || elemB < 0)
1440  continue;
1441  if (elemA != elemB)
1442  return false;
1443  }
1444  return true;
1445  }
1446 
1447  /// Folding is only possible in the absence of an internal permutation in the
1448  /// result vector.
1449  bool canFold() {
1450  return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1451  }
1452 
1453  // Helper to get the next defining op of interest.
1454  void updateStateForNextIteration(Value v) {
1455  nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1456  nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1457  };
1458 
1459  // Case 1. If we hit a transpose, just compose the map and iterate.
1460  // Invariant: insert + transpose do not change rank, we can always compose.
1461  LogicalResult handleTransposeOp();
1462 
1463  // Case 2: the insert position matches extractPosition exactly, early return.
1464  LogicalResult handleInsertOpWithMatchingPos(Value &res);
1465 
1466  /// Case 3: if the insert position is a prefix of extractPosition, extract a
1467  /// portion of the source of the insert.
1468  /// Example:
1469  /// ```
1470  /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1471  /// // extractPosition == [1, 2, 3]
1472  /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1473  /// // can fold to vector.extract %source[0, 3]
1474  /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1475  /// ```
1476  /// To traverse through %source, we need to set the leading dims to 0 and
1477  /// drop the extra leading dims.
1478  /// This method updates the internal state.
1479  LogicalResult handleInsertOpWithPrefixPos(Value &res);
1480 
1481  /// Try to fold in place to extract(source, extractPosition) and return the
1482  /// folded result. Return null if folding is not possible (e.g. due to an
1483  /// internal transposition in the result).
1484  Value tryToFoldExtractOpInPlace(Value source);
1485 
1486  ExtractOp extractOp;
1487  int64_t vectorRank;
1488  int64_t extractedRank;
1489 
1490  InsertOp nextInsertOp;
1491  TransposeOp nextTransposeOp;
1492 
1493  /// Sentinel values that encode the internal permutation status of the result.
1494  /// They are set to (-1, ... , -k) at the beginning and appended to
1495  /// `extractPosition`.
1496  /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1497  /// ensure that there is no internal transposition.
1498  /// Internal transposition cannot be accounted for with a folding pattern.
1499  // TODO: We could relax the internal transposition with an extra transposition
1500  // operation in a future canonicalizer.
1501  SmallVector<int64_t> sentinels;
1503 };
1504 } // namespace
1505 
1506 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1507  ExtractOp e)
1508  : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1509  extractedRank(extractOp.getNumIndices()) {
1510  assert(vectorRank >= extractedRank && "Extracted position overflow");
1511  sentinels.reserve(vectorRank - extractedRank);
1512  for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1513  sentinels.push_back(-(i + 1));
1514  extractPosition.assign(extractOp.getStaticPosition().begin(),
1515  extractOp.getStaticPosition().end());
1516  llvm::append_range(extractPosition, sentinels);
1517 }
1518 
1519 // Case 1. If we hit a transpose, just compose the map and iterate.
1520 // Invariant: insert + transpose do not change rank, we can always compose.
1521 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1522  // TODO: Canonicalization for dynamic position not implemented yet.
1523  if (extractOp.hasDynamicPosition())
1524  return failure();
1525 
1526  if (!nextTransposeOp)
1527  return failure();
1528  AffineMap m = inversePermutation(AffineMap::getPermutationMap(
1529  nextTransposeOp.getPermutation(), extractOp.getContext()));
1531  return success();
1532 }
1533 
1534 // Case 2: the insert position matches extractPosition exactly, early return.
1535 LogicalResult
1536 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1537  Value &res) {
1538  // TODO: Canonicalization for dynamic position not implemented yet.
1539  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1540  return failure();
1541 
1542  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1543  if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1544  return failure();
1545  // Case 2.a. early-exit fold.
1546  res = nextInsertOp.getValueToStore();
1547  // Case 2.b. if internal transposition is present, canFold will be false.
1548  return success(canFold());
1549 }
1550 
1551 /// Case 3: if inserted position is a prefix of extractPosition,
1552 /// extract a portion of the source of the insertion.
1553 /// This method updates the internal state.
1554 LogicalResult
1555 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1556  // TODO: Canonicalization for dynamic position not implemented yet.
1557  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1558  return failure();
1559 
1560  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1561  if (!isContainedWithin(insertedPos, extractPosition))
1562  return failure();
1563  // Set leading dims to zero.
1564  std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1565  // Drop extra leading dims.
1566  extractPosition.erase(extractPosition.begin(),
1567  extractPosition.begin() + insertedPos.size());
1568  extractedRank = extractPosition.size() - sentinels.size();
1569  // Case 3.a. early-exit fold (break and delegate to post-while path).
1570  res = nextInsertOp.getValueToStore();
1571  // Case 3.b. if internal transposition is present, canFold will be false.
1572  return success();
1573 }
1574 
1575 /// Try to fold in place to extract(source, extractPosition) and return the
1576 /// folded result. Return null if folding is not possible (e.g. due to an
1577 /// internal transposition in the result).
1578 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1579  Value source) {
1580  // TODO: Canonicalization for dynamic position not implemented yet.
1581  if (extractOp.hasDynamicPosition())
1582  return Value();
1583 
1584  // If we can't fold (either internal transposition, or nothing to fold), bail.
1585  bool nothingToFold = (source == extractOp.getSource());
1586  if (nothingToFold || !canFold())
1587  return Value();
1588 
1589  // Otherwise, fold by updating the op inplace and return its result.
1590  OpBuilder b(extractOp.getContext());
1591  extractOp.setStaticPosition(
1592  ArrayRef(extractPosition).take_front(extractedRank));
1593  extractOp.getSourceMutable().assign(source);
1594  return extractOp.getResult();
1595 }
1596 
1597 /// Iterate over producing insert and transpose ops until we find a fold.
1598 Value ExtractFromInsertTransposeChainState::fold() {
1599  // TODO: Canonicalization for dynamic position not implemented yet.
1600  if (extractOp.hasDynamicPosition())
1601  return Value();
1602 
1603  Value valueToExtractFrom = extractOp.getSource();
1604  updateStateForNextIteration(valueToExtractFrom);
1605  while (nextInsertOp || nextTransposeOp) {
1606  // Case 1. If we hit a transpose, just compose the map and iterate.
1607  // Invariant: insert + transpose do not change rank, we can always compose.
1608  if (succeeded(handleTransposeOp())) {
1609  valueToExtractFrom = nextTransposeOp.getVector();
1610  updateStateForNextIteration(valueToExtractFrom);
1611  continue;
1612  }
1613 
1614  Value result;
1615  // Case 2: the position match exactly.
1616  if (succeeded(handleInsertOpWithMatchingPos(result)))
1617  return result;
1618 
1619  // Case 3: if the inserted position is a prefix of extractPosition, we can
1620  // just extract a portion of the source of the insert.
1621  if (succeeded(handleInsertOpWithPrefixPos(result)))
1622  return tryToFoldExtractOpInPlace(result);
1623 
1624  // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1625  // values. This is a more difficult case and we bail.
1626  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1627  if (isContainedWithin(extractPosition, insertedPos) ||
1628  intersectsWhereNonNegative(extractPosition, insertedPos))
1629  return Value();
1630 
1631  // Case 5: No intersection, we forward the extract to insertOp.dest().
1632  valueToExtractFrom = nextInsertOp.getDest();
1633  updateStateForNextIteration(valueToExtractFrom);
1634  }
1635  // If after all this we can fold, go for it.
1636  return tryToFoldExtractOpInPlace(valueToExtractFrom);
1637 }
1638 
1639 /// Returns true if the operation has a 0-D vector type operand or result.
1640 static bool hasZeroDimVectors(Operation *op) {
1641  auto hasZeroDimVectorType = [](Type type) -> bool {
1642  auto vecType = dyn_cast<VectorType>(type);
1643  return vecType && vecType.getRank() == 0;
1644  };
1645 
1646  return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1647  llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1648 }
1649 
1650 /// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
1651 /// 1s, are considered to be 'broadcastlike'.
1652 static bool isBroadcastLike(Operation *op) {
1653  if (isa<BroadcastOp, SplatOp>(op))
1654  return true;
1655 
1656  auto shapeCast = dyn_cast<ShapeCastOp>(op);
1657  if (!shapeCast)
1658  return false;
1659 
1660  // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1661  // Checking that the destination shape has a prefix of 1s is not sufficient,
1662  // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
1663  // is that the source shape is a suffix of the destination shape.
1664  VectorType srcType = shapeCast.getSourceVectorType();
1665  ArrayRef<int64_t> srcShape = srcType.getShape();
1666  uint64_t srcRank = srcType.getRank();
1667  ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
1668  return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1669 }
1670 
1671 /// Fold extract(broadcast(X)) to either extract(X) or just X.
1672 ///
1673 /// Example:
1674 ///
1675 /// broadcast extract [1][2]
1676 /// (3, 4) --------> (2, 3, 4) ----------------> (4)
1677 ///
1678 /// becomes
1679 /// extract [1]
1680 /// (3,4) -------------------------------------> (4)
1681 ///
1682 ///
1683 /// The variable names used in this implementation correspond to the above
1684 /// shapes as,
1685 ///
1686 /// - (3, 4) is `input` shape.
1687 /// - (2, 3, 4) is `broadcast` shape.
1688 /// - (4) is `extract` shape.
1689 ///
1690 /// This folding is possible when the suffix of `input` shape is the same as
1691 /// `extract` shape.
1692 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1693 
1694  Operation *defOp = extractOp.getSource().getDefiningOp();
1695  if (!defOp || !isBroadcastLike(defOp))
1696  return Value();
1697 
1698  Value input = defOp->getOperand(0);
1699 
1700  // Replace extract(broadcast(X)) with X
1701  if (extractOp.getType() == input.getType())
1702  return input;
1703 
1704  // Get required types and ranks in the chain
1705  // input -> broadcast -> extract
1706  // (scalars are treated as rank-0).
1707  auto inputType = llvm::dyn_cast<VectorType>(input.getType());
1708  auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1709  unsigned inputRank = inputType ? inputType.getRank() : 0;
1710  unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1711  unsigned extractRank = extractType ? extractType.getRank() : 0;
1712 
1713  // Cannot do without the broadcast if overall the rank increases.
1714  if (extractRank > inputRank)
1715  return Value();
1716 
1717  // The above condition guarantees that input is a vector.
1718  assert(inputType && "input must be a vector type because of previous checks");
1719  ArrayRef<int64_t> inputShape = inputType.getShape();
1720 
1721  // In the case where there is a broadcast dimension in the suffix, it is not
1722  // possible to replace extract(broadcast(X)) with extract(X). Example:
1723  //
1724  // broadcast extract
1725  // (1) --------> (3,4) ------> (4)
1726  if (extractType &&
1727  extractType.getShape() != inputShape.take_back(extractRank))
1728  return Value();
1729 
1730  // Replace extract(broadcast(X)) with extract(X).
1731  // First, determine the new extraction position.
1732  unsigned deltaOverall = inputRank - extractRank;
1733  unsigned deltaBroadcast = broadcastRank - inputRank;
1734  SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
1735  SmallVector<OpFoldResult> newPositions(deltaOverall);
1736  IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
1737  for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1738  newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1739  }
1740  auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
1741  extractOp->setOperands(
1742  llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
1743  extractOp.setStaticPosition(staticPos);
1744  return extractOp.getResult();
1745 }
1746 
1747 /// Fold extractOp coming from ShuffleOp.
1748 ///
1749 /// Example:
1750 ///
1751 /// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1752 /// : vector<8xf32>, vector<8xf32>
1753 /// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1754 /// ->
1755 /// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1756 ///
1757 static Value foldExtractFromShuffle(ExtractOp extractOp) {
1758  // Dynamic positions are not folded as the resulting code would be more
1759  // complex than the input code.
1760  if (extractOp.hasDynamicPosition())
1761  return Value();
1762 
1763  auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1764  if (!shuffleOp)
1765  return Value();
1766 
1767  // TODO: 0-D or multi-dimensional vectors not supported yet.
1768  if (shuffleOp.getResultVectorType().getRank() != 1)
1769  return Value();
1770 
1771  int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1772  auto shuffleMask = shuffleOp.getMask();
1773  int64_t extractIdx = extractOp.getStaticPosition()[0];
1774  int64_t shuffleIdx = shuffleMask[extractIdx];
1775 
1776  // Find the shuffled vector to extract from based on the shuffle index.
1777  if (shuffleIdx < inputVecSize) {
1778  extractOp.setOperand(0, shuffleOp.getV1());
1779  extractOp.setStaticPosition({shuffleIdx});
1780  } else {
1781  extractOp.setOperand(0, shuffleOp.getV2());
1782  extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1783  }
1784 
1785  return extractOp.getResult();
1786 }
1787 
1788 // Fold extractOp with source coming from ShapeCast op.
1789 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1790  // TODO: Canonicalization for dynamic position not implemented yet.
1791  if (extractOp.hasDynamicPosition())
1792  return Value();
1793 
1794  auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1795  if (!shapeCastOp)
1796  return Value();
1797 
1798  // Get the nth dimension size starting from lowest dimension.
1799  auto getDimReverse = [](VectorType type, int64_t n) {
1800  return type.getShape().take_back(n + 1).front();
1801  };
1802  int64_t destinationRank =
1803  llvm::isa<VectorType>(extractOp.getType())
1804  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1805  : 0;
1806  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1807  return Value();
1808  if (destinationRank > 0) {
1809  auto destinationType =
1810  llvm::cast<VectorType>(extractOp.getResult().getType());
1811  for (int64_t i = 0; i < destinationRank; i++) {
1812  // The lowest dimension of the destination must match the lowest
1813  // dimension of the shapecast op source.
1814  // TODO: This case could be support in a canonicalization pattern.
1815  if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1816  getDimReverse(destinationType, i))
1817  return Value();
1818  }
1819  }
1820  // Extract the strides associated with the extract op vector source. Then use
1821  // this to calculate a linearized position for the extract.
1822  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1823  std::reverse(extractedPos.begin(), extractedPos.end());
1824  SmallVector<int64_t, 4> strides;
1825  int64_t stride = 1;
1826  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1827  strides.push_back(stride);
1828  stride *=
1829  getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1830  }
1831 
1832  int64_t position = linearize(extractedPos, strides);
1833  // Then extract the strides associated to the shapeCast op vector source and
1834  // delinearize the position using those strides.
1835  SmallVector<int64_t, 4> newStrides;
1836  int64_t numDimension =
1837  shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1838  stride = 1;
1839  for (int64_t i = 0; i < numDimension; i++) {
1840  newStrides.push_back(stride);
1841  stride *=
1842  getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1843  }
1844  std::reverse(newStrides.begin(), newStrides.end());
1845  SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1846  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1847  OpBuilder b(extractOp.getContext());
1848  extractOp.setStaticPosition(newPosition);
1849  extractOp.setOperand(0, shapeCastOp.getSource());
1850  return extractOp.getResult();
1851 }
1852 
1853 /// Fold an ExtractOp from ExtractStridedSliceOp.
1854 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1855  // TODO: Canonicalization for dynamic position not implemented yet.
1856  if (extractOp.hasDynamicPosition())
1857  return Value();
1858 
1859  auto extractStridedSliceOp =
1860  extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1861  if (!extractStridedSliceOp)
1862  return Value();
1863 
1864  // 0-D vectors not supported.
1865  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1866  if (hasZeroDimVectors(extractStridedSliceOp))
1867  return Value();
1868 
1869  // Return if 'extractStridedSliceOp' has non-unit strides.
1870  if (extractStridedSliceOp.hasNonUnitStrides())
1871  return Value();
1872 
1873  // Trim offsets for dimensions fully extracted.
1874  auto sliceOffsets =
1875  extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1876  while (!sliceOffsets.empty()) {
1877  size_t lastOffset = sliceOffsets.size() - 1;
1878  if (sliceOffsets.back() != 0 ||
1879  extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1880  extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1881  break;
1882  sliceOffsets.pop_back();
1883  }
1884  unsigned destinationRank = 0;
1885  if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1886  destinationRank = vecType.getRank();
1887  // The dimensions of the result need to be untouched by the
1888  // extractStridedSlice op.
1889  if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1890  sliceOffsets.size())
1891  return Value();
1892 
1893  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1894  assert(extractedPos.size() >= sliceOffsets.size());
1895  for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1896  extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1897  extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1898 
1899  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1900  OpBuilder b(extractOp.getContext());
1901  extractOp.setStaticPosition(extractedPos);
1902  return extractOp.getResult();
1903 }
1904 
1905 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1906 static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1907  // TODO: Canonicalization for dynamic position not implemented yet.
1908  if (extractOp.hasDynamicPosition())
1909  return Value();
1910 
1911  int64_t destinationRank =
1912  llvm::isa<VectorType>(extractOp.getType())
1913  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1914  : 0;
1915  auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1916  if (!insertOp)
1917  return Value();
1918 
1919  // 0-D vectors not supported.
1920  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1921  if (hasZeroDimVectors(insertOp))
1922  return Value();
1923 
1924  while (insertOp) {
1925  int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1926  insertOp.getSourceVectorType().getRank();
1927  if (destinationRank > insertOp.getSourceVectorType().getRank())
1928  return Value();
1929  auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1930  ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1931 
1932  if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1933  return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1934  }))
1935  return Value();
1936  bool disjoint = false;
1937  SmallVector<int64_t, 4> offsetDiffs;
1938  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1939  int64_t start = insertOffsets[dim];
1940  int64_t size =
1941  (dim < insertRankDiff)
1942  ? 1
1943  : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1944  int64_t end = start + size;
1945  int64_t offset = extractOffsets[dim];
1946  // Check if the start of the extract offset is in the interval inserted.
1947  if (start <= offset && offset < end) {
1948  if (dim >= insertRankDiff)
1949  offsetDiffs.push_back(offset - start);
1950  continue;
1951  }
1952  disjoint = true;
1953  break;
1954  }
1955  // The extract element chunk overlap with the vector inserted.
1956  if (!disjoint) {
1957  // If any of the inner dimensions are only partially inserted we have a
1958  // partial overlap.
1959  int64_t srcRankDiff =
1960  insertOp.getSourceVectorType().getRank() - destinationRank;
1961  for (int64_t i = 0; i < destinationRank; i++) {
1962  if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1963  insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1964  insertRankDiff))
1965  return Value();
1966  }
1967  extractOp.getSourceMutable().assign(insertOp.getValueToStore());
1968  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1969  OpBuilder b(extractOp.getContext());
1970  extractOp.setStaticPosition(offsetDiffs);
1971  return extractOp.getResult();
1972  }
1973  // If the chunk extracted is disjoint from the chunk inserted, keep
1974  // looking in the insert chain.
1975  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1976  }
1977  return Value();
1978 }
1979 
1980 /// Try to fold the extraction of a scalar from a vector defined by
1981 /// vector.from_elements. E.g.:
1982 ///
1983 /// %0 = vector.from_elements %a, %b : vector<2xf32>
1984 /// %1 = vector.extract %0[0] : f32 from vector<2xf32>
1985 /// ==> fold to %a
1986 static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
1987  // Dynamic extractions cannot be folded.
1988  if (extractOp.hasDynamicPosition())
1989  return {};
1990 
1991  // Look for extract(from_elements).
1992  auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
1993  if (!fromElementsOp)
1994  return {};
1995 
1996  // Scalable vectors are not supported.
1997  auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1998  if (vecType.isScalable())
1999  return {};
2000 
2001  // Only extractions of scalars are supported.
2002  int64_t rank = vecType.getRank();
2003  ArrayRef<int64_t> indices = extractOp.getStaticPosition();
2004  if (extractOp.getType() != vecType.getElementType())
2005  return {};
2006  assert(static_cast<int64_t>(indices.size()) == rank &&
2007  "unexpected number of indices");
2008 
2009  // Compute flattened/linearized index and fold to operand.
2010  int flatIndex = 0;
2011  int stride = 1;
2012  for (int i = rank - 1; i >= 0; --i) {
2013  flatIndex += indices[i] * stride;
2014  stride *= vecType.getDimSize(i);
2015  }
2016  return fromElementsOp.getElements()[flatIndex];
2017 }
2018 
2019 /// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
2020 /// then fold it.
2021 template <typename OpType, typename AdaptorType>
2022 static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
2023  SmallVectorImpl<Value> &operands) {
2024  std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2025  OperandRange dynamicPosition = op.getDynamicPosition();
2026  ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2028  if constexpr (std::is_same_v<OpType, ExtractOp>)
2029  vectorShape = op.getSourceVectorType().getShape();
2030  else
2031  vectorShape = op.getDestVectorType().getShape();
2032 
2033  // If the dynamic operands is empty, it is returned directly.
2034  if (!dynamicPosition.size())
2035  return {};
2036 
2037  // `index` is used to iterate over the `dynamicPosition`.
2038  unsigned index = 0;
2039 
2040  // `opChange` is a flag. If it is true, it means to update `op` in place.
2041  bool opChange = false;
2042  for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2043  if (ShapedType::isStatic(staticPosition[i]))
2044  continue;
2045  Attribute positionAttr = dynamicPositionAttr[index];
2046  Value position = dynamicPosition[index++];
2047  if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2048  int64_t value = attr.getInt();
2049  // Do not fold if the value is out of bounds (-1 signifies a poison
2050  // value rather than OOB index).
2051  if (value >= -1 && value < vectorShape[i]) {
2052  staticPosition[i] = attr.getInt();
2053  opChange = true;
2054  continue;
2055  }
2056  }
2057  operands.push_back(position);
2058  }
2059 
2060  if (opChange) {
2061  op.setStaticPosition(staticPosition);
2062  op.getOperation()->setOperands(operands);
2063  // Return the original result to indicate an in-place folding happened.
2064  return op.getResult();
2065  }
2066  return {};
2067 }
2068 
2069 /// Fold an insert or extract operation into an poison value when a poison index
2070 /// is found at any dimension of the static position.
2072  ArrayRef<int64_t> staticPos,
2073  int64_t poisonVal) {
2074  if (!is_contained(staticPos, poisonVal))
2075  return {};
2076 
2077  return ub::PoisonAttr::get(context);
2078 }
2079 
2080 /// Fold a vector extract from is a poison source.
2082  if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2083  return srcAttr;
2084 
2085  return {};
2086 }
2087 
2088 /// Fold a vector extract extracting from a DenseElementsAttr.
2090  Attribute srcAttr) {
2091  auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2092  if (!denseAttr) {
2093  return {};
2094  }
2095 
2096  if (denseAttr.isSplat()) {
2097  Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2098  if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2099  newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2100  return newAttr;
2101  }
2102 
2103  auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2104  if (vecTy.isScalable())
2105  return {};
2106 
2107  if (extractOp.hasDynamicPosition()) {
2108  return {};
2109  }
2110 
2111  // Materializing subsets of a large constant array can generally lead to
2112  // explosion in IR size because of different combination of subsets that
2113  // can exist. However, vector.extract is a restricted form of subset
2114  // extract where you can only extract non-overlapping (or the same) subset for
2115  // a given rank of the subset. Because of this property, the IR size can only
2116  // increase at most by `rank * size(array)` from a single constant array being
2117  // extracted by multiple extracts.
2118 
2119  // Calculate the linearized position of the continuous chunk of elements to
2120  // extract.
2121  SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2122  copy(extractOp.getStaticPosition(), completePositions.begin());
2123  int64_t startPos =
2124  linearize(completePositions, computeStrides(vecTy.getShape()));
2125  auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2126 
2127  TypedAttr newAttr;
2128  if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2129  SmallVector<Attribute> elementValues(
2130  denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2131  newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2132  } else {
2133  newAttr = *denseValuesBegin;
2134  }
2135 
2136  return newAttr;
2137 }
2138 
2139 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2140  // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2141  // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2142  // mismatch).
2143  if (getNumIndices() == 0 && getSource().getType() == getResult().getType())
2144  return getSource();
2145  if (auto res = foldPoisonSrcExtractOp(adaptor.getSource()))
2146  return res;
2147  // Fold `arith.constant` indices into the `vector.extract` operation.
2148  // Do not stop here as this fold may enable subsequent folds that require
2149  // constant indices.
2150  SmallVector<Value> operands = {getSource()};
2151  auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
2152 
2153  if (auto res = foldPoisonIndexInsertExtractOp(
2154  getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2155  return res;
2156  if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getSource()))
2157  return res;
2158  if (succeeded(foldExtractOpFromExtractChain(*this)))
2159  return getResult();
2160  if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2161  return res;
2162  if (auto res = foldExtractFromBroadcast(*this))
2163  return res;
2164  if (auto res = foldExtractFromShuffle(*this))
2165  return res;
2166  if (auto res = foldExtractFromShapeCast(*this))
2167  return res;
2168  if (auto val = foldExtractFromExtractStrided(*this))
2169  return val;
2170  if (auto val = foldExtractStridedOpFromInsertChain(*this))
2171  return val;
2172  if (auto val = foldScalarExtractFromFromElements(*this))
2173  return val;
2174 
2175  return inplaceFolded;
2176 }
2177 
2178 namespace {
2179 
2180 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2181 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2182 public:
2184 
2185  LogicalResult matchAndRewrite(ExtractOp extractOp,
2186  PatternRewriter &rewriter) const override {
2187 
2188  Operation *defOp = extractOp.getSource().getDefiningOp();
2189  VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2190  if (!defOp || !isBroadcastLike(defOp) || !outType)
2191  return failure();
2192 
2193  Value source = defOp->getOperand(0);
2194  if (isBroadcastableTo(source.getType(), outType) !=
2195  BroadcastableToResult::Success)
2196  return failure();
2197 
2198  rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
2199  return success();
2200  }
2201 };
2202 
2203 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2204 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2205 public:
2207 
2208  LogicalResult matchAndRewrite(ExtractOp extractOp,
2209  PatternRewriter &rewriter) const override {
2210  auto createMaskOp =
2211  extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2212  if (!createMaskOp)
2213  return failure();
2214 
2215  VectorType extractedMaskType =
2216  llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2217 
2218  if (!extractedMaskType)
2219  return failure();
2220 
2221  auto maskOperands = createMaskOp.getOperands();
2222  ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2223  VectorType maskType = createMaskOp.getVectorType();
2224 
2225  bool containsUnknownDims = false;
2226  bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2227 
2228  for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2229  dimIdx++) {
2230  int64_t pos = extractOpPos[dimIdx];
2231  Value operand = maskOperands[dimIdx];
2232  auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2233  if (!constantOp) {
2234  // Bounds of this dim unknown.
2235  containsUnknownDims = true;
2236  continue;
2237  }
2238 
2239  int64_t createMaskBound =
2240  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2241 
2242  if (pos != ShapedType::kDynamic) {
2243  // If any position is outside the range from the `create_mask`, then the
2244  // extracted mask will be all-false.
2245  allFalse |= pos >= createMaskBound;
2246  } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2247  // This dim is not all-true and since this is a dynamic index we don't
2248  // know if the extraction is within the true or false region.
2249  // Note: Zero dims have already handled via getMaskFormat().
2250  containsUnknownDims = true;
2251  }
2252  }
2253 
2254  if (allFalse) {
2255  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2256  extractOp, DenseElementsAttr::get(extractedMaskType, false));
2257  } else if (!containsUnknownDims) {
2258  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2259  extractOp, extractedMaskType,
2260  maskOperands.drop_front(extractOpPos.size()));
2261  } else {
2262  return failure();
2263  }
2264  return success();
2265  }
2266 };
2267 
2268 // Folds extract(shape_cast(..)) into shape_cast when the total element count
2269 // does not change.
2270 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2271  PatternRewriter &rewriter) {
2272  auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2273  if (!castOp)
2274  return failure();
2275 
2276  VectorType sourceType = castOp.getSourceVectorType();
2277  auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2278  if (!targetType)
2279  return failure();
2280 
2281  if (sourceType.getNumElements() != targetType.getNumElements())
2282  return failure();
2283 
2284  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2285  castOp.getSource());
2286  return success();
2287 }
2288 
2289 /// Try to canonicalize the extraction of a subvector from a vector defined by
2290 /// vector.from_elements. E.g.:
2291 ///
2292 /// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2293 /// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2294 /// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2295 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2296  PatternRewriter &rewriter) {
2297  // Dynamic positions are not supported.
2298  if (extractOp.hasDynamicPosition())
2299  return failure();
2300 
2301  // Scalar extracts are handled by the folder.
2302  auto resultType = dyn_cast<VectorType>(extractOp.getType());
2303  if (!resultType)
2304  return failure();
2305 
2306  // Look for extracts from a from_elements op.
2307  auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2308  if (!fromElementsOp)
2309  return failure();
2310  VectorType inputType = fromElementsOp.getType();
2311 
2312  // Scalable vectors are not supported.
2313  if (resultType.isScalable() || inputType.isScalable())
2314  return failure();
2315 
2316  // Compute the position of first extracted element and flatten/linearize the
2317  // position.
2318  SmallVector<int64_t> firstElementPos =
2319  llvm::to_vector(extractOp.getStaticPosition());
2320  firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2321  int flatIndex = 0;
2322  int stride = 1;
2323  for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2324  flatIndex += firstElementPos[i] * stride;
2325  stride *= inputType.getDimSize(i);
2326  }
2327 
2328  // Replace the op with a smaller from_elements op.
2329  rewriter.replaceOpWithNewOp<FromElementsOp>(
2330  extractOp, resultType,
2331  fromElementsOp.getElements().slice(flatIndex,
2332  resultType.getNumElements()));
2333  return success();
2334 }
2335 
2336 } // namespace
2337 
2338 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2339  MLIRContext *context) {
2340  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2341  results.add(foldExtractFromShapeCastToShapeCast);
2342  results.add(foldExtractFromFromElements);
2343 }
2344 
2345 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
2346  SmallVectorImpl<int64_t> &results) {
2347  for (auto attr : arrayAttr)
2348  results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2349 }
2350 
2351 //===----------------------------------------------------------------------===//
2352 // FmaOp
2353 //===----------------------------------------------------------------------===//
2354 
2355 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2356  return llvm::to_vector<4>(getVectorType().getShape());
2357 }
2358 
2359 //===----------------------------------------------------------------------===//
2360 // ToElementsOp
2361 //===----------------------------------------------------------------------===//
2362 
2363 /// Returns true if all the `operands` are defined by `defOp`.
2364 /// Otherwise, returns false.
2365 static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
2366  if (operands.empty())
2367  return false;
2368 
2369  return llvm::all_of(operands, [&](Value operand) {
2370  Operation *currentDef = operand.getDefiningOp();
2371  return currentDef == defOp;
2372  });
2373 }
2374 
2375 /// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
2376 /// (%e0, %e1, ...). For example:
2377 ///
2378 /// %0 = vector.from_elements %a, %b, %c : vector<3xf32>
2379 /// %1:3 = vector.to_elements %0 : vector<3xf32>
2380 /// user_op %1#0, %1#1, %1#2
2381 ///
2382 /// becomes:
2383 ///
2384 /// user_op %a, %b, %c
2385 ///
2386 static LogicalResult
2387 foldToElementsFromElements(ToElementsOp toElementsOp,
2388  SmallVectorImpl<OpFoldResult> &results) {
2389  auto fromElementsOp =
2390  toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2391  if (!fromElementsOp)
2392  return failure();
2393 
2394  llvm::append_range(results, fromElementsOp.getElements());
2395  return success();
2396 }
2397 
2398 LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2399  SmallVectorImpl<OpFoldResult> &results) {
2400  return foldToElementsFromElements(*this, results);
2401 }
2402 
2403 LogicalResult
2404 ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2405  ToElementsOp::Adaptor adaptor,
2406  SmallVectorImpl<Type> &inferredReturnTypes) {
2407  auto vecType = cast<VectorType>(adaptor.getSource().getType());
2408  Type elType = vecType.getElementType();
2409  inferredReturnTypes.append(vecType.getNumElements(), elType);
2410  return success();
2411 }
2412 
2413 //===----------------------------------------------------------------------===//
2414 // FromElementsOp
2415 //===----------------------------------------------------------------------===//
2416 
2417 /// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
2418 ///
2419 /// Case #1: Input and output vectors are the same.
2420 ///
2421 /// %0:3 = vector.to_elements %a : vector<3xf32>
2422 /// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
2423 /// user_op %1
2424 ///
2425 /// becomes:
2426 ///
2427 /// user_op %a
2428 ///
2429 static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
2430  OperandRange fromElemsOperands = fromElementsOp.getElements();
2431  if (fromElemsOperands.empty())
2432  return {};
2433 
2434  auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2435  if (!toElementsOp)
2436  return {};
2437 
2438  if (!haveSameDefiningOp(fromElemsOperands, toElementsOp))
2439  return {};
2440 
2441  // Case #1: Input and output vectors are the same. Forward the input vector.
2442  Value toElementsInput = toElementsOp.getSource();
2443  if (fromElementsOp.getType() == toElementsInput.getType() &&
2444  llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2445  return toElementsInput;
2446  }
2447 
2448  // TODO: Support cases with different input and output shapes and different
2449  // number of elements.
2450 
2451  return {};
2452 }
2453 
2454 /// Fold vector.from_elements to a constant when all operands are constants.
2455 /// Example:
2456 /// %c1 = arith.constant 1 : i32
2457 /// %c2 = arith.constant 2 : i32
2458 /// %v = vector.from_elements %c1, %c2 : vector<2xi32>
2459 /// =>
2460 /// %v = arith.constant dense<[1, 2]> : vector<2xi32>
2461 ///
2462 static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
2463  ArrayRef<Attribute> elements) {
2464  // Check for null or poison attributes before any processing.
2465  if (llvm::any_of(elements, [](Attribute attr) {
2466  return !attr || isa<ub::PoisonAttrInterface>(attr);
2467  }))
2468  return {};
2469 
2470  // DenseElementsAttr only supports int/index/float/complex types.
2471  auto destVecType = fromElementsOp.getDest().getType();
2472  auto destEltType = destVecType.getElementType();
2473  if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2474  return {};
2475 
2476  // Convert integer attributes to the target type if needed, leave others
2477  // unchanged.
2478  auto convertedElements =
2479  llvm::map_to_vector(elements, [&](Attribute attr) -> Attribute {
2480  if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
2481  return convertIntegerAttr(intAttr, destEltType);
2482  }
2483  return attr; // Non-integer attributes (FloatAttr, etc.) returned
2484  // unchanged
2485  });
2486 
2487  return DenseElementsAttr::get(destVecType, convertedElements);
2488 }
2489 
2490 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2491  if (auto res = foldFromElementsToElements(*this))
2492  return res;
2493  if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
2494  return res;
2495 
2496  return {};
2497 }
2498 
2499 /// Rewrite vector.from_elements as vector.broadcast if the elements are the
2500 /// same. Example:
2501 /// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2502 /// =>
2503 /// %0 = vector.broadcast %a : f32 to vector<3xf32>
2504 static LogicalResult
2505 rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
2506  PatternRewriter &rewriter) {
2507  if (!llvm::all_equal(fromElementsOp.getElements()))
2508  return failure();
2509  rewriter.replaceOpWithNewOp<BroadcastOp>(
2510  fromElementsOp, fromElementsOp.getType(),
2511  fromElementsOp.getElements().front());
2512  return success();
2513 }
2514 
2515 /// Rewrite from_elements on multiple scalar extracts as a shape_cast
2516 /// on a single extract. Example:
2517 /// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
2518 /// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
2519 /// %2 = vector.from_elements %0, %1 : vector<2xi8>
2520 ///
2521 /// becomes
2522 /// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
2523 /// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
2524 ///
2525 /// The requirements for this to be valid are
2526 ///
2527 /// i) The elements are extracted from the same vector (%source).
2528 ///
2529 /// ii) The elements form a suffix of %source. Specifically, the number
2530 /// of elements is the same as the product of the last N dimension sizes
2531 /// of %source, for some N.
2532 ///
2533 /// iii) The elements are extracted contiguously in ascending order.
2534 
2535 class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
2536 
2537  using OpRewritePattern::OpRewritePattern;
2538 
2539  LogicalResult matchAndRewrite(FromElementsOp fromElements,
2540  PatternRewriter &rewriter) const override {
2541 
2542  // Handled by `rewriteFromElementsAsBroadcast`.
2543  if (fromElements.getType().getNumElements() == 1)
2544  return failure();
2545 
2546  // The common source that all elements are extracted from, if one exists.
2547  TypedValue<VectorType> source;
2548  // The position of the combined extract operation, if one is created.
2549  ArrayRef<int64_t> combinedPosition;
2550  // The expected index of extraction of the current element in the loop, if
2551  // elements are extracted contiguously in ascending order.
2552  SmallVector<int64_t> expectedPosition;
2553 
2554  for (auto [insertIndex, element] :
2555  llvm::enumerate(fromElements.getElements())) {
2556 
2557  // Check that the element is from a vector.extract operation.
2558  auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2559  if (!extractOp) {
2560  return rewriter.notifyMatchFailure(fromElements,
2561  "element not from vector.extract");
2562  }
2563 
2564  // Check condition (i) by checking that all elements have the same source
2565  // as the first element.
2566  if (insertIndex == 0) {
2567  source = extractOp.getSource();
2568  } else if (extractOp.getSource() != source) {
2569  return rewriter.notifyMatchFailure(fromElements,
2570  "element from different vector");
2571  }
2572 
2573  ArrayRef<int64_t> position = extractOp.getStaticPosition();
2574  int64_t rank = position.size();
2575  assert(rank == source.getType().getRank() &&
2576  "scalar extract must have full rank position");
2577 
2578  // Check condition (ii) by checking that the position that the first
2579  // element is extracted from has sufficient trailing 0s. For example, in
2580  //
2581  // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
2582  // [...]
2583  // %elms = vector.from_elements %elm0, [...] : vector<12xi8>
2584  //
2585  // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
2586  // elements, which is the number of elements of %n, so this is valid.
2587  if (insertIndex == 0) {
2588  const int64_t numElms = fromElements.getType().getNumElements();
2589  int64_t numSuffixElms = 1;
2590  int64_t index = rank;
2591  while (index > 0 && position[index - 1] == 0 &&
2592  numSuffixElms < numElms) {
2593  numSuffixElms *= source.getType().getDimSize(index - 1);
2594  --index;
2595  }
2596  if (numSuffixElms != numElms) {
2597  return rewriter.notifyMatchFailure(
2598  fromElements, "elements do not form a suffix of source");
2599  }
2600  expectedPosition = llvm::to_vector(position);
2601  combinedPosition = position.drop_back(rank - index);
2602  }
2603 
2604  // Check condition (iii).
2605  else if (expectedPosition != position) {
2606  return rewriter.notifyMatchFailure(
2607  fromElements, "elements not in ascending order (static order)");
2608  }
2609  increment(expectedPosition, source.getType().getShape());
2610  }
2611 
2612  auto extracted = rewriter.createOrFold<vector::ExtractOp>(
2613  fromElements.getLoc(), source, combinedPosition);
2614 
2615  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2616  fromElements, fromElements.getType(), extracted);
2617 
2618  return success();
2619  }
2620 
2621  /// Increments n-D `indices` by 1 starting from the innermost dimension.
2622  static void increment(MutableArrayRef<int64_t> indices,
2623  ArrayRef<int64_t> shape) {
2624  for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2625  indices[dim] += 1;
2626  if (indices[dim] < shape[dim])
2627  break;
2628  indices[dim] = 0;
2629  }
2630  }
2631 };
2632 
2633 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2634  MLIRContext *context) {
2636  results.add<FromElementsToShapeCast>(context);
2637 }
2638 
2639 //===----------------------------------------------------------------------===//
2640 // BroadcastOp
2641 //===----------------------------------------------------------------------===//
2642 
2643 void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2644  SetIntRangeFn setResultRanges) {
2645  setResultRanges(getResult(), argRanges.front());
2646 }
2647 
2648 std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2649  return llvm::to_vector<4>(getResultVectorType().getShape());
2650 }
2651 
2652 /// Return the dimensions of the result vector that were formerly ones in the
2653 /// source tensor and thus correspond to "dim-1" broadcasting.
2656  ArrayRef<int64_t> dstShape) {
2657  int64_t rankDiff = dstShape.size() - srcShape.size();
2658  int64_t dstDim = rankDiff;
2660  for (auto [s1, s2] :
2661  llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2662  if (s1 != s2) {
2663  assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2664  res.insert(dstDim);
2665  }
2666  ++dstDim;
2667  }
2668  return res;
2669 }
2670 
2672  // Scalar broadcast is without any unit dim broadcast.
2673  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2674  if (!srcVectorType)
2675  return {};
2676  return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2677  getResultVectorType().getShape());
2678 }
2679 
2680 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2681 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
2682 /// This requires (and asserts) that the broadcast is free of "dim-1"
2683 /// broadcasting.
2684 /// Since vector.broadcast only allows expanding leading dimensions, an extra
2685 /// vector.transpose may be inserted to make the broadcast possible.
2686 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2687 /// the helper will assert. This means:
2688 /// 1. `dstShape` must not be empty.
2689 /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2690 /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2691 // must match the `value` shape.
2692 Value BroadcastOp::createOrFoldBroadcastOp(
2693  OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2694  const llvm::SetVector<int64_t> &broadcastedDims) {
2695  assert(!dstShape.empty() && "unexpected empty dst shape");
2696 
2697  // Well-formedness check.
2698  SmallVector<int64_t> checkShape;
2699  for (int i = 0, e = dstShape.size(); i < e; ++i) {
2700  if (broadcastedDims.contains(i))
2701  continue;
2702  checkShape.push_back(dstShape[i]);
2703  }
2704  assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2705  "ill-formed broadcastedDims contains values not confined to "
2706  "destVectorShape");
2707 
2708  Location loc = value.getLoc();
2709  Type elementType = getElementTypeOrSelf(value.getType());
2710  VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2711  VectorType dstVectorType = VectorType::get(dstShape, elementType);
2712 
2713  // Step 2. If scalar -> dstShape broadcast, just do it.
2714  if (!srcVectorType) {
2715  assert(checkShape.empty() &&
2716  "ill-formed createOrFoldBroadcastOp arguments");
2717  return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2718  }
2719 
2720  assert(srcVectorType.getShape().equals(checkShape) &&
2721  "ill-formed createOrFoldBroadcastOp arguments");
2722 
2723  // Step 3. Since vector.broadcast only allows creating leading dims,
2724  // vector -> dstShape broadcast may require a transpose.
2725  // Traverse the dims in order and construct:
2726  // 1. The leading entries of the broadcastShape that is guaranteed to be
2727  // achievable by a simple broadcast.
2728  // 2. The induced permutation for the subsequent vector.transpose that will
2729  // bring us from `broadcastShape` back to he desired `dstShape`.
2730  // If the induced permutation is not the identity, create a vector.transpose.
2731  SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2732  broadcastShape.reserve(dstShape.size());
2733  // Consider the example:
2734  // srcShape = 2x4
2735  // dstShape = 1x2x3x4x5
2736  // broadcastedDims = [0, 2, 4]
2737  //
2738  // We want to build:
2739  // broadcastShape = 1x3x5x2x4
2740  // permutation = [0, 2, 4, 1, 3]
2741  // ---V--- -----V-----
2742  // leading broadcast part src shape part
2743  //
2744  // Note that the trailing dims of broadcastShape are exactly the srcShape
2745  // by construction.
2746  // nextSrcShapeDim is used to keep track of where in the permutation the
2747  // "src shape part" occurs.
2748  int64_t nextSrcShapeDim = broadcastedDims.size();
2749  for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2750  if (broadcastedDims.contains(i)) {
2751  // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2752  // bring it to the head of the broadcastShape.
2753  // It will need to be permuted back from `broadcastShape.size() - 1` into
2754  // position `i`.
2755  broadcastShape.push_back(dstShape[i]);
2756  permutation[i] = broadcastShape.size() - 1;
2757  } else {
2758  // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2759  // shape and needs to be permuted into position `i`.
2760  // Don't touch `broadcastShape` here, the whole srcShape will be
2761  // appended after.
2762  permutation[i] = nextSrcShapeDim++;
2763  }
2764  }
2765  // 3.c. Append the srcShape.
2766  llvm::append_range(broadcastShape, srcVectorType.getShape());
2767 
2768  // Ensure there are no "dim-1" broadcasts.
2769  assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2770  .empty() &&
2771  "unexpected \"dim-1\" broadcast");
2772 
2773  VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2774  assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2775  vector::BroadcastableToResult::Success &&
2776  "must be broadcastable");
2777  Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2778  // Step 4. If we find any dimension that indeed needs to be permuted,
2779  // immediately return a new vector.transpose.
2780  for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2781  if (permutation[i] != i)
2782  return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2783  // Otherwise return res.
2784  return res;
2785 }
2786 
2788  Type srcType, VectorType dstVectorType,
2789  std::pair<VectorDim, VectorDim> *mismatchingDims) {
2790  // Broadcast scalar to vector of the same element type.
2791  if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
2792  srcType == getElementTypeOrSelf(dstVectorType))
2793  return BroadcastableToResult::Success;
2794  // From now on, only vectors broadcast.
2795  VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2796  if (!srcVectorType)
2797  return BroadcastableToResult::SourceTypeNotAVector;
2798 
2799  int64_t srcRank = srcVectorType.getRank();
2800  int64_t dstRank = dstVectorType.getRank();
2801  if (srcRank > dstRank)
2802  return BroadcastableToResult::SourceRankHigher;
2803  // Source has an exact match or singleton value for all trailing dimensions
2804  // (all leading dimensions are simply duplicated).
2805  int64_t lead = dstRank - srcRank;
2806  for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2807  // Have mismatching dims (in the sense of vector.broadcast semantics) been
2808  // encountered?
2809  bool foundMismatchingDims = false;
2810 
2811  // Check fixed-width dims.
2812  int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2813  int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2814  if (srcDim != 1 && srcDim != dstDim)
2815  foundMismatchingDims = true;
2816 
2817  // Check scalable flags.
2818  bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2819  bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2820  if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2821  // 1 -> [N] is fine, everything else should be rejected when mixing
2822  // fixed-width and scalable dims
2823  (srcDimScalableFlag != dstDimScalableFlag &&
2824  (srcDim != 1 || srcDimScalableFlag)))
2825  foundMismatchingDims = true;
2826 
2827  if (foundMismatchingDims) {
2828  if (mismatchingDims != nullptr) {
2829  mismatchingDims->first.dim = srcDim;
2830  mismatchingDims->first.isScalable = srcDimScalableFlag;
2831 
2832  mismatchingDims->second.dim = dstDim;
2833  mismatchingDims->second.isScalable = dstDimScalableFlag;
2834  }
2835  return BroadcastableToResult::DimensionMismatch;
2836  }
2837  }
2838 
2839  return BroadcastableToResult::Success;
2840 }
2841 
2842 LogicalResult BroadcastOp::verify() {
2843  std::pair<VectorDim, VectorDim> mismatchingDims;
2845  getSourceType(), getResultVectorType(), &mismatchingDims);
2846  if (res == BroadcastableToResult::Success)
2847  return success();
2848  if (res == BroadcastableToResult::SourceRankHigher)
2849  return emitOpError("source rank higher than destination rank");
2850  if (res == BroadcastableToResult::DimensionMismatch) {
2851  return emitOpError("dimension mismatch (")
2852  << (mismatchingDims.first.isScalable ? "[" : "")
2853  << mismatchingDims.first.dim
2854  << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
2855  << (mismatchingDims.second.isScalable ? "[" : "")
2856  << mismatchingDims.second.dim
2857  << (mismatchingDims.second.isScalable ? "]" : "") << ")";
2858  }
2859  if (res == BroadcastableToResult::SourceTypeNotAVector)
2860  return emitOpError("source type is not a vector");
2861  llvm_unreachable("unexpected vector.broadcast op error");
2862 }
2863 
2864 // Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
2865 // with broadcast's result type and shape_cast only adds or removes ones in the
2866 // leading dimensions.
2867 static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
2868  auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
2869  if (!srcShapeCast)
2870  return failure();
2871 
2872  VectorType srcType = srcShapeCast.getSourceVectorType();
2873  VectorType destType = broadcastOp.getResultVectorType();
2874  // Check type compatibility.
2875  if (vector::isBroadcastableTo(srcType, destType) !=
2876  BroadcastableToResult::Success)
2877  return failure();
2878 
2879  ArrayRef<int64_t> srcShape = srcType.getShape();
2880  ArrayRef<int64_t> shapecastShape =
2881  srcShapeCast.getResultVectorType().getShape();
2882  // Trailing dimensions should be the same if shape_cast only alters the
2883  // leading dimensions.
2884  unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
2885  if (!llvm::equal(srcShape.take_back(numTrailingDims),
2886  shapecastShape.take_back(numTrailingDims)))
2887  return failure();
2888 
2889  assert(all_of(srcShape.drop_back(numTrailingDims),
2890  [](int64_t E) { return E == 1; }) &&
2891  all_of(shapecastShape.drop_back(numTrailingDims),
2892  [](int64_t E) { return E == 1; }) &&
2893  "ill-formed shape_cast");
2894 
2895  broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
2896  return success();
2897 }
2898 
2899 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
2900  if (getSourceType() == getResultVectorType())
2901  return getSource();
2902  if (succeeded(foldBroadcastOfShapeCast(*this)))
2903  return getResult();
2904 
2905  if (!adaptor.getSource())
2906  return {};
2907  auto vectorType = getResultVectorType();
2908  if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
2909  if (vectorType.getElementType() != attr.getType())
2910  return {};
2911  return DenseElementsAttr::get(vectorType, attr);
2912  }
2913  if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
2914  if (vectorType.getElementType() != attr.getType())
2915  return {};
2916  return DenseElementsAttr::get(vectorType, attr);
2917  }
2918  if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2919  return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
2920  if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
2921  return ub::PoisonAttr::get(getContext());
2922  return {};
2923 }
2924 
2925 namespace {
2926 
2927 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
2928 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2930 
2931  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2932  PatternRewriter &rewriter) const override {
2933  auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2934  if (!srcBroadcast)
2935  return failure();
2936  rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
2937  broadcastOp.getResultVectorType(),
2938  srcBroadcast.getSource());
2939  return success();
2940  }
2941 };
2942 } // namespace
2943 
2944 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2945  MLIRContext *context) {
2946  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2947  // calling `populateCastAwayVectorLeadingOneDimPatterns`
2948  results.add<BroadcastFolder>(context);
2949 }
2950 
2951 //===----------------------------------------------------------------------===//
2952 // ShuffleOp
2953 //===----------------------------------------------------------------------===//
2954 
2955 LogicalResult ShuffleOp::verify() {
2956  VectorType resultType = getResultVectorType();
2957  VectorType v1Type = getV1VectorType();
2958  VectorType v2Type = getV2VectorType();
2959  // Verify ranks.
2960  int64_t resRank = resultType.getRank();
2961  int64_t v1Rank = v1Type.getRank();
2962  int64_t v2Rank = v2Type.getRank();
2963  bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2964  bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2965  if (!wellFormed0DCase && !wellFormedNDCase)
2966  return emitOpError("rank mismatch");
2967 
2968  // Verify all but leading dimension sizes.
2969  for (int64_t r = 1; r < v1Rank; ++r) {
2970  int64_t resDim = resultType.getDimSize(r);
2971  int64_t v1Dim = v1Type.getDimSize(r);
2972  int64_t v2Dim = v2Type.getDimSize(r);
2973  if (resDim != v1Dim || v1Dim != v2Dim)
2974  return emitOpError("dimension mismatch");
2975  }
2976  // Verify mask length.
2977  ArrayRef<int64_t> mask = getMask();
2978  int64_t maskLength = mask.size();
2979  if (maskLength <= 0)
2980  return emitOpError("invalid mask length");
2981  if (maskLength != resultType.getDimSize(0))
2982  return emitOpError("mask length mismatch");
2983  // Verify all indices.
2984  int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2985  (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2986  for (auto [idx, maskPos] : llvm::enumerate(mask)) {
2987  if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
2988  return emitOpError("mask index #") << (idx + 1) << " out of range";
2989  }
2990  return success();
2991 }
2992 
2993 LogicalResult
2994 ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
2995  ShuffleOp::Adaptor adaptor,
2996  SmallVectorImpl<Type> &inferredReturnTypes) {
2997  auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2998  auto v1Rank = v1Type.getRank();
2999  // Construct resulting type: leading dimension matches mask
3000  // length, all trailing dimensions match the operands.
3002  shape.reserve(v1Rank);
3003  shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3004  // In the 0-D case there is no trailing shape to append.
3005  if (v1Rank > 0)
3006  llvm::append_range(shape, v1Type.getShape().drop_front());
3007  inferredReturnTypes.push_back(
3008  VectorType::get(shape, v1Type.getElementType()));
3009  return success();
3010 }
3011 
3012 template <typename T>
3013 static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
3014  T expected = begin;
3015  return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3016  return value == expected++;
3017  });
3018 }
3019 
3020 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3021  auto v1Type = getV1VectorType();
3022  auto v2Type = getV2VectorType();
3023 
3024  assert(!v1Type.isScalable() && !v2Type.isScalable() &&
3025  "Vector shuffle does not support scalable vectors");
3026 
3027  // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
3028  // but must be a canonicalization into a vector.broadcast.
3029  if (v1Type.getRank() == 0)
3030  return {};
3031 
3032  // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
3033  auto mask = getMask();
3034  if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
3035  return getV1();
3036  // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
3037  if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
3038  return getV2();
3039 
3040  Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3041  if (!v1Attr || !v2Attr)
3042  return {};
3043 
3044  // Fold shuffle poison, poison -> poison.
3045  bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
3046  bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
3047  if (isV1Poison && isV2Poison)
3048  return ub::PoisonAttr::get(getContext());
3049 
3050  // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
3051  // manipulation.
3052  if (v1Type.getRank() != 1)
3053  return {};
3054 
3055  // Poison input attributes need special handling as they are not
3056  // DenseElementsAttr. If an index is poison, we select the first element of
3057  // the first non-poison input.
3058  SmallVector<Attribute> v1Elements, v2Elements;
3059  Attribute poisonElement;
3060  if (!isV2Poison) {
3061  auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3062  if (!v2DenseAttr)
3063  return {};
3064  v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3065  poisonElement = v2Elements[0];
3066  }
3067  if (!isV1Poison) {
3068  auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3069  if (!v1DenseAttr)
3070  return {};
3071  v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3072  poisonElement = v1Elements[0];
3073  }
3074 
3075  SmallVector<Attribute> results;
3076  int64_t v1Size = v1Type.getDimSize(0);
3077  for (int64_t maskIdx : mask) {
3078  Attribute indexedElm;
3079  // TODO: Return a partial poison vector when supported by the UB dialect.
3080  if (maskIdx == ShuffleOp::kPoisonIndex) {
3081  indexedElm = poisonElement;
3082  } else {
3083  if (maskIdx < v1Size)
3084  indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3085  else
3086  indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3087  }
3088 
3089  results.push_back(indexedElm);
3090  }
3091 
3092  return DenseElementsAttr::get(getResultVectorType(), results);
3093 }
3094 
3095 namespace {
3096 
3097 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
3098 // to a broadcast.
3099 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
3101 
3102  LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3103  PatternRewriter &rewriter) const override {
3104  VectorType v1VectorType = shuffleOp.getV1VectorType();
3105  ArrayRef<int64_t> mask = shuffleOp.getMask();
3106  if (v1VectorType.getRank() > 0)
3107  return failure();
3108  if (mask.size() != 1)
3109  return failure();
3110  VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3111  if (mask[0] == 0)
3112  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3113  shuffleOp.getV1());
3114  else
3115  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3116  shuffleOp.getV2());
3117  return success();
3118  }
3119 };
3120 
3121 /// Consider the defining operation `defOp` of `value`. If `defOp` is a
3122 /// vector.splat or a vector.broadcast with a scalar operand, return the scalar
3123 /// value that is splatted. Otherwise return null.
3124 ///
3125 /// Examples:
3126 ///
3127 /// scalar_source --> vector.splat --> value - return scalar_source
3128 /// scalar_source --> vector.broadcast --> value - return scalar_source
3129 static Value getScalarSplatSource(Value value) {
3130  // Block argument:
3131  Operation *defOp = value.getDefiningOp();
3132  if (!defOp)
3133  return {};
3134 
3135  // Splat:
3136  if (auto splat = dyn_cast<vector::SplatOp>(defOp))
3137  return splat.getInput();
3138 
3139  auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3140 
3141  // Not broadcast (and not splat):
3142  if (!broadcast)
3143  return {};
3144 
3145  // Broadcast of a vector:
3146  if (isa<VectorType>(broadcast.getSourceType()))
3147  return {};
3148 
3149  // Broadcast of a scalar:
3150  return broadcast.getSource();
3151 }
3152 
3153 /// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
3154 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
3155 public:
3157 
3158  LogicalResult matchAndRewrite(ShuffleOp op,
3159  PatternRewriter &rewriter) const override {
3160  Value splat = getScalarSplatSource(op.getV1());
3161  if (!splat || getScalarSplatSource(op.getV2()) != splat)
3162  return failure();
3163 
3164  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3165  return success();
3166  }
3167 };
3168 
3169 /// Pattern to rewrite a fixed-size interleave via vector.shuffle to
3170 /// vector.interleave.
3171 class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
3172 public:
3174 
3175  LogicalResult matchAndRewrite(ShuffleOp op,
3176  PatternRewriter &rewriter) const override {
3177  VectorType resultType = op.getResultVectorType();
3178  if (resultType.isScalable())
3179  return rewriter.notifyMatchFailure(
3180  op, "ShuffleOp can't represent a scalable interleave");
3181 
3182  if (resultType.getRank() != 1)
3183  return rewriter.notifyMatchFailure(
3184  op, "ShuffleOp can't represent an n-D interleave");
3185 
3186  VectorType sourceType = op.getV1VectorType();
3187  if (sourceType != op.getV2VectorType() ||
3188  sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3189  return rewriter.notifyMatchFailure(
3190  op, "ShuffleOp types don't match an interleave");
3191  }
3192 
3193  ArrayRef<int64_t> shuffleMask = op.getMask();
3194  int64_t resultVectorSize = resultType.getNumElements();
3195  for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3196  int64_t maskValueA = shuffleMask[i * 2];
3197  int64_t maskValueB = shuffleMask[(i * 2) + 1];
3198  if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3199  return rewriter.notifyMatchFailure(op,
3200  "ShuffleOp mask not interleaving");
3201  }
3202 
3203  rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
3204  return success();
3205  }
3206 };
3207 
3208 } // namespace
3209 
3210 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3211  MLIRContext *context) {
3212  results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3213  context);
3214 }
3215 
3216 //===----------------------------------------------------------------------===//
3217 // InsertOp
3218 //===----------------------------------------------------------------------===//
3219 
3220 void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3221  SetIntRangeFn setResultRanges) {
3222  setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3223 }
3224 
3225 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3226  Value source, Value dest) {
3227  auto vectorTy = cast<VectorType>(dest.getType());
3228  build(builder, result, source, dest,
3229  SmallVector<int64_t>(vectorTy.getRank(), 0));
3230 }
3231 
3232 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3233  Value source, Value dest, int64_t position) {
3234  build(builder, result, source, dest, ArrayRef<int64_t>{position});
3235 }
3236 
3237 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3238  Value source, Value dest, OpFoldResult position) {
3239  build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
3240 }
3241 
3242 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3243  Value source, Value dest,
3244  ArrayRef<int64_t> position) {
3245  SmallVector<OpFoldResult> posVals;
3246  posVals.reserve(position.size());
3247  llvm::transform(position, std::back_inserter(posVals),
3248  [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
3249  build(builder, result, source, dest, posVals);
3250 }
3251 
3252 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3253  Value source, Value dest,
3254  ArrayRef<OpFoldResult> position) {
3255  SmallVector<int64_t> staticPos;
3256  SmallVector<Value> dynamicPos;
3257  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
3258  build(builder, result, source, dest, dynamicPos,
3259  builder.getDenseI64ArrayAttr(staticPos));
3260 }
3261 
3262 LogicalResult InsertOp::verify() {
3263  if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3264  if (srcTy.getRank() == 0)
3265  return emitError(
3266  "expected a scalar instead of a 0-d vector as the source operand");
3267 
3268  SmallVector<OpFoldResult> position = getMixedPosition();
3269  auto destVectorType = getDestVectorType();
3270  if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
3271  return emitOpError(
3272  "expected position attribute of rank no greater than dest vector rank");
3273  auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3274  if (srcVectorType &&
3275  (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3276  static_cast<unsigned>(destVectorType.getRank())))
3277  return emitOpError("expected position attribute rank + source rank to "
3278  "match dest vector rank");
3279  if (!srcVectorType &&
3280  (position.size() != static_cast<unsigned>(destVectorType.getRank())))
3281  return emitOpError(
3282  "expected position attribute rank to match the dest vector rank");
3283  for (auto [idx, pos] : llvm::enumerate(position)) {
3284  if (auto attr = dyn_cast<Attribute>(pos)) {
3285  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3286  if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
3287  destVectorType.getDimSize(idx))) {
3288  return emitOpError("expected position attribute #")
3289  << (idx + 1)
3290  << " to be a non-negative integer smaller than the "
3291  "corresponding "
3292  "dest vector dimension";
3293  }
3294  }
3295  }
3296  return success();
3297 }
3298 
3299 // Calculate the linearized position of the continuous chunk of elements to
3300 // insert, based on the shape of the value to insert and the positions to insert
3301 // at.
3302 static int64_t calculateInsertPosition(VectorType destTy,
3303  ArrayRef<int64_t> positions) {
3304  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3305  assert(positions.size() <= completePositions.size() &&
3306  "positions size must be less than or equal to destTy rank");
3307  copy(positions, completePositions.begin());
3308  return linearize(completePositions, computeStrides(destTy.getShape()));
3309 }
3310 
3311 namespace {
3312 
3313 // If insertOp is only inserting unit dimensions it can be transformed to a
3314 // broadcast.
3315 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3316 public:
3318 
3319  LogicalResult matchAndRewrite(InsertOp insertOp,
3320  PatternRewriter &rewriter) const override {
3321  auto srcVecType =
3322  llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3323  if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3324  srcVecType.getNumElements())
3325  return failure();
3326  rewriter.replaceOpWithNewOp<BroadcastOp>(
3327  insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3328  return success();
3329  }
3330 };
3331 
3332 /// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
3333 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3334 public:
3336 
3337  LogicalResult matchAndRewrite(InsertOp op,
3338  PatternRewriter &rewriter) const override {
3339 
3340  Value splat = getScalarSplatSource(op.getValueToStore());
3341  if (!splat || getScalarSplatSource(op.getDest()) != splat)
3342  return failure();
3343 
3344  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3345  return success();
3346  }
3347 };
3348 
3349 /// Pattern to optimize a chain of insertions.
3350 ///
3351 /// This pattern identifies chains of vector.insert operations that:
3352 /// 1. Only insert values at static positions.
3353 /// 2. Completely initialize all elements in the resulting vector.
3354 /// 3. All intermediate insert operations have only one use.
3355 ///
3356 /// When these conditions are met, the entire chain can be replaced with a
3357 /// single vector.from_elements operation.
3358 ///
3359 /// To keep this pattern simple, and avoid spending too much time on matching
3360 /// fragmented insert chains, this pattern only considers the last insert op in
3361 /// the chain.
3362 ///
3363 /// Example transformation:
3364 /// %poison = ub.poison : vector<2xi32>
3365 /// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
3366 /// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
3367 /// ->
3368 /// %result = vector.from_elements %c1, %c2 : vector<2xi32>
3369 class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
3370 public:
3372  LogicalResult matchAndRewrite(InsertOp op,
3373  PatternRewriter &rewriter) const override {
3374 
3375  VectorType destTy = op.getDestVectorType();
3376  if (destTy.isScalable())
3377  return failure();
3378  // Ensure this is the trailing vector.insert op in a chain of inserts.
3379  for (Operation *user : op.getResult().getUsers())
3380  if (auto insertOp = dyn_cast<InsertOp>(user))
3381  if (insertOp.getDest() == op.getResult())
3382  return failure();
3383 
3384  InsertOp currentOp = op;
3385  SmallVector<InsertOp> chainInsertOps;
3386  while (currentOp) {
3387  // Check cond 1: Dynamic position is not supported.
3388  if (currentOp.hasDynamicPosition())
3389  return failure();
3390 
3391  chainInsertOps.push_back(currentOp);
3392  currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3393  // Check cond 3: Intermediate inserts have only one use to avoid an
3394  // explosion of vectors.
3395  if (currentOp && !currentOp->hasOneUse())
3396  return failure();
3397  }
3398 
3399  int64_t vectorSize = destTy.getNumElements();
3400  int64_t initializedCount = 0;
3401  SmallVector<bool> initializedDestIdxs(vectorSize, false);
3402  SmallVector<int64_t> pendingInsertPos;
3403  SmallVector<int64_t> pendingInsertSize;
3404  SmallVector<Value> pendingInsertValues;
3405 
3406  for (auto insertOp : chainInsertOps) {
3407  // This pattern can do nothing with poison index.
3408  if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3409  return failure();
3410 
3411  // Calculate the linearized position for inserting elements.
3412  int64_t insertBeginPosition =
3413  calculateInsertPosition(destTy, insertOp.getStaticPosition());
3414 
3415  // The valueToStore operand may be a vector or a scalar. Need to handle
3416  // both cases.
3417  int64_t insertSize = 1;
3418  if (auto srcVectorType =
3419  llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3420  insertSize = srcVectorType.getNumElements();
3421 
3422  assert(insertBeginPosition + insertSize <= vectorSize &&
3423  "insert would overflow the vector");
3424 
3425  for (auto index : llvm::seq<int64_t>(insertBeginPosition,
3426  insertBeginPosition + insertSize)) {
3427  if (initializedDestIdxs[index])
3428  continue;
3429  initializedDestIdxs[index] = true;
3430  ++initializedCount;
3431  }
3432 
3433  // Defer the creation of ops before we can make sure the pattern can
3434  // succeed.
3435  pendingInsertPos.push_back(insertBeginPosition);
3436  pendingInsertSize.push_back(insertSize);
3437  pendingInsertValues.push_back(insertOp.getValueToStore());
3438 
3439  if (initializedCount == vectorSize)
3440  break;
3441  }
3442 
3443  // Check cond 2: all positions must be initialized.
3444  if (initializedCount != vectorSize)
3445  return failure();
3446 
3447  SmallVector<Value> elements(vectorSize);
3448  for (auto [insertBeginPosition, insertSize, valueToStore] :
3449  llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3450  pendingInsertValues))) {
3451  auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3452 
3453  if (!srcVectorType) {
3454  elements[insertBeginPosition] = valueToStore;
3455  continue;
3456  }
3457 
3458  SmallVector<Type> elementToInsertTypes(insertSize,
3459  srcVectorType.getElementType());
3460  // Get all elements from the vector in row-major order.
3461  auto elementsToInsert = vector::ToElementsOp::create(
3462  rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3463  for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3464  elements[insertBeginPosition + linearIdx] =
3465  elementsToInsert.getResult(linearIdx);
3466  }
3467  }
3468 
3469  rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
3470  return success();
3471  }
3472 };
3473 
3474 } // namespace
3475 
3476 static Attribute
3478  Attribute dstAttr,
3479  int64_t maxVectorSizeFoldThreshold) {
3480  if (insertOp.hasDynamicPosition())
3481  return {};
3482 
3483  auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3484  if (!denseDst)
3485  return {};
3486 
3487  if (!srcAttr) {
3488  return {};
3489  }
3490 
3491  VectorType destTy = insertOp.getDestVectorType();
3492  if (destTy.isScalable())
3493  return {};
3494 
3495  // Make sure we do not create too many large constants.
3496  if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3497  !insertOp->hasOneUse())
3498  return {};
3499 
3500  // Calculate the linearized position for inserting elements.
3501  int64_t insertBeginPosition =
3502  calculateInsertPosition(destTy, insertOp.getStaticPosition());
3503  SmallVector<Attribute> insertedValues;
3504  Type destEltType = destTy.getElementType();
3505 
3506  /// Converts integer attributes to the expected type if there's a mismatch.
3507  /// Non-integer attributes are left unchanged.
3508  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3509  for (auto value : denseSource.getValues<Attribute>())
3510  if (auto intAttr = dyn_cast<IntegerAttr>(value))
3511  insertedValues.push_back(convertIntegerAttr(intAttr, destEltType));
3512  else
3513  insertedValues.push_back(value); // Non-integer attributes unchanged
3514  } else {
3515  if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
3516  insertedValues.push_back(convertIntegerAttr(intAttr, destEltType));
3517  else
3518  insertedValues.push_back(srcAttr); // Non-integer attributes unchanged
3519  }
3520 
3521  auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
3522  copy(insertedValues, allValues.begin() + insertBeginPosition);
3523  auto newAttr = DenseElementsAttr::get(destTy, allValues);
3524 
3525  return newAttr;
3526 }
3527 
3528 /// Folder to replace the `dest` operand of the insert op with the root dest of
3529 /// the insert op use chain.
3530 static Value foldInsertUseChain(InsertOp insertOp) {
3531  auto destInsert = insertOp.getDest().getDefiningOp<InsertOp>();
3532  if (!destInsert)
3533  return {};
3534 
3535  if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3536  return {};
3537 
3538  insertOp.setOperand(1, destInsert.getDest());
3539  return insertOp.getResult();
3540 }
3541 
3542 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3543  MLIRContext *context) {
3544  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3545  InsertChainFullyInitialized>(context);
3546 }
3547 
3548 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
3549  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3550  // unless the source vector constant has a single use.
3551  constexpr int64_t vectorSizeFoldThreshold = 256;
3552  // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3553  // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3554  // (type mismatch).
3555  if (getNumIndices() == 0 && getValueToStoreType() == getType())
3556  return getValueToStore();
3557  // Fold `arith.constant` indices into the `vector.insert` operation.
3558  // Do not stop here as this fold may enable subsequent folds that require
3559  // constant indices.
3560  SmallVector<Value> operands = {getValueToStore(), getDest()};
3561  auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
3562 
3563  if (auto res = foldInsertUseChain(*this))
3564  return res;
3565  if (auto res = foldPoisonIndexInsertExtractOp(
3566  getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3567  return res;
3568  if (auto res = foldDenseElementsAttrDestInsertOp(
3569  *this, adaptor.getValueToStore(), adaptor.getDest(),
3570  vectorSizeFoldThreshold)) {
3571  return res;
3572  }
3573 
3574  return inplaceFolded;
3575 }
3576 
3577 //===----------------------------------------------------------------------===//
3578 // InsertStridedSliceOp
3579 //===----------------------------------------------------------------------===//
3580 
3581 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3582  Value source, Value dest,
3583  ArrayRef<int64_t> offsets,
3584  ArrayRef<int64_t> strides) {
3585  result.addOperands({source, dest});
3586  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3587  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3588  result.addTypes(dest.getType());
3589  result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
3590  offsetsAttr);
3591  result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
3592  stridesAttr);
3593 }
3594 
3595 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
3596 template <typename OpType>
3597 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3598  ArrayAttr arrayAttr,
3599  ArrayRef<int64_t> shape,
3600  StringRef attrName) {
3601  if (arrayAttr.size() > shape.size())
3602  return op.emitOpError("expected ")
3603  << attrName << " attribute of rank no greater than vector rank";
3604  return success();
3605 }
3606 
3607 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3608 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3609 // Otherwise, the admissible interval is [min, max].
3610 template <typename OpType>
3611 static LogicalResult
3612 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
3613  int64_t max, StringRef attrName,
3614  bool halfOpen = true) {
3615  for (auto attr : arrayAttr) {
3616  auto val = llvm::cast<IntegerAttr>(attr).getInt();
3617  auto upper = max;
3618  if (!halfOpen)
3619  upper += 1;
3620  if (val < min || val >= upper)
3621  return op.emitOpError("expected ") << attrName << " to be confined to ["
3622  << min << ", " << upper << ")";
3623  }
3624  return success();
3625 }
3626 
3627 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3628 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3629 // Otherwise, the admissible interval is [min, max].
3630 template <typename OpType>
3631 static LogicalResult
3632 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
3633  ArrayRef<int64_t> shape, StringRef attrName,
3634  bool halfOpen = true, int64_t min = 0) {
3635  for (auto [index, attrDimPair] :
3636  llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3637  int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3638  int64_t max = std::get<1>(attrDimPair);
3639  if (!halfOpen)
3640  max += 1;
3641  if (val < min || val >= max)
3642  return op.emitOpError("expected ")
3643  << attrName << " dimension " << index << " to be confined to ["
3644  << min << ", " << max << ")";
3645  }
3646  return success();
3647 }
3648 
3649 // Returns true if, for all indices i = 0..shape.size()-1, val is in the
3650 // [min, max} interval:
3651 // val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3652 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3653 // the admissible interval is [min, max].
3654 template <typename OpType>
3656  OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3657  ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3658  bool halfOpen = true, int64_t min = 1) {
3659  assert(arrayAttr1.size() <= shape.size());
3660  assert(arrayAttr2.size() <= shape.size());
3661  for (auto [index, it] :
3662  llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3663  auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3664  auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3665  int64_t max = std::get<2>(it);
3666  if (!halfOpen)
3667  max += 1;
3668  if (val1 + val2 < 0 || val1 + val2 >= max)
3669  return op.emitOpError("expected sum(")
3670  << attrName1 << ", " << attrName2 << ") dimension " << index
3671  << " to be confined to [" << min << ", " << max << ")";
3672  }
3673  return success();
3674 }
3675 
3676 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
3677  MLIRContext *context) {
3678  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3679  return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3680  });
3681  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3682 }
3683 
3684 LogicalResult InsertStridedSliceOp::verify() {
3685  auto sourceVectorType = getSourceVectorType();
3686  auto destVectorType = getDestVectorType();
3687  auto offsets = getOffsetsAttr();
3688  auto strides = getStridesAttr();
3689  if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3690  return emitOpError(
3691  "expected offsets of same size as destination vector rank");
3692  if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3693  return emitOpError("expected strides of same size as source vector rank");
3694  if (sourceVectorType.getRank() > destVectorType.getRank())
3695  return emitOpError(
3696  "expected source rank to be no greater than destination rank");
3697 
3698  auto sourceShape = sourceVectorType.getShape();
3699  auto destShape = destVectorType.getShape();
3700  SmallVector<int64_t, 4> sourceShapeAsDestShape(
3701  destShape.size() - sourceShape.size(), 0);
3702  sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3703  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3704  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3705  if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3706  offName)) ||
3707  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3708  /*max=*/1, stridesName,
3709  /*halfOpen=*/false)) ||
3711  *this, offsets,
3712  makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3713  offName, "source vector shape",
3714  /*halfOpen=*/false, /*min=*/1)))
3715  return failure();
3716 
3717  unsigned rankDiff = destShape.size() - sourceShape.size();
3718  for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3719  if (sourceVectorType.getScalableDims()[idx] !=
3720  destVectorType.getScalableDims()[idx + rankDiff]) {
3721  return emitOpError("mismatching scalable flags (at source vector idx=")
3722  << idx << ")";
3723  }
3724  if (sourceVectorType.getScalableDims()[idx]) {
3725  auto sourceSize = sourceShape[idx];
3726  auto destSize = destShape[idx + rankDiff];
3727  if (sourceSize != destSize) {
3728  return emitOpError("expected size at idx=")
3729  << idx
3730  << (" to match the corresponding base size from the input "
3731  "vector (")
3732  << sourceSize << (" vs ") << destSize << (")");
3733  }
3734  }
3735  }
3736 
3737  return success();
3738 }
3739 
3740 namespace {
3741 /// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
3742 class FoldInsertStridedSliceSplat final
3743  : public OpRewritePattern<InsertStridedSliceOp> {
3744 public:
3746 
3747  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3748  PatternRewriter &rewriter) const override {
3749 
3750  auto dst = insertStridedSliceOp.getDest();
3751  auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
3752  if (!splat || getScalarSplatSource(dst) != splat)
3753  return failure();
3754 
3755  rewriter.replaceOp(insertStridedSliceOp, dst);
3756  return success();
3757  }
3758 };
3759 
3760 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3761 /// to dst.
3762 class FoldInsertStridedSliceOfExtract final
3763  : public OpRewritePattern<InsertStridedSliceOp> {
3764 public:
3766 
3767  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3768  PatternRewriter &rewriter) const override {
3769  auto extractStridedSliceOp =
3770  insertStridedSliceOp.getValueToStore()
3771  .getDefiningOp<vector::ExtractStridedSliceOp>();
3772 
3773  if (!extractStridedSliceOp)
3774  return failure();
3775 
3776  if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3777  return failure();
3778 
3779  // Check if have the same strides and offsets.
3780  if (extractStridedSliceOp.getStrides() !=
3781  insertStridedSliceOp.getStrides() ||
3782  extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3783  return failure();
3784 
3785  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3786  return success();
3787  }
3788 };
3789 
3790 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3791 // ConstantOp.
3792 class InsertStridedSliceConstantFolder final
3793  : public OpRewritePattern<InsertStridedSliceOp> {
3794 public:
3796 
3797  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3798  // unless the source vector constant has a single use.
3799  static constexpr int64_t vectorSizeFoldThreshold = 256;
3800 
3801  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3802  PatternRewriter &rewriter) const override {
3803  // Return if 'InsertOp' operand is not defined by a compatible vector
3804  // ConstantOp.
3805  TypedValue<VectorType> destVector = op.getDest();
3806  Attribute vectorDestCst;
3807  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3808  return failure();
3809 
3810  VectorType destTy = destVector.getType();
3811  if (destTy.isScalable())
3812  return failure();
3813 
3814  // Make sure we do not create too many large constants.
3815  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3816  !destVector.hasOneUse())
3817  return failure();
3818 
3819  TypedValue<VectorType> sourceValue = op.getValueToStore();
3820  Attribute sourceCst;
3821  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3822  return failure();
3823 
3824  // TODO: Support poison.
3825  if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3826  return failure();
3827 
3828  // TODO: Handle non-unit strides when they become available.
3829  if (op.hasNonUnitStrides())
3830  return failure();
3831 
3832  VectorType sliceVecTy = sourceValue.getType();
3833  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3834  int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3835  SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3836  SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3837 
3838  // Calcualte the destination element indices by enumerating all slice
3839  // positions within the destination and linearizing them. The enumeration
3840  // order is lexicographic which yields a sequence of monotonically
3841  // increasing linearized position indices.
3842  // Because the destination may have higher dimensionality then the slice,
3843  // we keep track of two overlapping sets of positions and offsets.
3844  auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3845  auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3846  auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3847  auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3848  SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3849  MutableArrayRef<int64_t> currSlicePosition(
3850  currDestPosition.begin() + rankDifference, currDestPosition.end());
3851  ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3852  offsets.end());
3853  do {
3854  int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3855  assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3856  assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3857  "Invalid slice element");
3858  newValues[linearizedPosition] = *sliceValuesIt;
3859  ++sliceValuesIt;
3860  } while (succeeded(
3861  incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3862 
3863  auto newAttr = DenseElementsAttr::get(destTy, newValues);
3864  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3865  return success();
3866  }
3867 };
3868 
3869 } // namespace
3870 
3871 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3872  RewritePatternSet &results, MLIRContext *context) {
3873  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3874  InsertStridedSliceConstantFolder>(context);
3875 }
3876 
3877 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3878  if (getSourceVectorType() == getDestVectorType())
3879  return getValueToStore();
3880  return {};
3881 }
3882 
3883 //===----------------------------------------------------------------------===//
3884 // OuterProductOp
3885 //===----------------------------------------------------------------------===//
3886 
3887 /// Build an op without mask, use the type of `acc` as the return type.
3888 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
3889  Value lhs, Value rhs, Value acc) {
3890  result.addOperands({lhs, rhs, acc});
3891  result.addTypes(acc.getType());
3892 }
3893 
3895  p << " " << getLhs() << ", " << getRhs();
3896  if (getAcc()) {
3897  p << ", " << getAcc();
3898  p.printOptionalAttrDict((*this)->getAttrs());
3899  }
3900  p << " : " << getLhs().getType() << ", " << getRhs().getType();
3901 }
3902 
3903 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
3905  Type tLHS, tRHS;
3906  if (parser.parseOperandList(operandsInfo) ||
3907  parser.parseOptionalAttrDict(result.attributes) ||
3908  parser.parseColonType(tLHS) || parser.parseComma() ||
3909  parser.parseType(tRHS))
3910  return failure();
3911  if (operandsInfo.size() < 2)
3912  return parser.emitError(parser.getNameLoc(),
3913  "expected at least 2 operands");
3914  VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3915  VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3916  if (!vLHS)
3917  return parser.emitError(parser.getNameLoc(),
3918  "expected vector type for operand #1");
3919 
3920  VectorType resType;
3921  if (vRHS) {
3922  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3923  vRHS.getScalableDims()[0]};
3924  resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
3925  vLHS.getElementType(), scalableDimsRes);
3926  } else {
3927  // Scalar RHS operand
3928  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3929  resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3930  scalableDimsRes);
3931  }
3932 
3933  if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
3934  result.attributes.append(
3935  OuterProductOp::getKindAttrName(result.name),
3937  OuterProductOp::getDefaultKind()));
3938  }
3939 
3940  return failure(
3941  parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
3942  parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
3943  (operandsInfo.size() > 2 &&
3944  parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
3945  parser.addTypeToList(resType, result.types));
3946 }
3947 
3948 LogicalResult OuterProductOp::verify() {
3949  Type tRHS = getOperandTypeRHS();
3950  VectorType vLHS = getOperandVectorTypeLHS(),
3951  vRHS = llvm::dyn_cast<VectorType>(tRHS),
3952  vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3953 
3954  if (vLHS.getRank() != 1)
3955  return emitOpError("expected 1-d vector for operand #1");
3956 
3957  if (vRHS) {
3958  // Proper OUTER operation.
3959  if (vRHS.getRank() != 1)
3960  return emitOpError("expected 1-d vector for operand #2");
3961  if (vRES.getRank() != 2)
3962  return emitOpError("expected 2-d vector result");
3963  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3964  return emitOpError("expected #1 operand dim to match result dim #1");
3965  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3966  return emitOpError("expected #2 operand dim to match result dim #2");
3967  if (vLHS.isScalable() && !vRHS.isScalable()) {
3968  // This restriction reflects what's currently supported in terms of
3969  // scalable vectors. However, we could relax this if there's a use case.
3970  return emitOpError(
3971  "expected either both or only #2 operand dim to be scalable");
3972  }
3973  } else {
3974  // An AXPY operation.
3975  if (vRES.getRank() != 1)
3976  return emitOpError("expected 1-d vector result");
3977  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3978  return emitOpError("expected #1 operand dim to match result dim #1");
3979  }
3980 
3981  if (vACC && vACC != vRES)
3982  return emitOpError("expected operand #3 of same type as result type");
3983 
3984  // Verify supported combining kind.
3985  if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
3986  return emitOpError("unsupported outerproduct type");
3987 
3988  return success();
3989 }
3990 
3991 // MaskableOpInterface methods.
3992 
3993 /// Returns the mask type expected by this operation. Mostly used for
3994 /// verification purposes. It requires the operation to be vectorized."
3995 Type OuterProductOp::getExpectedMaskType() {
3996  auto vecType = this->getResultVectorType();
3997  return VectorType::get(vecType.getShape(),
3998  IntegerType::get(vecType.getContext(), /*width=*/1),
3999  vecType.getScalableDims());
4000 }
4001 
4002 //===----------------------------------------------------------------------===//
4003 // ExtractStridedSliceOp
4004 //===----------------------------------------------------------------------===//
4005 
4006 // Inference works as follows:
4007 // 1. Add 'sizes' from prefix of dims in 'offsets'.
4008 // 2. Add sizes from 'vectorType' for remaining dims.
4009 // Scalable flags are inherited from 'vectorType'.
4010 static Type inferStridedSliceOpResultType(VectorType vectorType,
4011  ArrayAttr offsets, ArrayAttr sizes,
4012  ArrayAttr strides) {
4013  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4015  shape.reserve(vectorType.getRank());
4016  unsigned idx = 0;
4017  for (unsigned e = offsets.size(); idx < e; ++idx)
4018  shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4019  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4020  shape.push_back(vectorType.getShape()[idx]);
4021 
4022  return VectorType::get(shape, vectorType.getElementType(),
4023  vectorType.getScalableDims());
4024 }
4025 
4026 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
4027  Value source, ArrayRef<int64_t> offsets,
4028  ArrayRef<int64_t> sizes,
4029  ArrayRef<int64_t> strides) {
4030  result.addOperands(source);
4031  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
4032  auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
4033  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
4034  result.addTypes(
4035  inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
4036  offsetsAttr, sizesAttr, stridesAttr));
4037  result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
4038  offsetsAttr);
4039  result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
4040  sizesAttr);
4041  result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
4042  stridesAttr);
4043 }
4044 
4045 LogicalResult ExtractStridedSliceOp::verify() {
4046  auto type = getSourceVectorType();
4047  auto offsets = getOffsetsAttr();
4048  auto sizes = getSizesAttr();
4049  auto strides = getStridesAttr();
4050  if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4051  return emitOpError(
4052  "expected offsets, sizes and strides attributes of same size");
4053 
4054  auto shape = type.getShape();
4055  auto offName = getOffsetsAttrName();
4056  auto sizesName = getSizesAttrName();
4057  auto stridesName = getStridesAttrName();
4058  if (failed(
4059  isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
4060  failed(
4061  isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
4062  failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
4063  stridesName)) ||
4064  failed(
4065  isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
4066  failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
4067  /*halfOpen=*/false,
4068  /*min=*/1)) ||
4069  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
4070  /*max=*/1, stridesName,
4071  /*halfOpen=*/false)) ||
4072  failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
4073  shape, offName, sizesName,
4074  /*halfOpen=*/false)))
4075  return failure();
4076 
4077  auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
4078  offsets, sizes, strides);
4079  if (getResult().getType() != resultType)
4080  return emitOpError("expected result type to be ") << resultType;
4081 
4082  for (unsigned idx = 0; idx < sizes.size(); ++idx) {
4083  if (type.getScalableDims()[idx]) {
4084  auto inputDim = type.getShape()[idx];
4085  auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4086  if (inputDim != inputSize)
4087  return emitOpError("expected size at idx=")
4088  << idx
4089  << (" to match the corresponding base size from the input "
4090  "vector (")
4091  << inputSize << (" vs ") << inputDim << (")");
4092  }
4093  }
4094 
4095  return success();
4096 }
4097 
4098 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
4099 // to use the source of the InsertStrided ops if we can detect that the
4100 // extracted vector is a subset of one of the vector inserted.
4101 static LogicalResult
4102 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
4103  // Helper to extract integer out of ArrayAttr.
4104  auto getElement = [](ArrayAttr array, int idx) {
4105  return llvm::cast<IntegerAttr>(array[idx]).getInt();
4106  };
4107  ArrayAttr extractOffsets = op.getOffsets();
4108  ArrayAttr extractStrides = op.getStrides();
4109  ArrayAttr extractSizes = op.getSizes();
4110  auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4111  while (insertOp) {
4112  if (op.getSourceVectorType().getRank() !=
4113  insertOp.getSourceVectorType().getRank())
4114  return failure();
4115  ArrayAttr insertOffsets = insertOp.getOffsets();
4116  ArrayAttr insertStrides = insertOp.getStrides();
4117  // If the rank of extract is greater than the rank of insert, we are likely
4118  // extracting a partial chunk of the vector inserted.
4119  if (extractOffsets.size() > insertOffsets.size())
4120  return failure();
4121  bool patialoverlap = false;
4122  bool disjoint = false;
4123  SmallVector<int64_t, 4> offsetDiffs;
4124  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4125  if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
4126  return failure();
4127  int64_t start = getElement(insertOffsets, dim);
4128  int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4129  int64_t offset = getElement(extractOffsets, dim);
4130  int64_t size = getElement(extractSizes, dim);
4131  // Check if the start of the extract offset is in the interval inserted.
4132  if (start <= offset && offset < end) {
4133  // If the extract interval overlaps but is not fully included we may
4134  // have a partial overlap that will prevent any folding.
4135  if (offset + size > end)
4136  patialoverlap = true;
4137  offsetDiffs.push_back(offset - start);
4138  continue;
4139  }
4140  disjoint = true;
4141  break;
4142  }
4143  // The extract element chunk is a subset of the insert element.
4144  if (!disjoint && !patialoverlap) {
4145  op.setOperand(insertOp.getValueToStore());
4146  // OpBuilder is only used as a helper to build an I64ArrayAttr.
4147  OpBuilder b(op.getContext());
4148  op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
4149  return success();
4150  }
4151  // If the chunk extracted is disjoint from the chunk inserted, keep looking
4152  // in the insert chain.
4153  if (disjoint)
4154  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4155  else {
4156  // The extracted vector partially overlap the inserted vector, we cannot
4157  // fold.
4158  return failure();
4159  }
4160  }
4161  return failure();
4162 }
4163 
4164 // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4165 static OpFoldResult
4167  Attribute foldInput) {
4168 
4169  auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4170  if (!dense)
4171  return {};
4172 
4173  // TODO: Handle non-unit strides when they become available.
4174  if (op.hasNonUnitStrides())
4175  return {};
4176 
4177  VectorType sourceVecTy = op.getSourceVectorType();
4178  ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
4179  SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
4180 
4181  VectorType sliceVecTy = op.getType();
4182  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4183  int64_t rank = sliceVecTy.getRank();
4184 
4185  // Expand offsets and sizes to match the vector rank.
4186  SmallVector<int64_t, 4> offsets(rank, 0);
4187  copy(getI64SubArray(op.getOffsets()), offsets.begin());
4188 
4189  SmallVector<int64_t, 4> sizes(sourceShape);
4190  copy(getI64SubArray(op.getSizes()), sizes.begin());
4191 
4192  // Calculate the slice elements by enumerating all slice positions and
4193  // linearizing them. The enumeration order is lexicographic which yields a
4194  // sequence of monotonically increasing linearized position indices.
4195  const auto denseValuesBegin = dense.value_begin<Attribute>();
4196  SmallVector<Attribute> sliceValues;
4197  sliceValues.reserve(sliceVecTy.getNumElements());
4198  SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
4199  do {
4200  int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
4201  assert(linearizedPosition < sourceVecTy.getNumElements() &&
4202  "Invalid index");
4203  sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4204  } while (succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
4205 
4206  assert(static_cast<int64_t>(sliceValues.size()) ==
4207  sliceVecTy.getNumElements() &&
4208  "Invalid number of slice elements");
4209  return DenseElementsAttr::get(sliceVecTy, sliceValues);
4210 }
4211 
4212 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4213  if (getSourceVectorType() == getResult().getType())
4214  return getSource();
4215  if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
4216  return getResult();
4217 
4218  // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
4219  if (auto splat =
4220  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4221  DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
4222 
4223  // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4224  return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getSource());
4225 }
4226 
4227 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4228  populateFromInt64AttrArray(getOffsets(), results);
4229 }
4230 
4231 namespace {
4232 
4233 // Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
4234 // CreateMaskOp.
4235 //
4236 // Example:
4237 //
4238 // %mask = vector.create_mask %ub : vector<16xi1>
4239 // %slice = vector.extract_strided_slice [%offset] [8] [1]
4240 //
4241 // to
4242 //
4243 // %new_ub = arith.subi %ub, %offset
4244 // %mask = vector.create_mask %new_ub : vector<8xi1>
4245 class StridedSliceCreateMaskFolder final
4246  : public OpRewritePattern<ExtractStridedSliceOp> {
4248 
4249 public:
4250  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4251  PatternRewriter &rewriter) const override {
4252  Location loc = extractStridedSliceOp.getLoc();
4253  // Return if 'extractStridedSliceOp' operand is not defined by a
4254  // CreateMaskOp.
4255  auto createMaskOp =
4256  extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4257  if (!createMaskOp)
4258  return failure();
4259  // Return if 'extractStridedSliceOp' has non-unit strides.
4260  if (extractStridedSliceOp.hasNonUnitStrides())
4261  return failure();
4262  // Gather constant mask dimension sizes.
4263  SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4264  // Gather strided slice offsets and sizes.
4265  SmallVector<int64_t> sliceOffsets;
4266  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4267  sliceOffsets);
4268  SmallVector<int64_t> sliceSizes;
4269  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4270 
4271  // Compute slice of vector mask region.
4272  SmallVector<Value> sliceMaskDimSizes;
4273  sliceMaskDimSizes.reserve(maskDimSizes.size());
4274  // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4275  // only iterate on the leading dim sizes. The tail accounts for the
4276  // remaining dim sizes.
4277  for (auto [maskDimSize, sliceOffset, sliceSize] :
4278  llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4279  // No need to clamp on min/max values, because create_mask has clamping
4280  // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
4281  // greater than the vector dim size.
4282  IntegerAttr offsetAttr =
4283  rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
4284  Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4285  Value sliceMaskDimSize =
4286  arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4287  sliceMaskDimSizes.push_back(sliceMaskDimSize);
4288  }
4289  // Add unchanged dimensions.
4290  llvm::append_range(
4291  sliceMaskDimSizes,
4292  llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4293  // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
4294  // region.
4295  rewriter.replaceOpWithNewOp<CreateMaskOp>(
4296  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4297  sliceMaskDimSizes);
4298  return success();
4299  }
4300 };
4301 
4302 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
4303 // ConstantMaskOp.
4304 class StridedSliceConstantMaskFolder final
4305  : public OpRewritePattern<ExtractStridedSliceOp> {
4306 public:
4308 
4309  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4310  PatternRewriter &rewriter) const override {
4311  // Return if 'extractStridedSliceOp' operand is not defined by a
4312  // ConstantMaskOp.
4313  auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4314  auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4315  if (!constantMaskOp)
4316  return failure();
4317  // Return if 'extractStridedSliceOp' has non-unit strides.
4318  if (extractStridedSliceOp.hasNonUnitStrides())
4319  return failure();
4320  // Gather constant mask dimension sizes.
4321  ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4322  // Gather strided slice offsets and sizes.
4323  SmallVector<int64_t> sliceOffsets;
4324  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4325  sliceOffsets);
4326  SmallVector<int64_t> sliceSizes;
4327  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4328 
4329  // Compute slice of vector mask region.
4330  SmallVector<int64_t> sliceMaskDimSizes;
4331  sliceMaskDimSizes.reserve(maskDimSizes.size());
4332  for (auto [maskDimSize, sliceOffset, sliceSize] :
4333  llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4334  int64_t sliceMaskDimSize = std::max(
4335  static_cast<int64_t>(0),
4336  std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4337  sliceMaskDimSizes.push_back(sliceMaskDimSize);
4338  }
4339  // Add unchanged dimensions.
4340  if (sliceMaskDimSizes.size() < maskDimSizes.size())
4341  for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4342  sliceMaskDimSizes.push_back(maskDimSizes[i]);
4343  // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
4344  // region is a conjunction of mask dim intervals).
4345  if (llvm::is_contained(sliceMaskDimSizes, 0))
4346  sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4347 
4348  // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
4349  // region.
4350  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4351  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4352  sliceMaskDimSizes);
4353  return success();
4354  }
4355 };
4356 
4357 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
4358 // BroadcastOp(ExtractStrideSliceOp).
4359 class StridedSliceBroadcast final
4360  : public OpRewritePattern<ExtractStridedSliceOp> {
4361 public:
4363 
4364  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4365  PatternRewriter &rewriter) const override {
4366  auto broadcast = op.getSource().getDefiningOp<BroadcastOp>();
4367  if (!broadcast)
4368  return failure();
4369  auto srcVecType =
4370  llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
4371  unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4372  auto dstVecType = llvm::cast<VectorType>(op.getType());
4373  unsigned dstRank = dstVecType.getRank();
4374  unsigned rankDiff = dstRank - srcRank;
4375  // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4376  // (n -> m with n > m). If they are originally both broadcasted *and*
4377  // sliced, this can be simplified to just broadcasting.
4378  bool needsSlice = false;
4379  for (unsigned i = 0; i < srcRank; i++) {
4380  if (srcVecType.getDimSize(i) != 1 &&
4381  srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4382  needsSlice = true;
4383  break;
4384  }
4385  }
4386  Value source = broadcast.getSource();
4387  if (needsSlice) {
4388  SmallVector<int64_t> offsets =
4389  getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
4390  SmallVector<int64_t> sizes =
4391  getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
4392  for (unsigned i = 0; i < srcRank; i++) {
4393  if (srcVecType.getDimSize(i) == 1) {
4394  // In case this dimension was broadcasted *and* sliced, the offset
4395  // and size need to be updated now that there is no broadcast before
4396  // the slice.
4397  offsets[i] = 0;
4398  sizes[i] = 1;
4399  }
4400  }
4401  source = ExtractStridedSliceOp::create(
4402  rewriter, op->getLoc(), source, offsets, sizes,
4403  getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
4404  }
4405  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
4406  return success();
4407  }
4408 };
4409 
4410 /// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
4411 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
4412 public:
4414 
4415  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4416  PatternRewriter &rewriter) const override {
4417 
4418  Value splat = getScalarSplatSource(op.getSource());
4419  if (!splat)
4420  return failure();
4421  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
4422  return success();
4423  }
4424 };
4425 
4426 /// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
4427 /// slice is contiguous, into extract and shape_cast.
4428 ///
4429 /// Example:
4430 /// Before:
4431 /// %1 = vector.extract_strided_slice %arg0 {
4432 /// offsets = [0, 0, 0, 0, 0],
4433 /// sizes = [1, 1, 1, 1, 8],
4434 /// strides = [1, 1, 1, 1, 1]
4435 /// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
4436 /// After:
4437 /// %0 = vector.extract %arg0[0, 0, 0, 0]
4438 /// : vector<8xi8> from vector<8x1x1x2x8xi8>
4439 /// %1 = vector.shape_cast %0
4440 /// : vector<8xi8> to vector<1x1x1x1x8xi8>
4441 ///
4442 class ContiguousExtractStridedSliceToExtract final
4443  : public OpRewritePattern<ExtractStridedSliceOp> {
4444 public:
4446 
4447  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4448  PatternRewriter &rewriter) const override {
4449  if (op.hasNonUnitStrides())
4450  return failure();
4451  Value source = op.getOperand();
4452  auto sourceType = cast<VectorType>(source.getType());
4453  if (sourceType.isScalable() || sourceType.getRank() == 0)
4454  return failure();
4455 
4456  // Compute the number of offsets to pass to ExtractOp::build. That is the
4457  // difference between the source rank and the desired slice rank. We walk
4458  // the dimensions from innermost out, and stop when the next slice dimension
4459  // is not full-size.
4460  SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
4461  int numOffsets;
4462  for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4463  if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4464  break;
4465  }
4466 
4467  // If the created extract op would have no offsets, then this whole
4468  // extract_strided_slice is the identity and should have been handled by
4469  // other canonicalizations.
4470  if (numOffsets == 0)
4471  return failure();
4472 
4473  // If not even the inner-most dimension is full-size, this op can't be
4474  // rewritten as an ExtractOp.
4475  if (numOffsets == sourceType.getRank() &&
4476  static_cast<int>(sizes.size()) == sourceType.getRank())
4477  return failure();
4478 
4479  // The outer dimensions must have unit size.
4480  for (int i = 0; i < numOffsets; ++i) {
4481  if (sizes[i] != 1)
4482  return failure();
4483  }
4484 
4485  // Avoid generating slices that have leading unit dimensions. The shape_cast
4486  // op that we create below would take bad generic fallback patterns
4487  // (ShapeCastOpRewritePattern).
4488  while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
4489  sizes[numOffsets] == 1) {
4490  ++numOffsets;
4491  }
4492 
4493  SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
4494  auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4495  Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4496  extractOffsets);
4497  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
4498  return success();
4499  }
4500 };
4501 
4502 } // namespace
4503 
4504 void ExtractStridedSliceOp::getCanonicalizationPatterns(
4505  RewritePatternSet &results, MLIRContext *context) {
4506  // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
4507  // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4508  results.add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
4509  StridedSliceBroadcast, StridedSliceSplat,
4510  ContiguousExtractStridedSliceToExtract>(context);
4511 }
4512 
4513 //===----------------------------------------------------------------------===//
4514 // TransferReadOp
4515 //===----------------------------------------------------------------------===//
4516 
4517 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
4518 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4519  VectorType vectorType, Value source,
4520  ValueRange indices, std::optional<Value> padding,
4521  AffineMapAttr permutationMapAttr,
4522  /*optional*/ ArrayAttr inBoundsAttr) {
4523 
4524  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4525  if (!padding)
4526  padding = ub::PoisonOp::create(builder, result.location, elemType);
4527  build(builder, result, vectorType, source, indices, permutationMapAttr,
4528  *padding, /*mask=*/Value(), inBoundsAttr);
4529 }
4530 
4531 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
4532 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4533  VectorType vectorType, Value source,
4534  ValueRange indices, std::optional<Value> padding,
4535  AffineMap permutationMap,
4536  std::optional<ArrayRef<bool>> inBounds) {
4537  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4538  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4539  ? builder.getBoolArrayAttr(inBounds.value())
4540  : builder.getBoolArrayAttr(
4541  SmallVector<bool>(vectorType.getRank(), false));
4542  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4543  if (!padding)
4544  padding = ub::PoisonOp::create(builder, result.location, elemType);
4545  build(builder, result, vectorType, source, indices, *padding,
4546  permutationMapAttr, inBoundsAttr);
4547 }
4548 
4549 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
4550 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4551  VectorType vectorType, Value source,
4552  ValueRange indices, std::optional<Value> padding,
4553  std::optional<ArrayRef<bool>> inBounds) {
4554  AffineMap permutationMap = getTransferMinorIdentityMap(
4555  llvm::cast<ShapedType>(source.getType()), vectorType);
4556  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4557  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4558  ? builder.getBoolArrayAttr(inBounds.value())
4559  : builder.getBoolArrayAttr(
4560  SmallVector<bool>(vectorType.getRank(), false));
4561  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4562  if (!padding)
4563  padding = ub::PoisonOp::create(builder, result.location, elemType);
4564  build(builder, result, vectorType, source, indices, permutationMapAttr,
4565  *padding,
4566  /*mask=*/Value(), inBoundsAttr);
4567 }
4568 
4569 template <typename EmitFun>
4570 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
4571  EmitFun emitOpError) {
4572  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
4573  for (auto expr : permutationMap.getResults()) {
4574  auto dim = dyn_cast<AffineDimExpr>(expr);
4575  auto zero = dyn_cast<AffineConstantExpr>(expr);
4576  if (zero) {
4577  if (zero.getValue() != 0) {
4578  return emitOpError(
4579  "requires a projected permutation_map (at most one dim or the zero "
4580  "constant can appear in each result)");
4581  }
4582  continue;
4583  }
4584  if (!dim) {
4585  return emitOpError("requires a projected permutation_map (at most one "
4586  "dim or the zero constant can appear in each result)");
4587  }
4588  if (seen[dim.getPosition()]) {
4589  return emitOpError(
4590  "requires a permutation_map that is a permutation (found one dim "
4591  "used more than once)");
4592  }
4593  seen[dim.getPosition()] = true;
4594  }
4595  return success();
4596 }
4597 
4598 static LogicalResult
4599 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4600  VectorType vectorType, VectorType maskType,
4601  VectorType inferredMaskType, AffineMap permutationMap,
4602  ArrayAttr inBounds) {
4603  if (op->hasAttr("masked")) {
4604  return op->emitOpError("masked attribute has been removed. "
4605  "Use in_bounds instead.");
4606  }
4607 
4608  if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4609  return op->emitOpError(
4610  "requires source to be a memref or ranked tensor type");
4611 
4612  auto elementType = shapedType.getElementType();
4613  DataLayout dataLayout = DataLayout::closest(op);
4614  if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4615  // Memref or tensor has vector element type.
4616  unsigned sourceVecSize =
4617  dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
4618  vectorElementType.getShape().back();
4619  unsigned resultVecSize =
4620  dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
4621  vectorType.getShape().back();
4622  if (resultVecSize % sourceVecSize != 0)
4623  return op->emitOpError(
4624  "requires the bitwidth of the minor 1-D vector to be an integral "
4625  "multiple of the bitwidth of the minor 1-D vector of the source");
4626 
4627  unsigned sourceVecEltRank = vectorElementType.getRank();
4628  unsigned resultVecRank = vectorType.getRank();
4629  if (sourceVecEltRank > resultVecRank)
4630  return op->emitOpError(
4631  "requires source vector element and vector result ranks to match.");
4632  unsigned rankOffset = resultVecRank - sourceVecEltRank;
4633  // Check that permutation map results match 'rankOffset' of vector type.
4634  if (permutationMap.getNumResults() != rankOffset)
4635  return op->emitOpError("requires a permutation_map with result dims of "
4636  "the same rank as the vector type");
4637 
4638  if (maskType)
4639  return op->emitOpError("does not support masks with vector element type");
4640  } else {
4641  // Memref or tensor has scalar element type.
4642  unsigned minorSize =
4643  vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4644  unsigned resultVecSize =
4645  dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
4646  if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
4647  return op->emitOpError(
4648  "requires the bitwidth of the minor 1-D vector to be an integral "
4649  "multiple of the bitwidth of the source element type");
4650 
4651  // Check that permutation map results match rank of vector type.
4652  if (permutationMap.getNumResults() != vectorType.getRank())
4653  return op->emitOpError("requires a permutation_map with result dims of "
4654  "the same rank as the vector type");
4655  }
4656 
4657  if (permutationMap.getNumSymbols() != 0)
4658  return op->emitOpError("requires permutation_map without symbols");
4659 
4660  if (permutationMap.getNumInputs() != shapedType.getRank())
4661  return op->emitOpError("requires a permutation_map with input dims of the "
4662  "same rank as the source type");
4663 
4664  if (maskType && maskType != inferredMaskType)
4665  return op->emitOpError("inferred mask type (")
4666  << inferredMaskType << ") and mask operand type (" << maskType
4667  << ") don't match";
4668 
4669  if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
4670  return op->emitOpError("expects the in_bounds attr of same rank "
4671  "as permutation_map results: ")
4672  << AffineMapAttr::get(permutationMap)
4673  << " vs inBounds of size: " << inBounds.size();
4674 
4675  return success();
4676 }
4677 
4678 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
4679  SmallVector<StringRef, 3> elidedAttrs;
4680  elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4681  if (op.getPermutationMap().isMinorIdentity())
4682  elidedAttrs.push_back(op.getPermutationMapAttrName());
4683  // Elide in_bounds attribute if all dims are out-of-bounds.
4684  if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
4685  elidedAttrs.push_back(op.getInBoundsAttrName());
4686  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4687 }
4688 
4690  p << " " << getBase() << "[" << getIndices() << "], " << getPadding();
4691  if (getMask())
4692  p << ", " << getMask();
4693  printTransferAttrs(p, *this);
4694  p << " : " << getShapedType() << ", " << getVectorType();
4695 }
4696 
4697 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
4698  AffineMap permMap) {
4699  auto i1Type = IntegerType::get(permMap.getContext(), 1);
4700  AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
4701  assert(invPermMap && "Inversed permutation map couldn't be computed");
4702  SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
4703 
4704  // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
4705  // 0-D mask into a single-element 1-D mask.
4706  if (maskShape.empty())
4707  maskShape.push_back(1);
4708 
4709  SmallVector<bool> scalableDims =
4710  applyPermutationMap(invPermMap, vecType.getScalableDims());
4711 
4712  return VectorType::get(maskShape, i1Type, scalableDims);
4713 }
4714 
4715 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4716  auto &builder = parser.getBuilder();
4717  SMLoc typesLoc;
4718  OpAsmParser::UnresolvedOperand sourceInfo;
4720  OpAsmParser::UnresolvedOperand paddingInfo;
4721  SmallVector<Type, 2> types;
4723  // Parsing with support for paddingValue.
4724  if (parser.parseOperand(sourceInfo) ||
4725  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
4726  parser.parseComma() || parser.parseOperand(paddingInfo))
4727  return failure();
4728  ParseResult hasMask = parser.parseOptionalComma();
4729  if (hasMask.succeeded()) {
4730  if (parser.parseOperand(maskInfo))
4731  return failure();
4732  }
4733  if (parser.parseOptionalAttrDict(result.attributes) ||
4734  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4735  return failure();
4736  if (types.size() != 2)
4737  return parser.emitError(typesLoc, "requires two types");
4738  auto indexType = builder.getIndexType();
4739  auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4740  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4741  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4742  VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4743  if (!vectorType)
4744  return parser.emitError(typesLoc, "requires vector type");
4745  auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
4746  Attribute permMapAttr = result.attributes.get(permMapAttrName);
4747  AffineMap permMap;
4748  if (!permMapAttr) {
4749  if (shapedType.getRank() <
4750  getEffectiveVectorRankForXferOp(shapedType, vectorType))
4751  return parser.emitError(typesLoc,
4752  "expected a custom permutation_map when "
4753  "rank(source) != rank(destination)");
4754  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4755  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4756  } else {
4757  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4758  }
4759  auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
4760  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4761  if (!inBoundsAttr) {
4762  result.addAttribute(inBoundsAttrName,
4763  builder.getBoolArrayAttr(
4764  SmallVector<bool>(permMap.getNumResults(), false)));
4765  }
4766  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4767  parser.resolveOperands(indexInfo, indexType, result.operands) ||
4768  parser.resolveOperand(paddingInfo, shapedType.getElementType(),
4769  result.operands))
4770  return failure();
4771  if (hasMask.succeeded()) {
4772  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4773  return parser.emitError(
4774  maskInfo.location, "does not support masks with vector element type");
4775  if (vectorType.getRank() != permMap.getNumResults()) {
4776  return parser.emitError(typesLoc,
4777  "expected the same rank for the vector and the "
4778  "results of the permutation map");
4779  }
4780  // Instead of adding the mask type as an op type, compute it based on the
4781  // vector type and the permutation map (to keep the type signature small).
4782  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4783  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4784  return failure();
4785  }
4786  result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4787  builder.getDenseI32ArrayAttr(
4788  {1, static_cast<int32_t>(indexInfo.size()), 1,
4789  static_cast<int32_t>(hasMask.succeeded())}));
4790  return parser.addTypeToList(vectorType, result.types);
4791 }
4792 
4793 LogicalResult TransferReadOp::verify() {
4794  // Consistency of elemental types in source and vector.
4795  ShapedType shapedType = getShapedType();
4796  VectorType vectorType = getVectorType();
4797  VectorType maskType = getMaskType();
4798  auto paddingType = getPadding().getType();
4799  auto permutationMap = getPermutationMap();
4800  VectorType inferredMaskType =
4801  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4802  : VectorType();
4803  auto sourceElementType = shapedType.getElementType();
4804 
4805  if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
4806  return emitOpError("requires ") << shapedType.getRank() << " indices";
4807 
4808  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4809  shapedType, vectorType, maskType,
4810  inferredMaskType, permutationMap, getInBounds())))
4811  return failure();
4812 
4813  if (auto sourceVectorElementType =
4814  llvm::dyn_cast<VectorType>(sourceElementType)) {
4815  // Source has vector element type.
4816  // Check that 'sourceVectorElementType' and 'paddingType' types match.
4817  if (sourceVectorElementType != paddingType)
4818  return emitOpError(
4819  "requires source element type and padding type to match.");
4820 
4821  } else {
4822  // Check that 'paddingType' is valid to store in a vector type.
4823  if (!VectorType::isValidElementType(paddingType))
4824  return emitOpError("requires valid padding vector elemental type");
4825 
4826  // Check that padding type and vector element types match.
4827  if (paddingType != sourceElementType)
4828  return emitOpError(
4829  "requires formal padding and source of the same elemental type");
4830  }
4831 
4832  return verifyPermutationMap(permutationMap,
4833  [&](Twine t) { return emitOpError(t); });
4834 }
4835 
4836 // MaskableOpInterface methods.
4837 
4838 /// Returns the mask type expected by this operation. Mostly used for
4839 /// verification purposes. It requires the operation to be vectorized."
4840 Type TransferReadOp::getExpectedMaskType() {
4841  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4842 }
4843 
4844 //===----------------------------------------------------------------------===//
4845 // TransferReadOp: VectorTransferOpInterface methods.
4846 //===----------------------------------------------------------------------===//
4847 VectorType TransferReadOp::getVectorType() {
4848  return cast<VectorType>(getVector().getType());
4849 }
4850 
4851 template <typename TransferOp>
4852 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4853  // TODO: support more aggressive createOrFold on:
4854  // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4855  if (op.getShapedType().isDynamicDim(indicesIdx))
4856  return false;
4857  Value index = op.getIndices()[indicesIdx];
4858  std::optional<int64_t> cstOp = getConstantIntValue(index);
4859  if (!cstOp.has_value())
4860  return false;
4861 
4862  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4863  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4864 
4865  return cstOp.value() + vectorSize <= sourceSize;
4866 }
4867 
4868 template <typename TransferOp>
4869 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4870  // TODO: support 0-d corner case.
4871  // TODO: Be less conservative.
4872  if (op.getTransferRank() == 0)
4873  return failure();
4874  AffineMap permutationMap = op.getPermutationMap();
4875  bool changed = false;
4876  SmallVector<bool, 4> newInBounds;
4877  newInBounds.reserve(op.getTransferRank());
4878  // Idxs of non-bcast dims - used when analysing bcast dims.
4879  SmallVector<unsigned> nonBcastDims;
4880 
4881  // 1. Process non-broadcast dims
4882  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
4883  // 1.1. Already marked as in-bounds, nothing to see here.
4884  if (op.isDimInBounds(i)) {
4885  newInBounds.push_back(true);
4886  continue;
4887  }
4888  // 1.2. Currently out-of-bounds, check whether we can statically determine
4889  // it is inBounds.
4890  bool inBounds = false;
4891  auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
4892  if (dimExpr) {
4893  inBounds = isInBounds(op, /*resultIdx=*/i,
4894  /*indicesIdx=*/dimExpr.getPosition());
4895  nonBcastDims.push_back(i);
4896  }
4897 
4898  newInBounds.push_back(inBounds);
4899  // We commit the pattern if it is "more inbounds".
4900  changed |= inBounds;
4901  }
4902 
4903  // 2. Handle broadcast dims
4904  // If all non-broadcast dims are "in bounds", then all bcast dims should be
4905  // "in bounds" as well.
4906  bool allNonBcastDimsInBounds = llvm::all_of(
4907  nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
4908  if (allNonBcastDimsInBounds) {
4909  for (size_t idx : permutationMap.getBroadcastDims()) {
4910  changed |= !newInBounds[idx];
4911  newInBounds[idx] = true;
4912  }
4913  }
4914 
4915  if (!changed)
4916  return failure();
4917  // OpBuilder is only used as a helper to build an I64ArrayAttr.
4918  OpBuilder b(op.getContext());
4919  op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
4920  return success();
4921 }
4922 
4923 template <typename TransferOp>
4924 static LogicalResult foldTransferFullMask(TransferOp op) {
4925  auto mask = op.getMask();
4926  if (!mask)
4927  return failure();
4928 
4929  if (getMaskFormat(mask) != MaskFormat::AllTrue)
4930  return failure();
4931 
4932  op.getMaskMutable().clear();
4933  return success();
4934 }
4935 
4936 /// ```
4937 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4938 /// : vector<1x4xf32>, tensor<4x4xf32>
4939 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
4940 /// : tensor<4x4xf32>, vector<1x4xf32>
4941 /// ```
4942 /// -> Folds into
4943 /// ```
4944 /// %v0
4945 /// ```
4946 static Value foldRAW(TransferReadOp readOp) {
4947  if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4948  return {};
4949  auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
4950  while (defWrite) {
4951  if (checkSameValueRAW(defWrite, readOp))
4952  return defWrite.getVector();
4954  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4955  cast<VectorTransferOpInterface>(readOp.getOperation())))
4956  break;
4957  defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
4958  }
4959  return {};
4960 }
4961 
4962 OpFoldResult TransferReadOp::fold(FoldAdaptor) {
4963  if (Value vec = foldRAW(*this))
4964  return vec;
4965  /// transfer_read(memrefcast) -> transfer_read
4966  if (succeeded(foldTransferInBoundsAttribute(*this)))
4967  return getResult();
4968  if (succeeded(foldTransferFullMask(*this)))
4969  return getResult();
4970  if (succeeded(memref::foldMemRefCast(*this)))
4971  return getResult();
4972  if (succeeded(tensor::foldTensorCast(*this)))
4973  return getResult();
4974  return OpFoldResult();
4975 }
4976 
4977 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4978  return llvm::to_vector<4>(getVectorType().getShape());
4979 }
4980 
4981 void TransferReadOp::getEffects(
4983  &effects) {
4984  if (llvm::isa<MemRefType>(getShapedType()))
4985  effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
4987 }
4988 
4989 Speculation::Speculatability TransferReadOp::getSpeculatability() {
4990  if (hasPureTensorSemantics())
4993 }
4994 
4995 namespace {
4996 /// Store to load forwarding for transfer operations with permuation maps.
4997 /// Even if the permutation maps are different we can still propagate the store
4998 /// into the load if the size of the dimensions read and written match. Then we
4999 /// can replace the transfer_read + transfer_write by vector.broadcast and
5000 /// vector.transpose.
5001 /// Example:
5002 /// ```
5003 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
5004 /// {in_bounds = [true, true],
5005 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
5006 /// vector<4x1xf32>, tensor<4x4x4xf32>
5007 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
5008 /// {in_bounds = [true, true, true, true],
5009 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
5010 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
5011 /// ```
5012 /// To:
5013 /// ```
5014 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
5015 /// %r = vector.transpose %0, [3, 0, 2, 1] :
5016 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
5017 /// ```
5018 struct TransferReadAfterWriteToBroadcast
5019  : public OpRewritePattern<TransferReadOp> {
5021 
5022  LogicalResult matchAndRewrite(TransferReadOp readOp,
5023  PatternRewriter &rewriter) const override {
5024  auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5025  if (!defWrite)
5026  return failure();
5027  // Bail if we need an alias analysis.
5028  if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5029  return failure();
5030  // Bail if we need a bounds analysis.
5031  if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5032  return failure();
5033  // TODO: If the written transfer chunk is a superset of the read transfer
5034  // chunk we could do an extract_strided_slice.
5035  if (readOp.getTransferChunkAccessed() !=
5036  defWrite.getTransferChunkAccessed())
5037  return failure();
5038  // TODO: Support cases where a dim is explicitly written but implicitly
5039  // read (i.e., a unit dim that is rank reduced).
5040  if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
5041  getUnusedDimsBitVector({defWrite.getPermutationMap()}))
5042  return failure();
5043  // This pattern should only catch the broadcast case, the non-broadcast case
5044  // should be done separately to keep application conditions clean and
5045  // separate.
5046  AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
5047  AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
5048  bool bcast = !readMap.getBroadcastDims().empty() ||
5049  !writeMap.getBroadcastDims().empty();
5050  if (!bcast)
5051  return failure();
5052  // At this point, we know we have a bcast.
5053  // Bail in the masked case (too complex atm and needed to properly account
5054  // for padding).
5055  if (readOp.getMask() || defWrite.getMask())
5056  return failure();
5057  // If indices are not the same a shift may be required, bail.
5058  if (readOp.getIndices() != defWrite.getIndices())
5059  return failure();
5060 
5061  Value vec = defWrite.getVector();
5062  // TODO: loop through the chain of transfer_write if we can prove that they
5063  // don't overlap with the transfer_read. This requires improving
5064  // `isDisjointTransferIndices` helper.
5065  AffineMap map = readMap.compose(writeMap);
5066  if (map.getNumResults() == 0)
5067  return failure();
5068  // Calculate the permutation to apply to go from the vector stored to the
5069  // vector read.
5070  SmallVector<unsigned> permutation;
5071  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
5072  return failure();
5073 
5074  Location loc = readOp.getLoc();
5075  // Calculate the broadcast shape by applying the reverse permutation to the
5076  // final shape we want.
5077  ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
5078  SmallVector<int64_t> broadcastShape(destShape.size());
5079  SmallVector<bool> broadcastScalableFlags(destShape.size());
5080  for (const auto &pos : llvm::enumerate(permutation)) {
5081  broadcastShape[pos.value()] = destShape[pos.index()];
5082  broadcastScalableFlags[pos.value()] =
5083  readOp.getVectorType().getScalableDims()[pos.index()];
5084  }
5085  VectorType broadcastedType = VectorType::get(
5086  broadcastShape, defWrite.getVectorType().getElementType(),
5087  broadcastScalableFlags);
5088  vec = vector::BroadcastOp::create(rewriter, loc, broadcastedType, vec);
5089  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
5090  rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
5091  transposePerm);
5092  return success();
5093  }
5094 };
5095 } // namespace
5096 
5097 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5098  MLIRContext *context) {
5099  results.add<TransferReadAfterWriteToBroadcast>(context);
5100 }
5101 
5102 //===----------------------------------------------------------------------===//
5103 // TransferWriteOp
5104 //===----------------------------------------------------------------------===//
5105 
5106 /// 1. Builder with type inference.
5107 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5108  Value vector, Value dest, ValueRange indices,
5109  AffineMapAttr permutationMapAttr,
5110  /*optional*/ Value mask,
5111  /*optional*/ ArrayAttr inBoundsAttr) {
5112  Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
5113  build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
5114  mask, inBoundsAttr);
5115 }
5116 
5117 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
5118 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5119  Value vector, Value dest, ValueRange indices,
5120  AffineMapAttr permutationMapAttr,
5121  /*optional*/ ArrayAttr inBoundsAttr) {
5122  build(builder, result, vector, dest, indices, permutationMapAttr,
5123  /*mask=*/Value(), inBoundsAttr);
5124 }
5125 
5126 /// 3. Builder with type inference that sets an empty mask (variant without
5127 /// attrs)
5128 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5129  Value vector, Value dest, ValueRange indices,
5130  AffineMap permutationMap,
5131  std::optional<ArrayRef<bool>> inBounds) {
5132  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5133  auto inBoundsAttr =
5134  (inBounds && !inBounds.value().empty())
5135  ? builder.getBoolArrayAttr(inBounds.value())
5137  llvm::cast<VectorType>(vector.getType()).getRank(), false));
5138  build(builder, result, vector, dest, indices, permutationMapAttr,
5139  /*mask=*/Value(), inBoundsAttr);
5140 }
5141 
5142 /// 4. Builder with type inference that sets an empty mask and sets permutation
5143 /// map to 'getMinorIdentityMap'.
5144 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5145  Value vector, Value dest, ValueRange indices,
5146  std::optional<ArrayRef<bool>> inBounds) {
5147  auto vectorType = llvm::cast<VectorType>(vector.getType());
5148  AffineMap permutationMap = getTransferMinorIdentityMap(
5149  llvm::cast<ShapedType>(dest.getType()), vectorType);
5150  build(builder, result, vector, dest, indices, permutationMap, inBounds);
5151 }
5152 
5153 ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5154  OperationState &result) {
5155  auto &builder = parser.getBuilder();
5156  SMLoc typesLoc;
5157  OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5159  SmallVector<Type, 2> types;
5161  if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
5162  parser.parseOperand(sourceInfo) ||
5163  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
5164  return failure();
5165  ParseResult hasMask = parser.parseOptionalComma();
5166  if (hasMask.succeeded() && parser.parseOperand(maskInfo))
5167  return failure();
5168  if (parser.parseOptionalAttrDict(result.attributes) ||
5169  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5170  return failure();
5171  if (types.size() != 2)
5172  return parser.emitError(typesLoc, "requires two types");
5173  auto indexType = builder.getIndexType();
5174  VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5175  if (!vectorType)
5176  return parser.emitError(typesLoc, "requires vector type");
5177  ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5178  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5179  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5180  auto permMapAttrName =
5181  TransferWriteOp::getPermutationMapAttrName(result.name);
5182  auto permMapAttr = result.attributes.get(permMapAttrName);
5183  AffineMap permMap;
5184  if (!permMapAttr) {
5185  if (shapedType.getRank() <
5186  getEffectiveVectorRankForXferOp(shapedType, vectorType))
5187  return parser.emitError(typesLoc,
5188  "expected a custom permutation_map when "
5189  "rank(source) != rank(destination)");
5190  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5191  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5192  } else {
5193  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5194  }
5195  auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
5196  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5197  if (!inBoundsAttr) {
5198  result.addAttribute(inBoundsAttrName,
5199  builder.getBoolArrayAttr(
5200  SmallVector<bool>(permMap.getNumResults(), false)));
5201  }
5202  if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
5203  parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5204  parser.resolveOperands(indexInfo, indexType, result.operands))
5205  return failure();
5206  if (hasMask.succeeded()) {
5207  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5208  return parser.emitError(
5209  maskInfo.location, "does not support masks with vector element type");
5210  if (vectorType.getRank() != permMap.getNumResults()) {
5211  return parser.emitError(typesLoc,
5212  "expected the same rank for the vector and the "
5213  "results of the permutation map");
5214  }
5215  auto maskType = inferTransferOpMaskType(vectorType, permMap);
5216  if (parser.resolveOperand(maskInfo, maskType, result.operands))
5217  return failure();
5218  }
5219  result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5220  builder.getDenseI32ArrayAttr(
5221  {1, 1, static_cast<int32_t>(indexInfo.size()),
5222  static_cast<int32_t>(hasMask.succeeded())}));
5223  return failure(llvm::isa<RankedTensorType>(shapedType) &&
5224  parser.addTypeToList(shapedType, result.types));
5225 }
5226 
5228  p << " " << getVector() << ", " << getBase() << "[" << getIndices() << "]";
5229  if (getMask())
5230  p << ", " << getMask();
5231  printTransferAttrs(p, *this);
5232  p << " : " << getVectorType() << ", " << getShapedType();
5233 }
5234 
5235 LogicalResult TransferWriteOp::verify() {
5236  // Consistency of elemental types in shape and vector.
5237  ShapedType shapedType = getShapedType();
5238  VectorType vectorType = getVectorType();
5239  VectorType maskType = getMaskType();
5240  auto permutationMap = getPermutationMap();
5241  VectorType inferredMaskType =
5242  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5243  : VectorType();
5244 
5245  if (llvm::size(getIndices()) != shapedType.getRank())
5246  return emitOpError("requires ") << shapedType.getRank() << " indices";
5247 
5248  // We do not allow broadcast dimensions on TransferWriteOps for the moment,
5249  // as the semantics is unclear. This can be revisited later if necessary.
5250  if (hasBroadcastDim())
5251  return emitOpError("should not have broadcast dimensions");
5252 
5253  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5254  shapedType, vectorType, maskType,
5255  inferredMaskType, permutationMap, getInBounds())))
5256  return failure();
5257 
5258  return verifyPermutationMap(permutationMap,
5259  [&](Twine t) { return emitOpError(t); });
5260 }
5261 
5262 //===----------------------------------------------------------------------===//
5263 // TransferWriteOp: MaskableOpInterface methods.
5264 //===----------------------------------------------------------------------===//
5265 
5266 /// Returns the mask type expected by this operation. Mostly used for
5267 /// verification purposes.
5268 Type TransferWriteOp::getExpectedMaskType() {
5269  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5270 }
5271 
5272 //===----------------------------------------------------------------------===//
5273 // TransferWriteOp: VectorTransferOpInterface methods.
5274 //===----------------------------------------------------------------------===//
5275 Value TransferWriteOp::getVector() { return getOperand(0); }
5276 VectorType TransferWriteOp::getVectorType() {
5277  return cast<VectorType>(getValueToStore().getType());
5278 }
5279 
5280 //===----------------------------------------------------------------------===//
5281 // TransferWriteOp: fold methods.
5282 //===----------------------------------------------------------------------===//
5283 /// Fold:
5284 /// ```
5285 /// %t1 = ...
5286 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
5287 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
5288 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
5289 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
5290 /// ```
5291 ///
5292 /// into:
5293 ///
5294 /// ```
5295 /// %t0
5296 /// ```
5297 ///
5298 /// The producer of t1 may or may not be DCE'd depending on whether it is a
5299 /// block argument or has side effects.
5300 static LogicalResult foldReadInitWrite(TransferWriteOp write,
5302  SmallVectorImpl<OpFoldResult> &results) {
5303  // TODO: support 0-d corner case.
5304  if (write.getTransferRank() == 0)
5305  return failure();
5306  auto rankedTensorType =
5307  llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5308  // If not operating on tensors, bail.
5309  if (!rankedTensorType)
5310  return failure();
5311  // If no read, bail.
5312  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5313  if (!read)
5314  return failure();
5315  // TODO: support 0-d corner case.
5316  if (read.getTransferRank() == 0)
5317  return failure();
5318  // For now, only accept minor identity. Future: composition is minor identity.
5319  if (!read.getPermutationMap().isMinorIdentity() ||
5320  !write.getPermutationMap().isMinorIdentity())
5321  return failure();
5322  // Bail on mismatching ranks.
5323  if (read.getTransferRank() != write.getTransferRank())
5324  return failure();
5325  // Bail on potential out-of-bounds accesses.
5326  if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5327  return failure();
5328  // Tensor types must be the same.
5329  if (read.getBase().getType() != rankedTensorType)
5330  return failure();
5331  // Vector types must be the same.
5332  if (read.getVectorType() != write.getVectorType())
5333  return failure();
5334  // Vector and Tensor shapes must match.
5335  if (read.getVectorType().getShape() != rankedTensorType.getShape())
5336  return failure();
5337  // If any index is nonzero.
5338  auto isNotConstantZero = [](Value v) {
5339  auto cstOp = getConstantIntValue(v);
5340  return !cstOp.has_value() || cstOp.value() != 0;
5341  };
5342  if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5343  llvm::any_of(write.getIndices(), isNotConstantZero))
5344  return failure();
5345  // Success.
5346  results.push_back(read.getBase());
5347  return success();
5348 }
5349 
5350 static bool checkSameValueWAR(vector::TransferReadOp read,
5351  vector::TransferWriteOp write) {
5352  return read.getBase() == write.getBase() &&
5353  read.getIndices() == write.getIndices() &&
5354  read.getPermutationMap() == write.getPermutationMap() &&
5355  read.getVectorType() == write.getVectorType() && !read.getMask() &&
5356  !write.getMask();
5357 }
5358 /// Fold transfer_write write after read:
5359 /// ```
5360 /// %t0 = ...
5361 /// %v = vector.transfer_read %t0[%c0...] :
5362 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
5363 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
5364 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
5365 /// ```
5366 ///
5367 /// into:
5368 ///
5369 /// ```
5370 /// %t0
5371 /// ```
5372 static LogicalResult foldWAR(TransferWriteOp write,
5373  SmallVectorImpl<OpFoldResult> &results) {
5374  if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5375  return failure();
5376  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5377  if (!read)
5378  return failure();
5379 
5380  if (!checkSameValueWAR(read, write))
5381  return failure();
5382  results.push_back(read.getBase());
5383  return success();
5384 }
5385 
5386 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5387  SmallVectorImpl<OpFoldResult> &results) {
5388  if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
5389  return success();
5390  if (succeeded(foldWAR(*this, results)))
5391  return success();
5392  if (succeeded(foldTransferInBoundsAttribute(*this)))
5393  return success();
5394  if (succeeded(foldTransferFullMask(*this)))
5395  return success();
5396  return memref::foldMemRefCast(*this);
5397 }
5398 
5399 //===----------------------------------------------------------------------===//
5400 // TransferWriteOp: other methods.
5401 //===----------------------------------------------------------------------===//
5402 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5403  return llvm::to_vector<4>(getVectorType().getShape());
5404 }
5405 
5406 void TransferWriteOp::getEffects(
5408  &effects) {
5409  if (llvm::isa<MemRefType>(getShapedType()))
5410  effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5412 }
5413 
5414 Speculation::Speculatability TransferWriteOp::getSpeculatability() {
5415  if (hasPureTensorSemantics())
5418 }
5419 
5420 namespace {
5421 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
5422 /// DCE
5423 /// ```
5424 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5425 /// : vector<1x4xf32>, tensor<4x4xf32>
5426 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
5427 /// : vector<1x4xf32>, tensor<4x4xf32>
5428 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5429 /// : vector<1x4xf32>, tensor<4x4xf32>
5430 /// ```
5431 ///
5432 /// into:
5433 ///
5434 /// ```
5435 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5436 /// : vector<1x4xf32>, tensor<4x4xf32>
5437 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
5438 /// : vector<1x4xf32>, tensor<4x4xf32>
5439 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5440 /// : vector<1x4xf32>, tensor<4x4xf32>
5441 /// ```
5442 ///
5443 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
5444 /// any other uses.
5445 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
5446 public:
5448  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5449  PatternRewriter &rewriter) const override {
5450  if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5451  return failure();
5452  vector::TransferWriteOp writeToModify = writeOp;
5453 
5454  auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5455  while (defWrite) {
5456  if (checkSameValueWAW(writeOp, defWrite)) {
5457  rewriter.modifyOpInPlace(writeToModify, [&]() {
5458  writeToModify.getBaseMutable().assign(defWrite.getBase());
5459  });
5460  return success();
5461  }
5463  cast<VectorTransferOpInterface>(defWrite.getOperation()),
5464  cast<VectorTransferOpInterface>(writeOp.getOperation())))
5465  break;
5466  // If the previous write op doesn't have any other use we an safely look
5467  // at the previous store to see if it can be removed.
5468  if (!defWrite->hasOneUse())
5469  break;
5470  writeToModify = defWrite;
5471  defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5472  }
5473  return failure();
5474  }
5475 };
5476 
5477 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
5478 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
5479 /// overwritten and inserted into another tensor. After this rewrite, the
5480 /// operations bufferize in-place since all of them work on the same slice.
5481 ///
5482 /// For example:
5483 /// ```mlir
5484 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
5485 /// : vector<8x16xf32>, tensor<8x16xf32>
5486 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
5487 /// : tensor<8x16xf32> to tensor<?x?xf32>
5488 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5489 /// : tensor<?x?xf32> into tensor<27x37xf32>
5490 /// ```
5491 /// folds to
5492 /// ```mlir
5493 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5494 /// : tensor<27x37xf32> to tensor<?x?xf32>
5495 /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
5496 /// : vector<8x16xf32>, tensor<?x?xf32>
5497 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5498 /// : tensor<?x?xf32> into tensor<27x37xf32>
5499 /// ```
5500 struct SwapExtractSliceOfTransferWrite
5501  : public OpRewritePattern<tensor::InsertSliceOp> {
5502 public:
5504 
5505  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5506  PatternRewriter &rewriter) const override {
5507  if (!insertOp.hasUnitStride())
5508  return failure();
5509  auto extractOp =
5510  insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5511  if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5512  return failure();
5513  auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5514  if (!transferOp || !transferOp->hasOneUse())
5515  return failure();
5516 
5517  // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
5518  // rank-reducing.
5519  if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5520  return rewriter.notifyMatchFailure(insertOp,
5521  "use-def chain is rank-reducing");
5522  }
5523 
5524  // Fail if tensor::ExtractSliceOp has non-zero offset.
5525  if (!extractOp.hasZeroOffset()) {
5526  return rewriter.notifyMatchFailure(insertOp,
5527  "ExtractSliceOp has non-zero offset");
5528  }
5529 
5530  // Fail if tensor::TransferWriteOp has non-zero offset.
5531  if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5532  return getConstantIntValue(value) == static_cast<int64_t>(0);
5533  })) {
5534  return rewriter.notifyMatchFailure(insertOp,
5535  "TranferWriteOp has non-zero offset");
5536  }
5537 
5538  // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
5539  if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5540  return rewriter.notifyMatchFailure(
5541  insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
5542  }
5543 
5544  for (auto [insertSize, extractSize] :
5545  llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5546  if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
5547  return rewriter.notifyMatchFailure(
5548  insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
5549  }
5550  }
5551 
5552  // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
5553  assert(transferOp.getVectorType().hasStaticShape() &&
5554  "expected vector to have a static shape");
5555  ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
5557  transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5558  if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
5559  return rewriter.notifyMatchFailure(
5560  insertOp, "TransferWriteOp may not write the full tensor.");
5561  }
5562 
5563  // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
5564  // Set all in_bounds to false and let the folder infer them.
5565  SmallVector<bool> newInBounds(vectorShape.size(), false);
5566  auto newExtractOp = tensor::ExtractSliceOp::create(
5567  rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5568  insertOp.getDest(), insertOp.getMixedOffsets(),
5569  insertOp.getMixedSizes(), insertOp.getMixedStrides());
5570  auto newTransferWriteOp = TransferWriteOp::create(
5571  rewriter, transferOp.getLoc(), transferOp.getVector(),
5572  newExtractOp.getResult(), transferOp.getIndices(),
5573  transferOp.getPermutationMapAttr(),
5574  rewriter.getBoolArrayAttr(newInBounds));
5575  rewriter.modifyOpInPlace(insertOp, [&]() {
5576  insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5577  });
5578  return success();
5579  }
5580 };
5581 
5582 } // namespace
5583 
5584 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5585  MLIRContext *context) {
5586  results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5587 }
5588 
5589 //===----------------------------------------------------------------------===//
5590 // LoadOp
5591 //===----------------------------------------------------------------------===//
5592 
5593 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5594  VectorType vecTy,
5595  MemRefType memRefTy) {
5596  // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
5597  // need any strides limitations.
5598  if (!vecTy.isScalable() &&
5599  (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5600  return success();
5601 
5602  if (!memRefTy.isLastDimUnitStride())
5603  return op->emitOpError("most minor memref dim must have unit stride");
5604  return success();
5605 }
5606 
5607 LogicalResult vector::LoadOp::verify() {
5608  VectorType resVecTy = getVectorType();
5609  MemRefType memRefTy = getMemRefType();
5610 
5611  if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
5612  return failure();
5613 
5614  if (memRefTy.getRank() < resVecTy.getRank())
5615  return emitOpError(
5616  "destination memref has lower rank than the result vector");
5617 
5618  // Checks for vector memrefs.
5619  Type memElemTy = memRefTy.getElementType();
5620  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5621  if (memVecTy != resVecTy)
5622  return emitOpError("base memref and result vector types should match");
5623  memElemTy = memVecTy.getElementType();
5624  }
5625 
5626  if (resVecTy.getElementType() != memElemTy)
5627  return emitOpError("base and result element types should match");
5628  if (llvm::size(getIndices()) != memRefTy.getRank())
5629  return emitOpError("requires ") << memRefTy.getRank() << " indices";
5630  return success();
5631 }
5632 
5633 OpFoldResult LoadOp::fold(FoldAdaptor) {
5634  if (succeeded(memref::foldMemRefCast(*this)))
5635  return getResult();
5636  return OpFoldResult();
5637 }
5638 
5639 std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5640  return llvm::to_vector<4>(getVectorType().getShape());
5641 }
5642 
5643 //===----------------------------------------------------------------------===//
5644 // StoreOp
5645 //===----------------------------------------------------------------------===//
5646 
5647 LogicalResult vector::StoreOp::verify() {
5648  VectorType valueVecTy = getVectorType();
5649  MemRefType memRefTy = getMemRefType();
5650 
5651  if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
5652  return failure();
5653 
5654  if (memRefTy.getRank() < valueVecTy.getRank())
5655  return emitOpError("source memref has lower rank than the vector to store");
5656 
5657  // Checks for vector memrefs.
5658  Type memElemTy = memRefTy.getElementType();
5659  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5660  if (memVecTy != valueVecTy)
5661  return emitOpError(
5662  "base memref and valueToStore vector types should match");
5663  memElemTy = memVecTy.getElementType();
5664  }
5665 
5666  if (valueVecTy.getElementType() != memElemTy)
5667  return emitOpError("base and valueToStore element type should match");
5668  if (llvm::size(getIndices()) != memRefTy.getRank())
5669  return emitOpError("requires ") << memRefTy.getRank() << " indices";
5670  return success();
5671 }
5672 
5673 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5674  SmallVectorImpl<OpFoldResult> &results) {
5675  return memref::foldMemRefCast(*this);
5676 }
5677 
5678 std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5679  return llvm::to_vector<4>(getVectorType().getShape());
5680 }
5681 
5682 //===----------------------------------------------------------------------===//
5683 // MaskedLoadOp
5684 //===----------------------------------------------------------------------===//
5685 
5686 LogicalResult MaskedLoadOp::verify() {
5687  VectorType maskVType = getMaskVectorType();
5688  VectorType passVType = getPassThruVectorType();
5689  VectorType resVType = getVectorType();
5690  MemRefType memType = getMemRefType();
5691 
5692  if (resVType.getElementType() != memType.getElementType())
5693  return emitOpError("base and result element type should match");
5694  if (llvm::size(getIndices()) != memType.getRank())
5695  return emitOpError("requires ") << memType.getRank() << " indices";
5696  if (resVType.getShape() != maskVType.getShape())
5697  return emitOpError("expected result shape to match mask shape");
5698  if (resVType != passVType)
5699  return emitOpError("expected pass_thru of same type as result type");
5700  return success();
5701 }
5702 
5703 namespace {
5704 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
5705 public:
5707  LogicalResult matchAndRewrite(MaskedLoadOp load,
5708  PatternRewriter &rewriter) const override {
5709  switch (getMaskFormat(load.getMask())) {
5710  case MaskFormat::AllTrue:
5711  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5712  load, load.getType(), load.getBase(), load.getIndices());
5713  return success();
5714  case MaskFormat::AllFalse:
5715  rewriter.replaceOp(load, load.getPassThru());
5716  return success();
5717  case MaskFormat::Unknown:
5718  return failure();
5719  }
5720  llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
5721  }
5722 };
5723 } // namespace
5724 
5725 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5726  MLIRContext *context) {
5727  results.add<MaskedLoadFolder>(context);
5728 }
5729 
5730 OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5731  if (succeeded(memref::foldMemRefCast(*this)))
5732  return getResult();
5733  return OpFoldResult();
5734 }
5735 
5736 //===----------------------------------------------------------------------===//
5737 // MaskedStoreOp
5738 //===----------------------------------------------------------------------===//
5739 
5740 LogicalResult MaskedStoreOp::verify() {
5741  VectorType maskVType = getMaskVectorType();
5742  VectorType valueVType = getVectorType();
5743  MemRefType memType = getMemRefType();
5744 
5745  if (valueVType.getElementType() != memType.getElementType())
5746  return emitOpError("base and valueToStore element type should match");
5747  if (llvm::size(getIndices()) != memType.getRank())
5748  return emitOpError("requires ") << memType.getRank() << " indices";
5749  if (valueVType.getShape() != maskVType.getShape())
5750  return emitOpError("expected valueToStore shape to match mask shape");
5751  return success();
5752 }
5753 
5754 namespace {
5755 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
5756 public:
5758  LogicalResult matchAndRewrite(MaskedStoreOp store,
5759  PatternRewriter &rewriter) const override {
5760  switch (getMaskFormat(store.getMask())) {
5761  case MaskFormat::AllTrue:
5762  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5763  store, store.getValueToStore(), store.getBase(), store.getIndices());
5764  return success();
5765  case MaskFormat::AllFalse:
5766  rewriter.eraseOp(store);
5767  return success();
5768  case MaskFormat::Unknown:
5769  return failure();
5770  }
5771  llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
5772  }
5773 };
5774 } // namespace
5775 
5776 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5777  MLIRContext *context) {
5778  results.add<MaskedStoreFolder>(context);
5779 }
5780 
5781 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5782  SmallVectorImpl<OpFoldResult> &results) {
5783  return memref::foldMemRefCast(*this);
5784 }
5785 
5786 //===----------------------------------------------------------------------===//
5787 // GatherOp
5788 //===----------------------------------------------------------------------===//
5789 
5790 LogicalResult GatherOp::verify() {
5791  VectorType indVType = getIndexVectorType();
5792  VectorType maskVType = getMaskVectorType();
5793  VectorType resVType = getVectorType();
5794  ShapedType baseType = getBaseType();
5795 
5796  if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5797  return emitOpError("requires base to be a memref or ranked tensor type");
5798 
5799  if (resVType.getElementType() != baseType.getElementType())
5800  return emitOpError("base and result element type should match");
5801  if (llvm::size(getOffsets()) != baseType.getRank())
5802  return emitOpError("requires ") << baseType.getRank() << " indices";
5803  if (resVType.getShape() != indVType.getShape())
5804  return emitOpError("expected result dim to match indices dim");
5805  if (resVType.getShape() != maskVType.getShape())
5806  return emitOpError("expected result dim to match mask dim");
5807  if (resVType != getPassThruVectorType())
5808  return emitOpError("expected pass_thru of same type as result type");
5809  return success();
5810 }
5811 
5812 // MaskableOpInterface methods.
5813 
5814 /// Returns the mask type expected by this operation. Mostly used for
5815 /// verification purposes. It requires the operation to be vectorized."
5816 Type GatherOp::getExpectedMaskType() {
5817  auto vecType = this->getIndexVectorType();
5818  return VectorType::get(vecType.getShape(),
5819  IntegerType::get(vecType.getContext(), /*width=*/1),
5820  vecType.getScalableDims());
5821 }
5822 
5823 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5824  return llvm::to_vector<4>(getVectorType().getShape());
5825 }
5826 
5827 /// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
5828 static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
5829  auto vecType = dyn_cast<VectorType>(indexVec.getType());
5830  if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5831  return failure();
5832 
5833  if (indexVec.getDefiningOp<StepOp>())
5834  return success();
5835 
5836  DenseIntElementsAttr elements;
5837  if (!matchPattern(indexVec, m_Constant(&elements)))
5838  return failure();
5839 
5840  return success(
5841  llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
5842 }
5843 
5844 namespace {
5845 class GatherFolder final : public OpRewritePattern<GatherOp> {
5846 public:
5848  LogicalResult matchAndRewrite(GatherOp gather,
5849  PatternRewriter &rewriter) const override {
5850  switch (getMaskFormat(gather.getMask())) {
5851  case MaskFormat::AllTrue:
5852  return failure(); // no unmasked equivalent
5853  case MaskFormat::AllFalse:
5854  rewriter.replaceOp(gather, gather.getPassThru());
5855  return success();
5856  case MaskFormat::Unknown:
5857  return failure();
5858  }
5859  llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
5860  }
5861 };
5862 
5863 /// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
5864 /// maskedload. Only 1D fixed vectors are supported for now.
5865 class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
5866 public:
5868  LogicalResult matchAndRewrite(GatherOp op,
5869  PatternRewriter &rewriter) const override {
5870  if (!isa<MemRefType>(op.getBase().getType()))
5871  return rewriter.notifyMatchFailure(op, "base must be of memref type");
5872 
5873  if (failed(isZeroBasedContiguousSeq(op.getIndices())))
5874  return failure();
5875 
5876  rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
5877  op.getOffsets(), op.getMask(),
5878  op.getPassThru());
5879  return success();
5880  }
5881 };
5882 } // namespace
5883 
5884 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
5885  MLIRContext *context) {
5886  results.add<GatherFolder, FoldContiguousGather>(context);
5887 }
5888 
5889 //===----------------------------------------------------------------------===//
5890 // ScatterOp
5891 //===----------------------------------------------------------------------===//
5892 
5893 LogicalResult ScatterOp::verify() {
5894  VectorType indVType = getIndexVectorType();
5895  VectorType maskVType = getMaskVectorType();
5896  VectorType valueVType = getVectorType();
5897  MemRefType memType = getMemRefType();
5898 
5899  if (valueVType.getElementType() != memType.getElementType())
5900  return emitOpError("base and valueToStore element type should match");
5901  if (llvm::size(getOffsets()) != memType.getRank())
5902  return emitOpError("requires ") << memType.getRank() << " indices";
5903  if (valueVType.getShape() != indVType.getShape())
5904  return emitOpError("expected valueToStore dim to match indices dim");
5905  if (valueVType.getShape() != maskVType.getShape())
5906  return emitOpError("expected valueToStore dim to match mask dim");
5907  return success();
5908 }
5909 
5910 namespace {
5911 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
5912 public:
5914  LogicalResult matchAndRewrite(ScatterOp scatter,
5915  PatternRewriter &rewriter) const override {
5916  switch (getMaskFormat(scatter.getMask())) {
5917  case MaskFormat::AllTrue:
5918  return failure(); // no unmasked equivalent
5919  case MaskFormat::AllFalse:
5920  rewriter.eraseOp(scatter);
5921  return success();
5922  case MaskFormat::Unknown:
5923  return failure();
5924  }
5925  llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
5926  }
5927 };
5928 
5929 /// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
5930 /// maskedstore. Only 1D fixed vectors are supported for now.
5931 class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
5932 public:
5934  LogicalResult matchAndRewrite(ScatterOp op,
5935  PatternRewriter &rewriter) const override {
5936  if (failed(isZeroBasedContiguousSeq(op.getIndices())))
5937  return failure();
5938 
5939  rewriter.replaceOpWithNewOp<MaskedStoreOp>(
5940  op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
5941  return success();
5942  }
5943 };
5944 } // namespace
5945 
5946 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
5947  MLIRContext *context) {
5948  results.add<ScatterFolder, FoldContiguousScatter>(context);
5949 }
5950 
5951 //===----------------------------------------------------------------------===//
5952 // ExpandLoadOp
5953 //===----------------------------------------------------------------------===//
5954 
5955 LogicalResult ExpandLoadOp::verify() {
5956  VectorType maskVType = getMaskVectorType();
5957  VectorType passVType = getPassThruVectorType();
5958  VectorType resVType = getVectorType();
5959  MemRefType memType = getMemRefType();
5960 
5961  if (resVType.getElementType() != memType.getElementType())
5962  return emitOpError("base and result element type should match");
5963  if (llvm::size(getIndices()) != memType.getRank())
5964  return emitOpError("requires ") << memType.getRank() << " indices";
5965  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5966  return emitOpError("expected result dim to match mask dim");
5967  if (resVType != passVType)
5968  return emitOpError("expected pass_thru of same type as result type");
5969  return success();
5970 }
5971 
5972 namespace {
5973 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
5974 public:
5976  LogicalResult matchAndRewrite(ExpandLoadOp expand,
5977  PatternRewriter &rewriter) const override {
5978  switch (getMaskFormat(expand.getMask())) {
5979  case MaskFormat::AllTrue:
5980  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5981  expand, expand.getType(), expand.getBase(), expand.getIndices());
5982  return success();
5983  case MaskFormat::AllFalse:
5984  rewriter.replaceOp(expand, expand.getPassThru());
5985  return success();
5986  case MaskFormat::Unknown:
5987  return failure();
5988  }
5989  llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
5990  }
5991 };
5992 } // namespace
5993 
5994 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5995  MLIRContext *context) {
5996  results.add<ExpandLoadFolder>(context);
5997 }
5998 
5999 //===----------------------------------------------------------------------===//
6000 // CompressStoreOp
6001 //===----------------------------------------------------------------------===//
6002 
6003 LogicalResult CompressStoreOp::verify() {
6004  VectorType maskVType = getMaskVectorType();
6005  VectorType valueVType = getVectorType();
6006  MemRefType memType = getMemRefType();
6007 
6008  if (valueVType.getElementType() != memType.getElementType())
6009  return emitOpError("base and valueToStore element type should match");
6010  if (llvm::size(getIndices()) != memType.getRank())
6011  return emitOpError("requires ") << memType.getRank() << " indices";
6012  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6013  return emitOpError("expected valueToStore dim to match mask dim");
6014  return success();
6015 }
6016 
6017 namespace {
6018 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
6019 public:
6021  LogicalResult matchAndRewrite(CompressStoreOp compress,
6022  PatternRewriter &rewriter) const override {
6023  switch (getMaskFormat(compress.getMask())) {
6024  case MaskFormat::AllTrue:
6025  rewriter.replaceOpWithNewOp<vector::StoreOp>(
6026  compress, compress.getValueToStore(), compress.getBase(),
6027  compress.getIndices());
6028  return success();
6029  case MaskFormat::AllFalse:
6030  rewriter.eraseOp(compress);
6031  return success();
6032  case MaskFormat::Unknown:
6033  return failure();
6034  }
6035  llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
6036  }
6037 };
6038 } // namespace
6039 
6040 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6041  MLIRContext *context) {
6042  results.add<CompressStoreFolder>(context);
6043 }
6044 
6045 //===----------------------------------------------------------------------===//
6046 // ShapeCastOp
6047 //===----------------------------------------------------------------------===//
6048 
6049 void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6050  SetIntRangeFn setResultRanges) {
6051  setResultRanges(getResult(), argRanges.front());
6052 }
6053 
6054 LogicalResult ShapeCastOp::verify() {
6055 
6056  VectorType sourceType = getSourceVectorType();
6057  VectorType resultType = getResultVectorType();
6058 
6059  // Check that element type is preserved
6060  if (sourceType.getElementType() != resultType.getElementType())
6061  return emitOpError("has different source and result element types");
6062 
6063  // Check that number of elements is preserved
6064  int64_t sourceNElms = sourceType.getNumElements();
6065  int64_t resultNElms = resultType.getNumElements();
6066  if (sourceNElms != resultNElms) {
6067  return emitOpError() << "has different number of elements at source ("
6068  << sourceNElms << ") and result (" << resultNElms
6069  << ")";
6070  }
6071 
6072  // Check that (non-)scalability is preserved
6073  int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6074  int64_t resultNScalableDims = resultType.getNumScalableDims();
6075  if (sourceNScalableDims != resultNScalableDims)
6076  return emitOpError() << "has different number of scalable dims at source ("
6077  << sourceNScalableDims << ") and result ("
6078  << resultNScalableDims << ")";
6079 
6080  return success();
6081 }
6082 
6083 /// Return true if `transpose` does not permute a pair of non-unit dims.
6084 /// By `order preserving` we mean that the flattened versions of the input and
6085 /// output vectors are (numerically) identical. In other words `transpose` is
6086 /// effectively a shape cast.
6087 static bool isOrderPreserving(TransposeOp transpose) {
6088  ArrayRef<int64_t> permutation = transpose.getPermutation();
6089  VectorType sourceType = transpose.getSourceVectorType();
6090  ArrayRef<int64_t> inShape = sourceType.getShape();
6091  ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6092  auto isNonScalableUnitDim = [&](int64_t dim) {
6093  return inShape[dim] == 1 && !inDimIsScalable[dim];
6094  };
6095  int64_t current = 0;
6096  for (auto p : permutation) {
6097  if (!isNonScalableUnitDim(p)) {
6098  if (p < current) {
6099  return false;
6100  }
6101  current = p;
6102  }
6103  }
6104  return true;
6105 }
6106 
6107 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6108 
6109  VectorType resultType = getType();
6110 
6111  // No-op shape cast.
6112  if (getSource().getType() == resultType)
6113  return getSource();
6114 
6115  // shape_cast(shape_cast(x)) -> shape_cast(x)
6116  if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6117  setOperand(precedingShapeCast.getSource());
6118  return getResult();
6119  }
6120 
6121  // shape_cast(transpose(x)) -> shape_cast(x)
6122  if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6123  if (isOrderPreserving(transpose)) {
6124  setOperand(transpose.getVector());
6125  return getResult();
6126  }
6127  return {};
6128  }
6129 
6130  // Y = shape_cast(broadcast(X))
6131  // -> X, if X and Y have same type
6132  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6133  if (bcastOp.getSourceType() == resultType)
6134  return bcastOp.getSource();
6135  }
6136 
6137  // shape_cast(constant) -> constant
6138  if (auto denseAttr =
6139  dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6140  return denseAttr.reshape(getType());
6141 
6142  // shape_cast(poison) -> poison
6143  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
6144  return ub::PoisonAttr::get(getContext());
6145 
6146  return {};
6147 }
6148 
6149 namespace {
6150 
6151 /// Helper function that computes a new vector type based on the input vector
6152 /// type by removing the trailing one dims:
6153 ///
6154 /// vector<4x1x1xi1> --> vector<4x1xi1>
6155 ///
6156 static VectorType trimTrailingOneDims(VectorType oldType) {
6157  ArrayRef<int64_t> oldShape = oldType.getShape();
6158  ArrayRef<int64_t> newShape = oldShape;
6159 
6160  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6161  ArrayRef<bool> newScalableDims = oldScalableDims;
6162 
6163  while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6164  newShape = newShape.drop_back(1);
6165  newScalableDims = newScalableDims.drop_back(1);
6166  }
6167 
6168  // Make sure we have at least 1 dimension.
6169  // TODO: Add support for 0-D vectors.
6170  if (newShape.empty()) {
6171  newShape = oldShape.take_back();
6172  newScalableDims = oldScalableDims.take_back();
6173  }
6174 
6175  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6176 }
6177 
6178 /// Folds qualifying shape_cast(create_mask) into a new create_mask
6179 ///
6180 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
6181 /// dimension. If the input vector comes from `vector.create_mask` for which
6182 /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
6183 /// to fold shape_cast into create_mask.
6184 ///
6185 /// BEFORE:
6186 /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
6187 /// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
6188 /// AFTER:
6189 /// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
6190 class ShapeCastCreateMaskFolderTrailingOneDim final
6191  : public OpRewritePattern<ShapeCastOp> {
6192 public:
6194 
6195  LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6196  PatternRewriter &rewriter) const override {
6197  Value shapeOpSrc = shapeOp->getOperand(0);
6198  auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
6199  auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
6200  if (!createMaskOp && !constantMaskOp)
6201  return failure();
6202 
6203  VectorType shapeOpResTy = shapeOp.getResultVectorType();
6204  VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6205 
6206  VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6207  if (newVecType != shapeOpResTy)
6208  return failure();
6209 
6210  auto numDimsToDrop =
6211  shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6212 
6213  // No unit dims to drop
6214  if (!numDimsToDrop)
6215  return failure();
6216 
6217  if (createMaskOp) {
6218  auto maskOperands = createMaskOp.getOperands();
6219  auto numMaskOperands = maskOperands.size();
6220 
6221  // Check every mask dim size to see whether it can be dropped
6222  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6223  --i) {
6224  auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6225  if (!constant || (constant.value() != 1))
6226  return failure();
6227  }
6228  SmallVector<Value> newMaskOperands =
6229  maskOperands.drop_back(numDimsToDrop);
6230 
6231  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
6232  newMaskOperands);
6233  return success();
6234  }
6235 
6236  if (constantMaskOp) {
6237  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6238  auto numMaskOperands = maskDimSizes.size();
6239 
6240  // Check every mask dim size to see whether it can be dropped
6241  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6242  --i) {
6243  if (maskDimSizes[i] != 1)
6244  return failure();
6245  }
6246 
6247  auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6248  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
6249  newMaskOperands);
6250  return success();
6251  }
6252 
6253  return failure();
6254  }
6255 };
6256 
6257 /// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
6258 /// i) Y = ShapeCast(X), or
6259 /// ii) Y = Broadcast(X)
6260 /// If both (i) and (ii) are possible, (i) is chosen.
6261 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
6262 public:
6264 
6265  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6266  PatternRewriter &rewriter) const override {
6267  auto broadcastOp =
6268  shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6269  if (!broadcastOp)
6270  return failure();
6271 
6272  auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6273  bool srcIsScalar = !srcVectorType;
6274 
6275  // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
6276  // Example:
6277  // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
6278  // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
6279  // to
6280  // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
6281  if (srcVectorType) {
6282  if (srcVectorType.getNumElements() ==
6283  shapeCastOp.getResultVectorType().getNumElements()) {
6284  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
6285  shapeCastOp, shapeCastOp.getResultVectorType(),
6286  broadcastOp.getSource());
6287  return success();
6288  }
6289  }
6290 
6291  // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
6292  // Example
6293  // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
6294  // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
6295  // to
6296  // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
6297  VectorType dstVectorType = shapeCastOp.getResultVectorType();
6298  if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
6299  BroadcastableToResult::Success) {
6300  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6301  shapeCastOp, dstVectorType, broadcastOp.getSource());
6302  return success();
6303  }
6304  return failure();
6305  }
6306 };
6307 
6308 } // namespace
6309 
6310 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6311  MLIRContext *context) {
6312  results
6313  .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
6314  context);
6315 }
6316 
6317 //===----------------------------------------------------------------------===//
6318 // VectorBitCastOp
6319 //===----------------------------------------------------------------------===//
6320 
6321 LogicalResult BitCastOp::verify() {
6322  auto sourceVectorType = getSourceVectorType();
6323  auto resultVectorType = getResultVectorType();
6324 
6325  for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6326  if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6327  return emitOpError("dimension size mismatch at: ") << i;
6328  }
6329 
6330  DataLayout dataLayout = DataLayout::closest(*this);
6331  auto sourceElementBits =
6332  dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
6333  auto resultElementBits =
6334  dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
6335 
6336  if (sourceVectorType.getRank() == 0) {
6337  if (sourceElementBits != resultElementBits)
6338  return emitOpError("source/result bitwidth of the 0-D vector element "
6339  "types must be equal");
6340  } else if (sourceElementBits * sourceVectorType.getShape().back() !=
6341  resultElementBits * resultVectorType.getShape().back()) {
6342  return emitOpError(
6343  "source/result bitwidth of the minor 1-D vectors must be equal");
6344  }
6345 
6346  return success();
6347 }
6348 
6349 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
6350  // Nop cast.
6351  if (getSource().getType() == getResult().getType())
6352  return getSource();
6353 
6354  // Canceling bitcasts.
6355  if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6356  if (getResult().getType() == otherOp.getSource().getType())
6357  return otherOp.getSource();
6358 
6359  setOperand(otherOp.getSource());
6360  return getResult();
6361  }
6362 
6363  Attribute sourceConstant = adaptor.getSource();
6364  if (!sourceConstant)
6365  return {};
6366 
6367  Type srcElemType = getSourceVectorType().getElementType();
6368  Type dstElemType = getResultVectorType().getElementType();
6369 
6370  if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6371  if (floatPack.isSplat()) {
6372  auto splat = floatPack.getSplatValue<FloatAttr>();
6373 
6374  // Casting fp16 into fp32.
6375  if (srcElemType.isF16() && dstElemType.isF32()) {
6376  uint32_t bits = static_cast<uint32_t>(
6377  splat.getValue().bitcastToAPInt().getZExtValue());
6378  // Duplicate the 16-bit pattern.
6379  bits = (bits << 16) | (bits & 0xffff);
6380  APInt intBits(32, bits);
6381  APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6382  return DenseElementsAttr::get(getResultVectorType(), floatBits);
6383  }
6384  }
6385  }
6386 
6387  if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6388  if (intPack.isSplat()) {
6389  auto splat = intPack.getSplatValue<IntegerAttr>();
6390 
6391  if (llvm::isa<IntegerType>(dstElemType)) {
6392  uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
6393  uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
6394 
6395  // Casting to a larger integer bit width.
6396  if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6397  APInt intBits = splat.getValue().zext(dstBitWidth);
6398 
6399  // Duplicate the lower width element.
6400  for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6401  intBits = (intBits << srcBitWidth) | intBits;
6402  return DenseElementsAttr::get(getResultVectorType(), intBits);
6403  }
6404  }
6405  }
6406  }
6407 
6408  return {};
6409 }
6410 
6411 //===----------------------------------------------------------------------===//
6412 // TypeCastOp
6413 //===----------------------------------------------------------------------===//
6414 
6415 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
6416  auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6417  SmallVector<int64_t, 8> res(memRefType.getShape());
6418  if (vectorType)
6419  res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6420  return res;
6421 }
6422 
6423 /// Build the canonical memRefType with a single vector.
6424 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
6425 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
6426  Value source) {
6427  result.addOperands(source);
6428  MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
6429  VectorType vectorType =
6430  VectorType::get(extractShape(memRefType),
6432  result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
6433  memRefType.getMemorySpace()));
6434 }
6435 
6436 LogicalResult TypeCastOp::verify() {
6437  MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
6438  if (!canonicalType.getLayout().isIdentity())
6439  return emitOpError("expects operand to be a memref with identity layout");
6440  if (!getResultMemRefType().getLayout().isIdentity())
6441  return emitOpError("expects result to be a memref with identity layout");
6442  if (getResultMemRefType().getMemorySpace() !=
6443  getMemRefType().getMemorySpace())
6444  return emitOpError("expects result in same memory space");
6445 
6446  auto sourceType = getMemRefType();
6447  auto resultType = getResultMemRefType();
6448  if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
6450  return emitOpError(
6451  "expects result and operand with same underlying scalar type: ")
6452  << resultType;
6453  if (extractShape(sourceType) != extractShape(resultType))
6454  return emitOpError(
6455  "expects concatenated result and operand shapes to be equal: ")
6456  << resultType;
6457  return success();
6458 }
6459 
6460 //===----------------------------------------------------------------------===//
6461 // TransposeOp
6462 //===----------------------------------------------------------------------===//
6463 
6464 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
6465  Value vector, ArrayRef<int64_t> permutation) {
6466  VectorType vt = llvm::cast<VectorType>(vector.getType());
6467  SmallVector<int64_t, 4> transposedShape(vt.getRank());
6468  SmallVector<bool, 4> transposedScalableDims(vt.getRank());
6469  for (unsigned i = 0; i < permutation.size(); ++i) {
6470  transposedShape[i] = vt.getShape()[permutation[i]];
6471  transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6472  }
6473 
6474  result.addOperands(vector);
6475  result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
6476  transposedScalableDims));
6477  result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
6478  builder.getDenseI64ArrayAttr(permutation));
6479 }
6480 
6481 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6482  // Eliminate splat constant transpose ops.
6483  if (auto splat =
6484  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6485  return splat.reshape(getResultVectorType());
6486 
6487  // Eliminate poison transpose ops.
6488  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
6489  return ub::PoisonAttr::get(getContext());
6490 
6491  // Eliminate identity transposes, and more generally any transposes that
6492  // preserves the shape without permuting elements.
6493  //
6494  // Examples of what to fold:
6495  // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6496  // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6497  // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6498  //
6499  // Example of what NOT to fold:
6500  // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6501  //
6502  if (getSourceVectorType() == getResultVectorType() &&
6503  isOrderPreserving(*this))
6504  return getVector();
6505 
6506  return {};
6507 }
6508 
6509 LogicalResult vector::TransposeOp::verify() {
6510  VectorType vectorType = getSourceVectorType();
6511  VectorType resultType = getResultVectorType();
6512  int64_t rank = resultType.getRank();
6513  if (vectorType.getRank() != rank)
6514  return emitOpError("vector result rank mismatch: ") << rank;
6515  // Verify transposition array.
6516  ArrayRef<int64_t> perm = getPermutation();
6517  int64_t size = perm.size();
6518  if (rank != size)
6519  return emitOpError("transposition length mismatch: ") << size;
6520  SmallVector<bool, 8> seen(rank, false);
6521  for (const auto &ta : llvm::enumerate(perm)) {
6522  if (ta.value() < 0 || ta.value() >= rank)
6523  return emitOpError("transposition index out of range: ") << ta.value();
6524  if (seen[ta.value()])
6525  return emitOpError("duplicate position index: ") << ta.value();
6526  seen[ta.value()] = true;
6527  if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6528  return emitOpError("dimension size mismatch at: ") << ta.value();
6529  }
6530  return success();
6531 }
6532 
6533 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6534  return llvm::to_vector<4>(getResultVectorType().getShape());
6535 }
6536 
6537 void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6538  SetIntRangeFn setResultRanges) {
6539  setResultRanges(getResult(), argRanges.front());
6540 }
6541 
6542 namespace {
6543 
6544 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
6545 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
6546 public:
6548 
6549  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6550  PatternRewriter &rewriter) const override {
6551  // Composes two permutations: result[i] = permutation1[permutation2[i]].
6552  auto composePermutations = [](ArrayRef<int64_t> permutation1,
6553  ArrayRef<int64_t> permutation2) {
6554  SmallVector<int64_t, 4> result;
6555  for (auto index : permutation2)
6556  result.push_back(permutation1[index]);
6557  return result;
6558  };
6559 
6560  // Return if the input of 'transposeOp' is not defined by another transpose.
6561  vector::TransposeOp parentTransposeOp =
6562  transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6563  if (!parentTransposeOp)
6564  return failure();
6565 
6566  SmallVector<int64_t, 4> permutation = composePermutations(
6567  parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6568  // Replace 'transposeOp' with a new transpose operation.
6569  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
6570  transposeOp, transposeOp.getResult().getType(),
6571  parentTransposeOp.getVector(), permutation);
6572  return success();
6573  }
6574 };
6575 
6576 /// Replace transpose(splat-like(v)) with broadcast(v)
6577 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
6578 public:
6580 
6581  LogicalResult matchAndRewrite(TransposeOp transposeOp,
6582  PatternRewriter &rewriter) const override {
6583  Value splat = getScalarSplatSource(transposeOp.getVector());
6584  if (!splat)
6585  return failure();
6586 
6587  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6588  transposeOp, transposeOp.getResultVectorType(), splat);
6589  return success();
6590  }
6591 };
6592 
6593 /// Folds transpose(create_mask) into a new transposed create_mask.
6594 class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6595 public:
6597 
6598  LogicalResult matchAndRewrite(TransposeOp transpOp,
6599  PatternRewriter &rewriter) const override {
6600  Value transposeSrc = transpOp.getVector();
6601  auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
6602  auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
6603  if (!createMaskOp && !constantMaskOp)
6604  return failure();
6605 
6606  // Get the transpose permutation and apply it to the vector.create_mask or
6607  // vector.constant_mask operands.
6608  ArrayRef<int64_t> permutation = transpOp.getPermutation();
6609 
6610  if (createMaskOp) {
6611  auto maskOperands = createMaskOp.getOperands();
6612  SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
6613  applyPermutationToVector(newOperands, permutation);
6614 
6615  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
6616  transpOp, transpOp.getResultVectorType(), newOperands);
6617  return success();
6618  }
6619 
6620  // ConstantMaskOp case.
6621  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6622  auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
6623 
6624  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
6625  transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6626  return success();
6627  }
6628 };
6629 
6630 /// Folds transpose(shape_cast) into a new shape_cast.
6631 class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6632 public:
6634 
6635  LogicalResult matchAndRewrite(TransposeOp transposeOp,
6636  PatternRewriter &rewriter) const override {
6637  auto shapeCastOp =
6638  transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6639  if (!shapeCastOp)
6640  return failure();
6641  if (!isOrderPreserving(transposeOp))
6642  return failure();
6643 
6644  VectorType resultType = transposeOp.getType();
6645 
6646  // We don't need to check isValidShapeCast at this point, because it is
6647  // guaranteed that merging the transpose into the the shape_cast is a valid
6648  // shape_cast, because the transpose just inserts/removes ones.
6649 
6650  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
6651  shapeCastOp.getSource());
6652  return success();
6653  }
6654 };
6655 
6656 /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
6657 /// 'order preserving', where 'order preserving' means the flattened
6658 /// inputs and outputs of the transpose have identical (numerical) values.
6659 ///
6660 /// Example:
6661 /// ```
6662 /// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
6663 /// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
6664 /// to vector<8x1xi32>
6665 /// ```
6666 /// can be rewritten as the equivalent
6667 /// ```
6668 /// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
6669 /// ```
6670 /// The algorithm works by partitioning dimensions into groups that can be
6671 /// locally permuted while preserving order, and checks that the transpose
6672 /// only permutes within these groups.
6673 ///
6674 /// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
6675 /// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
6676 /// broadcasting from 1x1x4x1x1x7.
6677 /// ^^^ ^ ^^^ ^
6678 /// groups: 0 1 2 3
6679 /// Order preserving permutations for this example are ones that only permute
6680 /// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
6681 class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6682 public:
6684  FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
6685  : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
6686 
6687  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
6688  PatternRewriter &rewriter) const override {
6689 
6690  vector::BroadcastOp broadcast =
6691  transpose.getVector().getDefiningOp<vector::BroadcastOp>();
6692  if (!broadcast) {
6693  return rewriter.notifyMatchFailure(transpose,
6694  "not preceded by a broadcast");
6695  }
6696 
6697  auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
6698  VectorType outputType = transpose.getResultVectorType();
6699 
6700  // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
6701  bool inputIsScalar = !inputType;
6702  if (inputIsScalar) {
6703  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
6704  broadcast.getSource());
6705  return success();
6706  }
6707 
6708  ArrayRef<int64_t> permutation = transpose.getPermutation();
6709  ArrayRef<int64_t> inputShape = inputType.getShape();
6710  int64_t inputRank = inputType.getRank();
6711  int64_t outputRank = transpose.getType().getRank();
6712  int64_t deltaRank = outputRank - inputRank;
6713 
6714  int low = 0;
6715  for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
6716  bool notOne = inputShape[inputIndex] != 1;
6717  bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
6718  bool groupEndFound = notOne || prevNotOne;
6719  if (groupEndFound) {
6720  int high = inputIndex + deltaRank;
6721  // Return failure if not all permutation destinations for indices in
6722  // [low, high) are in [low, high), i.e. the permutation is not local to
6723  // the group.
6724  for (int i = low; i < high; ++i) {
6725  if (permutation[i] < low || permutation[i] >= high) {
6726  return rewriter.notifyMatchFailure(
6727  transpose, "permutation not local to group");
6728  }
6729  }
6730  low = high;
6731  }
6732  }
6733 
6734  // We don't need to check the final group [low, outputRank) because if it is
6735  // not locally bound, there must be a preceding group that already failed
6736  // the check (impossible to have just 1 non-locally bound group).
6737 
6738  // The preceding logic also ensures that at this point, the output of the
6739  // transpose is definitely broadcastable from the input shape, assert so:
6740  assert(vector::isBroadcastableTo(inputType, outputType) ==
6741  vector::BroadcastableToResult::Success &&
6742  "not broadcastable directly to transpose output");
6743 
6744  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
6745  broadcast.getSource());
6746 
6747  return success();
6748  }
6749 };
6750 
6751 } // namespace
6752 
6753 void vector::TransposeOp::getCanonicalizationPatterns(
6754  RewritePatternSet &results, MLIRContext *context) {
6755  results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6756  FoldTransposeSplat, FoldTransposeBroadcast>(context);
6757 }
6758 
6759 //===----------------------------------------------------------------------===//
6760 // ConstantMaskOp
6761 //===----------------------------------------------------------------------===//
6762 
6763 void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
6764  VectorType type, ConstantMaskKind kind) {
6765  assert(kind == ConstantMaskKind::AllTrue ||
6766  kind == ConstantMaskKind::AllFalse);
6767  build(builder, result, type,
6768  kind == ConstantMaskKind::AllTrue
6769  ? type.getShape()
6770  : SmallVector<int64_t>(type.getRank(), 0));
6771 }
6772 
6773 LogicalResult ConstantMaskOp::verify() {
6774  auto resultType = llvm::cast<VectorType>(getResult().getType());
6775  // Check the corner case of 0-D vectors first.
6776  if (resultType.getRank() == 0) {
6777  if (getMaskDimSizes().size() != 1)
6778  return emitError("array attr must have length 1 for 0-D vectors");
6779  auto dim = getMaskDimSizes()[0];
6780  if (dim != 0 && dim != 1)
6781  return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
6782  return success();
6783  }
6784 
6785  // Verify that array attr size matches the rank of the vector result.
6786  if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
6787  return emitOpError(
6788  "must specify array attr of size equal vector result rank");
6789  // Verify that each array attr element is in bounds of corresponding vector
6790  // result dimension size.
6791  auto resultShape = resultType.getShape();
6792  auto resultScalableDims = resultType.getScalableDims();
6793  ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
6794  for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
6795  if (maskDimSize < 0 || maskDimSize > resultShape[index])
6796  return emitOpError(
6797  "array attr of size out of bounds of vector result dimension size");
6798  if (resultScalableDims[index] && maskDimSize != 0 &&
6799  maskDimSize != resultShape[index])
6800  return emitOpError(
6801  "only supports 'none set' or 'all set' scalable dimensions");
6802  }
6803  // Verify that if one mask dim size is zero, they all should be zero (because
6804  // the mask region is a conjunction of each mask dimension interval).
6805  bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6806  bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
6807  if (anyZeros && !allZeros)
6808  return emitOpError("expected all mask dim sizes to be zeros, "
6809  "as a result of conjunction with zero mask dim");
6810  return success();
6811 }
6812 
6813 bool ConstantMaskOp::isAllOnesMask() {
6814  auto resultType = getVectorType();
6815  // Check the corner case of 0-D vectors first.
6816  if (resultType.getRank() == 0) {
6817  assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
6818  return getMaskDimSizes()[0] == 1;
6819  }
6820  for (const auto [resultSize, maskDimSize] :
6821  llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
6822  if (maskDimSize < resultSize)
6823  return false;
6824  }
6825  return true;
6826 }
6827 
6828 OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
6829  ArrayRef<int64_t> bounds = getMaskDimSizes();
6830  ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
6831 
6832  auto createBoolSplat = [&](bool x) {
6834  BoolAttr::get(getContext(), x));
6835  };
6836 
6837  // Check the corner case of 0-D vectors first.
6838  if (vectorSizes.empty()) {
6839  assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
6840  return createBoolSplat(bounds[0] == 1);
6841  }
6842  // Fold vector.constant_mask to splat if possible.
6843  if (bounds == vectorSizes)
6844  return createBoolSplat(true);
6845  if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
6846  return createBoolSplat(false);
6847  return OpFoldResult();
6848 }
6849 
6850 //===----------------------------------------------------------------------===//
6851 // CreateMaskOp
6852 //===----------------------------------------------------------------------===//
6853 
6854 void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
6855  VectorType type,
6856  ArrayRef<OpFoldResult> mixedOperands) {
6857  SmallVector<Value> operands =
6858  getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
6859  build(builder, result, type, operands);
6860 }
6861 
6862 LogicalResult CreateMaskOp::verify() {
6863  auto vectorType = llvm::cast<VectorType>(getResult().getType());
6864  // Verify that an operand was specified for each result vector each dimension.
6865  if (vectorType.getRank() == 0) {
6866  if (getNumOperands() != 1)
6867  return emitOpError(
6868  "must specify exactly one operand for 0-D create_mask");
6869  } else if (getNumOperands() !=
6870  llvm::cast<VectorType>(getResult().getType()).getRank()) {
6871  return emitOpError(
6872  "must specify an operand for each result vector dimension");
6873  }
6874  return success();
6875 }
6876 
6877 namespace {
6878 
6879 /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
6880 ///
6881 /// Ex 1:
6882 /// %c2 = arith.constant 2 : index
6883 /// %c3 = arith.constant 3 : index
6884 /// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
6885 /// Becomes:
6886 /// vector.constant_mask [3, 2] : vector<4x3xi1>
6887 ///
6888 /// Ex 2:
6889 /// %c_neg_1 = arith.constant -1 : index
6890 /// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
6891 /// becomes:
6892 /// vector.constant_mask [0] : vector<[8]xi1>
6893 ///
6894 /// Ex 3:
6895 /// %c8 = arith.constant 8 : index
6896 /// %c16 = arith.constant 16 : index
6897 /// %0 = vector.vscale
6898 /// %1 = arith.muli %0, %c16 : index
6899 /// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
6900 /// becomes:
6901 /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
6902 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
6903 public:
6905 
6906  LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
6907  PatternRewriter &rewriter) const override {
6908  VectorType maskType = createMaskOp.getVectorType();
6909  ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
6910  ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6911 
6912  // Special case: Rank zero shape.
6913  constexpr std::array<int64_t, 1> rankZeroShape{1};
6914  constexpr std::array<bool, 1> rankZeroScalableDims{false};
6915  if (maskType.getRank() == 0) {
6916  maskTypeDimSizes = rankZeroShape;
6917  maskTypeDimScalableFlags = rankZeroScalableDims;
6918  }
6919 
6920  // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
6921  // collect the `constantDims` (for the ConstantMaskOp).
6922  SmallVector<int64_t, 4> constantDims;
6923  for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
6924  if (auto intSize = getConstantIntValue(dimSize)) {
6925  // Constant value.
6926  // If the mask dim is non-scalable this can be any value.
6927  // If the mask dim is scalable only zero (all-false) is supported.
6928  if (maskTypeDimScalableFlags[i] && intSize >= 0)
6929  return failure();
6930  constantDims.push_back(*intSize);
6931  } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
6932  // Constant vscale multiple (e.g. 4 x vscale).
6933  // Must be all-true to fold to a ConstantMask.
6934  if (vscaleMultiplier < maskTypeDimSizes[i])
6935  return failure();
6936  constantDims.push_back(*vscaleMultiplier);
6937  } else {
6938  return failure();
6939  }
6940  }
6941 
6942  // Clamp values to constant_mask bounds.
6943  for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
6944  value = std::clamp<int64_t>(value, 0, maskDimSize);
6945 
6946  // If one of dim sizes is zero, set all dims to zero.
6947  if (llvm::is_contained(constantDims, 0))
6948  constantDims.assign(constantDims.size(), 0);
6949 
6950  // Replace 'createMaskOp' with ConstantMaskOp.
6951  rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
6952  constantDims);
6953  return success();
6954  }
6955 };
6956 
6957 } // namespace
6958 
6959 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6960  MLIRContext *context) {
6961  results.add<CreateMaskFolder>(context);
6962 }
6963 
6964 //===----------------------------------------------------------------------===//
6965 // MaskOp
6966 //===----------------------------------------------------------------------===//
6967 
6968 void MaskOp::build(
6969  OpBuilder &builder, OperationState &result, Value mask,
6970  Operation *maskableOp,
6971  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6972  assert(maskRegionBuilder &&
6973  "builder callback for 'maskRegion' must be present");
6974 
6975  result.addOperands(mask);
6976  OpBuilder::InsertionGuard guard(builder);
6977  Region *maskRegion = result.addRegion();
6978  builder.createBlock(maskRegion);
6979  maskRegionBuilder(builder, maskableOp);
6980 }
6981 
6982 void MaskOp::build(
6983  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6984  Value mask, Operation *maskableOp,
6985  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6986  build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
6987  maskRegionBuilder);
6988 }
6989 
6990 void MaskOp::build(
6991  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6992  Value mask, Value passthru, Operation *maskableOp,
6993  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6994  build(builder, result, mask, maskableOp, maskRegionBuilder);
6995  if (passthru)
6996  result.addOperands(passthru);
6997  result.addTypes(resultTypes);
6998 }
6999 
7000 ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
7001  // Create the op region.
7002  result.regions.reserve(1);
7003  Region &maskRegion = *result.addRegion();
7004 
7005  auto &builder = parser.getBuilder();
7006 
7007  // Parse all the operands.
7009  if (parser.parseOperand(mask))
7010  return failure();
7011 
7012  // Optional passthru operand.
7014  ParseResult parsePassthru = parser.parseOptionalComma();
7015  if (parsePassthru.succeeded() && parser.parseOperand(passthru))
7016  return failure();
7017 
7018  // Parse op region.
7019  if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
7020  return failure();
7021 
7022  MaskOp::ensureTerminator(maskRegion, builder, result.location);
7023 
7024  // Parse the optional attribute list.
7025  if (parser.parseOptionalAttrDict(result.attributes))
7026  return failure();
7027 
7028  // Parse all the types.
7029  Type maskType;
7030  if (parser.parseColonType(maskType))
7031  return failure();
7032 
7033  SmallVector<Type> resultTypes;
7034  if (parser.parseOptionalArrowTypeList(resultTypes))
7035  return failure();
7036  result.types.append(resultTypes);
7037 
7038  // Resolve operands.
7039  if (parser.resolveOperand(mask, maskType, result.operands))
7040  return failure();
7041 
7042  if (parsePassthru.succeeded()) {
7043  if (resultTypes.empty())
7044  return parser.emitError(
7045  parser.getNameLoc(),
7046  "expects a result if passthru operand is provided");
7047 
7048  if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
7049  return failure();
7050  }
7051 
7052  return success();
7053 }
7054 
7056  p << " " << getMask();
7057  if (getPassthru())
7058  p << ", " << getPassthru();
7059 
7060  // Print single masked operation and skip terminator.
7061  p << " { ";
7062  Block *singleBlock = &getMaskRegion().getBlocks().front();
7063  if (singleBlock && !singleBlock->getOperations().empty())
7064  p.printCustomOrGenericOp(&singleBlock->front());
7065  p << " }";
7066 
7067  p.printOptionalAttrDict(getOperation()->getAttrs());
7068 
7069  p << " : " << getMask().getType();
7070  if (getNumResults() > 0)
7071  p << " -> " << getResultTypes();
7072 }
7073 
7074 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
7075  // 1. For an empty `vector.mask`, create a default terminator.
7076  if (region.empty() || region.front().empty()) {
7078  MaskOp>::ensureTerminator(region, builder, loc);
7079  return;
7080  }
7081 
7082  // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
7083  Block &block = region.front();
7084  if (isa<vector::YieldOp>(block.back()))
7085  return;
7086 
7087  // 3. For a non-empty `vector.mask` without an explicit terminator:
7088 
7089  // Create default terminator if the number of masked operations is not
7090  // one. This case will trigger a verification failure.
7091  if (block.getOperations().size() != 1) {
7093  MaskOp>::ensureTerminator(region, builder, loc);
7094  return;
7095  }
7096 
7097  // Create a terminator that yields the results from the masked operation.
7098  OpBuilder opBuilder(builder.getContext());
7099  Operation *maskedOp = &block.front();
7100  opBuilder.setInsertionPointToEnd(&block);
7101  vector::YieldOp::create(opBuilder, loc, maskedOp->getResults());
7102 }
7103 
7104 LogicalResult MaskOp::verify() {
7105  // Structural checks.
7106  Block &block = getMaskRegion().getBlocks().front();
7107  if (block.getOperations().empty())
7108  return emitOpError("expects a terminator within the mask region");
7109 
7110  unsigned numMaskRegionOps = block.getOperations().size();
7111  if (numMaskRegionOps > 2)
7112  return emitOpError("expects only one operation to mask");
7113 
7114  // Terminator checks.
7115  auto terminator = dyn_cast<vector::YieldOp>(block.back());
7116  if (!terminator)
7117  return emitOpError("expects a terminator within the mask region");
7118 
7119  if (terminator->getNumOperands() != getNumResults())
7120  return emitOpError(
7121  "expects number of results to match mask region yielded values");
7122 
7123  // Empty vector.mask. Nothing else to check.
7124  if (numMaskRegionOps == 1)
7125  return success();
7126 
7127  auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
7128  if (!maskableOp)
7129  return emitOpError("expects a MaskableOpInterface within the mask region");
7130 
7131  // Result checks.
7132  if (maskableOp->getNumResults() != getNumResults())
7133  return emitOpError("expects number of results to match maskable operation "
7134  "number of results");
7135 
7136  if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
7137  return emitOpError("expects all the results from the MaskableOpInterface "
7138  "to match all the values returned by the terminator");
7139 
7140  if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
7141  return emitOpError(
7142  "expects result type to match maskable operation result type");
7143 
7144  if (llvm::count_if(maskableOp->getResultTypes(),
7145  [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7146  return emitOpError("multiple vector results not supported");
7147 
7148  // Mask checks.
7149  Type expectedMaskType = maskableOp.getExpectedMaskType();
7150  if (getMask().getType() != expectedMaskType)
7151  return emitOpError("expects a ")
7152  << expectedMaskType << " mask for the maskable operation";
7153 
7154  // Passthru checks.
7155  Value passthru = getPassthru();
7156  if (passthru) {
7157  if (!maskableOp.supportsPassthru())
7158  return emitOpError(
7159  "doesn't expect a passthru argument for this maskable operation");
7160 
7161  if (maskableOp->getNumResults() != 1)
7162  return emitOpError("expects result when passthru argument is provided");
7163 
7164  if (passthru.getType() != maskableOp->getResultTypes()[0])
7165  return emitOpError("expects passthru type to match result type");
7166  }
7167 
7168  return success();
7169 }
7170 
7171 /// Folds empty `vector.mask` with no passthru operand and with or without
7172 /// return values. For example:
7173 ///
7174 /// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
7175 /// vector<8xi1> -> vector<8xf32>
7176 /// %1 = user_op %0 : vector<8xf32>
7177 ///
7178 /// becomes:
7179 ///
7180 /// %0 = user_op %a : vector<8xf32>
7181 ///
7182 /// Empty `vector.mask` with passthru operand are handled by the canonicalizer
7183 /// as it requires creating new operations.
7184 
7185 static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7186  SmallVectorImpl<OpFoldResult> &results) {
7187  if (!maskOp.isEmpty() || maskOp.hasPassthru())
7188  return failure();
7189 
7190  Block *block = maskOp.getMaskBlock();
7191  auto terminator = cast<vector::YieldOp>(block->front());
7192  if (terminator.getNumOperands() == 0) {
7193  // `vector.mask` has no results, just remove the `vector.mask`.
7194  return success();
7195  }
7196 
7197  // `vector.mask` has results, propagate the results.
7198  llvm::append_range(results, terminator.getOperands());
7199  return success();
7200 }
7201 
7202 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7203  SmallVectorImpl<OpFoldResult> &results) {
7204  if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
7205  return success();
7206 
7207  MaskFormat maskFormat = getMaskFormat(getMask());
7208  if (maskFormat != MaskFormat::AllTrue)
7209  return failure();
7210 
7211  // Move maskable operation outside of the `vector.mask` region.
7212  Operation *maskableOp = getMaskableOp();
7213  maskableOp->dropAllUses();
7214  maskableOp->moveBefore(getOperation());
7215 
7216  llvm::append_range(results, maskableOp->getResults());
7217  return success();
7218 }
7219 
7220 /// Canonialize empty `vector.mask` operations that can't be handled in
7221 /// `VectorMask::fold` as they require creating new operations.
7222 ///
7223 /// Example 1: Empty `vector.mask` with passthru operand.
7224 ///
7225 /// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
7226 /// vector<8xi1> -> vector<8xf32>
7227 ///
7228 /// becomes:
7229 ///
7230 /// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
7231 ///
7232 class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
7234 
7235  LogicalResult matchAndRewrite(MaskOp maskOp,
7236  PatternRewriter &rewriter) const override {
7237  if (!maskOp.isEmpty())
7238  return failure();
7239 
7240  if (!maskOp.hasPassthru())
7241  return failure();
7242 
7243  Block *block = maskOp.getMaskBlock();
7244  auto terminator = cast<vector::YieldOp>(block->front());
7245  assert(terminator.getNumOperands() == 1 &&
7246  "expected one result when passthru is provided");
7247 
7248  rewriter.replaceOpWithNewOp<arith::SelectOp>(
7249  maskOp, maskOp.getResultTypes(), maskOp.getMask(),
7250  terminator.getOperand(0), maskOp.getPassthru());
7251 
7252  return success();
7253  }
7254 };
7255 
7256 void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7257  MLIRContext *context) {
7258  results.add<CanonializeEmptyMaskOp>(context);
7259 }
7260 
7261 // MaskingOpInterface definitions.
7262 
7263 /// Returns the operation masked by this 'vector.mask'.
7264 Operation *MaskOp::getMaskableOp() {
7265  Block *block = getMaskBlock();
7266  if (block->getOperations().size() < 2)
7267  return nullptr;
7268 
7269  return &block->front();
7270 }
7271 
7272 /// Returns true if 'vector.mask' has a passthru value.
7273 bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
7274 
7275 //===----------------------------------------------------------------------===//
7276 // ScanOp
7277 //===----------------------------------------------------------------------===//
7278 
7279 LogicalResult ScanOp::verify() {
7280  VectorType srcType = getSourceType();
7281  VectorType initialType = getInitialValueType();
7282  // Check reduction dimension < rank.
7283  int64_t srcRank = srcType.getRank();
7284  int64_t reductionDim = getReductionDim();
7285  if (reductionDim >= srcRank)
7286  return emitOpError("reduction dimension ")
7287  << reductionDim << " has to be less than " << srcRank;
7288 
7289  // Check that rank(initial_value) = rank(src) - 1.
7290  int64_t initialValueRank = initialType.getRank();
7291  if (initialValueRank != srcRank - 1)
7292  return emitOpError("initial value rank ")
7293  << initialValueRank << " has to be equal to " << srcRank - 1;
7294 
7295  // Check shapes of initial value and src.
7296  ArrayRef<int64_t> srcShape = srcType.getShape();
7297  ArrayRef<int64_t> initialValueShapes = initialType.getShape();
7298  SmallVector<int64_t> expectedShape;
7299  for (int i = 0; i < srcRank; i++) {
7300  if (i != reductionDim)
7301  expectedShape.push_back(srcShape[i]);
7302  }
7303  if (!llvm::equal(initialValueShapes, expectedShape)) {
7304  return emitOpError("incompatible input/initial value shapes");
7305  }
7306 
7307  // Verify supported reduction kind.
7308  Type eltType = getDestType().getElementType();
7309  if (!isSupportedCombiningKind(getKind(), eltType))
7310  return emitOpError("unsupported reduction type ")
7311  << eltType << " for kind '" << stringifyCombiningKind(getKind())
7312  << "'";
7313 
7314  return success();
7315 }
7316 
7319  patterns
7320  .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7321  ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7322  StridedSliceConstantMaskFolder, TransposeFolder>(
7323  patterns.getContext(), benefit);
7324 }
7325 
7326 //===----------------------------------------------------------------------===//
7327 // SplatOp
7328 //===----------------------------------------------------------------------===//
7329 
7330 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
7331  auto constOperand = adaptor.getInput();
7332  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
7333  return {};
7334 
7335  // SplatElementsAttr::get treats single value for second arg as being a splat.
7336  return SplatElementsAttr::get(getType(), {constOperand});
7337 }
7338 
7339 // Canonicalizer for vector.splat. It always gets canonicalized to a
7340 // vector.broadcast.
7341 class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
7342 public:
7344  LogicalResult matchAndRewrite(SplatOp splatOp,
7345  PatternRewriter &rewriter) const override {
7346  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
7347  splatOp.getOperand());
7348  return success();
7349  }
7350 };
7351 void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
7352  MLIRContext *context) {
7353  results.add<SplatToBroadcastPattern>(context);
7354 }
7355 
7356 void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7357  SetIntRangeFn setResultRanges) {
7358  setResultRanges(getResult(), argRanges.front());
7359 }
7360 
7362  CombiningKind kind, Value v1, Value acc,
7363  arith::FastMathFlagsAttr fastmath,
7364  Value mask) {
7365  Type t1 = getElementTypeOrSelf(v1.getType());
7366  Type tAcc = getElementTypeOrSelf(acc.getType());
7367  Value result;
7368 
7369  switch (kind) {
7370  case CombiningKind::ADD:
7371  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7372  result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
7373  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7374  result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7375  else
7376  llvm_unreachable("invalid value types for ADD reduction");
7377  break;
7378  case CombiningKind::AND:
7379  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7380  result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
7381  break;
7382  case CombiningKind::MAXNUMF:
7383  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7384  "expected float values");
7385  result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7386  break;
7387  case CombiningKind::MAXIMUMF:
7388  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7389  "expected float values");
7390  result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7391  break;
7392  case CombiningKind::MINNUMF:
7393  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7394  "expected float values");
7395  result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7396  break;
7397  case CombiningKind::MINIMUMF:
7398  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7399  "expected float values");
7400  result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7401  break;
7402  case CombiningKind::MAXSI:
7403  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7404  result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
7405  break;
7406  case CombiningKind::MINSI:
7407  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7408  result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
7409  break;
7410  case CombiningKind::MAXUI:
7411  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7412  result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
7413  break;
7414  case CombiningKind::MINUI:
7415  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7416  result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
7417  break;
7418  case CombiningKind::MUL:
7419  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7420  result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
7421  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7422  result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7423  else
7424  llvm_unreachable("invalid value types for MUL reduction");
7425  break;
7426  case CombiningKind::OR:
7427  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7428  result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
7429  break;
7430  case CombiningKind::XOR:
7431  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7432  result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
7433  break;
7434  };
7435 
7436  assert(result && "unknown CombiningKind");
7437  return selectPassthru(b, mask, result, acc);
7438 }
7439 
7440 //===----------------------------------------------------------------------===//
7441 // StepOp
7442 //===----------------------------------------------------------------------===//
7443 
7444 void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7445  SetIntRangeFn setResultRanges) {
7446  auto resultType = cast<VectorType>(getType());
7447  if (resultType.isScalable()) {
7448  return;
7449  }
7450  unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
7451  APInt zero(bitwidth, 0);
7452  APInt high(bitwidth, resultType.getDimSize(0) - 1);
7453  ConstantIntRanges result = {zero, high, zero, high};
7454  setResultRanges(getResult(), result);
7455 }
7456 
7457 //===----------------------------------------------------------------------===//
7458 // Vector Masking Utilities
7459 //===----------------------------------------------------------------------===//
7460 
7461 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
7462 /// as masked operation.
7464  Operation *maskableOp) {
7465  assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
7466  Block *insBlock = builder.getInsertionBlock();
7467  // Create a block and move the op to that block.
7468  insBlock->getOperations().splice(
7469  insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
7470  YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults());
7471 }
7472 
7473 /// Creates a vector.mask operation around a maskable operation. Returns the
7474 /// vector.mask operation if the mask provided is valid. Otherwise, returns
7475 /// the maskable operation itself.
7477  Operation *maskableOp, Value mask,
7478  Value passthru) {
7479  if (!mask)
7480  return maskableOp;
7481  if (passthru)
7482  return MaskOp::create(builder, maskableOp->getLoc(),
7483  maskableOp->getResultTypes(), mask, passthru,
7484  maskableOp, createMaskOpRegion);
7485  return MaskOp::create(builder, maskableOp->getLoc(),
7486  maskableOp->getResultTypes(), mask, maskableOp,
7488 }
7489 
7490 /// Creates a vector select operation that picks values from `newValue` or
7491 /// `passthru` for each result vector lane based on `mask`. This utility is used
7492 /// to propagate the pass-thru value of vector.mask or for cases where only the
7493 /// pass-thru value propagation is needed. VP intrinsics do not support
7494 /// pass-thru values and every mask-out lane is set to poison. LLVM backends are
7495 /// usually able to match op + select patterns and fold them into a native
7496 /// target instructions.
7498  Value newValue, Value passthru) {
7499  if (!mask)
7500  return newValue;
7501 
7502  return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(),
7503  mask, newValue, passthru);
7504 }
7505 
7506 //===----------------------------------------------------------------------===//
7507 // TableGen'd op method definitions
7508 //===----------------------------------------------------------------------===//
7509 
7510 #define GET_ATTRDEF_CLASSES
7511 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
7512 
7513 #define GET_OP_CLASSES
7514 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static SmallVector< Value > computeStrides(Location loc, RewriterBase &rewriter, ValueRange dynamicBasis, ArrayRef< int64_t > staticBasis, bool knownNonNegative)
Given a basis (in static and dynamic components), return the sequence of suffix products of the basis...
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
union mlir::linalg::@1243::ArityGroupAndKind::Kind kind
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MINUI(lhs, rhs)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition: VectorOps.cpp:69
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
Definition: VectorOps.cpp:1127
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
Definition: VectorOps.cpp:1379
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
Definition: VectorOps.cpp:2429
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
Definition: VectorOps.cpp:1640
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
Definition: VectorOps.cpp:4678
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
Definition: VectorOps.cpp:1986
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
Definition: VectorOps.cpp:1854
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
Definition: VectorOps.cpp:1692
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
Definition: VectorOps.cpp:2387
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
Definition: VectorOps.cpp:2081
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Definition: VectorOps.cpp:2345
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
Definition: VectorOps.cpp:2867
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
Definition: VectorOps.cpp:131
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
Definition: VectorOps.cpp:2071
MaskFormat
Helper enum to classify mask value.
Definition: VectorOps.cpp:59
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
Definition: VectorOps.cpp:3676
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
Definition: VectorOps.cpp:178
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
Definition: VectorOps.cpp:2462
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
Definition: VectorOps.cpp:327
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
Definition: VectorOps.cpp:908
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
Definition: VectorOps.cpp:2655
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
Definition: VectorOps.cpp:2022
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
Definition: VectorOps.cpp:3013
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
Definition: VectorOps.cpp:3612
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
Definition: VectorOps.cpp:2365
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
Definition: VectorOps.cpp:1119
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
Definition: VectorOps.cpp:3632
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
Definition: VectorOps.cpp:206
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
Definition: VectorOps.cpp:3655
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
Definition: VectorOps.cpp:3477
static LogicalResult foldTransferFullMask(TransferOp op)
Definition: VectorOps.cpp:4924
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Definition: VectorOps.cpp:1371
static IntegerAttr convertIntegerAttr(IntegerAttr intAttr, Type expectedType)
Converts an IntegerAttr to have the specified type if needed.
Definition: VectorOps.cpp:402
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
Definition: VectorOps.cpp:1261
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
Definition: VectorOps.cpp:4166
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
Definition: VectorOps.cpp:4570
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
Definition: VectorOps.cpp:2505
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
Definition: VectorOps.cpp:3530
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
Definition: VectorOps.cpp:919
static bool isBroadcastLike(Operation *op)
All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'br...
Definition: VectorOps.cpp:1652
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
Definition: VectorOps.cpp:3597
static Value foldExtractFromShapeCast(ExtractOp extractOp)
Definition: VectorOps.cpp:1789
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
Definition: VectorOps.cpp:4599
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
Definition: VectorOps.cpp:4852
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
Definition: VectorOps.cpp:4010
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
Definition: VectorOps.cpp:1757
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
Definition: VectorOps.cpp:4869
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
Definition: VectorOps.cpp:1906
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
Definition: VectorOps.cpp:3302
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
Definition: VectorOps.cpp:2089
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Definition: VectorOps.cpp:2535
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:131
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:411
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:394
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:212
unsigned getNumResults() const
Definition: AffineMap.cpp:398
unsigned getNumInputs() const
Definition: AffineMap.cpp:399
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:157
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:58
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
Operation & back()
Definition: Block.h:152
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:162
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:166
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:111
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:323
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:52
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:265
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:280
IndexType getIndexType()
Definition: Builders.cpp:50
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:269
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:317
A set of arbitrary-precision integers representing bounds on a given integer value.
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:429
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:442
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:834
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:554
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
bool isF32() const
Definition: Types.cpp:40
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:112
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:286
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Definition: BuiltinTypes.h:298
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:305
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:391
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Definition: VectorOps.cpp:490
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
Definition: VectorOps.cpp:126
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
Definition: VectorOps.cpp:384
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:188
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition: VectorOps.h:62
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2787
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Definition: VectorOps.cpp:4697
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
Definition: VectorOps.cpp:250
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:314
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:345
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:369
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
Definition: VectorOps.cpp:666
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition: VectorOps.h:70
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Definition: VectorOps.cpp:486
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:766
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:715
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:923
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Return a fused vector::ContractionOp which represents a patterns such as:
Definition: VectorOps.cpp:1219
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Definition: VectorOps.cpp:1222
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:322
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Definition: PatternMatch.h:297
bool operator==(const KeyTy &key) const
Definition: VectorOps.cpp:421
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Definition: VectorOps.cpp:423