MLIR  21.0.0git
VectorOps.cpp
Go to the documentation of this file.
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/AffineMap.h"
28 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/IRMapping.h"
34 #include "mlir/IR/PatternMatch.h"
35 #include "mlir/IR/TypeUtilities.h"
36 #include "mlir/IR/ValueRange.h"
39 #include "mlir/Support/LLVM.h"
41 #include "llvm/ADT/ArrayRef.h"
42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/ADT/SmallVector.h"
44 #include "llvm/ADT/StringSet.h"
45 #include "llvm/ADT/TypeSwitch.h"
46 #include "llvm/Support/Casting.h"
47 
48 #include <cassert>
49 #include <cstdint>
50 #include <numeric>
51 
52 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
53 // Pull in all enum type and utility function definitions.
54 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
55 
56 using namespace mlir;
57 using namespace mlir::vector;
58 
59 /// Helper enum to classify mask value.
60 enum class MaskFormat {
61  AllTrue = 0,
62  AllFalse = 1,
63  Unknown = 2,
64 };
65 
66 /// Helper method to classify a mask value. Currently, the method
67 /// looks "under the hood" of a constant value with dense attributes
68 /// and a constant mask operation (since the client may be called at
69 /// various stages during progressive lowering).
71  if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
72  // Inspect constant dense values. We count up for bits that
73  // are set, count down for bits that are cleared, and bail
74  // when a mix is detected.
75  if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
76  int64_t val = 0;
77  for (bool b : denseElts.getValues<bool>())
78  if (b && val >= 0)
79  val++;
80  else if (!b && val <= 0)
81  val--;
82  else
83  return MaskFormat::Unknown;
84  if (val > 0)
85  return MaskFormat::AllTrue;
86  if (val < 0)
87  return MaskFormat::AllFalse;
88  }
89  } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
90  // Inspect constant mask index. If the index exceeds the
91  // dimension size, all bits are set. If the index is zero
92  // or less, no bits are set.
93  ArrayRef<int64_t> masks = m.getMaskDimSizes();
94  auto shape = m.getType().getShape();
95  bool allTrue = true;
96  bool allFalse = true;
97  for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
98  if (maskIdx < dimSize)
99  allTrue = false;
100  if (maskIdx > 0)
101  allFalse = false;
102  }
103  if (allTrue)
104  return MaskFormat::AllTrue;
105  if (allFalse)
106  return MaskFormat::AllFalse;
107  } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
108  // Finds all-false create_masks. An all-true create_mask requires all
109  // dims to be constants, so that'll be folded to a constant_mask, then
110  // detected in the constant_mask case.
111  auto maskOperands = m.getOperands();
112  for (Value operand : maskOperands) {
113  if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
114  int64_t dimSize =
115  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
116  if (dimSize <= 0)
117  return MaskFormat::AllFalse;
118  }
119  }
120  return MaskFormat::Unknown;
121  }
122  return MaskFormat::Unknown;
123 }
124 
125 /// Default callback to build a region with a 'vector.yield' terminator with no
126 /// arguments.
128  builder.create<vector::YieldOp>(loc);
129 }
130 
131 // Helper for verifying combining kinds in contractions and reductions.
132 static bool isSupportedCombiningKind(CombiningKind combiningKind,
133  Type elementType) {
134  switch (combiningKind) {
135  case CombiningKind::ADD:
136  case CombiningKind::MUL:
137  return elementType.isIntOrIndexOrFloat();
139  case CombiningKind::MINSI:
140  case CombiningKind::MAXUI:
141  case CombiningKind::MAXSI:
142  case CombiningKind::AND:
143  case CombiningKind::OR:
144  case CombiningKind::XOR:
145  return elementType.isIntOrIndex();
146  case CombiningKind::MINNUMF:
147  case CombiningKind::MAXNUMF:
148  case CombiningKind::MINIMUMF:
149  case CombiningKind::MAXIMUMF:
150  return llvm::isa<FloatType>(elementType);
151  }
152  return false;
153 }
154 
155 /// Returns the effective rank of the vector to read/write for Xfer Ops
156 ///
157 /// When the element type of the shaped type is _a scalar_, this will simply
158 /// return the rank of the vector ( the result for xfer_read or the value to
159 /// store for xfer_write).
160 ///
161 /// When the element type of the base shaped type is _a vector_, returns the
162 /// difference between the original vector type and the element type of the
163 /// shaped type.
164 ///
165 /// EXAMPLE 1 (element type is _a scalar_):
166 /// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
167 /// - shapedType.getElementType() = f32 (rank 0)
168 /// - vectorType.getRank() = 2
169 /// - Result = 2 - 0 = 2
170 ///
171 /// EXAMPLE 2 (element type is _a vector_):
172 /// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
173 /// - shapedType.getElementType() = vector<20xf32> (rank 1)
174 /// - vectorType.getRank() = 1
175 /// - Result = 1 - 1 = 0
176 ///
177 /// This is used to determine the number of minor dimensions for identity maps
178 /// in vector transfer Ops.
179 static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
180  VectorType vectorType) {
181  unsigned elementVectorRank = 0;
182  VectorType elementVectorType =
183  llvm::dyn_cast<VectorType>(shapedType.getElementType());
184  if (elementVectorType)
185  elementVectorRank += elementVectorType.getRank();
186  return vectorType.getRank() - elementVectorRank;
187 }
188 
190  VectorType vectorType) {
191  // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
192  // TODO: replace once we have 0-d vectors.
193  if (shapedType.getRank() == 0 &&
194  vectorType.getShape() == ArrayRef<int64_t>{1})
195  return AffineMap::get(
196  /*numDims=*/0, /*numSymbols=*/0,
197  getAffineConstantExpr(0, shapedType.getContext()));
199  shapedType.getRank(),
200  getEffectiveVectorRankForXferOp(shapedType, vectorType),
201  shapedType.getContext());
202 }
203 
204 /// Check if `write` is of a constant splat and the masked `read` is padded with
205 /// the same splat value -- meaning it could be the same value as the initial
206 /// constant splat.
207 static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
208  vector::TransferReadOp read) {
209  auto readMask = read.getMask();
210  auto writeMask = write.getMask();
211  // Check if the masks are consistent. The splat value could be the same if the
212  // read is masked (and padded with the splat value), and the write is unmasked
213  // or has the same mask. Note this does not allow the case where the write is
214  // masked and the read is unmasked, as then the read could be of more elements
215  // than the write (which may not be the same value).
216  bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
217  if (!couldBeSameSplat)
218  return false;
219  // Check for constant splat (as the source of the write).
220  DenseElementsAttr splatAttr;
221  if (!matchPattern(write.getVector(),
222  m_Constant<DenseElementsAttr>(&splatAttr)) ||
223  !splatAttr.isSplat()) {
224  return false;
225  }
226  // The padding of the read and the constant splat value must be the same.
227  Attribute padAttr;
228  if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
229  return false;
230  return padAttr == splatAttr.getSplatValue<Attribute>();
231 }
232 
233 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
234  vector::TransferReadOp read) {
235  return !defWrite.hasOutOfBoundsDim() &&
236  defWrite.getIndices() == read.getIndices() &&
237  defWrite.getVectorType() == read.getVectorType() &&
238  defWrite.getPermutationMap() == read.getPermutationMap() &&
239  ((!defWrite.getMask() && !read.getMask()) ||
240  isSplatWriteConsistentWithMaskedRead(defWrite, read));
241 }
242 
243 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
244  vector::TransferWriteOp priorWrite) {
245  return priorWrite.getIndices() == write.getIndices() &&
246  priorWrite.getMask() == write.getMask() &&
247  priorWrite.getVectorType() == write.getVectorType() &&
248  priorWrite.getPermutationMap() == write.getPermutationMap();
249 }
250 
252  VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
253  bool testDynamicValueUsingBounds) {
254  // For simplicity only look at transfer of same type.
255  if (transferA.getVectorType() != transferB.getVectorType())
256  return false;
257  unsigned rankOffset = transferA.getLeadingShapedRank();
258  for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
259  Value indexA = transferA.getIndices()[i];
260  Value indexB = transferB.getIndices()[i];
261  std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
262  std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
263 
264  if (i < rankOffset) {
265  // For leading dimensions, if we can prove that index are different we
266  // know we are accessing disjoint slices.
267  if (cstIndexA.has_value() && cstIndexB.has_value()) {
268  if (*cstIndexA != *cstIndexB)
269  return true;
270  continue;
271  }
272  if (testDynamicValueUsingBounds) {
273  // First try to see if we can fully compose and simplify the affine
274  // expression as a fast track.
275  FailureOr<uint64_t> delta =
277  if (succeeded(delta) && *delta != 0)
278  return true;
279 
280  FailureOr<bool> testEqual =
281  ValueBoundsConstraintSet::areEqual(indexA, indexB);
282  if (succeeded(testEqual) && !testEqual.value())
283  return true;
284  }
285  } else {
286  // For this dimension, we slice a part of the memref we need to make sure
287  // the intervals accessed don't overlap.
288  int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
289  if (cstIndexA.has_value() && cstIndexB.has_value()) {
290  int64_t distance = std::abs(*cstIndexA - *cstIndexB);
291  if (distance >= vectorDim)
292  return true;
293  continue;
294  }
295  if (testDynamicValueUsingBounds) {
296  // First try to see if we can fully compose and simplify the affine
297  // expression as a fast track.
298  FailureOr<int64_t> delta =
300  if (succeeded(delta) && std::abs(*delta) >= vectorDim)
301  return true;
302 
303  FailureOr<int64_t> computeDelta =
305  if (succeeded(computeDelta)) {
306  if (std::abs(computeDelta.value()) >= vectorDim)
307  return true;
308  }
309  }
310  }
311  }
312  return false;
313 }
314 
315 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
316  VectorTransferOpInterface transferB,
317  bool testDynamicValueUsingBounds) {
318  if (transferA.getBase() != transferB.getBase())
319  return false;
320  return isDisjointTransferIndices(transferA, transferB,
321  testDynamicValueUsingBounds);
322 }
323 
324 // Helper to iterate over n-D vector slice elements. Calculate the next
325 // `position` in the n-D vector of size `shape`, applying an offset `offsets`.
326 // Modifies the `position` in place. Returns a failure when `position` becomes
327 // the end position.
328 static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
329  ArrayRef<int64_t> shape,
330  ArrayRef<int64_t> offsets) {
331  for (auto [posInDim, dimSize, offsetInDim] :
332  llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
333  ++posInDim;
334  if (posInDim < dimSize + offsetInDim)
335  return success();
336 
337  // Carry the overflow to the next loop iteration.
338  posInDim = offsetInDim;
339  }
340 
341  return failure();
342 }
343 
344 /// Returns the integer numbers in `values`. `values` are expected to be
345 /// constant operations.
348  llvm::transform(values, std::back_inserter(ints), [](Value value) {
349  auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
350  assert(constOp && "Unexpected non-constant index");
351  return constOp.value();
352  });
353  return ints;
354 }
355 
356 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
357 /// be constant operations.
360  llvm::transform(
361  foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
362  assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
363  return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
364  });
365  return ints;
366 }
367 
368 /// Convert `foldResults` into Values. Integer attributes are converted to
369 /// constant op.
371  ArrayRef<OpFoldResult> foldResults) {
372  SmallVector<Value> values;
373  llvm::transform(foldResults, std::back_inserter(values),
374  [&](OpFoldResult foldResult) {
375  if (auto attr = dyn_cast<Attribute>(foldResult))
376  return builder
377  .create<arith::ConstantIndexOp>(
378  loc, cast<IntegerAttr>(attr).getInt())
379  .getResult();
380 
381  return cast<Value>(foldResult);
382  });
383  return values;
384 }
385 
386 std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
387  if (value.getDefiningOp<vector::VectorScaleOp>())
388  return 1;
389  auto mul = value.getDefiningOp<arith::MulIOp>();
390  if (!mul)
391  return {};
392  auto lhs = mul.getLhs();
393  auto rhs = mul.getRhs();
394  if (lhs.getDefiningOp<vector::VectorScaleOp>())
395  return getConstantIntValue(rhs);
396  if (rhs.getDefiningOp<vector::VectorScaleOp>())
397  return getConstantIntValue(lhs);
398  return {};
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // CombiningKindAttr
403 //===----------------------------------------------------------------------===//
404 
405 namespace mlir {
406 namespace vector {
407 namespace detail {
409  using KeyTy = uint64_t;
410 
411  BitmaskEnumStorage(KeyTy val) : value(val) {}
412 
413  bool operator==(const KeyTy &key) const { return value == key; }
414 
416  const KeyTy &key) {
417  return new (allocator.allocate<BitmaskEnumStorage>())
418  BitmaskEnumStorage(key);
419  }
420 
421  KeyTy value = 0;
422 };
423 } // namespace detail
424 } // namespace vector
425 } // namespace mlir
426 
427 //===----------------------------------------------------------------------===//
428 // VectorDialect
429 //===----------------------------------------------------------------------===//
430 
431 namespace {
432 /// This class defines the interface for handling inlining with vector dialect
433 /// operations.
434 struct VectorInlinerInterface : public DialectInlinerInterface {
436 
437  /// All vector dialect ops can be inlined.
438  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
439  return true;
440  }
441 };
442 } // namespace
443 
444 void VectorDialect::initialize() {
445  addAttributes<
446 #define GET_ATTRDEF_LIST
447 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
448  >();
449 
450  addOperations<
451 #define GET_OP_LIST
452 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
453  >();
454 
455  addInterfaces<VectorInlinerInterface>();
456 
457  declarePromisedInterfaces<bufferization::BufferizableOpInterface,
458  TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
459  YieldOp>();
460  declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
461  TransferWriteOp>();
462  declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
463  declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
464  declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
465 }
466 
467 /// Materialize a single constant operation from a given attribute value with
468 /// the desired resultant type.
470  Attribute value, Type type,
471  Location loc) {
472  if (isa<ub::PoisonAttrInterface>(value))
473  return value.getDialect().materializeConstant(builder, value, type, loc);
474 
475  return arith::ConstantOp::materialize(builder, value, type, loc);
476 }
477 
479  return builder.getIntegerType(64);
480 }
481 
483  ArrayRef<int64_t> values) {
484  return builder.getI64ArrayAttr(values);
485 }
486 
487 //===----------------------------------------------------------------------===//
488 // MultiDimReductionOp
489 //===----------------------------------------------------------------------===//
490 
491 void vector::MultiDimReductionOp::build(OpBuilder &builder,
492  OperationState &result, Value source,
493  Value acc, ArrayRef<bool> reductionMask,
494  CombiningKind kind) {
495  SmallVector<int64_t> reductionDims;
496  for (const auto &en : llvm::enumerate(reductionMask))
497  if (en.value())
498  reductionDims.push_back(en.index());
499  build(builder, result, kind, source, acc, reductionDims);
500 }
501 
502 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
503  // Single parallel dim, this is a noop.
504  if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
505  return getSource();
506  return {};
507 }
508 
509 std::optional<SmallVector<int64_t, 4>>
510 MultiDimReductionOp::getShapeForUnroll() {
511  return llvm::to_vector<4>(getSourceVectorType().getShape());
512 }
513 
514 LogicalResult MultiDimReductionOp::verify() {
515  SmallVector<int64_t> targetShape;
516  SmallVector<bool> scalableDims;
517  Type inferredReturnType;
518  auto sourceScalableDims = getSourceVectorType().getScalableDims();
519  for (auto [dimIdx, dimSize] :
520  llvm::enumerate(getSourceVectorType().getShape()))
521  if (!llvm::any_of(getReductionDims(),
522  [dimIdx = dimIdx](int64_t reductionDimIdx) {
523  return reductionDimIdx == static_cast<int64_t>(dimIdx);
524  })) {
525  targetShape.push_back(dimSize);
526  scalableDims.push_back(sourceScalableDims[dimIdx]);
527  }
528  // TODO: update to also allow 0-d vectors when available.
529  if (targetShape.empty())
530  inferredReturnType = getSourceVectorType().getElementType();
531  else
532  inferredReturnType = VectorType::get(
533  targetShape, getSourceVectorType().getElementType(), scalableDims);
534  if (getType() != inferredReturnType)
535  return emitOpError() << "destination type " << getType()
536  << " is incompatible with source type "
537  << getSourceVectorType();
538 
539  return success();
540 }
541 
542 /// Returns the mask type expected by this operation.
543 Type MultiDimReductionOp::getExpectedMaskType() {
544  auto vecType = getSourceVectorType();
545  return VectorType::get(vecType.getShape(),
546  IntegerType::get(vecType.getContext(), /*width=*/1),
547  vecType.getScalableDims());
548 }
549 
550 namespace {
551 // Only unit dimensions that are being reduced are folded. If the dimension is
552 // unit, but not reduced, it is not folded, thereby keeping the output type the
553 // same. If not all dimensions which are reduced are of unit dimension, this
554 // transformation does nothing. This is just a generalization of
555 // ElideSingleElementReduction for ReduceOp.
556 struct ElideUnitDimsInMultiDimReduction
557  : public OpRewritePattern<MultiDimReductionOp> {
559 
560  LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
561  PatternRewriter &rewriter) const override {
562  ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
563  for (const auto &dim : enumerate(shape)) {
564  if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
565  return failure();
566  }
567 
568  // Vector mask setup.
569  OpBuilder::InsertionGuard guard(rewriter);
570  Operation *rootOp;
571  Value mask;
572  if (reductionOp.isMasked()) {
573  rewriter.setInsertionPoint(reductionOp.getMaskingOp());
574  rootOp = reductionOp.getMaskingOp();
575  mask = reductionOp.getMaskingOp().getMask();
576  } else {
577  rootOp = reductionOp;
578  }
579 
580  Location loc = reductionOp.getLoc();
581  Value acc = reductionOp.getAcc();
582  Value cast;
583  if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
584  if (mask) {
585  VectorType newMaskType =
586  VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
587  dstVecType.getScalableDims());
588  mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
589  }
590  cast = rewriter.create<vector::ShapeCastOp>(
591  loc, reductionOp.getDestType(), reductionOp.getSource());
592  } else {
593  // This means we are reducing all the dimensions, and all reduction
594  // dimensions are of size 1. So a simple extraction would do.
595  if (mask)
596  mask = rewriter.create<vector::ExtractOp>(loc, mask);
597  cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource());
598  }
599 
600  Value result =
601  vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
602  cast, /*fastmath=*/nullptr, mask);
603  rewriter.replaceOp(rootOp, result);
604  return success();
605  }
606 };
607 } // namespace
608 
609 void MultiDimReductionOp::getCanonicalizationPatterns(
610  RewritePatternSet &results, MLIRContext *context) {
611  results.add<ElideUnitDimsInMultiDimReduction>(context);
612 }
613 
614 //===----------------------------------------------------------------------===//
615 // ReductionOp
616 //===----------------------------------------------------------------------===//
617 
618 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
619  CombiningKind kind, Value vector,
620  arith::FastMathFlags fastMathFlags) {
621  build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
622 }
623 
624 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
625  CombiningKind kind, Value vector, Value acc,
626  arith::FastMathFlags fastMathFlags) {
627  build(builder, result,
628  llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
629  acc, fastMathFlags);
630 }
631 
632 LogicalResult ReductionOp::verify() {
633  // Verify for 0-D and 1-D vector.
634  int64_t rank = getSourceVectorType().getRank();
635  if (rank > 1)
636  return emitOpError("unsupported reduction rank: ") << rank;
637 
638  // Verify supported reduction kind.
639  Type eltType = getDest().getType();
640  if (!isSupportedCombiningKind(getKind(), eltType))
641  return emitOpError("unsupported reduction type '")
642  << eltType << "' for kind '" << stringifyCombiningKind(getKind())
643  << "'";
644 
645  return success();
646 }
647 
648 // MaskableOpInterface methods.
649 
650 /// Returns the mask type expected by this operation.
651 Type ReductionOp::getExpectedMaskType() {
652  auto vecType = getSourceVectorType();
653  return VectorType::get(vecType.getShape(),
654  IntegerType::get(vecType.getContext(), /*width=*/1),
655  vecType.getScalableDims());
656 }
657 
659  OpBuilder &builder, Location loc,
660  Value vector) {
661  switch (op) {
662  case arith::AtomicRMWKind::addf:
663  case arith::AtomicRMWKind::addi:
664  return builder.create<vector::ReductionOp>(vector.getLoc(),
665  CombiningKind::ADD, vector);
666  case arith::AtomicRMWKind::mulf:
667  case arith::AtomicRMWKind::muli:
668  return builder.create<vector::ReductionOp>(vector.getLoc(),
669  CombiningKind::MUL, vector);
670  case arith::AtomicRMWKind::minimumf:
671  return builder.create<vector::ReductionOp>(vector.getLoc(),
672  CombiningKind::MINIMUMF, vector);
673  case arith::AtomicRMWKind::mins:
674  return builder.create<vector::ReductionOp>(vector.getLoc(),
675  CombiningKind::MINSI, vector);
676  case arith::AtomicRMWKind::minu:
677  return builder.create<vector::ReductionOp>(vector.getLoc(),
678  CombiningKind::MINUI, vector);
679  case arith::AtomicRMWKind::maximumf:
680  return builder.create<vector::ReductionOp>(vector.getLoc(),
681  CombiningKind::MAXIMUMF, vector);
682  case arith::AtomicRMWKind::maxs:
683  return builder.create<vector::ReductionOp>(vector.getLoc(),
684  CombiningKind::MAXSI, vector);
685  case arith::AtomicRMWKind::maxu:
686  return builder.create<vector::ReductionOp>(vector.getLoc(),
687  CombiningKind::MAXUI, vector);
688  case arith::AtomicRMWKind::andi:
689  return builder.create<vector::ReductionOp>(vector.getLoc(),
690  CombiningKind::AND, vector);
691  case arith::AtomicRMWKind::ori:
692  return builder.create<vector::ReductionOp>(vector.getLoc(),
693  CombiningKind::OR, vector);
694  // TODO: Add remaining reduction operations.
695  default:
696  (void)emitOptionalError(loc, "Reduction operation type not supported");
697  break;
698  }
699  return nullptr;
700 }
701 
702 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
703  return llvm::to_vector<4>(getSourceVectorType().getShape());
704 }
705 
706 namespace {
707 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
709 
710  LogicalResult matchAndRewrite(ReductionOp reductionOp,
711  PatternRewriter &rewriter) const override {
712  // Vector mask setup.
713  OpBuilder::InsertionGuard guard(rewriter);
714  auto maskableOp =
715  cast<vector::MaskableOpInterface>(reductionOp.getOperation());
716  Operation *rootOp;
717  Value mask;
718  if (maskableOp.isMasked()) {
719  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
720  rootOp = maskableOp.getMaskingOp();
721  mask = maskableOp.getMaskingOp().getMask();
722  } else {
723  rootOp = reductionOp;
724  }
725 
726  auto vectorType = reductionOp.getSourceVectorType();
727  if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
728  return failure();
729 
730  Location loc = reductionOp.getLoc();
731  if (mask)
732  mask = rewriter.create<ExtractOp>(loc, mask);
733  Value result = rewriter.create<ExtractOp>(loc, reductionOp.getVector());
734 
735  if (Value acc = reductionOp.getAcc())
736  result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
737  result, acc,
738  reductionOp.getFastmathAttr(), mask);
739 
740  rewriter.replaceOp(rootOp, result);
741  return success();
742  }
743 };
744 } // namespace
745 
746 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
747  MLIRContext *context) {
748  results.add<ElideSingleElementReduction>(context);
749 }
750 
751 //===----------------------------------------------------------------------===//
752 // ContractionOp
753 //===----------------------------------------------------------------------===//
754 
755 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
756  Value lhs, Value rhs, Value acc,
757  ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
758  ArrayRef<IteratorType> iteratorTypes) {
759  result.addOperands({lhs, rhs, acc});
760  result.addTypes(acc.getType());
761  result.addAttribute(
762  getIndexingMapsAttrName(result.name),
763  builder.getAffineMapArrayAttr(
764  AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
765  result.addAttribute(
766  getIteratorTypesAttrName(result.name),
767  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
768  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
769  return IteratorTypeAttr::get(builder.getContext(), t);
770  }))));
771 }
772 
773 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
774  Value lhs, Value rhs, Value acc,
775  ArrayAttr indexingMaps,
776  ArrayAttr iteratorTypes) {
777  build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
778  ContractionOp::getDefaultKind());
779 }
780 
781 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
782  Value lhs, Value rhs, Value acc,
783  ArrayAttr indexingMaps,
784  ArrayAttr iteratorTypes, CombiningKind kind) {
785  result.addOperands({lhs, rhs, acc});
786  result.addTypes(acc.getType());
787  result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
788  result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
789  result.addAttribute(getKindAttrName(result.name),
791 }
792 
793 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
798  SmallVector<Type, 2> types;
799  Type resultType;
800  auto loc = parser.getCurrentLocation();
801  DictionaryAttr dictAttr;
802  // TODO: Unify linalg op attribute parsing.
803  if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
804  parser.parseComma() || parser.parseOperand(rhsInfo) ||
805  parser.parseComma() || parser.parseOperand(accInfo) ||
806  parser.parseTrailingOperandList(masksInfo) ||
807  parser.parseOptionalAttrDict(result.attributes) ||
808  parser.parseColonTypeList(types) ||
809  parser.parseKeywordType("into", resultType) ||
810  parser.resolveOperand(lhsInfo, types[0], result.operands) ||
811  parser.resolveOperand(rhsInfo, types[1], result.operands) ||
812  parser.resolveOperand(accInfo, resultType, result.operands) ||
813  parser.addTypeToList(resultType, result.types))
814  return failure();
815  result.attributes.append(dictAttr.getValue().begin(),
816  dictAttr.getValue().end());
817 
818  // Convert array of string into an array of IteratyType enums. This is needed,
819  // because tests still use the old format when 'iterator_types' attribute is
820  // represented as an array of strings.
821  // TODO: Remove this conversion once tests are fixed.
822  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
823  result.attributes.get(getIteratorTypesAttrName(result.name)));
824  if (!iteratorTypes) {
825  return parser.emitError(loc)
826  << "expected " << getIteratorTypesAttrName(result.name)
827  << " array attribute";
828  }
829 
830  SmallVector<Attribute> iteratorTypeAttrs;
831 
832  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
833  auto maybeIteratorType = symbolizeIteratorType(s);
834  if (!maybeIteratorType.has_value())
835  return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
836 
837  iteratorTypeAttrs.push_back(
838  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
839  }
840  result.attributes.set(getIteratorTypesAttrName(result.name),
841  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
842 
843  if (!result.attributes.get(getKindAttrName(result.name))) {
844  result.addAttribute(
845  getKindAttrName(result.name),
847  ContractionOp::getDefaultKind()));
848  }
849  if (masksInfo.empty())
850  return success();
851  if (masksInfo.size() != 2)
852  return parser.emitError(parser.getNameLoc(),
853  "expected zero or exactly 2 vector mask operands");
854  auto lhsType = llvm::cast<VectorType>(types[0]);
855  auto rhsType = llvm::cast<VectorType>(types[1]);
856  auto maskElementType = parser.getBuilder().getI1Type();
857  std::array<VectorType, 2> maskTypes = {
858  VectorType::Builder(lhsType).setElementType(maskElementType),
859  VectorType::Builder(rhsType).setElementType(maskElementType)};
860  if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
861  return failure();
862  return success();
863 }
864 
866  // TODO: Unify printing code with linalg ops.
867  auto attrNames = getTraitAttrNames();
868  llvm::StringSet<> traitAttrsSet;
869  traitAttrsSet.insert_range(attrNames);
871  for (auto attr : (*this)->getAttrs()) {
872  if (attr.getName() == getIteratorTypesAttrName()) {
873  auto iteratorTypes =
874  llvm::cast<ArrayAttr>(attr.getValue())
875  .getAsValueRange<IteratorTypeAttr, IteratorType>();
876  // Convert IteratorType enums into the string representation. This is
877  // needed, because tests still use the old format when 'iterator_types'
878  // attribute is represented as an array of strings.
879  // TODO: Remove this conversion once tests are fixed.
880  SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
881  llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
882  return StringAttr::get(getContext(), stringifyIteratorType(t));
883  }));
884 
885  attrs.emplace_back(getIteratorTypesAttrName(),
886  ArrayAttr::get(getContext(), iteratorTypeNames));
887  } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
888  attrs.push_back(attr);
889  }
890 
891  auto dictAttr = DictionaryAttr::get(getContext(), attrs);
892  p << " " << dictAttr << " " << getLhs() << ", ";
893  p << getRhs() << ", " << getAcc();
894 
895  p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
896  p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
897  << getResultType();
898 }
899 
900 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
901  const std::vector<std::pair<int64_t, int64_t>> &map) {
902  for (auto &dimPair : map) {
903  if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
904  dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
905  lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
906  return false;
907  }
908  return true;
909 }
910 
911 static LogicalResult verifyOutputShape(
912  ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
913  Type resType,
914  const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
915  const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
916  DenseSet<int64_t> lhsContractingDimSet;
917  DenseSet<int64_t> rhsContractingDimSet;
918  for (auto &dimPair : contractingDimMap) {
919  lhsContractingDimSet.insert(dimPair.first);
920  rhsContractingDimSet.insert(dimPair.second);
921  }
922  DenseSet<int64_t> rhsBatchDimSet(llvm::from_range,
923  llvm::make_second_range(batchDimMap));
924 
925  // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
926  SmallVector<int64_t, 4> expectedResultDims;
927  for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
928  if (lhsContractingDimSet.count(i) > 0)
929  continue;
930  expectedResultDims.push_back(lhsType.getDimSize(i));
931  }
932 
933  // Add free dimensions from 'rhsType' to 'expectedResultDims'.
934  for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
935  if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
936  continue;
937  expectedResultDims.push_back(rhsType.getDimSize(i));
938  }
939 
940  // Verify 'expectedResultDims'.
941  if (expectedResultDims.empty()) {
942  // No batch or free dimension implies a scalar result.
943  if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
944  return op.emitOpError("invalid accumulator/result vector shape");
945  } else {
946  // At least one batch or free dimension implies a vector result.
947  auto resVectorType = llvm::dyn_cast<VectorType>(resType);
948  auto accVectorType = llvm::dyn_cast<VectorType>(accType);
949  if (!resVectorType || !accVectorType)
950  return op.emitOpError("invalid accumulator/result vector shape");
951 
952  // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
953  // types fully define the result vector type. This assumes the affine maps
954  // are well-formed, which must have been verified already.
955  MLIRContext *ctx = op.getContext();
956  AffineMap lhsMap = op.getIndexingMapsArray()[0];
957  AffineMap rhsMap = op.getIndexingMapsArray()[1];
958  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
959  return op.emitOpError(
960  "expected all dimensions to be either a LHS or a RHS dimension");
961  SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
962  for (auto pair :
963  {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
964  VectorType v = pair.first;
965  auto map = pair.second;
966  for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
967  unsigned pos = map.getDimPosition(idx);
968  if (!extents[pos])
969  extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
970  }
971  }
972  if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
973  return op.emitOpError("expected all dimensions to get an extent as "
974  "either a LHS or a RHS dimension");
975 
976  AffineMap resMap = op.getIndexingMapsArray()[2];
977  auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
978  /*symbolCount=*/0, extents, ctx);
979  // Compose the resMap with the extentsMap, which is a constant map.
980  AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
981  assert(llvm::all_of(expectedMap.getResults(),
982  llvm::IsaPred<AffineConstantExpr>) &&
983  "expected constant extent along all dimensions.");
984  // Extract the expected shape and build the type.
985  auto expectedShape = llvm::to_vector<4>(
986  llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
987  return cast<AffineConstantExpr>(e).getValue();
988  }));
989  auto expected =
990  VectorType::get(expectedShape, resVectorType.getElementType(),
991  resVectorType.getScalableDims());
992  if (resVectorType != expected || accVectorType != expected)
993  return op.emitOpError(
994  "invalid accumulator/result vector shape, expected: ")
995  << expected;
996  }
997  return success();
998 }
999 
1000 LogicalResult ContractionOp::verify() {
1001  VectorType lhsType = getLhsType();
1002  VectorType rhsType = getRhsType();
1003  Type accType = getAccType();
1004  Type resType = getResultType();
1005 
1006  if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1007  if (!lhsType.getElementType().isSignlessInteger())
1008  return emitOpError("only supports signless integer types");
1009  }
1010 
1011  // Verify that an indexing map was specified for each vector operand.
1012  if (getIndexingMapsArray().size() != 3)
1013  return emitOpError("expected an indexing map for each vector operand");
1014 
1015  // Verify that each index map has 'numIterators' inputs, no symbols, and
1016  // that the number of map outputs equals the rank of its associated
1017  // vector operand.
1018  unsigned numIterators = getIteratorTypes().getValue().size();
1019  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1020  auto index = it.index();
1021  auto map = it.value();
1022  if (map.getNumSymbols() != 0)
1023  return emitOpError("expected indexing map ")
1024  << index << " to have no symbols";
1025  auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
1026  unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1027  // Verify that the map has the right number of inputs, outputs, and indices.
1028  // This also correctly accounts for (..) -> () for rank-0 results.
1029  if (map.getNumDims() != numIterators)
1030  return emitOpError("expected indexing map ")
1031  << index << " to have " << numIterators << " number of inputs";
1032  if (map.getNumResults() != rank)
1033  return emitOpError("expected indexing map ")
1034  << index << " to have " << rank << " number of outputs";
1035  if (!map.isProjectedPermutation())
1036  return emitOpError("expected indexing map ")
1037  << index << " to be a projected permutation of its inputs";
1038  }
1039 
1040  auto contractingDimMap = getContractingDimMap();
1041  auto batchDimMap = getBatchDimMap();
1042 
1043  // Verify at least one contracting dimension pair was specified.
1044  if (contractingDimMap.empty())
1045  return emitOpError("expected at least one contracting dimension pair");
1046 
1047  // Verify contracting dimension map was properly constructed.
1048  if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1049  return emitOpError("invalid contracting dimension map");
1050 
1051  // Verify batch dimension map was properly constructed.
1052  if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1053  return emitOpError("invalid batch dimension map");
1054 
1055  // Verify 'accType' and 'resType' shape.
1056  if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1057  contractingDimMap, batchDimMap)))
1058  return failure();
1059 
1060  // Verify supported combining kind.
1061  auto vectorType = llvm::dyn_cast<VectorType>(resType);
1062  auto elementType = vectorType ? vectorType.getElementType() : resType;
1063  if (!isSupportedCombiningKind(getKind(), elementType))
1064  return emitOpError("unsupported contraction type");
1065 
1066  return success();
1067 }
1068 
1069 // MaskableOpInterface methods.
1070 
1071 /// Returns the mask type expected by this operation. Mostly used for
1072 /// verification purposes. It requires the operation to be vectorized."
1073 Type ContractionOp::getExpectedMaskType() {
1074  auto indexingMaps = this->getIndexingMapsArray();
1075  AffineMap lhsIdxMap = indexingMaps[0];
1076  AffineMap rhsIdxMap = indexingMaps[1];
1077  VectorType lhsType = this->getLhsType();
1078  VectorType rhsType = this->getRhsType();
1079 
1080  unsigned numVecDims = lhsIdxMap.getNumDims();
1081  SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1082  SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1083 
1084  // Using the information in the indexing maps, extract the size of each
1085  // dimension in the vector.contract operation from the two input operands.
1086  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1087  maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1088  maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1089  lhsType.getScalableDims()[dimIdx];
1090  }
1091  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1092  maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1093  maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1094  rhsType.getScalableDims()[dimIdx];
1095  }
1096 
1097  assert(!ShapedType::isDynamicShape(maskShape) &&
1098  "Mask shape couldn't be computed");
1099 
1100  return VectorType::get(maskShape,
1101  IntegerType::get(lhsType.getContext(), /*width=*/1),
1102  maskShapeScalableDims);
1103 }
1104 
1105 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1106  return SmallVector<StringRef>{getIndexingMapsAttrName(),
1107  getIteratorTypesAttrName(), getKindAttrName()};
1108 }
1109 
1110 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
1111  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1112  if (targetExpr == map.getResult(i))
1113  return i;
1114  return -1;
1115 }
1116 
1117 static std::vector<std::pair<int64_t, int64_t>>
1118 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1119  IteratorType targetIteratorType, MLIRContext *context) {
1120  std::vector<std::pair<int64_t, int64_t>> dimMap;
1121  for (const auto &it : llvm::enumerate(iteratorTypes)) {
1122  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1123  if (iteratorType != targetIteratorType)
1124  continue;
1125  // Search lhs/rhs map results for 'targetExpr'.
1126  auto targetExpr = getAffineDimExpr(it.index(), context);
1127  int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1128  int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1129  if (lhsDim >= 0 && rhsDim >= 0)
1130  dimMap.emplace_back(lhsDim, rhsDim);
1131  }
1132  return dimMap;
1133 }
1134 
1135 void ContractionOp::getIterationBounds(
1136  SmallVectorImpl<int64_t> &iterationBounds) {
1137  auto lhsShape = getLhsType().getShape();
1138  auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1139  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1140  for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1141  // Search lhs/rhs map results for 'targetExpr'.
1142  auto targetExpr = getAffineDimExpr(it.index(), getContext());
1143  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1144  if (iteratorType == IteratorType::reduction) {
1145  // Get reduction dim size from lhs shape (same size in rhsShape).
1146  int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1147  assert(lhsDimIndex >= 0);
1148  iterationBounds.push_back(lhsShape[lhsDimIndex]);
1149  continue;
1150  }
1151  // Get parallel dimension size from result shape.
1152  int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1153  assert(resDimIndex >= 0);
1154  assert(resVectorType != nullptr);
1155  iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1156  }
1157 }
1158 
1159 void ContractionOp::getIterationIndexMap(
1160  std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1161  unsigned numMaps = getIndexingMapsArray().size();
1162  iterationIndexMap.resize(numMaps);
1163  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1164  auto index = it.index();
1165  auto map = it.value();
1166  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1167  auto dim = cast<AffineDimExpr>(map.getResult(i));
1168  iterationIndexMap[index][dim.getPosition()] = i;
1169  }
1170  }
1171 }
1172 
1173 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1174  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1175  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1176  getContext());
1177 }
1178 
1179 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1180  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1181  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1182  getContext());
1183 }
1184 
1185 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1187  getIterationBounds(shape);
1188  return shape;
1189 }
1190 
1191 /// Return a fused vector::ContractionOp which represents a patterns such as:
1192 ///
1193 /// ```mlir
1194 /// %c0 = vector.constant 0: ...
1195 /// %c = vector.contract %a, %b, %c0: ...
1196 /// %e = add %c, %d: ...
1197 /// ```
1198 ///
1199 /// by:
1200 ///
1201 /// ```mlir
1202 /// %e = vector.contract %a, %b, %d: ...
1203 /// ```
1204 ///
1205 /// Return null if the canonicalization does not apply.
1206 // TODO: This should be a folding of Add into Contract in core but while they
1207 // live in different dialects, it is not possible without unnatural
1208 // dependencies.
1209 template <typename AddOpType>
1210 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1212 
1213  LogicalResult matchAndRewrite(AddOpType addOp,
1214  PatternRewriter &rewriter) const override {
1215  auto canonicalize = [&](Value maybeContraction,
1216  Value otherOperand) -> vector::ContractionOp {
1217  vector::ContractionOp contractionOp =
1218  dyn_cast_or_null<vector::ContractionOp>(
1219  maybeContraction.getDefiningOp());
1220  if (!contractionOp)
1221  return vector::ContractionOp();
1222  if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1223  contractionOp.getAcc().getDefiningOp())) {
1224  if (maybeZero.getValue() ==
1225  rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1226  IRMapping bvm;
1227  bvm.map(contractionOp.getAcc(), otherOperand);
1228  auto newContraction =
1229  cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1230  rewriter.replaceOp(addOp, newContraction.getResult());
1231  return newContraction;
1232  }
1233  }
1234  return vector::ContractionOp();
1235  };
1236 
1237  Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1238  vector::ContractionOp contract = canonicalize(a, b);
1239  contract = contract ? contract : canonicalize(b, a);
1240  return contract ? success() : failure();
1241  }
1242 };
1243 
1244 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1245  MLIRContext *context) {
1248 }
1249 
1250 //===----------------------------------------------------------------------===//
1251 // ExtractElementOp
1252 //===----------------------------------------------------------------------===//
1253 
1254 void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1255  SetIntRangeFn setResultRanges) {
1256  setResultRanges(getResult(), argRanges.front());
1257 }
1258 
1259 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1260  Value source) {
1261  result.addOperands({source});
1262  result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
1263 }
1264 
1265 LogicalResult vector::ExtractElementOp::verify() {
1266  VectorType vectorType = getSourceVectorType();
1267  if (vectorType.getRank() == 0) {
1268  if (getPosition())
1269  return emitOpError("expected position to be empty with 0-D vector");
1270  return success();
1271  }
1272  if (vectorType.getRank() != 1)
1273  return emitOpError("unexpected >1 vector rank");
1274  if (!getPosition())
1275  return emitOpError("expected position for 1-D vector");
1276  return success();
1277 }
1278 
1279 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1280  // Skip the 0-D vector here now.
1281  if (!adaptor.getPosition())
1282  return {};
1283 
1284  // Fold extractelement (splat X) -> X.
1285  if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
1286  return splat.getInput();
1287 
1288  // Fold extractelement(broadcast(X)) -> X.
1289  if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1290  if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
1291  return broadcast.getSource();
1292 
1293  auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1294  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1295  if (!pos || !src)
1296  return {};
1297 
1298  auto srcElements = src.getValues<Attribute>();
1299 
1300  uint64_t posIdx = pos.getInt();
1301  if (posIdx >= srcElements.size())
1302  return {};
1303 
1304  return srcElements[posIdx];
1305 }
1306 
1307 // Returns `true` if `index` is either within [0, maxIndex) or equal to
1308 // `poisonValue`.
1309 static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
1310  int64_t maxIndex) {
1311  return index == poisonValue || (index >= 0 && index < maxIndex);
1312 }
1313 
1314 //===----------------------------------------------------------------------===//
1315 // ExtractOp
1316 //===----------------------------------------------------------------------===//
1317 
1318 void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1319  SetIntRangeFn setResultRanges) {
1320  setResultRanges(getResult(), argRanges.front());
1321 }
1322 
1323 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1324  Value source) {
1325  auto vectorTy = cast<VectorType>(source.getType());
1326  build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
1327 }
1328 
1329 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1330  Value source, int64_t position) {
1331  build(builder, result, source, ArrayRef<int64_t>{position});
1332 }
1333 
1334 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1335  Value source, OpFoldResult position) {
1336  build(builder, result, source, ArrayRef<OpFoldResult>{position});
1337 }
1338 
1339 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1340  Value source, ArrayRef<int64_t> position) {
1341  build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1342  builder.getDenseI64ArrayAttr(position));
1343 }
1344 
1345 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1346  Value source, ArrayRef<OpFoldResult> position) {
1347  SmallVector<int64_t> staticPos;
1348  SmallVector<Value> dynamicPos;
1349  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1350  build(builder, result, source, dynamicPos,
1351  builder.getDenseI64ArrayAttr(staticPos));
1352 }
1353 
1354 LogicalResult
1355 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1356  ExtractOp::Adaptor adaptor,
1357  SmallVectorImpl<Type> &inferredReturnTypes) {
1358  auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1359  if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1360  vectorType.getRank()) {
1361  inferredReturnTypes.push_back(vectorType.getElementType());
1362  } else {
1363  auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1364  vectorType.getRank());
1365  inferredReturnTypes.push_back(VectorType::get(
1366  vectorType.getShape().drop_front(n), vectorType.getElementType(),
1367  vectorType.getScalableDims().drop_front(n)));
1368  }
1369  return success();
1370 }
1371 
1372 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1373  // Allow extracting 1-element vectors instead of scalars.
1374  auto isCompatible = [](TypeRange l, TypeRange r) {
1375  auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1376  return vectorType && vectorType.getShape().equals({1}) &&
1377  vectorType.getElementType() == r.front();
1378  };
1379  if (l.size() == 1 && r.size() == 1 &&
1380  (isCompatible(l, r) || isCompatible(r, l)))
1381  return true;
1382  return l == r;
1383 }
1384 
1385 LogicalResult vector::ExtractOp::verify() {
1386  // Note: This check must come before getMixedPosition() to prevent a crash.
1387  auto dynamicMarkersCount =
1388  llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1389  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1390  return emitOpError(
1391  "mismatch between dynamic and static positions (kDynamic marker but no "
1392  "corresponding dynamic position) -- this can only happen due to an "
1393  "incorrect fold/rewrite");
1394  auto position = getMixedPosition();
1395  if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1396  return emitOpError(
1397  "expected position attribute of rank no greater than vector rank");
1398  for (auto [idx, pos] : llvm::enumerate(position)) {
1399  if (auto attr = dyn_cast<Attribute>(pos)) {
1400  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1402  constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1403  return emitOpError("expected position attribute #")
1404  << (idx + 1)
1405  << " to be a non-negative integer smaller than the "
1406  "corresponding vector dimension or poison (-1)";
1407  }
1408  }
1409  }
1410  return success();
1411 }
1412 
1413 template <typename IntType>
1414 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1415  return llvm::to_vector<4>(llvm::map_range(
1416  arrayAttr.getAsRange<IntegerAttr>(),
1417  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1418 }
1419 
1420 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1421 /// positions.
1422 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1423  if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1424  return failure();
1425 
1426  // TODO: Canonicalization for dynamic position not implemented yet.
1427  if (extractOp.hasDynamicPosition())
1428  return failure();
1429 
1430  SmallVector<int64_t> globalPosition;
1431  ExtractOp currentOp = extractOp;
1432  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1433  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1434  while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1435  currentOp = nextOp;
1436  // TODO: Canonicalization for dynamic position not implemented yet.
1437  if (currentOp.hasDynamicPosition())
1438  return failure();
1439  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1440  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1441  }
1442  extractOp.setOperand(0, currentOp.getVector());
1443  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1444  OpBuilder b(extractOp.getContext());
1445  std::reverse(globalPosition.begin(), globalPosition.end());
1446  extractOp.setStaticPosition(globalPosition);
1447  return success();
1448 }
1449 
1450 namespace {
1451 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1452 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1453 /// Compose TransposeOp permutations as we walk back.
1454 /// This helper class keeps an updated extraction position `extractPosition`
1455 /// with extra trailing sentinels.
1456 /// The sentinels encode the internal transposition status of the result vector.
1457 /// As we iterate, extractPosition is permuted and updated.
1458 class ExtractFromInsertTransposeChainState {
1459 public:
1460  ExtractFromInsertTransposeChainState(ExtractOp e);
1461 
1462  /// Iterate over producing insert and transpose ops until we find a fold.
1463  Value fold();
1464 
1465 private:
1466  /// Return true if the vector at position `a` is contained within the vector
1467  /// at position `b`. Under insert/extract semantics, this is the same as `a`
1468  /// is a prefix of `b`.
1469  template <typename ContainerA, typename ContainerB>
1470  bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1471  return a.size() <= b.size() &&
1472  std::equal(a.begin(), a.begin() + a.size(), b.begin());
1473  }
1474 
1475  /// Return true if the vector at position `a` intersects the vector at
1476  /// position `b`. Under insert/extract semantics, this is the same as equality
1477  /// of all entries of `a` that are >=0 with the corresponding entries of b.
1478  /// Comparison is on the common prefix (i.e. zip).
1479  template <typename ContainerA, typename ContainerB>
1480  bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1481  for (auto [elemA, elemB] : llvm::zip(a, b)) {
1482  if (elemA < 0 || elemB < 0)
1483  continue;
1484  if (elemA != elemB)
1485  return false;
1486  }
1487  return true;
1488  }
1489 
1490  /// Folding is only possible in the absence of an internal permutation in the
1491  /// result vector.
1492  bool canFold() {
1493  return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1494  }
1495 
1496  // Helper to get the next defining op of interest.
1497  void updateStateForNextIteration(Value v) {
1498  nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1499  nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1500  };
1501 
1502  // Case 1. If we hit a transpose, just compose the map and iterate.
1503  // Invariant: insert + transpose do not change rank, we can always compose.
1504  LogicalResult handleTransposeOp();
1505 
1506  // Case 2: the insert position matches extractPosition exactly, early return.
1507  LogicalResult handleInsertOpWithMatchingPos(Value &res);
1508 
1509  /// Case 3: if the insert position is a prefix of extractPosition, extract a
1510  /// portion of the source of the insert.
1511  /// Example:
1512  /// ```
1513  /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1514  /// // extractPosition == [1, 2, 3]
1515  /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1516  /// // can fold to vector.extract %source[0, 3]
1517  /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1518  /// ```
1519  /// To traverse through %source, we need to set the leading dims to 0 and
1520  /// drop the extra leading dims.
1521  /// This method updates the internal state.
1522  LogicalResult handleInsertOpWithPrefixPos(Value &res);
1523 
1524  /// Try to fold in place to extract(source, extractPosition) and return the
1525  /// folded result. Return null if folding is not possible (e.g. due to an
1526  /// internal transposition in the result).
1527  Value tryToFoldExtractOpInPlace(Value source);
1528 
1529  ExtractOp extractOp;
1530  int64_t vectorRank;
1531  int64_t extractedRank;
1532 
1533  InsertOp nextInsertOp;
1534  TransposeOp nextTransposeOp;
1535 
1536  /// Sentinel values that encode the internal permutation status of the result.
1537  /// They are set to (-1, ... , -k) at the beginning and appended to
1538  /// `extractPosition`.
1539  /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1540  /// ensure that there is no internal transposition.
1541  /// Internal transposition cannot be accounted for with a folding pattern.
1542  // TODO: We could relax the internal transposition with an extra transposition
1543  // operation in a future canonicalizer.
1544  SmallVector<int64_t> sentinels;
1546 };
1547 } // namespace
1548 
1549 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1550  ExtractOp e)
1551  : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1552  extractedRank(extractOp.getNumIndices()) {
1553  assert(vectorRank >= extractedRank && "Extracted position overflow");
1554  sentinels.reserve(vectorRank - extractedRank);
1555  for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1556  sentinels.push_back(-(i + 1));
1557  extractPosition.assign(extractOp.getStaticPosition().begin(),
1558  extractOp.getStaticPosition().end());
1559  llvm::append_range(extractPosition, sentinels);
1560 }
1561 
1562 // Case 1. If we hit a transpose, just compose the map and iterate.
1563 // Invariant: insert + transpose do not change rank, we can always compose.
1564 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1565  // TODO: Canonicalization for dynamic position not implemented yet.
1566  if (extractOp.hasDynamicPosition())
1567  return failure();
1568 
1569  if (!nextTransposeOp)
1570  return failure();
1571  AffineMap m = inversePermutation(AffineMap::getPermutationMap(
1572  nextTransposeOp.getPermutation(), extractOp.getContext()));
1574  return success();
1575 }
1576 
1577 // Case 2: the insert position matches extractPosition exactly, early return.
1578 LogicalResult
1579 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1580  Value &res) {
1581  // TODO: Canonicalization for dynamic position not implemented yet.
1582  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1583  return failure();
1584 
1585  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1586  if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1587  return failure();
1588  // Case 2.a. early-exit fold.
1589  res = nextInsertOp.getValueToStore();
1590  // Case 2.b. if internal transposition is present, canFold will be false.
1591  return success(canFold());
1592 }
1593 
1594 /// Case 3: if inserted position is a prefix of extractPosition,
1595 /// extract a portion of the source of the insertion.
1596 /// This method updates the internal state.
1597 LogicalResult
1598 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1599  // TODO: Canonicalization for dynamic position not implemented yet.
1600  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1601  return failure();
1602 
1603  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1604  if (!isContainedWithin(insertedPos, extractPosition))
1605  return failure();
1606  // Set leading dims to zero.
1607  std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1608  // Drop extra leading dims.
1609  extractPosition.erase(extractPosition.begin(),
1610  extractPosition.begin() + insertedPos.size());
1611  extractedRank = extractPosition.size() - sentinels.size();
1612  // Case 3.a. early-exit fold (break and delegate to post-while path).
1613  res = nextInsertOp.getValueToStore();
1614  // Case 3.b. if internal transposition is present, canFold will be false.
1615  return success();
1616 }
1617 
1618 /// Try to fold in place to extract(source, extractPosition) and return the
1619 /// folded result. Return null if folding is not possible (e.g. due to an
1620 /// internal transposition in the result).
1621 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1622  Value source) {
1623  // TODO: Canonicalization for dynamic position not implemented yet.
1624  if (extractOp.hasDynamicPosition())
1625  return Value();
1626 
1627  // If we can't fold (either internal transposition, or nothing to fold), bail.
1628  bool nothingToFold = (source == extractOp.getVector());
1629  if (nothingToFold || !canFold())
1630  return Value();
1631 
1632  // Otherwise, fold by updating the op inplace and return its result.
1633  OpBuilder b(extractOp.getContext());
1634  extractOp.setStaticPosition(
1635  ArrayRef(extractPosition).take_front(extractedRank));
1636  extractOp.getVectorMutable().assign(source);
1637  return extractOp.getResult();
1638 }
1639 
1640 /// Iterate over producing insert and transpose ops until we find a fold.
1641 Value ExtractFromInsertTransposeChainState::fold() {
1642  // TODO: Canonicalization for dynamic position not implemented yet.
1643  if (extractOp.hasDynamicPosition())
1644  return Value();
1645 
1646  Value valueToExtractFrom = extractOp.getVector();
1647  updateStateForNextIteration(valueToExtractFrom);
1648  while (nextInsertOp || nextTransposeOp) {
1649  // Case 1. If we hit a transpose, just compose the map and iterate.
1650  // Invariant: insert + transpose do not change rank, we can always compose.
1651  if (succeeded(handleTransposeOp())) {
1652  valueToExtractFrom = nextTransposeOp.getVector();
1653  updateStateForNextIteration(valueToExtractFrom);
1654  continue;
1655  }
1656 
1657  Value result;
1658  // Case 2: the position match exactly.
1659  if (succeeded(handleInsertOpWithMatchingPos(result)))
1660  return result;
1661 
1662  // Case 3: if the inserted position is a prefix of extractPosition, we can
1663  // just extract a portion of the source of the insert.
1664  if (succeeded(handleInsertOpWithPrefixPos(result)))
1665  return tryToFoldExtractOpInPlace(result);
1666 
1667  // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1668  // values. This is a more difficult case and we bail.
1669  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1670  if (isContainedWithin(extractPosition, insertedPos) ||
1671  intersectsWhereNonNegative(extractPosition, insertedPos))
1672  return Value();
1673 
1674  // Case 5: No intersection, we forward the extract to insertOp.dest().
1675  valueToExtractFrom = nextInsertOp.getDest();
1676  updateStateForNextIteration(valueToExtractFrom);
1677  }
1678  // If after all this we can fold, go for it.
1679  return tryToFoldExtractOpInPlace(valueToExtractFrom);
1680 }
1681 
1682 /// Returns true if the operation has a 0-D vector type operand or result.
1683 static bool hasZeroDimVectors(Operation *op) {
1684  auto hasZeroDimVectorType = [](Type type) -> bool {
1685  auto vecType = dyn_cast<VectorType>(type);
1686  return vecType && vecType.getRank() == 0;
1687  };
1688 
1689  return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1690  llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1691 }
1692 
1693 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1694 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1695  Operation *defOp = extractOp.getVector().getDefiningOp();
1696  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1697  return Value();
1698 
1699  Value source = defOp->getOperand(0);
1700  if (extractOp.getType() == source.getType())
1701  return source;
1702  auto getRank = [](Type type) {
1703  return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1704  : 0;
1705  };
1706 
1707  // If splat or broadcast from a scalar, just return the source scalar.
1708  unsigned broadcastSrcRank = getRank(source.getType());
1709  if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
1710  return source;
1711 
1712  unsigned extractResultRank = getRank(extractOp.getType());
1713  if (extractResultRank > broadcastSrcRank)
1714  return Value();
1715  // Check that the dimension of the result haven't been broadcasted.
1716  auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1717  auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
1718  if (extractVecType && broadcastVecType &&
1719  extractVecType.getShape() !=
1720  broadcastVecType.getShape().take_back(extractResultRank))
1721  return Value();
1722 
1723  auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1724  int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1725 
1726  // Detect all the positions that come from "dim-1" broadcasting.
1727  // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1728  // extract position to `0` when extracting from the source operand.
1729  llvm::SetVector<int64_t> broadcastedUnitDims =
1730  broadcastOp.computeBroadcastedUnitDims();
1731  SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
1732  OpBuilder b(extractOp.getContext());
1733  int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1734  for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1735  if (broadcastedUnitDims.contains(i))
1736  extractPos[i] = b.getIndexAttr(0);
1737  // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1738  // matching extract position when extracting from the source operand.
1739  int64_t rankDiff = broadcastSrcRank - extractResultRank;
1740  extractPos.erase(extractPos.begin(),
1741  std::next(extractPos.begin(), extractPos.size() - rankDiff));
1742  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1743  auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
1744  extractOp->setOperands(
1745  llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
1746  extractOp.setStaticPosition(staticPos);
1747  return extractOp.getResult();
1748 }
1749 
1750 /// Fold extractOp coming from ShuffleOp.
1751 ///
1752 /// Example:
1753 ///
1754 /// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1755 /// : vector<8xf32>, vector<8xf32>
1756 /// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1757 /// ->
1758 /// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1759 ///
1760 static Value foldExtractFromShuffle(ExtractOp extractOp) {
1761  // Dynamic positions are not folded as the resulting code would be more
1762  // complex than the input code.
1763  if (extractOp.hasDynamicPosition())
1764  return Value();
1765 
1766  auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1767  if (!shuffleOp)
1768  return Value();
1769 
1770  // TODO: 0-D or multi-dimensional vectors not supported yet.
1771  if (shuffleOp.getResultVectorType().getRank() != 1)
1772  return Value();
1773 
1774  int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1775  auto shuffleMask = shuffleOp.getMask();
1776  int64_t extractIdx = extractOp.getStaticPosition()[0];
1777  int64_t shuffleIdx = shuffleMask[extractIdx];
1778 
1779  // Find the shuffled vector to extract from based on the shuffle index.
1780  if (shuffleIdx < inputVecSize) {
1781  extractOp.setOperand(0, shuffleOp.getV1());
1782  extractOp.setStaticPosition({shuffleIdx});
1783  } else {
1784  extractOp.setOperand(0, shuffleOp.getV2());
1785  extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1786  }
1787 
1788  return extractOp.getResult();
1789 }
1790 
1791 // Fold extractOp with source coming from ShapeCast op.
1792 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1793  // TODO: Canonicalization for dynamic position not implemented yet.
1794  if (extractOp.hasDynamicPosition())
1795  return Value();
1796 
1797  auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1798  if (!shapeCastOp)
1799  return Value();
1800 
1801  // Get the nth dimension size starting from lowest dimension.
1802  auto getDimReverse = [](VectorType type, int64_t n) {
1803  return type.getShape().take_back(n + 1).front();
1804  };
1805  int64_t destinationRank =
1806  llvm::isa<VectorType>(extractOp.getType())
1807  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1808  : 0;
1809  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1810  return Value();
1811  if (destinationRank > 0) {
1812  auto destinationType =
1813  llvm::cast<VectorType>(extractOp.getResult().getType());
1814  for (int64_t i = 0; i < destinationRank; i++) {
1815  // The lowest dimension of the destination must match the lowest
1816  // dimension of the shapecast op source.
1817  // TODO: This case could be support in a canonicalization pattern.
1818  if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1819  getDimReverse(destinationType, i))
1820  return Value();
1821  }
1822  }
1823  // Extract the strides associated with the extract op vector source. Then use
1824  // this to calculate a linearized position for the extract.
1825  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1826  std::reverse(extractedPos.begin(), extractedPos.end());
1827  SmallVector<int64_t, 4> strides;
1828  int64_t stride = 1;
1829  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1830  strides.push_back(stride);
1831  stride *=
1832  getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1833  }
1834 
1835  int64_t position = linearize(extractedPos, strides);
1836  // Then extract the strides associated to the shapeCast op vector source and
1837  // delinearize the position using those strides.
1838  SmallVector<int64_t, 4> newStrides;
1839  int64_t numDimension =
1840  shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1841  stride = 1;
1842  for (int64_t i = 0; i < numDimension; i++) {
1843  newStrides.push_back(stride);
1844  stride *=
1845  getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1846  }
1847  std::reverse(newStrides.begin(), newStrides.end());
1848  SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1849  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1850  OpBuilder b(extractOp.getContext());
1851  extractOp.setStaticPosition(newPosition);
1852  extractOp.setOperand(0, shapeCastOp.getSource());
1853  return extractOp.getResult();
1854 }
1855 
1856 /// Fold an ExtractOp from ExtractStridedSliceOp.
1857 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1858  // TODO: Canonicalization for dynamic position not implemented yet.
1859  if (extractOp.hasDynamicPosition())
1860  return Value();
1861 
1862  auto extractStridedSliceOp =
1863  extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1864  if (!extractStridedSliceOp)
1865  return Value();
1866 
1867  // 0-D vectors not supported.
1868  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1869  if (hasZeroDimVectors(extractStridedSliceOp))
1870  return Value();
1871 
1872  // Return if 'extractStridedSliceOp' has non-unit strides.
1873  if (extractStridedSliceOp.hasNonUnitStrides())
1874  return Value();
1875 
1876  // Trim offsets for dimensions fully extracted.
1877  auto sliceOffsets =
1878  extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1879  while (!sliceOffsets.empty()) {
1880  size_t lastOffset = sliceOffsets.size() - 1;
1881  if (sliceOffsets.back() != 0 ||
1882  extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1883  extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1884  break;
1885  sliceOffsets.pop_back();
1886  }
1887  unsigned destinationRank = 0;
1888  if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1889  destinationRank = vecType.getRank();
1890  // The dimensions of the result need to be untouched by the
1891  // extractStridedSlice op.
1892  if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1893  sliceOffsets.size())
1894  return Value();
1895 
1896  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1897  assert(extractedPos.size() >= sliceOffsets.size());
1898  for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1899  extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1900  extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1901 
1902  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1903  OpBuilder b(extractOp.getContext());
1904  extractOp.setStaticPosition(extractedPos);
1905  return extractOp.getResult();
1906 }
1907 
1908 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1909 static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1910  // TODO: Canonicalization for dynamic position not implemented yet.
1911  if (extractOp.hasDynamicPosition())
1912  return Value();
1913 
1914  int64_t destinationRank =
1915  llvm::isa<VectorType>(extractOp.getType())
1916  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1917  : 0;
1918  auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1919  if (!insertOp)
1920  return Value();
1921 
1922  // 0-D vectors not supported.
1923  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1924  if (hasZeroDimVectors(insertOp))
1925  return Value();
1926 
1927  while (insertOp) {
1928  int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1929  insertOp.getSourceVectorType().getRank();
1930  if (destinationRank > insertOp.getSourceVectorType().getRank())
1931  return Value();
1932  auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1933  ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1934 
1935  if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1936  return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1937  }))
1938  return Value();
1939  bool disjoint = false;
1940  SmallVector<int64_t, 4> offsetDiffs;
1941  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1942  int64_t start = insertOffsets[dim];
1943  int64_t size =
1944  (dim < insertRankDiff)
1945  ? 1
1946  : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1947  int64_t end = start + size;
1948  int64_t offset = extractOffsets[dim];
1949  // Check if the start of the extract offset is in the interval inserted.
1950  if (start <= offset && offset < end) {
1951  if (dim >= insertRankDiff)
1952  offsetDiffs.push_back(offset - start);
1953  continue;
1954  }
1955  disjoint = true;
1956  break;
1957  }
1958  // The extract element chunk overlap with the vector inserted.
1959  if (!disjoint) {
1960  // If any of the inner dimensions are only partially inserted we have a
1961  // partial overlap.
1962  int64_t srcRankDiff =
1963  insertOp.getSourceVectorType().getRank() - destinationRank;
1964  for (int64_t i = 0; i < destinationRank; i++) {
1965  if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1966  insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1967  insertRankDiff))
1968  return Value();
1969  }
1970  extractOp.getVectorMutable().assign(insertOp.getValueToStore());
1971  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1972  OpBuilder b(extractOp.getContext());
1973  extractOp.setStaticPosition(offsetDiffs);
1974  return extractOp.getResult();
1975  }
1976  // If the chunk extracted is disjoint from the chunk inserted, keep
1977  // looking in the insert chain.
1978  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1979  }
1980  return Value();
1981 }
1982 
1983 /// Try to fold the extraction of a scalar from a vector defined by
1984 /// vector.from_elements. E.g.:
1985 ///
1986 /// %0 = vector.from_elements %a, %b : vector<2xf32>
1987 /// %1 = vector.extract %0[0] : f32 from vector<2xf32>
1988 /// ==> fold to %a
1989 static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
1990  // Dynamic extractions cannot be folded.
1991  if (extractOp.hasDynamicPosition())
1992  return {};
1993 
1994  // Look for extract(from_elements).
1995  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
1996  if (!fromElementsOp)
1997  return {};
1998 
1999  // Scalable vectors are not supported.
2000  auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2001  if (vecType.isScalable())
2002  return {};
2003 
2004  // Only extractions of scalars are supported.
2005  int64_t rank = vecType.getRank();
2006  ArrayRef<int64_t> indices = extractOp.getStaticPosition();
2007  if (extractOp.getType() != vecType.getElementType())
2008  return {};
2009  assert(static_cast<int64_t>(indices.size()) == rank &&
2010  "unexpected number of indices");
2011 
2012  // Compute flattened/linearized index and fold to operand.
2013  int flatIndex = 0;
2014  int stride = 1;
2015  for (int i = rank - 1; i >= 0; --i) {
2016  flatIndex += indices[i] * stride;
2017  stride *= vecType.getDimSize(i);
2018  }
2019  return fromElementsOp.getElements()[flatIndex];
2020 }
2021 
2022 /// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
2023 /// then fold it.
2024 template <typename OpType, typename AdaptorType>
2025 static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
2026  SmallVectorImpl<Value> &operands) {
2027  std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2028  OperandRange dynamicPosition = op.getDynamicPosition();
2029  ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2031  if constexpr (std::is_same_v<OpType, ExtractOp>)
2032  vectorShape = op.getSourceVectorType().getShape();
2033  else
2034  vectorShape = op.getDestVectorType().getShape();
2035 
2036  // If the dynamic operands is empty, it is returned directly.
2037  if (!dynamicPosition.size())
2038  return {};
2039 
2040  // `index` is used to iterate over the `dynamicPosition`.
2041  unsigned index = 0;
2042 
2043  // `opChange` is a flag. If it is true, it means to update `op` in place.
2044  bool opChange = false;
2045  for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2046  if (!ShapedType::isDynamic(staticPosition[i]))
2047  continue;
2048  Attribute positionAttr = dynamicPositionAttr[index];
2049  Value position = dynamicPosition[index++];
2050  if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2051  int64_t value = attr.getInt();
2052  // Do not fold if the value is out of bounds (-1 signifies a poison
2053  // value rather than OOB index).
2054  if (value >= -1 && value < vectorShape[i]) {
2055  staticPosition[i] = attr.getInt();
2056  opChange = true;
2057  continue;
2058  }
2059  }
2060  operands.push_back(position);
2061  }
2062 
2063  if (opChange) {
2064  op.setStaticPosition(staticPosition);
2065  op.getOperation()->setOperands(operands);
2066  // Return the original result to indicate an in-place folding happened.
2067  return op.getResult();
2068  }
2069  return {};
2070 }
2071 
2072 /// Fold an insert or extract operation into an poison value when a poison index
2073 /// is found at any dimension of the static position.
2075  ArrayRef<int64_t> staticPos,
2076  int64_t poisonVal) {
2077  if (!is_contained(staticPos, poisonVal))
2078  return {};
2079 
2080  return ub::PoisonAttr::get(context);
2081 }
2082 
2083 /// Fold a vector extract from is a poison source.
2085  if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2086  return srcAttr;
2087 
2088  return {};
2089 }
2090 
2091 /// Fold a vector extract extracting from a DenseElementsAttr.
2093  Attribute srcAttr) {
2094  auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2095  if (!denseAttr) {
2096  return {};
2097  }
2098 
2099  if (denseAttr.isSplat()) {
2100  Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2101  if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2102  newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2103  return newAttr;
2104  }
2105 
2106  auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2107  if (vecTy.isScalable())
2108  return {};
2109 
2110  if (extractOp.hasDynamicPosition()) {
2111  return {};
2112  }
2113 
2114  // Materializing subsets of a large constant array can generally lead to
2115  // explosion in IR size because of different combination of subsets that
2116  // can exist. However, vector.extract is a restricted form of subset
2117  // extract where you can only extract non-overlapping (or the same) subset for
2118  // a given rank of the subset. Because of this property, the IR size can only
2119  // increase at most by `rank * size(array)` from a single constant array being
2120  // extracted by multiple extracts.
2121 
2122  // Calculate the linearized position of the continuous chunk of elements to
2123  // extract.
2124  SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2125  copy(extractOp.getStaticPosition(), completePositions.begin());
2126  int64_t startPos =
2127  linearize(completePositions, computeStrides(vecTy.getShape()));
2128  auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2129 
2130  TypedAttr newAttr;
2131  if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2132  SmallVector<Attribute> elementValues(
2133  denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2134  newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2135  } else {
2136  newAttr = *denseValuesBegin;
2137  }
2138 
2139  return newAttr;
2140 }
2141 
2142 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2143  // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2144  // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2145  // mismatch).
2146  if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
2147  return getVector();
2148  if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
2149  return res;
2150  // Fold `arith.constant` indices into the `vector.extract` operation.
2151  // Do not stop here as this fold may enable subsequent folds that require
2152  // constant indices.
2153  SmallVector<Value> operands = {getVector()};
2154  auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
2155 
2156  if (auto res = foldPoisonIndexInsertExtractOp(
2157  getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2158  return res;
2159  if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
2160  return res;
2161  if (succeeded(foldExtractOpFromExtractChain(*this)))
2162  return getResult();
2163  if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2164  return res;
2165  if (auto res = foldExtractFromBroadcast(*this))
2166  return res;
2167  if (auto res = foldExtractFromShuffle(*this))
2168  return res;
2169  if (auto res = foldExtractFromShapeCast(*this))
2170  return res;
2171  if (auto val = foldExtractFromExtractStrided(*this))
2172  return val;
2173  if (auto val = foldExtractStridedOpFromInsertChain(*this))
2174  return val;
2175  if (auto val = foldScalarExtractFromFromElements(*this))
2176  return val;
2177 
2178  return inplaceFolded;
2179 }
2180 
2181 namespace {
2182 
2183 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2184 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2185 public:
2187 
2188  LogicalResult matchAndRewrite(ExtractOp extractOp,
2189  PatternRewriter &rewriter) const override {
2190  Operation *defOp = extractOp.getVector().getDefiningOp();
2191  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2192  return failure();
2193 
2194  Value source = defOp->getOperand(0);
2195  if (extractOp.getType() == source.getType())
2196  return failure();
2197  auto getRank = [](Type type) {
2198  return llvm::isa<VectorType>(type)
2199  ? llvm::cast<VectorType>(type).getRank()
2200  : 0;
2201  };
2202  unsigned broadcastSrcRank = getRank(source.getType());
2203  unsigned extractResultRank = getRank(extractOp.getType());
2204  // We only consider the case where the rank of the source is less than or
2205  // equal to the rank of the extract dst. The other cases are handled in the
2206  // folding patterns.
2207  if (extractResultRank < broadcastSrcRank)
2208  return failure();
2209  // For scalar result, the input can only be a rank-0 vector, which will
2210  // be handled by the folder.
2211  if (extractResultRank == 0)
2212  return failure();
2213 
2214  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2215  extractOp, extractOp.getType(), source);
2216  return success();
2217  }
2218 };
2219 
2220 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2221 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2222 public:
2224 
2225  LogicalResult matchAndRewrite(ExtractOp extractOp,
2226  PatternRewriter &rewriter) const override {
2227  auto createMaskOp =
2228  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2229  if (!createMaskOp)
2230  return failure();
2231 
2232  VectorType extractedMaskType =
2233  llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2234 
2235  if (!extractedMaskType)
2236  return failure();
2237 
2238  auto maskOperands = createMaskOp.getOperands();
2239  ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2240  VectorType maskType = createMaskOp.getVectorType();
2241 
2242  bool containsUnknownDims = false;
2243  bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2244 
2245  for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2246  dimIdx++) {
2247  int64_t pos = extractOpPos[dimIdx];
2248  Value operand = maskOperands[dimIdx];
2249  auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2250  if (!constantOp) {
2251  // Bounds of this dim unknown.
2252  containsUnknownDims = true;
2253  continue;
2254  }
2255 
2256  int64_t createMaskBound =
2257  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2258 
2259  if (pos != ShapedType::kDynamic) {
2260  // If any position is outside the range from the `create_mask`, then the
2261  // extracted mask will be all-false.
2262  allFalse |= pos >= createMaskBound;
2263  } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2264  // This dim is not all-true and since this is a dynamic index we don't
2265  // know if the extraction is within the true or false region.
2266  // Note: Zero dims have already handled via getMaskFormat().
2267  containsUnknownDims = true;
2268  }
2269  }
2270 
2271  if (allFalse) {
2272  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2273  extractOp, DenseElementsAttr::get(extractedMaskType, false));
2274  } else if (!containsUnknownDims) {
2275  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2276  extractOp, extractedMaskType,
2277  maskOperands.drop_front(extractOpPos.size()));
2278  } else {
2279  return failure();
2280  }
2281  return success();
2282  }
2283 };
2284 
2285 // Folds extract(shape_cast(..)) into shape_cast when the total element count
2286 // does not change.
2287 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2288  PatternRewriter &rewriter) {
2289  auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2290  if (!castOp)
2291  return failure();
2292 
2293  VectorType sourceType = castOp.getSourceVectorType();
2294  auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2295  if (!targetType)
2296  return failure();
2297 
2298  if (sourceType.getNumElements() != targetType.getNumElements())
2299  return failure();
2300 
2301  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2302  castOp.getSource());
2303  return success();
2304 }
2305 
2306 /// Try to canonicalize the extraction of a subvector from a vector defined by
2307 /// vector.from_elements. E.g.:
2308 ///
2309 /// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2310 /// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2311 /// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2312 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2313  PatternRewriter &rewriter) {
2314  // Dynamic positions are not supported.
2315  if (extractOp.hasDynamicPosition())
2316  return failure();
2317 
2318  // Scalar extracts are handled by the folder.
2319  auto resultType = dyn_cast<VectorType>(extractOp.getType());
2320  if (!resultType)
2321  return failure();
2322 
2323  // Look for extracts from a from_elements op.
2324  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2325  if (!fromElementsOp)
2326  return failure();
2327  VectorType inputType = fromElementsOp.getType();
2328 
2329  // Scalable vectors are not supported.
2330  if (resultType.isScalable() || inputType.isScalable())
2331  return failure();
2332 
2333  // Compute the position of first extracted element and flatten/linearize the
2334  // position.
2335  SmallVector<int64_t> firstElementPos =
2336  llvm::to_vector(extractOp.getStaticPosition());
2337  firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2338  int flatIndex = 0;
2339  int stride = 1;
2340  for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2341  flatIndex += firstElementPos[i] * stride;
2342  stride *= inputType.getDimSize(i);
2343  }
2344 
2345  // Replace the op with a smaller from_elements op.
2346  rewriter.replaceOpWithNewOp<FromElementsOp>(
2347  extractOp, resultType,
2348  fromElementsOp.getElements().slice(flatIndex,
2349  resultType.getNumElements()));
2350  return success();
2351 }
2352 
2353 } // namespace
2354 
2355 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2356  MLIRContext *context) {
2357  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2358  results.add(foldExtractFromShapeCastToShapeCast);
2359  results.add(foldExtractFromFromElements);
2360 }
2361 
2362 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
2363  SmallVectorImpl<int64_t> &results) {
2364  for (auto attr : arrayAttr)
2365  results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2366 }
2367 
2368 //===----------------------------------------------------------------------===//
2369 // FmaOp
2370 //===----------------------------------------------------------------------===//
2371 
2372 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2373  return llvm::to_vector<4>(getVectorType().getShape());
2374 }
2375 
2376 //===----------------------------------------------------------------------===//
2377 // ToElementsOp
2378 //===----------------------------------------------------------------------===//
2379 
2380 /// Returns true if all the `operands` are defined by `defOp`.
2381 /// Otherwise, returns false.
2382 static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
2383  if (operands.empty())
2384  return false;
2385 
2386  return llvm::all_of(operands, [&](Value operand) {
2387  Operation *currentDef = operand.getDefiningOp();
2388  return currentDef == defOp;
2389  });
2390 }
2391 
2392 /// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
2393 /// (%e0, %e1, ...). For example:
2394 ///
2395 /// %0 = vector.from_elements %a, %b, %c : vector<3xf32>
2396 /// %1:3 = vector.to_elements %0 : vector<3xf32>
2397 /// user_op %1#0, %1#1, %1#2
2398 ///
2399 /// becomes:
2400 ///
2401 /// user_op %a, %b, %c
2402 ///
2403 static LogicalResult
2404 foldToElementsFromElements(ToElementsOp toElementsOp,
2405  SmallVectorImpl<OpFoldResult> &results) {
2406  auto fromElementsOp =
2407  toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2408  if (!fromElementsOp)
2409  return failure();
2410 
2411  llvm::append_range(results, fromElementsOp.getElements());
2412  return success();
2413 }
2414 
2415 LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2416  SmallVectorImpl<OpFoldResult> &results) {
2417  return foldToElementsFromElements(*this, results);
2418 }
2419 
2420 //===----------------------------------------------------------------------===//
2421 // FromElementsOp
2422 //===----------------------------------------------------------------------===//
2423 
2424 /// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
2425 ///
2426 /// Case #1: Input and output vectors are the same.
2427 ///
2428 /// %0:3 = vector.to_elements %a : vector<3xf32>
2429 /// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
2430 /// user_op %1
2431 ///
2432 /// becomes:
2433 ///
2434 /// user_op %a
2435 ///
2436 static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
2437  OperandRange fromElemsOperands = fromElementsOp.getElements();
2438  if (fromElemsOperands.empty())
2439  return {};
2440 
2441  auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2442  if (!toElementsOp)
2443  return {};
2444 
2445  if (!haveSameDefiningOp(fromElemsOperands, toElementsOp))
2446  return {};
2447 
2448  // Case #1: Input and output vectors are the same. Forward the input vector.
2449  Value toElementsInput = toElementsOp.getSource();
2450  if (fromElementsOp.getType() == toElementsInput.getType() &&
2451  llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2452  return toElementsInput;
2453  }
2454 
2455  // TODO: Support cases with different input and output shapes and different
2456  // number of elements.
2457 
2458  return {};
2459 }
2460 
2461 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2462  return foldFromElementsToElements(*this);
2463 }
2464 
2465 /// Rewrite a vector.from_elements into a vector.splat if all elements are the
2466 /// same SSA value. E.g.:
2467 ///
2468 /// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2469 /// ==> rewrite to vector.splat %a : vector<3xf32>
2470 static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
2471  PatternRewriter &rewriter) {
2472  if (!llvm::all_equal(fromElementsOp.getElements()))
2473  return failure();
2474  rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
2475  fromElementsOp.getElements().front());
2476  return success();
2477 }
2478 
2479 /// Rewrite from_elements on multiple scalar extracts as a shape_cast
2480 /// on a single extract. Example:
2481 /// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
2482 /// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
2483 /// %2 = vector.from_elements %0, %1 : vector<2xi8>
2484 ///
2485 /// becomes
2486 /// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
2487 /// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
2488 ///
2489 /// The requirements for this to be valid are
2490 ///
2491 /// i) The elements are extracted from the same vector (%source).
2492 ///
2493 /// ii) The elements form a suffix of %source. Specifically, the number
2494 /// of elements is the same as the product of the last N dimension sizes
2495 /// of %source, for some N.
2496 ///
2497 /// iii) The elements are extracted contiguously in ascending order.
2498 
2499 class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
2500 
2501  using OpRewritePattern::OpRewritePattern;
2502 
2503  LogicalResult matchAndRewrite(FromElementsOp fromElements,
2504  PatternRewriter &rewriter) const override {
2505 
2506  // Handled by `rewriteFromElementsAsSplat`
2507  if (fromElements.getType().getNumElements() == 1)
2508  return failure();
2509 
2510  // The common source that all elements are extracted from, if one exists.
2511  TypedValue<VectorType> source;
2512  // The position of the combined extract operation, if one is created.
2513  ArrayRef<int64_t> combinedPosition;
2514  // The expected index of extraction of the current element in the loop, if
2515  // elements are extracted contiguously in ascending order.
2516  SmallVector<int64_t> expectedPosition;
2517 
2518  for (auto [insertIndex, element] :
2519  llvm::enumerate(fromElements.getElements())) {
2520 
2521  // Check that the element is from a vector.extract operation.
2522  auto extractOp =
2523  dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
2524  if (!extractOp) {
2525  return rewriter.notifyMatchFailure(fromElements,
2526  "element not from vector.extract");
2527  }
2528 
2529  // Check condition (i) by checking that all elements have the same source
2530  // as the first element.
2531  if (insertIndex == 0) {
2532  source = extractOp.getVector();
2533  } else if (extractOp.getVector() != source) {
2534  return rewriter.notifyMatchFailure(fromElements,
2535  "element from different vector");
2536  }
2537 
2538  ArrayRef<int64_t> position = extractOp.getStaticPosition();
2539  int64_t rank = position.size();
2540  assert(rank == source.getType().getRank() &&
2541  "scalar extract must have full rank position");
2542 
2543  // Check condition (ii) by checking that the position that the first
2544  // element is extracted from has sufficient trailing 0s. For example, in
2545  //
2546  // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
2547  // [...]
2548  // %elms = vector.from_elements %elm0, [...] : vector<12xi8>
2549  //
2550  // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
2551  // elements, which is the number of elements of %n, so this is valid.
2552  if (insertIndex == 0) {
2553  const int64_t numElms = fromElements.getType().getNumElements();
2554  int64_t numSuffixElms = 1;
2555  int64_t index = rank;
2556  while (index > 0 && position[index - 1] == 0 &&
2557  numSuffixElms < numElms) {
2558  numSuffixElms *= source.getType().getDimSize(index - 1);
2559  --index;
2560  }
2561  if (numSuffixElms != numElms) {
2562  return rewriter.notifyMatchFailure(
2563  fromElements, "elements do not form a suffix of source");
2564  }
2565  expectedPosition = llvm::to_vector(position);
2566  combinedPosition = position.drop_back(rank - index);
2567  }
2568 
2569  // Check condition (iii).
2570  else if (expectedPosition != position) {
2571  return rewriter.notifyMatchFailure(
2572  fromElements, "elements not in ascending order (static order)");
2573  }
2574  increment(expectedPosition, source.getType().getShape());
2575  }
2576 
2577  auto extracted = rewriter.createOrFold<vector::ExtractOp>(
2578  fromElements.getLoc(), source, combinedPosition);
2579 
2580  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2581  fromElements, fromElements.getType(), extracted);
2582 
2583  return success();
2584  }
2585 
2586  /// Increments n-D `indices` by 1 starting from the innermost dimension.
2587  static void increment(MutableArrayRef<int64_t> indices,
2588  ArrayRef<int64_t> shape) {
2589  for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2590  indices[dim] += 1;
2591  if (indices[dim] < shape[dim])
2592  break;
2593  indices[dim] = 0;
2594  }
2595  }
2596 };
2597 
2598 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2599  MLIRContext *context) {
2601  results.add<FromElementsToShapeCast>(context);
2602 }
2603 
2604 //===----------------------------------------------------------------------===//
2605 // BroadcastOp
2606 //===----------------------------------------------------------------------===//
2607 
2608 void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2609  SetIntRangeFn setResultRanges) {
2610  setResultRanges(getResult(), argRanges.front());
2611 }
2612 
2613 std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2614  return llvm::to_vector<4>(getResultVectorType().getShape());
2615 }
2616 
2617 /// Return the dimensions of the result vector that were formerly ones in the
2618 /// source tensor and thus correspond to "dim-1" broadcasting.
2621  ArrayRef<int64_t> dstShape) {
2622  int64_t rankDiff = dstShape.size() - srcShape.size();
2623  int64_t dstDim = rankDiff;
2625  for (auto [s1, s2] :
2626  llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2627  if (s1 != s2) {
2628  assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2629  res.insert(dstDim);
2630  }
2631  ++dstDim;
2632  }
2633  return res;
2634 }
2635 
2637  // Scalar broadcast is without any unit dim broadcast.
2638  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2639  if (!srcVectorType)
2640  return {};
2641  return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2642  getResultVectorType().getShape());
2643 }
2644 
2645 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2646 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
2647 /// This requires (and asserts) that the broadcast is free of "dim-1"
2648 /// broadcasting.
2649 /// Since vector.broadcast only allows expanding leading dimensions, an extra
2650 /// vector.transpose may be inserted to make the broadcast possible.
2651 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2652 /// the helper will assert. This means:
2653 /// 1. `dstShape` must not be empty.
2654 /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2655 /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2656 // must match the `value` shape.
2657 Value BroadcastOp::createOrFoldBroadcastOp(
2658  OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2659  const llvm::SetVector<int64_t> &broadcastedDims) {
2660  assert(!dstShape.empty() && "unexpected empty dst shape");
2661 
2662  // Well-formedness check.
2663  SmallVector<int64_t> checkShape;
2664  for (int i = 0, e = dstShape.size(); i < e; ++i) {
2665  if (broadcastedDims.contains(i))
2666  continue;
2667  checkShape.push_back(dstShape[i]);
2668  }
2669  assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2670  "ill-formed broadcastedDims contains values not confined to "
2671  "destVectorShape");
2672 
2673  Location loc = value.getLoc();
2674  Type elementType = getElementTypeOrSelf(value.getType());
2675  VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2676  VectorType dstVectorType = VectorType::get(dstShape, elementType);
2677 
2678  // Step 2. If scalar -> dstShape broadcast, just do it.
2679  if (!srcVectorType) {
2680  assert(checkShape.empty() &&
2681  "ill-formed createOrFoldBroadcastOp arguments");
2682  return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2683  }
2684 
2685  assert(srcVectorType.getShape().equals(checkShape) &&
2686  "ill-formed createOrFoldBroadcastOp arguments");
2687 
2688  // Step 3. Since vector.broadcast only allows creating leading dims,
2689  // vector -> dstShape broadcast may require a transpose.
2690  // Traverse the dims in order and construct:
2691  // 1. The leading entries of the broadcastShape that is guaranteed to be
2692  // achievable by a simple broadcast.
2693  // 2. The induced permutation for the subsequent vector.transpose that will
2694  // bring us from `broadcastShape` back to he desired `dstShape`.
2695  // If the induced permutation is not the identity, create a vector.transpose.
2696  SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2697  broadcastShape.reserve(dstShape.size());
2698  // Consider the example:
2699  // srcShape = 2x4
2700  // dstShape = 1x2x3x4x5
2701  // broadcastedDims = [0, 2, 4]
2702  //
2703  // We want to build:
2704  // broadcastShape = 1x3x5x2x4
2705  // permutation = [0, 2, 4, 1, 3]
2706  // ---V--- -----V-----
2707  // leading broadcast part src shape part
2708  //
2709  // Note that the trailing dims of broadcastShape are exactly the srcShape
2710  // by construction.
2711  // nextSrcShapeDim is used to keep track of where in the permutation the
2712  // "src shape part" occurs.
2713  int64_t nextSrcShapeDim = broadcastedDims.size();
2714  for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2715  if (broadcastedDims.contains(i)) {
2716  // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2717  // bring it to the head of the broadcastShape.
2718  // It will need to be permuted back from `broadcastShape.size() - 1` into
2719  // position `i`.
2720  broadcastShape.push_back(dstShape[i]);
2721  permutation[i] = broadcastShape.size() - 1;
2722  } else {
2723  // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2724  // shape and needs to be permuted into position `i`.
2725  // Don't touch `broadcastShape` here, the whole srcShape will be
2726  // appended after.
2727  permutation[i] = nextSrcShapeDim++;
2728  }
2729  }
2730  // 3.c. Append the srcShape.
2731  llvm::append_range(broadcastShape, srcVectorType.getShape());
2732 
2733  // Ensure there are no "dim-1" broadcasts.
2734  assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2735  .empty() &&
2736  "unexpected \"dim-1\" broadcast");
2737 
2738  VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2739  assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2740  vector::BroadcastableToResult::Success &&
2741  "must be broadcastable");
2742  Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2743  // Step 4. If we find any dimension that indeed needs to be permuted,
2744  // immediately return a new vector.transpose.
2745  for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2746  if (permutation[i] != i)
2747  return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2748  // Otherwise return res.
2749  return res;
2750 }
2751 
2753  Type srcType, VectorType dstVectorType,
2754  std::pair<VectorDim, VectorDim> *mismatchingDims) {
2755  // Broadcast scalar to vector of the same element type.
2756  if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
2757  getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
2758  return BroadcastableToResult::Success;
2759  // From now on, only vectors broadcast.
2760  VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2761  if (!srcVectorType)
2762  return BroadcastableToResult::SourceTypeNotAVector;
2763 
2764  int64_t srcRank = srcVectorType.getRank();
2765  int64_t dstRank = dstVectorType.getRank();
2766  if (srcRank > dstRank)
2767  return BroadcastableToResult::SourceRankHigher;
2768  // Source has an exact match or singleton value for all trailing dimensions
2769  // (all leading dimensions are simply duplicated).
2770  int64_t lead = dstRank - srcRank;
2771  for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2772  // Have mismatching dims (in the sense of vector.broadcast semantics) been
2773  // encountered?
2774  bool foundMismatchingDims = false;
2775 
2776  // Check fixed-width dims.
2777  int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2778  int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2779  if (srcDim != 1 && srcDim != dstDim)
2780  foundMismatchingDims = true;
2781 
2782  // Check scalable flags.
2783  bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2784  bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2785  if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2786  // 1 -> [N] is fine, everything else should be rejected when mixing
2787  // fixed-width and scalable dims
2788  (srcDimScalableFlag != dstDimScalableFlag &&
2789  (srcDim != 1 || srcDimScalableFlag)))
2790  foundMismatchingDims = true;
2791 
2792  if (foundMismatchingDims) {
2793  if (mismatchingDims != nullptr) {
2794  mismatchingDims->first.dim = srcDim;
2795  mismatchingDims->first.isScalable = srcDimScalableFlag;
2796 
2797  mismatchingDims->second.dim = dstDim;
2798  mismatchingDims->second.isScalable = dstDimScalableFlag;
2799  }
2800  return BroadcastableToResult::DimensionMismatch;
2801  }
2802  }
2803 
2804  return BroadcastableToResult::Success;
2805 }
2806 
2807 LogicalResult BroadcastOp::verify() {
2808  std::pair<VectorDim, VectorDim> mismatchingDims;
2810  getSourceType(), getResultVectorType(), &mismatchingDims);
2811  if (res == BroadcastableToResult::Success)
2812  return success();
2813  if (res == BroadcastableToResult::SourceRankHigher)
2814  return emitOpError("source rank higher than destination rank");
2815  if (res == BroadcastableToResult::DimensionMismatch) {
2816  return emitOpError("dimension mismatch (")
2817  << (mismatchingDims.first.isScalable ? "[" : "")
2818  << mismatchingDims.first.dim
2819  << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
2820  << (mismatchingDims.second.isScalable ? "[" : "")
2821  << mismatchingDims.second.dim
2822  << (mismatchingDims.second.isScalable ? "]" : "") << ")";
2823  }
2824  if (res == BroadcastableToResult::SourceTypeNotAVector)
2825  return emitOpError("source type is not a vector");
2826  llvm_unreachable("unexpected vector.broadcast op error");
2827 }
2828 
2829 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
2830  if (getSourceType() == getResultVectorType())
2831  return getSource();
2832  if (!adaptor.getSource())
2833  return {};
2834  auto vectorType = getResultVectorType();
2835  if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
2836  if (vectorType.getElementType() != attr.getType())
2837  return {};
2838  return DenseElementsAttr::get(vectorType, attr);
2839  }
2840  if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
2841  if (vectorType.getElementType() != attr.getType())
2842  return {};
2843  return DenseElementsAttr::get(vectorType, attr);
2844  }
2845  if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2846  return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
2847  if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
2848  return ub::PoisonAttr::get(getContext());
2849  return {};
2850 }
2851 
2852 namespace {
2853 
2854 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
2855 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2857 
2858  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2859  PatternRewriter &rewriter) const override {
2860  auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2861  if (!srcBroadcast)
2862  return failure();
2863  rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
2864  broadcastOp.getResultVectorType(),
2865  srcBroadcast.getSource());
2866  return success();
2867  }
2868 };
2869 } // namespace
2870 
2871 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2872  MLIRContext *context) {
2873  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2874  // calling `populateCastAwayVectorLeadingOneDimPatterns`
2875  results.add<BroadcastFolder>(context);
2876 }
2877 
2878 //===----------------------------------------------------------------------===//
2879 // ShuffleOp
2880 //===----------------------------------------------------------------------===//
2881 
2882 LogicalResult ShuffleOp::verify() {
2883  VectorType resultType = getResultVectorType();
2884  VectorType v1Type = getV1VectorType();
2885  VectorType v2Type = getV2VectorType();
2886  // Verify ranks.
2887  int64_t resRank = resultType.getRank();
2888  int64_t v1Rank = v1Type.getRank();
2889  int64_t v2Rank = v2Type.getRank();
2890  bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2891  bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2892  if (!wellFormed0DCase && !wellFormedNDCase)
2893  return emitOpError("rank mismatch");
2894 
2895  // Verify all but leading dimension sizes.
2896  for (int64_t r = 1; r < v1Rank; ++r) {
2897  int64_t resDim = resultType.getDimSize(r);
2898  int64_t v1Dim = v1Type.getDimSize(r);
2899  int64_t v2Dim = v2Type.getDimSize(r);
2900  if (resDim != v1Dim || v1Dim != v2Dim)
2901  return emitOpError("dimension mismatch");
2902  }
2903  // Verify mask length.
2904  ArrayRef<int64_t> mask = getMask();
2905  int64_t maskLength = mask.size();
2906  if (maskLength <= 0)
2907  return emitOpError("invalid mask length");
2908  if (maskLength != resultType.getDimSize(0))
2909  return emitOpError("mask length mismatch");
2910  // Verify all indices.
2911  int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2912  (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2913  for (auto [idx, maskPos] : llvm::enumerate(mask)) {
2914  if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
2915  return emitOpError("mask index #") << (idx + 1) << " out of range";
2916  }
2917  return success();
2918 }
2919 
2920 LogicalResult
2921 ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
2922  ShuffleOp::Adaptor adaptor,
2923  SmallVectorImpl<Type> &inferredReturnTypes) {
2924  auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2925  auto v1Rank = v1Type.getRank();
2926  // Construct resulting type: leading dimension matches mask
2927  // length, all trailing dimensions match the operands.
2929  shape.reserve(v1Rank);
2930  shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2931  // In the 0-D case there is no trailing shape to append.
2932  if (v1Rank > 0)
2933  llvm::append_range(shape, v1Type.getShape().drop_front());
2934  inferredReturnTypes.push_back(
2935  VectorType::get(shape, v1Type.getElementType()));
2936  return success();
2937 }
2938 
2939 template <typename T>
2940 static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
2941  T expected = begin;
2942  return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2943  return value == expected++;
2944  });
2945 }
2946 
2947 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2948  auto v1Type = getV1VectorType();
2949  auto v2Type = getV2VectorType();
2950 
2951  assert(!v1Type.isScalable() && !v2Type.isScalable() &&
2952  "Vector shuffle does not support scalable vectors");
2953 
2954  // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
2955  // but must be a canonicalization into a vector.broadcast.
2956  if (v1Type.getRank() == 0)
2957  return {};
2958 
2959  // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
2960  auto mask = getMask();
2961  if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
2962  return getV1();
2963  // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
2964  if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
2965  return getV2();
2966 
2967  Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
2968  if (!v1Attr || !v2Attr)
2969  return {};
2970 
2971  // Fold shuffle poison, poison -> poison.
2972  bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
2973  bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
2974  if (isV1Poison && isV2Poison)
2975  return ub::PoisonAttr::get(getContext());
2976 
2977  // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2978  // manipulation.
2979  if (v1Type.getRank() != 1)
2980  return {};
2981 
2982  // Poison input attributes need special handling as they are not
2983  // DenseElementsAttr. If an index is poison, we select the first element of
2984  // the first non-poison input.
2985  SmallVector<Attribute> v1Elements, v2Elements;
2986  Attribute poisonElement;
2987  if (!isV2Poison) {
2988  v2Elements =
2989  to_vector(cast<DenseElementsAttr>(v2Attr).getValues<Attribute>());
2990  poisonElement = v2Elements[0];
2991  }
2992  if (!isV1Poison) {
2993  v1Elements =
2994  to_vector(cast<DenseElementsAttr>(v1Attr).getValues<Attribute>());
2995  poisonElement = v1Elements[0];
2996  }
2997 
2998  SmallVector<Attribute> results;
2999  int64_t v1Size = v1Type.getDimSize(0);
3000  for (int64_t maskIdx : mask) {
3001  Attribute indexedElm;
3002  // TODO: Return a partial poison vector when supported by the UB dialect.
3003  if (maskIdx == ShuffleOp::kPoisonIndex) {
3004  indexedElm = poisonElement;
3005  } else {
3006  if (maskIdx < v1Size)
3007  indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3008  else
3009  indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3010  }
3011 
3012  results.push_back(indexedElm);
3013  }
3014 
3015  return DenseElementsAttr::get(getResultVectorType(), results);
3016 }
3017 
3018 namespace {
3019 
3020 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
3021 // to a broadcast.
3022 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
3024 
3025  LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3026  PatternRewriter &rewriter) const override {
3027  VectorType v1VectorType = shuffleOp.getV1VectorType();
3028  ArrayRef<int64_t> mask = shuffleOp.getMask();
3029  if (v1VectorType.getRank() > 0)
3030  return failure();
3031  if (mask.size() != 1)
3032  return failure();
3033  VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3034  if (mask[0] == 0)
3035  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3036  shuffleOp.getV1());
3037  else
3038  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3039  shuffleOp.getV2());
3040  return success();
3041  }
3042 };
3043 
3044 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
3045 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
3046 public:
3048 
3049  LogicalResult matchAndRewrite(ShuffleOp op,
3050  PatternRewriter &rewriter) const override {
3051  auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
3052  auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
3053 
3054  if (!v1Splat || !v2Splat)
3055  return failure();
3056 
3057  if (v1Splat.getInput() != v2Splat.getInput())
3058  return failure();
3059 
3060  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
3061  return success();
3062  }
3063 };
3064 
3065 /// Pattern to rewrite a fixed-size interleave via vector.shuffle to
3066 /// vector.interleave.
3067 class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
3068 public:
3070 
3071  LogicalResult matchAndRewrite(ShuffleOp op,
3072  PatternRewriter &rewriter) const override {
3073  VectorType resultType = op.getResultVectorType();
3074  if (resultType.isScalable())
3075  return rewriter.notifyMatchFailure(
3076  op, "ShuffleOp can't represent a scalable interleave");
3077 
3078  if (resultType.getRank() != 1)
3079  return rewriter.notifyMatchFailure(
3080  op, "ShuffleOp can't represent an n-D interleave");
3081 
3082  VectorType sourceType = op.getV1VectorType();
3083  if (sourceType != op.getV2VectorType() ||
3084  sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3085  return rewriter.notifyMatchFailure(
3086  op, "ShuffleOp types don't match an interleave");
3087  }
3088 
3089  ArrayRef<int64_t> shuffleMask = op.getMask();
3090  int64_t resultVectorSize = resultType.getNumElements();
3091  for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3092  int64_t maskValueA = shuffleMask[i * 2];
3093  int64_t maskValueB = shuffleMask[(i * 2) + 1];
3094  if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3095  return rewriter.notifyMatchFailure(op,
3096  "ShuffleOp mask not interleaving");
3097  }
3098 
3099  rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
3100  return success();
3101  }
3102 };
3103 
3104 } // namespace
3105 
3106 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3107  MLIRContext *context) {
3108  results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3109  context);
3110 }
3111 
3112 //===----------------------------------------------------------------------===//
3113 // InsertElementOp
3114 //===----------------------------------------------------------------------===//
3115 
3116 void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3117  SetIntRangeFn setResultRanges) {
3118  setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3119 }
3120 
3121 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
3122  Value source, Value dest) {
3123  build(builder, result, source, dest, {});
3124 }
3125 
3126 LogicalResult InsertElementOp::verify() {
3127  auto dstVectorType = getDestVectorType();
3128  if (dstVectorType.getRank() == 0) {
3129  if (getPosition())
3130  return emitOpError("expected position to be empty with 0-D vector");
3131  return success();
3132  }
3133  if (dstVectorType.getRank() != 1)
3134  return emitOpError("unexpected >1 vector rank");
3135  if (!getPosition())
3136  return emitOpError("expected position for 1-D vector");
3137  return success();
3138 }
3139 
3140 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
3141  // Skip the 0-D vector here.
3142  if (!adaptor.getPosition())
3143  return {};
3144 
3145  auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
3146  auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
3147  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
3148  if (!src || !dst || !pos)
3149  return {};
3150 
3151  if (src.getType() != getDestVectorType().getElementType())
3152  return {};
3153 
3154  auto dstElements = dst.getValues<Attribute>();
3155 
3156  SmallVector<Attribute> results(dstElements);
3157 
3158  uint64_t posIdx = pos.getInt();
3159  if (posIdx >= results.size())
3160  return {};
3161  results[posIdx] = src;
3162 
3163  return DenseElementsAttr::get(getDestVectorType(), results);
3164 }
3165 
3166 //===----------------------------------------------------------------------===//
3167 // InsertOp
3168 //===----------------------------------------------------------------------===//
3169 
3170 void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3171  SetIntRangeFn setResultRanges) {
3172  setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3173 }
3174 
3175 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3176  Value source, Value dest) {
3177  auto vectorTy = cast<VectorType>(dest.getType());
3178  build(builder, result, source, dest,
3179  SmallVector<int64_t>(vectorTy.getRank(), 0));
3180 }
3181 
3182 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3183  Value source, Value dest, int64_t position) {
3184  build(builder, result, source, dest, ArrayRef<int64_t>{position});
3185 }
3186 
3187 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3188  Value source, Value dest, OpFoldResult position) {
3189  build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
3190 }
3191 
3192 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3193  Value source, Value dest,
3194  ArrayRef<int64_t> position) {
3195  SmallVector<OpFoldResult> posVals;
3196  posVals.reserve(position.size());
3197  llvm::transform(position, std::back_inserter(posVals),
3198  [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
3199  build(builder, result, source, dest, posVals);
3200 }
3201 
3202 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3203  Value source, Value dest,
3204  ArrayRef<OpFoldResult> position) {
3205  SmallVector<int64_t> staticPos;
3206  SmallVector<Value> dynamicPos;
3207  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
3208  build(builder, result, source, dest, dynamicPos,
3209  builder.getDenseI64ArrayAttr(staticPos));
3210 }
3211 
3212 LogicalResult InsertOp::verify() {
3213  SmallVector<OpFoldResult> position = getMixedPosition();
3214  auto destVectorType = getDestVectorType();
3215  if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
3216  return emitOpError(
3217  "expected position attribute of rank no greater than dest vector rank");
3218  auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3219  if (srcVectorType &&
3220  (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3221  static_cast<unsigned>(destVectorType.getRank())))
3222  return emitOpError("expected position attribute rank + source rank to "
3223  "match dest vector rank");
3224  if (!srcVectorType &&
3225  (position.size() != static_cast<unsigned>(destVectorType.getRank())))
3226  return emitOpError(
3227  "expected position attribute rank to match the dest vector rank");
3228  for (auto [idx, pos] : llvm::enumerate(position)) {
3229  if (auto attr = dyn_cast<Attribute>(pos)) {
3230  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3231  if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
3232  destVectorType.getDimSize(idx))) {
3233  return emitOpError("expected position attribute #")
3234  << (idx + 1)
3235  << " to be a non-negative integer smaller than the "
3236  "corresponding "
3237  "dest vector dimension";
3238  }
3239  }
3240  }
3241  return success();
3242 }
3243 
3244 namespace {
3245 
3246 // If insertOp is only inserting unit dimensions it can be transformed to a
3247 // broadcast.
3248 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3249 public:
3251 
3252  LogicalResult matchAndRewrite(InsertOp insertOp,
3253  PatternRewriter &rewriter) const override {
3254  auto srcVecType =
3255  llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3256  if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3257  srcVecType.getNumElements())
3258  return failure();
3259  rewriter.replaceOpWithNewOp<BroadcastOp>(
3260  insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3261  return success();
3262  }
3263 };
3264 
3265 /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
3266 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3267 public:
3269 
3270  LogicalResult matchAndRewrite(InsertOp op,
3271  PatternRewriter &rewriter) const override {
3272  auto srcSplat = op.getValueToStore().getDefiningOp<SplatOp>();
3273  auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
3274 
3275  if (!srcSplat || !dstSplat)
3276  return failure();
3277 
3278  if (srcSplat.getInput() != dstSplat.getInput())
3279  return failure();
3280 
3281  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
3282  return success();
3283  }
3284 };
3285 
3286 } // namespace
3287 
3288 static Attribute
3290  Attribute dstAttr,
3291  int64_t maxVectorSizeFoldThreshold) {
3292  if (insertOp.hasDynamicPosition())
3293  return {};
3294 
3295  auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3296  if (!denseDst)
3297  return {};
3298 
3299  if (!srcAttr) {
3300  return {};
3301  }
3302 
3303  VectorType destTy = insertOp.getDestVectorType();
3304  if (destTy.isScalable())
3305  return {};
3306 
3307  // Make sure we do not create too many large constants.
3308  if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3309  !insertOp->hasOneUse())
3310  return {};
3311 
3312  // Calculate the linearized position of the continuous chunk of elements to
3313  // insert.
3314  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3315  copy(insertOp.getStaticPosition(), completePositions.begin());
3316  int64_t insertBeginPosition =
3317  linearize(completePositions, computeStrides(destTy.getShape()));
3318 
3319  SmallVector<Attribute> insertedValues;
3320  Type destEltType = destTy.getElementType();
3321 
3322  /// Converts the expected type to an IntegerAttr if there's
3323  /// a mismatch.
3324  auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
3325  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3326  if (intAttr.getType() != expectedType)
3327  return IntegerAttr::get(expectedType, intAttr.getInt());
3328  }
3329  return attr;
3330  };
3331 
3332  // The `convertIntegerAttr` method specifically handles the case
3333  // for `llvm.mlir.constant` which can hold an attribute with a
3334  // different type than the return type.
3335  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3336  for (auto value : denseSource.getValues<Attribute>())
3337  insertedValues.push_back(convertIntegerAttr(value, destEltType));
3338  } else {
3339  insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
3340  }
3341 
3342  auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
3343  copy(insertedValues, allValues.begin() + insertBeginPosition);
3344  auto newAttr = DenseElementsAttr::get(destTy, allValues);
3345 
3346  return newAttr;
3347 }
3348 
3349 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3350  MLIRContext *context) {
3351  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
3352 }
3353 
3354 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3355  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3356  // unless the source vector constant has a single use.
3357  constexpr int64_t vectorSizeFoldThreshold = 256;
3358  // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3359  // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3360  // (type mismatch).
3361  if (getNumIndices() == 0 && getValueToStoreType() == getType())
3362  return getValueToStore();
3363  // Fold `arith.constant` indices into the `vector.insert` operation.
3364  // Do not stop here as this fold may enable subsequent folds that require
3365  // constant indices.
3366  SmallVector<Value> operands = {getValueToStore(), getDest()};
3367  auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
3368 
3369  if (auto res = foldPoisonIndexInsertExtractOp(
3370  getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3371  return res;
3372  if (auto res = foldDenseElementsAttrDestInsertOp(
3373  *this, adaptor.getValueToStore(), adaptor.getDest(),
3374  vectorSizeFoldThreshold)) {
3375  return res;
3376  }
3377 
3378  return inplaceFolded;
3379 }
3380 
3381 //===----------------------------------------------------------------------===//
3382 // InsertStridedSliceOp
3383 //===----------------------------------------------------------------------===//
3384 
3385 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3386  Value source, Value dest,
3387  ArrayRef<int64_t> offsets,
3388  ArrayRef<int64_t> strides) {
3389  result.addOperands({source, dest});
3390  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3391  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3392  result.addTypes(dest.getType());
3393  result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
3394  offsetsAttr);
3395  result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
3396  stridesAttr);
3397 }
3398 
3399 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
3400 template <typename OpType>
3401 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3402  ArrayAttr arrayAttr,
3403  ArrayRef<int64_t> shape,
3404  StringRef attrName) {
3405  if (arrayAttr.size() > shape.size())
3406  return op.emitOpError("expected ")
3407  << attrName << " attribute of rank no greater than vector rank";
3408  return success();
3409 }
3410 
3411 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3412 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3413 // Otherwise, the admissible interval is [min, max].
3414 template <typename OpType>
3415 static LogicalResult
3416 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
3417  int64_t max, StringRef attrName,
3418  bool halfOpen = true) {
3419  for (auto attr : arrayAttr) {
3420  auto val = llvm::cast<IntegerAttr>(attr).getInt();
3421  auto upper = max;
3422  if (!halfOpen)
3423  upper += 1;
3424  if (val < min || val >= upper)
3425  return op.emitOpError("expected ") << attrName << " to be confined to ["
3426  << min << ", " << upper << ")";
3427  }
3428  return success();
3429 }
3430 
3431 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3432 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3433 // Otherwise, the admissible interval is [min, max].
3434 template <typename OpType>
3435 static LogicalResult
3436 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
3437  ArrayRef<int64_t> shape, StringRef attrName,
3438  bool halfOpen = true, int64_t min = 0) {
3439  for (auto [index, attrDimPair] :
3440  llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3441  int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3442  int64_t max = std::get<1>(attrDimPair);
3443  if (!halfOpen)
3444  max += 1;
3445  if (val < min || val >= max)
3446  return op.emitOpError("expected ")
3447  << attrName << " dimension " << index << " to be confined to ["
3448  << min << ", " << max << ")";
3449  }
3450  return success();
3451 }
3452 
3453 // Returns true if, for all indices i = 0..shape.size()-1, val is in the
3454 // [min, max} interval:
3455 // val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3456 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3457 // the admissible interval is [min, max].
3458 template <typename OpType>
3460  OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3461  ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3462  bool halfOpen = true, int64_t min = 1) {
3463  assert(arrayAttr1.size() <= shape.size());
3464  assert(arrayAttr2.size() <= shape.size());
3465  for (auto [index, it] :
3466  llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3467  auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3468  auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3469  int64_t max = std::get<2>(it);
3470  if (!halfOpen)
3471  max += 1;
3472  if (val1 + val2 < 0 || val1 + val2 >= max)
3473  return op.emitOpError("expected sum(")
3474  << attrName1 << ", " << attrName2 << ") dimension " << index
3475  << " to be confined to [" << min << ", " << max << ")";
3476  }
3477  return success();
3478 }
3479 
3480 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
3481  MLIRContext *context) {
3482  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3483  return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3484  });
3485  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3486 }
3487 
3488 LogicalResult InsertStridedSliceOp::verify() {
3489  auto sourceVectorType = getSourceVectorType();
3490  auto destVectorType = getDestVectorType();
3491  auto offsets = getOffsetsAttr();
3492  auto strides = getStridesAttr();
3493  if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3494  return emitOpError(
3495  "expected offsets of same size as destination vector rank");
3496  if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3497  return emitOpError("expected strides of same size as source vector rank");
3498  if (sourceVectorType.getRank() > destVectorType.getRank())
3499  return emitOpError(
3500  "expected source rank to be no greater than destination rank");
3501 
3502  auto sourceShape = sourceVectorType.getShape();
3503  auto destShape = destVectorType.getShape();
3504  SmallVector<int64_t, 4> sourceShapeAsDestShape(
3505  destShape.size() - sourceShape.size(), 0);
3506  sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3507  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3508  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3509  if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3510  offName)) ||
3511  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3512  /*max=*/1, stridesName,
3513  /*halfOpen=*/false)) ||
3515  *this, offsets,
3516  makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3517  offName, "source vector shape",
3518  /*halfOpen=*/false, /*min=*/1)))
3519  return failure();
3520 
3521  unsigned rankDiff = destShape.size() - sourceShape.size();
3522  for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3523  if (sourceVectorType.getScalableDims()[idx] !=
3524  destVectorType.getScalableDims()[idx + rankDiff]) {
3525  return emitOpError("mismatching scalable flags (at source vector idx=")
3526  << idx << ")";
3527  }
3528  if (sourceVectorType.getScalableDims()[idx]) {
3529  auto sourceSize = sourceShape[idx];
3530  auto destSize = destShape[idx + rankDiff];
3531  if (sourceSize != destSize) {
3532  return emitOpError("expected size at idx=")
3533  << idx
3534  << (" to match the corresponding base size from the input "
3535  "vector (")
3536  << sourceSize << (" vs ") << destSize << (")");
3537  }
3538  }
3539  }
3540 
3541  return success();
3542 }
3543 
3544 namespace {
3545 /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3546 /// SplatOp(X):dst_type) to SplatOp(X):dst_type.
3547 class FoldInsertStridedSliceSplat final
3548  : public OpRewritePattern<InsertStridedSliceOp> {
3549 public:
3551 
3552  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3553  PatternRewriter &rewriter) const override {
3554  auto srcSplatOp =
3555  insertStridedSliceOp.getValueToStore().getDefiningOp<vector::SplatOp>();
3556  auto destSplatOp =
3557  insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3558 
3559  if (!srcSplatOp || !destSplatOp)
3560  return failure();
3561 
3562  if (srcSplatOp.getInput() != destSplatOp.getInput())
3563  return failure();
3564 
3565  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3566  return success();
3567  }
3568 };
3569 
3570 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3571 /// to dst.
3572 class FoldInsertStridedSliceOfExtract final
3573  : public OpRewritePattern<InsertStridedSliceOp> {
3574 public:
3576 
3577  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3578  PatternRewriter &rewriter) const override {
3579  auto extractStridedSliceOp =
3580  insertStridedSliceOp.getValueToStore()
3581  .getDefiningOp<vector::ExtractStridedSliceOp>();
3582 
3583  if (!extractStridedSliceOp)
3584  return failure();
3585 
3586  if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3587  return failure();
3588 
3589  // Check if have the same strides and offsets.
3590  if (extractStridedSliceOp.getStrides() !=
3591  insertStridedSliceOp.getStrides() ||
3592  extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3593  return failure();
3594 
3595  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3596  return success();
3597  }
3598 };
3599 
3600 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3601 // ConstantOp.
3602 class InsertStridedSliceConstantFolder final
3603  : public OpRewritePattern<InsertStridedSliceOp> {
3604 public:
3606 
3607  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3608  // unless the source vector constant has a single use.
3609  static constexpr int64_t vectorSizeFoldThreshold = 256;
3610 
3611  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3612  PatternRewriter &rewriter) const override {
3613  // Return if 'InsertOp' operand is not defined by a compatible vector
3614  // ConstantOp.
3615  TypedValue<VectorType> destVector = op.getDest();
3616  Attribute vectorDestCst;
3617  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3618  return failure();
3619 
3620  VectorType destTy = destVector.getType();
3621  if (destTy.isScalable())
3622  return failure();
3623 
3624  // Make sure we do not create too many large constants.
3625  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3626  !destVector.hasOneUse())
3627  return failure();
3628 
3629  TypedValue<VectorType> sourceValue = op.getValueToStore();
3630  Attribute sourceCst;
3631  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3632  return failure();
3633 
3634  // TODO: Support poison.
3635  if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3636  return failure();
3637 
3638  // TODO: Handle non-unit strides when they become available.
3639  if (op.hasNonUnitStrides())
3640  return failure();
3641 
3642  VectorType sliceVecTy = sourceValue.getType();
3643  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3644  int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3645  SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3646  SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3647 
3648  // Calcualte the destination element indices by enumerating all slice
3649  // positions within the destination and linearizing them. The enumeration
3650  // order is lexicographic which yields a sequence of monotonically
3651  // increasing linearized position indices.
3652  // Because the destination may have higher dimensionality then the slice,
3653  // we keep track of two overlapping sets of positions and offsets.
3654  auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3655  auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3656  auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3657  auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3658  SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3659  MutableArrayRef<int64_t> currSlicePosition(
3660  currDestPosition.begin() + rankDifference, currDestPosition.end());
3661  ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3662  offsets.end());
3663  do {
3664  int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3665  assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3666  assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3667  "Invalid slice element");
3668  newValues[linearizedPosition] = *sliceValuesIt;
3669  ++sliceValuesIt;
3670  } while (succeeded(
3671  incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3672 
3673  auto newAttr = DenseElementsAttr::get(destTy, newValues);
3674  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3675  return success();
3676  }
3677 };
3678 
3679 } // namespace
3680 
3681 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3682  RewritePatternSet &results, MLIRContext *context) {
3683  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3684  InsertStridedSliceConstantFolder>(context);
3685 }
3686 
3687 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3688  if (getSourceVectorType() == getDestVectorType())
3689  return getValueToStore();
3690  return {};
3691 }
3692 
3693 //===----------------------------------------------------------------------===//
3694 // OuterProductOp
3695 //===----------------------------------------------------------------------===//
3696 
3697 /// Build an op without mask, use the type of `acc` as the return type.
3698 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
3699  Value lhs, Value rhs, Value acc) {
3700  result.addOperands({lhs, rhs, acc});
3701  result.addTypes(acc.getType());
3702 }
3703 
3705  p << " " << getLhs() << ", " << getRhs();
3706  if (getAcc()) {
3707  p << ", " << getAcc();
3708  p.printOptionalAttrDict((*this)->getAttrs());
3709  }
3710  p << " : " << getLhs().getType() << ", " << getRhs().getType();
3711 }
3712 
3713 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
3715  Type tLHS, tRHS;
3716  if (parser.parseOperandList(operandsInfo) ||
3717  parser.parseOptionalAttrDict(result.attributes) ||
3718  parser.parseColonType(tLHS) || parser.parseComma() ||
3719  parser.parseType(tRHS))
3720  return failure();
3721  if (operandsInfo.size() < 2)
3722  return parser.emitError(parser.getNameLoc(),
3723  "expected at least 2 operands");
3724  VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3725  VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3726  if (!vLHS)
3727  return parser.emitError(parser.getNameLoc(),
3728  "expected vector type for operand #1");
3729 
3730  VectorType resType;
3731  if (vRHS) {
3732  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3733  vRHS.getScalableDims()[0]};
3734  resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
3735  vLHS.getElementType(), scalableDimsRes);
3736  } else {
3737  // Scalar RHS operand
3738  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3739  resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3740  scalableDimsRes);
3741  }
3742 
3743  if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
3744  result.attributes.append(
3745  OuterProductOp::getKindAttrName(result.name),
3747  OuterProductOp::getDefaultKind()));
3748  }
3749 
3750  return failure(
3751  parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
3752  parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
3753  (operandsInfo.size() > 2 &&
3754  parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
3755  parser.addTypeToList(resType, result.types));
3756 }
3757 
3758 LogicalResult OuterProductOp::verify() {
3759  Type tRHS = getOperandTypeRHS();
3760  VectorType vLHS = getOperandVectorTypeLHS(),
3761  vRHS = llvm::dyn_cast<VectorType>(tRHS),
3762  vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3763 
3764  if (vLHS.getRank() != 1)
3765  return emitOpError("expected 1-d vector for operand #1");
3766 
3767  if (vRHS) {
3768  // Proper OUTER operation.
3769  if (vRHS.getRank() != 1)
3770  return emitOpError("expected 1-d vector for operand #2");
3771  if (vRES.getRank() != 2)
3772  return emitOpError("expected 2-d vector result");
3773  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3774  return emitOpError("expected #1 operand dim to match result dim #1");
3775  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3776  return emitOpError("expected #2 operand dim to match result dim #2");
3777  if (vLHS.isScalable() && !vRHS.isScalable()) {
3778  // This restriction reflects what's currently supported in terms of
3779  // scalable vectors. However, we could relax this if there's a use case.
3780  return emitOpError(
3781  "expected either both or only #2 operand dim to be scalable");
3782  }
3783  } else {
3784  // An AXPY operation.
3785  if (vRES.getRank() != 1)
3786  return emitOpError("expected 1-d vector result");
3787  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3788  return emitOpError("expected #1 operand dim to match result dim #1");
3789  }
3790 
3791  if (vACC && vACC != vRES)
3792  return emitOpError("expected operand #3 of same type as result type");
3793 
3794  // Verify supported combining kind.
3795  if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
3796  return emitOpError("unsupported outerproduct type");
3797 
3798  return success();
3799 }
3800 
3801 // MaskableOpInterface methods.
3802 
3803 /// Returns the mask type expected by this operation. Mostly used for
3804 /// verification purposes. It requires the operation to be vectorized."
3805 Type OuterProductOp::getExpectedMaskType() {
3806  auto vecType = this->getResultVectorType();
3807  return VectorType::get(vecType.getShape(),
3808  IntegerType::get(vecType.getContext(), /*width=*/1),
3809  vecType.getScalableDims());
3810 }
3811 
3812 //===----------------------------------------------------------------------===//
3813 // ExtractStridedSliceOp
3814 //===----------------------------------------------------------------------===//
3815 
3816 // Inference works as follows:
3817 // 1. Add 'sizes' from prefix of dims in 'offsets'.
3818 // 2. Add sizes from 'vectorType' for remaining dims.
3819 // Scalable flags are inherited from 'vectorType'.
3820 static Type inferStridedSliceOpResultType(VectorType vectorType,
3821  ArrayAttr offsets, ArrayAttr sizes,
3822  ArrayAttr strides) {
3823  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3825  shape.reserve(vectorType.getRank());
3826  unsigned idx = 0;
3827  for (unsigned e = offsets.size(); idx < e; ++idx)
3828  shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3829  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3830  shape.push_back(vectorType.getShape()[idx]);
3831 
3832  return VectorType::get(shape, vectorType.getElementType(),
3833  vectorType.getScalableDims());
3834 }
3835 
3836 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3837  Value source, ArrayRef<int64_t> offsets,
3838  ArrayRef<int64_t> sizes,
3839  ArrayRef<int64_t> strides) {
3840  result.addOperands(source);
3841  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3842  auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
3843  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3844  result.addTypes(
3845  inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
3846  offsetsAttr, sizesAttr, stridesAttr));
3847  result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
3848  offsetsAttr);
3849  result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
3850  sizesAttr);
3851  result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
3852  stridesAttr);
3853 }
3854 
3855 LogicalResult ExtractStridedSliceOp::verify() {
3856  auto type = getSourceVectorType();
3857  auto offsets = getOffsetsAttr();
3858  auto sizes = getSizesAttr();
3859  auto strides = getStridesAttr();
3860  if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3861  return emitOpError(
3862  "expected offsets, sizes and strides attributes of same size");
3863 
3864  auto shape = type.getShape();
3865  auto offName = getOffsetsAttrName();
3866  auto sizesName = getSizesAttrName();
3867  auto stridesName = getStridesAttrName();
3868  if (failed(
3869  isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
3870  failed(
3871  isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
3872  failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
3873  stridesName)) ||
3874  failed(
3875  isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
3876  failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
3877  /*halfOpen=*/false,
3878  /*min=*/1)) ||
3879  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3880  /*max=*/1, stridesName,
3881  /*halfOpen=*/false)) ||
3882  failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
3883  shape, offName, sizesName,
3884  /*halfOpen=*/false)))
3885  return failure();
3886 
3887  auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
3888  offsets, sizes, strides);
3889  if (getResult().getType() != resultType)
3890  return emitOpError("expected result type to be ") << resultType;
3891 
3892  for (unsigned idx = 0; idx < sizes.size(); ++idx) {
3893  if (type.getScalableDims()[idx]) {
3894  auto inputDim = type.getShape()[idx];
3895  auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3896  if (inputDim != inputSize)
3897  return emitOpError("expected size at idx=")
3898  << idx
3899  << (" to match the corresponding base size from the input "
3900  "vector (")
3901  << inputSize << (" vs ") << inputDim << (")");
3902  }
3903  }
3904 
3905  return success();
3906 }
3907 
3908 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
3909 // to use the source of the InsertStrided ops if we can detect that the
3910 // extracted vector is a subset of one of the vector inserted.
3911 static LogicalResult
3912 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
3913  // Helper to extract integer out of ArrayAttr.
3914  auto getElement = [](ArrayAttr array, int idx) {
3915  return llvm::cast<IntegerAttr>(array[idx]).getInt();
3916  };
3917  ArrayAttr extractOffsets = op.getOffsets();
3918  ArrayAttr extractStrides = op.getStrides();
3919  ArrayAttr extractSizes = op.getSizes();
3920  auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3921  while (insertOp) {
3922  if (op.getSourceVectorType().getRank() !=
3923  insertOp.getSourceVectorType().getRank())
3924  return failure();
3925  ArrayAttr insertOffsets = insertOp.getOffsets();
3926  ArrayAttr insertStrides = insertOp.getStrides();
3927  // If the rank of extract is greater than the rank of insert, we are likely
3928  // extracting a partial chunk of the vector inserted.
3929  if (extractOffsets.size() > insertOffsets.size())
3930  return failure();
3931  bool patialoverlap = false;
3932  bool disjoint = false;
3933  SmallVector<int64_t, 4> offsetDiffs;
3934  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3935  if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
3936  return failure();
3937  int64_t start = getElement(insertOffsets, dim);
3938  int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3939  int64_t offset = getElement(extractOffsets, dim);
3940  int64_t size = getElement(extractSizes, dim);
3941  // Check if the start of the extract offset is in the interval inserted.
3942  if (start <= offset && offset < end) {
3943  // If the extract interval overlaps but is not fully included we may
3944  // have a partial overlap that will prevent any folding.
3945  if (offset + size > end)
3946  patialoverlap = true;
3947  offsetDiffs.push_back(offset - start);
3948  continue;
3949  }
3950  disjoint = true;
3951  break;
3952  }
3953  // The extract element chunk is a subset of the insert element.
3954  if (!disjoint && !patialoverlap) {
3955  op.setOperand(insertOp.getValueToStore());
3956  // OpBuilder is only used as a helper to build an I64ArrayAttr.
3957  OpBuilder b(op.getContext());
3958  op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
3959  return success();
3960  }
3961  // If the chunk extracted is disjoint from the chunk inserted, keep looking
3962  // in the insert chain.
3963  if (disjoint)
3964  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3965  else {
3966  // The extracted vector partially overlap the inserted vector, we cannot
3967  // fold.
3968  return failure();
3969  }
3970  }
3971  return failure();
3972 }
3973 
3974 // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
3975 static OpFoldResult
3977  Attribute foldInput) {
3978 
3979  auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
3980  if (!dense)
3981  return {};
3982 
3983  // TODO: Handle non-unit strides when they become available.
3984  if (op.hasNonUnitStrides())
3985  return {};
3986 
3987  VectorType sourceVecTy = op.getSourceVectorType();
3988  ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3989  SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3990 
3991  VectorType sliceVecTy = op.getType();
3992  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3993  int64_t rank = sliceVecTy.getRank();
3994 
3995  // Expand offsets and sizes to match the vector rank.
3996  SmallVector<int64_t, 4> offsets(rank, 0);
3997  copy(getI64SubArray(op.getOffsets()), offsets.begin());
3998 
3999  SmallVector<int64_t, 4> sizes(sourceShape);
4000  copy(getI64SubArray(op.getSizes()), sizes.begin());
4001 
4002  // Calculate the slice elements by enumerating all slice positions and
4003  // linearizing them. The enumeration order is lexicographic which yields a
4004  // sequence of monotonically increasing linearized position indices.
4005  const auto denseValuesBegin = dense.value_begin<Attribute>();
4006  SmallVector<Attribute> sliceValues;
4007  sliceValues.reserve(sliceVecTy.getNumElements());
4008  SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
4009  do {
4010  int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
4011  assert(linearizedPosition < sourceVecTy.getNumElements() &&
4012  "Invalid index");
4013  sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4014  } while (succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
4015 
4016  assert(static_cast<int64_t>(sliceValues.size()) ==
4017  sliceVecTy.getNumElements() &&
4018  "Invalid number of slice elements");
4019  return DenseElementsAttr::get(sliceVecTy, sliceValues);
4020 }
4021 
4022 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4023  if (getSourceVectorType() == getResult().getType())
4024  return getVector();
4025  if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
4026  return getResult();
4027 
4028  // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
4029  if (auto splat =
4030  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
4031  DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
4032 
4033  // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4034  return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getVector());
4035 }
4036 
4037 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4038  populateFromInt64AttrArray(getOffsets(), results);
4039 }
4040 
4041 namespace {
4042 
4043 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
4044 // ConstantMaskOp.
4045 class StridedSliceConstantMaskFolder final
4046  : public OpRewritePattern<ExtractStridedSliceOp> {
4047 public:
4049 
4050  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4051  PatternRewriter &rewriter) const override {
4052  // Return if 'extractStridedSliceOp' operand is not defined by a
4053  // ConstantMaskOp.
4054  auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
4055  auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4056  if (!constantMaskOp)
4057  return failure();
4058  // Return if 'extractStridedSliceOp' has non-unit strides.
4059  if (extractStridedSliceOp.hasNonUnitStrides())
4060  return failure();
4061  // Gather constant mask dimension sizes.
4062  ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4063  // Gather strided slice offsets and sizes.
4064  SmallVector<int64_t, 4> sliceOffsets;
4065  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4066  sliceOffsets);
4067  SmallVector<int64_t, 4> sliceSizes;
4068  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4069 
4070  // Compute slice of vector mask region.
4071  SmallVector<int64_t, 4> sliceMaskDimSizes;
4072  sliceMaskDimSizes.reserve(maskDimSizes.size());
4073  for (auto [maskDimSize, sliceOffset, sliceSize] :
4074  llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4075  int64_t sliceMaskDimSize = std::max(
4076  static_cast<int64_t>(0),
4077  std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4078  sliceMaskDimSizes.push_back(sliceMaskDimSize);
4079  }
4080  // Add unchanged dimensions.
4081  if (sliceMaskDimSizes.size() < maskDimSizes.size())
4082  for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4083  sliceMaskDimSizes.push_back(maskDimSizes[i]);
4084  // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
4085  // region is a conjunction of mask dim intervals).
4086  if (llvm::is_contained(sliceMaskDimSizes, 0))
4087  sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4088 
4089  // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
4090  // region.
4091  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4092  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4093  sliceMaskDimSizes);
4094  return success();
4095  }
4096 };
4097 
4098 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
4099 // BroadcastOp(ExtractStrideSliceOp).
4100 class StridedSliceBroadcast final
4101  : public OpRewritePattern<ExtractStridedSliceOp> {
4102 public:
4104 
4105  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4106  PatternRewriter &rewriter) const override {
4107  auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
4108  if (!broadcast)
4109  return failure();
4110  auto srcVecType =
4111  llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
4112  unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4113  auto dstVecType = llvm::cast<VectorType>(op.getType());
4114  unsigned dstRank = dstVecType.getRank();
4115  unsigned rankDiff = dstRank - srcRank;
4116  // Check if the most inner dimensions of the source of the broadcast are the
4117  // same as the destination of the extract. If this is the case we can just
4118  // use a broadcast as the original dimensions are untouched.
4119  bool lowerDimMatch = true;
4120  for (unsigned i = 0; i < srcRank; i++) {
4121  if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4122  lowerDimMatch = false;
4123  break;
4124  }
4125  }
4126  Value source = broadcast.getSource();
4127  // If the inner dimensions don't match, it means we need to extract from the
4128  // source of the orignal broadcast and then broadcast the extracted value.
4129  // We also need to handle degenerated cases where the source is effectively
4130  // just a single scalar.
4131  bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
4132  if (!lowerDimMatch && !isScalarSrc) {
4133  source = rewriter.create<ExtractStridedSliceOp>(
4134  op->getLoc(), source,
4135  getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
4136  getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
4137  getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
4138  }
4139  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
4140  return success();
4141  }
4142 };
4143 
4144 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
4145 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
4146 public:
4148 
4149  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4150  PatternRewriter &rewriter) const override {
4151  auto splat = op.getVector().getDefiningOp<SplatOp>();
4152  if (!splat)
4153  return failure();
4154  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
4155  return success();
4156  }
4157 };
4158 
4159 /// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
4160 /// slice is contiguous, into extract and shape_cast.
4161 ///
4162 /// Example:
4163 /// Before:
4164 /// %1 = vector.extract_strided_slice %arg0 {
4165 /// offsets = [0, 0, 0, 0, 0],
4166 /// sizes = [1, 1, 1, 1, 8],
4167 /// strides = [1, 1, 1, 1, 1]
4168 /// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
4169 /// After:
4170 /// %0 = vector.extract %arg0[0, 0, 0, 0]
4171 /// : vector<8xi8> from vector<8x1x1x2x8xi8>
4172 /// %1 = vector.shape_cast %0
4173 /// : vector<8xi8> to vector<1x1x1x1x8xi8>
4174 ///
4175 class ContiguousExtractStridedSliceToExtract final
4176  : public OpRewritePattern<ExtractStridedSliceOp> {
4177 public:
4179 
4180  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4181  PatternRewriter &rewriter) const override {
4182  if (op.hasNonUnitStrides())
4183  return failure();
4184  Value source = op.getOperand();
4185  auto sourceType = cast<VectorType>(source.getType());
4186  if (sourceType.isScalable() || sourceType.getRank() == 0)
4187  return failure();
4188 
4189  // Compute the number of offsets to pass to ExtractOp::build. That is the
4190  // difference between the source rank and the desired slice rank. We walk
4191  // the dimensions from innermost out, and stop when the next slice dimension
4192  // is not full-size.
4193  SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
4194  int numOffsets;
4195  for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4196  if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4197  break;
4198  }
4199 
4200  // If the created extract op would have no offsets, then this whole
4201  // extract_strided_slice is the identity and should have been handled by
4202  // other canonicalizations.
4203  if (numOffsets == 0)
4204  return failure();
4205 
4206  // If not even the inner-most dimension is full-size, this op can't be
4207  // rewritten as an ExtractOp.
4208  if (numOffsets == sourceType.getRank() &&
4209  static_cast<int>(sizes.size()) == sourceType.getRank())
4210  return failure();
4211 
4212  // The outer dimensions must have unit size.
4213  for (int i = 0; i < numOffsets; ++i) {
4214  if (sizes[i] != 1)
4215  return failure();
4216  }
4217 
4218  // Avoid generating slices that have leading unit dimensions. The shape_cast
4219  // op that we create below would take bad generic fallback patterns
4220  // (ShapeCastOpRewritePattern).
4221  while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
4222  sizes[numOffsets] == 1) {
4223  ++numOffsets;
4224  }
4225 
4226  SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
4227  auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4228  Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
4229  extractOffsets);
4230  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
4231  return success();
4232  }
4233 };
4234 
4235 } // namespace
4236 
4237 void ExtractStridedSliceOp::getCanonicalizationPatterns(
4238  RewritePatternSet &results, MLIRContext *context) {
4239  // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
4240  // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4241  results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
4242  StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4243  context);
4244 }
4245 
4246 //===----------------------------------------------------------------------===//
4247 // TransferReadOp
4248 //===----------------------------------------------------------------------===//
4249 
4250 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
4251 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4252  VectorType vectorType, Value source,
4253  ValueRange indices, AffineMapAttr permutationMapAttr,
4254  /*optional*/ ArrayAttr inBoundsAttr) {
4255  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4256  Value padding = builder.create<arith::ConstantOp>(
4257  result.location, elemType, builder.getZeroAttr(elemType));
4258  build(builder, result, vectorType, source, indices, permutationMapAttr,
4259  padding, /*mask=*/Value(), inBoundsAttr);
4260 }
4261 
4262 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
4263 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4264  VectorType vectorType, Value source,
4265  ValueRange indices, AffineMap permutationMap,
4266  std::optional<ArrayRef<bool>> inBounds) {
4267  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4268  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4269  ? builder.getBoolArrayAttr(inBounds.value())
4270  : builder.getBoolArrayAttr(
4271  SmallVector<bool>(vectorType.getRank(), false));
4272  build(builder, result, vectorType, source, indices, permutationMapAttr,
4273  inBoundsAttr);
4274 }
4275 
4276 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
4277 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4278  VectorType vectorType, Value source,
4279  ValueRange indices, Value padding,
4280  std::optional<ArrayRef<bool>> inBounds) {
4281  AffineMap permutationMap = getTransferMinorIdentityMap(
4282  llvm::cast<ShapedType>(source.getType()), vectorType);
4283  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4284  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4285  ? builder.getBoolArrayAttr(inBounds.value())
4286  : builder.getBoolArrayAttr(
4287  SmallVector<bool>(vectorType.getRank(), false));
4288  build(builder, result, vectorType, source, indices, permutationMapAttr,
4289  padding,
4290  /*mask=*/Value(), inBoundsAttr);
4291 }
4292 
4293 /// 4. Builder that sets padding to zero and permutation map to
4294 /// 'getMinorIdentityMap'.
4295 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4296  VectorType vectorType, Value source,
4297  ValueRange indices,
4298  std::optional<ArrayRef<bool>> inBounds) {
4299  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4300  Value padding = builder.create<arith::ConstantOp>(
4301  result.location, elemType, builder.getZeroAttr(elemType));
4302  build(builder, result, vectorType, source, indices, padding, inBounds);
4303 }
4304 
4305 template <typename EmitFun>
4306 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
4307  EmitFun emitOpError) {
4308  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
4309  for (auto expr : permutationMap.getResults()) {
4310  auto dim = dyn_cast<AffineDimExpr>(expr);
4311  auto zero = dyn_cast<AffineConstantExpr>(expr);
4312  if (zero) {
4313  if (zero.getValue() != 0) {
4314  return emitOpError(
4315  "requires a projected permutation_map (at most one dim or the zero "
4316  "constant can appear in each result)");
4317  }
4318  continue;
4319  }
4320  if (!dim) {
4321  return emitOpError("requires a projected permutation_map (at most one "
4322  "dim or the zero constant can appear in each result)");
4323  }
4324  if (seen[dim.getPosition()]) {
4325  return emitOpError(
4326  "requires a permutation_map that is a permutation (found one dim "
4327  "used more than once)");
4328  }
4329  seen[dim.getPosition()] = true;
4330  }
4331  return success();
4332 }
4333 
4334 static LogicalResult
4335 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4336  VectorType vectorType, VectorType maskType,
4337  VectorType inferredMaskType, AffineMap permutationMap,
4338  ArrayAttr inBounds) {
4339  if (op->hasAttr("masked")) {
4340  return op->emitOpError("masked attribute has been removed. "
4341  "Use in_bounds instead.");
4342  }
4343 
4344  if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4345  return op->emitOpError(
4346  "requires source to be a memref or ranked tensor type");
4347 
4348  auto elementType = shapedType.getElementType();
4349  DataLayout dataLayout = DataLayout::closest(op);
4350  if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4351  // Memref or tensor has vector element type.
4352  unsigned sourceVecSize =
4353  dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
4354  vectorElementType.getShape().back();
4355  unsigned resultVecSize =
4356  dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
4357  vectorType.getShape().back();
4358  if (resultVecSize % sourceVecSize != 0)
4359  return op->emitOpError(
4360  "requires the bitwidth of the minor 1-D vector to be an integral "
4361  "multiple of the bitwidth of the minor 1-D vector of the source");
4362 
4363  unsigned sourceVecEltRank = vectorElementType.getRank();
4364  unsigned resultVecRank = vectorType.getRank();
4365  if (sourceVecEltRank > resultVecRank)
4366  return op->emitOpError(
4367  "requires source vector element and vector result ranks to match.");
4368  unsigned rankOffset = resultVecRank - sourceVecEltRank;
4369  // Check that permutation map results match 'rankOffset' of vector type.
4370  if (permutationMap.getNumResults() != rankOffset)
4371  return op->emitOpError("requires a permutation_map with result dims of "
4372  "the same rank as the vector type");
4373 
4374  if (maskType)
4375  return op->emitOpError("does not support masks with vector element type");
4376  } else {
4377  // Memref or tensor has scalar element type.
4378  unsigned minorSize =
4379  vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4380  unsigned resultVecSize =
4381  dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
4382  if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
4383  return op->emitOpError(
4384  "requires the bitwidth of the minor 1-D vector to be an integral "
4385  "multiple of the bitwidth of the source element type");
4386 
4387  // Check that permutation map results match rank of vector type.
4388  if (permutationMap.getNumResults() != vectorType.getRank())
4389  return op->emitOpError("requires a permutation_map with result dims of "
4390  "the same rank as the vector type");
4391  }
4392 
4393  if (permutationMap.getNumSymbols() != 0)
4394  return op->emitOpError("requires permutation_map without symbols");
4395 
4396  if (permutationMap.getNumInputs() != shapedType.getRank())
4397  return op->emitOpError("requires a permutation_map with input dims of the "
4398  "same rank as the source type");
4399 
4400  if (maskType && maskType != inferredMaskType)
4401  return op->emitOpError("inferred mask type (")
4402  << inferredMaskType << ") and mask operand type (" << maskType
4403  << ") don't match";
4404 
4405  if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
4406  return op->emitOpError("expects the in_bounds attr of same rank "
4407  "as permutation_map results: ")
4408  << AffineMapAttr::get(permutationMap)
4409  << " vs inBounds of size: " << inBounds.size();
4410 
4411  return success();
4412 }
4413 
4414 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
4415  SmallVector<StringRef, 3> elidedAttrs;
4416  elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4417  if (op.getPermutationMap().isMinorIdentity())
4418  elidedAttrs.push_back(op.getPermutationMapAttrName());
4419  // Elide in_bounds attribute if all dims are out-of-bounds.
4420  if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
4421  elidedAttrs.push_back(op.getInBoundsAttrName());
4422  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4423 }
4424 
4426  p << " " << getBase() << "[" << getIndices() << "], " << getPadding();
4427  if (getMask())
4428  p << ", " << getMask();
4429  printTransferAttrs(p, *this);
4430  p << " : " << getShapedType() << ", " << getVectorType();
4431 }
4432 
4433 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
4434  AffineMap permMap) {
4435  auto i1Type = IntegerType::get(permMap.getContext(), 1);
4436  AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
4437  assert(invPermMap && "Inversed permutation map couldn't be computed");
4438  SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
4439 
4440  // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
4441  // 0-D mask into a single-element 1-D mask.
4442  if (maskShape.empty())
4443  maskShape.push_back(1);
4444 
4445  SmallVector<bool> scalableDims =
4446  applyPermutationMap(invPermMap, vecType.getScalableDims());
4447 
4448  return VectorType::get(maskShape, i1Type, scalableDims);
4449 }
4450 
4451 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4452  auto &builder = parser.getBuilder();
4453  SMLoc typesLoc;
4454  OpAsmParser::UnresolvedOperand sourceInfo;
4456  OpAsmParser::UnresolvedOperand paddingInfo;
4457  SmallVector<Type, 2> types;
4459  // Parsing with support for paddingValue.
4460  if (parser.parseOperand(sourceInfo) ||
4461  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
4462  parser.parseComma() || parser.parseOperand(paddingInfo))
4463  return failure();
4464  ParseResult hasMask = parser.parseOptionalComma();
4465  if (hasMask.succeeded()) {
4466  if (parser.parseOperand(maskInfo))
4467  return failure();
4468  }
4469  if (parser.parseOptionalAttrDict(result.attributes) ||
4470  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4471  return failure();
4472  if (types.size() != 2)
4473  return parser.emitError(typesLoc, "requires two types");
4474  auto indexType = builder.getIndexType();
4475  auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4476  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4477  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4478  VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4479  if (!vectorType)
4480  return parser.emitError(typesLoc, "requires vector type");
4481  auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
4482  Attribute permMapAttr = result.attributes.get(permMapAttrName);
4483  AffineMap permMap;
4484  if (!permMapAttr) {
4485  if (shapedType.getRank() <
4486  getEffectiveVectorRankForXferOp(shapedType, vectorType))
4487  return parser.emitError(typesLoc,
4488  "expected a custom permutation_map when "
4489  "rank(source) != rank(destination)");
4490  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4491  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4492  } else {
4493  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4494  }
4495  auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
4496  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4497  if (!inBoundsAttr) {
4498  result.addAttribute(inBoundsAttrName,
4499  builder.getBoolArrayAttr(
4500  SmallVector<bool>(permMap.getNumResults(), false)));
4501  }
4502  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4503  parser.resolveOperands(indexInfo, indexType, result.operands) ||
4504  parser.resolveOperand(paddingInfo, shapedType.getElementType(),
4505  result.operands))
4506  return failure();
4507  if (hasMask.succeeded()) {
4508  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4509  return parser.emitError(
4510  maskInfo.location, "does not support masks with vector element type");
4511  if (vectorType.getRank() != permMap.getNumResults()) {
4512  return parser.emitError(typesLoc,
4513  "expected the same rank for the vector and the "
4514  "results of the permutation map");
4515  }
4516  // Instead of adding the mask type as an op type, compute it based on the
4517  // vector type and the permutation map (to keep the type signature small).
4518  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4519  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4520  return failure();
4521  }
4522  result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4523  builder.getDenseI32ArrayAttr(
4524  {1, static_cast<int32_t>(indexInfo.size()), 1,
4525  static_cast<int32_t>(hasMask.succeeded())}));
4526  return parser.addTypeToList(vectorType, result.types);
4527 }
4528 
4529 LogicalResult TransferReadOp::verify() {
4530  // Consistency of elemental types in source and vector.
4531  ShapedType shapedType = getShapedType();
4532  VectorType vectorType = getVectorType();
4533  VectorType maskType = getMaskType();
4534  auto paddingType = getPadding().getType();
4535  auto permutationMap = getPermutationMap();
4536  VectorType inferredMaskType =
4537  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4538  : VectorType();
4539  auto sourceElementType = shapedType.getElementType();
4540 
4541  if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
4542  return emitOpError("requires ") << shapedType.getRank() << " indices";
4543 
4544  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4545  shapedType, vectorType, maskType,
4546  inferredMaskType, permutationMap, getInBounds())))
4547  return failure();
4548 
4549  if (auto sourceVectorElementType =
4550  llvm::dyn_cast<VectorType>(sourceElementType)) {
4551  // Source has vector element type.
4552  // Check that 'sourceVectorElementType' and 'paddingType' types match.
4553  if (sourceVectorElementType != paddingType)
4554  return emitOpError(
4555  "requires source element type and padding type to match.");
4556 
4557  } else {
4558  // Check that 'paddingType' is valid to store in a vector type.
4559  if (!VectorType::isValidElementType(paddingType))
4560  return emitOpError("requires valid padding vector elemental type");
4561 
4562  // Check that padding type and vector element types match.
4563  if (paddingType != sourceElementType)
4564  return emitOpError(
4565  "requires formal padding and source of the same elemental type");
4566  }
4567 
4568  return verifyPermutationMap(permutationMap,
4569  [&](Twine t) { return emitOpError(t); });
4570 }
4571 
4572 // MaskableOpInterface methods.
4573 
4574 /// Returns the mask type expected by this operation. Mostly used for
4575 /// verification purposes. It requires the operation to be vectorized."
4576 Type TransferReadOp::getExpectedMaskType() {
4577  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4578 }
4579 
4580 //===----------------------------------------------------------------------===//
4581 // TransferReadOp: VectorTransferOpInterface methods.
4582 //===----------------------------------------------------------------------===//
4583 VectorType TransferReadOp::getVectorType() {
4584  return cast<VectorType>(getVector().getType());
4585 }
4586 
4587 template <typename TransferOp>
4588 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4589  // TODO: support more aggressive createOrFold on:
4590  // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4591  if (op.getShapedType().isDynamicDim(indicesIdx))
4592  return false;
4593  Value index = op.getIndices()[indicesIdx];
4594  std::optional<int64_t> cstOp = getConstantIntValue(index);
4595  if (!cstOp.has_value())
4596  return false;
4597 
4598  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4599  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4600 
4601  return cstOp.value() + vectorSize <= sourceSize;
4602 }
4603 
4604 template <typename TransferOp>
4605 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4606  // TODO: support 0-d corner case.
4607  // TODO: Be less conservative.
4608  if (op.getTransferRank() == 0)
4609  return failure();
4610  AffineMap permutationMap = op.getPermutationMap();
4611  bool changed = false;
4612  SmallVector<bool, 4> newInBounds;
4613  newInBounds.reserve(op.getTransferRank());
4614  // Idxs of non-bcast dims - used when analysing bcast dims.
4615  SmallVector<unsigned> nonBcastDims;
4616 
4617  // 1. Process non-broadcast dims
4618  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
4619  // 1.1. Already marked as in-bounds, nothing to see here.
4620  if (op.isDimInBounds(i)) {
4621  newInBounds.push_back(true);
4622  continue;
4623  }
4624  // 1.2. Currently out-of-bounds, check whether we can statically determine
4625  // it is inBounds.
4626  bool inBounds = false;
4627  auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
4628  if (dimExpr) {
4629  inBounds = isInBounds(op, /*resultIdx=*/i,
4630  /*indicesIdx=*/dimExpr.getPosition());
4631  nonBcastDims.push_back(i);
4632  }
4633 
4634  newInBounds.push_back(inBounds);
4635  // We commit the pattern if it is "more inbounds".
4636  changed |= inBounds;
4637  }
4638 
4639  // 2. Handle broadcast dims
4640  // If all non-broadcast dims are "in bounds", then all bcast dims should be
4641  // "in bounds" as well.
4642  bool allNonBcastDimsInBounds = llvm::all_of(
4643  nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
4644  if (allNonBcastDimsInBounds) {
4645  for (size_t idx : permutationMap.getBroadcastDims()) {
4646  changed |= !newInBounds[idx];
4647  newInBounds[idx] = true;
4648  }
4649  }
4650 
4651  if (!changed)
4652  return failure();
4653  // OpBuilder is only used as a helper to build an I64ArrayAttr.
4654  OpBuilder b(op.getContext());
4655  op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
4656  return success();
4657 }
4658 
4659 template <typename TransferOp>
4660 static LogicalResult foldTransferFullMask(TransferOp op) {
4661  auto mask = op.getMask();
4662  if (!mask)
4663  return failure();
4664 
4665  if (getMaskFormat(mask) != MaskFormat::AllTrue)
4666  return failure();
4667 
4668  op.getMaskMutable().clear();
4669  return success();
4670 }
4671 
4672 /// ```
4673 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4674 /// : vector<1x4xf32>, tensor<4x4xf32>
4675 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
4676 /// : tensor<4x4xf32>, vector<1x4xf32>
4677 /// ```
4678 /// -> Folds into
4679 /// ```
4680 /// %v0
4681 /// ```
4682 static Value foldRAW(TransferReadOp readOp) {
4683  if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4684  return {};
4685  auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
4686  while (defWrite) {
4687  if (checkSameValueRAW(defWrite, readOp))
4688  return defWrite.getVector();
4690  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4691  cast<VectorTransferOpInterface>(readOp.getOperation())))
4692  break;
4693  defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
4694  }
4695  return {};
4696 }
4697 
4698 OpFoldResult TransferReadOp::fold(FoldAdaptor) {
4699  if (Value vec = foldRAW(*this))
4700  return vec;
4701  /// transfer_read(memrefcast) -> transfer_read
4702  if (succeeded(foldTransferInBoundsAttribute(*this)))
4703  return getResult();
4704  if (succeeded(foldTransferFullMask(*this)))
4705  return getResult();
4706  if (succeeded(memref::foldMemRefCast(*this)))
4707  return getResult();
4708  if (succeeded(tensor::foldTensorCast(*this)))
4709  return getResult();
4710  return OpFoldResult();
4711 }
4712 
4713 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4714  return llvm::to_vector<4>(getVectorType().getShape());
4715 }
4716 
4717 void TransferReadOp::getEffects(
4719  &effects) {
4720  if (llvm::isa<MemRefType>(getShapedType()))
4721  effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
4723 }
4724 
4725 Speculation::Speculatability TransferReadOp::getSpeculatability() {
4726  if (hasPureTensorSemantics())
4729 }
4730 
4731 namespace {
4732 /// Store to load forwarding for transfer operations with permuation maps.
4733 /// Even if the permutation maps are different we can still propagate the store
4734 /// into the load if the size of the dimensions read and written match. Then we
4735 /// can replace the transfer_read + transfer_write by vector.broadcast and
4736 /// vector.transpose.
4737 /// Example:
4738 /// ```
4739 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
4740 /// {in_bounds = [true, true],
4741 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
4742 /// vector<4x1xf32>, tensor<4x4x4xf32>
4743 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
4744 /// {in_bounds = [true, true, true, true],
4745 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
4746 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
4747 /// ```
4748 /// To:
4749 /// ```
4750 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
4751 /// %r = vector.transpose %0, [3, 0, 2, 1] :
4752 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
4753 /// ```
4754 struct TransferReadAfterWriteToBroadcast
4755  : public OpRewritePattern<TransferReadOp> {
4757 
4758  LogicalResult matchAndRewrite(TransferReadOp readOp,
4759  PatternRewriter &rewriter) const override {
4760  auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
4761  if (!defWrite)
4762  return failure();
4763  // Bail if we need an alias analysis.
4764  if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
4765  return failure();
4766  // Bail if we need a bounds analysis.
4767  if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
4768  return failure();
4769  // TODO: If the written transfer chunk is a superset of the read transfer
4770  // chunk we could do an extract_strided_slice.
4771  if (readOp.getTransferChunkAccessed() !=
4772  defWrite.getTransferChunkAccessed())
4773  return failure();
4774  // TODO: Support cases where a dim is explicitly written but implicitly
4775  // read (i.e., a unit dim that is rank reduced).
4776  if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
4777  getUnusedDimsBitVector({defWrite.getPermutationMap()}))
4778  return failure();
4779  // This pattern should only catch the broadcast case, the non-broadcast case
4780  // should be done separately to keep application conditions clean and
4781  // separate.
4782  AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4783  AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
4784  bool bcast = !readMap.getBroadcastDims().empty() ||
4785  !writeMap.getBroadcastDims().empty();
4786  if (!bcast)
4787  return failure();
4788  // At this point, we know we have a bcast.
4789  // Bail in the masked case (too complex atm and needed to properly account
4790  // for padding).
4791  if (readOp.getMask() || defWrite.getMask())
4792  return failure();
4793  // If indices are not the same a shift may be required, bail.
4794  if (readOp.getIndices() != defWrite.getIndices())
4795  return failure();
4796 
4797  Value vec = defWrite.getVector();
4798  // TODO: loop through the chain of transfer_write if we can prove that they
4799  // don't overlap with the transfer_read. This requires improving
4800  // `isDisjointTransferIndices` helper.
4801  AffineMap map = readMap.compose(writeMap);
4802  if (map.getNumResults() == 0)
4803  return failure();
4804  // Calculate the permutation to apply to go from the vector stored to the
4805  // vector read.
4806  SmallVector<unsigned> permutation;
4807  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
4808  return failure();
4809 
4810  Location loc = readOp.getLoc();
4811  // Calculate the broadcast shape by applying the reverse permutation to the
4812  // final shape we want.
4813  ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
4814  SmallVector<int64_t> broadcastShape(destShape.size());
4815  SmallVector<bool> broadcastScalableFlags(destShape.size());
4816  for (const auto &pos : llvm::enumerate(permutation)) {
4817  broadcastShape[pos.value()] = destShape[pos.index()];
4818  broadcastScalableFlags[pos.value()] =
4819  readOp.getVectorType().getScalableDims()[pos.index()];
4820  }
4821  VectorType broadcastedType = VectorType::get(
4822  broadcastShape, defWrite.getVectorType().getElementType(),
4823  broadcastScalableFlags);
4824  vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
4825  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
4826  rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
4827  transposePerm);
4828  return success();
4829  }
4830 };
4831 } // namespace
4832 
4833 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4834  MLIRContext *context) {
4835  results.add<TransferReadAfterWriteToBroadcast>(context);
4836 }
4837 
4838 //===----------------------------------------------------------------------===//
4839 // TransferWriteOp
4840 //===----------------------------------------------------------------------===//
4841 
4842 /// 1. Builder with type inference.
4843 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4844  Value vector, Value dest, ValueRange indices,
4845  AffineMapAttr permutationMapAttr,
4846  /*optional*/ Value mask,
4847  /*optional*/ ArrayAttr inBoundsAttr) {
4848  Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
4849  build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4850  mask, inBoundsAttr);
4851 }
4852 
4853 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
4854 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4855  Value vector, Value dest, ValueRange indices,
4856  AffineMapAttr permutationMapAttr,
4857  /*optional*/ ArrayAttr inBoundsAttr) {
4858  build(builder, result, vector, dest, indices, permutationMapAttr,
4859  /*mask=*/Value(), inBoundsAttr);
4860 }
4861 
4862 /// 3. Builder with type inference that sets an empty mask (variant without
4863 /// attrs)
4864 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4865  Value vector, Value dest, ValueRange indices,
4866  AffineMap permutationMap,
4867  std::optional<ArrayRef<bool>> inBounds) {
4868  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4869  auto inBoundsAttr =
4870  (inBounds && !inBounds.value().empty())
4871  ? builder.getBoolArrayAttr(inBounds.value())
4873  llvm::cast<VectorType>(vector.getType()).getRank(), false));
4874  build(builder, result, vector, dest, indices, permutationMapAttr,
4875  /*mask=*/Value(), inBoundsAttr);
4876 }
4877 
4878 /// 4. Builder with type inference that sets an empty mask and sets permutation
4879 /// map to 'getMinorIdentityMap'.
4880 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4881  Value vector, Value dest, ValueRange indices,
4882  std::optional<ArrayRef<bool>> inBounds) {
4883  auto vectorType = llvm::cast<VectorType>(vector.getType());
4884  AffineMap permutationMap = getTransferMinorIdentityMap(
4885  llvm::cast<ShapedType>(dest.getType()), vectorType);
4886  build(builder, result, vector, dest, indices, permutationMap, inBounds);
4887 }
4888 
4889 ParseResult TransferWriteOp::parse(OpAsmParser &parser,
4890  OperationState &result) {
4891  auto &builder = parser.getBuilder();
4892  SMLoc typesLoc;
4893  OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
4895  SmallVector<Type, 2> types;
4897  if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
4898  parser.parseOperand(sourceInfo) ||
4899  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
4900  return failure();
4901  ParseResult hasMask = parser.parseOptionalComma();
4902  if (hasMask.succeeded() && parser.parseOperand(maskInfo))
4903  return failure();
4904  if (parser.parseOptionalAttrDict(result.attributes) ||
4905  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4906  return failure();
4907  if (types.size() != 2)
4908  return parser.emitError(typesLoc, "requires two types");
4909  auto indexType = builder.getIndexType();
4910  VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4911  if (!vectorType)
4912  return parser.emitError(typesLoc, "requires vector type");
4913  ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4914  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4915  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4916  auto permMapAttrName =
4917  TransferWriteOp::getPermutationMapAttrName(result.name);
4918  auto permMapAttr = result.attributes.get(permMapAttrName);
4919  AffineMap permMap;
4920  if (!permMapAttr) {
4921  if (shapedType.getRank() <
4922  getEffectiveVectorRankForXferOp(shapedType, vectorType))
4923  return parser.emitError(typesLoc,
4924  "expected a custom permutation_map when "
4925  "rank(source) != rank(destination)");
4926  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4927  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4928  } else {
4929  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4930  }
4931  auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
4932  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4933  if (!inBoundsAttr) {
4934  result.addAttribute(inBoundsAttrName,
4935  builder.getBoolArrayAttr(
4936  SmallVector<bool>(permMap.getNumResults(), false)));
4937  }
4938  if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
4939  parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4940  parser.resolveOperands(indexInfo, indexType, result.operands))
4941  return failure();
4942  if (hasMask.succeeded()) {
4943  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4944  return parser.emitError(
4945  maskInfo.location, "does not support masks with vector element type");
4946  if (vectorType.getRank() != permMap.getNumResults()) {
4947  return parser.emitError(typesLoc,
4948  "expected the same rank for the vector and the "
4949  "results of the permutation map");
4950  }
4951  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4952  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4953  return failure();
4954  }
4955  result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4956  builder.getDenseI32ArrayAttr(
4957  {1, 1, static_cast<int32_t>(indexInfo.size()),
4958  static_cast<int32_t>(hasMask.succeeded())}));
4959  return failure(llvm::isa<RankedTensorType>(shapedType) &&
4960  parser.addTypeToList(shapedType, result.types));
4961 }
4962 
4964  p << " " << getVector() << ", " << getBase() << "[" << getIndices() << "]";
4965  if (getMask())
4966  p << ", " << getMask();
4967  printTransferAttrs(p, *this);
4968  p << " : " << getVectorType() << ", " << getShapedType();
4969 }
4970 
4971 LogicalResult TransferWriteOp::verify() {
4972  // Consistency of elemental types in shape and vector.
4973  ShapedType shapedType = getShapedType();
4974  VectorType vectorType = getVectorType();
4975  VectorType maskType = getMaskType();
4976  auto permutationMap = getPermutationMap();
4977  VectorType inferredMaskType =
4978  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4979  : VectorType();
4980 
4981  if (llvm::size(getIndices()) != shapedType.getRank())
4982  return emitOpError("requires ") << shapedType.getRank() << " indices";
4983 
4984  // We do not allow broadcast dimensions on TransferWriteOps for the moment,
4985  // as the semantics is unclear. This can be revisited later if necessary.
4986  if (hasBroadcastDim())
4987  return emitOpError("should not have broadcast dimensions");
4988 
4989  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4990  shapedType, vectorType, maskType,
4991  inferredMaskType, permutationMap, getInBounds())))
4992  return failure();
4993 
4994  return verifyPermutationMap(permutationMap,
4995  [&](Twine t) { return emitOpError(t); });
4996 }
4997 
4998 //===----------------------------------------------------------------------===//
4999 // TransferWriteOp: MaskableOpInterface methods.
5000 //===----------------------------------------------------------------------===//
5001 
5002 /// Returns the mask type expected by this operation. Mostly used for
5003 /// verification purposes.
5004 Type TransferWriteOp::getExpectedMaskType() {
5005  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5006 }
5007 
5008 //===----------------------------------------------------------------------===//
5009 // TransferWriteOp: VectorTransferOpInterface methods.
5010 //===----------------------------------------------------------------------===//
5011 Value TransferWriteOp::getVector() { return getOperand(0); }
5012 VectorType TransferWriteOp::getVectorType() {
5013  return cast<VectorType>(getValueToStore().getType());
5014 }
5015 
5016 //===----------------------------------------------------------------------===//
5017 // TransferWriteOp: fold methods.
5018 //===----------------------------------------------------------------------===//
5019 /// Fold:
5020 /// ```
5021 /// %t1 = ...
5022 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
5023 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
5024 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
5025 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
5026 /// ```
5027 ///
5028 /// into:
5029 ///
5030 /// ```
5031 /// %t0
5032 /// ```
5033 ///
5034 /// The producer of t1 may or may not be DCE'd depending on whether it is a
5035 /// block argument or has side effects.
5036 static LogicalResult foldReadInitWrite(TransferWriteOp write,
5038  SmallVectorImpl<OpFoldResult> &results) {
5039  // TODO: support 0-d corner case.
5040  if (write.getTransferRank() == 0)
5041  return failure();
5042  auto rankedTensorType =
5043  llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5044  // If not operating on tensors, bail.
5045  if (!rankedTensorType)
5046  return failure();
5047  // If no read, bail.
5048  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5049  if (!read)
5050  return failure();
5051  // TODO: support 0-d corner case.
5052  if (read.getTransferRank() == 0)
5053  return failure();
5054  // For now, only accept minor identity. Future: composition is minor identity.
5055  if (!read.getPermutationMap().isMinorIdentity() ||
5056  !write.getPermutationMap().isMinorIdentity())
5057  return failure();
5058  // Bail on mismatching ranks.
5059  if (read.getTransferRank() != write.getTransferRank())
5060  return failure();
5061  // Bail on potential out-of-bounds accesses.
5062  if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5063  return failure();
5064  // Tensor types must be the same.
5065  if (read.getBase().getType() != rankedTensorType)
5066  return failure();
5067  // Vector types must be the same.
5068  if (read.getVectorType() != write.getVectorType())
5069  return failure();
5070  // Vector and Tensor shapes must match.
5071  if (read.getVectorType().getShape() != rankedTensorType.getShape())
5072  return failure();
5073  // If any index is nonzero.
5074  auto isNotConstantZero = [](Value v) {
5075  auto cstOp = getConstantIntValue(v);
5076  return !cstOp.has_value() || cstOp.value() != 0;
5077  };
5078  if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5079  llvm::any_of(write.getIndices(), isNotConstantZero))
5080  return failure();
5081  // Success.
5082  results.push_back(read.getBase());
5083  return success();
5084 }
5085 
5086 static bool checkSameValueWAR(vector::TransferReadOp read,
5087  vector::TransferWriteOp write) {
5088  return read.getBase() == write.getBase() &&
5089  read.getIndices() == write.getIndices() &&
5090  read.getPermutationMap() == write.getPermutationMap() &&
5091  read.getVectorType() == write.getVectorType() && !read.getMask() &&
5092  !write.getMask();
5093 }
5094 /// Fold transfer_write write after read:
5095 /// ```
5096 /// %t0 = ...
5097 /// %v = vector.transfer_read %t0[%c0...] :
5098 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
5099 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
5100 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
5101 /// ```
5102 ///
5103 /// into:
5104 ///
5105 /// ```
5106 /// %t0
5107 /// ```
5108 static LogicalResult foldWAR(TransferWriteOp write,
5109  SmallVectorImpl<OpFoldResult> &results) {
5110  if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5111  return failure();
5112  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5113  if (!read)
5114  return failure();
5115 
5116  if (!checkSameValueWAR(read, write))
5117  return failure();
5118  results.push_back(read.getBase());
5119  return success();
5120 }
5121 
5122 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5123  SmallVectorImpl<OpFoldResult> &results) {
5124  if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
5125  return success();
5126  if (succeeded(foldWAR(*this, results)))
5127  return success();
5128  if (succeeded(foldTransferInBoundsAttribute(*this)))
5129  return success();
5130  if (succeeded(foldTransferFullMask(*this)))
5131  return success();
5132  return memref::foldMemRefCast(*this);
5133 }
5134 
5135 //===----------------------------------------------------------------------===//
5136 // TransferWriteOp: other methods.
5137 //===----------------------------------------------------------------------===//
5138 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5139  return llvm::to_vector<4>(getVectorType().getShape());
5140 }
5141 
5142 void TransferWriteOp::getEffects(
5144  &effects) {
5145  if (llvm::isa<MemRefType>(getShapedType()))
5146  effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5148 }
5149 
5150 Speculation::Speculatability TransferWriteOp::getSpeculatability() {
5151  if (hasPureTensorSemantics())
5154 }
5155 
5156 namespace {
5157 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
5158 /// DCE
5159 /// ```
5160 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5161 /// : vector<1x4xf32>, tensor<4x4xf32>
5162 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
5163 /// : vector<1x4xf32>, tensor<4x4xf32>
5164 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5165 /// : vector<1x4xf32>, tensor<4x4xf32>
5166 /// ```
5167 ///
5168 /// into:
5169 ///
5170 /// ```
5171 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5172 /// : vector<1x4xf32>, tensor<4x4xf32>
5173 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
5174 /// : vector<1x4xf32>, tensor<4x4xf32>
5175 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5176 /// : vector<1x4xf32>, tensor<4x4xf32>
5177 /// ```
5178 ///
5179 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
5180 /// any other uses.
5181 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
5182 public:
5184  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5185  PatternRewriter &rewriter) const override {
5186  if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5187  return failure();
5188  vector::TransferWriteOp writeToModify = writeOp;
5189 
5190  auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5191  while (defWrite) {
5192  if (checkSameValueWAW(writeOp, defWrite)) {
5193  rewriter.modifyOpInPlace(writeToModify, [&]() {
5194  writeToModify.getBaseMutable().assign(defWrite.getBase());
5195  });
5196  return success();
5197  }
5199  cast<VectorTransferOpInterface>(defWrite.getOperation()),
5200  cast<VectorTransferOpInterface>(writeOp.getOperation())))
5201  break;
5202  // If the previous write op doesn't have any other use we an safely look
5203  // at the previous store to see if it can be removed.
5204  if (!defWrite->hasOneUse())
5205  break;
5206  writeToModify = defWrite;
5207  defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5208  }
5209  return failure();
5210  }
5211 };
5212 
5213 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
5214 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
5215 /// overwritten and inserted into another tensor. After this rewrite, the
5216 /// operations bufferize in-place since all of them work on the same slice.
5217 ///
5218 /// For example:
5219 /// ```mlir
5220 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
5221 /// : vector<8x16xf32>, tensor<8x16xf32>
5222 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
5223 /// : tensor<8x16xf32> to tensor<?x?xf32>
5224 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5225 /// : tensor<?x?xf32> into tensor<27x37xf32>
5226 /// ```
5227 /// folds to
5228 /// ```mlir
5229 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5230 /// : tensor<27x37xf32> to tensor<?x?xf32>
5231 /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
5232 /// : vector<8x16xf32>, tensor<?x?xf32>
5233 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5234 /// : tensor<?x?xf32> into tensor<27x37xf32>
5235 /// ```
5236 struct SwapExtractSliceOfTransferWrite
5237  : public OpRewritePattern<tensor::InsertSliceOp> {
5238 public:
5240 
5241  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5242  PatternRewriter &rewriter) const override {
5243  if (!insertOp.hasUnitStride())
5244  return failure();
5245  auto extractOp =
5246  insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5247  if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5248  return failure();
5249  auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5250  if (!transferOp || !transferOp->hasOneUse())
5251  return failure();
5252 
5253  // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
5254  // rank-reducing.
5255  if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5256  return rewriter.notifyMatchFailure(insertOp,
5257  "use-def chain is rank-reducing");
5258  }
5259 
5260  // Fail if tensor::ExtractSliceOp has non-zero offset.
5261  if (!extractOp.hasZeroOffset()) {
5262  return rewriter.notifyMatchFailure(insertOp,
5263  "ExtractSliceOp has non-zero offset");
5264  }
5265 
5266  // Fail if tensor::TransferWriteOp has non-zero offset.
5267  if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5268  return getConstantIntValue(value) == static_cast<int64_t>(0);
5269  })) {
5270  return rewriter.notifyMatchFailure(insertOp,
5271  "TranferWriteOp has non-zero offset");
5272  }
5273 
5274  // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
5275  if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5276  return rewriter.notifyMatchFailure(
5277  insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
5278  }
5279 
5280  for (auto [insertSize, extractSize] :
5281  llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5282  if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
5283  return rewriter.notifyMatchFailure(
5284  insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
5285  }
5286  }
5287 
5288  // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
5289  assert(transferOp.getVectorType().hasStaticShape() &&
5290  "expected vector to have a static shape");
5291  ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
5293  transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5294  if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
5295  return rewriter.notifyMatchFailure(
5296  insertOp, "TransferWriteOp may not write the full tensor.");
5297  }
5298 
5299  // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
5300  // Set all in_bounds to false and let the folder infer them.
5301  SmallVector<bool> newInBounds(vectorShape.size(), false);
5302  auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
5303  extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
5304  insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
5305  insertOp.getMixedStrides());
5306  auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
5307  transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
5308  transferOp.getIndices(), transferOp.getPermutationMapAttr(),
5309  rewriter.getBoolArrayAttr(newInBounds));
5310  rewriter.modifyOpInPlace(insertOp, [&]() {
5311  insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5312  });
5313  return success();
5314  }
5315 };
5316 
5317 } // namespace
5318 
5319 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5320  MLIRContext *context) {
5321  results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5322 }
5323 
5324 //===----------------------------------------------------------------------===//
5325 // LoadOp
5326 //===----------------------------------------------------------------------===//
5327 
5328 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5329  VectorType vecTy,
5330  MemRefType memRefTy) {
5331  // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
5332  // need any strides limitations.
5333  if (!vecTy.isScalable() &&
5334  (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5335  return success();
5336 
5337  if (!memRefTy.isLastDimUnitStride())
5338  return op->emitOpError("most minor memref dim must have unit stride");
5339  return success();
5340 }
5341 
5342 LogicalResult vector::LoadOp::verify() {
5343  VectorType resVecTy = getVectorType();
5344  MemRefType memRefTy = getMemRefType();
5345 
5346  if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
5347  return failure();
5348 
5349  if (memRefTy.getRank() < resVecTy.getRank())
5350  return emitOpError(
5351  "destination memref has lower rank than the result vector");
5352 
5353  // Checks for vector memrefs.
5354  Type memElemTy = memRefTy.getElementType();
5355  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5356  if (memVecTy != resVecTy)
5357  return emitOpError("base memref and result vector types should match");
5358  memElemTy = memVecTy.getElementType();
5359  }
5360 
5361  if (resVecTy.getElementType() != memElemTy)
5362  return emitOpError("base and result element types should match");
5363  if (llvm::size(getIndices()) != memRefTy.getRank())
5364  return emitOpError("requires ") << memRefTy.getRank() << " indices";
5365  return success();
5366 }
5367 
5368 OpFoldResult LoadOp::fold(FoldAdaptor) {
5369  if (succeeded(memref::foldMemRefCast(*this)))
5370  return getResult();
5371  return OpFoldResult();
5372 }
5373 
5374 std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5375  return llvm::to_vector<4>(getVectorType().getShape());
5376 }
5377 
5378 //===----------------------------------------------------------------------===//
5379 // StoreOp
5380 //===----------------------------------------------------------------------===//
5381 
5382 LogicalResult vector::StoreOp::verify() {
5383  VectorType valueVecTy = getVectorType();
5384  MemRefType memRefTy = getMemRefType();
5385 
5386  if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
5387  return failure();
5388 
5389  if (memRefTy.getRank() < valueVecTy.getRank())
5390  return emitOpError("source memref has lower rank than the vector to store");
5391 
5392  // Checks for vector memrefs.
5393  Type memElemTy = memRefTy.getElementType();
5394  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5395  if (memVecTy != valueVecTy)
5396  return emitOpError(
5397  "base memref and valueToStore vector types should match");
5398  memElemTy = memVecTy.getElementType();
5399  }
5400 
5401  if (valueVecTy.getElementType() != memElemTy)
5402  return emitOpError("base and valueToStore element type should match");
5403  if (llvm::size(getIndices()) != memRefTy.getRank())
5404  return emitOpError("requires ") << memRefTy.getRank() << " indices";
5405  return success();
5406 }
5407 
5408 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5409  SmallVectorImpl<OpFoldResult> &results) {
5410  return memref::foldMemRefCast(*this);
5411 }
5412 
5413 std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5414  return llvm::to_vector<4>(getVectorType().getShape());
5415 }
5416 
5417 //===----------------------------------------------------------------------===//
5418 // MaskedLoadOp
5419 //===----------------------------------------------------------------------===//
5420 
5421 LogicalResult MaskedLoadOp::verify() {
5422  VectorType maskVType = getMaskVectorType();
5423  VectorType passVType = getPassThruVectorType();
5424  VectorType resVType = getVectorType();
5425  MemRefType memType = getMemRefType();
5426 
5427  if (resVType.getElementType() != memType.getElementType())
5428  return emitOpError("base and result element type should match");
5429  if (llvm::size(getIndices()) != memType.getRank())
5430  return emitOpError("requires ") << memType.getRank() << " indices";
5431  if (resVType.getShape() != maskVType.getShape())
5432  return emitOpError("expected result shape to match mask shape");
5433  if (resVType != passVType)
5434  return emitOpError("expected pass_thru of same type as result type");
5435  return success();
5436 }
5437 
5438 namespace {
5439 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
5440 public:
5442  LogicalResult matchAndRewrite(MaskedLoadOp load,
5443  PatternRewriter &rewriter) const override {
5444  switch (getMaskFormat(load.getMask())) {
5445  case MaskFormat::AllTrue:
5446  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5447  load, load.getType(), load.getBase(), load.getIndices());
5448  return success();
5449  case MaskFormat::AllFalse:
5450  rewriter.replaceOp(load, load.getPassThru());
5451  return success();
5452  case MaskFormat::Unknown:
5453  return failure();
5454  }
5455  llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
5456  }
5457 };
5458 } // namespace
5459 
5460 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5461  MLIRContext *context) {
5462  results.add<MaskedLoadFolder>(context);
5463 }
5464 
5465 OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5466  if (succeeded(memref::foldMemRefCast(*this)))
5467  return getResult();
5468  return OpFoldResult();
5469 }
5470 
5471 //===----------------------------------------------------------------------===//
5472 // MaskedStoreOp
5473 //===----------------------------------------------------------------------===//
5474 
5475 LogicalResult MaskedStoreOp::verify() {
5476  VectorType maskVType = getMaskVectorType();
5477  VectorType valueVType = getVectorType();
5478  MemRefType memType = getMemRefType();
5479 
5480  if (valueVType.getElementType() != memType.getElementType())
5481  return emitOpError("base and valueToStore element type should match");
5482  if (llvm::size(getIndices()) != memType.getRank())
5483  return emitOpError("requires ") << memType.getRank() << " indices";
5484  if (valueVType.getShape() != maskVType.getShape())
5485  return emitOpError("expected valueToStore shape to match mask shape");
5486  return success();
5487 }
5488 
5489 namespace {
5490 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
5491 public:
5493  LogicalResult matchAndRewrite(MaskedStoreOp store,
5494  PatternRewriter &rewriter) const override {
5495  switch (getMaskFormat(store.getMask())) {
5496  case MaskFormat::AllTrue:
5497  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5498  store, store.getValueToStore(), store.getBase(), store.getIndices());
5499  return success();
5500  case MaskFormat::AllFalse:
5501  rewriter.eraseOp(store);
5502  return success();
5503  case MaskFormat::Unknown:
5504  return failure();
5505  }
5506  llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
5507  }
5508 };
5509 } // namespace
5510 
5511 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5512  MLIRContext *context) {
5513  results.add<MaskedStoreFolder>(context);
5514 }
5515 
5516 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5517  SmallVectorImpl<OpFoldResult> &results) {
5518  return memref::foldMemRefCast(*this);
5519 }
5520 
5521 //===----------------------------------------------------------------------===//
5522 // GatherOp
5523 //===----------------------------------------------------------------------===//
5524 
5525 LogicalResult GatherOp::verify() {
5526  VectorType indVType = getIndexVectorType();
5527  VectorType maskVType = getMaskVectorType();
5528  VectorType resVType = getVectorType();
5529  ShapedType baseType = getBaseType();
5530 
5531  if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5532  return emitOpError("requires base to be a memref or ranked tensor type");
5533 
5534  if (resVType.getElementType() != baseType.getElementType())
5535  return emitOpError("base and result element type should match");
5536  if (llvm::size(getIndices()) != baseType.getRank())
5537  return emitOpError("requires ") << baseType.getRank() << " indices";
5538  if (resVType.getShape() != indVType.getShape())
5539  return emitOpError("expected result dim to match indices dim");
5540  if (resVType.getShape() != maskVType.getShape())
5541  return emitOpError("expected result dim to match mask dim");
5542  if (resVType != getPassThruVectorType())
5543  return emitOpError("expected pass_thru of same type as result type");
5544  return success();
5545 }
5546 
5547 // MaskableOpInterface methods.
5548 
5549 /// Returns the mask type expected by this operation. Mostly used for
5550 /// verification purposes. It requires the operation to be vectorized."
5551 Type GatherOp::getExpectedMaskType() {
5552  auto vecType = this->getIndexVectorType();
5553  return VectorType::get(vecType.getShape(),
5554  IntegerType::get(vecType.getContext(), /*width=*/1),
5555  vecType.getScalableDims());
5556 }
5557 
5558 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5559  return llvm::to_vector<4>(getVectorType().getShape());
5560 }
5561 
5562 /// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
5563 static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
5564  auto vecType = dyn_cast<VectorType>(indexVec.getType());
5565  if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5566  return failure();
5567 
5568  if (indexVec.getDefiningOp<StepOp>())
5569  return success();
5570 
5571  DenseIntElementsAttr elements;
5572  if (!matchPattern(indexVec, m_Constant(&elements)))
5573  return failure();
5574 
5575  return success(
5576  llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
5577 }
5578 
5579 namespace {
5580 class GatherFolder final : public OpRewritePattern<GatherOp> {
5581 public:
5583  LogicalResult matchAndRewrite(GatherOp gather,
5584  PatternRewriter &rewriter) const override {
5585  switch (getMaskFormat(gather.getMask())) {
5586  case MaskFormat::AllTrue:
5587  return failure(); // no unmasked equivalent
5588  case MaskFormat::AllFalse:
5589  rewriter.replaceOp(gather, gather.getPassThru());
5590  return success();
5591  case MaskFormat::Unknown:
5592  return failure();
5593  }
5594  llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
5595  }
5596 };
5597 
5598 /// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
5599 /// maskedload. Only 1D fixed vectors are supported for now.
5600 class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
5601 public:
5603  LogicalResult matchAndRewrite(GatherOp op,
5604  PatternRewriter &rewriter) const override {
5605  if (!isa<MemRefType>(op.getBase().getType()))
5606  return rewriter.notifyMatchFailure(op, "base must be of memref type");
5607 
5608  if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5609  return failure();
5610 
5611  rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
5612  op.getIndices(), op.getMask(),
5613  op.getPassThru());
5614  return success();
5615  }
5616 };
5617 } // namespace
5618 
5619 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
5620  MLIRContext *context) {
5621  results.add<GatherFolder, FoldContiguousGather>(context);
5622 }
5623 
5624 //===----------------------------------------------------------------------===//
5625 // ScatterOp
5626 //===----------------------------------------------------------------------===//
5627 
5628 LogicalResult ScatterOp::verify() {
5629  VectorType indVType = getIndexVectorType();
5630  VectorType maskVType = getMaskVectorType();
5631  VectorType valueVType = getVectorType();
5632  MemRefType memType = getMemRefType();
5633 
5634  if (valueVType.getElementType() != memType.getElementType())
5635  return emitOpError("base and valueToStore element type should match");
5636  if (llvm::size(getIndices()) != memType.getRank())
5637  return emitOpError("requires ") << memType.getRank() << " indices";
5638  if (valueVType.getShape() != indVType.getShape())
5639  return emitOpError("expected valueToStore dim to match indices dim");
5640  if (valueVType.getShape() != maskVType.getShape())
5641  return emitOpError("expected valueToStore dim to match mask dim");
5642  return success();
5643 }
5644 
5645 namespace {
5646 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
5647 public:
5649  LogicalResult matchAndRewrite(ScatterOp scatter,
5650  PatternRewriter &rewriter) const override {
5651  switch (getMaskFormat(scatter.getMask())) {
5652  case MaskFormat::AllTrue:
5653  return failure(); // no unmasked equivalent
5654  case MaskFormat::AllFalse:
5655  rewriter.eraseOp(scatter);
5656  return success();
5657  case MaskFormat::Unknown:
5658  return failure();
5659  }
5660  llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
5661  }
5662 };
5663 
5664 /// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
5665 /// maskedstore. Only 1D fixed vectors are supported for now.
5666 class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
5667 public:
5669  LogicalResult matchAndRewrite(ScatterOp op,
5670  PatternRewriter &rewriter) const override {
5671  if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5672  return failure();
5673 
5674  rewriter.replaceOpWithNewOp<MaskedStoreOp>(
5675  op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
5676  return success();
5677  }
5678 };
5679 } // namespace
5680 
5681 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
5682  MLIRContext *context) {
5683  results.add<ScatterFolder, FoldContiguousScatter>(context);
5684 }
5685 
5686 //===----------------------------------------------------------------------===//
5687 // ExpandLoadOp
5688 //===----------------------------------------------------------------------===//
5689 
5690 LogicalResult ExpandLoadOp::verify() {
5691  VectorType maskVType = getMaskVectorType();
5692  VectorType passVType = getPassThruVectorType();
5693  VectorType resVType = getVectorType();
5694  MemRefType memType = getMemRefType();
5695 
5696  if (resVType.getElementType() != memType.getElementType())
5697  return emitOpError("base and result element type should match");
5698  if (llvm::size(getIndices()) != memType.getRank())
5699  return emitOpError("requires ") << memType.getRank() << " indices";
5700  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5701  return emitOpError("expected result dim to match mask dim");
5702  if (resVType != passVType)
5703  return emitOpError("expected pass_thru of same type as result type");
5704  return success();
5705 }
5706 
5707 namespace {
5708 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
5709 public:
5711  LogicalResult matchAndRewrite(ExpandLoadOp expand,
5712  PatternRewriter &rewriter) const override {
5713  switch (getMaskFormat(expand.getMask())) {
5714  case MaskFormat::AllTrue:
5715  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5716  expand, expand.getType(), expand.getBase(), expand.getIndices());
5717  return success();
5718  case MaskFormat::AllFalse:
5719  rewriter.replaceOp(expand, expand.getPassThru());
5720  return success();
5721  case MaskFormat::Unknown:
5722  return failure();
5723  }
5724  llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
5725  }
5726 };
5727 } // namespace
5728 
5729 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5730  MLIRContext *context) {
5731  results.add<ExpandLoadFolder>(context);
5732 }
5733 
5734 //===----------------------------------------------------------------------===//
5735 // CompressStoreOp
5736 //===----------------------------------------------------------------------===//
5737 
5738 LogicalResult CompressStoreOp::verify() {
5739  VectorType maskVType = getMaskVectorType();
5740  VectorType valueVType = getVectorType();
5741  MemRefType memType = getMemRefType();
5742 
5743  if (valueVType.getElementType() != memType.getElementType())
5744  return emitOpError("base and valueToStore element type should match");
5745  if (llvm::size(getIndices()) != memType.getRank())
5746  return emitOpError("requires ") << memType.getRank() << " indices";
5747  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5748  return emitOpError("expected valueToStore dim to match mask dim");
5749  return success();
5750 }
5751 
5752 namespace {
5753 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
5754 public:
5756  LogicalResult matchAndRewrite(CompressStoreOp compress,
5757  PatternRewriter &rewriter) const override {
5758  switch (getMaskFormat(compress.getMask())) {
5759  case MaskFormat::AllTrue:
5760  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5761  compress, compress.getValueToStore(), compress.getBase(),
5762  compress.getIndices());
5763  return success();
5764  case MaskFormat::AllFalse:
5765  rewriter.eraseOp(compress);
5766  return success();
5767  case MaskFormat::Unknown:
5768  return failure();
5769  }
5770  llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
5771  }
5772 };
5773 } // namespace
5774 
5775 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5776  MLIRContext *context) {
5777  results.add<CompressStoreFolder>(context);
5778 }
5779 
5780 //===----------------------------------------------------------------------===//
5781 // ShapeCastOp
5782 //===----------------------------------------------------------------------===//
5783 
5784 void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
5785  SetIntRangeFn setResultRanges) {
5786  setResultRanges(getResult(), argRanges.front());
5787 }
5788 
5789 LogicalResult ShapeCastOp::verify() {
5790 
5791  VectorType sourceType = getSourceVectorType();
5792  VectorType resultType = getResultVectorType();
5793 
5794  // Check that element type is preserved
5795  if (sourceType.getElementType() != resultType.getElementType())
5796  return emitOpError("has different source and result element types");
5797 
5798  // Check that number of elements is preserved
5799  int64_t sourceNElms = sourceType.getNumElements();
5800  int64_t resultNElms = resultType.getNumElements();
5801  if (sourceNElms != resultNElms) {
5802  return emitOpError() << "has different number of elements at source ("
5803  << sourceNElms << ") and result (" << resultNElms
5804  << ")";
5805  }
5806 
5807  // Check that (non-)scalability is preserved
5808  int64_t sourceNScalableDims = sourceType.getNumScalableDims();
5809  int64_t resultNScalableDims = resultType.getNumScalableDims();
5810  if (sourceNScalableDims != resultNScalableDims)
5811  return emitOpError() << "has different number of scalable dims at source ("
5812  << sourceNScalableDims << ") and result ("
5813  << resultNScalableDims << ")";
5814 
5815  return success();
5816 }
5817 
5818 /// Return true if `transpose` does not permute a pair of non-unit dims.
5819 /// By `order preserving` we mean that the flattened versions of the input and
5820 /// output vectors are (numerically) identical. In other words `transpose` is
5821 /// effectively a shape cast.
5822 static bool isOrderPreserving(TransposeOp transpose) {
5823  ArrayRef<int64_t> permutation = transpose.getPermutation();
5824  VectorType sourceType = transpose.getSourceVectorType();
5825  ArrayRef<int64_t> inShape = sourceType.getShape();
5826  ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
5827  auto isNonScalableUnitDim = [&](int64_t dim) {
5828  return inShape[dim] == 1 && !inDimIsScalable[dim];
5829  };
5830  int64_t current = 0;
5831  for (auto p : permutation) {
5832  if (!isNonScalableUnitDim(p)) {
5833  if (p < current) {
5834  return false;
5835  }
5836  current = p;
5837  }
5838  }
5839  return true;
5840 }
5841 
5842 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5843 
5844  VectorType resultType = getType();
5845 
5846  // No-op shape cast.
5847  if (getSource().getType() == resultType)
5848  return getSource();
5849 
5850  // shape_cast(shape_cast(x)) -> shape_cast(x)
5851  if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
5852  setOperand(precedingShapeCast.getSource());
5853  return getResult();
5854  }
5855 
5856  // shape_cast(transpose(x)) -> shape_cast(x)
5857  if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
5858  // This folder does
5859  // shape_cast(transpose) -> shape_cast
5860  // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
5861  // shape_cast -> shape_cast(transpose)
5862  // i.e. the complete opposite. When paired, these 2 patterns can cause
5863  // infinite cycles in pattern rewriting.
5864  // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
5865  // vectors, so by disabling this folder for scalable vectors the
5866  // cycle is avoided.
5867  // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
5868  // still needed. If it's not, then we can fold here.
5869  if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
5870  setOperand(transpose.getVector());
5871  return getResult();
5872  }
5873  return {};
5874  }
5875 
5876  // Y = shape_cast(broadcast(X))
5877  // -> X, if X and Y have same type
5878  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5879  if (bcastOp.getSourceType() == resultType)
5880  return bcastOp.getSource();
5881  }
5882 
5883  // shape_cast(constant) -> constant
5884  if (auto splatAttr =
5885  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
5886  return splatAttr.reshape(getType());
5887 
5888  // shape_cast(poison) -> poison
5889  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
5890  return ub::PoisonAttr::get(getContext());
5891  }
5892 
5893  return {};
5894 }
5895 
5896 namespace {
5897 
5898 /// Helper function that computes a new vector type based on the input vector
5899 /// type by removing the trailing one dims:
5900 ///
5901 /// vector<4x1x1xi1> --> vector<4x1xi1>
5902 ///
5903 static VectorType trimTrailingOneDims(VectorType oldType) {
5904  ArrayRef<int64_t> oldShape = oldType.getShape();
5905  ArrayRef<int64_t> newShape = oldShape;
5906 
5907  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
5908  ArrayRef<bool> newScalableDims = oldScalableDims;
5909 
5910  while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5911  newShape = newShape.drop_back(1);
5912  newScalableDims = newScalableDims.drop_back(1);
5913  }
5914 
5915  // Make sure we have at least 1 dimension.
5916  // TODO: Add support for 0-D vectors.
5917  if (newShape.empty()) {
5918  newShape = oldShape.take_back();
5919  newScalableDims = oldScalableDims.take_back();
5920  }
5921 
5922  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5923 }
5924 
5925 /// Folds qualifying shape_cast(create_mask) into a new create_mask
5926 ///
5927 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
5928 /// dimension. If the input vector comes from `vector.create_mask` for which
5929 /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
5930 /// to fold shape_cast into create_mask.
5931 ///
5932 /// BEFORE:
5933 /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
5934 /// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
5935 /// AFTER:
5936 /// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
5937 class ShapeCastCreateMaskFolderTrailingOneDim final
5938  : public OpRewritePattern<ShapeCastOp> {
5939 public:
5941 
5942  LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
5943  PatternRewriter &rewriter) const override {
5944  Value shapeOpSrc = shapeOp->getOperand(0);
5945  auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
5946  auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
5947  if (!createMaskOp && !constantMaskOp)
5948  return failure();
5949 
5950  VectorType shapeOpResTy = shapeOp.getResultVectorType();
5951  VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5952 
5953  VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5954  if (newVecType != shapeOpResTy)
5955  return failure();
5956 
5957  auto numDimsToDrop =
5958  shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5959 
5960  // No unit dims to drop
5961  if (!numDimsToDrop)
5962  return failure();
5963 
5964  if (createMaskOp) {
5965  auto maskOperands = createMaskOp.getOperands();
5966  auto numMaskOperands = maskOperands.size();
5967 
5968  // Check every mask dim size to see whether it can be dropped
5969  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5970  --i) {
5971  auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5972  if (!constant || (constant.value() != 1))
5973  return failure();
5974  }
5975  SmallVector<Value> newMaskOperands =
5976  maskOperands.drop_back(numDimsToDrop);
5977 
5978  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
5979  newMaskOperands);
5980  return success();
5981  }
5982 
5983  if (constantMaskOp) {
5984  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5985  auto numMaskOperands = maskDimSizes.size();
5986 
5987  // Check every mask dim size to see whether it can be dropped
5988  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5989  --i) {
5990  if (maskDimSizes[i] != 1)
5991  return failure();
5992  }
5993 
5994  auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5995  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
5996  newMaskOperands);
5997  return success();
5998  }
5999 
6000  return failure();
6001  }
6002 };
6003 
6004 /// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
6005 /// i) Y = ShapeCast(X), or
6006 /// ii) Y = Broadcast(X)
6007 /// If both (i) and (ii) are possible, (i) is chosen.
6008 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
6009 public:
6011 
6012  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6013  PatternRewriter &rewriter) const override {
6014  auto broadcastOp =
6015  shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6016  if (!broadcastOp)
6017  return failure();
6018 
6019  auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6020  bool srcIsScalar = !srcVectorType;
6021 
6022  // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
6023  // Example:
6024  // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
6025  // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
6026  // to
6027  // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
6028  if (srcVectorType) {
6029  if (srcVectorType.getNumElements() ==
6030  shapeCastOp.getResultVectorType().getNumElements()) {
6031  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
6032  shapeCastOp, shapeCastOp.getResultVectorType(),
6033  broadcastOp.getSource());
6034  return success();
6035  }
6036  }
6037 
6038  // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
6039  // Example
6040  // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
6041  // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
6042  // to
6043  // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
6044  VectorType dstVectorType = shapeCastOp.getResultVectorType();
6045  if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
6046  BroadcastableToResult::Success) {
6047  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6048  shapeCastOp, dstVectorType, broadcastOp.getSource());
6049  return success();
6050  }
6051  return failure();
6052  }
6053 };
6054 
6055 } // namespace
6056 
6057 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6058  MLIRContext *context) {
6059  results
6060  .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
6061  context);
6062 }
6063 
6064 //===----------------------------------------------------------------------===//
6065 // VectorBitCastOp
6066 //===----------------------------------------------------------------------===//
6067 
6068 LogicalResult BitCastOp::verify() {
6069  auto sourceVectorType = getSourceVectorType();
6070  auto resultVectorType = getResultVectorType();
6071 
6072  for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6073  if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6074  return emitOpError("dimension size mismatch at: ") << i;
6075  }
6076 
6077  DataLayout dataLayout = DataLayout::closest(*this);
6078  auto sourceElementBits =
6079  dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
6080  auto resultElementBits =
6081  dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
6082 
6083  if (sourceVectorType.getRank() == 0) {
6084  if (sourceElementBits != resultElementBits)
6085  return emitOpError("source/result bitwidth of the 0-D vector element "
6086  "types must be equal");
6087  } else if (sourceElementBits * sourceVectorType.getShape().back() !=
6088  resultElementBits * resultVectorType.getShape().back()) {
6089  return emitOpError(
6090  "source/result bitwidth of the minor 1-D vectors must be equal");
6091  }
6092 
6093  return success();
6094 }
6095 
6096 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
6097  // Nop cast.
6098  if (getSource().getType() == getResult().getType())
6099  return getSource();
6100 
6101  // Canceling bitcasts.
6102  if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6103  if (getResult().getType() == otherOp.getSource().getType())
6104  return otherOp.getSource();
6105 
6106  setOperand(otherOp.getSource());
6107  return getResult();
6108  }
6109 
6110  Attribute sourceConstant = adaptor.getSource();
6111  if (!sourceConstant)
6112  return {};
6113 
6114  Type srcElemType = getSourceVectorType().getElementType();
6115  Type dstElemType = getResultVectorType().getElementType();
6116 
6117  if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6118  if (floatPack.isSplat()) {
6119  auto splat = floatPack.getSplatValue<FloatAttr>();
6120 
6121  // Casting fp16 into fp32.
6122  if (srcElemType.isF16() && dstElemType.isF32()) {
6123  uint32_t bits = static_cast<uint32_t>(
6124  splat.getValue().bitcastToAPInt().getZExtValue());
6125  // Duplicate the 16-bit pattern.
6126  bits = (bits << 16) | (bits & 0xffff);
6127  APInt intBits(32, bits);
6128  APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6129  return DenseElementsAttr::get(getResultVectorType(), floatBits);
6130  }
6131  }
6132  }
6133 
6134  if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6135  if (intPack.isSplat()) {
6136  auto splat = intPack.getSplatValue<IntegerAttr>();
6137 
6138  if (llvm::isa<IntegerType>(dstElemType)) {
6139  uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
6140  uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
6141 
6142  // Casting to a larger integer bit width.
6143  if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6144  APInt intBits = splat.getValue().zext(dstBitWidth);
6145 
6146  // Duplicate the lower width element.
6147  for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6148  intBits = (intBits << srcBitWidth) | intBits;
6149  return DenseElementsAttr::get(getResultVectorType(), intBits);
6150  }
6151  }
6152  }
6153  }
6154 
6155  return {};
6156 }
6157 
6158 //===----------------------------------------------------------------------===//
6159 // TypeCastOp
6160 //===----------------------------------------------------------------------===//
6161 
6162 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
6163  auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6164  SmallVector<int64_t, 8> res(memRefType.getShape());
6165  if (vectorType)
6166  res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6167  return res;
6168 }
6169 
6170 /// Build the canonical memRefType with a single vector.
6171 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
6172 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
6173  Value source) {
6174  result.addOperands(source);
6175  MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
6176  VectorType vectorType =
6177  VectorType::get(extractShape(memRefType),
6179  result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
6180  memRefType.getMemorySpace()));
6181 }
6182 
6183 LogicalResult TypeCastOp::verify() {
6184  MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
6185  if (!canonicalType.getLayout().isIdentity())
6186  return emitOpError("expects operand to be a memref with identity layout");
6187  if (!getResultMemRefType().getLayout().isIdentity())
6188  return emitOpError("expects result to be a memref with identity layout");
6189  if (getResultMemRefType().getMemorySpace() !=
6190  getMemRefType().getMemorySpace())
6191  return emitOpError("expects result in same memory space");
6192 
6193  auto sourceType = getMemRefType();
6194  auto resultType = getResultMemRefType();
6195  if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
6197  return emitOpError(
6198  "expects result and operand with same underlying scalar type: ")
6199  << resultType;
6200  if (extractShape(sourceType) != extractShape(resultType))
6201  return emitOpError(
6202  "expects concatenated result and operand shapes to be equal: ")
6203  << resultType;
6204  return success();
6205 }
6206 
6207 //===----------------------------------------------------------------------===//
6208 // TransposeOp
6209 //===----------------------------------------------------------------------===//
6210 
6211 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
6212  Value vector, ArrayRef<int64_t> permutation) {
6213  VectorType vt = llvm::cast<VectorType>(vector.getType());
6214  SmallVector<int64_t, 4> transposedShape(vt.getRank());
6215  SmallVector<bool, 4> transposedScalableDims(vt.getRank());
6216  for (unsigned i = 0; i < permutation.size(); ++i) {
6217  transposedShape[i] = vt.getShape()[permutation[i]];
6218  transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6219  }
6220 
6221  result.addOperands(vector);
6222  result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
6223  transposedScalableDims));
6224  result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
6225  builder.getDenseI64ArrayAttr(permutation));
6226 }
6227 
6228 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6229  // Eliminate splat constant transpose ops.
6230  if (auto splat =
6231  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6232  return splat.reshape(getResultVectorType());
6233 
6234  // Eliminate poison transpose ops.
6235  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
6236  return ub::PoisonAttr::get(getContext());
6237 
6238  // Eliminate identity transposes, and more generally any transposes that
6239  // preserves the shape without permuting elements.
6240  //
6241  // Examples of what to fold:
6242  // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6243  // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6244  // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6245  //
6246  // Example of what NOT to fold:
6247  // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6248  //
6249  if (getSourceVectorType() == getResultVectorType() &&
6250  isOrderPreserving(*this))
6251  return getVector();
6252 
6253  return {};
6254 }
6255 
6256 LogicalResult vector::TransposeOp::verify() {
6257  VectorType vectorType = getSourceVectorType();
6258  VectorType resultType = getResultVectorType();
6259  int64_t rank = resultType.getRank();
6260  if (vectorType.getRank() != rank)
6261  return emitOpError("vector result rank mismatch: ") << rank;
6262  // Verify transposition array.
6263  ArrayRef<int64_t> perm = getPermutation();
6264  int64_t size = perm.size();
6265  if (rank != size)
6266  return emitOpError("transposition length mismatch: ") << size;
6267  SmallVector<bool, 8> seen(rank, false);
6268  for (const auto &ta : llvm::enumerate(perm)) {
6269  if (ta.value() < 0 || ta.value() >= rank)
6270  return emitOpError("transposition index out of range: ") << ta.value();
6271  if (seen[ta.value()])
6272  return emitOpError("duplicate position index: ") << ta.value();
6273  seen[ta.value()] = true;
6274  if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6275  return emitOpError("dimension size mismatch at: ") << ta.value();
6276  }
6277  return success();
6278 }
6279 
6280 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6281  return llvm::to_vector<4>(getResultVectorType().getShape());
6282 }
6283 
6284 namespace {
6285 
6286 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
6287 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
6288 public:
6290 
6291  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6292  PatternRewriter &rewriter) const override {
6293  // Composes two permutations: result[i] = permutation1[permutation2[i]].
6294  auto composePermutations = [](ArrayRef<int64_t> permutation1,
6295  ArrayRef<int64_t> permutation2) {
6296  SmallVector<int64_t, 4> result;
6297  for (auto index : permutation2)
6298  result.push_back(permutation1[index]);
6299  return result;
6300  };
6301 
6302  // Return if the input of 'transposeOp' is not defined by another transpose.
6303  vector::TransposeOp parentTransposeOp =
6304  transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6305  if (!parentTransposeOp)
6306  return failure();
6307 
6308  SmallVector<int64_t, 4> permutation = composePermutations(
6309  parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6310  // Replace 'transposeOp' with a new transpose operation.
6311  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
6312  transposeOp, transposeOp.getResult().getType(),
6313  parentTransposeOp.getVector(), permutation);
6314  return success();
6315  }
6316 };
6317 
6318 // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
6319 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
6320 public:
6322 
6323  LogicalResult matchAndRewrite(TransposeOp transposeOp,
6324  PatternRewriter &rewriter) const override {
6325  auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
6326  if (!splatOp)
6327  return failure();
6328 
6329  rewriter.replaceOpWithNewOp<vector::SplatOp>(
6330  transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
6331  return success();
6332  }
6333 };
6334 
6335 /// Folds transpose(create_mask) into a new transposed create_mask.
6336 class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6337 public:
6339 
6340  LogicalResult matchAndRewrite(TransposeOp transpOp,
6341  PatternRewriter &rewriter) const override {
6342  Value transposeSrc = transpOp.getVector();
6343  auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
6344  auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
6345  if (!createMaskOp && !constantMaskOp)
6346  return failure();
6347 
6348  // Get the transpose permutation and apply it to the vector.create_mask or
6349  // vector.constant_mask operands.
6350  ArrayRef<int64_t> permutation = transpOp.getPermutation();
6351 
6352  if (createMaskOp) {
6353  auto maskOperands = createMaskOp.getOperands();
6354  SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
6355  applyPermutationToVector(newOperands, permutation);
6356 
6357  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
6358  transpOp, transpOp.getResultVectorType(), newOperands);
6359  return success();
6360  }
6361 
6362  // ConstantMaskOp case.
6363  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6364  auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
6365 
6366  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
6367  transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6368  return success();
6369  }
6370 };
6371 
6372 /// Folds transpose(shape_cast) into a new shape_cast.
6373 class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6374 public:
6376 
6377  LogicalResult matchAndRewrite(TransposeOp transposeOp,
6378  PatternRewriter &rewriter) const override {
6379  auto shapeCastOp =
6380  transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6381  if (!shapeCastOp)
6382  return failure();
6383  if (!isOrderPreserving(transposeOp))
6384  return failure();
6385 
6386  VectorType resultType = transposeOp.getType();
6387 
6388  // We don't need to check isValidShapeCast at this point, because it is
6389  // guaranteed that merging the transpose into the the shape_cast is a valid
6390  // shape_cast, because the transpose just inserts/removes ones.
6391 
6392  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
6393  shapeCastOp.getSource());
6394  return success();
6395  }
6396 };
6397 
6398 /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
6399 /// 'order preserving', where 'order preserving' means the flattened
6400 /// inputs and outputs of the transpose have identical (numerical) values.
6401 ///
6402 /// Example:
6403 /// ```
6404 /// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
6405 /// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
6406 /// to vector<8x1xi32>
6407 /// ```
6408 /// can be rewritten as the equivalent
6409 /// ```
6410 /// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
6411 /// ```
6412 /// The algorithm works by partitioning dimensions into groups that can be
6413 /// locally permuted while preserving order, and checks that the transpose
6414 /// only permutes within these groups.
6415 ///
6416 /// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
6417 /// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
6418 /// broadcasting from 1x1x4x1x1x7.
6419 /// ^^^ ^ ^^^ ^
6420 /// groups: 0 1 2 3
6421 /// Order preserving permutations for this example are ones that only permute
6422 /// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
6423 class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6424 public:
6426  FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
6427  : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
6428 
6429  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
6430  PatternRewriter &rewriter) const override {
6431 
6432  vector::BroadcastOp broadcast =
6433  transpose.getVector().getDefiningOp<vector::BroadcastOp>();
6434  if (!broadcast) {
6435  return rewriter.notifyMatchFailure(transpose,
6436  "not preceded by a broadcast");
6437  }
6438 
6439  auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
6440  VectorType outputType = transpose.getResultVectorType();
6441 
6442  // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
6443  bool inputIsScalar = !inputType;
6444  if (inputIsScalar) {
6445  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
6446  broadcast.getSource());
6447  return success();
6448  }
6449 
6450  ArrayRef<int64_t> permutation = transpose.getPermutation();
6451  ArrayRef<int64_t> inputShape = inputType.getShape();
6452  int64_t inputRank = inputType.getRank();
6453  int64_t outputRank = transpose.getType().getRank();
6454  int64_t deltaRank = outputRank - inputRank;
6455 
6456  int low = 0;
6457  for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
6458  bool notOne = inputShape[inputIndex] != 1;
6459  bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
6460  bool groupEndFound = notOne || prevNotOne;
6461  if (groupEndFound) {
6462  int high = inputIndex + deltaRank;
6463  // Return failure if not all permutation destinations for indices in
6464  // [low, high) are in [low, high), i.e. the permutation is not local to
6465  // the group.
6466  for (int i = low; i < high; ++i) {
6467  if (permutation[i] < low || permutation[i] >= high) {
6468  return rewriter.notifyMatchFailure(
6469  transpose, "permutation not local to group");
6470  }
6471  }
6472  low = high;
6473  }
6474  }
6475 
6476  // We don't need to check the final group [low, outputRank) because if it is
6477  // not locally bound, there must be a preceding group that already failed
6478  // the check (impossible to have just 1 non-locally bound group).
6479 
6480  // The preceding logic also ensures that at this point, the output of the
6481  // transpose is definitely broadcastable from the input shape, assert so:
6482  assert(vector::isBroadcastableTo(inputType, outputType) ==
6483  vector::BroadcastableToResult::Success &&
6484  "not broadcastable directly to transpose output");
6485 
6486  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
6487  broadcast.getSource());
6488 
6489  return success();
6490  }
6491 };
6492 
6493 } // namespace
6494 
6495 void vector::TransposeOp::getCanonicalizationPatterns(
6496  RewritePatternSet &results, MLIRContext *context) {
6497  results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6498  FoldTransposeSplat, FoldTransposeBroadcast>(context);
6499 }
6500 
6501 //===----------------------------------------------------------------------===//
6502 // ConstantMaskOp
6503 //===----------------------------------------------------------------------===//
6504 
6505 void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
6506  VectorType type, ConstantMaskKind kind) {
6507  assert(kind == ConstantMaskKind::AllTrue ||
6508  kind == ConstantMaskKind::AllFalse);
6509  build(builder, result, type,
6510  kind == ConstantMaskKind::AllTrue
6511  ? type.getShape()
6512  : SmallVector<int64_t>(type.getRank(), 0));
6513 }
6514 
6515 LogicalResult ConstantMaskOp::verify() {
6516  auto resultType = llvm::cast<VectorType>(getResult().getType());
6517  // Check the corner case of 0-D vectors first.
6518  if (resultType.getRank() == 0) {
6519  if (getMaskDimSizes().size() != 1)
6520  return emitError("array attr must have length 1 for 0-D vectors");
6521  auto dim = getMaskDimSizes()[0];
6522  if (dim != 0 && dim != 1)
6523  return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
6524  return success();
6525  }
6526 
6527  // Verify that array attr size matches the rank of the vector result.
6528  if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
6529  return emitOpError(
6530  "must specify array attr of size equal vector result rank");
6531  // Verify that each array attr element is in bounds of corresponding vector
6532  // result dimension size.
6533  auto resultShape = resultType.getShape();
6534  auto resultScalableDims = resultType.getScalableDims();
6535  ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
6536  for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
6537  if (maskDimSize < 0 || maskDimSize > resultShape[index])
6538  return emitOpError(
6539  "array attr of size out of bounds of vector result dimension size");
6540  if (resultScalableDims[index] && maskDimSize != 0 &&
6541  maskDimSize != resultShape[index])
6542  return emitOpError(
6543  "only supports 'none set' or 'all set' scalable dimensions");
6544  }
6545  // Verify that if one mask dim size is zero, they all should be zero (because
6546  // the mask region is a conjunction of each mask dimension interval).
6547  bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6548  bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
6549  if (anyZeros && !allZeros)
6550  return emitOpError("expected all mask dim sizes to be zeros, "
6551  "as a result of conjunction with zero mask dim");
6552  return success();
6553 }
6554 
6555 bool ConstantMaskOp::isAllOnesMask() {
6556  auto resultType = getVectorType();
6557  // Check the corner case of 0-D vectors first.
6558  if (resultType.getRank() == 0) {
6559  assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
6560  return getMaskDimSizes()[0] == 1;
6561  }
6562  for (const auto [resultSize, maskDimSize] :
6563  llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
6564  if (maskDimSize < resultSize)
6565  return false;
6566  }
6567  return true;
6568 }
6569 
6570 //===----------------------------------------------------------------------===//
6571 // CreateMaskOp
6572 //===----------------------------------------------------------------------===//
6573 
6574 void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
6575  VectorType type,
6576  ArrayRef<OpFoldResult> mixedOperands) {
6577  SmallVector<Value> operands =
6578  getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
6579  build(builder, result, type, operands);
6580 }
6581 
6582 LogicalResult CreateMaskOp::verify() {
6583  auto vectorType = llvm::cast<VectorType>(getResult().getType());
6584  // Verify that an operand was specified for each result vector each dimension.
6585  if (vectorType.getRank() == 0) {
6586  if (getNumOperands() != 1)
6587  return emitOpError(
6588  "must specify exactly one operand for 0-D create_mask");
6589  } else if (getNumOperands() !=
6590  llvm::cast<VectorType>(getResult().getType()).getRank()) {
6591  return emitOpError(
6592  "must specify an operand for each result vector dimension");
6593  }
6594  return success();
6595 }
6596 
6597 namespace {
6598 
6599 /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
6600 ///
6601 /// Ex 1:
6602 /// %c2 = arith.constant 2 : index
6603 /// %c3 = arith.constant 3 : index
6604 /// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
6605 /// Becomes:
6606 /// vector.constant_mask [3, 2] : vector<4x3xi1>
6607 ///
6608 /// Ex 2:
6609 /// %c_neg_1 = arith.constant -1 : index
6610 /// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
6611 /// becomes:
6612 /// vector.constant_mask [0] : vector<[8]xi1>
6613 ///
6614 /// Ex 3:
6615 /// %c8 = arith.constant 8 : index
6616 /// %c16 = arith.constant 16 : index
6617 /// %0 = vector.vscale
6618 /// %1 = arith.muli %0, %c16 : index
6619 /// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
6620 /// becomes:
6621 /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
6622 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
6623 public:
6625 
6626  LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
6627  PatternRewriter &rewriter) const override {
6628  VectorType maskType = createMaskOp.getVectorType();
6629  ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
6630  ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6631 
6632  // Special case: Rank zero shape.
6633  constexpr std::array<int64_t, 1> rankZeroShape{1};
6634  constexpr std::array<bool, 1> rankZeroScalableDims{false};
6635  if (maskType.getRank() == 0) {
6636  maskTypeDimSizes = rankZeroShape;
6637  maskTypeDimScalableFlags = rankZeroScalableDims;
6638  }
6639 
6640  // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
6641  // collect the `constantDims` (for the ConstantMaskOp).
6642  SmallVector<int64_t, 4> constantDims;
6643  for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
6644  if (auto intSize = getConstantIntValue(dimSize)) {
6645  // Constant value.
6646  // If the mask dim is non-scalable this can be any value.
6647  // If the mask dim is scalable only zero (all-false) is supported.
6648  if (maskTypeDimScalableFlags[i] && intSize >= 0)
6649  return failure();
6650  constantDims.push_back(*intSize);
6651  } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
6652  // Constant vscale multiple (e.g. 4 x vscale).
6653  // Must be all-true to fold to a ConstantMask.
6654  if (vscaleMultiplier < maskTypeDimSizes[i])
6655  return failure();
6656  constantDims.push_back(*vscaleMultiplier);
6657  } else {
6658  return failure();
6659  }
6660  }
6661 
6662  // Clamp values to constant_mask bounds.
6663  for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
6664  value = std::clamp<int64_t>(value, 0, maskDimSize);
6665 
6666  // If one of dim sizes is zero, set all dims to zero.
6667  if (llvm::is_contained(constantDims, 0))
6668  constantDims.assign(constantDims.size(), 0);
6669 
6670  // Replace 'createMaskOp' with ConstantMaskOp.
6671  rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
6672  constantDims);
6673  return success();
6674  }
6675 };
6676 
6677 } // namespace
6678 
6679 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6680  MLIRContext *context) {
6681  results.add<CreateMaskFolder>(context);
6682 }
6683 
6684 //===----------------------------------------------------------------------===//
6685 // MaskOp
6686 //===----------------------------------------------------------------------===//
6687 
6688 void MaskOp::build(
6689  OpBuilder &builder, OperationState &result, Value mask,
6690  Operation *maskableOp,
6691  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6692  assert(maskRegionBuilder &&
6693  "builder callback for 'maskRegion' must be present");
6694 
6695  result.addOperands(mask);
6696  OpBuilder::InsertionGuard guard(builder);
6697  Region *maskRegion = result.addRegion();
6698  builder.createBlock(maskRegion);
6699  maskRegionBuilder(builder, maskableOp);
6700 }
6701 
6702 void MaskOp::build(
6703  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6704  Value mask, Operation *maskableOp,
6705  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6706  build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
6707  maskRegionBuilder);
6708 }
6709 
6710 void MaskOp::build(
6711  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6712  Value mask, Value passthru, Operation *maskableOp,
6713  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6714  build(builder, result, mask, maskableOp, maskRegionBuilder);
6715  if (passthru)
6716  result.addOperands(passthru);
6717  result.addTypes(resultTypes);
6718 }
6719 
6720 ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
6721  // Create the op region.
6722  result.regions.reserve(1);
6723  Region &maskRegion = *result.addRegion();
6724 
6725  auto &builder = parser.getBuilder();
6726 
6727  // Parse all the operands.
6729  if (parser.parseOperand(mask))
6730  return failure();
6731 
6732  // Optional passthru operand.
6734  ParseResult parsePassthru = parser.parseOptionalComma();
6735  if (parsePassthru.succeeded() && parser.parseOperand(passthru))
6736  return failure();
6737 
6738  // Parse op region.
6739  if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
6740  return failure();
6741 
6742  MaskOp::ensureTerminator(maskRegion, builder, result.location);
6743 
6744  // Parse the optional attribute list.
6745  if (parser.parseOptionalAttrDict(result.attributes))
6746  return failure();
6747 
6748  // Parse all the types.
6749  Type maskType;
6750  if (parser.parseColonType(maskType))
6751  return failure();
6752 
6753  SmallVector<Type> resultTypes;
6754  if (parser.parseOptionalArrowTypeList(resultTypes))
6755  return failure();
6756  result.types.append(resultTypes);
6757 
6758  // Resolve operands.
6759  if (parser.resolveOperand(mask, maskType, result.operands))
6760  return failure();
6761 
6762  if (parsePassthru.succeeded()) {
6763  if (resultTypes.empty())
6764  return parser.emitError(
6765  parser.getNameLoc(),
6766  "expects a result if passthru operand is provided");
6767 
6768  if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
6769  return failure();
6770  }
6771 
6772  return success();
6773 }
6774 
6776  p << " " << getMask();
6777  if (getPassthru())
6778  p << ", " << getPassthru();
6779 
6780  // Print single masked operation and skip terminator.
6781  p << " { ";
6782  Block *singleBlock = &getMaskRegion().getBlocks().front();
6783  if (singleBlock && !singleBlock->getOperations().empty())
6784  p.printCustomOrGenericOp(&singleBlock->front());
6785  p << " }";
6786 
6787  p.printOptionalAttrDict(getOperation()->getAttrs());
6788 
6789  p << " : " << getMask().getType();
6790  if (getNumResults() > 0)
6791  p << " -> " << getResultTypes();
6792 }
6793 
6794 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
6795  // 1. For an empty `vector.mask`, create a default terminator.
6796  if (region.empty() || region.front().empty()) {
6798  MaskOp>::ensureTerminator(region, builder, loc);
6799  return;
6800  }
6801 
6802  // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
6803  Block &block = region.front();
6804  if (isa<vector::YieldOp>(block.back()))
6805  return;
6806 
6807  // 3. For a non-empty `vector.mask` without an explicit terminator:
6808 
6809  // Create default terminator if the number of masked operations is not
6810  // one. This case will trigger a verification failure.
6811  if (block.getOperations().size() != 1) {
6813  MaskOp>::ensureTerminator(region, builder, loc);
6814  return;
6815  }
6816 
6817  // Create a terminator that yields the results from the masked operation.
6818  OpBuilder opBuilder(builder.getContext());
6819  Operation *maskedOp = &block.front();
6820  opBuilder.setInsertionPointToEnd(&block);
6821  opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
6822 }
6823 
6824 LogicalResult MaskOp::verify() {
6825  // Structural checks.
6826  Block &block = getMaskRegion().getBlocks().front();
6827  if (block.getOperations().empty())
6828  return emitOpError("expects a terminator within the mask region");
6829 
6830  unsigned numMaskRegionOps = block.getOperations().size();
6831  if (numMaskRegionOps > 2)
6832  return emitOpError("expects only one operation to mask");
6833 
6834  // Terminator checks.
6835  auto terminator = dyn_cast<vector::YieldOp>(block.back());
6836  if (!terminator)
6837  return emitOpError("expects a terminator within the mask region");
6838 
6839  if (terminator->getNumOperands() != getNumResults())
6840  return emitOpError(
6841  "expects number of results to match mask region yielded values");
6842 
6843  // Empty vector.mask. Nothing else to check.
6844  if (numMaskRegionOps == 1)
6845  return success();
6846 
6847  auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
6848  if (!maskableOp)
6849  return emitOpError("expects a MaskableOpInterface within the mask region");
6850 
6851  // Result checks.
6852  if (maskableOp->getNumResults() != getNumResults())
6853  return emitOpError("expects number of results to match maskable operation "
6854  "number of results");
6855 
6856  if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
6857  return emitOpError("expects all the results from the MaskableOpInterface "
6858  "to match all the values returned by the terminator");
6859 
6860  if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
6861  return emitOpError(
6862  "expects result type to match maskable operation result type");
6863 
6864  if (llvm::count_if(maskableOp->getResultTypes(),
6865  [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
6866  return emitOpError("multiple vector results not supported");
6867 
6868  // Mask checks.
6869  Type expectedMaskType = maskableOp.getExpectedMaskType();
6870  if (getMask().getType() != expectedMaskType)
6871  return emitOpError("expects a ")
6872  << expectedMaskType << " mask for the maskable operation";
6873 
6874  // Passthru checks.
6875  Value passthru = getPassthru();
6876  if (passthru) {
6877  if (!maskableOp.supportsPassthru())
6878  return emitOpError(
6879  "doesn't expect a passthru argument for this maskable operation");
6880 
6881  if (maskableOp->getNumResults() != 1)
6882  return emitOpError("expects result when passthru argument is provided");
6883 
6884  if (passthru.getType() != maskableOp->getResultTypes()[0])
6885  return emitOpError("expects passthru type to match result type");
6886  }
6887 
6888  return success();
6889 }
6890 
6891 /// Folds empty `vector.mask` with no passthru operand and with or without
6892 /// return values. For example:
6893 ///
6894 /// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
6895 /// vector<8xi1> -> vector<8xf32>
6896 /// %1 = user_op %0 : vector<8xf32>
6897 ///
6898 /// becomes:
6899 ///
6900 /// %0 = user_op %a : vector<8xf32>
6901 ///
6902 /// Empty `vector.mask` with passthru operand are handled by the canonicalizer
6903 /// as it requires creating new operations.
6904 
6905 static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
6906  SmallVectorImpl<OpFoldResult> &results) {
6907  if (!maskOp.isEmpty() || maskOp.hasPassthru())
6908  return failure();
6909 
6910  Block *block = maskOp.getMaskBlock();
6911  auto terminator = cast<vector::YieldOp>(block->front());
6912  if (terminator.getNumOperands() == 0) {
6913  // `vector.mask` has no results, just remove the `vector.mask`.
6914  return success();
6915  }
6916 
6917  // `vector.mask` has results, propagate the results.
6918  llvm::append_range(results, terminator.getOperands());
6919  return success();
6920 }
6921 
6922 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6923  SmallVectorImpl<OpFoldResult> &results) {
6924  if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
6925  return success();
6926 
6927  MaskFormat maskFormat = getMaskFormat(getMask());
6928  if (maskFormat != MaskFormat::AllTrue)
6929  return failure();
6930 
6931  // Move maskable operation outside of the `vector.mask` region.
6932  Operation *maskableOp = getMaskableOp();
6933  maskableOp->dropAllUses();
6934  maskableOp->moveBefore(getOperation());
6935 
6936  llvm::append_range(results, maskableOp->getResults());
6937  return success();
6938 }
6939 
6940 /// Canonialize empty `vector.mask` operations that can't be handled in
6941 /// `VectorMask::fold` as they require creating new operations.
6942 ///
6943 /// Example 1: Empty `vector.mask` with passthru operand.
6944 ///
6945 /// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
6946 /// vector<8xi1> -> vector<8xf32>
6947 ///
6948 /// becomes:
6949 ///
6950 /// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
6951 ///
6952 class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
6954 
6955  LogicalResult matchAndRewrite(MaskOp maskOp,
6956  PatternRewriter &rewriter) const override {
6957  if (!maskOp.isEmpty())
6958  return failure();
6959 
6960  if (!maskOp.hasPassthru())
6961  return failure();
6962 
6963  Block *block = maskOp.getMaskBlock();
6964  auto terminator = cast<vector::YieldOp>(block->front());
6965  assert(terminator.getNumOperands() == 1 &&
6966  "expected one result when passthru is provided");
6967 
6968  rewriter.replaceOpWithNewOp<arith::SelectOp>(
6969  maskOp, maskOp.getResultTypes(), maskOp.getMask(),
6970  terminator.getOperand(0), maskOp.getPassthru());
6971 
6972  return success();
6973  }
6974 };
6975 
6976 void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6977  MLIRContext *context) {
6978  results.add<CanonializeEmptyMaskOp>(context);
6979 }
6980 
6981 // MaskingOpInterface definitions.
6982 
6983 /// Returns the operation masked by this 'vector.mask'.
6984 Operation *MaskOp::getMaskableOp() {
6985  Block *block = getMaskBlock();
6986  if (block->getOperations().size() < 2)
6987  return nullptr;
6988 
6989  return &block->front();
6990 }
6991 
6992 /// Returns true if 'vector.mask' has a passthru value.
6993 bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
6994 
6995 //===----------------------------------------------------------------------===//
6996 // ScanOp
6997 //===----------------------------------------------------------------------===//
6998 
6999 LogicalResult ScanOp::verify() {
7000  VectorType srcType = getSourceType();
7001  VectorType initialType = getInitialValueType();
7002  // Check reduction dimension < rank.
7003  int64_t srcRank = srcType.getRank();
7004  int64_t reductionDim = getReductionDim();
7005  if (reductionDim >= srcRank)
7006  return emitOpError("reduction dimension ")
7007  << reductionDim << " has to be less than " << srcRank;
7008 
7009  // Check that rank(initial_value) = rank(src) - 1.
7010  int64_t initialValueRank = initialType.getRank();
7011  if (initialValueRank != srcRank - 1)
7012  return emitOpError("initial value rank ")
7013  << initialValueRank << " has to be equal to " << srcRank - 1;
7014 
7015  // Check shapes of initial value and src.
7016  ArrayRef<int64_t> srcShape = srcType.getShape();
7017  ArrayRef<int64_t> initialValueShapes = initialType.getShape();
7018  SmallVector<int64_t> expectedShape;
7019  for (int i = 0; i < srcRank; i++) {
7020  if (i != reductionDim)
7021  expectedShape.push_back(srcShape[i]);
7022  }
7023  if (!llvm::equal(initialValueShapes, expectedShape)) {
7024  return emitOpError("incompatible input/initial value shapes");
7025  }
7026 
7027  // Verify supported reduction kind.
7028  Type eltType = getDestType().getElementType();
7029  if (!isSupportedCombiningKind(getKind(), eltType))
7030  return emitOpError("unsupported reduction type ")
7031  << eltType << " for kind '" << stringifyCombiningKind(getKind())
7032  << "'";
7033 
7034  return success();
7035 }
7036 
7039  patterns
7040  .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7041  ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7042  StridedSliceConstantMaskFolder, TransposeFolder>(
7043  patterns.getContext(), benefit);
7044 }
7045 
7046 //===----------------------------------------------------------------------===//
7047 // SplatOp
7048 //===----------------------------------------------------------------------===//
7049 
7050 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
7051  auto constOperand = adaptor.getInput();
7052  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
7053  return {};
7054 
7055  // SplatElementsAttr::get treats single value for second arg as being a splat.
7056  return SplatElementsAttr::get(getType(), {constOperand});
7057 }
7058 
7059 void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7060  SetIntRangeFn setResultRanges) {
7061  setResultRanges(getResult(), argRanges.front());
7062 }
7063 
7065  CombiningKind kind, Value v1, Value acc,
7066  arith::FastMathFlagsAttr fastmath,
7067  Value mask) {
7068  Type t1 = getElementTypeOrSelf(v1.getType());
7069  Type tAcc = getElementTypeOrSelf(acc.getType());
7070  Value result;
7071 
7072  switch (kind) {
7073  case CombiningKind::ADD:
7074  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7075  result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
7076  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7077  result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7078  else
7079  llvm_unreachable("invalid value types for ADD reduction");
7080  break;
7081  case CombiningKind::AND:
7082  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7083  result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
7084  break;
7085  case CombiningKind::MAXNUMF:
7086  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7087  "expected float values");
7088  result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7089  break;
7090  case CombiningKind::MAXIMUMF:
7091  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7092  "expected float values");
7093  result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7094  break;
7095  case CombiningKind::MINNUMF:
7096  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7097  "expected float values");
7098  result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7099  break;
7100  case CombiningKind::MINIMUMF:
7101  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7102  "expected float values");
7103  result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7104  break;
7105  case CombiningKind::MAXSI:
7106  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7107  result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
7108  break;
7109  case CombiningKind::MINSI:
7110  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7111  result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
7112  break;
7113  case CombiningKind::MAXUI:
7114  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7115  result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
7116  break;
7117  case CombiningKind::MINUI:
7118  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7119  result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
7120  break;
7121  case CombiningKind::MUL:
7122  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7123  result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
7124  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7125  result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7126  else
7127  llvm_unreachable("invalid value types for MUL reduction");
7128  break;
7129  case CombiningKind::OR:
7130  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7131  result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
7132  break;
7133  case CombiningKind::XOR:
7134  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7135  result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
7136  break;
7137  };
7138 
7139  assert(result && "unknown CombiningKind");
7140  return selectPassthru(b, mask, result, acc);
7141 }
7142 
7143 //===----------------------------------------------------------------------===//
7144 // Vector Masking Utilities
7145 //===----------------------------------------------------------------------===//
7146 
7147 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
7148 /// as masked operation.
7150  Operation *maskableOp) {
7151  assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
7152  Block *insBlock = builder.getInsertionBlock();
7153  // Create a block and move the op to that block.
7154  insBlock->getOperations().splice(
7155  insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
7156  builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
7157 }
7158 
7159 /// Creates a vector.mask operation around a maskable operation. Returns the
7160 /// vector.mask operation if the mask provided is valid. Otherwise, returns
7161 /// the maskable operation itself.
7163  Operation *maskableOp, Value mask,
7164  Value passthru) {
7165  if (!mask)
7166  return maskableOp;
7167  if (passthru)
7168  return builder.create<MaskOp>(maskableOp->getLoc(),
7169  maskableOp->getResultTypes(), mask, passthru,
7170  maskableOp, createMaskOpRegion);
7171  return builder.create<MaskOp>(maskableOp->getLoc(),
7172  maskableOp->getResultTypes(), mask, maskableOp,
7174 }
7175 
7176 /// Creates a vector select operation that picks values from `newValue` or
7177 /// `passthru` for each result vector lane based on `mask`. This utility is used
7178 /// to propagate the pass-thru value of vector.mask or for cases where only the
7179 /// pass-thru value propagation is needed. VP intrinsics do not support
7180 /// pass-thru values and every mask-out lane is set to poison. LLVM backends are
7181 /// usually able to match op + select patterns and fold them into a native
7182 /// target instructions.
7184  Value newValue, Value passthru) {
7185  if (!mask)
7186  return newValue;
7187 
7188  return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(),
7189  mask, newValue, passthru);
7190 }
7191 
7192 //===----------------------------------------------------------------------===//
7193 // TableGen'd op method definitions
7194 //===----------------------------------------------------------------------===//
7195 
7196 #define GET_ATTRDEF_CLASSES
7197 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
7198 
7199 #define GET_OP_CLASSES
7200 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static SmallVector< Value > computeStrides(Location loc, RewriterBase &rewriter, ValueRange dynamicBasis, ArrayRef< int64_t > staticBasis, bool knownNonNegative)
Given a basis (in static and dynamic components), return the sequence of suffix products of the basis...
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
union mlir::linalg::@1205::ArityGroupAndKind::Kind kind
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MINUI(lhs, rhs)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition: VectorOps.cpp:70
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
Definition: VectorOps.cpp:1118
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
Definition: VectorOps.cpp:1422
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
Definition: VectorOps.cpp:2436
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
Definition: VectorOps.cpp:1683
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
Definition: VectorOps.cpp:4414
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
Definition: VectorOps.cpp:1989
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
Definition: VectorOps.cpp:1857
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
Definition: VectorOps.cpp:1694
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
Definition: VectorOps.cpp:2404
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
Definition: VectorOps.cpp:2084
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Definition: VectorOps.cpp:2362
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
Definition: VectorOps.cpp:132
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
Definition: VectorOps.cpp:2074
MaskFormat
Helper enum to classify mask value.
Definition: VectorOps.cpp:60
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
Definition: VectorOps.cpp:3480
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
Definition: VectorOps.cpp:179
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
Definition: VectorOps.cpp:328
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
Definition: VectorOps.cpp:900
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
Definition: VectorOps.cpp:2620
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
Definition: VectorOps.cpp:2025
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
Definition: VectorOps.cpp:2940
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
Definition: VectorOps.cpp:3416
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
Definition: VectorOps.cpp:2382
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
Definition: VectorOps.cpp:1110
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
Definition: VectorOps.cpp:3436
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
Definition: VectorOps.cpp:207
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
Definition: VectorOps.cpp:3459
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
Definition: VectorOps.cpp:3289
static LogicalResult foldTransferFullMask(TransferOp op)
Definition: VectorOps.cpp:4660
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Definition: VectorOps.cpp:1414
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite a vector.from_elements into a vector.splat if all elements are the same SSA value.
Definition: VectorOps.cpp:2470
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
Definition: VectorOps.cpp:1309
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
Definition: VectorOps.cpp:3976
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
Definition: VectorOps.cpp:4306
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
Definition: VectorOps.cpp:911
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
Definition: VectorOps.cpp:3401
static Value foldExtractFromShapeCast(ExtractOp extractOp)
Definition: VectorOps.cpp:1792
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
Definition: VectorOps.cpp:4335
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
Definition: VectorOps.cpp:4588
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
Definition: VectorOps.cpp:3820
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
Definition: VectorOps.cpp:1760
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
Definition: VectorOps.cpp:4605
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
Definition: VectorOps.cpp:1909
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
Definition: VectorOps.cpp:2092
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Definition: VectorOps.cpp:2499
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:135
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:216
unsigned getNumResults() const
Definition: AffineMap.cpp:402
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:161
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:58
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
Operation & back()
Definition: Block.h:152
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:106
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:161
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:165
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:110
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:69
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:322
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:264
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:279
IndexType getIndexType()
Definition: Builders.cpp:53
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:268
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:316
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:428
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:551
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:834
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:810
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:593
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
bool isF32() const
Definition: Types.cpp:40
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:112
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:286
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Definition: BuiltinTypes.h:298
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:305
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:394
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Definition: VectorOps.cpp:482
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
Definition: VectorOps.cpp:127
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
Definition: VectorOps.cpp:386
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:189
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition: VectorOps.h:60
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2752
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Definition: VectorOps.cpp:4433
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
Definition: VectorOps.cpp:251
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:315
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:346
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:370
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
Definition: VectorOps.cpp:658
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition: VectorOps.h:68
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Definition: VectorOps.cpp:478
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:23
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:770
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:788
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:719
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:927
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
Return a fused vector::ContractionOp which represents a patterns such as:
Definition: VectorOps.cpp:1210
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Definition: VectorOps.cpp:1213
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Definition: PatternMatch.h:297
bool operator==(const KeyTy &key) const
Definition: VectorOps.cpp:413
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Definition: VectorOps.cpp:415