MLIR  21.0.0git
VectorOps.cpp
Go to the documentation of this file.
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/AffineMap.h"
28 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/IRMapping.h"
34 #include "mlir/IR/PatternMatch.h"
35 #include "mlir/IR/TypeUtilities.h"
38 #include "mlir/Support/LLVM.h"
40 #include "llvm/ADT/ArrayRef.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SmallVector.h"
43 #include "llvm/ADT/StringSet.h"
44 #include "llvm/ADT/TypeSwitch.h"
45 #include "llvm/Support/Casting.h"
46 
47 #include <cassert>
48 #include <cstdint>
49 #include <numeric>
50 
51 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
52 // Pull in all enum type and utility function definitions.
53 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
54 
55 using namespace mlir;
56 using namespace mlir::vector;
57 
58 /// Helper enum to classify mask value.
59 enum class MaskFormat {
60  AllTrue = 0,
61  AllFalse = 1,
62  Unknown = 2,
63 };
64 
65 /// Helper method to classify a mask value. Currently, the method
66 /// looks "under the hood" of a constant value with dense attributes
67 /// and a constant mask operation (since the client may be called at
68 /// various stages during progressive lowering).
70  if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
71  // Inspect constant dense values. We count up for bits that
72  // are set, count down for bits that are cleared, and bail
73  // when a mix is detected.
74  if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
75  int64_t val = 0;
76  for (bool b : denseElts.getValues<bool>())
77  if (b && val >= 0)
78  val++;
79  else if (!b && val <= 0)
80  val--;
81  else
82  return MaskFormat::Unknown;
83  if (val > 0)
84  return MaskFormat::AllTrue;
85  if (val < 0)
86  return MaskFormat::AllFalse;
87  }
88  } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
89  // Inspect constant mask index. If the index exceeds the
90  // dimension size, all bits are set. If the index is zero
91  // or less, no bits are set.
92  ArrayRef<int64_t> masks = m.getMaskDimSizes();
93  auto shape = m.getType().getShape();
94  bool allTrue = true;
95  bool allFalse = true;
96  for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
97  if (maskIdx < dimSize)
98  allTrue = false;
99  if (maskIdx > 0)
100  allFalse = false;
101  }
102  if (allTrue)
103  return MaskFormat::AllTrue;
104  if (allFalse)
105  return MaskFormat::AllFalse;
106  } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
107  // Finds all-false create_masks. An all-true create_mask requires all
108  // dims to be constants, so that'll be folded to a constant_mask, then
109  // detected in the constant_mask case.
110  auto maskOperands = m.getOperands();
111  for (Value operand : maskOperands) {
112  if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
113  int64_t dimSize =
114  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
115  if (dimSize <= 0)
116  return MaskFormat::AllFalse;
117  }
118  }
119  return MaskFormat::Unknown;
120  }
121  return MaskFormat::Unknown;
122 }
123 
124 /// Default callback to build a region with a 'vector.yield' terminator with no
125 /// arguments.
127  builder.create<vector::YieldOp>(loc);
128 }
129 
130 // Helper for verifying combining kinds in contractions and reductions.
131 static bool isSupportedCombiningKind(CombiningKind combiningKind,
132  Type elementType) {
133  switch (combiningKind) {
134  case CombiningKind::ADD:
135  case CombiningKind::MUL:
136  return elementType.isIntOrIndexOrFloat();
138  case CombiningKind::MINSI:
139  case CombiningKind::MAXUI:
140  case CombiningKind::MAXSI:
141  case CombiningKind::AND:
142  case CombiningKind::OR:
143  case CombiningKind::XOR:
144  return elementType.isIntOrIndex();
145  case CombiningKind::MINNUMF:
146  case CombiningKind::MAXNUMF:
147  case CombiningKind::MINIMUMF:
148  case CombiningKind::MAXIMUMF:
149  return llvm::isa<FloatType>(elementType);
150  }
151  return false;
152 }
153 
154 /// Returns the effective rank of the vector to read/write for Xfer Ops
155 ///
156 /// When the element type of the shaped type is _a scalar_, this will simply
157 /// return the rank of the vector ( the result for xfer_read or the value to
158 /// store for xfer_write).
159 ///
160 /// When the element type of the base shaped type is _a vector_, returns the
161 /// difference between the original vector type and the element type of the
162 /// shaped type.
163 ///
164 /// EXAMPLE 1 (element type is _a scalar_):
165 /// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
166 /// - shapedType.getElementType() = f32 (rank 0)
167 /// - vectorType.getRank() = 2
168 /// - Result = 2 - 0 = 2
169 ///
170 /// EXAMPLE 2 (element type is _a vector_):
171 /// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
172 /// - shapedType.getElementType() = vector<20xf32> (rank 1)
173 /// - vectorType.getRank() = 1
174 /// - Result = 1 - 1 = 0
175 ///
176 /// This is used to determine the number of minor dimensions for identity maps
177 /// in vector transfer Ops.
178 static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
179  VectorType vectorType) {
180  unsigned elementVectorRank = 0;
181  VectorType elementVectorType =
182  llvm::dyn_cast<VectorType>(shapedType.getElementType());
183  if (elementVectorType)
184  elementVectorRank += elementVectorType.getRank();
185  return vectorType.getRank() - elementVectorRank;
186 }
187 
189  VectorType vectorType) {
190  // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
191  // TODO: replace once we have 0-d vectors.
192  if (shapedType.getRank() == 0 &&
193  vectorType.getShape() == ArrayRef<int64_t>{1})
194  return AffineMap::get(
195  /*numDims=*/0, /*numSymbols=*/0,
196  getAffineConstantExpr(0, shapedType.getContext()));
198  shapedType.getRank(),
199  getEffectiveVectorRankForXferOp(shapedType, vectorType),
200  shapedType.getContext());
201 }
202 
203 /// Check if `write` is of a constant splat and the masked `read` is padded with
204 /// the same splat value -- meaning it could be the same value as the initial
205 /// constant splat.
206 static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
207  vector::TransferReadOp read) {
208  auto readMask = read.getMask();
209  auto writeMask = write.getMask();
210  // Check if the masks are consistent. The splat value could be the same if the
211  // read is masked (and padded with the splat value), and the write is unmasked
212  // or has the same mask. Note this does not allow the case where the write is
213  // masked and the read is unmasked, as then the read could be of more elements
214  // than the write (which may not be the same value).
215  bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
216  if (!couldBeSameSplat)
217  return false;
218  // Check for constant splat (as the source of the write).
219  DenseElementsAttr splatAttr;
220  if (!matchPattern(write.getVector(),
221  m_Constant<DenseElementsAttr>(&splatAttr)) ||
222  !splatAttr.isSplat()) {
223  return false;
224  }
225  // The padding of the read and the constant splat value must be the same.
226  Attribute padAttr;
227  if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
228  return false;
229  return padAttr == splatAttr.getSplatValue<Attribute>();
230 }
231 
232 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
233  vector::TransferReadOp read) {
234  return !defWrite.hasOutOfBoundsDim() &&
235  defWrite.getIndices() == read.getIndices() &&
236  defWrite.getVectorType() == read.getVectorType() &&
237  defWrite.getPermutationMap() == read.getPermutationMap() &&
238  ((!defWrite.getMask() && !read.getMask()) ||
239  isSplatWriteConsistentWithMaskedRead(defWrite, read));
240 }
241 
242 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
243  vector::TransferWriteOp priorWrite) {
244  return priorWrite.getIndices() == write.getIndices() &&
245  priorWrite.getMask() == write.getMask() &&
246  priorWrite.getVectorType() == write.getVectorType() &&
247  priorWrite.getPermutationMap() == write.getPermutationMap();
248 }
249 
251  VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
252  bool testDynamicValueUsingBounds) {
253  // For simplicity only look at transfer of same type.
254  if (transferA.getVectorType() != transferB.getVectorType())
255  return false;
256  unsigned rankOffset = transferA.getLeadingShapedRank();
257  for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
258  Value indexA = transferA.getIndices()[i];
259  Value indexB = transferB.getIndices()[i];
260  std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
261  std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
262 
263  if (i < rankOffset) {
264  // For leading dimensions, if we can prove that index are different we
265  // know we are accessing disjoint slices.
266  if (cstIndexA.has_value() && cstIndexB.has_value()) {
267  if (*cstIndexA != *cstIndexB)
268  return true;
269  continue;
270  }
271  if (testDynamicValueUsingBounds) {
272  // First try to see if we can fully compose and simplify the affine
273  // expression as a fast track.
274  FailureOr<uint64_t> delta =
276  if (succeeded(delta) && *delta != 0)
277  return true;
278 
279  FailureOr<bool> testEqual =
280  ValueBoundsConstraintSet::areEqual(indexA, indexB);
281  if (succeeded(testEqual) && !testEqual.value())
282  return true;
283  }
284  } else {
285  // For this dimension, we slice a part of the memref we need to make sure
286  // the intervals accessed don't overlap.
287  int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
288  if (cstIndexA.has_value() && cstIndexB.has_value()) {
289  int64_t distance = std::abs(*cstIndexA - *cstIndexB);
290  if (distance >= vectorDim)
291  return true;
292  continue;
293  }
294  if (testDynamicValueUsingBounds) {
295  // First try to see if we can fully compose and simplify the affine
296  // expression as a fast track.
297  FailureOr<int64_t> delta =
299  if (succeeded(delta) && std::abs(*delta) >= vectorDim)
300  return true;
301 
302  FailureOr<int64_t> computeDelta =
304  if (succeeded(computeDelta)) {
305  if (std::abs(computeDelta.value()) >= vectorDim)
306  return true;
307  }
308  }
309  }
310  }
311  return false;
312 }
313 
314 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
315  VectorTransferOpInterface transferB,
316  bool testDynamicValueUsingBounds) {
317  if (transferA.getBase() != transferB.getBase())
318  return false;
319  return isDisjointTransferIndices(transferA, transferB,
320  testDynamicValueUsingBounds);
321 }
322 
323 // Helper to iterate over n-D vector slice elements. Calculate the next
324 // `position` in the n-D vector of size `shape`, applying an offset `offsets`.
325 // Modifies the `position` in place. Returns a failure when `position` becomes
326 // the end position.
327 static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
328  ArrayRef<int64_t> shape,
329  ArrayRef<int64_t> offsets) {
330  for (auto [posInDim, dimSize, offsetInDim] :
331  llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
332  ++posInDim;
333  if (posInDim < dimSize + offsetInDim)
334  return success();
335 
336  // Carry the overflow to the next loop iteration.
337  posInDim = offsetInDim;
338  }
339 
340  return failure();
341 }
342 
343 /// Returns the integer numbers in `values`. `values` are expected to be
344 /// constant operations.
347  llvm::transform(values, std::back_inserter(ints), [](Value value) {
348  auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
349  assert(constOp && "Unexpected non-constant index");
350  return constOp.value();
351  });
352  return ints;
353 }
354 
355 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
356 /// be constant operations.
359  llvm::transform(
360  foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
361  assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
362  return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
363  });
364  return ints;
365 }
366 
367 /// Convert `foldResults` into Values. Integer attributes are converted to
368 /// constant op.
370  ArrayRef<OpFoldResult> foldResults) {
371  SmallVector<Value> values;
372  llvm::transform(foldResults, std::back_inserter(values),
373  [&](OpFoldResult foldResult) {
374  if (auto attr = dyn_cast<Attribute>(foldResult))
375  return builder
376  .create<arith::ConstantIndexOp>(
377  loc, cast<IntegerAttr>(attr).getInt())
378  .getResult();
379 
380  return cast<Value>(foldResult);
381  });
382  return values;
383 }
384 
385 std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
386  if (value.getDefiningOp<vector::VectorScaleOp>())
387  return 1;
388  auto mul = value.getDefiningOp<arith::MulIOp>();
389  if (!mul)
390  return {};
391  auto lhs = mul.getLhs();
392  auto rhs = mul.getRhs();
393  if (lhs.getDefiningOp<vector::VectorScaleOp>())
394  return getConstantIntValue(rhs);
395  if (rhs.getDefiningOp<vector::VectorScaleOp>())
396  return getConstantIntValue(lhs);
397  return {};
398 }
399 
400 //===----------------------------------------------------------------------===//
401 // CombiningKindAttr
402 //===----------------------------------------------------------------------===//
403 
404 namespace mlir {
405 namespace vector {
406 namespace detail {
408  using KeyTy = uint64_t;
409 
410  BitmaskEnumStorage(KeyTy val) : value(val) {}
411 
412  bool operator==(const KeyTy &key) const { return value == key; }
413 
415  const KeyTy &key) {
416  return new (allocator.allocate<BitmaskEnumStorage>())
417  BitmaskEnumStorage(key);
418  }
419 
420  KeyTy value = 0;
421 };
422 } // namespace detail
423 } // namespace vector
424 } // namespace mlir
425 
426 //===----------------------------------------------------------------------===//
427 // VectorDialect
428 //===----------------------------------------------------------------------===//
429 
430 namespace {
431 /// This class defines the interface for handling inlining with vector dialect
432 /// operations.
433 struct VectorInlinerInterface : public DialectInlinerInterface {
435 
436  /// All vector dialect ops can be inlined.
437  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
438  return true;
439  }
440 };
441 } // namespace
442 
443 void VectorDialect::initialize() {
444  addAttributes<
445 #define GET_ATTRDEF_LIST
446 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
447  >();
448 
449  addOperations<
450 #define GET_OP_LIST
451 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
452  >();
453 
454  addInterfaces<VectorInlinerInterface>();
455 
456  declarePromisedInterfaces<bufferization::BufferizableOpInterface,
457  TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
458  YieldOp>();
459  declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
460  TransferWriteOp>();
461  declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
462  declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
463  declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
464 }
465 
466 /// Materialize a single constant operation from a given attribute value with
467 /// the desired resultant type.
469  Attribute value, Type type,
470  Location loc) {
471  if (isa<ub::PoisonAttrInterface>(value))
472  return value.getDialect().materializeConstant(builder, value, type, loc);
473 
474  return arith::ConstantOp::materialize(builder, value, type, loc);
475 }
476 
478  return builder.getIntegerType(64);
479 }
480 
482  ArrayRef<int64_t> values) {
483  return builder.getI64ArrayAttr(values);
484 }
485 
486 //===----------------------------------------------------------------------===//
487 // MultiDimReductionOp
488 //===----------------------------------------------------------------------===//
489 
490 void vector::MultiDimReductionOp::build(OpBuilder &builder,
491  OperationState &result, Value source,
492  Value acc, ArrayRef<bool> reductionMask,
493  CombiningKind kind) {
494  SmallVector<int64_t> reductionDims;
495  for (const auto &en : llvm::enumerate(reductionMask))
496  if (en.value())
497  reductionDims.push_back(en.index());
498  build(builder, result, kind, source, acc, reductionDims);
499 }
500 
501 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
502  // Single parallel dim, this is a noop.
503  if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
504  return getSource();
505  return {};
506 }
507 
508 std::optional<SmallVector<int64_t, 4>>
509 MultiDimReductionOp::getShapeForUnroll() {
510  return llvm::to_vector<4>(getSourceVectorType().getShape());
511 }
512 
513 LogicalResult MultiDimReductionOp::verify() {
514  SmallVector<int64_t> targetShape;
515  SmallVector<bool> scalableDims;
516  Type inferredReturnType;
517  auto sourceScalableDims = getSourceVectorType().getScalableDims();
518  for (auto [dimIdx, dimSize] :
519  llvm::enumerate(getSourceVectorType().getShape()))
520  if (!llvm::any_of(getReductionDims(),
521  [dimIdx = dimIdx](int64_t reductionDimIdx) {
522  return reductionDimIdx == static_cast<int64_t>(dimIdx);
523  })) {
524  targetShape.push_back(dimSize);
525  scalableDims.push_back(sourceScalableDims[dimIdx]);
526  }
527  // TODO: update to also allow 0-d vectors when available.
528  if (targetShape.empty())
529  inferredReturnType = getSourceVectorType().getElementType();
530  else
531  inferredReturnType = VectorType::get(
532  targetShape, getSourceVectorType().getElementType(), scalableDims);
533  if (getType() != inferredReturnType)
534  return emitOpError() << "destination type " << getType()
535  << " is incompatible with source type "
536  << getSourceVectorType();
537 
538  return success();
539 }
540 
541 /// Returns the mask type expected by this operation.
542 Type MultiDimReductionOp::getExpectedMaskType() {
543  auto vecType = getSourceVectorType();
544  return VectorType::get(vecType.getShape(),
545  IntegerType::get(vecType.getContext(), /*width=*/1),
546  vecType.getScalableDims());
547 }
548 
549 namespace {
550 // Only unit dimensions that are being reduced are folded. If the dimension is
551 // unit, but not reduced, it is not folded, thereby keeping the output type the
552 // same. If not all dimensions which are reduced are of unit dimension, this
553 // transformation does nothing. This is just a generalization of
554 // ElideSingleElementReduction for ReduceOp.
555 struct ElideUnitDimsInMultiDimReduction
556  : public OpRewritePattern<MultiDimReductionOp> {
558 
559  LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
560  PatternRewriter &rewriter) const override {
561  ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
562  for (const auto &dim : enumerate(shape)) {
563  if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
564  return failure();
565  }
566 
567  // Vector mask setup.
568  OpBuilder::InsertionGuard guard(rewriter);
569  Operation *rootOp;
570  Value mask;
571  if (reductionOp.isMasked()) {
572  rewriter.setInsertionPoint(reductionOp.getMaskingOp());
573  rootOp = reductionOp.getMaskingOp();
574  mask = reductionOp.getMaskingOp().getMask();
575  } else {
576  rootOp = reductionOp;
577  }
578 
579  Location loc = reductionOp.getLoc();
580  Value acc = reductionOp.getAcc();
581  Value cast;
582  if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
583  if (mask) {
584  VectorType newMaskType =
585  VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
586  dstVecType.getScalableDims());
587  mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
588  }
589  cast = rewriter.create<vector::ShapeCastOp>(
590  loc, reductionOp.getDestType(), reductionOp.getSource());
591  } else {
592  // This means we are reducing all the dimensions, and all reduction
593  // dimensions are of size 1. So a simple extraction would do.
594  if (mask)
595  mask = rewriter.create<vector::ExtractOp>(loc, mask);
596  cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource());
597  }
598 
599  Value result =
600  vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
601  cast, /*fastmath=*/nullptr, mask);
602  rewriter.replaceOp(rootOp, result);
603  return success();
604  }
605 };
606 } // namespace
607 
608 void MultiDimReductionOp::getCanonicalizationPatterns(
609  RewritePatternSet &results, MLIRContext *context) {
610  results.add<ElideUnitDimsInMultiDimReduction>(context);
611 }
612 
613 //===----------------------------------------------------------------------===//
614 // ReductionOp
615 //===----------------------------------------------------------------------===//
616 
617 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
618  CombiningKind kind, Value vector,
619  arith::FastMathFlags fastMathFlags) {
620  build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
621 }
622 
623 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
624  CombiningKind kind, Value vector, Value acc,
625  arith::FastMathFlags fastMathFlags) {
626  build(builder, result,
627  llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
628  acc, fastMathFlags);
629 }
630 
631 LogicalResult ReductionOp::verify() {
632  // Verify for 0-D and 1-D vector.
633  int64_t rank = getSourceVectorType().getRank();
634  if (rank > 1)
635  return emitOpError("unsupported reduction rank: ") << rank;
636 
637  // Verify supported reduction kind.
638  Type eltType = getDest().getType();
639  if (!isSupportedCombiningKind(getKind(), eltType))
640  return emitOpError("unsupported reduction type '")
641  << eltType << "' for kind '" << stringifyCombiningKind(getKind())
642  << "'";
643 
644  return success();
645 }
646 
647 // MaskableOpInterface methods.
648 
649 /// Returns the mask type expected by this operation.
650 Type ReductionOp::getExpectedMaskType() {
651  auto vecType = getSourceVectorType();
652  return VectorType::get(vecType.getShape(),
653  IntegerType::get(vecType.getContext(), /*width=*/1),
654  vecType.getScalableDims());
655 }
656 
658  OpBuilder &builder, Location loc,
659  Value vector) {
660  switch (op) {
661  case arith::AtomicRMWKind::addf:
662  case arith::AtomicRMWKind::addi:
663  return builder.create<vector::ReductionOp>(vector.getLoc(),
664  CombiningKind::ADD, vector);
665  case arith::AtomicRMWKind::mulf:
666  case arith::AtomicRMWKind::muli:
667  return builder.create<vector::ReductionOp>(vector.getLoc(),
668  CombiningKind::MUL, vector);
669  case arith::AtomicRMWKind::minimumf:
670  return builder.create<vector::ReductionOp>(vector.getLoc(),
671  CombiningKind::MINIMUMF, vector);
672  case arith::AtomicRMWKind::mins:
673  return builder.create<vector::ReductionOp>(vector.getLoc(),
674  CombiningKind::MINSI, vector);
675  case arith::AtomicRMWKind::minu:
676  return builder.create<vector::ReductionOp>(vector.getLoc(),
677  CombiningKind::MINUI, vector);
678  case arith::AtomicRMWKind::maximumf:
679  return builder.create<vector::ReductionOp>(vector.getLoc(),
680  CombiningKind::MAXIMUMF, vector);
681  case arith::AtomicRMWKind::maxs:
682  return builder.create<vector::ReductionOp>(vector.getLoc(),
683  CombiningKind::MAXSI, vector);
684  case arith::AtomicRMWKind::maxu:
685  return builder.create<vector::ReductionOp>(vector.getLoc(),
686  CombiningKind::MAXUI, vector);
687  case arith::AtomicRMWKind::andi:
688  return builder.create<vector::ReductionOp>(vector.getLoc(),
689  CombiningKind::AND, vector);
690  case arith::AtomicRMWKind::ori:
691  return builder.create<vector::ReductionOp>(vector.getLoc(),
692  CombiningKind::OR, vector);
693  // TODO: Add remaining reduction operations.
694  default:
695  (void)emitOptionalError(loc, "Reduction operation type not supported");
696  break;
697  }
698  return nullptr;
699 }
700 
701 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
702  return llvm::to_vector<4>(getSourceVectorType().getShape());
703 }
704 
705 namespace {
706 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
708 
709  LogicalResult matchAndRewrite(ReductionOp reductionOp,
710  PatternRewriter &rewriter) const override {
711  // Vector mask setup.
712  OpBuilder::InsertionGuard guard(rewriter);
713  auto maskableOp =
714  cast<vector::MaskableOpInterface>(reductionOp.getOperation());
715  Operation *rootOp;
716  Value mask;
717  if (maskableOp.isMasked()) {
718  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
719  rootOp = maskableOp.getMaskingOp();
720  mask = maskableOp.getMaskingOp().getMask();
721  } else {
722  rootOp = reductionOp;
723  }
724 
725  auto vectorType = reductionOp.getSourceVectorType();
726  if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
727  return failure();
728 
729  Location loc = reductionOp.getLoc();
730  if (mask)
731  mask = rewriter.create<ExtractOp>(loc, mask);
732  Value result = rewriter.create<ExtractOp>(loc, reductionOp.getVector());
733 
734  if (Value acc = reductionOp.getAcc())
735  result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
736  result, acc,
737  reductionOp.getFastmathAttr(), mask);
738 
739  rewriter.replaceOp(rootOp, result);
740  return success();
741  }
742 };
743 } // namespace
744 
745 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
746  MLIRContext *context) {
747  results.add<ElideSingleElementReduction>(context);
748 }
749 
750 //===----------------------------------------------------------------------===//
751 // ContractionOp
752 //===----------------------------------------------------------------------===//
753 
754 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
755  Value lhs, Value rhs, Value acc,
756  ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
757  ArrayRef<IteratorType> iteratorTypes) {
758  result.addOperands({lhs, rhs, acc});
759  result.addTypes(acc.getType());
760  result.addAttribute(
761  getIndexingMapsAttrName(result.name),
762  builder.getAffineMapArrayAttr(
763  AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
764  result.addAttribute(
765  getIteratorTypesAttrName(result.name),
766  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
767  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
768  return IteratorTypeAttr::get(builder.getContext(), t);
769  }))));
770 }
771 
772 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
773  Value lhs, Value rhs, Value acc,
774  ArrayAttr indexingMaps,
775  ArrayAttr iteratorTypes) {
776  build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
777  ContractionOp::getDefaultKind());
778 }
779 
780 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
781  Value lhs, Value rhs, Value acc,
782  ArrayAttr indexingMaps,
783  ArrayAttr iteratorTypes, CombiningKind kind) {
784  result.addOperands({lhs, rhs, acc});
785  result.addTypes(acc.getType());
786  result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
787  result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
788  result.addAttribute(getKindAttrName(result.name),
790 }
791 
792 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
797  SmallVector<Type, 2> types;
798  Type resultType;
799  auto loc = parser.getCurrentLocation();
800  DictionaryAttr dictAttr;
801  // TODO: Unify linalg op attribute parsing.
802  if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
803  parser.parseComma() || parser.parseOperand(rhsInfo) ||
804  parser.parseComma() || parser.parseOperand(accInfo) ||
805  parser.parseTrailingOperandList(masksInfo) ||
806  parser.parseOptionalAttrDict(result.attributes) ||
807  parser.parseColonTypeList(types) ||
808  parser.parseKeywordType("into", resultType) ||
809  parser.resolveOperand(lhsInfo, types[0], result.operands) ||
810  parser.resolveOperand(rhsInfo, types[1], result.operands) ||
811  parser.resolveOperand(accInfo, resultType, result.operands) ||
812  parser.addTypeToList(resultType, result.types))
813  return failure();
814  result.attributes.append(dictAttr.getValue().begin(),
815  dictAttr.getValue().end());
816 
817  // Convert array of string into an array of IteratyType enums. This is needed,
818  // because tests still use the old format when 'iterator_types' attribute is
819  // represented as an array of strings.
820  // TODO: Remove this conversion once tests are fixed.
821  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
822  result.attributes.get(getIteratorTypesAttrName(result.name)));
823  if (!iteratorTypes) {
824  return parser.emitError(loc)
825  << "expected " << getIteratorTypesAttrName(result.name)
826  << " array attribute";
827  }
828 
829  SmallVector<Attribute> iteratorTypeAttrs;
830 
831  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
832  auto maybeIteratorType = symbolizeIteratorType(s);
833  if (!maybeIteratorType.has_value())
834  return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
835 
836  iteratorTypeAttrs.push_back(
837  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
838  }
839  result.attributes.set(getIteratorTypesAttrName(result.name),
840  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
841 
842  if (!result.attributes.get(getKindAttrName(result.name))) {
843  result.addAttribute(
844  getKindAttrName(result.name),
846  ContractionOp::getDefaultKind()));
847  }
848  if (masksInfo.empty())
849  return success();
850  if (masksInfo.size() != 2)
851  return parser.emitError(parser.getNameLoc(),
852  "expected zero or exactly 2 vector mask operands");
853  auto lhsType = llvm::cast<VectorType>(types[0]);
854  auto rhsType = llvm::cast<VectorType>(types[1]);
855  auto maskElementType = parser.getBuilder().getI1Type();
856  std::array<VectorType, 2> maskTypes = {
857  VectorType::Builder(lhsType).setElementType(maskElementType),
858  VectorType::Builder(rhsType).setElementType(maskElementType)};
859  if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
860  return failure();
861  return success();
862 }
863 
865  // TODO: Unify printing code with linalg ops.
866  auto attrNames = getTraitAttrNames();
867  llvm::StringSet<> traitAttrsSet;
868  traitAttrsSet.insert_range(attrNames);
870  for (auto attr : (*this)->getAttrs()) {
871  if (attr.getName() == getIteratorTypesAttrName()) {
872  auto iteratorTypes =
873  llvm::cast<ArrayAttr>(attr.getValue())
874  .getAsValueRange<IteratorTypeAttr, IteratorType>();
875  // Convert IteratorType enums into the string representation. This is
876  // needed, because tests still use the old format when 'iterator_types'
877  // attribute is represented as an array of strings.
878  // TODO: Remove this conversion once tests are fixed.
879  SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
880  llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
881  return StringAttr::get(getContext(), stringifyIteratorType(t));
882  }));
883 
884  attrs.emplace_back(getIteratorTypesAttrName(),
885  ArrayAttr::get(getContext(), iteratorTypeNames));
886  } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
887  attrs.push_back(attr);
888  }
889 
890  auto dictAttr = DictionaryAttr::get(getContext(), attrs);
891  p << " " << dictAttr << " " << getLhs() << ", ";
892  p << getRhs() << ", " << getAcc();
893 
894  p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
895  p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
896  << getResultType();
897 }
898 
899 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
900  const std::vector<std::pair<int64_t, int64_t>> &map) {
901  for (auto &dimPair : map) {
902  if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
903  dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
904  lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
905  return false;
906  }
907  return true;
908 }
909 
910 static LogicalResult verifyOutputShape(
911  ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
912  Type resType,
913  const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
914  const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
915  DenseSet<int64_t> lhsContractingDimSet;
916  DenseSet<int64_t> rhsContractingDimSet;
917  for (auto &dimPair : contractingDimMap) {
918  lhsContractingDimSet.insert(dimPair.first);
919  rhsContractingDimSet.insert(dimPair.second);
920  }
921  DenseSet<int64_t> rhsBatchDimSet(llvm::from_range,
922  llvm::make_second_range(batchDimMap));
923 
924  // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
925  SmallVector<int64_t, 4> expectedResultDims;
926  for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
927  if (lhsContractingDimSet.count(i) > 0)
928  continue;
929  expectedResultDims.push_back(lhsType.getDimSize(i));
930  }
931 
932  // Add free dimensions from 'rhsType' to 'expectedResultDims'.
933  for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
934  if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
935  continue;
936  expectedResultDims.push_back(rhsType.getDimSize(i));
937  }
938 
939  // Verify 'expectedResultDims'.
940  if (expectedResultDims.empty()) {
941  // No batch or free dimension implies a scalar result.
942  if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
943  return op.emitOpError("invalid accumulator/result vector shape");
944  } else {
945  // At least one batch or free dimension implies a vector result.
946  auto resVectorType = llvm::dyn_cast<VectorType>(resType);
947  auto accVectorType = llvm::dyn_cast<VectorType>(accType);
948  if (!resVectorType || !accVectorType)
949  return op.emitOpError("invalid accumulator/result vector shape");
950 
951  // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
952  // types fully define the result vector type. This assumes the affine maps
953  // are well-formed, which must have been verified already.
954  MLIRContext *ctx = op.getContext();
955  AffineMap lhsMap = op.getIndexingMapsArray()[0];
956  AffineMap rhsMap = op.getIndexingMapsArray()[1];
957  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
958  return op.emitOpError(
959  "expected all dimensions to be either a LHS or a RHS dimension");
960  SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
961  for (auto pair :
962  {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
963  VectorType v = pair.first;
964  auto map = pair.second;
965  for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
966  unsigned pos = map.getDimPosition(idx);
967  if (!extents[pos])
968  extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
969  }
970  }
971  if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
972  return op.emitOpError("expected all dimensions to get an extent as "
973  "either a LHS or a RHS dimension");
974 
975  AffineMap resMap = op.getIndexingMapsArray()[2];
976  auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
977  /*symbolCount=*/0, extents, ctx);
978  // Compose the resMap with the extentsMap, which is a constant map.
979  AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
980  assert(llvm::all_of(expectedMap.getResults(),
981  llvm::IsaPred<AffineConstantExpr>) &&
982  "expected constant extent along all dimensions.");
983  // Extract the expected shape and build the type.
984  auto expectedShape = llvm::to_vector<4>(
985  llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
986  return cast<AffineConstantExpr>(e).getValue();
987  }));
988  auto expected =
989  VectorType::get(expectedShape, resVectorType.getElementType(),
990  resVectorType.getScalableDims());
991  if (resVectorType != expected || accVectorType != expected)
992  return op.emitOpError(
993  "invalid accumulator/result vector shape, expected: ")
994  << expected;
995  }
996  return success();
997 }
998 
999 LogicalResult ContractionOp::verify() {
1000  VectorType lhsType = getLhsType();
1001  VectorType rhsType = getRhsType();
1002  Type accType = getAccType();
1003  Type resType = getResultType();
1004 
1005  if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1006  if (!lhsType.getElementType().isSignlessInteger())
1007  return emitOpError("only supports signless integer types");
1008  }
1009 
1010  // Verify that an indexing map was specified for each vector operand.
1011  if (getIndexingMapsArray().size() != 3)
1012  return emitOpError("expected an indexing map for each vector operand");
1013 
1014  // Verify that each index map has 'numIterators' inputs, no symbols, and
1015  // that the number of map outputs equals the rank of its associated
1016  // vector operand.
1017  unsigned numIterators = getIteratorTypes().getValue().size();
1018  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1019  auto index = it.index();
1020  auto map = it.value();
1021  if (map.getNumSymbols() != 0)
1022  return emitOpError("expected indexing map ")
1023  << index << " to have no symbols";
1024  auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
1025  unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1026  // Verify that the map has the right number of inputs, outputs, and indices.
1027  // This also correctly accounts for (..) -> () for rank-0 results.
1028  if (map.getNumDims() != numIterators)
1029  return emitOpError("expected indexing map ")
1030  << index << " to have " << numIterators << " number of inputs";
1031  if (map.getNumResults() != rank)
1032  return emitOpError("expected indexing map ")
1033  << index << " to have " << rank << " number of outputs";
1034  if (!map.isProjectedPermutation())
1035  return emitOpError("expected indexing map ")
1036  << index << " to be a projected permutation of its inputs";
1037  }
1038 
1039  auto contractingDimMap = getContractingDimMap();
1040  auto batchDimMap = getBatchDimMap();
1041 
1042  // Verify at least one contracting dimension pair was specified.
1043  if (contractingDimMap.empty())
1044  return emitOpError("expected at least one contracting dimension pair");
1045 
1046  // Verify contracting dimension map was properly constructed.
1047  if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1048  return emitOpError("invalid contracting dimension map");
1049 
1050  // Verify batch dimension map was properly constructed.
1051  if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1052  return emitOpError("invalid batch dimension map");
1053 
1054  // Verify 'accType' and 'resType' shape.
1055  if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1056  contractingDimMap, batchDimMap)))
1057  return failure();
1058 
1059  // Verify supported combining kind.
1060  auto vectorType = llvm::dyn_cast<VectorType>(resType);
1061  auto elementType = vectorType ? vectorType.getElementType() : resType;
1062  if (!isSupportedCombiningKind(getKind(), elementType))
1063  return emitOpError("unsupported contraction type");
1064 
1065  return success();
1066 }
1067 
1068 // MaskableOpInterface methods.
1069 
1070 /// Returns the mask type expected by this operation. Mostly used for
1071 /// verification purposes. It requires the operation to be vectorized."
1072 Type ContractionOp::getExpectedMaskType() {
1073  auto indexingMaps = this->getIndexingMapsArray();
1074  AffineMap lhsIdxMap = indexingMaps[0];
1075  AffineMap rhsIdxMap = indexingMaps[1];
1076  VectorType lhsType = this->getLhsType();
1077  VectorType rhsType = this->getRhsType();
1078 
1079  unsigned numVecDims = lhsIdxMap.getNumDims();
1080  SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1081  SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1082 
1083  // Using the information in the indexing maps, extract the size of each
1084  // dimension in the vector.contract operation from the two input operands.
1085  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1086  maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1087  maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1088  lhsType.getScalableDims()[dimIdx];
1089  }
1090  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1091  maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1092  maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1093  rhsType.getScalableDims()[dimIdx];
1094  }
1095 
1096  assert(!ShapedType::isDynamicShape(maskShape) &&
1097  "Mask shape couldn't be computed");
1098 
1099  return VectorType::get(maskShape,
1100  IntegerType::get(lhsType.getContext(), /*width=*/1),
1101  maskShapeScalableDims);
1102 }
1103 
1104 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1105  return SmallVector<StringRef>{getIndexingMapsAttrName(),
1106  getIteratorTypesAttrName(), getKindAttrName()};
1107 }
1108 
1109 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
1110  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1111  if (targetExpr == map.getResult(i))
1112  return i;
1113  return -1;
1114 }
1115 
1116 static std::vector<std::pair<int64_t, int64_t>>
1117 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1118  IteratorType targetIteratorType, MLIRContext *context) {
1119  std::vector<std::pair<int64_t, int64_t>> dimMap;
1120  for (const auto &it : llvm::enumerate(iteratorTypes)) {
1121  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1122  if (iteratorType != targetIteratorType)
1123  continue;
1124  // Search lhs/rhs map results for 'targetExpr'.
1125  auto targetExpr = getAffineDimExpr(it.index(), context);
1126  int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1127  int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1128  if (lhsDim >= 0 && rhsDim >= 0)
1129  dimMap.emplace_back(lhsDim, rhsDim);
1130  }
1131  return dimMap;
1132 }
1133 
1134 void ContractionOp::getIterationBounds(
1135  SmallVectorImpl<int64_t> &iterationBounds) {
1136  auto lhsShape = getLhsType().getShape();
1137  auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1138  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1139  for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1140  // Search lhs/rhs map results for 'targetExpr'.
1141  auto targetExpr = getAffineDimExpr(it.index(), getContext());
1142  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1143  if (iteratorType == IteratorType::reduction) {
1144  // Get reduction dim size from lhs shape (same size in rhsShape).
1145  int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1146  assert(lhsDimIndex >= 0);
1147  iterationBounds.push_back(lhsShape[lhsDimIndex]);
1148  continue;
1149  }
1150  // Get parallel dimension size from result shape.
1151  int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1152  assert(resDimIndex >= 0);
1153  assert(resVectorType != nullptr);
1154  iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1155  }
1156 }
1157 
1158 void ContractionOp::getIterationIndexMap(
1159  std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1160  unsigned numMaps = getIndexingMapsArray().size();
1161  iterationIndexMap.resize(numMaps);
1162  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1163  auto index = it.index();
1164  auto map = it.value();
1165  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1166  auto dim = cast<AffineDimExpr>(map.getResult(i));
1167  iterationIndexMap[index][dim.getPosition()] = i;
1168  }
1169  }
1170 }
1171 
1172 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1173  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1174  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1175  getContext());
1176 }
1177 
1178 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1179  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1180  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1181  getContext());
1182 }
1183 
1184 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1186  getIterationBounds(shape);
1187  return shape;
1188 }
1189 
1190 /// Return a fused vector::ContractionOp which represents a patterns such as:
1191 ///
1192 /// ```mlir
1193 /// %c0 = vector.constant 0: ...
1194 /// %c = vector.contract %a, %b, %c0: ...
1195 /// %e = add %c, %d: ...
1196 /// ```
1197 ///
1198 /// by:
1199 ///
1200 /// ```mlir
1201 /// %e = vector.contract %a, %b, %d: ...
1202 /// ```
1203 ///
1204 /// Return null if the canonicalization does not apply.
1205 // TODO: This should be a folding of Add into Contract in core but while they
1206 // live in different dialects, it is not possible without unnatural
1207 // dependencies.
1208 template <typename AddOpType>
1209 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1211 
1212  LogicalResult matchAndRewrite(AddOpType addOp,
1213  PatternRewriter &rewriter) const override {
1214  auto canonicalize = [&](Value maybeContraction,
1215  Value otherOperand) -> vector::ContractionOp {
1216  vector::ContractionOp contractionOp =
1217  dyn_cast_or_null<vector::ContractionOp>(
1218  maybeContraction.getDefiningOp());
1219  if (!contractionOp)
1220  return vector::ContractionOp();
1221  if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1222  contractionOp.getAcc().getDefiningOp())) {
1223  if (maybeZero.getValue() ==
1224  rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1225  IRMapping bvm;
1226  bvm.map(contractionOp.getAcc(), otherOperand);
1227  auto newContraction =
1228  cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1229  rewriter.replaceOp(addOp, newContraction.getResult());
1230  return newContraction;
1231  }
1232  }
1233  return vector::ContractionOp();
1234  };
1235 
1236  Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1237  vector::ContractionOp contract = canonicalize(a, b);
1238  contract = contract ? contract : canonicalize(b, a);
1239  return contract ? success() : failure();
1240  }
1241 };
1242 
1243 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1244  MLIRContext *context) {
1247 }
1248 
1249 //===----------------------------------------------------------------------===//
1250 // ExtractElementOp
1251 //===----------------------------------------------------------------------===//
1252 
1253 void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1254  SetIntRangeFn setResultRanges) {
1255  setResultRanges(getResult(), argRanges.front());
1256 }
1257 
1258 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1259  Value source) {
1260  result.addOperands({source});
1261  result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
1262 }
1263 
1264 LogicalResult vector::ExtractElementOp::verify() {
1265  VectorType vectorType = getSourceVectorType();
1266  if (vectorType.getRank() == 0) {
1267  if (getPosition())
1268  return emitOpError("expected position to be empty with 0-D vector");
1269  return success();
1270  }
1271  if (vectorType.getRank() != 1)
1272  return emitOpError("unexpected >1 vector rank");
1273  if (!getPosition())
1274  return emitOpError("expected position for 1-D vector");
1275  return success();
1276 }
1277 
1278 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1279  // Skip the 0-D vector here now.
1280  if (!adaptor.getPosition())
1281  return {};
1282 
1283  // Fold extractelement (splat X) -> X.
1284  if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
1285  return splat.getInput();
1286 
1287  // Fold extractelement(broadcast(X)) -> X.
1288  if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1289  if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
1290  return broadcast.getSource();
1291 
1292  auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1293  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1294  if (!pos || !src)
1295  return {};
1296 
1297  auto srcElements = src.getValues<Attribute>();
1298 
1299  uint64_t posIdx = pos.getInt();
1300  if (posIdx >= srcElements.size())
1301  return {};
1302 
1303  return srcElements[posIdx];
1304 }
1305 
1306 // Returns `true` if `index` is either within [0, maxIndex) or equal to
1307 // `poisonValue`.
1308 static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
1309  int64_t maxIndex) {
1310  return index == poisonValue || (index >= 0 && index < maxIndex);
1311 }
1312 
1313 //===----------------------------------------------------------------------===//
1314 // ExtractOp
1315 //===----------------------------------------------------------------------===//
1316 
1317 void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1318  SetIntRangeFn setResultRanges) {
1319  setResultRanges(getResult(), argRanges.front());
1320 }
1321 
1322 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1323  Value source) {
1324  auto vectorTy = cast<VectorType>(source.getType());
1325  build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
1326 }
1327 
1328 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1329  Value source, int64_t position) {
1330  build(builder, result, source, ArrayRef<int64_t>{position});
1331 }
1332 
1333 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1334  Value source, OpFoldResult position) {
1335  build(builder, result, source, ArrayRef<OpFoldResult>{position});
1336 }
1337 
1338 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1339  Value source, ArrayRef<int64_t> position) {
1340  build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1341  builder.getDenseI64ArrayAttr(position));
1342 }
1343 
1344 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1345  Value source, ArrayRef<OpFoldResult> position) {
1346  SmallVector<int64_t> staticPos;
1347  SmallVector<Value> dynamicPos;
1348  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1349  build(builder, result, source, dynamicPos,
1350  builder.getDenseI64ArrayAttr(staticPos));
1351 }
1352 
1353 LogicalResult
1354 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1355  ExtractOp::Adaptor adaptor,
1356  SmallVectorImpl<Type> &inferredReturnTypes) {
1357  auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1358  if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1359  vectorType.getRank()) {
1360  inferredReturnTypes.push_back(vectorType.getElementType());
1361  } else {
1362  auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1363  vectorType.getRank());
1364  inferredReturnTypes.push_back(VectorType::get(
1365  vectorType.getShape().drop_front(n), vectorType.getElementType(),
1366  vectorType.getScalableDims().drop_front(n)));
1367  }
1368  return success();
1369 }
1370 
1371 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1372  // Allow extracting 1-element vectors instead of scalars.
1373  auto isCompatible = [](TypeRange l, TypeRange r) {
1374  auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1375  return vectorType && vectorType.getShape().equals({1}) &&
1376  vectorType.getElementType() == r.front();
1377  };
1378  if (l.size() == 1 && r.size() == 1 &&
1379  (isCompatible(l, r) || isCompatible(r, l)))
1380  return true;
1381  return l == r;
1382 }
1383 
1384 LogicalResult vector::ExtractOp::verify() {
1385  // Note: This check must come before getMixedPosition() to prevent a crash.
1386  auto dynamicMarkersCount =
1387  llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1388  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1389  return emitOpError(
1390  "mismatch between dynamic and static positions (kDynamic marker but no "
1391  "corresponding dynamic position) -- this can only happen due to an "
1392  "incorrect fold/rewrite");
1393  auto position = getMixedPosition();
1394  if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1395  return emitOpError(
1396  "expected position attribute of rank no greater than vector rank");
1397  for (auto [idx, pos] : llvm::enumerate(position)) {
1398  if (auto attr = dyn_cast<Attribute>(pos)) {
1399  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1401  constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1402  return emitOpError("expected position attribute #")
1403  << (idx + 1)
1404  << " to be a non-negative integer smaller than the "
1405  "corresponding vector dimension or poison (-1)";
1406  }
1407  }
1408  }
1409  return success();
1410 }
1411 
1412 template <typename IntType>
1413 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1414  return llvm::to_vector<4>(llvm::map_range(
1415  arrayAttr.getAsRange<IntegerAttr>(),
1416  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1417 }
1418 
1419 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1420 /// positions.
1421 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1422  if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1423  return failure();
1424 
1425  // TODO: Canonicalization for dynamic position not implemented yet.
1426  if (extractOp.hasDynamicPosition())
1427  return failure();
1428 
1429  SmallVector<int64_t> globalPosition;
1430  ExtractOp currentOp = extractOp;
1431  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1432  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1433  while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1434  currentOp = nextOp;
1435  // TODO: Canonicalization for dynamic position not implemented yet.
1436  if (currentOp.hasDynamicPosition())
1437  return failure();
1438  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1439  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1440  }
1441  extractOp.setOperand(0, currentOp.getVector());
1442  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1443  OpBuilder b(extractOp.getContext());
1444  std::reverse(globalPosition.begin(), globalPosition.end());
1445  extractOp.setStaticPosition(globalPosition);
1446  return success();
1447 }
1448 
1449 namespace {
1450 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1451 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1452 /// Compose TransposeOp permutations as we walk back.
1453 /// This helper class keeps an updated extraction position `extractPosition`
1454 /// with extra trailing sentinels.
1455 /// The sentinels encode the internal transposition status of the result vector.
1456 /// As we iterate, extractPosition is permuted and updated.
1457 class ExtractFromInsertTransposeChainState {
1458 public:
1459  ExtractFromInsertTransposeChainState(ExtractOp e);
1460 
1461  /// Iterate over producing insert and transpose ops until we find a fold.
1462  Value fold();
1463 
1464 private:
1465  /// Return true if the vector at position `a` is contained within the vector
1466  /// at position `b`. Under insert/extract semantics, this is the same as `a`
1467  /// is a prefix of `b`.
1468  template <typename ContainerA, typename ContainerB>
1469  bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1470  return a.size() <= b.size() &&
1471  std::equal(a.begin(), a.begin() + a.size(), b.begin());
1472  }
1473 
1474  /// Return true if the vector at position `a` intersects the vector at
1475  /// position `b`. Under insert/extract semantics, this is the same as equality
1476  /// of all entries of `a` that are >=0 with the corresponding entries of b.
1477  /// Comparison is on the common prefix (i.e. zip).
1478  template <typename ContainerA, typename ContainerB>
1479  bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1480  for (auto [elemA, elemB] : llvm::zip(a, b)) {
1481  if (elemA < 0 || elemB < 0)
1482  continue;
1483  if (elemA != elemB)
1484  return false;
1485  }
1486  return true;
1487  }
1488 
1489  /// Folding is only possible in the absence of an internal permutation in the
1490  /// result vector.
1491  bool canFold() {
1492  return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1493  }
1494 
1495  // Helper to get the next defining op of interest.
1496  void updateStateForNextIteration(Value v) {
1497  nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1498  nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1499  };
1500 
1501  // Case 1. If we hit a transpose, just compose the map and iterate.
1502  // Invariant: insert + transpose do not change rank, we can always compose.
1503  LogicalResult handleTransposeOp();
1504 
1505  // Case 2: the insert position matches extractPosition exactly, early return.
1506  LogicalResult handleInsertOpWithMatchingPos(Value &res);
1507 
1508  /// Case 3: if the insert position is a prefix of extractPosition, extract a
1509  /// portion of the source of the insert.
1510  /// Example:
1511  /// ```
1512  /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1513  /// // extractPosition == [1, 2, 3]
1514  /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1515  /// // can fold to vector.extract %source[0, 3]
1516  /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1517  /// ```
1518  /// To traverse through %source, we need to set the leading dims to 0 and
1519  /// drop the extra leading dims.
1520  /// This method updates the internal state.
1521  LogicalResult handleInsertOpWithPrefixPos(Value &res);
1522 
1523  /// Try to fold in place to extract(source, extractPosition) and return the
1524  /// folded result. Return null if folding is not possible (e.g. due to an
1525  /// internal transposition in the result).
1526  Value tryToFoldExtractOpInPlace(Value source);
1527 
1528  ExtractOp extractOp;
1529  int64_t vectorRank;
1530  int64_t extractedRank;
1531 
1532  InsertOp nextInsertOp;
1533  TransposeOp nextTransposeOp;
1534 
1535  /// Sentinel values that encode the internal permutation status of the result.
1536  /// They are set to (-1, ... , -k) at the beginning and appended to
1537  /// `extractPosition`.
1538  /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1539  /// ensure that there is no internal transposition.
1540  /// Internal transposition cannot be accounted for with a folding pattern.
1541  // TODO: We could relax the internal transposition with an extra transposition
1542  // operation in a future canonicalizer.
1543  SmallVector<int64_t> sentinels;
1545 };
1546 } // namespace
1547 
1548 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1549  ExtractOp e)
1550  : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1551  extractedRank(extractOp.getNumIndices()) {
1552  assert(vectorRank >= extractedRank && "Extracted position overflow");
1553  sentinels.reserve(vectorRank - extractedRank);
1554  for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1555  sentinels.push_back(-(i + 1));
1556  extractPosition.assign(extractOp.getStaticPosition().begin(),
1557  extractOp.getStaticPosition().end());
1558  llvm::append_range(extractPosition, sentinels);
1559 }
1560 
1561 // Case 1. If we hit a transpose, just compose the map and iterate.
1562 // Invariant: insert + transpose do not change rank, we can always compose.
1563 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1564  // TODO: Canonicalization for dynamic position not implemented yet.
1565  if (extractOp.hasDynamicPosition())
1566  return failure();
1567 
1568  if (!nextTransposeOp)
1569  return failure();
1570  AffineMap m = inversePermutation(AffineMap::getPermutationMap(
1571  nextTransposeOp.getPermutation(), extractOp.getContext()));
1573  return success();
1574 }
1575 
1576 // Case 2: the insert position matches extractPosition exactly, early return.
1577 LogicalResult
1578 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1579  Value &res) {
1580  // TODO: Canonicalization for dynamic position not implemented yet.
1581  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1582  return failure();
1583 
1584  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1585  if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1586  return failure();
1587  // Case 2.a. early-exit fold.
1588  res = nextInsertOp.getValueToStore();
1589  // Case 2.b. if internal transposition is present, canFold will be false.
1590  return success(canFold());
1591 }
1592 
1593 /// Case 3: if inserted position is a prefix of extractPosition,
1594 /// extract a portion of the source of the insertion.
1595 /// This method updates the internal state.
1596 LogicalResult
1597 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1598  // TODO: Canonicalization for dynamic position not implemented yet.
1599  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1600  return failure();
1601 
1602  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1603  if (!isContainedWithin(insertedPos, extractPosition))
1604  return failure();
1605  // Set leading dims to zero.
1606  std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1607  // Drop extra leading dims.
1608  extractPosition.erase(extractPosition.begin(),
1609  extractPosition.begin() + insertedPos.size());
1610  extractedRank = extractPosition.size() - sentinels.size();
1611  // Case 3.a. early-exit fold (break and delegate to post-while path).
1612  res = nextInsertOp.getValueToStore();
1613  // Case 3.b. if internal transposition is present, canFold will be false.
1614  return success();
1615 }
1616 
1617 /// Try to fold in place to extract(source, extractPosition) and return the
1618 /// folded result. Return null if folding is not possible (e.g. due to an
1619 /// internal transposition in the result).
1620 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1621  Value source) {
1622  // TODO: Canonicalization for dynamic position not implemented yet.
1623  if (extractOp.hasDynamicPosition())
1624  return Value();
1625 
1626  // If we can't fold (either internal transposition, or nothing to fold), bail.
1627  bool nothingToFold = (source == extractOp.getVector());
1628  if (nothingToFold || !canFold())
1629  return Value();
1630 
1631  // Otherwise, fold by updating the op inplace and return its result.
1632  OpBuilder b(extractOp.getContext());
1633  extractOp.setStaticPosition(
1634  ArrayRef(extractPosition).take_front(extractedRank));
1635  extractOp.getVectorMutable().assign(source);
1636  return extractOp.getResult();
1637 }
1638 
1639 /// Iterate over producing insert and transpose ops until we find a fold.
1640 Value ExtractFromInsertTransposeChainState::fold() {
1641  // TODO: Canonicalization for dynamic position not implemented yet.
1642  if (extractOp.hasDynamicPosition())
1643  return Value();
1644 
1645  Value valueToExtractFrom = extractOp.getVector();
1646  updateStateForNextIteration(valueToExtractFrom);
1647  while (nextInsertOp || nextTransposeOp) {
1648  // Case 1. If we hit a transpose, just compose the map and iterate.
1649  // Invariant: insert + transpose do not change rank, we can always compose.
1650  if (succeeded(handleTransposeOp())) {
1651  valueToExtractFrom = nextTransposeOp.getVector();
1652  updateStateForNextIteration(valueToExtractFrom);
1653  continue;
1654  }
1655 
1656  Value result;
1657  // Case 2: the position match exactly.
1658  if (succeeded(handleInsertOpWithMatchingPos(result)))
1659  return result;
1660 
1661  // Case 3: if the inserted position is a prefix of extractPosition, we can
1662  // just extract a portion of the source of the insert.
1663  if (succeeded(handleInsertOpWithPrefixPos(result)))
1664  return tryToFoldExtractOpInPlace(result);
1665 
1666  // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1667  // values. This is a more difficult case and we bail.
1668  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1669  if (isContainedWithin(extractPosition, insertedPos) ||
1670  intersectsWhereNonNegative(extractPosition, insertedPos))
1671  return Value();
1672 
1673  // Case 5: No intersection, we forward the extract to insertOp.dest().
1674  valueToExtractFrom = nextInsertOp.getDest();
1675  updateStateForNextIteration(valueToExtractFrom);
1676  }
1677  // If after all this we can fold, go for it.
1678  return tryToFoldExtractOpInPlace(valueToExtractFrom);
1679 }
1680 
1681 /// Returns true if the operation has a 0-D vector type operand or result.
1682 static bool hasZeroDimVectors(Operation *op) {
1683  auto hasZeroDimVectorType = [](Type type) -> bool {
1684  auto vecType = dyn_cast<VectorType>(type);
1685  return vecType && vecType.getRank() == 0;
1686  };
1687 
1688  return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1689  llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1690 }
1691 
1692 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1693 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1694  Operation *defOp = extractOp.getVector().getDefiningOp();
1695  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1696  return Value();
1697 
1698  Value source = defOp->getOperand(0);
1699  if (extractOp.getType() == source.getType())
1700  return source;
1701  auto getRank = [](Type type) {
1702  return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1703  : 0;
1704  };
1705 
1706  // If splat or broadcast from a scalar, just return the source scalar.
1707  unsigned broadcastSrcRank = getRank(source.getType());
1708  if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
1709  return source;
1710 
1711  unsigned extractResultRank = getRank(extractOp.getType());
1712  if (extractResultRank > broadcastSrcRank)
1713  return Value();
1714  // Check that the dimension of the result haven't been broadcasted.
1715  auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1716  auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
1717  if (extractVecType && broadcastVecType &&
1718  extractVecType.getShape() !=
1719  broadcastVecType.getShape().take_back(extractResultRank))
1720  return Value();
1721 
1722  auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1723  int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1724 
1725  // Detect all the positions that come from "dim-1" broadcasting.
1726  // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1727  // extract position to `0` when extracting from the source operand.
1728  llvm::SetVector<int64_t> broadcastedUnitDims =
1729  broadcastOp.computeBroadcastedUnitDims();
1730  SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
1731  OpBuilder b(extractOp.getContext());
1732  int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1733  for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1734  if (broadcastedUnitDims.contains(i))
1735  extractPos[i] = b.getIndexAttr(0);
1736  // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1737  // matching extract position when extracting from the source operand.
1738  int64_t rankDiff = broadcastSrcRank - extractResultRank;
1739  extractPos.erase(extractPos.begin(),
1740  std::next(extractPos.begin(), extractPos.size() - rankDiff));
1741  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1742  auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
1743  extractOp->setOperands(
1744  llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
1745  extractOp.setStaticPosition(staticPos);
1746  return extractOp.getResult();
1747 }
1748 
1749 /// Fold extractOp coming from ShuffleOp.
1750 ///
1751 /// Example:
1752 ///
1753 /// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1754 /// : vector<8xf32>, vector<8xf32>
1755 /// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1756 /// ->
1757 /// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1758 ///
1759 static Value foldExtractFromShuffle(ExtractOp extractOp) {
1760  // Dynamic positions are not folded as the resulting code would be more
1761  // complex than the input code.
1762  if (extractOp.hasDynamicPosition())
1763  return Value();
1764 
1765  auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1766  if (!shuffleOp)
1767  return Value();
1768 
1769  // TODO: 0-D or multi-dimensional vectors not supported yet.
1770  if (shuffleOp.getResultVectorType().getRank() != 1)
1771  return Value();
1772 
1773  int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1774  auto shuffleMask = shuffleOp.getMask();
1775  int64_t extractIdx = extractOp.getStaticPosition()[0];
1776  int64_t shuffleIdx = shuffleMask[extractIdx];
1777 
1778  // Find the shuffled vector to extract from based on the shuffle index.
1779  if (shuffleIdx < inputVecSize) {
1780  extractOp.setOperand(0, shuffleOp.getV1());
1781  extractOp.setStaticPosition({shuffleIdx});
1782  } else {
1783  extractOp.setOperand(0, shuffleOp.getV2());
1784  extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1785  }
1786 
1787  return extractOp.getResult();
1788 }
1789 
1790 // Fold extractOp with source coming from ShapeCast op.
1791 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1792  // TODO: Canonicalization for dynamic position not implemented yet.
1793  if (extractOp.hasDynamicPosition())
1794  return Value();
1795 
1796  auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1797  if (!shapeCastOp)
1798  return Value();
1799 
1800  // Get the nth dimension size starting from lowest dimension.
1801  auto getDimReverse = [](VectorType type, int64_t n) {
1802  return type.getShape().take_back(n + 1).front();
1803  };
1804  int64_t destinationRank =
1805  llvm::isa<VectorType>(extractOp.getType())
1806  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1807  : 0;
1808  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1809  return Value();
1810  if (destinationRank > 0) {
1811  auto destinationType =
1812  llvm::cast<VectorType>(extractOp.getResult().getType());
1813  for (int64_t i = 0; i < destinationRank; i++) {
1814  // The lowest dimension of the destination must match the lowest
1815  // dimension of the shapecast op source.
1816  // TODO: This case could be support in a canonicalization pattern.
1817  if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1818  getDimReverse(destinationType, i))
1819  return Value();
1820  }
1821  }
1822  // Extract the strides associated with the extract op vector source. Then use
1823  // this to calculate a linearized position for the extract.
1824  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1825  std::reverse(extractedPos.begin(), extractedPos.end());
1826  SmallVector<int64_t, 4> strides;
1827  int64_t stride = 1;
1828  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1829  strides.push_back(stride);
1830  stride *=
1831  getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1832  }
1833 
1834  int64_t position = linearize(extractedPos, strides);
1835  // Then extract the strides associated to the shapeCast op vector source and
1836  // delinearize the position using those strides.
1837  SmallVector<int64_t, 4> newStrides;
1838  int64_t numDimension =
1839  shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1840  stride = 1;
1841  for (int64_t i = 0; i < numDimension; i++) {
1842  newStrides.push_back(stride);
1843  stride *=
1844  getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1845  }
1846  std::reverse(newStrides.begin(), newStrides.end());
1847  SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1848  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1849  OpBuilder b(extractOp.getContext());
1850  extractOp.setStaticPosition(newPosition);
1851  extractOp.setOperand(0, shapeCastOp.getSource());
1852  return extractOp.getResult();
1853 }
1854 
1855 /// Fold an ExtractOp from ExtractStridedSliceOp.
1856 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1857  // TODO: Canonicalization for dynamic position not implemented yet.
1858  if (extractOp.hasDynamicPosition())
1859  return Value();
1860 
1861  auto extractStridedSliceOp =
1862  extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1863  if (!extractStridedSliceOp)
1864  return Value();
1865 
1866  // 0-D vectors not supported.
1867  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1868  if (hasZeroDimVectors(extractStridedSliceOp))
1869  return Value();
1870 
1871  // Return if 'extractStridedSliceOp' has non-unit strides.
1872  if (extractStridedSliceOp.hasNonUnitStrides())
1873  return Value();
1874 
1875  // Trim offsets for dimensions fully extracted.
1876  auto sliceOffsets =
1877  extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1878  while (!sliceOffsets.empty()) {
1879  size_t lastOffset = sliceOffsets.size() - 1;
1880  if (sliceOffsets.back() != 0 ||
1881  extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1882  extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1883  break;
1884  sliceOffsets.pop_back();
1885  }
1886  unsigned destinationRank = 0;
1887  if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1888  destinationRank = vecType.getRank();
1889  // The dimensions of the result need to be untouched by the
1890  // extractStridedSlice op.
1891  if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1892  sliceOffsets.size())
1893  return Value();
1894 
1895  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1896  assert(extractedPos.size() >= sliceOffsets.size());
1897  for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1898  extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1899  extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1900 
1901  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1902  OpBuilder b(extractOp.getContext());
1903  extractOp.setStaticPosition(extractedPos);
1904  return extractOp.getResult();
1905 }
1906 
1907 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1908 static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1909  // TODO: Canonicalization for dynamic position not implemented yet.
1910  if (extractOp.hasDynamicPosition())
1911  return Value();
1912 
1913  int64_t destinationRank =
1914  llvm::isa<VectorType>(extractOp.getType())
1915  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1916  : 0;
1917  auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1918  if (!insertOp)
1919  return Value();
1920 
1921  // 0-D vectors not supported.
1922  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1923  if (hasZeroDimVectors(insertOp))
1924  return Value();
1925 
1926  while (insertOp) {
1927  int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1928  insertOp.getSourceVectorType().getRank();
1929  if (destinationRank > insertOp.getSourceVectorType().getRank())
1930  return Value();
1931  auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1932  ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1933 
1934  if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1935  return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1936  }))
1937  return Value();
1938  bool disjoint = false;
1939  SmallVector<int64_t, 4> offsetDiffs;
1940  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1941  int64_t start = insertOffsets[dim];
1942  int64_t size =
1943  (dim < insertRankDiff)
1944  ? 1
1945  : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1946  int64_t end = start + size;
1947  int64_t offset = extractOffsets[dim];
1948  // Check if the start of the extract offset is in the interval inserted.
1949  if (start <= offset && offset < end) {
1950  if (dim >= insertRankDiff)
1951  offsetDiffs.push_back(offset - start);
1952  continue;
1953  }
1954  disjoint = true;
1955  break;
1956  }
1957  // The extract element chunk overlap with the vector inserted.
1958  if (!disjoint) {
1959  // If any of the inner dimensions are only partially inserted we have a
1960  // partial overlap.
1961  int64_t srcRankDiff =
1962  insertOp.getSourceVectorType().getRank() - destinationRank;
1963  for (int64_t i = 0; i < destinationRank; i++) {
1964  if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1965  insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1966  insertRankDiff))
1967  return Value();
1968  }
1969  extractOp.getVectorMutable().assign(insertOp.getValueToStore());
1970  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1971  OpBuilder b(extractOp.getContext());
1972  extractOp.setStaticPosition(offsetDiffs);
1973  return extractOp.getResult();
1974  }
1975  // If the chunk extracted is disjoint from the chunk inserted, keep
1976  // looking in the insert chain.
1977  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1978  }
1979  return Value();
1980 }
1981 
1982 /// Try to fold the extraction of a scalar from a vector defined by
1983 /// vector.from_elements. E.g.:
1984 ///
1985 /// %0 = vector.from_elements %a, %b : vector<2xf32>
1986 /// %1 = vector.extract %0[0] : f32 from vector<2xf32>
1987 /// ==> fold to %a
1988 static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
1989  // Dynamic extractions cannot be folded.
1990  if (extractOp.hasDynamicPosition())
1991  return {};
1992 
1993  // Look for extract(from_elements).
1994  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
1995  if (!fromElementsOp)
1996  return {};
1997 
1998  // Scalable vectors are not supported.
1999  auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2000  if (vecType.isScalable())
2001  return {};
2002 
2003  // Only extractions of scalars are supported.
2004  int64_t rank = vecType.getRank();
2005  ArrayRef<int64_t> indices = extractOp.getStaticPosition();
2006  if (extractOp.getType() != vecType.getElementType())
2007  return {};
2008  assert(static_cast<int64_t>(indices.size()) == rank &&
2009  "unexpected number of indices");
2010 
2011  // Compute flattened/linearized index and fold to operand.
2012  int flatIndex = 0;
2013  int stride = 1;
2014  for (int i = rank - 1; i >= 0; --i) {
2015  flatIndex += indices[i] * stride;
2016  stride *= vecType.getDimSize(i);
2017  }
2018  return fromElementsOp.getElements()[flatIndex];
2019 }
2020 
2021 /// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
2022 /// then fold it.
2023 template <typename OpType, typename AdaptorType>
2024 static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
2025  SmallVectorImpl<Value> &operands) {
2026  std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2027  OperandRange dynamicPosition = op.getDynamicPosition();
2028  ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2030  if constexpr (std::is_same_v<OpType, ExtractOp>)
2031  vectorShape = op.getSourceVectorType().getShape();
2032  else
2033  vectorShape = op.getDestVectorType().getShape();
2034 
2035  // If the dynamic operands is empty, it is returned directly.
2036  if (!dynamicPosition.size())
2037  return {};
2038 
2039  // `index` is used to iterate over the `dynamicPosition`.
2040  unsigned index = 0;
2041 
2042  // `opChange` is a flag. If it is true, it means to update `op` in place.
2043  bool opChange = false;
2044  for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2045  if (!ShapedType::isDynamic(staticPosition[i]))
2046  continue;
2047  Attribute positionAttr = dynamicPositionAttr[index];
2048  Value position = dynamicPosition[index++];
2049  if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2050  int64_t value = attr.getInt();
2051  // Do not fold if the value is out of bounds (-1 signifies a poison
2052  // value rather than OOB index).
2053  if (value >= -1 && value < vectorShape[i]) {
2054  staticPosition[i] = attr.getInt();
2055  opChange = true;
2056  continue;
2057  }
2058  }
2059  operands.push_back(position);
2060  }
2061 
2062  if (opChange) {
2063  op.setStaticPosition(staticPosition);
2064  op.getOperation()->setOperands(operands);
2065  return op.getResult();
2066  }
2067  return {};
2068 }
2069 
2070 /// Fold an insert or extract operation into an poison value when a poison index
2071 /// is found at any dimension of the static position.
2073  ArrayRef<int64_t> staticPos,
2074  int64_t poisonVal) {
2075  if (!is_contained(staticPos, poisonVal))
2076  return {};
2077 
2078  return ub::PoisonAttr::get(context);
2079 }
2080 
2081 /// Fold a vector extract from is a poison source.
2083  if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2084  return srcAttr;
2085 
2086  return {};
2087 }
2088 
2089 /// Fold a vector extract extracting from a DenseElementsAttr.
2091  Attribute srcAttr) {
2092  auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2093  if (!denseAttr) {
2094  return {};
2095  }
2096 
2097  if (denseAttr.isSplat()) {
2098  Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2099  if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2100  newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2101  return newAttr;
2102  }
2103 
2104  auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2105  if (vecTy.isScalable())
2106  return {};
2107 
2108  if (extractOp.hasDynamicPosition()) {
2109  return {};
2110  }
2111 
2112  // Materializing subsets of a large constant array can generally lead to
2113  // explosion in IR size because of different combination of subsets that
2114  // can exist. However, vector.extract is a restricted form of subset
2115  // extract where you can only extract non-overlapping (or the same) subset for
2116  // a given rank of the subset. Because of this property, the IR size can only
2117  // increase at most by `rank * size(array)` from a single constant array being
2118  // extracted by multiple extracts.
2119 
2120  // Calculate the linearized position of the continuous chunk of elements to
2121  // extract.
2122  SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2123  copy(extractOp.getStaticPosition(), completePositions.begin());
2124  int64_t startPos =
2125  linearize(completePositions, computeStrides(vecTy.getShape()));
2126  auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2127 
2128  TypedAttr newAttr;
2129  if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2130  SmallVector<Attribute> elementValues(
2131  denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2132  newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2133  } else {
2134  newAttr = *denseValuesBegin;
2135  }
2136 
2137  return newAttr;
2138 }
2139 
2140 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2141  // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2142  // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2143  // mismatch).
2144  if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
2145  return getVector();
2146  if (auto res = foldPoisonIndexInsertExtractOp(
2147  getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2148  return res;
2149  if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
2150  return res;
2151  if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
2152  return res;
2153  if (succeeded(foldExtractOpFromExtractChain(*this)))
2154  return getResult();
2155  if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2156  return res;
2157  if (auto res = foldExtractFromBroadcast(*this))
2158  return res;
2159  if (auto res = foldExtractFromShuffle(*this))
2160  return res;
2161  if (auto res = foldExtractFromShapeCast(*this))
2162  return res;
2163  if (auto val = foldExtractFromExtractStrided(*this))
2164  return val;
2165  if (auto val = foldExtractStridedOpFromInsertChain(*this))
2166  return val;
2167  if (auto val = foldScalarExtractFromFromElements(*this))
2168  return val;
2169  SmallVector<Value> operands = {getVector()};
2170  if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
2171  return val;
2172  return OpFoldResult();
2173 }
2174 
2175 namespace {
2176 
2177 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2178 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2179 public:
2181 
2182  LogicalResult matchAndRewrite(ExtractOp extractOp,
2183  PatternRewriter &rewriter) const override {
2184  Operation *defOp = extractOp.getVector().getDefiningOp();
2185  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2186  return failure();
2187 
2188  Value source = defOp->getOperand(0);
2189  if (extractOp.getType() == source.getType())
2190  return failure();
2191  auto getRank = [](Type type) {
2192  return llvm::isa<VectorType>(type)
2193  ? llvm::cast<VectorType>(type).getRank()
2194  : 0;
2195  };
2196  unsigned broadcastSrcRank = getRank(source.getType());
2197  unsigned extractResultRank = getRank(extractOp.getType());
2198  // We only consider the case where the rank of the source is less than or
2199  // equal to the rank of the extract dst. The other cases are handled in the
2200  // folding patterns.
2201  if (extractResultRank < broadcastSrcRank)
2202  return failure();
2203  // For scalar result, the input can only be a rank-0 vector, which will
2204  // be handled by the folder.
2205  if (extractResultRank == 0)
2206  return failure();
2207 
2208  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2209  extractOp, extractOp.getType(), source);
2210  return success();
2211  }
2212 };
2213 
2214 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2215 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2216 public:
2218 
2219  LogicalResult matchAndRewrite(ExtractOp extractOp,
2220  PatternRewriter &rewriter) const override {
2221  auto createMaskOp =
2222  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2223  if (!createMaskOp)
2224  return failure();
2225 
2226  VectorType extractedMaskType =
2227  llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2228 
2229  if (!extractedMaskType)
2230  return failure();
2231 
2232  auto maskOperands = createMaskOp.getOperands();
2233  ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2234  VectorType maskType = createMaskOp.getVectorType();
2235 
2236  bool containsUnknownDims = false;
2237  bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2238 
2239  for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2240  dimIdx++) {
2241  int64_t pos = extractOpPos[dimIdx];
2242  Value operand = maskOperands[dimIdx];
2243  auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2244  if (!constantOp) {
2245  // Bounds of this dim unknown.
2246  containsUnknownDims = true;
2247  continue;
2248  }
2249 
2250  int64_t createMaskBound =
2251  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2252 
2253  if (pos != ShapedType::kDynamic) {
2254  // If any position is outside the range from the `create_mask`, then the
2255  // extracted mask will be all-false.
2256  allFalse |= pos >= createMaskBound;
2257  } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2258  // This dim is not all-true and since this is a dynamic index we don't
2259  // know if the extraction is within the true or false region.
2260  // Note: Zero dims have already handled via getMaskFormat().
2261  containsUnknownDims = true;
2262  }
2263  }
2264 
2265  if (allFalse) {
2266  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2267  extractOp, DenseElementsAttr::get(extractedMaskType, false));
2268  } else if (!containsUnknownDims) {
2269  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2270  extractOp, extractedMaskType,
2271  maskOperands.drop_front(extractOpPos.size()));
2272  } else {
2273  return failure();
2274  }
2275  return success();
2276  }
2277 };
2278 
2279 // Folds extract(shape_cast(..)) into shape_cast when the total element count
2280 // does not change.
2281 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2282  PatternRewriter &rewriter) {
2283  auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2284  if (!castOp)
2285  return failure();
2286 
2287  VectorType sourceType = castOp.getSourceVectorType();
2288  auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2289  if (!targetType)
2290  return failure();
2291 
2292  if (sourceType.getNumElements() != targetType.getNumElements())
2293  return failure();
2294 
2295  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2296  castOp.getSource());
2297  return success();
2298 }
2299 
2300 /// Try to canonicalize the extraction of a subvector from a vector defined by
2301 /// vector.from_elements. E.g.:
2302 ///
2303 /// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2304 /// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2305 /// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2306 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2307  PatternRewriter &rewriter) {
2308  // Dynamic positions are not supported.
2309  if (extractOp.hasDynamicPosition())
2310  return failure();
2311 
2312  // Scalar extracts are handled by the folder.
2313  auto resultType = dyn_cast<VectorType>(extractOp.getType());
2314  if (!resultType)
2315  return failure();
2316 
2317  // Look for extracts from a from_elements op.
2318  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2319  if (!fromElementsOp)
2320  return failure();
2321  VectorType inputType = fromElementsOp.getType();
2322 
2323  // Scalable vectors are not supported.
2324  if (resultType.isScalable() || inputType.isScalable())
2325  return failure();
2326 
2327  // Compute the position of first extracted element and flatten/linearize the
2328  // position.
2329  SmallVector<int64_t> firstElementPos =
2330  llvm::to_vector(extractOp.getStaticPosition());
2331  firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2332  int flatIndex = 0;
2333  int stride = 1;
2334  for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2335  flatIndex += firstElementPos[i] * stride;
2336  stride *= inputType.getDimSize(i);
2337  }
2338 
2339  // Replace the op with a smaller from_elements op.
2340  rewriter.replaceOpWithNewOp<FromElementsOp>(
2341  extractOp, resultType,
2342  fromElementsOp.getElements().slice(flatIndex,
2343  resultType.getNumElements()));
2344  return success();
2345 }
2346 
2347 } // namespace
2348 
2349 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2350  MLIRContext *context) {
2351  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2352  results.add(foldExtractFromShapeCastToShapeCast);
2353  results.add(foldExtractFromFromElements);
2354 }
2355 
2356 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
2357  SmallVectorImpl<int64_t> &results) {
2358  for (auto attr : arrayAttr)
2359  results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2360 }
2361 
2362 //===----------------------------------------------------------------------===//
2363 // FmaOp
2364 //===----------------------------------------------------------------------===//
2365 
2366 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2367  return llvm::to_vector<4>(getVectorType().getShape());
2368 }
2369 
2370 //===----------------------------------------------------------------------===//
2371 // FromElementsOp
2372 //===----------------------------------------------------------------------===//
2373 
2374 /// Rewrite a vector.from_elements into a vector.splat if all elements are the
2375 /// same SSA value. E.g.:
2376 ///
2377 /// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2378 /// ==> rewrite to vector.splat %a : vector<3xf32>
2379 static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
2380  PatternRewriter &rewriter) {
2381  if (!llvm::all_equal(fromElementsOp.getElements()))
2382  return failure();
2383  rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
2384  fromElementsOp.getElements().front());
2385  return success();
2386 }
2387 
2388 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2389  MLIRContext *context) {
2391 }
2392 
2393 //===----------------------------------------------------------------------===//
2394 // BroadcastOp
2395 //===----------------------------------------------------------------------===//
2396 
2397 void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2398  SetIntRangeFn setResultRanges) {
2399  setResultRanges(getResult(), argRanges.front());
2400 }
2401 
2402 /// Return the dimensions of the result vector that were formerly ones in the
2403 /// source tensor and thus correspond to "dim-1" broadcasting.
2406  ArrayRef<int64_t> dstShape) {
2407  int64_t rankDiff = dstShape.size() - srcShape.size();
2408  int64_t dstDim = rankDiff;
2410  for (auto [s1, s2] :
2411  llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2412  if (s1 != s2) {
2413  assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2414  res.insert(dstDim);
2415  }
2416  ++dstDim;
2417  }
2418  return res;
2419 }
2420 
2422  // Scalar broadcast is without any unit dim broadcast.
2423  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2424  if (!srcVectorType)
2425  return {};
2426  return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2427  getResultVectorType().getShape());
2428 }
2429 
2430 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2431 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
2432 /// This requires (and asserts) that the broadcast is free of "dim-1"
2433 /// broadcasting.
2434 /// Since vector.broadcast only allows expanding leading dimensions, an extra
2435 /// vector.transpose may be inserted to make the broadcast possible.
2436 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2437 /// the helper will assert. This means:
2438 /// 1. `dstShape` must not be empty.
2439 /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2440 /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2441 // must match the `value` shape.
2442 Value BroadcastOp::createOrFoldBroadcastOp(
2443  OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2444  const llvm::SetVector<int64_t> &broadcastedDims) {
2445  assert(!dstShape.empty() && "unexpected empty dst shape");
2446 
2447  // Well-formedness check.
2448  SmallVector<int64_t> checkShape;
2449  for (int i = 0, e = dstShape.size(); i < e; ++i) {
2450  if (broadcastedDims.contains(i))
2451  continue;
2452  checkShape.push_back(dstShape[i]);
2453  }
2454  assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2455  "ill-formed broadcastedDims contains values not confined to "
2456  "destVectorShape");
2457 
2458  Location loc = value.getLoc();
2459  Type elementType = getElementTypeOrSelf(value.getType());
2460  VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2461  VectorType dstVectorType = VectorType::get(dstShape, elementType);
2462 
2463  // Step 2. If scalar -> dstShape broadcast, just do it.
2464  if (!srcVectorType) {
2465  assert(checkShape.empty() &&
2466  "ill-formed createOrFoldBroadcastOp arguments");
2467  return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2468  }
2469 
2470  assert(srcVectorType.getShape().equals(checkShape) &&
2471  "ill-formed createOrFoldBroadcastOp arguments");
2472 
2473  // Step 3. Since vector.broadcast only allows creating leading dims,
2474  // vector -> dstShape broadcast may require a transpose.
2475  // Traverse the dims in order and construct:
2476  // 1. The leading entries of the broadcastShape that is guaranteed to be
2477  // achievable by a simple broadcast.
2478  // 2. The induced permutation for the subsequent vector.transpose that will
2479  // bring us from `broadcastShape` back to he desired `dstShape`.
2480  // If the induced permutation is not the identity, create a vector.transpose.
2481  SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2482  broadcastShape.reserve(dstShape.size());
2483  // Consider the example:
2484  // srcShape = 2x4
2485  // dstShape = 1x2x3x4x5
2486  // broadcastedDims = [0, 2, 4]
2487  //
2488  // We want to build:
2489  // broadcastShape = 1x3x5x2x4
2490  // permutation = [0, 2, 4, 1, 3]
2491  // ---V--- -----V-----
2492  // leading broadcast part src shape part
2493  //
2494  // Note that the trailing dims of broadcastShape are exactly the srcShape
2495  // by construction.
2496  // nextSrcShapeDim is used to keep track of where in the permutation the
2497  // "src shape part" occurs.
2498  int64_t nextSrcShapeDim = broadcastedDims.size();
2499  for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2500  if (broadcastedDims.contains(i)) {
2501  // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2502  // bring it to the head of the broadcastShape.
2503  // It will need to be permuted back from `broadcastShape.size() - 1` into
2504  // position `i`.
2505  broadcastShape.push_back(dstShape[i]);
2506  permutation[i] = broadcastShape.size() - 1;
2507  } else {
2508  // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2509  // shape and needs to be permuted into position `i`.
2510  // Don't touch `broadcastShape` here, the whole srcShape will be
2511  // appended after.
2512  permutation[i] = nextSrcShapeDim++;
2513  }
2514  }
2515  // 3.c. Append the srcShape.
2516  llvm::append_range(broadcastShape, srcVectorType.getShape());
2517 
2518  // Ensure there are no "dim-1" broadcasts.
2519  assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2520  .empty() &&
2521  "unexpected \"dim-1\" broadcast");
2522 
2523  VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2524  assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2525  vector::BroadcastableToResult::Success &&
2526  "must be broadcastable");
2527  Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2528  // Step 4. If we find any dimension that indeed needs to be permuted,
2529  // immediately return a new vector.transpose.
2530  for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2531  if (permutation[i] != i)
2532  return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2533  // Otherwise return res.
2534  return res;
2535 }
2536 
2538  Type srcType, VectorType dstVectorType,
2539  std::pair<VectorDim, VectorDim> *mismatchingDims) {
2540  // Broadcast scalar to vector of the same element type.
2541  if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
2542  getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
2543  return BroadcastableToResult::Success;
2544  // From now on, only vectors broadcast.
2545  VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2546  if (!srcVectorType)
2547  return BroadcastableToResult::SourceTypeNotAVector;
2548 
2549  int64_t srcRank = srcVectorType.getRank();
2550  int64_t dstRank = dstVectorType.getRank();
2551  if (srcRank > dstRank)
2552  return BroadcastableToResult::SourceRankHigher;
2553  // Source has an exact match or singleton value for all trailing dimensions
2554  // (all leading dimensions are simply duplicated).
2555  int64_t lead = dstRank - srcRank;
2556  for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2557  // Have mismatching dims (in the sense of vector.broadcast semantics) been
2558  // encountered?
2559  bool foundMismatchingDims = false;
2560 
2561  // Check fixed-width dims.
2562  int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2563  int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2564  if (srcDim != 1 && srcDim != dstDim)
2565  foundMismatchingDims = true;
2566 
2567  // Check scalable flags.
2568  bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2569  bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2570  if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2571  // 1 -> [N] is fine, everything else should be rejected when mixing
2572  // fixed-width and scalable dims
2573  (srcDimScalableFlag != dstDimScalableFlag &&
2574  (srcDim != 1 || srcDimScalableFlag)))
2575  foundMismatchingDims = true;
2576 
2577  if (foundMismatchingDims) {
2578  if (mismatchingDims != nullptr) {
2579  mismatchingDims->first.dim = srcDim;
2580  mismatchingDims->first.isScalable = srcDimScalableFlag;
2581 
2582  mismatchingDims->second.dim = dstDim;
2583  mismatchingDims->second.isScalable = dstDimScalableFlag;
2584  }
2585  return BroadcastableToResult::DimensionMismatch;
2586  }
2587  }
2588 
2589  return BroadcastableToResult::Success;
2590 }
2591 
2592 LogicalResult BroadcastOp::verify() {
2593  std::pair<VectorDim, VectorDim> mismatchingDims;
2595  getSourceType(), getResultVectorType(), &mismatchingDims);
2596  if (res == BroadcastableToResult::Success)
2597  return success();
2598  if (res == BroadcastableToResult::SourceRankHigher)
2599  return emitOpError("source rank higher than destination rank");
2600  if (res == BroadcastableToResult::DimensionMismatch) {
2601  return emitOpError("dimension mismatch (")
2602  << (mismatchingDims.first.isScalable ? "[" : "")
2603  << mismatchingDims.first.dim
2604  << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
2605  << (mismatchingDims.second.isScalable ? "[" : "")
2606  << mismatchingDims.second.dim
2607  << (mismatchingDims.second.isScalable ? "]" : "") << ")";
2608  }
2609  if (res == BroadcastableToResult::SourceTypeNotAVector)
2610  return emitOpError("source type is not a vector");
2611  llvm_unreachable("unexpected vector.broadcast op error");
2612 }
2613 
2614 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
2615  if (getSourceType() == getResultVectorType())
2616  return getSource();
2617  if (!adaptor.getSource())
2618  return {};
2619  auto vectorType = getResultVectorType();
2620  if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
2621  if (vectorType.getElementType() != attr.getType())
2622  return {};
2623  return DenseElementsAttr::get(vectorType, attr);
2624  }
2625  if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
2626  if (vectorType.getElementType() != attr.getType())
2627  return {};
2628  return DenseElementsAttr::get(vectorType, attr);
2629  }
2630  if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2631  return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
2632  if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
2633  return ub::PoisonAttr::get(getContext());
2634  return {};
2635 }
2636 
2637 namespace {
2638 
2639 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
2640 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2642 
2643  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2644  PatternRewriter &rewriter) const override {
2645  auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2646  if (!srcBroadcast)
2647  return failure();
2648  rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
2649  broadcastOp.getResultVectorType(),
2650  srcBroadcast.getSource());
2651  return success();
2652  }
2653 };
2654 } // namespace
2655 
2656 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2657  MLIRContext *context) {
2658  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2659  // calling `populateCastAwayVectorLeadingOneDimPatterns`
2660  results.add<BroadcastFolder>(context);
2661 }
2662 
2663 //===----------------------------------------------------------------------===//
2664 // ShuffleOp
2665 //===----------------------------------------------------------------------===//
2666 
2667 LogicalResult ShuffleOp::verify() {
2668  VectorType resultType = getResultVectorType();
2669  VectorType v1Type = getV1VectorType();
2670  VectorType v2Type = getV2VectorType();
2671  // Verify ranks.
2672  int64_t resRank = resultType.getRank();
2673  int64_t v1Rank = v1Type.getRank();
2674  int64_t v2Rank = v2Type.getRank();
2675  bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2676  bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2677  if (!wellFormed0DCase && !wellFormedNDCase)
2678  return emitOpError("rank mismatch");
2679 
2680  // Verify all but leading dimension sizes.
2681  for (int64_t r = 1; r < v1Rank; ++r) {
2682  int64_t resDim = resultType.getDimSize(r);
2683  int64_t v1Dim = v1Type.getDimSize(r);
2684  int64_t v2Dim = v2Type.getDimSize(r);
2685  if (resDim != v1Dim || v1Dim != v2Dim)
2686  return emitOpError("dimension mismatch");
2687  }
2688  // Verify mask length.
2689  ArrayRef<int64_t> mask = getMask();
2690  int64_t maskLength = mask.size();
2691  if (maskLength <= 0)
2692  return emitOpError("invalid mask length");
2693  if (maskLength != resultType.getDimSize(0))
2694  return emitOpError("mask length mismatch");
2695  // Verify all indices.
2696  int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2697  (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2698  for (auto [idx, maskPos] : llvm::enumerate(mask)) {
2699  if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
2700  return emitOpError("mask index #") << (idx + 1) << " out of range";
2701  }
2702  return success();
2703 }
2704 
2705 LogicalResult
2706 ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
2707  ShuffleOp::Adaptor adaptor,
2708  SmallVectorImpl<Type> &inferredReturnTypes) {
2709  auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2710  auto v1Rank = v1Type.getRank();
2711  // Construct resulting type: leading dimension matches mask
2712  // length, all trailing dimensions match the operands.
2714  shape.reserve(v1Rank);
2715  shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2716  // In the 0-D case there is no trailing shape to append.
2717  if (v1Rank > 0)
2718  llvm::append_range(shape, v1Type.getShape().drop_front());
2719  inferredReturnTypes.push_back(
2720  VectorType::get(shape, v1Type.getElementType()));
2721  return success();
2722 }
2723 
2724 template <typename T>
2725 static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
2726  T expected = begin;
2727  return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2728  return value == expected++;
2729  });
2730 }
2731 
2732 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2733  auto v1Type = getV1VectorType();
2734  auto v2Type = getV2VectorType();
2735 
2736  assert(!v1Type.isScalable() && !v2Type.isScalable() &&
2737  "Vector shuffle does not support scalable vectors");
2738 
2739  // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
2740  // but must be a canonicalization into a vector.broadcast.
2741  if (v1Type.getRank() == 0)
2742  return {};
2743 
2744  // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
2745  auto mask = getMask();
2746  if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
2747  return getV1();
2748  // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
2749  if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
2750  return getV2();
2751 
2752  Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
2753  if (!v1Attr || !v2Attr)
2754  return {};
2755 
2756  // Fold shuffle poison, poison -> poison.
2757  bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
2758  bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
2759  if (isV1Poison && isV2Poison)
2760  return ub::PoisonAttr::get(getContext());
2761 
2762  // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2763  // manipulation.
2764  if (v1Type.getRank() != 1)
2765  return {};
2766 
2767  // Poison input attributes need special handling as they are not
2768  // DenseElementsAttr. If an index is poison, we select the first element of
2769  // the first non-poison input.
2770  SmallVector<Attribute> v1Elements, v2Elements;
2771  Attribute poisonElement;
2772  if (!isV2Poison) {
2773  v2Elements =
2774  to_vector(cast<DenseElementsAttr>(v2Attr).getValues<Attribute>());
2775  poisonElement = v2Elements[0];
2776  }
2777  if (!isV1Poison) {
2778  v1Elements =
2779  to_vector(cast<DenseElementsAttr>(v1Attr).getValues<Attribute>());
2780  poisonElement = v1Elements[0];
2781  }
2782 
2783  SmallVector<Attribute> results;
2784  int64_t v1Size = v1Type.getDimSize(0);
2785  for (int64_t maskIdx : mask) {
2786  Attribute indexedElm;
2787  // TODO: Return a partial poison vector when supported by the UB dialect.
2788  if (maskIdx == ShuffleOp::kPoisonIndex) {
2789  indexedElm = poisonElement;
2790  } else {
2791  if (maskIdx < v1Size)
2792  indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
2793  else
2794  indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
2795  }
2796 
2797  results.push_back(indexedElm);
2798  }
2799 
2800  return DenseElementsAttr::get(getResultVectorType(), results);
2801 }
2802 
2803 namespace {
2804 
2805 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
2806 // to a broadcast.
2807 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
2809 
2810  LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
2811  PatternRewriter &rewriter) const override {
2812  VectorType v1VectorType = shuffleOp.getV1VectorType();
2813  ArrayRef<int64_t> mask = shuffleOp.getMask();
2814  if (v1VectorType.getRank() > 0)
2815  return failure();
2816  if (mask.size() != 1)
2817  return failure();
2818  VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
2819  if (mask[0] == 0)
2820  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2821  shuffleOp.getV1());
2822  else
2823  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2824  shuffleOp.getV2());
2825  return success();
2826  }
2827 };
2828 
2829 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
2830 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
2831 public:
2833 
2834  LogicalResult matchAndRewrite(ShuffleOp op,
2835  PatternRewriter &rewriter) const override {
2836  auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2837  auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2838 
2839  if (!v1Splat || !v2Splat)
2840  return failure();
2841 
2842  if (v1Splat.getInput() != v2Splat.getInput())
2843  return failure();
2844 
2845  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
2846  return success();
2847  }
2848 };
2849 
2850 /// Pattern to rewrite a fixed-size interleave via vector.shuffle to
2851 /// vector.interleave.
2852 class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
2853 public:
2855 
2856  LogicalResult matchAndRewrite(ShuffleOp op,
2857  PatternRewriter &rewriter) const override {
2858  VectorType resultType = op.getResultVectorType();
2859  if (resultType.isScalable())
2860  return rewriter.notifyMatchFailure(
2861  op, "ShuffleOp can't represent a scalable interleave");
2862 
2863  if (resultType.getRank() != 1)
2864  return rewriter.notifyMatchFailure(
2865  op, "ShuffleOp can't represent an n-D interleave");
2866 
2867  VectorType sourceType = op.getV1VectorType();
2868  if (sourceType != op.getV2VectorType() ||
2869  sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2870  return rewriter.notifyMatchFailure(
2871  op, "ShuffleOp types don't match an interleave");
2872  }
2873 
2874  ArrayRef<int64_t> shuffleMask = op.getMask();
2875  int64_t resultVectorSize = resultType.getNumElements();
2876  for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2877  int64_t maskValueA = shuffleMask[i * 2];
2878  int64_t maskValueB = shuffleMask[(i * 2) + 1];
2879  if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2880  return rewriter.notifyMatchFailure(op,
2881  "ShuffleOp mask not interleaving");
2882  }
2883 
2884  rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
2885  return success();
2886  }
2887 };
2888 
2889 } // namespace
2890 
2891 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
2892  MLIRContext *context) {
2893  results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2894  context);
2895 }
2896 
2897 //===----------------------------------------------------------------------===//
2898 // InsertElementOp
2899 //===----------------------------------------------------------------------===//
2900 
2901 void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2902  SetIntRangeFn setResultRanges) {
2903  setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2904 }
2905 
2906 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
2907  Value source, Value dest) {
2908  build(builder, result, source, dest, {});
2909 }
2910 
2911 LogicalResult InsertElementOp::verify() {
2912  auto dstVectorType = getDestVectorType();
2913  if (dstVectorType.getRank() == 0) {
2914  if (getPosition())
2915  return emitOpError("expected position to be empty with 0-D vector");
2916  return success();
2917  }
2918  if (dstVectorType.getRank() != 1)
2919  return emitOpError("unexpected >1 vector rank");
2920  if (!getPosition())
2921  return emitOpError("expected position for 1-D vector");
2922  return success();
2923 }
2924 
2925 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2926  // Skip the 0-D vector here.
2927  if (!adaptor.getPosition())
2928  return {};
2929 
2930  auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2931  auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2932  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2933  if (!src || !dst || !pos)
2934  return {};
2935 
2936  if (src.getType() != getDestVectorType().getElementType())
2937  return {};
2938 
2939  auto dstElements = dst.getValues<Attribute>();
2940 
2941  SmallVector<Attribute> results(dstElements);
2942 
2943  uint64_t posIdx = pos.getInt();
2944  if (posIdx >= results.size())
2945  return {};
2946  results[posIdx] = src;
2947 
2948  return DenseElementsAttr::get(getDestVectorType(), results);
2949 }
2950 
2951 //===----------------------------------------------------------------------===//
2952 // InsertOp
2953 //===----------------------------------------------------------------------===//
2954 
2955 void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2956  SetIntRangeFn setResultRanges) {
2957  setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2958 }
2959 
2960 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2961  Value source, Value dest) {
2962  auto vectorTy = cast<VectorType>(dest.getType());
2963  build(builder, result, source, dest,
2964  SmallVector<int64_t>(vectorTy.getRank(), 0));
2965 }
2966 
2967 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2968  Value source, Value dest, int64_t position) {
2969  build(builder, result, source, dest, ArrayRef<int64_t>{position});
2970 }
2971 
2972 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2973  Value source, Value dest, OpFoldResult position) {
2974  build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
2975 }
2976 
2977 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2978  Value source, Value dest,
2979  ArrayRef<int64_t> position) {
2980  SmallVector<OpFoldResult> posVals;
2981  posVals.reserve(position.size());
2982  llvm::transform(position, std::back_inserter(posVals),
2983  [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
2984  build(builder, result, source, dest, posVals);
2985 }
2986 
2987 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2988  Value source, Value dest,
2989  ArrayRef<OpFoldResult> position) {
2990  SmallVector<int64_t> staticPos;
2991  SmallVector<Value> dynamicPos;
2992  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
2993  build(builder, result, source, dest, dynamicPos,
2994  builder.getDenseI64ArrayAttr(staticPos));
2995 }
2996 
2997 LogicalResult InsertOp::verify() {
2998  SmallVector<OpFoldResult> position = getMixedPosition();
2999  auto destVectorType = getDestVectorType();
3000  if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
3001  return emitOpError(
3002  "expected position attribute of rank no greater than dest vector rank");
3003  auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3004  if (srcVectorType &&
3005  (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3006  static_cast<unsigned>(destVectorType.getRank())))
3007  return emitOpError("expected position attribute rank + source rank to "
3008  "match dest vector rank");
3009  if (!srcVectorType &&
3010  (position.size() != static_cast<unsigned>(destVectorType.getRank())))
3011  return emitOpError(
3012  "expected position attribute rank to match the dest vector rank");
3013  for (auto [idx, pos] : llvm::enumerate(position)) {
3014  if (auto attr = dyn_cast<Attribute>(pos)) {
3015  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3016  if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
3017  destVectorType.getDimSize(idx))) {
3018  return emitOpError("expected position attribute #")
3019  << (idx + 1)
3020  << " to be a non-negative integer smaller than the "
3021  "corresponding "
3022  "dest vector dimension";
3023  }
3024  }
3025  }
3026  return success();
3027 }
3028 
3029 namespace {
3030 
3031 // If insertOp is only inserting unit dimensions it can be transformed to a
3032 // broadcast.
3033 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3034 public:
3036 
3037  LogicalResult matchAndRewrite(InsertOp insertOp,
3038  PatternRewriter &rewriter) const override {
3039  auto srcVecType =
3040  llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3041  if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3042  srcVecType.getNumElements())
3043  return failure();
3044  rewriter.replaceOpWithNewOp<BroadcastOp>(
3045  insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3046  return success();
3047  }
3048 };
3049 
3050 /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
3051 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3052 public:
3054 
3055  LogicalResult matchAndRewrite(InsertOp op,
3056  PatternRewriter &rewriter) const override {
3057  auto srcSplat = op.getValueToStore().getDefiningOp<SplatOp>();
3058  auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
3059 
3060  if (!srcSplat || !dstSplat)
3061  return failure();
3062 
3063  if (srcSplat.getInput() != dstSplat.getInput())
3064  return failure();
3065 
3066  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
3067  return success();
3068  }
3069 };
3070 
3071 } // namespace
3072 
3073 static Attribute
3075  Attribute dstAttr,
3076  int64_t maxVectorSizeFoldThreshold) {
3077  if (insertOp.hasDynamicPosition())
3078  return {};
3079 
3080  auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3081  if (!denseDst)
3082  return {};
3083 
3084  if (!srcAttr) {
3085  return {};
3086  }
3087 
3088  VectorType destTy = insertOp.getDestVectorType();
3089  if (destTy.isScalable())
3090  return {};
3091 
3092  // Make sure we do not create too many large constants.
3093  if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3094  !insertOp->hasOneUse())
3095  return {};
3096 
3097  // Calculate the linearized position of the continuous chunk of elements to
3098  // insert.
3099  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3100  copy(insertOp.getStaticPosition(), completePositions.begin());
3101  int64_t insertBeginPosition =
3102  linearize(completePositions, computeStrides(destTy.getShape()));
3103 
3104  SmallVector<Attribute> insertedValues;
3105  Type destEltType = destTy.getElementType();
3106 
3107  /// Converts the expected type to an IntegerAttr if there's
3108  /// a mismatch.
3109  auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
3110  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3111  if (intAttr.getType() != expectedType)
3112  return IntegerAttr::get(expectedType, intAttr.getInt());
3113  }
3114  return attr;
3115  };
3116 
3117  // The `convertIntegerAttr` method specifically handles the case
3118  // for `llvm.mlir.constant` which can hold an attribute with a
3119  // different type than the return type.
3120  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3121  for (auto value : denseSource.getValues<Attribute>())
3122  insertedValues.push_back(convertIntegerAttr(value, destEltType));
3123  } else {
3124  insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
3125  }
3126 
3127  auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
3128  copy(insertedValues, allValues.begin() + insertBeginPosition);
3129  auto newAttr = DenseElementsAttr::get(destTy, allValues);
3130 
3131  return newAttr;
3132 }
3133 
3134 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3135  MLIRContext *context) {
3136  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
3137 }
3138 
3139 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3140  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3141  // unless the source vector constant has a single use.
3142  constexpr int64_t vectorSizeFoldThreshold = 256;
3143  // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3144  // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3145  // (type mismatch).
3146  if (getNumIndices() == 0 && getValueToStoreType() == getType())
3147  return getValueToStore();
3148  SmallVector<Value> operands = {getValueToStore(), getDest()};
3149  if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
3150  return val;
3151  if (auto res = foldPoisonIndexInsertExtractOp(
3152  getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3153  return res;
3154  if (auto res = foldDenseElementsAttrDestInsertOp(
3155  *this, adaptor.getValueToStore(), adaptor.getDest(),
3156  vectorSizeFoldThreshold)) {
3157  return res;
3158  }
3159 
3160  return {};
3161 }
3162 
3163 //===----------------------------------------------------------------------===//
3164 // InsertStridedSliceOp
3165 //===----------------------------------------------------------------------===//
3166 
3167 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3168  Value source, Value dest,
3169  ArrayRef<int64_t> offsets,
3170  ArrayRef<int64_t> strides) {
3171  result.addOperands({source, dest});
3172  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3173  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3174  result.addTypes(dest.getType());
3175  result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
3176  offsetsAttr);
3177  result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
3178  stridesAttr);
3179 }
3180 
3181 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
3182 template <typename OpType>
3183 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3184  ArrayAttr arrayAttr,
3185  ArrayRef<int64_t> shape,
3186  StringRef attrName) {
3187  if (arrayAttr.size() > shape.size())
3188  return op.emitOpError("expected ")
3189  << attrName << " attribute of rank no greater than vector rank";
3190  return success();
3191 }
3192 
3193 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3194 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3195 // Otherwise, the admissible interval is [min, max].
3196 template <typename OpType>
3197 static LogicalResult
3198 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
3199  int64_t max, StringRef attrName,
3200  bool halfOpen = true) {
3201  for (auto attr : arrayAttr) {
3202  auto val = llvm::cast<IntegerAttr>(attr).getInt();
3203  auto upper = max;
3204  if (!halfOpen)
3205  upper += 1;
3206  if (val < min || val >= upper)
3207  return op.emitOpError("expected ") << attrName << " to be confined to ["
3208  << min << ", " << upper << ")";
3209  }
3210  return success();
3211 }
3212 
3213 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3214 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3215 // Otherwise, the admissible interval is [min, max].
3216 template <typename OpType>
3217 static LogicalResult
3218 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
3219  ArrayRef<int64_t> shape, StringRef attrName,
3220  bool halfOpen = true, int64_t min = 0) {
3221  for (auto [index, attrDimPair] :
3222  llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3223  int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3224  int64_t max = std::get<1>(attrDimPair);
3225  if (!halfOpen)
3226  max += 1;
3227  if (val < min || val >= max)
3228  return op.emitOpError("expected ")
3229  << attrName << " dimension " << index << " to be confined to ["
3230  << min << ", " << max << ")";
3231  }
3232  return success();
3233 }
3234 
3235 // Returns true if, for all indices i = 0..shape.size()-1, val is in the
3236 // [min, max} interval:
3237 // val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3238 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3239 // the admissible interval is [min, max].
3240 template <typename OpType>
3242  OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3243  ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3244  bool halfOpen = true, int64_t min = 1) {
3245  assert(arrayAttr1.size() <= shape.size());
3246  assert(arrayAttr2.size() <= shape.size());
3247  for (auto [index, it] :
3248  llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3249  auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3250  auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3251  int64_t max = std::get<2>(it);
3252  if (!halfOpen)
3253  max += 1;
3254  if (val1 + val2 < 0 || val1 + val2 >= max)
3255  return op.emitOpError("expected sum(")
3256  << attrName1 << ", " << attrName2 << ") dimension " << index
3257  << " to be confined to [" << min << ", " << max << ")";
3258  }
3259  return success();
3260 }
3261 
3262 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
3263  MLIRContext *context) {
3264  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3265  return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3266  });
3267  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3268 }
3269 
3270 LogicalResult InsertStridedSliceOp::verify() {
3271  auto sourceVectorType = getSourceVectorType();
3272  auto destVectorType = getDestVectorType();
3273  auto offsets = getOffsetsAttr();
3274  auto strides = getStridesAttr();
3275  if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3276  return emitOpError(
3277  "expected offsets of same size as destination vector rank");
3278  if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3279  return emitOpError("expected strides of same size as source vector rank");
3280  if (sourceVectorType.getRank() > destVectorType.getRank())
3281  return emitOpError(
3282  "expected source rank to be no greater than destination rank");
3283 
3284  auto sourceShape = sourceVectorType.getShape();
3285  auto destShape = destVectorType.getShape();
3286  SmallVector<int64_t, 4> sourceShapeAsDestShape(
3287  destShape.size() - sourceShape.size(), 0);
3288  sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3289  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3290  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3291  if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3292  offName)) ||
3293  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3294  /*max=*/1, stridesName,
3295  /*halfOpen=*/false)) ||
3297  *this, offsets,
3298  makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3299  offName, "source vector shape",
3300  /*halfOpen=*/false, /*min=*/1)))
3301  return failure();
3302 
3303  unsigned rankDiff = destShape.size() - sourceShape.size();
3304  for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3305  if (sourceVectorType.getScalableDims()[idx] !=
3306  destVectorType.getScalableDims()[idx + rankDiff]) {
3307  return emitOpError("mismatching scalable flags (at source vector idx=")
3308  << idx << ")";
3309  }
3310  if (sourceVectorType.getScalableDims()[idx]) {
3311  auto sourceSize = sourceShape[idx];
3312  auto destSize = destShape[idx + rankDiff];
3313  if (sourceSize != destSize) {
3314  return emitOpError("expected size at idx=")
3315  << idx
3316  << (" to match the corresponding base size from the input "
3317  "vector (")
3318  << sourceSize << (" vs ") << destSize << (")");
3319  }
3320  }
3321  }
3322 
3323  return success();
3324 }
3325 
3326 namespace {
3327 /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3328 /// SplatOp(X):dst_type) to SplatOp(X):dst_type.
3329 class FoldInsertStridedSliceSplat final
3330  : public OpRewritePattern<InsertStridedSliceOp> {
3331 public:
3333 
3334  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3335  PatternRewriter &rewriter) const override {
3336  auto srcSplatOp =
3337  insertStridedSliceOp.getValueToStore().getDefiningOp<vector::SplatOp>();
3338  auto destSplatOp =
3339  insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3340 
3341  if (!srcSplatOp || !destSplatOp)
3342  return failure();
3343 
3344  if (srcSplatOp.getInput() != destSplatOp.getInput())
3345  return failure();
3346 
3347  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3348  return success();
3349  }
3350 };
3351 
3352 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3353 /// to dst.
3354 class FoldInsertStridedSliceOfExtract final
3355  : public OpRewritePattern<InsertStridedSliceOp> {
3356 public:
3358 
3359  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3360  PatternRewriter &rewriter) const override {
3361  auto extractStridedSliceOp =
3362  insertStridedSliceOp.getValueToStore()
3363  .getDefiningOp<vector::ExtractStridedSliceOp>();
3364 
3365  if (!extractStridedSliceOp)
3366  return failure();
3367 
3368  if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3369  return failure();
3370 
3371  // Check if have the same strides and offsets.
3372  if (extractStridedSliceOp.getStrides() !=
3373  insertStridedSliceOp.getStrides() ||
3374  extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3375  return failure();
3376 
3377  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3378  return success();
3379  }
3380 };
3381 
3382 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3383 // ConstantOp.
3384 class InsertStridedSliceConstantFolder final
3385  : public OpRewritePattern<InsertStridedSliceOp> {
3386 public:
3388 
3389  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3390  // unless the source vector constant has a single use.
3391  static constexpr int64_t vectorSizeFoldThreshold = 256;
3392 
3393  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3394  PatternRewriter &rewriter) const override {
3395  // Return if 'InsertOp' operand is not defined by a compatible vector
3396  // ConstantOp.
3397  TypedValue<VectorType> destVector = op.getDest();
3398  Attribute vectorDestCst;
3399  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3400  return failure();
3401 
3402  VectorType destTy = destVector.getType();
3403  if (destTy.isScalable())
3404  return failure();
3405 
3406  // Make sure we do not create too many large constants.
3407  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3408  !destVector.hasOneUse())
3409  return failure();
3410 
3411  TypedValue<VectorType> sourceValue = op.getValueToStore();
3412  Attribute sourceCst;
3413  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3414  return failure();
3415 
3416  // TODO: Support poison.
3417  if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3418  return failure();
3419 
3420  // TODO: Handle non-unit strides when they become available.
3421  if (op.hasNonUnitStrides())
3422  return failure();
3423 
3424  VectorType sliceVecTy = sourceValue.getType();
3425  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3426  int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3427  SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3428  SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3429 
3430  // Calcualte the destination element indices by enumerating all slice
3431  // positions within the destination and linearizing them. The enumeration
3432  // order is lexicographic which yields a sequence of monotonically
3433  // increasing linearized position indices.
3434  // Because the destination may have higher dimensionality then the slice,
3435  // we keep track of two overlapping sets of positions and offsets.
3436  auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3437  auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3438  auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3439  auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3440  SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3441  MutableArrayRef<int64_t> currSlicePosition(
3442  currDestPosition.begin() + rankDifference, currDestPosition.end());
3443  ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3444  offsets.end());
3445  do {
3446  int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3447  assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3448  assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3449  "Invalid slice element");
3450  newValues[linearizedPosition] = *sliceValuesIt;
3451  ++sliceValuesIt;
3452  } while (succeeded(
3453  incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3454 
3455  auto newAttr = DenseElementsAttr::get(destTy, newValues);
3456  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3457  return success();
3458  }
3459 };
3460 
3461 } // namespace
3462 
3463 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3464  RewritePatternSet &results, MLIRContext *context) {
3465  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3466  InsertStridedSliceConstantFolder>(context);
3467 }
3468 
3469 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3470  if (getSourceVectorType() == getDestVectorType())
3471  return getValueToStore();
3472  return {};
3473 }
3474 
3475 //===----------------------------------------------------------------------===//
3476 // OuterProductOp
3477 //===----------------------------------------------------------------------===//
3478 
3479 /// Build an op without mask, use the type of `acc` as the return type.
3480 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
3481  Value lhs, Value rhs, Value acc) {
3482  result.addOperands({lhs, rhs, acc});
3483  result.addTypes(acc.getType());
3484 }
3485 
3487  p << " " << getLhs() << ", " << getRhs();
3488  if (getAcc()) {
3489  p << ", " << getAcc();
3490  p.printOptionalAttrDict((*this)->getAttrs());
3491  }
3492  p << " : " << getLhs().getType() << ", " << getRhs().getType();
3493 }
3494 
3495 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
3497  Type tLHS, tRHS;
3498  if (parser.parseOperandList(operandsInfo) ||
3499  parser.parseOptionalAttrDict(result.attributes) ||
3500  parser.parseColonType(tLHS) || parser.parseComma() ||
3501  parser.parseType(tRHS))
3502  return failure();
3503  if (operandsInfo.size() < 2)
3504  return parser.emitError(parser.getNameLoc(),
3505  "expected at least 2 operands");
3506  VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3507  VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3508  if (!vLHS)
3509  return parser.emitError(parser.getNameLoc(),
3510  "expected vector type for operand #1");
3511 
3512  VectorType resType;
3513  if (vRHS) {
3514  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3515  vRHS.getScalableDims()[0]};
3516  resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
3517  vLHS.getElementType(), scalableDimsRes);
3518  } else {
3519  // Scalar RHS operand
3520  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3521  resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3522  scalableDimsRes);
3523  }
3524 
3525  if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
3526  result.attributes.append(
3527  OuterProductOp::getKindAttrName(result.name),
3529  OuterProductOp::getDefaultKind()));
3530  }
3531 
3532  return failure(
3533  parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
3534  parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
3535  (operandsInfo.size() > 2 &&
3536  parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
3537  parser.addTypeToList(resType, result.types));
3538 }
3539 
3540 LogicalResult OuterProductOp::verify() {
3541  Type tRHS = getOperandTypeRHS();
3542  VectorType vLHS = getOperandVectorTypeLHS(),
3543  vRHS = llvm::dyn_cast<VectorType>(tRHS),
3544  vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3545 
3546  if (vLHS.getRank() != 1)
3547  return emitOpError("expected 1-d vector for operand #1");
3548 
3549  if (vRHS) {
3550  // Proper OUTER operation.
3551  if (vRHS.getRank() != 1)
3552  return emitOpError("expected 1-d vector for operand #2");
3553  if (vRES.getRank() != 2)
3554  return emitOpError("expected 2-d vector result");
3555  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3556  return emitOpError("expected #1 operand dim to match result dim #1");
3557  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3558  return emitOpError("expected #2 operand dim to match result dim #2");
3559  if (vLHS.isScalable() && !vRHS.isScalable()) {
3560  // This restriction reflects what's currently supported in terms of
3561  // scalable vectors. However, we could relax this if there's a use case.
3562  return emitOpError(
3563  "expected either both or only #2 operand dim to be scalable");
3564  }
3565  } else {
3566  // An AXPY operation.
3567  if (vRES.getRank() != 1)
3568  return emitOpError("expected 1-d vector result");
3569  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3570  return emitOpError("expected #1 operand dim to match result dim #1");
3571  }
3572 
3573  if (vACC && vACC != vRES)
3574  return emitOpError("expected operand #3 of same type as result type");
3575 
3576  // Verify supported combining kind.
3577  if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
3578  return emitOpError("unsupported outerproduct type");
3579 
3580  return success();
3581 }
3582 
3583 // MaskableOpInterface methods.
3584 
3585 /// Returns the mask type expected by this operation. Mostly used for
3586 /// verification purposes. It requires the operation to be vectorized."
3587 Type OuterProductOp::getExpectedMaskType() {
3588  auto vecType = this->getResultVectorType();
3589  return VectorType::get(vecType.getShape(),
3590  IntegerType::get(vecType.getContext(), /*width=*/1),
3591  vecType.getScalableDims());
3592 }
3593 
3594 //===----------------------------------------------------------------------===//
3595 // ExtractStridedSliceOp
3596 //===----------------------------------------------------------------------===//
3597 
3598 // Inference works as follows:
3599 // 1. Add 'sizes' from prefix of dims in 'offsets'.
3600 // 2. Add sizes from 'vectorType' for remaining dims.
3601 // Scalable flags are inherited from 'vectorType'.
3602 static Type inferStridedSliceOpResultType(VectorType vectorType,
3603  ArrayAttr offsets, ArrayAttr sizes,
3604  ArrayAttr strides) {
3605  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3607  shape.reserve(vectorType.getRank());
3608  unsigned idx = 0;
3609  for (unsigned e = offsets.size(); idx < e; ++idx)
3610  shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3611  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3612  shape.push_back(vectorType.getShape()[idx]);
3613 
3614  return VectorType::get(shape, vectorType.getElementType(),
3615  vectorType.getScalableDims());
3616 }
3617 
3618 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3619  Value source, ArrayRef<int64_t> offsets,
3620  ArrayRef<int64_t> sizes,
3621  ArrayRef<int64_t> strides) {
3622  result.addOperands(source);
3623  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3624  auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
3625  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3626  result.addTypes(
3627  inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
3628  offsetsAttr, sizesAttr, stridesAttr));
3629  result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
3630  offsetsAttr);
3631  result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
3632  sizesAttr);
3633  result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
3634  stridesAttr);
3635 }
3636 
3637 LogicalResult ExtractStridedSliceOp::verify() {
3638  auto type = getSourceVectorType();
3639  auto offsets = getOffsetsAttr();
3640  auto sizes = getSizesAttr();
3641  auto strides = getStridesAttr();
3642  if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3643  return emitOpError(
3644  "expected offsets, sizes and strides attributes of same size");
3645 
3646  auto shape = type.getShape();
3647  auto offName = getOffsetsAttrName();
3648  auto sizesName = getSizesAttrName();
3649  auto stridesName = getStridesAttrName();
3650  if (failed(
3651  isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
3652  failed(
3653  isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
3654  failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
3655  stridesName)) ||
3656  failed(
3657  isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
3658  failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
3659  /*halfOpen=*/false,
3660  /*min=*/1)) ||
3661  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3662  /*max=*/1, stridesName,
3663  /*halfOpen=*/false)) ||
3664  failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
3665  shape, offName, sizesName,
3666  /*halfOpen=*/false)))
3667  return failure();
3668 
3669  auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
3670  offsets, sizes, strides);
3671  if (getResult().getType() != resultType)
3672  return emitOpError("expected result type to be ") << resultType;
3673 
3674  for (unsigned idx = 0; idx < sizes.size(); ++idx) {
3675  if (type.getScalableDims()[idx]) {
3676  auto inputDim = type.getShape()[idx];
3677  auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3678  if (inputDim != inputSize)
3679  return emitOpError("expected size at idx=")
3680  << idx
3681  << (" to match the corresponding base size from the input "
3682  "vector (")
3683  << inputSize << (" vs ") << inputDim << (")");
3684  }
3685  }
3686 
3687  return success();
3688 }
3689 
3690 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
3691 // to use the source of the InsertStrided ops if we can detect that the
3692 // extracted vector is a subset of one of the vector inserted.
3693 static LogicalResult
3694 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
3695  // Helper to extract integer out of ArrayAttr.
3696  auto getElement = [](ArrayAttr array, int idx) {
3697  return llvm::cast<IntegerAttr>(array[idx]).getInt();
3698  };
3699  ArrayAttr extractOffsets = op.getOffsets();
3700  ArrayAttr extractStrides = op.getStrides();
3701  ArrayAttr extractSizes = op.getSizes();
3702  auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3703  while (insertOp) {
3704  if (op.getSourceVectorType().getRank() !=
3705  insertOp.getSourceVectorType().getRank())
3706  return failure();
3707  ArrayAttr insertOffsets = insertOp.getOffsets();
3708  ArrayAttr insertStrides = insertOp.getStrides();
3709  // If the rank of extract is greater than the rank of insert, we are likely
3710  // extracting a partial chunk of the vector inserted.
3711  if (extractOffsets.size() > insertOffsets.size())
3712  return failure();
3713  bool patialoverlap = false;
3714  bool disjoint = false;
3715  SmallVector<int64_t, 4> offsetDiffs;
3716  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3717  if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
3718  return failure();
3719  int64_t start = getElement(insertOffsets, dim);
3720  int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3721  int64_t offset = getElement(extractOffsets, dim);
3722  int64_t size = getElement(extractSizes, dim);
3723  // Check if the start of the extract offset is in the interval inserted.
3724  if (start <= offset && offset < end) {
3725  // If the extract interval overlaps but is not fully included we may
3726  // have a partial overlap that will prevent any folding.
3727  if (offset + size > end)
3728  patialoverlap = true;
3729  offsetDiffs.push_back(offset - start);
3730  continue;
3731  }
3732  disjoint = true;
3733  break;
3734  }
3735  // The extract element chunk is a subset of the insert element.
3736  if (!disjoint && !patialoverlap) {
3737  op.setOperand(insertOp.getValueToStore());
3738  // OpBuilder is only used as a helper to build an I64ArrayAttr.
3739  OpBuilder b(op.getContext());
3740  op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
3741  return success();
3742  }
3743  // If the chunk extracted is disjoint from the chunk inserted, keep looking
3744  // in the insert chain.
3745  if (disjoint)
3746  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3747  else {
3748  // The extracted vector partially overlap the inserted vector, we cannot
3749  // fold.
3750  return failure();
3751  }
3752  }
3753  return failure();
3754 }
3755 
3756 // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
3757 static OpFoldResult
3759  Attribute foldInput) {
3760 
3761  auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
3762  if (!dense)
3763  return {};
3764 
3765  // TODO: Handle non-unit strides when they become available.
3766  if (op.hasNonUnitStrides())
3767  return {};
3768 
3769  VectorType sourceVecTy = op.getSourceVectorType();
3770  ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3771  SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3772 
3773  VectorType sliceVecTy = op.getType();
3774  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3775  int64_t rank = sliceVecTy.getRank();
3776 
3777  // Expand offsets and sizes to match the vector rank.
3778  SmallVector<int64_t, 4> offsets(rank, 0);
3779  copy(getI64SubArray(op.getOffsets()), offsets.begin());
3780 
3781  SmallVector<int64_t, 4> sizes(sourceShape);
3782  copy(getI64SubArray(op.getSizes()), sizes.begin());
3783 
3784  // Calculate the slice elements by enumerating all slice positions and
3785  // linearizing them. The enumeration order is lexicographic which yields a
3786  // sequence of monotonically increasing linearized position indices.
3787  const auto denseValuesBegin = dense.value_begin<Attribute>();
3788  SmallVector<Attribute> sliceValues;
3789  sliceValues.reserve(sliceVecTy.getNumElements());
3790  SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3791  do {
3792  int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
3793  assert(linearizedPosition < sourceVecTy.getNumElements() &&
3794  "Invalid index");
3795  sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3796  } while (succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
3797 
3798  assert(static_cast<int64_t>(sliceValues.size()) ==
3799  sliceVecTy.getNumElements() &&
3800  "Invalid number of slice elements");
3801  return DenseElementsAttr::get(sliceVecTy, sliceValues);
3802 }
3803 
3804 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3805  if (getSourceVectorType() == getResult().getType())
3806  return getVector();
3807  if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
3808  return getResult();
3809 
3810  // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3811  if (auto splat =
3812  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
3813  DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
3814 
3815  // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
3816  return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getVector());
3817 }
3818 
3819 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
3820  populateFromInt64AttrArray(getOffsets(), results);
3821 }
3822 
3823 namespace {
3824 
3825 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
3826 // ConstantMaskOp.
3827 class StridedSliceConstantMaskFolder final
3828  : public OpRewritePattern<ExtractStridedSliceOp> {
3829 public:
3831 
3832  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3833  PatternRewriter &rewriter) const override {
3834  // Return if 'extractStridedSliceOp' operand is not defined by a
3835  // ConstantMaskOp.
3836  auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3837  auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3838  if (!constantMaskOp)
3839  return failure();
3840  // Return if 'extractStridedSliceOp' has non-unit strides.
3841  if (extractStridedSliceOp.hasNonUnitStrides())
3842  return failure();
3843  // Gather constant mask dimension sizes.
3844  ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
3845  // Gather strided slice offsets and sizes.
3846  SmallVector<int64_t, 4> sliceOffsets;
3847  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
3848  sliceOffsets);
3849  SmallVector<int64_t, 4> sliceSizes;
3850  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
3851 
3852  // Compute slice of vector mask region.
3853  SmallVector<int64_t, 4> sliceMaskDimSizes;
3854  sliceMaskDimSizes.reserve(maskDimSizes.size());
3855  for (auto [maskDimSize, sliceOffset, sliceSize] :
3856  llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3857  int64_t sliceMaskDimSize = std::max(
3858  static_cast<int64_t>(0),
3859  std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3860  sliceMaskDimSizes.push_back(sliceMaskDimSize);
3861  }
3862  // Add unchanged dimensions.
3863  if (sliceMaskDimSizes.size() < maskDimSizes.size())
3864  for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3865  sliceMaskDimSizes.push_back(maskDimSizes[i]);
3866  // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
3867  // region is a conjunction of mask dim intervals).
3868  if (llvm::is_contained(sliceMaskDimSizes, 0))
3869  sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3870 
3871  // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
3872  // region.
3873  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3874  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3875  sliceMaskDimSizes);
3876  return success();
3877  }
3878 };
3879 
3880 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
3881 // BroadcastOp(ExtractStrideSliceOp).
3882 class StridedSliceBroadcast final
3883  : public OpRewritePattern<ExtractStridedSliceOp> {
3884 public:
3886 
3887  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3888  PatternRewriter &rewriter) const override {
3889  auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3890  if (!broadcast)
3891  return failure();
3892  auto srcVecType =
3893  llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
3894  unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3895  auto dstVecType = llvm::cast<VectorType>(op.getType());
3896  unsigned dstRank = dstVecType.getRank();
3897  unsigned rankDiff = dstRank - srcRank;
3898  // Check if the most inner dimensions of the source of the broadcast are the
3899  // same as the destination of the extract. If this is the case we can just
3900  // use a broadcast as the original dimensions are untouched.
3901  bool lowerDimMatch = true;
3902  for (unsigned i = 0; i < srcRank; i++) {
3903  if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3904  lowerDimMatch = false;
3905  break;
3906  }
3907  }
3908  Value source = broadcast.getSource();
3909  // If the inner dimensions don't match, it means we need to extract from the
3910  // source of the orignal broadcast and then broadcast the extracted value.
3911  // We also need to handle degenerated cases where the source is effectively
3912  // just a single scalar.
3913  bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3914  if (!lowerDimMatch && !isScalarSrc) {
3915  source = rewriter.create<ExtractStridedSliceOp>(
3916  op->getLoc(), source,
3917  getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
3918  getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
3919  getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
3920  }
3921  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
3922  return success();
3923  }
3924 };
3925 
3926 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
3927 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
3928 public:
3930 
3931  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3932  PatternRewriter &rewriter) const override {
3933  auto splat = op.getVector().getDefiningOp<SplatOp>();
3934  if (!splat)
3935  return failure();
3936  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
3937  return success();
3938  }
3939 };
3940 
3941 /// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
3942 /// slice is contiguous, into extract and shape_cast.
3943 ///
3944 /// Example:
3945 /// Before:
3946 /// %1 = vector.extract_strided_slice %arg0 {
3947 /// offsets = [0, 0, 0, 0, 0],
3948 /// sizes = [1, 1, 1, 1, 8],
3949 /// strides = [1, 1, 1, 1, 1]
3950 /// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
3951 /// After:
3952 /// %0 = vector.extract %arg0[0, 0, 0, 0]
3953 /// : vector<8xi8> from vector<8x1x1x2x8xi8>
3954 /// %1 = vector.shape_cast %0
3955 /// : vector<8xi8> to vector<1x1x1x1x8xi8>
3956 ///
3957 class ContiguousExtractStridedSliceToExtract final
3958  : public OpRewritePattern<ExtractStridedSliceOp> {
3959 public:
3961 
3962  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3963  PatternRewriter &rewriter) const override {
3964  if (op.hasNonUnitStrides())
3965  return failure();
3966  Value source = op.getOperand();
3967  auto sourceType = cast<VectorType>(source.getType());
3968  if (sourceType.isScalable() || sourceType.getRank() == 0)
3969  return failure();
3970 
3971  // Compute the number of offsets to pass to ExtractOp::build. That is the
3972  // difference between the source rank and the desired slice rank. We walk
3973  // the dimensions from innermost out, and stop when the next slice dimension
3974  // is not full-size.
3975  SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
3976  int numOffsets;
3977  for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
3978  if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
3979  break;
3980  }
3981 
3982  // If the created extract op would have no offsets, then this whole
3983  // extract_strided_slice is the identity and should have been handled by
3984  // other canonicalizations.
3985  if (numOffsets == 0)
3986  return failure();
3987 
3988  // If not even the inner-most dimension is full-size, this op can't be
3989  // rewritten as an ExtractOp.
3990  if (numOffsets == sourceType.getRank() &&
3991  static_cast<int>(sizes.size()) == sourceType.getRank())
3992  return failure();
3993 
3994  // The outer dimensions must have unit size.
3995  for (int i = 0; i < numOffsets; ++i) {
3996  if (sizes[i] != 1)
3997  return failure();
3998  }
3999 
4000  // Avoid generating slices that have leading unit dimensions. The shape_cast
4001  // op that we create below would take bad generic fallback patterns
4002  // (ShapeCastOpRewritePattern).
4003  while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
4004  sizes[numOffsets] == 1) {
4005  ++numOffsets;
4006  }
4007 
4008  SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
4009  auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4010  Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
4011  extractOffsets);
4012  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
4013  return success();
4014  }
4015 };
4016 
4017 } // namespace
4018 
4019 void ExtractStridedSliceOp::getCanonicalizationPatterns(
4020  RewritePatternSet &results, MLIRContext *context) {
4021  // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
4022  // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4023  results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
4024  StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4025  context);
4026 }
4027 
4028 //===----------------------------------------------------------------------===//
4029 // TransferReadOp
4030 //===----------------------------------------------------------------------===//
4031 
4032 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
4033 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4034  VectorType vectorType, Value source,
4035  ValueRange indices, AffineMapAttr permutationMapAttr,
4036  /*optional*/ ArrayAttr inBoundsAttr) {
4037  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4038  Value padding = builder.create<arith::ConstantOp>(
4039  result.location, elemType, builder.getZeroAttr(elemType));
4040  build(builder, result, vectorType, source, indices, permutationMapAttr,
4041  padding, /*mask=*/Value(), inBoundsAttr);
4042 }
4043 
4044 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
4045 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4046  VectorType vectorType, Value source,
4047  ValueRange indices, AffineMap permutationMap,
4048  std::optional<ArrayRef<bool>> inBounds) {
4049  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4050  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4051  ? builder.getBoolArrayAttr(inBounds.value())
4052  : builder.getBoolArrayAttr(
4053  SmallVector<bool>(vectorType.getRank(), false));
4054  build(builder, result, vectorType, source, indices, permutationMapAttr,
4055  inBoundsAttr);
4056 }
4057 
4058 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
4059 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4060  VectorType vectorType, Value source,
4061  ValueRange indices, Value padding,
4062  std::optional<ArrayRef<bool>> inBounds) {
4063  AffineMap permutationMap = getTransferMinorIdentityMap(
4064  llvm::cast<ShapedType>(source.getType()), vectorType);
4065  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4066  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4067  ? builder.getBoolArrayAttr(inBounds.value())
4068  : builder.getBoolArrayAttr(
4069  SmallVector<bool>(vectorType.getRank(), false));
4070  build(builder, result, vectorType, source, indices, permutationMapAttr,
4071  padding,
4072  /*mask=*/Value(), inBoundsAttr);
4073 }
4074 
4075 /// 4. Builder that sets padding to zero and permutation map to
4076 /// 'getMinorIdentityMap'.
4077 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4078  VectorType vectorType, Value source,
4079  ValueRange indices,
4080  std::optional<ArrayRef<bool>> inBounds) {
4081  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4082  Value padding = builder.create<arith::ConstantOp>(
4083  result.location, elemType, builder.getZeroAttr(elemType));
4084  build(builder, result, vectorType, source, indices, padding, inBounds);
4085 }
4086 
4087 template <typename EmitFun>
4088 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
4089  EmitFun emitOpError) {
4090  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
4091  for (auto expr : permutationMap.getResults()) {
4092  auto dim = dyn_cast<AffineDimExpr>(expr);
4093  auto zero = dyn_cast<AffineConstantExpr>(expr);
4094  if (zero) {
4095  if (zero.getValue() != 0) {
4096  return emitOpError(
4097  "requires a projected permutation_map (at most one dim or the zero "
4098  "constant can appear in each result)");
4099  }
4100  continue;
4101  }
4102  if (!dim) {
4103  return emitOpError("requires a projected permutation_map (at most one "
4104  "dim or the zero constant can appear in each result)");
4105  }
4106  if (seen[dim.getPosition()]) {
4107  return emitOpError(
4108  "requires a permutation_map that is a permutation (found one dim "
4109  "used more than once)");
4110  }
4111  seen[dim.getPosition()] = true;
4112  }
4113  return success();
4114 }
4115 
4116 static LogicalResult
4117 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4118  VectorType vectorType, VectorType maskType,
4119  VectorType inferredMaskType, AffineMap permutationMap,
4120  ArrayAttr inBounds) {
4121  if (op->hasAttr("masked")) {
4122  return op->emitOpError("masked attribute has been removed. "
4123  "Use in_bounds instead.");
4124  }
4125 
4126  if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4127  return op->emitOpError(
4128  "requires source to be a memref or ranked tensor type");
4129 
4130  auto elementType = shapedType.getElementType();
4131  DataLayout dataLayout = DataLayout::closest(op);
4132  if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4133  // Memref or tensor has vector element type.
4134  unsigned sourceVecSize =
4135  dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
4136  vectorElementType.getShape().back();
4137  unsigned resultVecSize =
4138  dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
4139  vectorType.getShape().back();
4140  if (resultVecSize % sourceVecSize != 0)
4141  return op->emitOpError(
4142  "requires the bitwidth of the minor 1-D vector to be an integral "
4143  "multiple of the bitwidth of the minor 1-D vector of the source");
4144 
4145  unsigned sourceVecEltRank = vectorElementType.getRank();
4146  unsigned resultVecRank = vectorType.getRank();
4147  if (sourceVecEltRank > resultVecRank)
4148  return op->emitOpError(
4149  "requires source vector element and vector result ranks to match.");
4150  unsigned rankOffset = resultVecRank - sourceVecEltRank;
4151  // Check that permutation map results match 'rankOffset' of vector type.
4152  if (permutationMap.getNumResults() != rankOffset)
4153  return op->emitOpError("requires a permutation_map with result dims of "
4154  "the same rank as the vector type");
4155 
4156  if (maskType)
4157  return op->emitOpError("does not support masks with vector element type");
4158  } else {
4159  // Memref or tensor has scalar element type.
4160  unsigned minorSize =
4161  vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4162  unsigned resultVecSize =
4163  dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
4164  if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
4165  return op->emitOpError(
4166  "requires the bitwidth of the minor 1-D vector to be an integral "
4167  "multiple of the bitwidth of the source element type");
4168 
4169  // Check that permutation map results match rank of vector type.
4170  if (permutationMap.getNumResults() != vectorType.getRank())
4171  return op->emitOpError("requires a permutation_map with result dims of "
4172  "the same rank as the vector type");
4173  }
4174 
4175  if (permutationMap.getNumSymbols() != 0)
4176  return op->emitOpError("requires permutation_map without symbols");
4177 
4178  if (permutationMap.getNumInputs() != shapedType.getRank())
4179  return op->emitOpError("requires a permutation_map with input dims of the "
4180  "same rank as the source type");
4181 
4182  if (maskType && maskType != inferredMaskType)
4183  return op->emitOpError("inferred mask type (")
4184  << inferredMaskType << ") and mask operand type (" << maskType
4185  << ") don't match";
4186 
4187  if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
4188  return op->emitOpError("expects the in_bounds attr of same rank "
4189  "as permutation_map results: ")
4190  << AffineMapAttr::get(permutationMap)
4191  << " vs inBounds of size: " << inBounds.size();
4192 
4193  return success();
4194 }
4195 
4196 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
4197  SmallVector<StringRef, 3> elidedAttrs;
4198  elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4199  if (op.getPermutationMap().isMinorIdentity())
4200  elidedAttrs.push_back(op.getPermutationMapAttrName());
4201  // Elide in_bounds attribute if all dims are out-of-bounds.
4202  if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
4203  elidedAttrs.push_back(op.getInBoundsAttrName());
4204  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4205 }
4206 
4208  p << " " << getBase() << "[" << getIndices() << "], " << getPadding();
4209  if (getMask())
4210  p << ", " << getMask();
4211  printTransferAttrs(p, *this);
4212  p << " : " << getShapedType() << ", " << getVectorType();
4213 }
4214 
4215 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
4216  AffineMap permMap) {
4217  auto i1Type = IntegerType::get(permMap.getContext(), 1);
4218  AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
4219  assert(invPermMap && "Inversed permutation map couldn't be computed");
4220  SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
4221 
4222  // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
4223  // 0-D mask into a single-element 1-D mask.
4224  if (maskShape.empty())
4225  maskShape.push_back(1);
4226 
4227  SmallVector<bool> scalableDims =
4228  applyPermutationMap(invPermMap, vecType.getScalableDims());
4229 
4230  return VectorType::get(maskShape, i1Type, scalableDims);
4231 }
4232 
4233 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4234  auto &builder = parser.getBuilder();
4235  SMLoc typesLoc;
4236  OpAsmParser::UnresolvedOperand sourceInfo;
4238  OpAsmParser::UnresolvedOperand paddingInfo;
4239  SmallVector<Type, 2> types;
4241  // Parsing with support for paddingValue.
4242  if (parser.parseOperand(sourceInfo) ||
4243  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
4244  parser.parseComma() || parser.parseOperand(paddingInfo))
4245  return failure();
4246  ParseResult hasMask = parser.parseOptionalComma();
4247  if (hasMask.succeeded()) {
4248  if (parser.parseOperand(maskInfo))
4249  return failure();
4250  }
4251  if (parser.parseOptionalAttrDict(result.attributes) ||
4252  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4253  return failure();
4254  if (types.size() != 2)
4255  return parser.emitError(typesLoc, "requires two types");
4256  auto indexType = builder.getIndexType();
4257  auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4258  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4259  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4260  VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4261  if (!vectorType)
4262  return parser.emitError(typesLoc, "requires vector type");
4263  auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
4264  Attribute permMapAttr = result.attributes.get(permMapAttrName);
4265  AffineMap permMap;
4266  if (!permMapAttr) {
4267  if (shapedType.getRank() <
4268  getEffectiveVectorRankForXferOp(shapedType, vectorType))
4269  return parser.emitError(typesLoc,
4270  "expected a custom permutation_map when "
4271  "rank(source) != rank(destination)");
4272  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4273  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4274  } else {
4275  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4276  }
4277  auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
4278  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4279  if (!inBoundsAttr) {
4280  result.addAttribute(inBoundsAttrName,
4281  builder.getBoolArrayAttr(
4282  SmallVector<bool>(permMap.getNumResults(), false)));
4283  }
4284  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4285  parser.resolveOperands(indexInfo, indexType, result.operands) ||
4286  parser.resolveOperand(paddingInfo, shapedType.getElementType(),
4287  result.operands))
4288  return failure();
4289  if (hasMask.succeeded()) {
4290  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4291  return parser.emitError(
4292  maskInfo.location, "does not support masks with vector element type");
4293  if (vectorType.getRank() != permMap.getNumResults()) {
4294  return parser.emitError(typesLoc,
4295  "expected the same rank for the vector and the "
4296  "results of the permutation map");
4297  }
4298  // Instead of adding the mask type as an op type, compute it based on the
4299  // vector type and the permutation map (to keep the type signature small).
4300  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4301  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4302  return failure();
4303  }
4304  result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4305  builder.getDenseI32ArrayAttr(
4306  {1, static_cast<int32_t>(indexInfo.size()), 1,
4307  static_cast<int32_t>(hasMask.succeeded())}));
4308  return parser.addTypeToList(vectorType, result.types);
4309 }
4310 
4311 LogicalResult TransferReadOp::verify() {
4312  // Consistency of elemental types in source and vector.
4313  ShapedType shapedType = getShapedType();
4314  VectorType vectorType = getVectorType();
4315  VectorType maskType = getMaskType();
4316  auto paddingType = getPadding().getType();
4317  auto permutationMap = getPermutationMap();
4318  VectorType inferredMaskType =
4319  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4320  : VectorType();
4321  auto sourceElementType = shapedType.getElementType();
4322 
4323  if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
4324  return emitOpError("requires ") << shapedType.getRank() << " indices";
4325 
4326  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4327  shapedType, vectorType, maskType,
4328  inferredMaskType, permutationMap, getInBounds())))
4329  return failure();
4330 
4331  if (auto sourceVectorElementType =
4332  llvm::dyn_cast<VectorType>(sourceElementType)) {
4333  // Source has vector element type.
4334  // Check that 'sourceVectorElementType' and 'paddingType' types match.
4335  if (sourceVectorElementType != paddingType)
4336  return emitOpError(
4337  "requires source element type and padding type to match.");
4338 
4339  } else {
4340  // Check that 'paddingType' is valid to store in a vector type.
4341  if (!VectorType::isValidElementType(paddingType))
4342  return emitOpError("requires valid padding vector elemental type");
4343 
4344  // Check that padding type and vector element types match.
4345  if (paddingType != sourceElementType)
4346  return emitOpError(
4347  "requires formal padding and source of the same elemental type");
4348  }
4349 
4350  return verifyPermutationMap(permutationMap,
4351  [&](Twine t) { return emitOpError(t); });
4352 }
4353 
4354 // MaskableOpInterface methods.
4355 
4356 /// Returns the mask type expected by this operation. Mostly used for
4357 /// verification purposes. It requires the operation to be vectorized."
4358 Type TransferReadOp::getExpectedMaskType() {
4359  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4360 }
4361 
4362 //===----------------------------------------------------------------------===//
4363 // TransferReadOp: VectorTransferOpInterface methods.
4364 //===----------------------------------------------------------------------===//
4365 VectorType TransferReadOp::getVectorType() {
4366  return cast<VectorType>(getVector().getType());
4367 }
4368 
4369 template <typename TransferOp>
4370 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4371  // TODO: support more aggressive createOrFold on:
4372  // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4373  if (op.getShapedType().isDynamicDim(indicesIdx))
4374  return false;
4375  Value index = op.getIndices()[indicesIdx];
4376  std::optional<int64_t> cstOp = getConstantIntValue(index);
4377  if (!cstOp.has_value())
4378  return false;
4379 
4380  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4381  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4382 
4383  return cstOp.value() + vectorSize <= sourceSize;
4384 }
4385 
4386 template <typename TransferOp>
4387 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4388  // TODO: support 0-d corner case.
4389  // TODO: Be less conservative.
4390  if (op.getTransferRank() == 0)
4391  return failure();
4392  AffineMap permutationMap = op.getPermutationMap();
4393  bool changed = false;
4394  SmallVector<bool, 4> newInBounds;
4395  newInBounds.reserve(op.getTransferRank());
4396  // Idxs of non-bcast dims - used when analysing bcast dims.
4397  SmallVector<unsigned> nonBcastDims;
4398 
4399  // 1. Process non-broadcast dims
4400  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
4401  // 1.1. Already marked as in-bounds, nothing to see here.
4402  if (op.isDimInBounds(i)) {
4403  newInBounds.push_back(true);
4404  continue;
4405  }
4406  // 1.2. Currently out-of-bounds, check whether we can statically determine
4407  // it is inBounds.
4408  bool inBounds = false;
4409  auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
4410  if (dimExpr) {
4411  inBounds = isInBounds(op, /*resultIdx=*/i,
4412  /*indicesIdx=*/dimExpr.getPosition());
4413  nonBcastDims.push_back(i);
4414  }
4415 
4416  newInBounds.push_back(inBounds);
4417  // We commit the pattern if it is "more inbounds".
4418  changed |= inBounds;
4419  }
4420 
4421  // 2. Handle broadcast dims
4422  // If all non-broadcast dims are "in bounds", then all bcast dims should be
4423  // "in bounds" as well.
4424  bool allNonBcastDimsInBounds = llvm::all_of(
4425  nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
4426  if (allNonBcastDimsInBounds) {
4427  for (size_t idx : permutationMap.getBroadcastDims()) {
4428  changed |= !newInBounds[idx];
4429  newInBounds[idx] = true;
4430  }
4431  }
4432 
4433  if (!changed)
4434  return failure();
4435  // OpBuilder is only used as a helper to build an I64ArrayAttr.
4436  OpBuilder b(op.getContext());
4437  op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
4438  return success();
4439 }
4440 
4441 template <typename TransferOp>
4442 static LogicalResult foldTransferFullMask(TransferOp op) {
4443  auto mask = op.getMask();
4444  if (!mask)
4445  return failure();
4446 
4447  if (getMaskFormat(mask) != MaskFormat::AllTrue)
4448  return failure();
4449 
4450  op.getMaskMutable().clear();
4451  return success();
4452 }
4453 
4454 /// ```
4455 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4456 /// : vector<1x4xf32>, tensor<4x4xf32>
4457 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
4458 /// : tensor<4x4xf32>, vector<1x4xf32>
4459 /// ```
4460 /// -> Folds into
4461 /// ```
4462 /// %v0
4463 /// ```
4464 static Value foldRAW(TransferReadOp readOp) {
4465  if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4466  return {};
4467  auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
4468  while (defWrite) {
4469  if (checkSameValueRAW(defWrite, readOp))
4470  return defWrite.getVector();
4472  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4473  cast<VectorTransferOpInterface>(readOp.getOperation())))
4474  break;
4475  defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
4476  }
4477  return {};
4478 }
4479 
4480 OpFoldResult TransferReadOp::fold(FoldAdaptor) {
4481  if (Value vec = foldRAW(*this))
4482  return vec;
4483  /// transfer_read(memrefcast) -> transfer_read
4484  if (succeeded(foldTransferInBoundsAttribute(*this)))
4485  return getResult();
4486  if (succeeded(foldTransferFullMask(*this)))
4487  return getResult();
4488  if (succeeded(memref::foldMemRefCast(*this)))
4489  return getResult();
4490  if (succeeded(tensor::foldTensorCast(*this)))
4491  return getResult();
4492  return OpFoldResult();
4493 }
4494 
4495 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4496  return llvm::to_vector<4>(getVectorType().getShape());
4497 }
4498 
4499 void TransferReadOp::getEffects(
4501  &effects) {
4502  if (llvm::isa<MemRefType>(getShapedType()))
4503  effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
4505 }
4506 
4507 Speculation::Speculatability TransferReadOp::getSpeculatability() {
4508  if (hasPureTensorSemantics())
4511 }
4512 
4513 namespace {
4514 /// Store to load forwarding for transfer operations with permuation maps.
4515 /// Even if the permutation maps are different we can still propagate the store
4516 /// into the load if the size of the dimensions read and written match. Then we
4517 /// can replace the transfer_read + transfer_write by vector.broadcast and
4518 /// vector.transpose.
4519 /// Example:
4520 /// ```
4521 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
4522 /// {in_bounds = [true, true],
4523 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
4524 /// vector<4x1xf32>, tensor<4x4x4xf32>
4525 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
4526 /// {in_bounds = [true, true, true, true],
4527 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
4528 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
4529 /// ```
4530 /// To:
4531 /// ```
4532 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
4533 /// %r = vector.transpose %0, [3, 0, 2, 1] :
4534 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
4535 /// ```
4536 struct TransferReadAfterWriteToBroadcast
4537  : public OpRewritePattern<TransferReadOp> {
4539 
4540  LogicalResult matchAndRewrite(TransferReadOp readOp,
4541  PatternRewriter &rewriter) const override {
4542  if (readOp.hasOutOfBoundsDim() ||
4543  !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4544  return failure();
4545  auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
4546  if (!defWrite)
4547  return failure();
4548  // TODO: If the written transfer chunk is a superset of the read transfer
4549  // chunk we could do an extract_strided_slice.
4550  if (readOp.getTransferChunkAccessed() !=
4551  defWrite.getTransferChunkAccessed())
4552  return failure();
4553  // TODO: Support cases where a dim is explicitly written but implicitly
4554  // read (i.e., a unit dim that is rank reduced).
4555  if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
4556  getUnusedDimsBitVector({defWrite.getPermutationMap()}))
4557  return failure();
4558  if (readOp.getIndices() != defWrite.getIndices() ||
4559  readOp.getMask() != defWrite.getMask())
4560  return failure();
4561  Value vec = defWrite.getVector();
4562  // TODO: loop through the chain of transfer_write if we can prove that they
4563  // don't overlap with the transfer_read. This requires improving
4564  // `isDisjointTransferIndices` helper.
4565  AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4566  AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
4567  AffineMap map = readMap.compose(writeMap);
4568  if (map.getNumResults() == 0)
4569  return failure();
4570  // Calculate the permutation to apply to go from the vector stored to the
4571  // vector read.
4572  SmallVector<unsigned> permutation;
4573  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
4574  return failure();
4575 
4576  Location loc = readOp.getLoc();
4577  // Calculate the broadcast shape by applying the reverse permutation to the
4578  // final shape we want.
4579  ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
4580  SmallVector<int64_t> broadcastShape(destShape.size());
4581  SmallVector<bool> broadcastScalableFlags(destShape.size());
4582  for (const auto &pos : llvm::enumerate(permutation)) {
4583  broadcastShape[pos.value()] = destShape[pos.index()];
4584  broadcastScalableFlags[pos.value()] =
4585  readOp.getVectorType().getScalableDims()[pos.index()];
4586  }
4587  VectorType broadcastedType = VectorType::get(
4588  broadcastShape, defWrite.getVectorType().getElementType(),
4589  broadcastScalableFlags);
4590  vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
4591  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
4592  rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
4593  transposePerm);
4594  return success();
4595  }
4596 };
4597 } // namespace
4598 
4599 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4600  MLIRContext *context) {
4601  results.add<TransferReadAfterWriteToBroadcast>(context);
4602 }
4603 
4604 //===----------------------------------------------------------------------===//
4605 // TransferWriteOp
4606 //===----------------------------------------------------------------------===//
4607 
4608 /// 1. Builder with type inference.
4609 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4610  Value vector, Value dest, ValueRange indices,
4611  AffineMapAttr permutationMapAttr,
4612  /*optional*/ Value mask,
4613  /*optional*/ ArrayAttr inBoundsAttr) {
4614  Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
4615  build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4616  mask, inBoundsAttr);
4617 }
4618 
4619 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
4620 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4621  Value vector, Value dest, ValueRange indices,
4622  AffineMapAttr permutationMapAttr,
4623  /*optional*/ ArrayAttr inBoundsAttr) {
4624  build(builder, result, vector, dest, indices, permutationMapAttr,
4625  /*mask=*/Value(), inBoundsAttr);
4626 }
4627 
4628 /// 3. Builder with type inference that sets an empty mask (variant without
4629 /// attrs)
4630 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4631  Value vector, Value dest, ValueRange indices,
4632  AffineMap permutationMap,
4633  std::optional<ArrayRef<bool>> inBounds) {
4634  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4635  auto inBoundsAttr =
4636  (inBounds && !inBounds.value().empty())
4637  ? builder.getBoolArrayAttr(inBounds.value())
4639  llvm::cast<VectorType>(vector.getType()).getRank(), false));
4640  build(builder, result, vector, dest, indices, permutationMapAttr,
4641  /*mask=*/Value(), inBoundsAttr);
4642 }
4643 
4644 /// 4. Builder with type inference that sets an empty mask and sets permutation
4645 /// map to 'getMinorIdentityMap'.
4646 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4647  Value vector, Value dest, ValueRange indices,
4648  std::optional<ArrayRef<bool>> inBounds) {
4649  auto vectorType = llvm::cast<VectorType>(vector.getType());
4650  AffineMap permutationMap = getTransferMinorIdentityMap(
4651  llvm::cast<ShapedType>(dest.getType()), vectorType);
4652  build(builder, result, vector, dest, indices, permutationMap, inBounds);
4653 }
4654 
4655 ParseResult TransferWriteOp::parse(OpAsmParser &parser,
4656  OperationState &result) {
4657  auto &builder = parser.getBuilder();
4658  SMLoc typesLoc;
4659  OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
4661  SmallVector<Type, 2> types;
4663  if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
4664  parser.parseOperand(sourceInfo) ||
4665  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
4666  return failure();
4667  ParseResult hasMask = parser.parseOptionalComma();
4668  if (hasMask.succeeded() && parser.parseOperand(maskInfo))
4669  return failure();
4670  if (parser.parseOptionalAttrDict(result.attributes) ||
4671  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4672  return failure();
4673  if (types.size() != 2)
4674  return parser.emitError(typesLoc, "requires two types");
4675  auto indexType = builder.getIndexType();
4676  VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4677  if (!vectorType)
4678  return parser.emitError(typesLoc, "requires vector type");
4679  ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4680  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4681  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4682  auto permMapAttrName =
4683  TransferWriteOp::getPermutationMapAttrName(result.name);
4684  auto permMapAttr = result.attributes.get(permMapAttrName);
4685  AffineMap permMap;
4686  if (!permMapAttr) {
4687  if (shapedType.getRank() <
4688  getEffectiveVectorRankForXferOp(shapedType, vectorType))
4689  return parser.emitError(typesLoc,
4690  "expected a custom permutation_map when "
4691  "rank(source) != rank(destination)");
4692  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4693  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4694  } else {
4695  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4696  }
4697  auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
4698  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4699  if (!inBoundsAttr) {
4700  result.addAttribute(inBoundsAttrName,
4701  builder.getBoolArrayAttr(
4702  SmallVector<bool>(permMap.getNumResults(), false)));
4703  }
4704  if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
4705  parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4706  parser.resolveOperands(indexInfo, indexType, result.operands))
4707  return failure();
4708  if (hasMask.succeeded()) {
4709  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4710  return parser.emitError(
4711  maskInfo.location, "does not support masks with vector element type");
4712  if (vectorType.getRank() != permMap.getNumResults()) {
4713  return parser.emitError(typesLoc,
4714  "expected the same rank for the vector and the "
4715  "results of the permutation map");
4716  }
4717  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4718  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4719  return failure();
4720  }
4721  result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4722  builder.getDenseI32ArrayAttr(
4723  {1, 1, static_cast<int32_t>(indexInfo.size()),
4724  static_cast<int32_t>(hasMask.succeeded())}));
4725  return failure(llvm::isa<RankedTensorType>(shapedType) &&
4726  parser.addTypeToList(shapedType, result.types));
4727 }
4728 
4730  p << " " << getVector() << ", " << getBase() << "[" << getIndices() << "]";
4731  if (getMask())
4732  p << ", " << getMask();
4733  printTransferAttrs(p, *this);
4734  p << " : " << getVectorType() << ", " << getShapedType();
4735 }
4736 
4737 LogicalResult TransferWriteOp::verify() {
4738  // Consistency of elemental types in shape and vector.
4739  ShapedType shapedType = getShapedType();
4740  VectorType vectorType = getVectorType();
4741  VectorType maskType = getMaskType();
4742  auto permutationMap = getPermutationMap();
4743  VectorType inferredMaskType =
4744  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4745  : VectorType();
4746 
4747  if (llvm::size(getIndices()) != shapedType.getRank())
4748  return emitOpError("requires ") << shapedType.getRank() << " indices";
4749 
4750  // We do not allow broadcast dimensions on TransferWriteOps for the moment,
4751  // as the semantics is unclear. This can be revisited later if necessary.
4752  if (hasBroadcastDim())
4753  return emitOpError("should not have broadcast dimensions");
4754 
4755  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4756  shapedType, vectorType, maskType,
4757  inferredMaskType, permutationMap, getInBounds())))
4758  return failure();
4759 
4760  return verifyPermutationMap(permutationMap,
4761  [&](Twine t) { return emitOpError(t); });
4762 }
4763 
4764 //===----------------------------------------------------------------------===//
4765 // TransferWriteOp: MaskableOpInterface methods.
4766 //===----------------------------------------------------------------------===//
4767 
4768 /// Returns the mask type expected by this operation. Mostly used for
4769 /// verification purposes.
4770 Type TransferWriteOp::getExpectedMaskType() {
4771  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4772 }
4773 
4774 //===----------------------------------------------------------------------===//
4775 // TransferWriteOp: VectorTransferOpInterface methods.
4776 //===----------------------------------------------------------------------===//
4777 Value TransferWriteOp::getVector() { return getOperand(0); }
4778 VectorType TransferWriteOp::getVectorType() {
4779  return cast<VectorType>(getValueToStore().getType());
4780 }
4781 
4782 //===----------------------------------------------------------------------===//
4783 // TransferWriteOp: fold methods.
4784 //===----------------------------------------------------------------------===//
4785 /// Fold:
4786 /// ```
4787 /// %t1 = ...
4788 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
4789 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4790 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
4791 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4792 /// ```
4793 ///
4794 /// into:
4795 ///
4796 /// ```
4797 /// %t0
4798 /// ```
4799 ///
4800 /// The producer of t1 may or may not be DCE'd depending on whether it is a
4801 /// block argument or has side effects.
4802 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4804  SmallVectorImpl<OpFoldResult> &results) {
4805  // TODO: support 0-d corner case.
4806  if (write.getTransferRank() == 0)
4807  return failure();
4808  auto rankedTensorType =
4809  llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
4810  // If not operating on tensors, bail.
4811  if (!rankedTensorType)
4812  return failure();
4813  // If no read, bail.
4814  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4815  if (!read)
4816  return failure();
4817  // TODO: support 0-d corner case.
4818  if (read.getTransferRank() == 0)
4819  return failure();
4820  // For now, only accept minor identity. Future: composition is minor identity.
4821  if (!read.getPermutationMap().isMinorIdentity() ||
4822  !write.getPermutationMap().isMinorIdentity())
4823  return failure();
4824  // Bail on mismatching ranks.
4825  if (read.getTransferRank() != write.getTransferRank())
4826  return failure();
4827  // Bail on potential out-of-bounds accesses.
4828  if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4829  return failure();
4830  // Tensor types must be the same.
4831  if (read.getBase().getType() != rankedTensorType)
4832  return failure();
4833  // Vector types must be the same.
4834  if (read.getVectorType() != write.getVectorType())
4835  return failure();
4836  // Vector and Tensor shapes must match.
4837  if (read.getVectorType().getShape() != rankedTensorType.getShape())
4838  return failure();
4839  // If any index is nonzero.
4840  auto isNotConstantZero = [](Value v) {
4841  auto cstOp = getConstantIntValue(v);
4842  return !cstOp.has_value() || cstOp.value() != 0;
4843  };
4844  if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4845  llvm::any_of(write.getIndices(), isNotConstantZero))
4846  return failure();
4847  // Success.
4848  results.push_back(read.getBase());
4849  return success();
4850 }
4851 
4852 static bool checkSameValueWAR(vector::TransferReadOp read,
4853  vector::TransferWriteOp write) {
4854  return read.getBase() == write.getBase() &&
4855  read.getIndices() == write.getIndices() &&
4856  read.getPermutationMap() == write.getPermutationMap() &&
4857  read.getVectorType() == write.getVectorType() && !read.getMask() &&
4858  !write.getMask();
4859 }
4860 /// Fold transfer_write write after read:
4861 /// ```
4862 /// %t0 = ...
4863 /// %v = vector.transfer_read %t0[%c0...] :
4864 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4865 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
4866 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4867 /// ```
4868 ///
4869 /// into:
4870 ///
4871 /// ```
4872 /// %t0
4873 /// ```
4874 static LogicalResult foldWAR(TransferWriteOp write,
4875  SmallVectorImpl<OpFoldResult> &results) {
4876  if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
4877  return failure();
4878  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4879  if (!read)
4880  return failure();
4881 
4882  if (!checkSameValueWAR(read, write))
4883  return failure();
4884  results.push_back(read.getBase());
4885  return success();
4886 }
4887 
4888 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4889  SmallVectorImpl<OpFoldResult> &results) {
4890  if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
4891  return success();
4892  if (succeeded(foldWAR(*this, results)))
4893  return success();
4894  if (succeeded(foldTransferInBoundsAttribute(*this)))
4895  return success();
4896  if (succeeded(foldTransferFullMask(*this)))
4897  return success();
4898  return memref::foldMemRefCast(*this);
4899 }
4900 
4901 //===----------------------------------------------------------------------===//
4902 // TransferWriteOp: other methods.
4903 //===----------------------------------------------------------------------===//
4904 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4905  return llvm::to_vector<4>(getVectorType().getShape());
4906 }
4907 
4908 void TransferWriteOp::getEffects(
4910  &effects) {
4911  if (llvm::isa<MemRefType>(getShapedType()))
4912  effects.emplace_back(MemoryEffects::Write::get(), &getValueToStoreMutable(),
4914 }
4915 
4916 Speculation::Speculatability TransferWriteOp::getSpeculatability() {
4917  if (hasPureTensorSemantics())
4920 }
4921 
4922 namespace {
4923 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
4924 /// DCE
4925 /// ```
4926 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4927 /// : vector<1x4xf32>, tensor<4x4xf32>
4928 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
4929 /// : vector<1x4xf32>, tensor<4x4xf32>
4930 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4931 /// : vector<1x4xf32>, tensor<4x4xf32>
4932 /// ```
4933 ///
4934 /// into:
4935 ///
4936 /// ```
4937 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4938 /// : vector<1x4xf32>, tensor<4x4xf32>
4939 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
4940 /// : vector<1x4xf32>, tensor<4x4xf32>
4941 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4942 /// : vector<1x4xf32>, tensor<4x4xf32>
4943 /// ```
4944 ///
4945 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
4946 /// any other uses.
4947 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
4948 public:
4950  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
4951  PatternRewriter &rewriter) const override {
4952  if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4953  return failure();
4954  vector::TransferWriteOp writeToModify = writeOp;
4955 
4956  auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
4957  while (defWrite) {
4958  if (checkSameValueWAW(writeOp, defWrite)) {
4959  rewriter.modifyOpInPlace(writeToModify, [&]() {
4960  writeToModify.getBaseMutable().assign(defWrite.getBase());
4961  });
4962  return success();
4963  }
4965  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4966  cast<VectorTransferOpInterface>(writeOp.getOperation())))
4967  break;
4968  // If the previous write op doesn't have any other use we an safely look
4969  // at the previous store to see if it can be removed.
4970  if (!defWrite->hasOneUse())
4971  break;
4972  writeToModify = defWrite;
4973  defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
4974  }
4975  return failure();
4976  }
4977 };
4978 
4979 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
4980 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
4981 /// overwritten and inserted into another tensor. After this rewrite, the
4982 /// operations bufferize in-place since all of them work on the same slice.
4983 ///
4984 /// For example:
4985 /// ```mlir
4986 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
4987 /// : vector<8x16xf32>, tensor<8x16xf32>
4988 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
4989 /// : tensor<8x16xf32> to tensor<?x?xf32>
4990 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4991 /// : tensor<?x?xf32> into tensor<27x37xf32>
4992 /// ```
4993 /// folds to
4994 /// ```mlir
4995 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4996 /// : tensor<27x37xf32> to tensor<?x?xf32>
4997 /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
4998 /// : vector<8x16xf32>, tensor<?x?xf32>
4999 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5000 /// : tensor<?x?xf32> into tensor<27x37xf32>
5001 /// ```
5002 struct SwapExtractSliceOfTransferWrite
5003  : public OpRewritePattern<tensor::InsertSliceOp> {
5004 public:
5006 
5007  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5008  PatternRewriter &rewriter) const override {
5009  if (!insertOp.hasUnitStride())
5010  return failure();
5011  auto extractOp =
5012  insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5013  if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5014  return failure();
5015  auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5016  if (!transferOp || !transferOp->hasOneUse())
5017  return failure();
5018 
5019  // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
5020  // rank-reducing.
5021  if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5022  return rewriter.notifyMatchFailure(insertOp,
5023  "use-def chain is rank-reducing");
5024  }
5025 
5026  // Fail if tensor::ExtractSliceOp has non-zero offset.
5027  if (!extractOp.hasZeroOffset()) {
5028  return rewriter.notifyMatchFailure(insertOp,
5029  "ExtractSliceOp has non-zero offset");
5030  }
5031 
5032  // Fail if tensor::TransferWriteOp has non-zero offset.
5033  if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5034  return getConstantIntValue(value) == static_cast<int64_t>(0);
5035  })) {
5036  return rewriter.notifyMatchFailure(insertOp,
5037  "TranferWriteOp has non-zero offset");
5038  }
5039 
5040  // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
5041  if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5042  return rewriter.notifyMatchFailure(
5043  insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
5044  }
5045 
5046  for (auto [insertSize, extractSize] :
5047  llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5048  if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
5049  return rewriter.notifyMatchFailure(
5050  insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
5051  }
5052  }
5053 
5054  // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
5055  assert(transferOp.getVectorType().hasStaticShape() &&
5056  "expected vector to have a static shape");
5057  ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
5059  transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5060  if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
5061  return rewriter.notifyMatchFailure(
5062  insertOp, "TransferWriteOp may not write the full tensor.");
5063  }
5064 
5065  // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
5066  // Set all in_bounds to false and let the folder infer them.
5067  SmallVector<bool> newInBounds(vectorShape.size(), false);
5068  auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
5069  extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
5070  insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
5071  insertOp.getMixedStrides());
5072  auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
5073  transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
5074  transferOp.getIndices(), transferOp.getPermutationMapAttr(),
5075  rewriter.getBoolArrayAttr(newInBounds));
5076  rewriter.modifyOpInPlace(insertOp, [&]() {
5077  insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5078  });
5079  return success();
5080  }
5081 };
5082 
5083 } // namespace
5084 
5085 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5086  MLIRContext *context) {
5087  results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5088 }
5089 
5090 //===----------------------------------------------------------------------===//
5091 // LoadOp
5092 //===----------------------------------------------------------------------===//
5093 
5094 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5095  VectorType vecTy,
5096  MemRefType memRefTy) {
5097  // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
5098  // need any strides limitations.
5099  if (!vecTy.isScalable() &&
5100  (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5101  return success();
5102 
5103  if (!memRefTy.isLastDimUnitStride())
5104  return op->emitOpError("most minor memref dim must have unit stride");
5105  return success();
5106 }
5107 
5108 LogicalResult vector::LoadOp::verify() {
5109  VectorType resVecTy = getVectorType();
5110  MemRefType memRefTy = getMemRefType();
5111 
5112  if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
5113  return failure();
5114 
5115  if (memRefTy.getRank() < resVecTy.getRank())
5116  return emitOpError(
5117  "destination memref has lower rank than the result vector");
5118 
5119  // Checks for vector memrefs.
5120  Type memElemTy = memRefTy.getElementType();
5121  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5122  if (memVecTy != resVecTy)
5123  return emitOpError("base memref and result vector types should match");
5124  memElemTy = memVecTy.getElementType();
5125  }
5126 
5127  if (resVecTy.getElementType() != memElemTy)
5128  return emitOpError("base and result element types should match");
5129  if (llvm::size(getIndices()) != memRefTy.getRank())
5130  return emitOpError("requires ") << memRefTy.getRank() << " indices";
5131  return success();
5132 }
5133 
5134 OpFoldResult LoadOp::fold(FoldAdaptor) {
5135  if (succeeded(memref::foldMemRefCast(*this)))
5136  return getResult();
5137  return OpFoldResult();
5138 }
5139 
5140 //===----------------------------------------------------------------------===//
5141 // StoreOp
5142 //===----------------------------------------------------------------------===//
5143 
5144 LogicalResult vector::StoreOp::verify() {
5145  VectorType valueVecTy = getVectorType();
5146  MemRefType memRefTy = getMemRefType();
5147 
5148  if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
5149  return failure();
5150 
5151  if (memRefTy.getRank() < valueVecTy.getRank())
5152  return emitOpError("source memref has lower rank than the vector to store");
5153 
5154  // Checks for vector memrefs.
5155  Type memElemTy = memRefTy.getElementType();
5156  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5157  if (memVecTy != valueVecTy)
5158  return emitOpError(
5159  "base memref and valueToStore vector types should match");
5160  memElemTy = memVecTy.getElementType();
5161  }
5162 
5163  if (valueVecTy.getElementType() != memElemTy)
5164  return emitOpError("base and valueToStore element type should match");
5165  if (llvm::size(getIndices()) != memRefTy.getRank())
5166  return emitOpError("requires ") << memRefTy.getRank() << " indices";
5167  return success();
5168 }
5169 
5170 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5171  SmallVectorImpl<OpFoldResult> &results) {
5172  return memref::foldMemRefCast(*this);
5173 }
5174 
5175 //===----------------------------------------------------------------------===//
5176 // MaskedLoadOp
5177 //===----------------------------------------------------------------------===//
5178 
5179 LogicalResult MaskedLoadOp::verify() {
5180  VectorType maskVType = getMaskVectorType();
5181  VectorType passVType = getPassThruVectorType();
5182  VectorType resVType = getVectorType();
5183  MemRefType memType = getMemRefType();
5184 
5185  if (resVType.getElementType() != memType.getElementType())
5186  return emitOpError("base and result element type should match");
5187  if (llvm::size(getIndices()) != memType.getRank())
5188  return emitOpError("requires ") << memType.getRank() << " indices";
5189  if (resVType.getShape() != maskVType.getShape())
5190  return emitOpError("expected result shape to match mask shape");
5191  if (resVType != passVType)
5192  return emitOpError("expected pass_thru of same type as result type");
5193  return success();
5194 }
5195 
5196 namespace {
5197 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
5198 public:
5200  LogicalResult matchAndRewrite(MaskedLoadOp load,
5201  PatternRewriter &rewriter) const override {
5202  switch (getMaskFormat(load.getMask())) {
5203  case MaskFormat::AllTrue:
5204  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5205  load, load.getType(), load.getBase(), load.getIndices());
5206  return success();
5207  case MaskFormat::AllFalse:
5208  rewriter.replaceOp(load, load.getPassThru());
5209  return success();
5210  case MaskFormat::Unknown:
5211  return failure();
5212  }
5213  llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
5214  }
5215 };
5216 } // namespace
5217 
5218 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5219  MLIRContext *context) {
5220  results.add<MaskedLoadFolder>(context);
5221 }
5222 
5223 OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5224  if (succeeded(memref::foldMemRefCast(*this)))
5225  return getResult();
5226  return OpFoldResult();
5227 }
5228 
5229 //===----------------------------------------------------------------------===//
5230 // MaskedStoreOp
5231 //===----------------------------------------------------------------------===//
5232 
5233 LogicalResult MaskedStoreOp::verify() {
5234  VectorType maskVType = getMaskVectorType();
5235  VectorType valueVType = getVectorType();
5236  MemRefType memType = getMemRefType();
5237 
5238  if (valueVType.getElementType() != memType.getElementType())
5239  return emitOpError("base and valueToStore element type should match");
5240  if (llvm::size(getIndices()) != memType.getRank())
5241  return emitOpError("requires ") << memType.getRank() << " indices";
5242  if (valueVType.getShape() != maskVType.getShape())
5243  return emitOpError("expected valueToStore shape to match mask shape");
5244  return success();
5245 }
5246 
5247 namespace {
5248 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
5249 public:
5251  LogicalResult matchAndRewrite(MaskedStoreOp store,
5252  PatternRewriter &rewriter) const override {
5253  switch (getMaskFormat(store.getMask())) {
5254  case MaskFormat::AllTrue:
5255  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5256  store, store.getValueToStore(), store.getBase(), store.getIndices());
5257  return success();
5258  case MaskFormat::AllFalse:
5259  rewriter.eraseOp(store);
5260  return success();
5261  case MaskFormat::Unknown:
5262  return failure();
5263  }
5264  llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
5265  }
5266 };
5267 } // namespace
5268 
5269 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5270  MLIRContext *context) {
5271  results.add<MaskedStoreFolder>(context);
5272 }
5273 
5274 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5275  SmallVectorImpl<OpFoldResult> &results) {
5276  return memref::foldMemRefCast(*this);
5277 }
5278 
5279 //===----------------------------------------------------------------------===//
5280 // GatherOp
5281 //===----------------------------------------------------------------------===//
5282 
5283 LogicalResult GatherOp::verify() {
5284  VectorType indVType = getIndexVectorType();
5285  VectorType maskVType = getMaskVectorType();
5286  VectorType resVType = getVectorType();
5287  ShapedType baseType = getBaseType();
5288 
5289  if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5290  return emitOpError("requires base to be a memref or ranked tensor type");
5291 
5292  if (resVType.getElementType() != baseType.getElementType())
5293  return emitOpError("base and result element type should match");
5294  if (llvm::size(getIndices()) != baseType.getRank())
5295  return emitOpError("requires ") << baseType.getRank() << " indices";
5296  if (resVType.getShape() != indVType.getShape())
5297  return emitOpError("expected result dim to match indices dim");
5298  if (resVType.getShape() != maskVType.getShape())
5299  return emitOpError("expected result dim to match mask dim");
5300  if (resVType != getPassThruVectorType())
5301  return emitOpError("expected pass_thru of same type as result type");
5302  return success();
5303 }
5304 
5305 // MaskableOpInterface methods.
5306 
5307 /// Returns the mask type expected by this operation. Mostly used for
5308 /// verification purposes. It requires the operation to be vectorized."
5309 Type GatherOp::getExpectedMaskType() {
5310  auto vecType = this->getIndexVectorType();
5311  return VectorType::get(vecType.getShape(),
5312  IntegerType::get(vecType.getContext(), /*width=*/1),
5313  vecType.getScalableDims());
5314 }
5315 
5316 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5317  return llvm::to_vector<4>(getVectorType().getShape());
5318 }
5319 
5320 /// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
5321 static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
5322  auto vecType = dyn_cast<VectorType>(indexVec.getType());
5323  if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5324  return failure();
5325 
5326  if (indexVec.getDefiningOp<StepOp>())
5327  return success();
5328 
5329  DenseIntElementsAttr elements;
5330  if (!matchPattern(indexVec, m_Constant(&elements)))
5331  return failure();
5332 
5333  return success(
5334  llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
5335 }
5336 
5337 namespace {
5338 class GatherFolder final : public OpRewritePattern<GatherOp> {
5339 public:
5341  LogicalResult matchAndRewrite(GatherOp gather,
5342  PatternRewriter &rewriter) const override {
5343  switch (getMaskFormat(gather.getMask())) {
5344  case MaskFormat::AllTrue:
5345  return failure(); // no unmasked equivalent
5346  case MaskFormat::AllFalse:
5347  rewriter.replaceOp(gather, gather.getPassThru());
5348  return success();
5349  case MaskFormat::Unknown:
5350  return failure();
5351  }
5352  llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
5353  }
5354 };
5355 
5356 /// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
5357 /// maskedload. Only 1D fixed vectors are supported for now.
5358 class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
5359 public:
5361  LogicalResult matchAndRewrite(GatherOp op,
5362  PatternRewriter &rewriter) const override {
5363  if (!isa<MemRefType>(op.getBase().getType()))
5364  return rewriter.notifyMatchFailure(op, "base must be of memref type");
5365 
5366  if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5367  return failure();
5368 
5369  rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
5370  op.getIndices(), op.getMask(),
5371  op.getPassThru());
5372  return success();
5373  }
5374 };
5375 } // namespace
5376 
5377 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
5378  MLIRContext *context) {
5379  results.add<GatherFolder, FoldContiguousGather>(context);
5380 }
5381 
5382 //===----------------------------------------------------------------------===//
5383 // ScatterOp
5384 //===----------------------------------------------------------------------===//
5385 
5386 LogicalResult ScatterOp::verify() {
5387  VectorType indVType = getIndexVectorType();
5388  VectorType maskVType = getMaskVectorType();
5389  VectorType valueVType = getVectorType();
5390  MemRefType memType = getMemRefType();
5391 
5392  if (valueVType.getElementType() != memType.getElementType())
5393  return emitOpError("base and valueToStore element type should match");
5394  if (llvm::size(getIndices()) != memType.getRank())
5395  return emitOpError("requires ") << memType.getRank() << " indices";
5396  if (valueVType.getShape() != indVType.getShape())
5397  return emitOpError("expected valueToStore dim to match indices dim");
5398  if (valueVType.getShape() != maskVType.getShape())
5399  return emitOpError("expected valueToStore dim to match mask dim");
5400  return success();
5401 }
5402 
5403 namespace {
5404 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
5405 public:
5407  LogicalResult matchAndRewrite(ScatterOp scatter,
5408  PatternRewriter &rewriter) const override {
5409  switch (getMaskFormat(scatter.getMask())) {
5410  case MaskFormat::AllTrue:
5411  return failure(); // no unmasked equivalent
5412  case MaskFormat::AllFalse:
5413  rewriter.eraseOp(scatter);
5414  return success();
5415  case MaskFormat::Unknown:
5416  return failure();
5417  }
5418  llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
5419  }
5420 };
5421 
5422 /// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
5423 /// maskedstore. Only 1D fixed vectors are supported for now.
5424 class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
5425 public:
5427  LogicalResult matchAndRewrite(ScatterOp op,
5428  PatternRewriter &rewriter) const override {
5429  if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5430  return failure();
5431 
5432  rewriter.replaceOpWithNewOp<MaskedStoreOp>(
5433  op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
5434  return success();
5435  }
5436 };
5437 } // namespace
5438 
5439 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
5440  MLIRContext *context) {
5441  results.add<ScatterFolder, FoldContiguousScatter>(context);
5442 }
5443 
5444 //===----------------------------------------------------------------------===//
5445 // ExpandLoadOp
5446 //===----------------------------------------------------------------------===//
5447 
5448 LogicalResult ExpandLoadOp::verify() {
5449  VectorType maskVType = getMaskVectorType();
5450  VectorType passVType = getPassThruVectorType();
5451  VectorType resVType = getVectorType();
5452  MemRefType memType = getMemRefType();
5453 
5454  if (resVType.getElementType() != memType.getElementType())
5455  return emitOpError("base and result element type should match");
5456  if (llvm::size(getIndices()) != memType.getRank())
5457  return emitOpError("requires ") << memType.getRank() << " indices";
5458  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5459  return emitOpError("expected result dim to match mask dim");
5460  if (resVType != passVType)
5461  return emitOpError("expected pass_thru of same type as result type");
5462  return success();
5463 }
5464 
5465 namespace {
5466 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
5467 public:
5469  LogicalResult matchAndRewrite(ExpandLoadOp expand,
5470  PatternRewriter &rewriter) const override {
5471  switch (getMaskFormat(expand.getMask())) {
5472  case MaskFormat::AllTrue:
5473  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5474  expand, expand.getType(), expand.getBase(), expand.getIndices());
5475  return success();
5476  case MaskFormat::AllFalse:
5477  rewriter.replaceOp(expand, expand.getPassThru());
5478  return success();
5479  case MaskFormat::Unknown:
5480  return failure();
5481  }
5482  llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
5483  }
5484 };
5485 } // namespace
5486 
5487 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5488  MLIRContext *context) {
5489  results.add<ExpandLoadFolder>(context);
5490 }
5491 
5492 //===----------------------------------------------------------------------===//
5493 // CompressStoreOp
5494 //===----------------------------------------------------------------------===//
5495 
5496 LogicalResult CompressStoreOp::verify() {
5497  VectorType maskVType = getMaskVectorType();
5498  VectorType valueVType = getVectorType();
5499  MemRefType memType = getMemRefType();
5500 
5501  if (valueVType.getElementType() != memType.getElementType())
5502  return emitOpError("base and valueToStore element type should match");
5503  if (llvm::size(getIndices()) != memType.getRank())
5504  return emitOpError("requires ") << memType.getRank() << " indices";
5505  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5506  return emitOpError("expected valueToStore dim to match mask dim");
5507  return success();
5508 }
5509 
5510 namespace {
5511 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
5512 public:
5514  LogicalResult matchAndRewrite(CompressStoreOp compress,
5515  PatternRewriter &rewriter) const override {
5516  switch (getMaskFormat(compress.getMask())) {
5517  case MaskFormat::AllTrue:
5518  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5519  compress, compress.getValueToStore(), compress.getBase(),
5520  compress.getIndices());
5521  return success();
5522  case MaskFormat::AllFalse:
5523  rewriter.eraseOp(compress);
5524  return success();
5525  case MaskFormat::Unknown:
5526  return failure();
5527  }
5528  llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
5529  }
5530 };
5531 } // namespace
5532 
5533 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5534  MLIRContext *context) {
5535  results.add<CompressStoreFolder>(context);
5536 }
5537 
5538 //===----------------------------------------------------------------------===//
5539 // ShapeCastOp
5540 //===----------------------------------------------------------------------===//
5541 
5542 void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
5543  SetIntRangeFn setResultRanges) {
5544  setResultRanges(getResult(), argRanges.front());
5545 }
5546 
5547 LogicalResult ShapeCastOp::verify() {
5548 
5549  VectorType sourceType = getSourceVectorType();
5550  VectorType resultType = getResultVectorType();
5551 
5552  // Check that element type is preserved
5553  if (sourceType.getElementType() != resultType.getElementType())
5554  return emitOpError("has different source and result element types");
5555 
5556  // Check that number of elements is preserved
5557  int64_t sourceNElms = sourceType.getNumElements();
5558  int64_t resultNElms = resultType.getNumElements();
5559  if (sourceNElms != resultNElms) {
5560  return emitOpError() << "has different number of elements at source ("
5561  << sourceNElms << ") and result (" << resultNElms
5562  << ")";
5563  }
5564 
5565  // Check that (non-)scalability is preserved
5566  int64_t sourceNScalableDims = sourceType.getNumScalableDims();
5567  int64_t resultNScalableDims = resultType.getNumScalableDims();
5568  if (sourceNScalableDims != resultNScalableDims)
5569  return emitOpError() << "has different number of scalable dims at source ("
5570  << sourceNScalableDims << ") and result ("
5571  << resultNScalableDims << ")";
5572 
5573  return success();
5574 }
5575 
5576 /// Return true if `transpose` does not permute a pair of non-unit dims.
5577 /// By `order preserving` we mean that the flattened versions of the input and
5578 /// output vectors are (numerically) identical. In other words `transpose` is
5579 /// effectively a shape cast.
5580 static bool isOrderPreserving(TransposeOp transpose) {
5581  ArrayRef<int64_t> permutation = transpose.getPermutation();
5582  VectorType sourceType = transpose.getSourceVectorType();
5583  ArrayRef<int64_t> inShape = sourceType.getShape();
5584  ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
5585  auto isNonScalableUnitDim = [&](int64_t dim) {
5586  return inShape[dim] == 1 && !inDimIsScalable[dim];
5587  };
5588  int64_t current = 0;
5589  for (auto p : permutation) {
5590  if (!isNonScalableUnitDim(p)) {
5591  if (p < current) {
5592  return false;
5593  }
5594  current = p;
5595  }
5596  }
5597  return true;
5598 }
5599 
5600 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5601 
5602  VectorType resultType = getType();
5603 
5604  // No-op shape cast.
5605  if (getSource().getType() == resultType)
5606  return getSource();
5607 
5608  // shape_cast(shape_cast(x)) -> shape_cast(x)
5609  if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
5610  setOperand(precedingShapeCast.getSource());
5611  return getResult();
5612  }
5613 
5614  // shape_cast(transpose(x)) -> shape_cast(x)
5615  if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
5616  // This folder does
5617  // shape_cast(transpose) -> shape_cast
5618  // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
5619  // shape_cast -> shape_cast(transpose)
5620  // i.e. the complete opposite. When paired, these 2 patterns can cause
5621  // infinite cycles in pattern rewriting.
5622  // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
5623  // vectors, so by disabling this folder for scalable vectors the
5624  // cycle is avoided.
5625  // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
5626  // still needed. If it's not, then we can fold here.
5627  if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
5628  setOperand(transpose.getVector());
5629  return getResult();
5630  }
5631  return {};
5632  }
5633 
5634  // Y = shape_cast(broadcast(X))
5635  // -> X, if X and Y have same type
5636  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5637  if (bcastOp.getSourceType() == resultType)
5638  return bcastOp.getSource();
5639  }
5640 
5641  // shape_cast(constant) -> constant
5642  if (auto splatAttr =
5643  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
5644  return splatAttr.reshape(getType());
5645 
5646  // shape_cast(poison) -> poison
5647  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
5648  return ub::PoisonAttr::get(getContext());
5649  }
5650 
5651  return {};
5652 }
5653 
5654 namespace {
5655 
5656 /// Helper function that computes a new vector type based on the input vector
5657 /// type by removing the trailing one dims:
5658 ///
5659 /// vector<4x1x1xi1> --> vector<4x1xi1>
5660 ///
5661 static VectorType trimTrailingOneDims(VectorType oldType) {
5662  ArrayRef<int64_t> oldShape = oldType.getShape();
5663  ArrayRef<int64_t> newShape = oldShape;
5664 
5665  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
5666  ArrayRef<bool> newScalableDims = oldScalableDims;
5667 
5668  while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5669  newShape = newShape.drop_back(1);
5670  newScalableDims = newScalableDims.drop_back(1);
5671  }
5672 
5673  // Make sure we have at least 1 dimension.
5674  // TODO: Add support for 0-D vectors.
5675  if (newShape.empty()) {
5676  newShape = oldShape.take_back();
5677  newScalableDims = oldScalableDims.take_back();
5678  }
5679 
5680  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5681 }
5682 
5683 /// Folds qualifying shape_cast(create_mask) into a new create_mask
5684 ///
5685 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
5686 /// dimension. If the input vector comes from `vector.create_mask` for which
5687 /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
5688 /// to fold shape_cast into create_mask.
5689 ///
5690 /// BEFORE:
5691 /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
5692 /// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
5693 /// AFTER:
5694 /// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
5695 class ShapeCastCreateMaskFolderTrailingOneDim final
5696  : public OpRewritePattern<ShapeCastOp> {
5697 public:
5699 
5700  LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
5701  PatternRewriter &rewriter) const override {
5702  Value shapeOpSrc = shapeOp->getOperand(0);
5703  auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
5704  auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
5705  if (!createMaskOp && !constantMaskOp)
5706  return failure();
5707 
5708  VectorType shapeOpResTy = shapeOp.getResultVectorType();
5709  VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5710 
5711  VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5712  if (newVecType != shapeOpResTy)
5713  return failure();
5714 
5715  auto numDimsToDrop =
5716  shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5717 
5718  // No unit dims to drop
5719  if (!numDimsToDrop)
5720  return failure();
5721 
5722  if (createMaskOp) {
5723  auto maskOperands = createMaskOp.getOperands();
5724  auto numMaskOperands = maskOperands.size();
5725 
5726  // Check every mask dim size to see whether it can be dropped
5727  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5728  --i) {
5729  auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5730  if (!constant || (constant.value() != 1))
5731  return failure();
5732  }
5733  SmallVector<Value> newMaskOperands =
5734  maskOperands.drop_back(numDimsToDrop);
5735 
5736  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
5737  newMaskOperands);
5738  return success();
5739  }
5740 
5741  if (constantMaskOp) {
5742  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5743  auto numMaskOperands = maskDimSizes.size();
5744 
5745  // Check every mask dim size to see whether it can be dropped
5746  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5747  --i) {
5748  if (maskDimSizes[i] != 1)
5749  return failure();
5750  }
5751 
5752  auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5753  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
5754  newMaskOperands);
5755  return success();
5756  }
5757 
5758  return failure();
5759  }
5760 };
5761 
5762 /// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
5763 /// i) Y = ShapeCast(X), or
5764 /// ii) Y = Broadcast(X)
5765 /// If both (i) and (ii) are possible, (i) is chosen.
5766 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5767 public:
5769 
5770  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5771  PatternRewriter &rewriter) const override {
5772  auto broadcastOp =
5773  shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5774  if (!broadcastOp)
5775  return failure();
5776 
5777  auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
5778  bool srcIsScalar = !srcVectorType;
5779 
5780  // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
5781  // Example:
5782  // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
5783  // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
5784  // to
5785  // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
5786  if (srcVectorType) {
5787  if (srcVectorType.getNumElements() ==
5788  shapeCastOp.getResultVectorType().getNumElements()) {
5789  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
5790  shapeCastOp, shapeCastOp.getResultVectorType(),
5791  broadcastOp.getSource());
5792  return success();
5793  }
5794  }
5795 
5796  // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
5797  // Example
5798  // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
5799  // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
5800  // to
5801  // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
5802  VectorType dstVectorType = shapeCastOp.getResultVectorType();
5803  if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
5804  BroadcastableToResult::Success) {
5805  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5806  shapeCastOp, dstVectorType, broadcastOp.getSource());
5807  return success();
5808  }
5809  return failure();
5810  }
5811 };
5812 
5813 } // namespace
5814 
5815 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
5816  MLIRContext *context) {
5817  results
5818  .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
5819  context);
5820 }
5821 
5822 //===----------------------------------------------------------------------===//
5823 // VectorBitCastOp
5824 //===----------------------------------------------------------------------===//
5825 
5826 LogicalResult BitCastOp::verify() {
5827  auto sourceVectorType = getSourceVectorType();
5828  auto resultVectorType = getResultVectorType();
5829 
5830  for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5831  if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5832  return emitOpError("dimension size mismatch at: ") << i;
5833  }
5834 
5835  DataLayout dataLayout = DataLayout::closest(*this);
5836  auto sourceElementBits =
5837  dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
5838  auto resultElementBits =
5839  dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
5840 
5841  if (sourceVectorType.getRank() == 0) {
5842  if (sourceElementBits != resultElementBits)
5843  return emitOpError("source/result bitwidth of the 0-D vector element "
5844  "types must be equal");
5845  } else if (sourceElementBits * sourceVectorType.getShape().back() !=
5846  resultElementBits * resultVectorType.getShape().back()) {
5847  return emitOpError(
5848  "source/result bitwidth of the minor 1-D vectors must be equal");
5849  }
5850 
5851  return success();
5852 }
5853 
5854 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
5855  // Nop cast.
5856  if (getSource().getType() == getResult().getType())
5857  return getSource();
5858 
5859  // Canceling bitcasts.
5860  if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5861  if (getResult().getType() == otherOp.getSource().getType())
5862  return otherOp.getSource();
5863 
5864  setOperand(otherOp.getSource());
5865  return getResult();
5866  }
5867 
5868  Attribute sourceConstant = adaptor.getSource();
5869  if (!sourceConstant)
5870  return {};
5871 
5872  Type srcElemType = getSourceVectorType().getElementType();
5873  Type dstElemType = getResultVectorType().getElementType();
5874 
5875  if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5876  if (floatPack.isSplat()) {
5877  auto splat = floatPack.getSplatValue<FloatAttr>();
5878 
5879  // Casting fp16 into fp32.
5880  if (srcElemType.isF16() && dstElemType.isF32()) {
5881  uint32_t bits = static_cast<uint32_t>(
5882  splat.getValue().bitcastToAPInt().getZExtValue());
5883  // Duplicate the 16-bit pattern.
5884  bits = (bits << 16) | (bits & 0xffff);
5885  APInt intBits(32, bits);
5886  APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5887  return DenseElementsAttr::get(getResultVectorType(), floatBits);
5888  }
5889  }
5890  }
5891 
5892  if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5893  if (intPack.isSplat()) {
5894  auto splat = intPack.getSplatValue<IntegerAttr>();
5895 
5896  if (llvm::isa<IntegerType>(dstElemType)) {
5897  uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
5898  uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
5899 
5900  // Casting to a larger integer bit width.
5901  if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5902  APInt intBits = splat.getValue().zext(dstBitWidth);
5903 
5904  // Duplicate the lower width element.
5905  for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5906  intBits = (intBits << srcBitWidth) | intBits;
5907  return DenseElementsAttr::get(getResultVectorType(), intBits);
5908  }
5909  }
5910  }
5911  }
5912 
5913  return {};
5914 }
5915 
5916 //===----------------------------------------------------------------------===//
5917 // TypeCastOp
5918 //===----------------------------------------------------------------------===//
5919 
5920 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
5921  auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5922  SmallVector<int64_t, 8> res(memRefType.getShape());
5923  if (vectorType)
5924  res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5925  return res;
5926 }
5927 
5928 /// Build the canonical memRefType with a single vector.
5929 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
5930 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
5931  Value source) {
5932  result.addOperands(source);
5933  MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
5934  VectorType vectorType =
5935  VectorType::get(extractShape(memRefType),
5937  result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
5938  memRefType.getMemorySpace()));
5939 }
5940 
5941 LogicalResult TypeCastOp::verify() {
5942  MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
5943  if (!canonicalType.getLayout().isIdentity())
5944  return emitOpError("expects operand to be a memref with identity layout");
5945  if (!getResultMemRefType().getLayout().isIdentity())
5946  return emitOpError("expects result to be a memref with identity layout");
5947  if (getResultMemRefType().getMemorySpace() !=
5948  getMemRefType().getMemorySpace())
5949  return emitOpError("expects result in same memory space");
5950 
5951  auto sourceType = getMemRefType();
5952  auto resultType = getResultMemRefType();
5953  if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
5955  return emitOpError(
5956  "expects result and operand with same underlying scalar type: ")
5957  << resultType;
5958  if (extractShape(sourceType) != extractShape(resultType))
5959  return emitOpError(
5960  "expects concatenated result and operand shapes to be equal: ")
5961  << resultType;
5962  return success();
5963 }
5964 
5965 //===----------------------------------------------------------------------===//
5966 // TransposeOp
5967 //===----------------------------------------------------------------------===//
5968 
5969 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
5970  Value vector, ArrayRef<int64_t> permutation) {
5971  VectorType vt = llvm::cast<VectorType>(vector.getType());
5972  SmallVector<int64_t, 4> transposedShape(vt.getRank());
5973  SmallVector<bool, 4> transposedScalableDims(vt.getRank());
5974  for (unsigned i = 0; i < permutation.size(); ++i) {
5975  transposedShape[i] = vt.getShape()[permutation[i]];
5976  transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5977  }
5978 
5979  result.addOperands(vector);
5980  result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
5981  transposedScalableDims));
5982  result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
5983  builder.getDenseI64ArrayAttr(permutation));
5984 }
5985 
5986 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5987  // Eliminate splat constant transpose ops.
5988  if (auto splat =
5989  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
5990  return splat.reshape(getResultVectorType());
5991 
5992  // Eliminate poison transpose ops.
5993  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
5994  return ub::PoisonAttr::get(getContext());
5995 
5996  // Eliminate identity transposes, and more generally any transposes that
5997  // preserves the shape without permuting elements.
5998  //
5999  // Examples of what to fold:
6000  // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6001  // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6002  // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6003  //
6004  // Example of what NOT to fold:
6005  // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6006  //
6007  if (getSourceVectorType() == getResultVectorType() &&
6008  isOrderPreserving(*this))
6009  return getVector();
6010 
6011  return {};
6012 }
6013 
6014 LogicalResult vector::TransposeOp::verify() {
6015  VectorType vectorType = getSourceVectorType();
6016  VectorType resultType = getResultVectorType();
6017  int64_t rank = resultType.getRank();
6018  if (vectorType.getRank() != rank)
6019  return emitOpError("vector result rank mismatch: ") << rank;
6020  // Verify transposition array.
6021  ArrayRef<int64_t> perm = getPermutation();
6022  int64_t size = perm.size();
6023  if (rank != size)
6024  return emitOpError("transposition length mismatch: ") << size;
6025  SmallVector<bool, 8> seen(rank, false);
6026  for (const auto &ta : llvm::enumerate(perm)) {
6027  if (ta.value() < 0 || ta.value() >= rank)
6028  return emitOpError("transposition index out of range: ") << ta.value();
6029  if (seen[ta.value()])
6030  return emitOpError("duplicate position index: ") << ta.value();
6031  seen[ta.value()] = true;
6032  if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6033  return emitOpError("dimension size mismatch at: ") << ta.value();
6034  }
6035  return success();
6036 }
6037 
6038 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6039  return llvm::to_vector<4>(getResultVectorType().getShape());
6040 }
6041 
6042 namespace {
6043 
6044 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
6045 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
6046 public:
6048 
6049  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6050  PatternRewriter &rewriter) const override {
6051  // Composes two permutations: result[i] = permutation1[permutation2[i]].
6052  auto composePermutations = [](ArrayRef<int64_t> permutation1,
6053  ArrayRef<int64_t> permutation2) {
6054  SmallVector<int64_t, 4> result;
6055  for (auto index : permutation2)
6056  result.push_back(permutation1[index]);
6057  return result;
6058  };
6059 
6060  // Return if the input of 'transposeOp' is not defined by another transpose.
6061  vector::TransposeOp parentTransposeOp =
6062  transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6063  if (!parentTransposeOp)
6064  return failure();
6065 
6066  SmallVector<int64_t, 4> permutation = composePermutations(
6067  parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6068  // Replace 'transposeOp' with a new transpose operation.
6069  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
6070  transposeOp, transposeOp.getResult().getType(),
6071  parentTransposeOp.getVector(), permutation);
6072  return success();
6073  }
6074 };
6075 
6076 // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
6077 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
6078 public:
6080 
6081  LogicalResult matchAndRewrite(TransposeOp transposeOp,
6082  PatternRewriter &rewriter) const override {
6083  auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
6084  if (!splatOp)
6085  return failure();
6086 
6087  rewriter.replaceOpWithNewOp<vector::SplatOp>(
6088  transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
6089  return success();
6090  }
6091 };
6092 
6093 /// Folds transpose(create_mask) into a new transposed create_mask.
6094 class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6095 public:
6097 
6098  LogicalResult matchAndRewrite(TransposeOp transpOp,
6099  PatternRewriter &rewriter) const override {
6100  Value transposeSrc = transpOp.getVector();
6101  auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
6102  auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
6103  if (!createMaskOp && !constantMaskOp)
6104  return failure();
6105 
6106  // Get the transpose permutation and apply it to the vector.create_mask or
6107  // vector.constant_mask operands.
6108  ArrayRef<int64_t> permutation = transpOp.getPermutation();
6109 
6110  if (createMaskOp) {
6111  auto maskOperands = createMaskOp.getOperands();
6112  SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
6113  applyPermutationToVector(newOperands, permutation);
6114 
6115  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
6116  transpOp, transpOp.getResultVectorType(), newOperands);
6117  return success();
6118  }
6119 
6120  // ConstantMaskOp case.
6121  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6122  auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
6123 
6124  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
6125  transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6126  return success();
6127  }
6128 };
6129 
6130 /// Folds transpose(shape_cast) into a new shape_cast.
6131 class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6132 public:
6134 
6135  LogicalResult matchAndRewrite(TransposeOp transposeOp,
6136  PatternRewriter &rewriter) const override {
6137  auto shapeCastOp =
6138  transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6139  if (!shapeCastOp)
6140  return failure();
6141  if (!isOrderPreserving(transposeOp))
6142  return failure();
6143 
6144  VectorType resultType = transposeOp.getType();
6145 
6146  // We don't need to check isValidShapeCast at this point, because it is
6147  // guaranteed that merging the transpose into the the shape_cast is a valid
6148  // shape_cast, because the transpose just inserts/removes ones.
6149 
6150  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
6151  shapeCastOp.getSource());
6152  return success();
6153  }
6154 };
6155 
6156 /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
6157 /// 'order preserving', where 'order preserving' means the flattened
6158 /// inputs and outputs of the transpose have identical (numerical) values.
6159 ///
6160 /// Example:
6161 /// ```
6162 /// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
6163 /// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
6164 /// to vector<8x1xi32>
6165 /// ```
6166 /// can be rewritten as the equivalent
6167 /// ```
6168 /// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
6169 /// ```
6170 /// The algorithm works by partitioning dimensions into groups that can be
6171 /// locally permuted while preserving order, and checks that the transpose
6172 /// only permutes within these groups.
6173 ///
6174 /// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
6175 /// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
6176 /// broadcasting from 1x1x4x1x1x7.
6177 /// ^^^ ^ ^^^ ^
6178 /// groups: 0 1 2 3
6179 /// Order preserving permutations for this example are ones that only permute
6180 /// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
6181 class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6182 public:
6184  FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
6185  : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
6186 
6187  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
6188  PatternRewriter &rewriter) const override {
6189 
6190  vector::BroadcastOp broadcast =
6191  transpose.getVector().getDefiningOp<vector::BroadcastOp>();
6192  if (!broadcast) {
6193  return rewriter.notifyMatchFailure(transpose,
6194  "not preceded by a broadcast");
6195  }
6196 
6197  auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
6198  VectorType outputType = transpose.getResultVectorType();
6199 
6200  // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
6201  bool inputIsScalar = !inputType;
6202  if (inputIsScalar) {
6203  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
6204  broadcast.getSource());
6205  return success();
6206  }
6207 
6208  ArrayRef<int64_t> permutation = transpose.getPermutation();
6209  ArrayRef<int64_t> inputShape = inputType.getShape();
6210  int64_t inputRank = inputType.getRank();
6211  int64_t outputRank = transpose.getType().getRank();
6212  int64_t deltaRank = outputRank - inputRank;
6213 
6214  int low = 0;
6215  for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
6216  bool notOne = inputShape[inputIndex] != 1;
6217  bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
6218  bool groupEndFound = notOne || prevNotOne;
6219  if (groupEndFound) {
6220  int high = inputIndex + deltaRank;
6221  // Return failure if not all permutation destinations for indices in
6222  // [low, high) are in [low, high), i.e. the permutation is not local to
6223  // the group.
6224  for (int i = low; i < high; ++i) {
6225  if (permutation[i] < low || permutation[i] >= high) {
6226  return rewriter.notifyMatchFailure(
6227  transpose, "permutation not local to group");
6228  }
6229  }
6230  low = high;
6231  }
6232  }
6233 
6234  // We don't need to check the final group [low, outputRank) because if it is
6235  // not locally bound, there must be a preceding group that already failed
6236  // the check (impossible to have just 1 non-locally bound group).
6237 
6238  // The preceding logic also ensures that at this point, the output of the
6239  // transpose is definitely broadcastable from the input shape, assert so:
6240  assert(vector::isBroadcastableTo(inputType, outputType) ==
6241  vector::BroadcastableToResult::Success &&
6242  "not broadcastable directly to transpose output");
6243 
6244  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
6245  broadcast.getSource());
6246 
6247  return success();
6248  }
6249 };
6250 
6251 } // namespace
6252 
6253 void vector::TransposeOp::getCanonicalizationPatterns(
6254  RewritePatternSet &results, MLIRContext *context) {
6255  results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6256  FoldTransposeSplat, FoldTransposeBroadcast>(context);
6257 }
6258 
6259 //===----------------------------------------------------------------------===//
6260 // ConstantMaskOp
6261 //===----------------------------------------------------------------------===//
6262 
6263 void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
6264  VectorType type, ConstantMaskKind kind) {
6265  assert(kind == ConstantMaskKind::AllTrue ||
6266  kind == ConstantMaskKind::AllFalse);
6267  build(builder, result, type,
6268  kind == ConstantMaskKind::AllTrue
6269  ? type.getShape()
6270  : SmallVector<int64_t>(type.getRank(), 0));
6271 }
6272 
6273 LogicalResult ConstantMaskOp::verify() {
6274  auto resultType = llvm::cast<VectorType>(getResult().getType());
6275  // Check the corner case of 0-D vectors first.
6276  if (resultType.getRank() == 0) {
6277  if (getMaskDimSizes().size() != 1)
6278  return emitError("array attr must have length 1 for 0-D vectors");
6279  auto dim = getMaskDimSizes()[0];
6280  if (dim != 0 && dim != 1)
6281  return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
6282  return success();
6283  }
6284 
6285  // Verify that array attr size matches the rank of the vector result.
6286  if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
6287  return emitOpError(
6288  "must specify array attr of size equal vector result rank");
6289  // Verify that each array attr element is in bounds of corresponding vector
6290  // result dimension size.
6291  auto resultShape = resultType.getShape();
6292  auto resultScalableDims = resultType.getScalableDims();
6293  ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
6294  for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
6295  if (maskDimSize < 0 || maskDimSize > resultShape[index])
6296  return emitOpError(
6297  "array attr of size out of bounds of vector result dimension size");
6298  if (resultScalableDims[index] && maskDimSize != 0 &&
6299  maskDimSize != resultShape[index])
6300  return emitOpError(
6301  "only supports 'none set' or 'all set' scalable dimensions");
6302  }
6303  // Verify that if one mask dim size is zero, they all should be zero (because
6304  // the mask region is a conjunction of each mask dimension interval).
6305  bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6306  bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
6307  if (anyZeros && !allZeros)
6308  return emitOpError("expected all mask dim sizes to be zeros, "
6309  "as a result of conjunction with zero mask dim");
6310  return success();
6311 }
6312 
6313 bool ConstantMaskOp::isAllOnesMask() {
6314  auto resultType = getVectorType();
6315  // Check the corner case of 0-D vectors first.
6316  if (resultType.getRank() == 0) {
6317  assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
6318  return getMaskDimSizes()[0] == 1;
6319  }
6320  for (const auto [resultSize, maskDimSize] :
6321  llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
6322  if (maskDimSize < resultSize)
6323  return false;
6324  }
6325  return true;
6326 }
6327 
6328 //===----------------------------------------------------------------------===//
6329 // CreateMaskOp
6330 //===----------------------------------------------------------------------===//
6331 
6332 void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
6333  VectorType type,
6334  ArrayRef<OpFoldResult> mixedOperands) {
6335  SmallVector<Value> operands =
6336  getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
6337  build(builder, result, type, operands);
6338 }
6339 
6340 LogicalResult CreateMaskOp::verify() {
6341  auto vectorType = llvm::cast<VectorType>(getResult().getType());
6342  // Verify that an operand was specified for each result vector each dimension.
6343  if (vectorType.getRank() == 0) {
6344  if (getNumOperands() != 1)
6345  return emitOpError(
6346  "must specify exactly one operand for 0-D create_mask");
6347  } else if (getNumOperands() !=
6348  llvm::cast<VectorType>(getResult().getType()).getRank()) {
6349  return emitOpError(
6350  "must specify an operand for each result vector dimension");
6351  }
6352  return success();
6353 }
6354 
6355 namespace {
6356 
6357 /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
6358 ///
6359 /// Ex 1:
6360 /// %c2 = arith.constant 2 : index
6361 /// %c3 = arith.constant 3 : index
6362 /// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
6363 /// Becomes:
6364 /// vector.constant_mask [3, 2] : vector<4x3xi1>
6365 ///
6366 /// Ex 2:
6367 /// %c_neg_1 = arith.constant -1 : index
6368 /// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
6369 /// becomes:
6370 /// vector.constant_mask [0] : vector<[8]xi1>
6371 ///
6372 /// Ex 3:
6373 /// %c8 = arith.constant 8 : index
6374 /// %c16 = arith.constant 16 : index
6375 /// %0 = vector.vscale
6376 /// %1 = arith.muli %0, %c16 : index
6377 /// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
6378 /// becomes:
6379 /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
6380 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
6381 public:
6383 
6384  LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
6385  PatternRewriter &rewriter) const override {
6386  VectorType maskType = createMaskOp.getVectorType();
6387  ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
6388  ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6389 
6390  // Special case: Rank zero shape.
6391  constexpr std::array<int64_t, 1> rankZeroShape{1};
6392  constexpr std::array<bool, 1> rankZeroScalableDims{false};
6393  if (maskType.getRank() == 0) {
6394  maskTypeDimSizes = rankZeroShape;
6395  maskTypeDimScalableFlags = rankZeroScalableDims;
6396  }
6397 
6398  // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
6399  // collect the `constantDims` (for the ConstantMaskOp).
6400  SmallVector<int64_t, 4> constantDims;
6401  for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
6402  if (auto intSize = getConstantIntValue(dimSize)) {
6403  // Constant value.
6404  // If the mask dim is non-scalable this can be any value.
6405  // If the mask dim is scalable only zero (all-false) is supported.
6406  if (maskTypeDimScalableFlags[i] && intSize >= 0)
6407  return failure();
6408  constantDims.push_back(*intSize);
6409  } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
6410  // Constant vscale multiple (e.g. 4 x vscale).
6411  // Must be all-true to fold to a ConstantMask.
6412  if (vscaleMultiplier < maskTypeDimSizes[i])
6413  return failure();
6414  constantDims.push_back(*vscaleMultiplier);
6415  } else {
6416  return failure();
6417  }
6418  }
6419 
6420  // Clamp values to constant_mask bounds.
6421  for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
6422  value = std::clamp<int64_t>(value, 0, maskDimSize);
6423 
6424  // If one of dim sizes is zero, set all dims to zero.
6425  if (llvm::is_contained(constantDims, 0))
6426  constantDims.assign(constantDims.size(), 0);
6427 
6428  // Replace 'createMaskOp' with ConstantMaskOp.
6429  rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
6430  constantDims);
6431  return success();
6432  }
6433 };
6434 
6435 } // namespace
6436 
6437 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6438  MLIRContext *context) {
6439  results.add<CreateMaskFolder>(context);
6440 }
6441 
6442 //===----------------------------------------------------------------------===//
6443 // MaskOp
6444 //===----------------------------------------------------------------------===//
6445 
6446 void MaskOp::build(
6447  OpBuilder &builder, OperationState &result, Value mask,
6448  Operation *maskableOp,
6449  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6450  assert(maskRegionBuilder &&
6451  "builder callback for 'maskRegion' must be present");
6452 
6453  result.addOperands(mask);
6454  OpBuilder::InsertionGuard guard(builder);
6455  Region *maskRegion = result.addRegion();
6456  builder.createBlock(maskRegion);
6457  maskRegionBuilder(builder, maskableOp);
6458 }
6459 
6460 void MaskOp::build(
6461  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6462  Value mask, Operation *maskableOp,
6463  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6464  build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
6465  maskRegionBuilder);
6466 }
6467 
6468 void MaskOp::build(
6469  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6470  Value mask, Value passthru, Operation *maskableOp,
6471  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6472  build(builder, result, mask, maskableOp, maskRegionBuilder);
6473  if (passthru)
6474  result.addOperands(passthru);
6475  result.addTypes(resultTypes);
6476 }
6477 
6478 ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
6479  // Create the op region.
6480  result.regions.reserve(1);
6481  Region &maskRegion = *result.addRegion();
6482 
6483  auto &builder = parser.getBuilder();
6484 
6485  // Parse all the operands.
6487  if (parser.parseOperand(mask))
6488  return failure();
6489 
6490  // Optional passthru operand.
6492  ParseResult parsePassthru = parser.parseOptionalComma();
6493  if (parsePassthru.succeeded() && parser.parseOperand(passthru))
6494  return failure();
6495 
6496  // Parse op region.
6497  if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
6498  return failure();
6499 
6500  MaskOp::ensureTerminator(maskRegion, builder, result.location);
6501 
6502  // Parse the optional attribute list.
6503  if (parser.parseOptionalAttrDict(result.attributes))
6504  return failure();
6505 
6506  // Parse all the types.
6507  Type maskType;
6508  if (parser.parseColonType(maskType))
6509  return failure();
6510 
6511  SmallVector<Type> resultTypes;
6512  if (parser.parseOptionalArrowTypeList(resultTypes))
6513  return failure();
6514  result.types.append(resultTypes);
6515 
6516  // Resolve operands.
6517  if (parser.resolveOperand(mask, maskType, result.operands))
6518  return failure();
6519 
6520  if (parsePassthru.succeeded()) {
6521  if (resultTypes.empty())
6522  return parser.emitError(
6523  parser.getNameLoc(),
6524  "expects a result if passthru operand is provided");
6525 
6526  if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
6527  return failure();
6528  }
6529 
6530  return success();
6531 }
6532 
6534  p << " " << getMask();
6535  if (getPassthru())
6536  p << ", " << getPassthru();
6537 
6538  // Print single masked operation and skip terminator.
6539  p << " { ";
6540  Block *singleBlock = &getMaskRegion().getBlocks().front();
6541  if (singleBlock && !singleBlock->getOperations().empty())
6542  p.printCustomOrGenericOp(&singleBlock->front());
6543  p << " }";
6544 
6545  p.printOptionalAttrDict(getOperation()->getAttrs());
6546 
6547  p << " : " << getMask().getType();
6548  if (getNumResults() > 0)
6549  p << " -> " << getResultTypes();
6550 }
6551 
6552 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
6554  MaskOp>::ensureTerminator(region, builder, loc);
6555  // Keep the default yield terminator if the number of masked operations is not
6556  // the expected. This case will trigger a verification failure.
6557  Block &block = region.front();
6558  if (block.getOperations().size() != 2)
6559  return;
6560 
6561  // Replace default yield terminator with a new one that returns the results
6562  // from the masked operation.
6563  OpBuilder opBuilder(builder.getContext());
6564  Operation *maskedOp = &block.front();
6565  Operation *oldYieldOp = &block.back();
6566  assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
6567 
6568  // Empty vector.mask op.
6569  if (maskedOp == oldYieldOp)
6570  return;
6571 
6572  opBuilder.setInsertionPoint(oldYieldOp);
6573  opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
6574  oldYieldOp->dropAllReferences();
6575  oldYieldOp->erase();
6576 }
6577 
6578 LogicalResult MaskOp::verify() {
6579  // Structural checks.
6580  Block &block = getMaskRegion().getBlocks().front();
6581  if (block.getOperations().empty())
6582  return emitOpError("expects a terminator within the mask region");
6583 
6584  unsigned numMaskRegionOps = block.getOperations().size();
6585  if (numMaskRegionOps > 2)
6586  return emitOpError("expects only one operation to mask");
6587 
6588  // Terminator checks.
6589  auto terminator = dyn_cast<vector::YieldOp>(block.back());
6590  if (!terminator)
6591  return emitOpError("expects a terminator within the mask region");
6592 
6593  if (terminator->getNumOperands() != getNumResults())
6594  return emitOpError(
6595  "expects number of results to match mask region yielded values");
6596 
6597  // Empty vector.mask. Nothing else to check.
6598  if (numMaskRegionOps == 1)
6599  return success();
6600 
6601  auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
6602  if (!maskableOp)
6603  return emitOpError("expects a MaskableOpInterface within the mask region");
6604 
6605  // Result checks.
6606  if (maskableOp->getNumResults() != getNumResults())
6607  return emitOpError("expects number of results to match maskable operation "
6608  "number of results");
6609 
6610  if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
6611  return emitOpError(
6612  "expects result type to match maskable operation result type");
6613 
6614  if (llvm::count_if(maskableOp->getResultTypes(),
6615  [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
6616  return emitOpError("multiple vector results not supported");
6617 
6618  // Mask checks.
6619  Type expectedMaskType = maskableOp.getExpectedMaskType();
6620  if (getMask().getType() != expectedMaskType)
6621  return emitOpError("expects a ")
6622  << expectedMaskType << " mask for the maskable operation";
6623 
6624  // Passthru checks.
6625  Value passthru = getPassthru();
6626  if (passthru) {
6627  if (!maskableOp.supportsPassthru())
6628  return emitOpError(
6629  "doesn't expect a passthru argument for this maskable operation");
6630 
6631  if (maskableOp->getNumResults() != 1)
6632  return emitOpError("expects result when passthru argument is provided");
6633 
6634  if (passthru.getType() != maskableOp->getResultTypes()[0])
6635  return emitOpError("expects passthru type to match result type");
6636  }
6637 
6638  return success();
6639 }
6640 
6641 /// Folds vector.mask ops with an all-true mask.
6642 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6643  SmallVectorImpl<OpFoldResult> &results) {
6644  MaskFormat maskFormat = getMaskFormat(getMask());
6645  if (isEmpty())
6646  return failure();
6647 
6648  if (maskFormat != MaskFormat::AllTrue)
6649  return failure();
6650 
6651  // Move maskable operation outside of the `vector.mask` region.
6652  Operation *maskableOp = getMaskableOp();
6653  maskableOp->dropAllUses();
6654  maskableOp->moveBefore(getOperation());
6655 
6656  llvm::append_range(results, maskableOp->getResults());
6657  return success();
6658 }
6659 
6660 // Elides empty vector.mask operations with or without return values. Propagates
6661 // the yielded values by the vector.yield terminator, if any, or erases the op,
6662 // otherwise.
6663 class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
6665 
6666  LogicalResult matchAndRewrite(MaskOp maskOp,
6667  PatternRewriter &rewriter) const override {
6668  auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6669  if (maskingOp.getMaskableOp())
6670  return failure();
6671 
6672  if (!maskOp.isEmpty())
6673  return failure();
6674 
6675  Block *block = maskOp.getMaskBlock();
6676  auto terminator = cast<vector::YieldOp>(block->front());
6677  if (terminator.getNumOperands() == 0)
6678  rewriter.eraseOp(maskOp);
6679  else
6680  rewriter.replaceOp(maskOp, terminator.getOperands());
6681 
6682  return success();
6683  }
6684 };
6685 
6686 void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6687  MLIRContext *context) {
6688  results.add<ElideEmptyMaskOp>(context);
6689 }
6690 
6691 // MaskingOpInterface definitions.
6692 
6693 /// Returns the operation masked by this 'vector.mask'.
6694 Operation *MaskOp::getMaskableOp() {
6695  Block *block = getMaskBlock();
6696  if (block->getOperations().size() < 2)
6697  return nullptr;
6698 
6699  return &block->front();
6700 }
6701 
6702 /// Returns true if 'vector.mask' has a passthru value.
6703 bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
6704 
6705 //===----------------------------------------------------------------------===//
6706 // ScanOp
6707 //===----------------------------------------------------------------------===//
6708 
6709 LogicalResult ScanOp::verify() {
6710  VectorType srcType = getSourceType();
6711  VectorType initialType = getInitialValueType();
6712  // Check reduction dimension < rank.
6713  int64_t srcRank = srcType.getRank();
6714  int64_t reductionDim = getReductionDim();
6715  if (reductionDim >= srcRank)
6716  return emitOpError("reduction dimension ")
6717  << reductionDim << " has to be less than " << srcRank;
6718 
6719  // Check that rank(initial_value) = rank(src) - 1.
6720  int64_t initialValueRank = initialType.getRank();
6721  if (initialValueRank != srcRank - 1)
6722  return emitOpError("initial value rank ")
6723  << initialValueRank << " has to be equal to " << srcRank - 1;
6724 
6725  // Check shapes of initial value and src.
6726  ArrayRef<int64_t> srcShape = srcType.getShape();
6727  ArrayRef<int64_t> initialValueShapes = initialType.getShape();
6728  SmallVector<int64_t> expectedShape;
6729  for (int i = 0; i < srcRank; i++) {
6730  if (i != reductionDim)
6731  expectedShape.push_back(srcShape[i]);
6732  }
6733  if (!llvm::equal(initialValueShapes, expectedShape)) {
6734  return emitOpError("incompatible input/initial value shapes");
6735  }
6736 
6737  // Verify supported reduction kind.
6738  Type eltType = getDestType().getElementType();
6739  if (!isSupportedCombiningKind(getKind(), eltType))
6740  return emitOpError("unsupported reduction type ")
6741  << eltType << " for kind '" << stringifyCombiningKind(getKind())
6742  << "'";
6743 
6744  return success();
6745 }
6746 
6749  patterns
6750  .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
6751  ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
6752  StridedSliceConstantMaskFolder, TransposeFolder>(
6753  patterns.getContext(), benefit);
6754 }
6755 
6756 //===----------------------------------------------------------------------===//
6757 // SplatOp
6758 //===----------------------------------------------------------------------===//
6759 
6760 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
6761  auto constOperand = adaptor.getInput();
6762  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
6763  return {};
6764 
6765  // SplatElementsAttr::get treats single value for second arg as being a splat.
6766  return SplatElementsAttr::get(getType(), {constOperand});
6767 }
6768 
6769 void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6770  SetIntRangeFn setResultRanges) {
6771  setResultRanges(getResult(), argRanges.front());
6772 }
6773 
6775  CombiningKind kind, Value v1, Value acc,
6776  arith::FastMathFlagsAttr fastmath,
6777  Value mask) {
6778  Type t1 = getElementTypeOrSelf(v1.getType());
6779  Type tAcc = getElementTypeOrSelf(acc.getType());
6780  Value result;
6781 
6782  switch (kind) {
6783  case CombiningKind::ADD:
6784  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6785  result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
6786  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6787  result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
6788  else
6789  llvm_unreachable("invalid value types for ADD reduction");
6790  break;
6791  case CombiningKind::AND:
6792  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6793  result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
6794  break;
6795  case CombiningKind::MAXNUMF:
6796  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6797  "expected float values");
6798  result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
6799  break;
6800  case CombiningKind::MAXIMUMF:
6801  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6802  "expected float values");
6803  result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
6804  break;
6805  case CombiningKind::MINNUMF:
6806  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6807  "expected float values");
6808  result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
6809  break;
6810  case CombiningKind::MINIMUMF:
6811  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6812  "expected float values");
6813  result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
6814  break;
6815  case CombiningKind::MAXSI:
6816  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6817  result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
6818  break;
6819  case CombiningKind::MINSI:
6820  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6821  result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
6822  break;
6823  case CombiningKind::MAXUI:
6824  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6825  result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
6826  break;
6827  case CombiningKind::MINUI:
6828  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6829  result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
6830  break;
6831  case CombiningKind::MUL:
6832  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6833  result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
6834  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6835  result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
6836  else
6837  llvm_unreachable("invalid value types for MUL reduction");
6838  break;
6839  case CombiningKind::OR:
6840  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6841  result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
6842  break;
6843  case CombiningKind::XOR:
6844  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6845  result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
6846  break;
6847  };
6848 
6849  assert(result && "unknown CombiningKind");
6850  return selectPassthru(b, mask, result, acc);
6851 }
6852 
6853 //===----------------------------------------------------------------------===//
6854 // Vector Masking Utilities
6855 //===----------------------------------------------------------------------===//
6856 
6857 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
6858 /// as masked operation.
6860  Operation *maskableOp) {
6861  assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
6862  Block *insBlock = builder.getInsertionBlock();
6863  // Create a block and move the op to that block.
6864  insBlock->getOperations().splice(
6865  insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
6866  builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
6867 }
6868 
6869 /// Creates a vector.mask operation around a maskable operation. Returns the
6870 /// vector.mask operation if the mask provided is valid. Otherwise, returns
6871 /// the maskable operation itself.
6873  Operation *maskableOp, Value mask,
6874  Value passthru) {
6875  if (!mask)
6876  return maskableOp;
6877  if (passthru)
6878  return builder.create<MaskOp>(maskableOp->getLoc(),
6879  maskableOp->getResultTypes(), mask, passthru,
6880  maskableOp, createMaskOpRegion);
6881  return builder.create<MaskOp>(maskableOp->getLoc(),
6882  maskableOp->getResultTypes(), mask, maskableOp,
6884 }
6885 
6886 /// Creates a vector select operation that picks values from `newValue` or
6887 /// `passthru` for each result vector lane based on `mask`. This utility is used
6888 /// to propagate the pass-thru value of vector.mask or for cases where only the
6889 /// pass-thru value propagation is needed. VP intrinsics do not support
6890 /// pass-thru values and every mask-out lane is set to poison. LLVM backends are
6891 /// usually able to match op + select patterns and fold them into a native
6892 /// target instructions.
6894  Value newValue, Value passthru) {
6895  if (!mask)
6896  return newValue;
6897 
6898  return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(),
6899  mask, newValue, passthru);
6900 }
6901 
6902 //===----------------------------------------------------------------------===//
6903 // TableGen'd op method definitions
6904 //===----------------------------------------------------------------------===//
6905 
6906 #define GET_ATTRDEF_CLASSES
6907 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
6908 
6909 #define GET_OP_CLASSES
6910 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static SmallVector< Value > computeStrides(Location loc, RewriterBase &rewriter, ValueRange dynamicBasis, ArrayRef< int64_t > staticBasis, bool knownNonNegative)
Given a basis (in static and dynamic components), return the sequence of suffix products of the basis...
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
union mlir::linalg::@1194::ArityGroupAndKind::Kind kind
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MINUI(lhs, rhs)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition: VectorOps.cpp:69
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
Definition: VectorOps.cpp:1117
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
Definition: VectorOps.cpp:1421
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
Definition: VectorOps.cpp:1682
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
Definition: VectorOps.cpp:4196
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
Definition: VectorOps.cpp:1988
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
Definition: VectorOps.cpp:1856
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
Definition: VectorOps.cpp:1693
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
Definition: VectorOps.cpp:2082
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Definition: VectorOps.cpp:2356
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
Definition: VectorOps.cpp:131
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
Definition: VectorOps.cpp:2072
MaskFormat
Helper enum to classify mask value.
Definition: VectorOps.cpp:59
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
Definition: VectorOps.cpp:3262
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
Definition: VectorOps.cpp:178
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
Definition: VectorOps.cpp:327
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
Definition: VectorOps.cpp:899
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
Definition: VectorOps.cpp:2405
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
Definition: VectorOps.cpp:2024
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
Definition: VectorOps.cpp:2725
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
Definition: VectorOps.cpp:3198
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
Definition: VectorOps.cpp:1109
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
Definition: VectorOps.cpp:3218
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
Definition: VectorOps.cpp:206
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
Definition: VectorOps.cpp:3241
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
Definition: VectorOps.cpp:3074
static LogicalResult foldTransferFullMask(TransferOp op)
Definition: VectorOps.cpp:4442
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Definition: VectorOps.cpp:1413
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite a vector.from_elements into a vector.splat if all elements are the same SSA value.
Definition: VectorOps.cpp:2379
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
Definition: VectorOps.cpp:1308
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
Definition: VectorOps.cpp:3758
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
Definition: VectorOps.cpp:4088
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
Definition: VectorOps.cpp:910
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
Definition: VectorOps.cpp:3183
static Value foldExtractFromShapeCast(ExtractOp extractOp)
Definition: VectorOps.cpp:1791
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
Definition: VectorOps.cpp:4117
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
Definition: VectorOps.cpp:4370
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
Definition: VectorOps.cpp:3602
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
Definition: VectorOps.cpp:1759
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
Definition: VectorOps.cpp:4387
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
Definition: VectorOps.cpp:1908
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
Definition: VectorOps.cpp:2090
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:135
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:216
unsigned getNumResults() const
Definition: AffineMap.cpp:402
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:161
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:58
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation & back()
Definition: Block.h:152
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:53
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:277
IndexType getIndexType()
Definition: Builders.cpp:51
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:266
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:314
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:395
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:439
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:834
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
Definition: Operation.cpp:586
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
bool isF32() const
Definition: Types.cpp:40
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:112
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:270
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Definition: BuiltinTypes.h:282
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:289
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:391
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Definition: VectorOps.cpp:481
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
Definition: VectorOps.cpp:126
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
Definition: VectorOps.cpp:385
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:188
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition: VectorOps.h:60
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2537
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Definition: VectorOps.cpp:4215
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
Definition: VectorOps.cpp:250
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:314
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:345
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:369
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
Definition: VectorOps.cpp:657
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition: VectorOps.h:68
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Definition: VectorOps.cpp:477
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:23
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:773
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:474
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:722
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:930
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
Return a fused vector::ContractionOp which represents a patterns such as:
Definition: VectorOps.cpp:1209
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Definition: VectorOps.cpp:1212
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Definition: PatternMatch.h:297
bool operator==(const KeyTy &key) const
Definition: VectorOps.cpp:412
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Definition: VectorOps.cpp:414