MLIR  21.0.0git
VectorOps.cpp
Go to the documentation of this file.
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/AffineMap.h"
28 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/IRMapping.h"
34 #include "mlir/IR/PatternMatch.h"
35 #include "mlir/IR/TypeUtilities.h"
38 #include "mlir/Support/LLVM.h"
40 #include "llvm/ADT/ArrayRef.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SmallVector.h"
43 #include "llvm/ADT/StringSet.h"
44 #include "llvm/ADT/TypeSwitch.h"
45 
46 #include <cassert>
47 #include <cstdint>
48 #include <numeric>
49 
50 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
51 // Pull in all enum type and utility function definitions.
52 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
53 
54 using namespace mlir;
55 using namespace mlir::vector;
56 
57 /// Helper enum to classify mask value.
58 enum class MaskFormat {
59  AllTrue = 0,
60  AllFalse = 1,
61  Unknown = 2,
62 };
63 
64 /// Helper method to classify a mask value. Currently, the method
65 /// looks "under the hood" of a constant value with dense attributes
66 /// and a constant mask operation (since the client may be called at
67 /// various stages during progressive lowering).
69  if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
70  // Inspect constant dense values. We count up for bits that
71  // are set, count down for bits that are cleared, and bail
72  // when a mix is detected.
73  if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
74  int64_t val = 0;
75  for (bool b : denseElts.getValues<bool>())
76  if (b && val >= 0)
77  val++;
78  else if (!b && val <= 0)
79  val--;
80  else
81  return MaskFormat::Unknown;
82  if (val > 0)
83  return MaskFormat::AllTrue;
84  if (val < 0)
85  return MaskFormat::AllFalse;
86  }
87  } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
88  // Inspect constant mask index. If the index exceeds the
89  // dimension size, all bits are set. If the index is zero
90  // or less, no bits are set.
91  ArrayRef<int64_t> masks = m.getMaskDimSizes();
92  auto shape = m.getType().getShape();
93  bool allTrue = true;
94  bool allFalse = true;
95  for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
96  if (maskIdx < dimSize)
97  allTrue = false;
98  if (maskIdx > 0)
99  allFalse = false;
100  }
101  if (allTrue)
102  return MaskFormat::AllTrue;
103  if (allFalse)
104  return MaskFormat::AllFalse;
105  } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
106  // Finds all-false create_masks. An all-true create_mask requires all
107  // dims to be constants, so that'll be folded to a constant_mask, then
108  // detected in the constant_mask case.
109  auto maskOperands = m.getOperands();
110  for (Value operand : maskOperands) {
111  if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
112  int64_t dimSize =
113  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
114  if (dimSize <= 0)
115  return MaskFormat::AllFalse;
116  }
117  }
118  return MaskFormat::Unknown;
119  }
120  return MaskFormat::Unknown;
121 }
122 
123 /// Default callback to build a region with a 'vector.yield' terminator with no
124 /// arguments.
126  builder.create<vector::YieldOp>(loc);
127 }
128 
129 // Helper for verifying combining kinds in contractions and reductions.
130 static bool isSupportedCombiningKind(CombiningKind combiningKind,
131  Type elementType) {
132  switch (combiningKind) {
133  case CombiningKind::ADD:
134  case CombiningKind::MUL:
135  return elementType.isIntOrIndexOrFloat();
137  case CombiningKind::MINSI:
138  case CombiningKind::MAXUI:
139  case CombiningKind::MAXSI:
140  case CombiningKind::AND:
141  case CombiningKind::OR:
142  case CombiningKind::XOR:
143  return elementType.isIntOrIndex();
144  case CombiningKind::MINNUMF:
145  case CombiningKind::MAXNUMF:
146  case CombiningKind::MINIMUMF:
147  case CombiningKind::MAXIMUMF:
148  return llvm::isa<FloatType>(elementType);
149  }
150  return false;
151 }
152 
154  VectorType vectorType) {
155  int64_t elementVectorRank = 0;
156  VectorType elementVectorType =
157  llvm::dyn_cast<VectorType>(shapedType.getElementType());
158  if (elementVectorType)
159  elementVectorRank += elementVectorType.getRank();
160  // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
161  // TODO: replace once we have 0-d vectors.
162  if (shapedType.getRank() == 0 &&
163  vectorType.getShape() == ArrayRef<int64_t>{1})
164  return AffineMap::get(
165  /*numDims=*/0, /*numSymbols=*/0,
166  getAffineConstantExpr(0, shapedType.getContext()));
168  shapedType.getRank(), vectorType.getRank() - elementVectorRank,
169  shapedType.getContext());
170 }
171 
172 /// Check if `write` is of a constant splat and the masked `read` is padded with
173 /// the same splat value -- meaning it could be the same value as the initial
174 /// constant splat.
175 static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
176  vector::TransferReadOp read) {
177  auto readMask = read.getMask();
178  auto writeMask = write.getMask();
179  // Check if the masks are consistent. The splat value could be the same if the
180  // read is masked (and padded with the splat value), and the write is unmasked
181  // or has the same mask. Note this does not allow the case where the write is
182  // masked and the read is unmasked, as then the read could be of more elements
183  // than the write (which may not be the same value).
184  bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
185  if (!couldBeSameSplat)
186  return false;
187  // Check for constant splat (as the source of the write).
188  DenseElementsAttr splatAttr;
189  if (!matchPattern(write.getVector(),
190  m_Constant<DenseElementsAttr>(&splatAttr)) ||
191  !splatAttr.isSplat()) {
192  return false;
193  }
194  // The padding of the read and the constant splat value must be the same.
195  Attribute padAttr;
196  if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
197  return false;
198  return padAttr == splatAttr.getSplatValue<Attribute>();
199 }
200 
201 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
202  vector::TransferReadOp read) {
203  return !defWrite.hasOutOfBoundsDim() &&
204  defWrite.getIndices() == read.getIndices() &&
205  defWrite.getVectorType() == read.getVectorType() &&
206  defWrite.getPermutationMap() == read.getPermutationMap() &&
207  ((!defWrite.getMask() && !read.getMask()) ||
208  isSplatWriteConsistentWithMaskedRead(defWrite, read));
209 }
210 
211 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
212  vector::TransferWriteOp priorWrite) {
213  return priorWrite.getIndices() == write.getIndices() &&
214  priorWrite.getMask() == write.getMask() &&
215  priorWrite.getVectorType() == write.getVectorType() &&
216  priorWrite.getPermutationMap() == write.getPermutationMap();
217 }
218 
220  VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
221  bool testDynamicValueUsingBounds) {
222  // For simplicity only look at transfer of same type.
223  if (transferA.getVectorType() != transferB.getVectorType())
224  return false;
225  unsigned rankOffset = transferA.getLeadingShapedRank();
226  for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
227  Value indexA = transferA.getIndices()[i];
228  Value indexB = transferB.getIndices()[i];
229  std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
230  std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
231 
232  if (i < rankOffset) {
233  // For leading dimensions, if we can prove that index are different we
234  // know we are accessing disjoint slices.
235  if (cstIndexA.has_value() && cstIndexB.has_value()) {
236  if (*cstIndexA != *cstIndexB)
237  return true;
238  continue;
239  }
240  if (testDynamicValueUsingBounds) {
241  // First try to see if we can fully compose and simplify the affine
242  // expression as a fast track.
243  FailureOr<uint64_t> delta =
245  if (succeeded(delta) && *delta != 0)
246  return true;
247 
248  FailureOr<bool> testEqual =
249  ValueBoundsConstraintSet::areEqual(indexA, indexB);
250  if (succeeded(testEqual) && !testEqual.value())
251  return true;
252  }
253  } else {
254  // For this dimension, we slice a part of the memref we need to make sure
255  // the intervals accessed don't overlap.
256  int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
257  if (cstIndexA.has_value() && cstIndexB.has_value()) {
258  int64_t distance = std::abs(*cstIndexA - *cstIndexB);
259  if (distance >= vectorDim)
260  return true;
261  continue;
262  }
263  if (testDynamicValueUsingBounds) {
264  // First try to see if we can fully compose and simplify the affine
265  // expression as a fast track.
266  FailureOr<int64_t> delta =
268  if (succeeded(delta) && std::abs(*delta) >= vectorDim)
269  return true;
270 
271  FailureOr<int64_t> computeDelta =
273  if (succeeded(computeDelta)) {
274  if (std::abs(computeDelta.value()) >= vectorDim)
275  return true;
276  }
277  }
278  }
279  }
280  return false;
281 }
282 
283 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
284  VectorTransferOpInterface transferB,
285  bool testDynamicValueUsingBounds) {
286  if (transferA.getSource() != transferB.getSource())
287  return false;
288  return isDisjointTransferIndices(transferA, transferB,
289  testDynamicValueUsingBounds);
290 }
291 
292 // Helper to iterate over n-D vector slice elements. Calculate the next
293 // `position` in the n-D vector of size `shape`, applying an offset `offsets`.
294 // Modifies the `position` in place. Returns a failure when `position` becomes
295 // the end position.
296 static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
297  ArrayRef<int64_t> shape,
298  ArrayRef<int64_t> offsets) {
299  for (auto [posInDim, dimSize, offsetInDim] :
300  llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
301  ++posInDim;
302  if (posInDim < dimSize + offsetInDim)
303  return success();
304 
305  // Carry the overflow to the next loop iteration.
306  posInDim = offsetInDim;
307  }
308 
309  return failure();
310 }
311 
312 /// Returns the integer numbers in `values`. `values` are expected to be
313 /// constant operations.
316  llvm::transform(values, std::back_inserter(ints), [](Value value) {
317  auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
318  assert(constOp && "Unexpected non-constant index");
319  return constOp.value();
320  });
321  return ints;
322 }
323 
324 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
325 /// be constant operations.
328  llvm::transform(
329  foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
330  assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
331  return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
332  });
333  return ints;
334 }
335 
336 /// Convert `foldResults` into Values. Integer attributes are converted to
337 /// constant op.
339  ArrayRef<OpFoldResult> foldResults) {
340  SmallVector<Value> values;
341  llvm::transform(foldResults, std::back_inserter(values),
342  [&](OpFoldResult foldResult) {
343  if (auto attr = foldResult.dyn_cast<Attribute>())
344  return builder
346  loc, cast<IntegerAttr>(attr).getInt())
347  .getResult();
348 
349  return cast<Value>(foldResult);
350  });
351  return values;
352 }
353 
354 std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
355  if (value.getDefiningOp<vector::VectorScaleOp>())
356  return 1;
357  auto mul = value.getDefiningOp<arith::MulIOp>();
358  if (!mul)
359  return {};
360  auto lhs = mul.getLhs();
361  auto rhs = mul.getRhs();
362  if (lhs.getDefiningOp<vector::VectorScaleOp>())
363  return getConstantIntValue(rhs);
364  if (rhs.getDefiningOp<vector::VectorScaleOp>())
365  return getConstantIntValue(lhs);
366  return {};
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // CombiningKindAttr
371 //===----------------------------------------------------------------------===//
372 
373 namespace mlir {
374 namespace vector {
375 namespace detail {
377  using KeyTy = uint64_t;
378 
379  BitmaskEnumStorage(KeyTy val) : value(val) {}
380 
381  bool operator==(const KeyTy &key) const { return value == key; }
382 
384  const KeyTy &key) {
385  return new (allocator.allocate<BitmaskEnumStorage>())
386  BitmaskEnumStorage(key);
387  }
388 
389  KeyTy value = 0;
390 };
391 } // namespace detail
392 } // namespace vector
393 } // namespace mlir
394 
395 //===----------------------------------------------------------------------===//
396 // VectorDialect
397 //===----------------------------------------------------------------------===//
398 
399 namespace {
400 /// This class defines the interface for handling inlining with vector dialect
401 /// operations.
402 struct VectorInlinerInterface : public DialectInlinerInterface {
404 
405  /// All vector dialect ops can be inlined.
406  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
407  return true;
408  }
409 };
410 } // namespace
411 
412 void VectorDialect::initialize() {
413  addAttributes<
414 #define GET_ATTRDEF_LIST
415 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
416  >();
417 
418  addOperations<
419 #define GET_OP_LIST
420 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
421  >();
422 
423  addInterfaces<VectorInlinerInterface>();
424 
425  declarePromisedInterfaces<bufferization::BufferizableOpInterface,
426  TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
427  YieldOp>();
428  declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
429  TransferWriteOp>();
430  declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
431  declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
432  declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
433 }
434 
435 /// Materialize a single constant operation from a given attribute value with
436 /// the desired resultant type.
438  Attribute value, Type type,
439  Location loc) {
440  if (isa<ub::PoisonAttrInterface>(value))
441  return value.getDialect().materializeConstant(builder, value, type, loc);
442 
443  return arith::ConstantOp::materialize(builder, value, type, loc);
444 }
445 
447  return builder.getIntegerType(64);
448 }
449 
451  ArrayRef<int64_t> values) {
452  return builder.getI64ArrayAttr(values);
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // MultiDimReductionOp
457 //===----------------------------------------------------------------------===//
458 
459 void vector::MultiDimReductionOp::build(OpBuilder &builder,
460  OperationState &result, Value source,
461  Value acc, ArrayRef<bool> reductionMask,
462  CombiningKind kind) {
463  SmallVector<int64_t> reductionDims;
464  for (const auto &en : llvm::enumerate(reductionMask))
465  if (en.value())
466  reductionDims.push_back(en.index());
467  build(builder, result, kind, source, acc, reductionDims);
468 }
469 
470 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
471  // Single parallel dim, this is a noop.
472  if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
473  return getSource();
474  return {};
475 }
476 
477 std::optional<SmallVector<int64_t, 4>>
478 MultiDimReductionOp::getShapeForUnroll() {
479  return llvm::to_vector<4>(getSourceVectorType().getShape());
480 }
481 
482 LogicalResult MultiDimReductionOp::verify() {
483  SmallVector<int64_t> targetShape;
484  SmallVector<bool> scalableDims;
485  Type inferredReturnType;
486  auto sourceScalableDims = getSourceVectorType().getScalableDims();
487  for (auto [dimIdx, dimSize] :
488  llvm::enumerate(getSourceVectorType().getShape()))
489  if (!llvm::any_of(getReductionDims(),
490  [dimIdx = dimIdx](int64_t reductionDimIdx) {
491  return reductionDimIdx == static_cast<int64_t>(dimIdx);
492  })) {
493  targetShape.push_back(dimSize);
494  scalableDims.push_back(sourceScalableDims[dimIdx]);
495  }
496  // TODO: update to also allow 0-d vectors when available.
497  if (targetShape.empty())
498  inferredReturnType = getSourceVectorType().getElementType();
499  else
500  inferredReturnType = VectorType::get(
501  targetShape, getSourceVectorType().getElementType(), scalableDims);
502  if (getType() != inferredReturnType)
503  return emitOpError() << "destination type " << getType()
504  << " is incompatible with source type "
505  << getSourceVectorType();
506 
507  return success();
508 }
509 
510 /// Returns the mask type expected by this operation.
511 Type MultiDimReductionOp::getExpectedMaskType() {
512  auto vecType = getSourceVectorType();
513  return VectorType::get(vecType.getShape(),
514  IntegerType::get(vecType.getContext(), /*width=*/1),
515  vecType.getScalableDims());
516 }
517 
518 namespace {
519 // Only unit dimensions that are being reduced are folded. If the dimension is
520 // unit, but not reduced, it is not folded, thereby keeping the output type the
521 // same. If not all dimensions which are reduced are of unit dimension, this
522 // transformation does nothing. This is just a generalization of
523 // ElideSingleElementReduction for ReduceOp.
524 struct ElideUnitDimsInMultiDimReduction
525  : public OpRewritePattern<MultiDimReductionOp> {
527 
528  LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
529  PatternRewriter &rewriter) const override {
530  ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
531  for (const auto &dim : enumerate(shape)) {
532  if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
533  return failure();
534  }
535 
536  // Vector mask setup.
537  OpBuilder::InsertionGuard guard(rewriter);
538  Operation *rootOp;
539  Value mask;
540  if (reductionOp.isMasked()) {
541  rewriter.setInsertionPoint(reductionOp.getMaskingOp());
542  rootOp = reductionOp.getMaskingOp();
543  mask = reductionOp.getMaskingOp().getMask();
544  } else {
545  rootOp = reductionOp;
546  }
547 
548  Location loc = reductionOp.getLoc();
549  Value acc = reductionOp.getAcc();
550  Value cast;
551  if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
552  if (mask) {
553  VectorType newMaskType =
554  VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
555  dstVecType.getScalableDims());
556  mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
557  }
558  cast = rewriter.create<vector::ShapeCastOp>(
559  loc, reductionOp.getDestType(), reductionOp.getSource());
560  } else {
561  // This means we are reducing all the dimensions, and all reduction
562  // dimensions are of size 1. So a simple extraction would do.
563  if (mask)
564  mask = rewriter.create<vector::ExtractOp>(loc, mask);
565  cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource());
566  }
567 
568  Value result =
569  vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
570  cast, /*fastmath=*/nullptr, mask);
571  rewriter.replaceOp(rootOp, result);
572  return success();
573  }
574 };
575 } // namespace
576 
577 void MultiDimReductionOp::getCanonicalizationPatterns(
578  RewritePatternSet &results, MLIRContext *context) {
579  results.add<ElideUnitDimsInMultiDimReduction>(context);
580 }
581 
582 //===----------------------------------------------------------------------===//
583 // ReductionOp
584 //===----------------------------------------------------------------------===//
585 
586 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
587  CombiningKind kind, Value vector,
588  arith::FastMathFlags fastMathFlags) {
589  build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
590 }
591 
592 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
593  CombiningKind kind, Value vector, Value acc,
594  arith::FastMathFlags fastMathFlags) {
595  build(builder, result,
596  llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
597  acc, fastMathFlags);
598 }
599 
600 LogicalResult ReductionOp::verify() {
601  // Verify for 0-D and 1-D vector.
602  int64_t rank = getSourceVectorType().getRank();
603  if (rank > 1)
604  return emitOpError("unsupported reduction rank: ") << rank;
605 
606  // Verify supported reduction kind.
607  Type eltType = getDest().getType();
608  if (!isSupportedCombiningKind(getKind(), eltType))
609  return emitOpError("unsupported reduction type '")
610  << eltType << "' for kind '" << stringifyCombiningKind(getKind())
611  << "'";
612 
613  return success();
614 }
615 
616 // MaskableOpInterface methods.
617 
618 /// Returns the mask type expected by this operation.
619 Type ReductionOp::getExpectedMaskType() {
620  auto vecType = getSourceVectorType();
621  return VectorType::get(vecType.getShape(),
622  IntegerType::get(vecType.getContext(), /*width=*/1),
623  vecType.getScalableDims());
624 }
625 
627  OpBuilder &builder, Location loc,
628  Value vector) {
629  switch (op) {
630  case arith::AtomicRMWKind::addf:
631  case arith::AtomicRMWKind::addi:
632  return builder.create<vector::ReductionOp>(vector.getLoc(),
633  CombiningKind::ADD, vector);
634  case arith::AtomicRMWKind::mulf:
635  case arith::AtomicRMWKind::muli:
636  return builder.create<vector::ReductionOp>(vector.getLoc(),
637  CombiningKind::MUL, vector);
638  case arith::AtomicRMWKind::minimumf:
639  return builder.create<vector::ReductionOp>(vector.getLoc(),
640  CombiningKind::MINIMUMF, vector);
641  case arith::AtomicRMWKind::mins:
642  return builder.create<vector::ReductionOp>(vector.getLoc(),
643  CombiningKind::MINSI, vector);
644  case arith::AtomicRMWKind::minu:
645  return builder.create<vector::ReductionOp>(vector.getLoc(),
646  CombiningKind::MINUI, vector);
647  case arith::AtomicRMWKind::maximumf:
648  return builder.create<vector::ReductionOp>(vector.getLoc(),
649  CombiningKind::MAXIMUMF, vector);
650  case arith::AtomicRMWKind::maxs:
651  return builder.create<vector::ReductionOp>(vector.getLoc(),
652  CombiningKind::MAXSI, vector);
653  case arith::AtomicRMWKind::maxu:
654  return builder.create<vector::ReductionOp>(vector.getLoc(),
655  CombiningKind::MAXUI, vector);
656  case arith::AtomicRMWKind::andi:
657  return builder.create<vector::ReductionOp>(vector.getLoc(),
658  CombiningKind::AND, vector);
659  case arith::AtomicRMWKind::ori:
660  return builder.create<vector::ReductionOp>(vector.getLoc(),
661  CombiningKind::OR, vector);
662  // TODO: Add remaining reduction operations.
663  default:
664  (void)emitOptionalError(loc, "Reduction operation type not supported");
665  break;
666  }
667  return nullptr;
668 }
669 
670 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
671  return llvm::to_vector<4>(getSourceVectorType().getShape());
672 }
673 
674 namespace {
675 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
677 
678  LogicalResult matchAndRewrite(ReductionOp reductionOp,
679  PatternRewriter &rewriter) const override {
680  // Vector mask setup.
681  OpBuilder::InsertionGuard guard(rewriter);
682  auto maskableOp =
683  cast<vector::MaskableOpInterface>(reductionOp.getOperation());
684  Operation *rootOp;
685  Value mask;
686  if (maskableOp.isMasked()) {
687  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
688  rootOp = maskableOp.getMaskingOp();
689  mask = maskableOp.getMaskingOp().getMask();
690  } else {
691  rootOp = reductionOp;
692  }
693 
694  auto vectorType = reductionOp.getSourceVectorType();
695  if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
696  return failure();
697 
698  Location loc = reductionOp.getLoc();
699  if (mask)
700  mask = rewriter.create<ExtractOp>(loc, mask);
701  Value result = rewriter.create<ExtractOp>(loc, reductionOp.getVector());
702 
703  if (Value acc = reductionOp.getAcc())
704  result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
705  result, acc,
706  reductionOp.getFastmathAttr(), mask);
707 
708  rewriter.replaceOp(rootOp, result);
709  return success();
710  }
711 };
712 } // namespace
713 
714 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
715  MLIRContext *context) {
716  results.add<ElideSingleElementReduction>(context);
717 }
718 
719 //===----------------------------------------------------------------------===//
720 // ContractionOp
721 //===----------------------------------------------------------------------===//
722 
723 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
724  Value lhs, Value rhs, Value acc,
725  ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
726  ArrayRef<IteratorType> iteratorTypes) {
727  result.addOperands({lhs, rhs, acc});
728  result.addTypes(acc.getType());
729  result.addAttribute(
730  getIndexingMapsAttrName(result.name),
731  builder.getAffineMapArrayAttr(
732  AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
733  result.addAttribute(
734  getIteratorTypesAttrName(result.name),
735  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
736  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
737  return IteratorTypeAttr::get(builder.getContext(), t);
738  }))));
739 }
740 
741 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
742  Value lhs, Value rhs, Value acc,
743  ArrayAttr indexingMaps,
744  ArrayAttr iteratorTypes) {
745  build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
746  ContractionOp::getDefaultKind());
747 }
748 
749 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
750  Value lhs, Value rhs, Value acc,
751  ArrayAttr indexingMaps,
752  ArrayAttr iteratorTypes, CombiningKind kind) {
753  result.addOperands({lhs, rhs, acc});
754  result.addTypes(acc.getType());
755  result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
756  result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
757  result.addAttribute(getKindAttrName(result.name),
759 }
760 
761 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
766  SmallVector<Type, 2> types;
767  Type resultType;
768  auto loc = parser.getCurrentLocation();
769  DictionaryAttr dictAttr;
770  // TODO: Unify linalg op attribute parsing.
771  if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
772  parser.parseComma() || parser.parseOperand(rhsInfo) ||
773  parser.parseComma() || parser.parseOperand(accInfo) ||
774  parser.parseTrailingOperandList(masksInfo) ||
775  parser.parseOptionalAttrDict(result.attributes) ||
776  parser.parseColonTypeList(types) ||
777  parser.parseKeywordType("into", resultType) ||
778  parser.resolveOperand(lhsInfo, types[0], result.operands) ||
779  parser.resolveOperand(rhsInfo, types[1], result.operands) ||
780  parser.resolveOperand(accInfo, resultType, result.operands) ||
781  parser.addTypeToList(resultType, result.types))
782  return failure();
783  result.attributes.append(dictAttr.getValue().begin(),
784  dictAttr.getValue().end());
785 
786  // Convert array of string into an array of IteratyType enums. This is needed,
787  // because tests still use the old format when 'iterator_types' attribute is
788  // represented as an array of strings.
789  // TODO: Remove this conversion once tests are fixed.
790  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
791  result.attributes.get(getIteratorTypesAttrName(result.name)));
792  if (!iteratorTypes) {
793  return parser.emitError(loc)
794  << "expected " << getIteratorTypesAttrName(result.name)
795  << " array attribute";
796  }
797 
798  SmallVector<Attribute> iteratorTypeAttrs;
799 
800  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
801  auto maybeIteratorType = symbolizeIteratorType(s);
802  if (!maybeIteratorType.has_value())
803  return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
804 
805  iteratorTypeAttrs.push_back(
806  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
807  }
808  result.attributes.set(getIteratorTypesAttrName(result.name),
809  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
810 
811  if (!result.attributes.get(getKindAttrName(result.name))) {
812  result.addAttribute(
813  getKindAttrName(result.name),
815  ContractionOp::getDefaultKind()));
816  }
817  if (masksInfo.empty())
818  return success();
819  if (masksInfo.size() != 2)
820  return parser.emitError(parser.getNameLoc(),
821  "expected zero or exactly 2 vector mask operands");
822  auto lhsType = llvm::cast<VectorType>(types[0]);
823  auto rhsType = llvm::cast<VectorType>(types[1]);
824  auto maskElementType = parser.getBuilder().getI1Type();
825  std::array<VectorType, 2> maskTypes = {
826  VectorType::Builder(lhsType).setElementType(maskElementType),
827  VectorType::Builder(rhsType).setElementType(maskElementType)};
828  if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
829  return failure();
830  return success();
831 }
832 
834  // TODO: Unify printing code with linalg ops.
835  auto attrNames = getTraitAttrNames();
836  llvm::StringSet<> traitAttrsSet;
837  traitAttrsSet.insert_range(attrNames);
839  for (auto attr : (*this)->getAttrs()) {
840  if (attr.getName() == getIteratorTypesAttrName()) {
841  auto iteratorTypes =
842  llvm::cast<ArrayAttr>(attr.getValue())
843  .getAsValueRange<IteratorTypeAttr, IteratorType>();
844  // Convert IteratorType enums into the string representation. This is
845  // needed, because tests still use the old format when 'iterator_types'
846  // attribute is represented as an array of strings.
847  // TODO: Remove this conversion once tests are fixed.
848  SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
849  llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
850  return StringAttr::get(getContext(), stringifyIteratorType(t));
851  }));
852 
853  attrs.emplace_back(getIteratorTypesAttrName(),
854  ArrayAttr::get(getContext(), iteratorTypeNames));
855  } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
856  attrs.push_back(attr);
857  }
858 
859  auto dictAttr = DictionaryAttr::get(getContext(), attrs);
860  p << " " << dictAttr << " " << getLhs() << ", ";
861  p << getRhs() << ", " << getAcc();
862 
863  p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
864  p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
865  << getResultType();
866 }
867 
868 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
869  const std::vector<std::pair<int64_t, int64_t>> &map) {
870  for (auto &dimPair : map) {
871  if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
872  dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
873  lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
874  return false;
875  }
876  return true;
877 }
878 
879 static LogicalResult verifyOutputShape(
880  ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
881  Type resType,
882  const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
883  const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
884  DenseSet<int64_t> lhsContractingDimSet;
885  DenseSet<int64_t> rhsContractingDimSet;
886  for (auto &dimPair : contractingDimMap) {
887  lhsContractingDimSet.insert(dimPair.first);
888  rhsContractingDimSet.insert(dimPair.second);
889  }
890  DenseSet<int64_t> rhsBatchDimSet(llvm::from_range,
891  llvm::make_second_range(batchDimMap));
892 
893  // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
894  SmallVector<int64_t, 4> expectedResultDims;
895  for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
896  if (lhsContractingDimSet.count(i) > 0)
897  continue;
898  expectedResultDims.push_back(lhsType.getDimSize(i));
899  }
900 
901  // Add free dimensions from 'rhsType' to 'expectedResultDims'.
902  for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
903  if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
904  continue;
905  expectedResultDims.push_back(rhsType.getDimSize(i));
906  }
907 
908  // Verify 'expectedResultDims'.
909  if (expectedResultDims.empty()) {
910  // No batch or free dimension implies a scalar result.
911  if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
912  return op.emitOpError("invalid accumulator/result vector shape");
913  } else {
914  // At least one batch or free dimension implies a vector result.
915  auto resVectorType = llvm::dyn_cast<VectorType>(resType);
916  auto accVectorType = llvm::dyn_cast<VectorType>(accType);
917  if (!resVectorType || !accVectorType)
918  return op.emitOpError("invalid accumulator/result vector shape");
919 
920  // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
921  // types fully define the result vector type. This assumes the affine maps
922  // are well-formed, which must have been verified already.
923  MLIRContext *ctx = op.getContext();
924  AffineMap lhsMap = op.getIndexingMapsArray()[0];
925  AffineMap rhsMap = op.getIndexingMapsArray()[1];
926  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
927  return op.emitOpError(
928  "expected all dimensions to be either a LHS or a RHS dimension");
929  SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
930  for (auto pair :
931  {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
932  VectorType v = pair.first;
933  auto map = pair.second;
934  for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
935  unsigned pos = map.getDimPosition(idx);
936  if (!extents[pos])
937  extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
938  }
939  }
940  if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
941  return op.emitOpError("expected all dimensions to get an extent as "
942  "either a LHS or a RHS dimension");
943 
944  AffineMap resMap = op.getIndexingMapsArray()[2];
945  auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
946  /*symbolCount=*/0, extents, ctx);
947  // Compose the resMap with the extentsMap, which is a constant map.
948  AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
949  assert(llvm::all_of(expectedMap.getResults(),
950  llvm::IsaPred<AffineConstantExpr>) &&
951  "expected constant extent along all dimensions.");
952  // Extract the expected shape and build the type.
953  auto expectedShape = llvm::to_vector<4>(
954  llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
955  return cast<AffineConstantExpr>(e).getValue();
956  }));
957  auto expected =
958  VectorType::get(expectedShape, resVectorType.getElementType(),
959  resVectorType.getScalableDims());
960  if (resVectorType != expected || accVectorType != expected)
961  return op.emitOpError(
962  "invalid accumulator/result vector shape, expected: ")
963  << expected;
964  }
965  return success();
966 }
967 
968 LogicalResult ContractionOp::verify() {
969  VectorType lhsType = getLhsType();
970  VectorType rhsType = getRhsType();
971  Type accType = getAccType();
972  Type resType = getResultType();
973 
974  if (llvm::isa<IntegerType>(lhsType.getElementType())) {
975  if (!lhsType.getElementType().isSignlessInteger())
976  return emitOpError("only supports signless integer types");
977  }
978 
979  // Verify that an indexing map was specified for each vector operand.
980  if (getIndexingMapsArray().size() != 3)
981  return emitOpError("expected an indexing map for each vector operand");
982 
983  // Verify that each index map has 'numIterators' inputs, no symbols, and
984  // that the number of map outputs equals the rank of its associated
985  // vector operand.
986  unsigned numIterators = getIteratorTypes().getValue().size();
987  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
988  auto index = it.index();
989  auto map = it.value();
990  if (map.getNumSymbols() != 0)
991  return emitOpError("expected indexing map ")
992  << index << " to have no symbols";
993  auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
994  unsigned rank = vectorType ? vectorType.getShape().size() : 0;
995  // Verify that the map has the right number of inputs, outputs, and indices.
996  // This also correctly accounts for (..) -> () for rank-0 results.
997  if (map.getNumDims() != numIterators)
998  return emitOpError("expected indexing map ")
999  << index << " to have " << numIterators << " number of inputs";
1000  if (map.getNumResults() != rank)
1001  return emitOpError("expected indexing map ")
1002  << index << " to have " << rank << " number of outputs";
1003  if (!map.isProjectedPermutation())
1004  return emitOpError("expected indexing map ")
1005  << index << " to be a projected permutation of its inputs";
1006  }
1007 
1008  auto contractingDimMap = getContractingDimMap();
1009  auto batchDimMap = getBatchDimMap();
1010 
1011  // Verify at least one contracting dimension pair was specified.
1012  if (contractingDimMap.empty())
1013  return emitOpError("expected at least one contracting dimension pair");
1014 
1015  // Verify contracting dimension map was properly constructed.
1016  if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1017  return emitOpError("invalid contracting dimension map");
1018 
1019  // Verify batch dimension map was properly constructed.
1020  if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1021  return emitOpError("invalid batch dimension map");
1022 
1023  // Verify 'accType' and 'resType' shape.
1024  if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1025  contractingDimMap, batchDimMap)))
1026  return failure();
1027 
1028  // Verify supported combining kind.
1029  auto vectorType = llvm::dyn_cast<VectorType>(resType);
1030  auto elementType = vectorType ? vectorType.getElementType() : resType;
1031  if (!isSupportedCombiningKind(getKind(), elementType))
1032  return emitOpError("unsupported contraction type");
1033 
1034  return success();
1035 }
1036 
1037 // MaskableOpInterface methods.
1038 
1039 /// Returns the mask type expected by this operation. Mostly used for
1040 /// verification purposes. It requires the operation to be vectorized."
1041 Type ContractionOp::getExpectedMaskType() {
1042  auto indexingMaps = this->getIndexingMapsArray();
1043  AffineMap lhsIdxMap = indexingMaps[0];
1044  AffineMap rhsIdxMap = indexingMaps[1];
1045  VectorType lhsType = this->getLhsType();
1046  VectorType rhsType = this->getRhsType();
1047 
1048  unsigned numVecDims = lhsIdxMap.getNumDims();
1049  SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1050  SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1051 
1052  // Using the information in the indexing maps, extract the size of each
1053  // dimension in the vector.contract operation from the two input operands.
1054  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1055  maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1056  maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1057  lhsType.getScalableDims()[dimIdx];
1058  }
1059  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1060  maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1061  maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1062  rhsType.getScalableDims()[dimIdx];
1063  }
1064 
1065  assert(!ShapedType::isDynamicShape(maskShape) &&
1066  "Mask shape couldn't be computed");
1067 
1068  return VectorType::get(maskShape,
1069  IntegerType::get(lhsType.getContext(), /*width=*/1),
1070  maskShapeScalableDims);
1071 }
1072 
1073 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1074  return SmallVector<StringRef>{getIndexingMapsAttrName(),
1075  getIteratorTypesAttrName(), getKindAttrName()};
1076 }
1077 
1078 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
1079  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1080  if (targetExpr == map.getResult(i))
1081  return i;
1082  return -1;
1083 }
1084 
1085 static std::vector<std::pair<int64_t, int64_t>>
1086 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1087  IteratorType targetIteratorType, MLIRContext *context) {
1088  std::vector<std::pair<int64_t, int64_t>> dimMap;
1089  for (const auto &it : llvm::enumerate(iteratorTypes)) {
1090  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1091  if (iteratorType != targetIteratorType)
1092  continue;
1093  // Search lhs/rhs map results for 'targetExpr'.
1094  auto targetExpr = getAffineDimExpr(it.index(), context);
1095  int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1096  int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1097  if (lhsDim >= 0 && rhsDim >= 0)
1098  dimMap.emplace_back(lhsDim, rhsDim);
1099  }
1100  return dimMap;
1101 }
1102 
1103 void ContractionOp::getIterationBounds(
1104  SmallVectorImpl<int64_t> &iterationBounds) {
1105  auto lhsShape = getLhsType().getShape();
1106  auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1107  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1108  SmallVector<int64_t, 2> iterationShape;
1109  for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1110  // Search lhs/rhs map results for 'targetExpr'.
1111  auto targetExpr = getAffineDimExpr(it.index(), getContext());
1112  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1113  if (iteratorType == IteratorType::reduction) {
1114  // Get reduction dim size from lhs shape (same size in rhsShape).
1115  int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1116  assert(lhsDimIndex >= 0);
1117  iterationBounds.push_back(lhsShape[lhsDimIndex]);
1118  continue;
1119  }
1120  // Get parallel dimension size from result shape.
1121  int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1122  assert(resDimIndex >= 0);
1123  assert(resVectorType != nullptr);
1124  iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1125  }
1126 }
1127 
1128 void ContractionOp::getIterationIndexMap(
1129  std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1130  unsigned numMaps = getIndexingMapsArray().size();
1131  iterationIndexMap.resize(numMaps);
1132  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1133  auto index = it.index();
1134  auto map = it.value();
1135  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1136  auto dim = cast<AffineDimExpr>(map.getResult(i));
1137  iterationIndexMap[index][dim.getPosition()] = i;
1138  }
1139  }
1140 }
1141 
1142 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1143  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1144  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1145  getContext());
1146 }
1147 
1148 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1149  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1150  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1151  getContext());
1152 }
1153 
1154 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1156  getIterationBounds(shape);
1157  return shape;
1158 }
1159 
1160 /// Return a fused vector::ContractionOp which represents a patterns such as:
1161 ///
1162 /// ```mlir
1163 /// %c0 = vector.constant 0: ...
1164 /// %c = vector.contract %a, %b, %c0: ...
1165 /// %e = add %c, %d: ...
1166 /// ```
1167 ///
1168 /// by:
1169 ///
1170 /// ```mlir
1171 /// %e = vector.contract %a, %b, %d: ...
1172 /// ```
1173 ///
1174 /// Return null if the canonicalization does not apply.
1175 // TODO: This should be a folding of Add into Contract in core but while they
1176 // live in different dialects, it is not possible without unnatural
1177 // dependencies.
1178 template <typename AddOpType>
1179 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1181 
1182  LogicalResult matchAndRewrite(AddOpType addOp,
1183  PatternRewriter &rewriter) const override {
1184  auto canonicalize = [&](Value maybeContraction,
1185  Value otherOperand) -> vector::ContractionOp {
1186  vector::ContractionOp contractionOp =
1187  dyn_cast_or_null<vector::ContractionOp>(
1188  maybeContraction.getDefiningOp());
1189  if (!contractionOp)
1190  return vector::ContractionOp();
1191  if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1192  contractionOp.getAcc().getDefiningOp())) {
1193  if (maybeZero.getValue() ==
1194  rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1195  IRMapping bvm;
1196  bvm.map(contractionOp.getAcc(), otherOperand);
1197  auto newContraction =
1198  cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1199  rewriter.replaceOp(addOp, newContraction.getResult());
1200  return newContraction;
1201  }
1202  }
1203  return vector::ContractionOp();
1204  };
1205 
1206  Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1207  vector::ContractionOp contract = canonicalize(a, b);
1208  contract = contract ? contract : canonicalize(b, a);
1209  return contract ? success() : failure();
1210  }
1211 };
1212 
1213 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1214  MLIRContext *context) {
1217 }
1218 
1219 //===----------------------------------------------------------------------===//
1220 // ExtractElementOp
1221 //===----------------------------------------------------------------------===//
1222 
1223 void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1224  SetIntRangeFn setResultRanges) {
1225  setResultRanges(getResult(), argRanges.front());
1226 }
1227 
1228 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1229  Value source) {
1230  result.addOperands({source});
1231  result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
1232 }
1233 
1234 LogicalResult vector::ExtractElementOp::verify() {
1235  VectorType vectorType = getSourceVectorType();
1236  if (vectorType.getRank() == 0) {
1237  if (getPosition())
1238  return emitOpError("expected position to be empty with 0-D vector");
1239  return success();
1240  }
1241  if (vectorType.getRank() != 1)
1242  return emitOpError("unexpected >1 vector rank");
1243  if (!getPosition())
1244  return emitOpError("expected position for 1-D vector");
1245  return success();
1246 }
1247 
1248 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1249  // Skip the 0-D vector here now.
1250  if (!adaptor.getPosition())
1251  return {};
1252 
1253  // Fold extractelement (splat X) -> X.
1254  if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
1255  return splat.getInput();
1256 
1257  // Fold extractelement(broadcast(X)) -> X.
1258  if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1259  if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
1260  return broadcast.getSource();
1261 
1262  auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1263  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1264  if (!pos || !src)
1265  return {};
1266 
1267  auto srcElements = src.getValues<Attribute>();
1268 
1269  uint64_t posIdx = pos.getInt();
1270  if (posIdx >= srcElements.size())
1271  return {};
1272 
1273  return srcElements[posIdx];
1274 }
1275 
1276 // Returns `true` if `index` is either within [0, maxIndex) or equal to
1277 // `poisonValue`.
1278 static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
1279  int64_t maxIndex) {
1280  return index == poisonValue || (index >= 0 && index < maxIndex);
1281 }
1282 
1283 //===----------------------------------------------------------------------===//
1284 // ExtractOp
1285 //===----------------------------------------------------------------------===//
1286 
1287 void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1288  SetIntRangeFn setResultRanges) {
1289  setResultRanges(getResult(), argRanges.front());
1290 }
1291 
1292 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1293  Value source) {
1294  auto vectorTy = cast<VectorType>(source.getType());
1295  build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
1296 }
1297 
1298 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1299  Value source, int64_t position) {
1300  build(builder, result, source, ArrayRef<int64_t>{position});
1301 }
1302 
1303 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1304  Value source, OpFoldResult position) {
1305  build(builder, result, source, ArrayRef<OpFoldResult>{position});
1306 }
1307 
1308 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1309  Value source, ArrayRef<int64_t> position) {
1310  build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1311  builder.getDenseI64ArrayAttr(position));
1312 }
1313 
1314 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1315  Value source, ArrayRef<OpFoldResult> position) {
1316  SmallVector<int64_t> staticPos;
1317  SmallVector<Value> dynamicPos;
1318  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1319  build(builder, result, source, dynamicPos,
1320  builder.getDenseI64ArrayAttr(staticPos));
1321 }
1322 
1323 LogicalResult
1324 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1325  ExtractOp::Adaptor adaptor,
1326  SmallVectorImpl<Type> &inferredReturnTypes) {
1327  auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1328  if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1329  vectorType.getRank()) {
1330  inferredReturnTypes.push_back(vectorType.getElementType());
1331  } else {
1332  auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1333  vectorType.getRank());
1334  inferredReturnTypes.push_back(VectorType::get(
1335  vectorType.getShape().drop_front(n), vectorType.getElementType(),
1336  vectorType.getScalableDims().drop_front(n)));
1337  }
1338  return success();
1339 }
1340 
1341 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1342  // Allow extracting 1-element vectors instead of scalars.
1343  auto isCompatible = [](TypeRange l, TypeRange r) {
1344  auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1345  return vectorType && vectorType.getShape().equals({1}) &&
1346  vectorType.getElementType() == r.front();
1347  };
1348  if (l.size() == 1 && r.size() == 1 &&
1349  (isCompatible(l, r) || isCompatible(r, l)))
1350  return true;
1351  return l == r;
1352 }
1353 
1354 LogicalResult vector::ExtractOp::verify() {
1355  // Note: This check must come before getMixedPosition() to prevent a crash.
1356  auto dynamicMarkersCount =
1357  llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1358  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1359  return emitOpError(
1360  "mismatch between dynamic and static positions (kDynamic marker but no "
1361  "corresponding dynamic position) -- this can only happen due to an "
1362  "incorrect fold/rewrite");
1363  auto position = getMixedPosition();
1364  if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1365  return emitOpError(
1366  "expected position attribute of rank no greater than vector rank");
1367  for (auto [idx, pos] : llvm::enumerate(position)) {
1368  if (auto attr = dyn_cast<Attribute>(pos)) {
1369  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1371  constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1372  return emitOpError("expected position attribute #")
1373  << (idx + 1)
1374  << " to be a non-negative integer smaller than the "
1375  "corresponding vector dimension or poison (-1)";
1376  }
1377  }
1378  }
1379  return success();
1380 }
1381 
1382 template <typename IntType>
1383 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1384  return llvm::to_vector<4>(llvm::map_range(
1385  arrayAttr.getAsRange<IntegerAttr>(),
1386  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1387 }
1388 
1389 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1390 /// positions.
1391 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1392  if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1393  return failure();
1394 
1395  // TODO: Canonicalization for dynamic position not implemented yet.
1396  if (extractOp.hasDynamicPosition())
1397  return failure();
1398 
1399  SmallVector<int64_t> globalPosition;
1400  ExtractOp currentOp = extractOp;
1401  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1402  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1403  while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1404  currentOp = nextOp;
1405  // TODO: Canonicalization for dynamic position not implemented yet.
1406  if (currentOp.hasDynamicPosition())
1407  return failure();
1408  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1409  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1410  }
1411  extractOp.setOperand(0, currentOp.getVector());
1412  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1413  OpBuilder b(extractOp.getContext());
1414  std::reverse(globalPosition.begin(), globalPosition.end());
1415  extractOp.setStaticPosition(globalPosition);
1416  return success();
1417 }
1418 
1419 namespace {
1420 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1421 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1422 /// Compose TransposeOp permutations as we walk back.
1423 /// This helper class keeps an updated extraction position `extractPosition`
1424 /// with extra trailing sentinels.
1425 /// The sentinels encode the internal transposition status of the result vector.
1426 /// As we iterate, extractPosition is permuted and updated.
1427 class ExtractFromInsertTransposeChainState {
1428 public:
1429  ExtractFromInsertTransposeChainState(ExtractOp e);
1430 
1431  /// Iterate over producing insert and transpose ops until we find a fold.
1432  Value fold();
1433 
1434 private:
1435  /// Return true if the vector at position `a` is contained within the vector
1436  /// at position `b`. Under insert/extract semantics, this is the same as `a`
1437  /// is a prefix of `b`.
1438  template <typename ContainerA, typename ContainerB>
1439  bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1440  return a.size() <= b.size() &&
1441  std::equal(a.begin(), a.begin() + a.size(), b.begin());
1442  }
1443 
1444  /// Return true if the vector at position `a` intersects the vector at
1445  /// position `b`. Under insert/extract semantics, this is the same as equality
1446  /// of all entries of `a` that are >=0 with the corresponding entries of b.
1447  /// Comparison is on the common prefix (i.e. zip).
1448  template <typename ContainerA, typename ContainerB>
1449  bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1450  for (auto [elemA, elemB] : llvm::zip(a, b)) {
1451  if (elemA < 0 || elemB < 0)
1452  continue;
1453  if (elemA != elemB)
1454  return false;
1455  }
1456  return true;
1457  }
1458 
1459  /// Folding is only possible in the absence of an internal permutation in the
1460  /// result vector.
1461  bool canFold() {
1462  return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1463  }
1464 
1465  // Helper to get the next defining op of interest.
1466  void updateStateForNextIteration(Value v) {
1467  nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1468  nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1469  };
1470 
1471  // Case 1. If we hit a transpose, just compose the map and iterate.
1472  // Invariant: insert + transpose do not change rank, we can always compose.
1473  LogicalResult handleTransposeOp();
1474 
1475  // Case 2: the insert position matches extractPosition exactly, early return.
1476  LogicalResult handleInsertOpWithMatchingPos(Value &res);
1477 
1478  /// Case 3: if the insert position is a prefix of extractPosition, extract a
1479  /// portion of the source of the insert.
1480  /// Example:
1481  /// ```
1482  /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1483  /// // extractPosition == [1, 2, 3]
1484  /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1485  /// // can fold to vector.extract %source[0, 3]
1486  /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1487  /// ```
1488  /// To traverse through %source, we need to set the leading dims to 0 and
1489  /// drop the extra leading dims.
1490  /// This method updates the internal state.
1491  LogicalResult handleInsertOpWithPrefixPos(Value &res);
1492 
1493  /// Try to fold in place to extract(source, extractPosition) and return the
1494  /// folded result. Return null if folding is not possible (e.g. due to an
1495  /// internal transposition in the result).
1496  Value tryToFoldExtractOpInPlace(Value source);
1497 
1498  ExtractOp extractOp;
1499  int64_t vectorRank;
1500  int64_t extractedRank;
1501 
1502  InsertOp nextInsertOp;
1503  TransposeOp nextTransposeOp;
1504 
1505  /// Sentinel values that encode the internal permutation status of the result.
1506  /// They are set to (-1, ... , -k) at the beginning and appended to
1507  /// `extractPosition`.
1508  /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1509  /// ensure that there is no internal transposition.
1510  /// Internal transposition cannot be accounted for with a folding pattern.
1511  // TODO: We could relax the internal transposition with an extra transposition
1512  // operation in a future canonicalizer.
1513  SmallVector<int64_t> sentinels;
1515 };
1516 } // namespace
1517 
1518 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1519  ExtractOp e)
1520  : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1521  extractedRank(extractOp.getNumIndices()) {
1522  assert(vectorRank >= extractedRank && "Extracted position overflow");
1523  sentinels.reserve(vectorRank - extractedRank);
1524  for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1525  sentinels.push_back(-(i + 1));
1526  extractPosition.assign(extractOp.getStaticPosition().begin(),
1527  extractOp.getStaticPosition().end());
1528  llvm::append_range(extractPosition, sentinels);
1529 }
1530 
1531 // Case 1. If we hit a transpose, just compose the map and iterate.
1532 // Invariant: insert + transpose do not change rank, we can always compose.
1533 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1534  // TODO: Canonicalization for dynamic position not implemented yet.
1535  if (extractOp.hasDynamicPosition())
1536  return failure();
1537 
1538  if (!nextTransposeOp)
1539  return failure();
1540  AffineMap m = inversePermutation(AffineMap::getPermutationMap(
1541  nextTransposeOp.getPermutation(), extractOp.getContext()));
1543  return success();
1544 }
1545 
1546 // Case 2: the insert position matches extractPosition exactly, early return.
1547 LogicalResult
1548 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1549  Value &res) {
1550  // TODO: Canonicalization for dynamic position not implemented yet.
1551  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1552  return failure();
1553 
1554  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1555  if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1556  return failure();
1557  // Case 2.a. early-exit fold.
1558  res = nextInsertOp.getSource();
1559  // Case 2.b. if internal transposition is present, canFold will be false.
1560  return success(canFold());
1561 }
1562 
1563 /// Case 3: if inserted position is a prefix of extractPosition,
1564 /// extract a portion of the source of the insertion.
1565 /// This method updates the internal state.
1566 LogicalResult
1567 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1568  // TODO: Canonicalization for dynamic position not implemented yet.
1569  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1570  return failure();
1571 
1572  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1573  if (!isContainedWithin(insertedPos, extractPosition))
1574  return failure();
1575  // Set leading dims to zero.
1576  std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1577  // Drop extra leading dims.
1578  extractPosition.erase(extractPosition.begin(),
1579  extractPosition.begin() + insertedPos.size());
1580  extractedRank = extractPosition.size() - sentinels.size();
1581  // Case 3.a. early-exit fold (break and delegate to post-while path).
1582  res = nextInsertOp.getSource();
1583  // Case 3.b. if internal transposition is present, canFold will be false.
1584  return success();
1585 }
1586 
1587 /// Try to fold in place to extract(source, extractPosition) and return the
1588 /// folded result. Return null if folding is not possible (e.g. due to an
1589 /// internal transposition in the result).
1590 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1591  Value source) {
1592  // TODO: Canonicalization for dynamic position not implemented yet.
1593  if (extractOp.hasDynamicPosition())
1594  return Value();
1595 
1596  // If we can't fold (either internal transposition, or nothing to fold), bail.
1597  bool nothingToFold = (source == extractOp.getVector());
1598  if (nothingToFold || !canFold())
1599  return Value();
1600 
1601  // Otherwise, fold by updating the op inplace and return its result.
1602  OpBuilder b(extractOp.getContext());
1603  extractOp.setStaticPosition(
1604  ArrayRef(extractPosition).take_front(extractedRank));
1605  extractOp.getVectorMutable().assign(source);
1606  return extractOp.getResult();
1607 }
1608 
1609 /// Iterate over producing insert and transpose ops until we find a fold.
1610 Value ExtractFromInsertTransposeChainState::fold() {
1611  // TODO: Canonicalization for dynamic position not implemented yet.
1612  if (extractOp.hasDynamicPosition())
1613  return Value();
1614 
1615  Value valueToExtractFrom = extractOp.getVector();
1616  updateStateForNextIteration(valueToExtractFrom);
1617  while (nextInsertOp || nextTransposeOp) {
1618  // Case 1. If we hit a transpose, just compose the map and iterate.
1619  // Invariant: insert + transpose do not change rank, we can always compose.
1620  if (succeeded(handleTransposeOp())) {
1621  valueToExtractFrom = nextTransposeOp.getVector();
1622  updateStateForNextIteration(valueToExtractFrom);
1623  continue;
1624  }
1625 
1626  Value result;
1627  // Case 2: the position match exactly.
1628  if (succeeded(handleInsertOpWithMatchingPos(result)))
1629  return result;
1630 
1631  // Case 3: if the inserted position is a prefix of extractPosition, we can
1632  // just extract a portion of the source of the insert.
1633  if (succeeded(handleInsertOpWithPrefixPos(result)))
1634  return tryToFoldExtractOpInPlace(result);
1635 
1636  // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1637  // values. This is a more difficult case and we bail.
1638  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1639  if (isContainedWithin(extractPosition, insertedPos) ||
1640  intersectsWhereNonNegative(extractPosition, insertedPos))
1641  return Value();
1642 
1643  // Case 5: No intersection, we forward the extract to insertOp.dest().
1644  valueToExtractFrom = nextInsertOp.getDest();
1645  updateStateForNextIteration(valueToExtractFrom);
1646  }
1647  // If after all this we can fold, go for it.
1648  return tryToFoldExtractOpInPlace(valueToExtractFrom);
1649 }
1650 
1651 /// Returns true if the operation has a 0-D vector type operand or result.
1652 static bool hasZeroDimVectors(Operation *op) {
1653  auto hasZeroDimVectorType = [](Type type) -> bool {
1654  auto vecType = dyn_cast<VectorType>(type);
1655  return vecType && vecType.getRank() == 0;
1656  };
1657 
1658  return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1659  llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1660 }
1661 
1662 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1663 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1664  Operation *defOp = extractOp.getVector().getDefiningOp();
1665  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1666  return Value();
1667 
1668  Value source = defOp->getOperand(0);
1669  if (extractOp.getType() == source.getType())
1670  return source;
1671  auto getRank = [](Type type) {
1672  return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1673  : 0;
1674  };
1675 
1676  // If splat or broadcast from a scalar, just return the source scalar.
1677  unsigned broadcastSrcRank = getRank(source.getType());
1678  if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
1679  return source;
1680 
1681  unsigned extractResultRank = getRank(extractOp.getType());
1682  if (extractResultRank > broadcastSrcRank)
1683  return Value();
1684  // Check that the dimension of the result haven't been broadcasted.
1685  auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1686  auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
1687  if (extractVecType && broadcastVecType &&
1688  extractVecType.getShape() !=
1689  broadcastVecType.getShape().take_back(extractResultRank))
1690  return Value();
1691 
1692  auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1693  int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1694 
1695  // Detect all the positions that come from "dim-1" broadcasting.
1696  // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1697  // extract position to `0` when extracting from the source operand.
1698  llvm::SetVector<int64_t> broadcastedUnitDims =
1699  broadcastOp.computeBroadcastedUnitDims();
1700  SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
1701  OpBuilder b(extractOp.getContext());
1702  int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1703  for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1704  if (broadcastedUnitDims.contains(i))
1705  extractPos[i] = b.getIndexAttr(0);
1706  // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1707  // matching extract position when extracting from the source operand.
1708  int64_t rankDiff = broadcastSrcRank - extractResultRank;
1709  extractPos.erase(extractPos.begin(),
1710  std::next(extractPos.begin(), extractPos.size() - rankDiff));
1711  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1712  auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
1713  extractOp->setOperands(
1714  llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
1715  extractOp.setStaticPosition(staticPos);
1716  return extractOp.getResult();
1717 }
1718 
1719 /// Fold extractOp coming from ShuffleOp.
1720 ///
1721 /// Example:
1722 ///
1723 /// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1724 /// : vector<8xf32>, vector<8xf32>
1725 /// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1726 /// ->
1727 /// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1728 ///
1729 static Value foldExtractFromShuffle(ExtractOp extractOp) {
1730  // Dynamic positions are not folded as the resulting code would be more
1731  // complex than the input code.
1732  if (extractOp.hasDynamicPosition())
1733  return Value();
1734 
1735  auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1736  if (!shuffleOp)
1737  return Value();
1738 
1739  // TODO: 0-D or multi-dimensional vectors not supported yet.
1740  if (shuffleOp.getResultVectorType().getRank() != 1)
1741  return Value();
1742 
1743  int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1744  auto shuffleMask = shuffleOp.getMask();
1745  int64_t extractIdx = extractOp.getStaticPosition()[0];
1746  int64_t shuffleIdx = shuffleMask[extractIdx];
1747 
1748  // Find the shuffled vector to extract from based on the shuffle index.
1749  if (shuffleIdx < inputVecSize) {
1750  extractOp.setOperand(0, shuffleOp.getV1());
1751  extractOp.setStaticPosition({shuffleIdx});
1752  } else {
1753  extractOp.setOperand(0, shuffleOp.getV2());
1754  extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1755  }
1756 
1757  return extractOp.getResult();
1758 }
1759 
1760 // Fold extractOp with source coming from ShapeCast op.
1761 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1762  // TODO: Canonicalization for dynamic position not implemented yet.
1763  if (extractOp.hasDynamicPosition())
1764  return Value();
1765 
1766  auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1767  if (!shapeCastOp)
1768  return Value();
1769 
1770  // Get the nth dimension size starting from lowest dimension.
1771  auto getDimReverse = [](VectorType type, int64_t n) {
1772  return type.getShape().take_back(n + 1).front();
1773  };
1774  int64_t destinationRank =
1775  llvm::isa<VectorType>(extractOp.getType())
1776  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1777  : 0;
1778  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1779  return Value();
1780  if (destinationRank > 0) {
1781  auto destinationType =
1782  llvm::cast<VectorType>(extractOp.getResult().getType());
1783  for (int64_t i = 0; i < destinationRank; i++) {
1784  // The lowest dimension of the destination must match the lowest
1785  // dimension of the shapecast op source.
1786  // TODO: This case could be support in a canonicalization pattern.
1787  if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1788  getDimReverse(destinationType, i))
1789  return Value();
1790  }
1791  }
1792  // Extract the strides associated with the extract op vector source. Then use
1793  // this to calculate a linearized position for the extract.
1794  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1795  std::reverse(extractedPos.begin(), extractedPos.end());
1796  SmallVector<int64_t, 4> strides;
1797  int64_t stride = 1;
1798  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1799  strides.push_back(stride);
1800  stride *=
1801  getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1802  }
1803 
1804  int64_t position = linearize(extractedPos, strides);
1805  // Then extract the strides associated to the shapeCast op vector source and
1806  // delinearize the position using those strides.
1807  SmallVector<int64_t, 4> newStrides;
1808  int64_t numDimension =
1809  shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1810  stride = 1;
1811  for (int64_t i = 0; i < numDimension; i++) {
1812  newStrides.push_back(stride);
1813  stride *=
1814  getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1815  }
1816  std::reverse(newStrides.begin(), newStrides.end());
1817  SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1818  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1819  OpBuilder b(extractOp.getContext());
1820  extractOp.setStaticPosition(newPosition);
1821  extractOp.setOperand(0, shapeCastOp.getSource());
1822  return extractOp.getResult();
1823 }
1824 
1825 /// Fold an ExtractOp from ExtractStridedSliceOp.
1826 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1827  // TODO: Canonicalization for dynamic position not implemented yet.
1828  if (extractOp.hasDynamicPosition())
1829  return Value();
1830 
1831  auto extractStridedSliceOp =
1832  extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1833  if (!extractStridedSliceOp)
1834  return Value();
1835 
1836  // 0-D vectors not supported.
1837  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1838  if (hasZeroDimVectors(extractStridedSliceOp))
1839  return Value();
1840 
1841  // Return if 'extractStridedSliceOp' has non-unit strides.
1842  if (extractStridedSliceOp.hasNonUnitStrides())
1843  return Value();
1844 
1845  // Trim offsets for dimensions fully extracted.
1846  auto sliceOffsets =
1847  extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1848  while (!sliceOffsets.empty()) {
1849  size_t lastOffset = sliceOffsets.size() - 1;
1850  if (sliceOffsets.back() != 0 ||
1851  extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1852  extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1853  break;
1854  sliceOffsets.pop_back();
1855  }
1856  unsigned destinationRank = 0;
1857  if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1858  destinationRank = vecType.getRank();
1859  // The dimensions of the result need to be untouched by the
1860  // extractStridedSlice op.
1861  if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1862  sliceOffsets.size())
1863  return Value();
1864 
1865  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1866  assert(extractedPos.size() >= sliceOffsets.size());
1867  for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1868  extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1869  extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1870 
1871  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1872  OpBuilder b(extractOp.getContext());
1873  extractOp.setStaticPosition(extractedPos);
1874  return extractOp.getResult();
1875 }
1876 
1877 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1878 static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1879  // TODO: Canonicalization for dynamic position not implemented yet.
1880  if (extractOp.hasDynamicPosition())
1881  return Value();
1882 
1883  int64_t destinationRank =
1884  llvm::isa<VectorType>(extractOp.getType())
1885  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1886  : 0;
1887  auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1888  if (!insertOp)
1889  return Value();
1890 
1891  // 0-D vectors not supported.
1892  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1893  if (hasZeroDimVectors(insertOp))
1894  return Value();
1895 
1896  while (insertOp) {
1897  int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1898  insertOp.getSourceVectorType().getRank();
1899  if (destinationRank > insertOp.getSourceVectorType().getRank())
1900  return Value();
1901  auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1902  ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1903 
1904  if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1905  return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1906  }))
1907  return Value();
1908  bool disjoint = false;
1909  SmallVector<int64_t, 4> offsetDiffs;
1910  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1911  int64_t start = insertOffsets[dim];
1912  int64_t size =
1913  (dim < insertRankDiff)
1914  ? 1
1915  : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1916  int64_t end = start + size;
1917  int64_t offset = extractOffsets[dim];
1918  // Check if the start of the extract offset is in the interval inserted.
1919  if (start <= offset && offset < end) {
1920  if (dim >= insertRankDiff)
1921  offsetDiffs.push_back(offset - start);
1922  continue;
1923  }
1924  disjoint = true;
1925  break;
1926  }
1927  // The extract element chunk overlap with the vector inserted.
1928  if (!disjoint) {
1929  // If any of the inner dimensions are only partially inserted we have a
1930  // partial overlap.
1931  int64_t srcRankDiff =
1932  insertOp.getSourceVectorType().getRank() - destinationRank;
1933  for (int64_t i = 0; i < destinationRank; i++) {
1934  if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1935  insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1936  insertRankDiff))
1937  return Value();
1938  }
1939  extractOp.getVectorMutable().assign(insertOp.getSource());
1940  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1941  OpBuilder b(extractOp.getContext());
1942  extractOp.setStaticPosition(offsetDiffs);
1943  return extractOp.getResult();
1944  }
1945  // If the chunk extracted is disjoint from the chunk inserted, keep
1946  // looking in the insert chain.
1947  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1948  }
1949  return Value();
1950 }
1951 
1952 /// Try to fold the extraction of a scalar from a vector defined by
1953 /// vector.from_elements. E.g.:
1954 ///
1955 /// %0 = vector.from_elements %a, %b : vector<2xf32>
1956 /// %1 = vector.extract %0[0] : f32 from vector<2xf32>
1957 /// ==> fold to %a
1958 static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
1959  // Dynamic extractions cannot be folded.
1960  if (extractOp.hasDynamicPosition())
1961  return {};
1962 
1963  // Look for extract(from_elements).
1964  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
1965  if (!fromElementsOp)
1966  return {};
1967 
1968  // Scalable vectors are not supported.
1969  auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1970  if (vecType.isScalable())
1971  return {};
1972 
1973  // Only extractions of scalars are supported.
1974  int64_t rank = vecType.getRank();
1975  ArrayRef<int64_t> indices = extractOp.getStaticPosition();
1976  if (extractOp.getType() != vecType.getElementType())
1977  return {};
1978  assert(static_cast<int64_t>(indices.size()) == rank &&
1979  "unexpected number of indices");
1980 
1981  // Compute flattened/linearized index and fold to operand.
1982  int flatIndex = 0;
1983  int stride = 1;
1984  for (int i = rank - 1; i >= 0; --i) {
1985  flatIndex += indices[i] * stride;
1986  stride *= vecType.getDimSize(i);
1987  }
1988  return fromElementsOp.getElements()[flatIndex];
1989 }
1990 
1991 /// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
1992 /// then fold it.
1993 template <typename OpType, typename AdaptorType>
1994 static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
1995  SmallVectorImpl<Value> &operands) {
1996  std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
1997  OperandRange dynamicPosition = op.getDynamicPosition();
1998  ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
1999 
2000  // If the dynamic operands is empty, it is returned directly.
2001  if (!dynamicPosition.size())
2002  return {};
2003 
2004  // `index` is used to iterate over the `dynamicPosition`.
2005  unsigned index = 0;
2006 
2007  // `opChange` is a flag. If it is true, it means to update `op` in place.
2008  bool opChange = false;
2009  for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2010  if (!ShapedType::isDynamic(staticPosition[i]))
2011  continue;
2012  Attribute positionAttr = dynamicPositionAttr[index];
2013  Value position = dynamicPosition[index++];
2014  if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2015  staticPosition[i] = attr.getInt();
2016  opChange = true;
2017  continue;
2018  }
2019  operands.push_back(position);
2020  }
2021 
2022  if (opChange) {
2023  op.setStaticPosition(staticPosition);
2024  op.getOperation()->setOperands(operands);
2025  return op.getResult();
2026  }
2027  return {};
2028 }
2029 
2030 /// Fold an insert or extract operation into an poison value when a poison index
2031 /// is found at any dimension of the static position.
2033  ArrayRef<int64_t> staticPos,
2034  int64_t poisonVal) {
2035  if (!is_contained(staticPos, poisonVal))
2036  return {};
2037 
2038  return ub::PoisonAttr::get(context);
2039 }
2040 
2041 /// Fold a vector extract from is a poison source.
2043  if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2044  return srcAttr;
2045 
2046  return {};
2047 }
2048 
2049 /// Fold a vector extract extracting from a DenseElementsAttr.
2051  Attribute srcAttr) {
2052  auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2053  if (!denseAttr) {
2054  return {};
2055  }
2056 
2057  if (denseAttr.isSplat()) {
2058  Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2059  if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2060  newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2061  return newAttr;
2062  }
2063 
2064  auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2065  if (vecTy.isScalable())
2066  return {};
2067 
2068  if (extractOp.hasDynamicPosition()) {
2069  return {};
2070  }
2071 
2072  // Materializing subsets of a large constant array can generally lead to
2073  // explosion in IR size because of different combination of subsets that
2074  // can exist. However, vector.extract is a restricted form of subset
2075  // extract where you can only extract non-overlapping (or the same) subset for
2076  // a given rank of the subset. Because of this property, the IR size can only
2077  // increase at most by `rank * size(array)` from a single constant array being
2078  // extracted by multiple extracts.
2079 
2080  // Calculate the linearized position of the continuous chunk of elements to
2081  // extract.
2082  SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2083  copy(extractOp.getStaticPosition(), completePositions.begin());
2084  int64_t startPos =
2085  linearize(completePositions, computeStrides(vecTy.getShape()));
2086  auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2087 
2088  TypedAttr newAttr;
2089  if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2090  SmallVector<Attribute> elementValues(
2091  denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2092  newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2093  } else {
2094  newAttr = *denseValuesBegin;
2095  }
2096 
2097  return newAttr;
2098 }
2099 
2100 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2101  // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2102  // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2103  // mismatch).
2104  if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
2105  return getVector();
2106  if (auto res = foldPoisonIndexInsertExtractOp(
2107  getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2108  return res;
2109  if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
2110  return res;
2111  if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
2112  return res;
2113  if (succeeded(foldExtractOpFromExtractChain(*this)))
2114  return getResult();
2115  if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2116  return res;
2117  if (auto res = foldExtractFromBroadcast(*this))
2118  return res;
2119  if (auto res = foldExtractFromShuffle(*this))
2120  return res;
2121  if (auto res = foldExtractFromShapeCast(*this))
2122  return res;
2123  if (auto val = foldExtractFromExtractStrided(*this))
2124  return val;
2125  if (auto val = foldExtractStridedOpFromInsertChain(*this))
2126  return val;
2127  if (auto val = foldScalarExtractFromFromElements(*this))
2128  return val;
2129  SmallVector<Value> operands = {getVector()};
2130  if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
2131  return val;
2132  return OpFoldResult();
2133 }
2134 
2135 namespace {
2136 
2137 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2138 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2139 public:
2141 
2142  LogicalResult matchAndRewrite(ExtractOp extractOp,
2143  PatternRewriter &rewriter) const override {
2144  Operation *defOp = extractOp.getVector().getDefiningOp();
2145  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2146  return failure();
2147 
2148  Value source = defOp->getOperand(0);
2149  if (extractOp.getType() == source.getType())
2150  return failure();
2151  auto getRank = [](Type type) {
2152  return llvm::isa<VectorType>(type)
2153  ? llvm::cast<VectorType>(type).getRank()
2154  : 0;
2155  };
2156  unsigned broadcastSrcRank = getRank(source.getType());
2157  unsigned extractResultRank = getRank(extractOp.getType());
2158  // We only consider the case where the rank of the source is less than or
2159  // equal to the rank of the extract dst. The other cases are handled in the
2160  // folding patterns.
2161  if (extractResultRank < broadcastSrcRank)
2162  return failure();
2163  // For scalar result, the input can only be a rank-0 vector, which will
2164  // be handled by the folder.
2165  if (extractResultRank == 0)
2166  return failure();
2167 
2168  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2169  extractOp, extractOp.getType(), source);
2170  return success();
2171  }
2172 };
2173 
2174 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2175 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2176 public:
2178 
2179  LogicalResult matchAndRewrite(ExtractOp extractOp,
2180  PatternRewriter &rewriter) const override {
2181  auto createMaskOp =
2182  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2183  if (!createMaskOp)
2184  return failure();
2185 
2186  VectorType extractedMaskType =
2187  llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2188 
2189  if (!extractedMaskType)
2190  return failure();
2191 
2192  auto maskOperands = createMaskOp.getOperands();
2193  ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2194  VectorType maskType = createMaskOp.getVectorType();
2195 
2196  bool containsUnknownDims = false;
2197  bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2198 
2199  for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2200  dimIdx++) {
2201  int64_t pos = extractOpPos[dimIdx];
2202  Value operand = maskOperands[dimIdx];
2203  auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2204  if (!constantOp) {
2205  // Bounds of this dim unknown.
2206  containsUnknownDims = true;
2207  continue;
2208  }
2209 
2210  int64_t createMaskBound =
2211  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2212 
2213  if (pos != ShapedType::kDynamic) {
2214  // If any position is outside the range from the `create_mask`, then the
2215  // extracted mask will be all-false.
2216  allFalse |= pos >= createMaskBound;
2217  } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2218  // This dim is not all-true and since this is a dynamic index we don't
2219  // know if the extraction is within the true or false region.
2220  // Note: Zero dims have already handled via getMaskFormat().
2221  containsUnknownDims = true;
2222  }
2223  }
2224 
2225  if (allFalse) {
2226  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2227  extractOp, DenseElementsAttr::get(extractedMaskType, false));
2228  } else if (!containsUnknownDims) {
2229  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2230  extractOp, extractedMaskType,
2231  maskOperands.drop_front(extractOpPos.size()));
2232  } else {
2233  return failure();
2234  }
2235  return success();
2236  }
2237 };
2238 
2239 // Folds extract(shape_cast(..)) into shape_cast when the total element count
2240 // does not change.
2241 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2242  PatternRewriter &rewriter) {
2243  auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2244  if (!castOp)
2245  return failure();
2246 
2247  VectorType sourceType = castOp.getSourceVectorType();
2248  auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2249  if (!targetType)
2250  return failure();
2251 
2252  if (sourceType.getNumElements() != targetType.getNumElements())
2253  return failure();
2254 
2255  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2256  castOp.getSource());
2257  return success();
2258 }
2259 
2260 /// Try to canonicalize the extraction of a subvector from a vector defined by
2261 /// vector.from_elements. E.g.:
2262 ///
2263 /// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2264 /// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2265 /// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2266 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2267  PatternRewriter &rewriter) {
2268  // Dynamic positions are not supported.
2269  if (extractOp.hasDynamicPosition())
2270  return failure();
2271 
2272  // Scalar extracts are handled by the folder.
2273  auto resultType = dyn_cast<VectorType>(extractOp.getType());
2274  if (!resultType)
2275  return failure();
2276 
2277  // Look for extracts from a from_elements op.
2278  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2279  if (!fromElementsOp)
2280  return failure();
2281  VectorType inputType = fromElementsOp.getType();
2282 
2283  // Scalable vectors are not supported.
2284  if (resultType.isScalable() || inputType.isScalable())
2285  return failure();
2286 
2287  // Compute the position of first extracted element and flatten/linearize the
2288  // position.
2289  SmallVector<int64_t> firstElementPos =
2290  llvm::to_vector(extractOp.getStaticPosition());
2291  firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2292  int flatIndex = 0;
2293  int stride = 1;
2294  for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2295  flatIndex += firstElementPos[i] * stride;
2296  stride *= inputType.getDimSize(i);
2297  }
2298 
2299  // Replace the op with a smaller from_elements op.
2300  rewriter.replaceOpWithNewOp<FromElementsOp>(
2301  extractOp, resultType,
2302  fromElementsOp.getElements().slice(flatIndex,
2303  resultType.getNumElements()));
2304  return success();
2305 }
2306 
2307 } // namespace
2308 
2309 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2310  MLIRContext *context) {
2311  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2312  results.add(foldExtractFromShapeCastToShapeCast);
2313  results.add(foldExtractFromFromElements);
2314 }
2315 
2316 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
2317  SmallVectorImpl<int64_t> &results) {
2318  for (auto attr : arrayAttr)
2319  results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2320 }
2321 
2322 //===----------------------------------------------------------------------===//
2323 // FmaOp
2324 //===----------------------------------------------------------------------===//
2325 
2326 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2327  return llvm::to_vector<4>(getVectorType().getShape());
2328 }
2329 
2330 //===----------------------------------------------------------------------===//
2331 // FromElementsOp
2332 //===----------------------------------------------------------------------===//
2333 
2334 /// Rewrite a vector.from_elements into a vector.splat if all elements are the
2335 /// same SSA value. E.g.:
2336 ///
2337 /// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2338 /// ==> rewrite to vector.splat %a : vector<3xf32>
2339 static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
2340  PatternRewriter &rewriter) {
2341  if (!llvm::all_equal(fromElementsOp.getElements()))
2342  return failure();
2343  rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
2344  fromElementsOp.getElements().front());
2345  return success();
2346 }
2347 
2348 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2349  MLIRContext *context) {
2351 }
2352 
2353 //===----------------------------------------------------------------------===//
2354 // BroadcastOp
2355 //===----------------------------------------------------------------------===//
2356 
2357 void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2358  SetIntRangeFn setResultRanges) {
2359  setResultRanges(getResult(), argRanges.front());
2360 }
2361 
2362 /// Return the dimensions of the result vector that were formerly ones in the
2363 /// source tensor and thus correspond to "dim-1" broadcasting.
2366  ArrayRef<int64_t> dstShape) {
2367  int64_t rankDiff = dstShape.size() - srcShape.size();
2368  int64_t dstDim = rankDiff;
2370  for (auto [s1, s2] :
2371  llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2372  if (s1 != s2) {
2373  assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2374  res.insert(dstDim);
2375  }
2376  ++dstDim;
2377  }
2378  return res;
2379 }
2380 
2382  // Scalar broadcast is without any unit dim broadcast.
2383  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2384  if (!srcVectorType)
2385  return {};
2386  return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2387  getResultVectorType().getShape());
2388 }
2389 
2390 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2391 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
2392 /// This requires (and asserts) that the broadcast is free of "dim-1"
2393 /// broadcasting.
2394 /// Since vector.broadcast only allows expanding leading dimensions, an extra
2395 /// vector.transpose may be inserted to make the broadcast possible.
2396 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2397 /// the helper will assert. This means:
2398 /// 1. `dstShape` must not be empty.
2399 /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2400 /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2401 // must match the `value` shape.
2402 Value BroadcastOp::createOrFoldBroadcastOp(
2403  OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2404  const llvm::SetVector<int64_t> &broadcastedDims) {
2405  assert(!dstShape.empty() && "unexpected empty dst shape");
2406 
2407  // Well-formedness check.
2408  SmallVector<int64_t> checkShape;
2409  for (int i = 0, e = dstShape.size(); i < e; ++i) {
2410  if (broadcastedDims.contains(i))
2411  continue;
2412  checkShape.push_back(dstShape[i]);
2413  }
2414  assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2415  "ill-formed broadcastedDims contains values not confined to "
2416  "destVectorShape");
2417 
2418  Location loc = value.getLoc();
2419  Type elementType = getElementTypeOrSelf(value.getType());
2420  VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2421  VectorType dstVectorType = VectorType::get(dstShape, elementType);
2422 
2423  // Step 2. If scalar -> dstShape broadcast, just do it.
2424  if (!srcVectorType) {
2425  assert(checkShape.empty() &&
2426  "ill-formed createOrFoldBroadcastOp arguments");
2427  return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2428  }
2429 
2430  assert(srcVectorType.getShape().equals(checkShape) &&
2431  "ill-formed createOrFoldBroadcastOp arguments");
2432 
2433  // Step 3. Since vector.broadcast only allows creating leading dims,
2434  // vector -> dstShape broadcast may require a transpose.
2435  // Traverse the dims in order and construct:
2436  // 1. The leading entries of the broadcastShape that is guaranteed to be
2437  // achievable by a simple broadcast.
2438  // 2. The induced permutation for the subsequent vector.transpose that will
2439  // bring us from `broadcastShape` back to he desired `dstShape`.
2440  // If the induced permutation is not the identity, create a vector.transpose.
2441  SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2442  broadcastShape.reserve(dstShape.size());
2443  // Consider the example:
2444  // srcShape = 2x4
2445  // dstShape = 1x2x3x4x5
2446  // broadcastedDims = [0, 2, 4]
2447  //
2448  // We want to build:
2449  // broadcastShape = 1x3x5x2x4
2450  // permutation = [0, 2, 4, 1, 3]
2451  // ---V--- -----V-----
2452  // leading broadcast part src shape part
2453  //
2454  // Note that the trailing dims of broadcastShape are exactly the srcShape
2455  // by construction.
2456  // nextSrcShapeDim is used to keep track of where in the permutation the
2457  // "src shape part" occurs.
2458  int64_t nextSrcShapeDim = broadcastedDims.size();
2459  for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2460  if (broadcastedDims.contains(i)) {
2461  // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2462  // bring it to the head of the broadcastShape.
2463  // It will need to be permuted back from `broadcastShape.size() - 1` into
2464  // position `i`.
2465  broadcastShape.push_back(dstShape[i]);
2466  permutation[i] = broadcastShape.size() - 1;
2467  } else {
2468  // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2469  // shape and needs to be permuted into position `i`.
2470  // Don't touch `broadcastShape` here, the whole srcShape will be
2471  // appended after.
2472  permutation[i] = nextSrcShapeDim++;
2473  }
2474  }
2475  // 3.c. Append the srcShape.
2476  llvm::append_range(broadcastShape, srcVectorType.getShape());
2477 
2478  // Ensure there are no "dim-1" broadcasts.
2479  assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2480  .empty() &&
2481  "unexpected \"dim-1\" broadcast");
2482 
2483  VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2484  assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2485  vector::BroadcastableToResult::Success &&
2486  "must be broadcastable");
2487  Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2488  // Step 4. If we find any dimension that indeed needs to be permuted,
2489  // immediately return a new vector.transpose.
2490  for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2491  if (permutation[i] != i)
2492  return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2493  // Otherwise return res.
2494  return res;
2495 }
2496 
2498  Type srcType, VectorType dstVectorType,
2499  std::pair<VectorDim, VectorDim> *mismatchingDims) {
2500  // Broadcast scalar to vector of the same element type.
2501  if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
2502  getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
2503  return BroadcastableToResult::Success;
2504  // From now on, only vectors broadcast.
2505  VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2506  if (!srcVectorType)
2507  return BroadcastableToResult::SourceTypeNotAVector;
2508 
2509  int64_t srcRank = srcVectorType.getRank();
2510  int64_t dstRank = dstVectorType.getRank();
2511  if (srcRank > dstRank)
2512  return BroadcastableToResult::SourceRankHigher;
2513  // Source has an exact match or singleton value for all trailing dimensions
2514  // (all leading dimensions are simply duplicated).
2515  int64_t lead = dstRank - srcRank;
2516  for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2517  // Have mismatching dims (in the sense of vector.broadcast semantics) been
2518  // encountered?
2519  bool foundMismatchingDims = false;
2520 
2521  // Check fixed-width dims.
2522  int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2523  int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2524  if (srcDim != 1 && srcDim != dstDim)
2525  foundMismatchingDims = true;
2526 
2527  // Check scalable flags.
2528  bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2529  bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2530  if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2531  // 1 -> [N] is fine, everything else should be rejected when mixing
2532  // fixed-width and scalable dims
2533  (srcDimScalableFlag != dstDimScalableFlag &&
2534  (srcDim != 1 || srcDimScalableFlag)))
2535  foundMismatchingDims = true;
2536 
2537  if (foundMismatchingDims) {
2538  if (mismatchingDims != nullptr) {
2539  mismatchingDims->first.dim = srcDim;
2540  mismatchingDims->first.isScalable = srcDimScalableFlag;
2541 
2542  mismatchingDims->second.dim = dstDim;
2543  mismatchingDims->second.isScalable = dstDimScalableFlag;
2544  }
2545  return BroadcastableToResult::DimensionMismatch;
2546  }
2547  }
2548 
2549  return BroadcastableToResult::Success;
2550 }
2551 
2552 LogicalResult BroadcastOp::verify() {
2553  std::pair<VectorDim, VectorDim> mismatchingDims;
2555  getSourceType(), getResultVectorType(), &mismatchingDims);
2556  if (res == BroadcastableToResult::Success)
2557  return success();
2558  if (res == BroadcastableToResult::SourceRankHigher)
2559  return emitOpError("source rank higher than destination rank");
2560  if (res == BroadcastableToResult::DimensionMismatch) {
2561  return emitOpError("dimension mismatch (")
2562  << (mismatchingDims.first.isScalable ? "[" : "")
2563  << mismatchingDims.first.dim
2564  << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
2565  << (mismatchingDims.second.isScalable ? "[" : "")
2566  << mismatchingDims.second.dim
2567  << (mismatchingDims.second.isScalable ? "]" : "") << ")";
2568  }
2569  if (res == BroadcastableToResult::SourceTypeNotAVector)
2570  return emitOpError("source type is not a vector");
2571  llvm_unreachable("unexpected vector.broadcast op error");
2572 }
2573 
2574 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
2575  if (getSourceType() == getResultVectorType())
2576  return getSource();
2577  if (!adaptor.getSource())
2578  return {};
2579  auto vectorType = getResultVectorType();
2580  if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
2581  if (vectorType.getElementType() != attr.getType())
2582  return {};
2583  return DenseElementsAttr::get(vectorType, attr);
2584  }
2585  if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
2586  if (vectorType.getElementType() != attr.getType())
2587  return {};
2588  return DenseElementsAttr::get(vectorType, attr);
2589  }
2590  if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2591  return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
2592  return {};
2593 }
2594 
2595 namespace {
2596 
2597 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
2598 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2600 
2601  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2602  PatternRewriter &rewriter) const override {
2603  auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2604  if (!srcBroadcast)
2605  return failure();
2606  rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
2607  broadcastOp.getResultVectorType(),
2608  srcBroadcast.getSource());
2609  return success();
2610  }
2611 };
2612 } // namespace
2613 
2614 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2615  MLIRContext *context) {
2616  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2617  // calling `populateCastAwayVectorLeadingOneDimPatterns`
2618  results.add<BroadcastFolder>(context);
2619 }
2620 
2621 //===----------------------------------------------------------------------===//
2622 // ShuffleOp
2623 //===----------------------------------------------------------------------===//
2624 
2625 LogicalResult ShuffleOp::verify() {
2626  VectorType resultType = getResultVectorType();
2627  VectorType v1Type = getV1VectorType();
2628  VectorType v2Type = getV2VectorType();
2629  // Verify ranks.
2630  int64_t resRank = resultType.getRank();
2631  int64_t v1Rank = v1Type.getRank();
2632  int64_t v2Rank = v2Type.getRank();
2633  bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2634  bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2635  if (!wellFormed0DCase && !wellFormedNDCase)
2636  return emitOpError("rank mismatch");
2637 
2638  // Verify all but leading dimension sizes.
2639  for (int64_t r = 1; r < v1Rank; ++r) {
2640  int64_t resDim = resultType.getDimSize(r);
2641  int64_t v1Dim = v1Type.getDimSize(r);
2642  int64_t v2Dim = v2Type.getDimSize(r);
2643  if (resDim != v1Dim || v1Dim != v2Dim)
2644  return emitOpError("dimension mismatch");
2645  }
2646  // Verify mask length.
2647  ArrayRef<int64_t> mask = getMask();
2648  int64_t maskLength = mask.size();
2649  if (maskLength <= 0)
2650  return emitOpError("invalid mask length");
2651  if (maskLength != resultType.getDimSize(0))
2652  return emitOpError("mask length mismatch");
2653  // Verify all indices.
2654  int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2655  (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2656  for (auto [idx, maskPos] : llvm::enumerate(mask)) {
2657  if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
2658  return emitOpError("mask index #") << (idx + 1) << " out of range";
2659  }
2660  return success();
2661 }
2662 
2663 LogicalResult
2664 ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
2665  ShuffleOp::Adaptor adaptor,
2666  SmallVectorImpl<Type> &inferredReturnTypes) {
2667  auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2668  auto v1Rank = v1Type.getRank();
2669  // Construct resulting type: leading dimension matches mask
2670  // length, all trailing dimensions match the operands.
2672  shape.reserve(v1Rank);
2673  shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2674  // In the 0-D case there is no trailing shape to append.
2675  if (v1Rank > 0)
2676  llvm::append_range(shape, v1Type.getShape().drop_front());
2677  inferredReturnTypes.push_back(
2678  VectorType::get(shape, v1Type.getElementType()));
2679  return success();
2680 }
2681 
2682 template <typename T>
2683 static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
2684  T expected = begin;
2685  return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2686  return value == expected++;
2687  });
2688 }
2689 
2690 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2691  auto v1Type = getV1VectorType();
2692  auto v2Type = getV2VectorType();
2693 
2694  assert(!v1Type.isScalable() && !v2Type.isScalable() &&
2695  "Vector shuffle does not support scalable vectors");
2696 
2697  // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
2698  // but must be a canonicalization into a vector.broadcast.
2699  if (v1Type.getRank() == 0)
2700  return {};
2701 
2702  // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
2703  auto mask = getMask();
2704  if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
2705  return getV1();
2706  // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
2707  if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
2708  return getV2();
2709 
2710  Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
2711  if (!v1Attr || !v2Attr)
2712  return {};
2713 
2714  // Fold shuffle poison, poison -> poison.
2715  bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
2716  bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
2717  if (isV1Poison && isV2Poison)
2718  return ub::PoisonAttr::get(getContext());
2719 
2720  // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2721  // manipulation.
2722  if (v1Type.getRank() != 1)
2723  return {};
2724 
2725  // Poison input attributes need special handling as they are not
2726  // DenseElementsAttr. If an index is poison, we select the first element of
2727  // the first non-poison input.
2728  SmallVector<Attribute> v1Elements, v2Elements;
2729  Attribute poisonElement;
2730  if (!isV2Poison) {
2731  v2Elements =
2732  to_vector(cast<DenseElementsAttr>(v2Attr).getValues<Attribute>());
2733  poisonElement = v2Elements[0];
2734  }
2735  if (!isV1Poison) {
2736  v1Elements =
2737  to_vector(cast<DenseElementsAttr>(v1Attr).getValues<Attribute>());
2738  poisonElement = v1Elements[0];
2739  }
2740 
2741  SmallVector<Attribute> results;
2742  int64_t v1Size = v1Type.getDimSize(0);
2743  for (int64_t maskIdx : mask) {
2744  Attribute indexedElm;
2745  // TODO: Return a partial poison vector when supported by the UB dialect.
2746  if (maskIdx == ShuffleOp::kPoisonIndex) {
2747  indexedElm = poisonElement;
2748  } else {
2749  if (maskIdx < v1Size)
2750  indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
2751  else
2752  indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
2753  }
2754 
2755  results.push_back(indexedElm);
2756  }
2757 
2758  return DenseElementsAttr::get(getResultVectorType(), results);
2759 }
2760 
2761 namespace {
2762 
2763 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
2764 // to a broadcast.
2765 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
2767 
2768  LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
2769  PatternRewriter &rewriter) const override {
2770  VectorType v1VectorType = shuffleOp.getV1VectorType();
2771  ArrayRef<int64_t> mask = shuffleOp.getMask();
2772  if (v1VectorType.getRank() > 0)
2773  return failure();
2774  if (mask.size() != 1)
2775  return failure();
2776  VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
2777  if (mask[0] == 0)
2778  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2779  shuffleOp.getV1());
2780  else
2781  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2782  shuffleOp.getV2());
2783  return success();
2784  }
2785 };
2786 
2787 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
2788 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
2789 public:
2791 
2792  LogicalResult matchAndRewrite(ShuffleOp op,
2793  PatternRewriter &rewriter) const override {
2794  auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2795  auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2796 
2797  if (!v1Splat || !v2Splat)
2798  return failure();
2799 
2800  if (v1Splat.getInput() != v2Splat.getInput())
2801  return failure();
2802 
2803  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
2804  return success();
2805  }
2806 };
2807 
2808 /// Pattern to rewrite a fixed-size interleave via vector.shuffle to
2809 /// vector.interleave.
2810 class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
2811 public:
2813 
2814  LogicalResult matchAndRewrite(ShuffleOp op,
2815  PatternRewriter &rewriter) const override {
2816  VectorType resultType = op.getResultVectorType();
2817  if (resultType.isScalable())
2818  return rewriter.notifyMatchFailure(
2819  op, "ShuffleOp can't represent a scalable interleave");
2820 
2821  if (resultType.getRank() != 1)
2822  return rewriter.notifyMatchFailure(
2823  op, "ShuffleOp can't represent an n-D interleave");
2824 
2825  VectorType sourceType = op.getV1VectorType();
2826  if (sourceType != op.getV2VectorType() ||
2827  sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2828  return rewriter.notifyMatchFailure(
2829  op, "ShuffleOp types don't match an interleave");
2830  }
2831 
2832  ArrayRef<int64_t> shuffleMask = op.getMask();
2833  int64_t resultVectorSize = resultType.getNumElements();
2834  for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2835  int64_t maskValueA = shuffleMask[i * 2];
2836  int64_t maskValueB = shuffleMask[(i * 2) + 1];
2837  if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2838  return rewriter.notifyMatchFailure(op,
2839  "ShuffleOp mask not interleaving");
2840  }
2841 
2842  rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
2843  return success();
2844  }
2845 };
2846 
2847 } // namespace
2848 
2849 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
2850  MLIRContext *context) {
2851  results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2852  context);
2853 }
2854 
2855 //===----------------------------------------------------------------------===//
2856 // InsertElementOp
2857 //===----------------------------------------------------------------------===//
2858 
2859 void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2860  SetIntRangeFn setResultRanges) {
2861  setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2862 }
2863 
2864 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
2865  Value source, Value dest) {
2866  build(builder, result, source, dest, {});
2867 }
2868 
2869 LogicalResult InsertElementOp::verify() {
2870  auto dstVectorType = getDestVectorType();
2871  if (dstVectorType.getRank() == 0) {
2872  if (getPosition())
2873  return emitOpError("expected position to be empty with 0-D vector");
2874  return success();
2875  }
2876  if (dstVectorType.getRank() != 1)
2877  return emitOpError("unexpected >1 vector rank");
2878  if (!getPosition())
2879  return emitOpError("expected position for 1-D vector");
2880  return success();
2881 }
2882 
2883 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2884  // Skip the 0-D vector here.
2885  if (!adaptor.getPosition())
2886  return {};
2887 
2888  auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2889  auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2890  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2891  if (!src || !dst || !pos)
2892  return {};
2893 
2894  if (src.getType() != getDestVectorType().getElementType())
2895  return {};
2896 
2897  auto dstElements = dst.getValues<Attribute>();
2898 
2899  SmallVector<Attribute> results(dstElements);
2900 
2901  uint64_t posIdx = pos.getInt();
2902  if (posIdx >= results.size())
2903  return {};
2904  results[posIdx] = src;
2905 
2906  return DenseElementsAttr::get(getDestVectorType(), results);
2907 }
2908 
2909 //===----------------------------------------------------------------------===//
2910 // InsertOp
2911 //===----------------------------------------------------------------------===//
2912 
2913 void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2914  SetIntRangeFn setResultRanges) {
2915  setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2916 }
2917 
2918 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2919  Value source, Value dest) {
2920  auto vectorTy = cast<VectorType>(dest.getType());
2921  build(builder, result, source, dest,
2922  SmallVector<int64_t>(vectorTy.getRank(), 0));
2923 }
2924 
2925 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2926  Value source, Value dest, int64_t position) {
2927  build(builder, result, source, dest, ArrayRef<int64_t>{position});
2928 }
2929 
2930 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2931  Value source, Value dest, OpFoldResult position) {
2932  build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
2933 }
2934 
2935 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2936  Value source, Value dest,
2937  ArrayRef<int64_t> position) {
2938  SmallVector<OpFoldResult> posVals;
2939  posVals.reserve(position.size());
2940  llvm::transform(position, std::back_inserter(posVals),
2941  [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
2942  build(builder, result, source, dest, posVals);
2943 }
2944 
2945 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2946  Value source, Value dest,
2947  ArrayRef<OpFoldResult> position) {
2948  SmallVector<int64_t> staticPos;
2949  SmallVector<Value> dynamicPos;
2950  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
2951  build(builder, result, source, dest, dynamicPos,
2952  builder.getDenseI64ArrayAttr(staticPos));
2953 }
2954 
2955 LogicalResult InsertOp::verify() {
2956  SmallVector<OpFoldResult> position = getMixedPosition();
2957  auto destVectorType = getDestVectorType();
2958  if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
2959  return emitOpError(
2960  "expected position attribute of rank no greater than dest vector rank");
2961  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2962  if (srcVectorType &&
2963  (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2964  static_cast<unsigned>(destVectorType.getRank())))
2965  return emitOpError("expected position attribute rank + source rank to "
2966  "match dest vector rank");
2967  if (!srcVectorType &&
2968  (position.size() != static_cast<unsigned>(destVectorType.getRank())))
2969  return emitOpError(
2970  "expected position attribute rank to match the dest vector rank");
2971  for (auto [idx, pos] : llvm::enumerate(position)) {
2972  if (auto attr = pos.dyn_cast<Attribute>()) {
2973  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2974  if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
2975  destVectorType.getDimSize(idx))) {
2976  return emitOpError("expected position attribute #")
2977  << (idx + 1)
2978  << " to be a non-negative integer smaller than the "
2979  "corresponding "
2980  "dest vector dimension";
2981  }
2982  }
2983  }
2984  return success();
2985 }
2986 
2987 namespace {
2988 
2989 // If insertOp is only inserting unit dimensions it can be transformed to a
2990 // broadcast.
2991 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
2992 public:
2994 
2995  LogicalResult matchAndRewrite(InsertOp insertOp,
2996  PatternRewriter &rewriter) const override {
2997  auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
2998  if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2999  srcVecType.getNumElements())
3000  return failure();
3001  rewriter.replaceOpWithNewOp<BroadcastOp>(
3002  insertOp, insertOp.getDestVectorType(), insertOp.getSource());
3003  return success();
3004  }
3005 };
3006 
3007 /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
3008 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3009 public:
3011 
3012  LogicalResult matchAndRewrite(InsertOp op,
3013  PatternRewriter &rewriter) const override {
3014  auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
3015  auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
3016 
3017  if (!srcSplat || !dstSplat)
3018  return failure();
3019 
3020  if (srcSplat.getInput() != dstSplat.getInput())
3021  return failure();
3022 
3023  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
3024  return success();
3025  }
3026 };
3027 
3028 } // namespace
3029 
3030 static Attribute
3032  Attribute dstAttr,
3033  int64_t maxVectorSizeFoldThreshold) {
3034  if (insertOp.hasDynamicPosition())
3035  return {};
3036 
3037  auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3038  if (!denseDst)
3039  return {};
3040 
3041  if (!srcAttr) {
3042  return {};
3043  }
3044 
3045  VectorType destTy = insertOp.getDestVectorType();
3046  if (destTy.isScalable())
3047  return {};
3048 
3049  // Make sure we do not create too many large constants.
3050  if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3051  !insertOp->hasOneUse())
3052  return {};
3053 
3054  // Calculate the linearized position of the continuous chunk of elements to
3055  // insert.
3056  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3057  copy(insertOp.getStaticPosition(), completePositions.begin());
3058  int64_t insertBeginPosition =
3059  linearize(completePositions, computeStrides(destTy.getShape()));
3060 
3061  SmallVector<Attribute> insertedValues;
3062  Type destEltType = destTy.getElementType();
3063 
3064  /// Converts the expected type to an IntegerAttr if there's
3065  /// a mismatch.
3066  auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
3067  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3068  if (intAttr.getType() != expectedType)
3069  return IntegerAttr::get(expectedType, intAttr.getInt());
3070  }
3071  return attr;
3072  };
3073 
3074  // The `convertIntegerAttr` method specifically handles the case
3075  // for `llvm.mlir.constant` which can hold an attribute with a
3076  // different type than the return type.
3077  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3078  for (auto value : denseSource.getValues<Attribute>())
3079  insertedValues.push_back(convertIntegerAttr(value, destEltType));
3080  } else {
3081  insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
3082  }
3083 
3084  auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
3085  copy(insertedValues, allValues.begin() + insertBeginPosition);
3086  auto newAttr = DenseElementsAttr::get(destTy, allValues);
3087 
3088  return newAttr;
3089 }
3090 
3091 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3092  MLIRContext *context) {
3093  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
3094 }
3095 
3096 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3097  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3098  // unless the source vector constant has a single use.
3099  constexpr int64_t vectorSizeFoldThreshold = 256;
3100  // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3101  // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3102  // (type mismatch).
3103  if (getNumIndices() == 0 && getSourceType() == getType())
3104  return getSource();
3105  SmallVector<Value> operands = {getSource(), getDest()};
3106  if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
3107  return val;
3108  if (auto res = foldPoisonIndexInsertExtractOp(
3109  getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3110  return res;
3111  if (auto res = foldDenseElementsAttrDestInsertOp(*this, adaptor.getSource(),
3112  adaptor.getDest(),
3113  vectorSizeFoldThreshold)) {
3114  return res;
3115  }
3116 
3117  return {};
3118 }
3119 
3120 //===----------------------------------------------------------------------===//
3121 // InsertStridedSliceOp
3122 //===----------------------------------------------------------------------===//
3123 
3124 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3125  Value source, Value dest,
3126  ArrayRef<int64_t> offsets,
3127  ArrayRef<int64_t> strides) {
3128  result.addOperands({source, dest});
3129  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3130  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3131  result.addTypes(dest.getType());
3132  result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
3133  offsetsAttr);
3134  result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
3135  stridesAttr);
3136 }
3137 
3138 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
3139 template <typename OpType>
3140 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3141  ArrayAttr arrayAttr,
3142  ArrayRef<int64_t> shape,
3143  StringRef attrName) {
3144  if (arrayAttr.size() > shape.size())
3145  return op.emitOpError("expected ")
3146  << attrName << " attribute of rank no greater than vector rank";
3147  return success();
3148 }
3149 
3150 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3151 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3152 // Otherwise, the admissible interval is [min, max].
3153 template <typename OpType>
3154 static LogicalResult
3155 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
3156  int64_t max, StringRef attrName,
3157  bool halfOpen = true) {
3158  for (auto attr : arrayAttr) {
3159  auto val = llvm::cast<IntegerAttr>(attr).getInt();
3160  auto upper = max;
3161  if (!halfOpen)
3162  upper += 1;
3163  if (val < min || val >= upper)
3164  return op.emitOpError("expected ") << attrName << " to be confined to ["
3165  << min << ", " << upper << ")";
3166  }
3167  return success();
3168 }
3169 
3170 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3171 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3172 // Otherwise, the admissible interval is [min, max].
3173 template <typename OpType>
3174 static LogicalResult
3175 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
3176  ArrayRef<int64_t> shape, StringRef attrName,
3177  bool halfOpen = true, int64_t min = 0) {
3178  for (auto [index, attrDimPair] :
3179  llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3180  int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3181  int64_t max = std::get<1>(attrDimPair);
3182  if (!halfOpen)
3183  max += 1;
3184  if (val < min || val >= max)
3185  return op.emitOpError("expected ")
3186  << attrName << " dimension " << index << " to be confined to ["
3187  << min << ", " << max << ")";
3188  }
3189  return success();
3190 }
3191 
3192 // Returns true if, for all indices i = 0..shape.size()-1, val is in the
3193 // [min, max} interval:
3194 // val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3195 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3196 // the admissible interval is [min, max].
3197 template <typename OpType>
3199  OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3200  ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3201  bool halfOpen = true, int64_t min = 1) {
3202  assert(arrayAttr1.size() <= shape.size());
3203  assert(arrayAttr2.size() <= shape.size());
3204  for (auto [index, it] :
3205  llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3206  auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3207  auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3208  int64_t max = std::get<2>(it);
3209  if (!halfOpen)
3210  max += 1;
3211  if (val1 + val2 < 0 || val1 + val2 >= max)
3212  return op.emitOpError("expected sum(")
3213  << attrName1 << ", " << attrName2 << ") dimension " << index
3214  << " to be confined to [" << min << ", " << max << ")";
3215  }
3216  return success();
3217 }
3218 
3219 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
3220  MLIRContext *context) {
3221  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3222  return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3223  });
3224  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3225 }
3226 
3227 LogicalResult InsertStridedSliceOp::verify() {
3228  auto sourceVectorType = getSourceVectorType();
3229  auto destVectorType = getDestVectorType();
3230  auto offsets = getOffsetsAttr();
3231  auto strides = getStridesAttr();
3232  if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3233  return emitOpError(
3234  "expected offsets of same size as destination vector rank");
3235  if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3236  return emitOpError("expected strides of same size as source vector rank");
3237  if (sourceVectorType.getRank() > destVectorType.getRank())
3238  return emitOpError(
3239  "expected source rank to be no greater than destination rank");
3240 
3241  auto sourceShape = sourceVectorType.getShape();
3242  auto destShape = destVectorType.getShape();
3243  SmallVector<int64_t, 4> sourceShapeAsDestShape(
3244  destShape.size() - sourceShape.size(), 0);
3245  sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3246  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3247  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3248  if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3249  offName)) ||
3250  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3251  /*max=*/1, stridesName,
3252  /*halfOpen=*/false)) ||
3254  *this, offsets,
3255  makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3256  offName, "source vector shape",
3257  /*halfOpen=*/false, /*min=*/1)))
3258  return failure();
3259 
3260  unsigned rankDiff = destShape.size() - sourceShape.size();
3261  for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3262  if (sourceVectorType.getScalableDims()[idx] !=
3263  destVectorType.getScalableDims()[idx + rankDiff]) {
3264  return emitOpError("mismatching scalable flags (at source vector idx=")
3265  << idx << ")";
3266  }
3267  if (sourceVectorType.getScalableDims()[idx]) {
3268  auto sourceSize = sourceShape[idx];
3269  auto destSize = destShape[idx + rankDiff];
3270  if (sourceSize != destSize) {
3271  return emitOpError("expected size at idx=")
3272  << idx
3273  << (" to match the corresponding base size from the input "
3274  "vector (")
3275  << sourceSize << (" vs ") << destSize << (")");
3276  }
3277  }
3278  }
3279 
3280  return success();
3281 }
3282 
3283 namespace {
3284 /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3285 /// SplatOp(X):dst_type) to SplatOp(X):dst_type.
3286 class FoldInsertStridedSliceSplat final
3287  : public OpRewritePattern<InsertStridedSliceOp> {
3288 public:
3290 
3291  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3292  PatternRewriter &rewriter) const override {
3293  auto srcSplatOp =
3294  insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
3295  auto destSplatOp =
3296  insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3297 
3298  if (!srcSplatOp || !destSplatOp)
3299  return failure();
3300 
3301  if (srcSplatOp.getInput() != destSplatOp.getInput())
3302  return failure();
3303 
3304  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3305  return success();
3306  }
3307 };
3308 
3309 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3310 /// to dst.
3311 class FoldInsertStridedSliceOfExtract final
3312  : public OpRewritePattern<InsertStridedSliceOp> {
3313 public:
3315 
3316  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3317  PatternRewriter &rewriter) const override {
3318  auto extractStridedSliceOp =
3319  insertStridedSliceOp.getSource()
3320  .getDefiningOp<vector::ExtractStridedSliceOp>();
3321 
3322  if (!extractStridedSliceOp)
3323  return failure();
3324 
3325  if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3326  return failure();
3327 
3328  // Check if have the same strides and offsets.
3329  if (extractStridedSliceOp.getStrides() !=
3330  insertStridedSliceOp.getStrides() ||
3331  extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3332  return failure();
3333 
3334  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3335  return success();
3336  }
3337 };
3338 
3339 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3340 // ConstantOp.
3341 class InsertStridedSliceConstantFolder final
3342  : public OpRewritePattern<InsertStridedSliceOp> {
3343 public:
3345 
3346  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3347  // unless the source vector constant has a single use.
3348  static constexpr int64_t vectorSizeFoldThreshold = 256;
3349 
3350  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3351  PatternRewriter &rewriter) const override {
3352  // Return if 'InsertOp' operand is not defined by a compatible vector
3353  // ConstantOp.
3354  TypedValue<VectorType> destVector = op.getDest();
3355  Attribute vectorDestCst;
3356  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3357  return failure();
3358 
3359  VectorType destTy = destVector.getType();
3360  if (destTy.isScalable())
3361  return failure();
3362 
3363  // Make sure we do not create too many large constants.
3364  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3365  !destVector.hasOneUse())
3366  return failure();
3367 
3368  TypedValue<VectorType> sourceValue = op.getSource();
3369  Attribute sourceCst;
3370  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3371  return failure();
3372 
3373  // TODO: Support poison.
3374  if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3375  return failure();
3376 
3377  // TODO: Handle non-unit strides when they become available.
3378  if (op.hasNonUnitStrides())
3379  return failure();
3380 
3381  VectorType sliceVecTy = sourceValue.getType();
3382  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3383  int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3384  SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3385  SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3386 
3387  // Calcualte the destination element indices by enumerating all slice
3388  // positions within the destination and linearizing them. The enumeration
3389  // order is lexicographic which yields a sequence of monotonically
3390  // increasing linearized position indices.
3391  // Because the destination may have higher dimensionality then the slice,
3392  // we keep track of two overlapping sets of positions and offsets.
3393  auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3394  auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3395  auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3396  auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3397  SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3398  MutableArrayRef<int64_t> currSlicePosition(
3399  currDestPosition.begin() + rankDifference, currDestPosition.end());
3400  ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3401  offsets.end());
3402  do {
3403  int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3404  assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3405  assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3406  "Invalid slice element");
3407  newValues[linearizedPosition] = *sliceValuesIt;
3408  ++sliceValuesIt;
3409  } while (succeeded(
3410  incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3411 
3412  auto newAttr = DenseElementsAttr::get(destTy, newValues);
3413  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3414  return success();
3415  }
3416 };
3417 
3418 } // namespace
3419 
3420 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3421  RewritePatternSet &results, MLIRContext *context) {
3422  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3423  InsertStridedSliceConstantFolder>(context);
3424 }
3425 
3426 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3427  if (getSourceVectorType() == getDestVectorType())
3428  return getSource();
3429  return {};
3430 }
3431 
3432 //===----------------------------------------------------------------------===//
3433 // OuterProductOp
3434 //===----------------------------------------------------------------------===//
3435 
3436 /// Build an op without mask, use the type of `acc` as the return type.
3437 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
3438  Value lhs, Value rhs, Value acc) {
3439  result.addOperands({lhs, rhs, acc});
3440  result.addTypes(acc.getType());
3441 }
3442 
3444  p << " " << getLhs() << ", " << getRhs();
3445  if (getAcc()) {
3446  p << ", " << getAcc();
3447  p.printOptionalAttrDict((*this)->getAttrs());
3448  }
3449  p << " : " << getLhs().getType() << ", " << getRhs().getType();
3450 }
3451 
3452 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
3454  Type tLHS, tRHS;
3455  if (parser.parseOperandList(operandsInfo) ||
3456  parser.parseOptionalAttrDict(result.attributes) ||
3457  parser.parseColonType(tLHS) || parser.parseComma() ||
3458  parser.parseType(tRHS))
3459  return failure();
3460  if (operandsInfo.size() < 2)
3461  return parser.emitError(parser.getNameLoc(),
3462  "expected at least 2 operands");
3463  VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3464  VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3465  if (!vLHS)
3466  return parser.emitError(parser.getNameLoc(),
3467  "expected vector type for operand #1");
3468 
3469  VectorType resType;
3470  if (vRHS) {
3471  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3472  vRHS.getScalableDims()[0]};
3473  resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
3474  vLHS.getElementType(), scalableDimsRes);
3475  } else {
3476  // Scalar RHS operand
3477  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3478  resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3479  scalableDimsRes);
3480  }
3481 
3482  if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
3483  result.attributes.append(
3484  OuterProductOp::getKindAttrName(result.name),
3486  OuterProductOp::getDefaultKind()));
3487  }
3488 
3489  return failure(
3490  parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
3491  parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
3492  (operandsInfo.size() > 2 &&
3493  parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
3494  parser.addTypeToList(resType, result.types));
3495 }
3496 
3497 LogicalResult OuterProductOp::verify() {
3498  Type tRHS = getOperandTypeRHS();
3499  VectorType vLHS = getOperandVectorTypeLHS(),
3500  vRHS = llvm::dyn_cast<VectorType>(tRHS),
3501  vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3502 
3503  if (vLHS.getRank() != 1)
3504  return emitOpError("expected 1-d vector for operand #1");
3505 
3506  if (vRHS) {
3507  // Proper OUTER operation.
3508  if (vRHS.getRank() != 1)
3509  return emitOpError("expected 1-d vector for operand #2");
3510  if (vRES.getRank() != 2)
3511  return emitOpError("expected 2-d vector result");
3512  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3513  return emitOpError("expected #1 operand dim to match result dim #1");
3514  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3515  return emitOpError("expected #2 operand dim to match result dim #2");
3516  if (vLHS.isScalable() && !vRHS.isScalable()) {
3517  // This restriction reflects what's currently supported in terms of
3518  // scalable vectors. However, we could relax this if there's a use case.
3519  return emitOpError(
3520  "expected either both or only #2 operand dim to be scalable");
3521  }
3522  } else {
3523  // An AXPY operation.
3524  if (vRES.getRank() != 1)
3525  return emitOpError("expected 1-d vector result");
3526  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3527  return emitOpError("expected #1 operand dim to match result dim #1");
3528  }
3529 
3530  if (vACC && vACC != vRES)
3531  return emitOpError("expected operand #3 of same type as result type");
3532 
3533  // Verify supported combining kind.
3534  if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
3535  return emitOpError("unsupported outerproduct type");
3536 
3537  return success();
3538 }
3539 
3540 // MaskableOpInterface methods.
3541 
3542 /// Returns the mask type expected by this operation. Mostly used for
3543 /// verification purposes. It requires the operation to be vectorized."
3544 Type OuterProductOp::getExpectedMaskType() {
3545  auto vecType = this->getResultVectorType();
3546  return VectorType::get(vecType.getShape(),
3547  IntegerType::get(vecType.getContext(), /*width=*/1),
3548  vecType.getScalableDims());
3549 }
3550 
3551 //===----------------------------------------------------------------------===//
3552 // ExtractStridedSliceOp
3553 //===----------------------------------------------------------------------===//
3554 
3555 // Inference works as follows:
3556 // 1. Add 'sizes' from prefix of dims in 'offsets'.
3557 // 2. Add sizes from 'vectorType' for remaining dims.
3558 // Scalable flags are inherited from 'vectorType'.
3559 static Type inferStridedSliceOpResultType(VectorType vectorType,
3560  ArrayAttr offsets, ArrayAttr sizes,
3561  ArrayAttr strides) {
3562  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3564  shape.reserve(vectorType.getRank());
3565  unsigned idx = 0;
3566  for (unsigned e = offsets.size(); idx < e; ++idx)
3567  shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3568  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3569  shape.push_back(vectorType.getShape()[idx]);
3570 
3571  return VectorType::get(shape, vectorType.getElementType(),
3572  vectorType.getScalableDims());
3573 }
3574 
3575 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3576  Value source, ArrayRef<int64_t> offsets,
3577  ArrayRef<int64_t> sizes,
3578  ArrayRef<int64_t> strides) {
3579  result.addOperands(source);
3580  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3581  auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
3582  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3583  result.addTypes(
3584  inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
3585  offsetsAttr, sizesAttr, stridesAttr));
3586  result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
3587  offsetsAttr);
3588  result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
3589  sizesAttr);
3590  result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
3591  stridesAttr);
3592 }
3593 
3594 LogicalResult ExtractStridedSliceOp::verify() {
3595  auto type = getSourceVectorType();
3596  auto offsets = getOffsetsAttr();
3597  auto sizes = getSizesAttr();
3598  auto strides = getStridesAttr();
3599  if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3600  return emitOpError(
3601  "expected offsets, sizes and strides attributes of same size");
3602 
3603  auto shape = type.getShape();
3604  auto offName = getOffsetsAttrName();
3605  auto sizesName = getSizesAttrName();
3606  auto stridesName = getStridesAttrName();
3607  if (failed(
3608  isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
3609  failed(
3610  isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
3611  failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
3612  stridesName)) ||
3613  failed(
3614  isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
3615  failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
3616  /*halfOpen=*/false,
3617  /*min=*/1)) ||
3618  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3619  /*max=*/1, stridesName,
3620  /*halfOpen=*/false)) ||
3621  failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
3622  shape, offName, sizesName,
3623  /*halfOpen=*/false)))
3624  return failure();
3625 
3626  auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
3627  offsets, sizes, strides);
3628  if (getResult().getType() != resultType)
3629  return emitOpError("expected result type to be ") << resultType;
3630 
3631  for (unsigned idx = 0; idx < sizes.size(); ++idx) {
3632  if (type.getScalableDims()[idx]) {
3633  auto inputDim = type.getShape()[idx];
3634  auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3635  if (inputDim != inputSize)
3636  return emitOpError("expected size at idx=")
3637  << idx
3638  << (" to match the corresponding base size from the input "
3639  "vector (")
3640  << inputSize << (" vs ") << inputDim << (")");
3641  }
3642  }
3643 
3644  return success();
3645 }
3646 
3647 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
3648 // to use the source of the InsertStrided ops if we can detect that the
3649 // extracted vector is a subset of one of the vector inserted.
3650 static LogicalResult
3651 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
3652  // Helper to extract integer out of ArrayAttr.
3653  auto getElement = [](ArrayAttr array, int idx) {
3654  return llvm::cast<IntegerAttr>(array[idx]).getInt();
3655  };
3656  ArrayAttr extractOffsets = op.getOffsets();
3657  ArrayAttr extractStrides = op.getStrides();
3658  ArrayAttr extractSizes = op.getSizes();
3659  auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3660  while (insertOp) {
3661  if (op.getSourceVectorType().getRank() !=
3662  insertOp.getSourceVectorType().getRank())
3663  return failure();
3664  ArrayAttr insertOffsets = insertOp.getOffsets();
3665  ArrayAttr insertStrides = insertOp.getStrides();
3666  // If the rank of extract is greater than the rank of insert, we are likely
3667  // extracting a partial chunk of the vector inserted.
3668  if (extractOffsets.size() > insertOffsets.size())
3669  return failure();
3670  bool patialoverlap = false;
3671  bool disjoint = false;
3672  SmallVector<int64_t, 4> offsetDiffs;
3673  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3674  if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
3675  return failure();
3676  int64_t start = getElement(insertOffsets, dim);
3677  int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3678  int64_t offset = getElement(extractOffsets, dim);
3679  int64_t size = getElement(extractSizes, dim);
3680  // Check if the start of the extract offset is in the interval inserted.
3681  if (start <= offset && offset < end) {
3682  // If the extract interval overlaps but is not fully included we may
3683  // have a partial overlap that will prevent any folding.
3684  if (offset + size > end)
3685  patialoverlap = true;
3686  offsetDiffs.push_back(offset - start);
3687  continue;
3688  }
3689  disjoint = true;
3690  break;
3691  }
3692  // The extract element chunk is a subset of the insert element.
3693  if (!disjoint && !patialoverlap) {
3694  op.setOperand(insertOp.getSource());
3695  // OpBuilder is only used as a helper to build an I64ArrayAttr.
3696  OpBuilder b(op.getContext());
3697  op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
3698  return success();
3699  }
3700  // If the chunk extracted is disjoint from the chunk inserted, keep looking
3701  // in the insert chain.
3702  if (disjoint)
3703  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3704  else {
3705  // The extracted vector partially overlap the inserted vector, we cannot
3706  // fold.
3707  return failure();
3708  }
3709  }
3710  return failure();
3711 }
3712 
3713 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3714  if (getSourceVectorType() == getResult().getType())
3715  return getVector();
3716  if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
3717  return getResult();
3718  return {};
3719 }
3720 
3721 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
3722  populateFromInt64AttrArray(getOffsets(), results);
3723 }
3724 
3725 namespace {
3726 
3727 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
3728 // ConstantMaskOp.
3729 class StridedSliceConstantMaskFolder final
3730  : public OpRewritePattern<ExtractStridedSliceOp> {
3731 public:
3733 
3734  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3735  PatternRewriter &rewriter) const override {
3736  // Return if 'extractStridedSliceOp' operand is not defined by a
3737  // ConstantMaskOp.
3738  auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3739  auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3740  if (!constantMaskOp)
3741  return failure();
3742  // Return if 'extractStridedSliceOp' has non-unit strides.
3743  if (extractStridedSliceOp.hasNonUnitStrides())
3744  return failure();
3745  // Gather constant mask dimension sizes.
3746  ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
3747  // Gather strided slice offsets and sizes.
3748  SmallVector<int64_t, 4> sliceOffsets;
3749  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
3750  sliceOffsets);
3751  SmallVector<int64_t, 4> sliceSizes;
3752  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
3753 
3754  // Compute slice of vector mask region.
3755  SmallVector<int64_t, 4> sliceMaskDimSizes;
3756  sliceMaskDimSizes.reserve(maskDimSizes.size());
3757  for (auto [maskDimSize, sliceOffset, sliceSize] :
3758  llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3759  int64_t sliceMaskDimSize = std::max(
3760  static_cast<int64_t>(0),
3761  std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3762  sliceMaskDimSizes.push_back(sliceMaskDimSize);
3763  }
3764  // Add unchanged dimensions.
3765  if (sliceMaskDimSizes.size() < maskDimSizes.size())
3766  for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3767  sliceMaskDimSizes.push_back(maskDimSizes[i]);
3768  // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
3769  // region is a conjunction of mask dim intervals).
3770  if (llvm::is_contained(sliceMaskDimSizes, 0))
3771  sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3772 
3773  // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
3774  // region.
3775  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3776  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3777  sliceMaskDimSizes);
3778  return success();
3779  }
3780 };
3781 
3782 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3783 class StridedSliceSplatConstantFolder final
3784  : public OpRewritePattern<ExtractStridedSliceOp> {
3785 public:
3787 
3788  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3789  PatternRewriter &rewriter) const override {
3790  // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3791  // ConstantOp.
3792  Value sourceVector = extractStridedSliceOp.getVector();
3793  Attribute vectorCst;
3794  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3795  return failure();
3796 
3797  auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3798  if (!splat)
3799  return failure();
3800 
3801  auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
3802  splat.getSplatValue<Attribute>());
3803  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3804  newAttr);
3805  return success();
3806  }
3807 };
3808 
3809 // Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3810 // ConstantOp.
3811 class StridedSliceNonSplatConstantFolder final
3812  : public OpRewritePattern<ExtractStridedSliceOp> {
3813 public:
3815 
3816  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3817  PatternRewriter &rewriter) const override {
3818  // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3819  // ConstantOp.
3820  Value sourceVector = extractStridedSliceOp.getVector();
3821  Attribute vectorCst;
3822  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3823  return failure();
3824 
3825  // The splat case is handled by `StridedSliceSplatConstantFolder`.
3826  auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3827  if (!dense || dense.isSplat())
3828  return failure();
3829 
3830  // TODO: Handle non-unit strides when they become available.
3831  if (extractStridedSliceOp.hasNonUnitStrides())
3832  return failure();
3833 
3834  auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
3835  ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3836  SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3837 
3838  VectorType sliceVecTy = extractStridedSliceOp.getType();
3839  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3840  int64_t sliceRank = sliceVecTy.getRank();
3841 
3842  // Expand offsets and sizes to match the vector rank.
3843  SmallVector<int64_t, 4> offsets(sliceRank, 0);
3844  copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
3845 
3846  SmallVector<int64_t, 4> sizes(sourceShape);
3847  copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
3848 
3849  // Calculate the slice elements by enumerating all slice positions and
3850  // linearizing them. The enumeration order is lexicographic which yields a
3851  // sequence of monotonically increasing linearized position indices.
3852  auto denseValuesBegin = dense.value_begin<Attribute>();
3853  SmallVector<Attribute> sliceValues;
3854  sliceValues.reserve(sliceVecTy.getNumElements());
3855  SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3856  do {
3857  int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
3858  assert(linearizedPosition < sourceVecTy.getNumElements() &&
3859  "Invalid index");
3860  sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3861  } while (
3862  succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
3863 
3864  assert(static_cast<int64_t>(sliceValues.size()) ==
3865  sliceVecTy.getNumElements() &&
3866  "Invalid number of slice elements");
3867  auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
3868  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3869  newAttr);
3870  return success();
3871  }
3872 };
3873 
3874 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
3875 // BroadcastOp(ExtractStrideSliceOp).
3876 class StridedSliceBroadcast final
3877  : public OpRewritePattern<ExtractStridedSliceOp> {
3878 public:
3880 
3881  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3882  PatternRewriter &rewriter) const override {
3883  auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3884  if (!broadcast)
3885  return failure();
3886  auto srcVecType =
3887  llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
3888  unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3889  auto dstVecType = llvm::cast<VectorType>(op.getType());
3890  unsigned dstRank = dstVecType.getRank();
3891  unsigned rankDiff = dstRank - srcRank;
3892  // Check if the most inner dimensions of the source of the broadcast are the
3893  // same as the destination of the extract. If this is the case we can just
3894  // use a broadcast as the original dimensions are untouched.
3895  bool lowerDimMatch = true;
3896  for (unsigned i = 0; i < srcRank; i++) {
3897  if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3898  lowerDimMatch = false;
3899  break;
3900  }
3901  }
3902  Value source = broadcast.getSource();
3903  // If the inner dimensions don't match, it means we need to extract from the
3904  // source of the orignal broadcast and then broadcast the extracted value.
3905  // We also need to handle degenerated cases where the source is effectively
3906  // just a single scalar.
3907  bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3908  if (!lowerDimMatch && !isScalarSrc) {
3909  source = rewriter.create<ExtractStridedSliceOp>(
3910  op->getLoc(), source,
3911  getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
3912  getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
3913  getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
3914  }
3915  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
3916  return success();
3917  }
3918 };
3919 
3920 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
3921 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
3922 public:
3924 
3925  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3926  PatternRewriter &rewriter) const override {
3927  auto splat = op.getVector().getDefiningOp<SplatOp>();
3928  if (!splat)
3929  return failure();
3930  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
3931  return success();
3932  }
3933 };
3934 
3935 /// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
3936 /// slice is contiguous, into extract and shape_cast.
3937 ///
3938 /// Example:
3939 /// Before:
3940 /// %1 = vector.extract_strided_slice %arg0 {
3941 /// offsets = [0, 0, 0, 0, 0],
3942 /// sizes = [1, 1, 1, 1, 8],
3943 /// strides = [1, 1, 1, 1, 1]
3944 /// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
3945 /// After:
3946 /// %0 = vector.extract %arg0[0, 0, 0, 0]
3947 /// : vector<8xi8> from vector<8x1x1x2x8xi8>
3948 /// %1 = vector.shape_cast %0
3949 /// : vector<8xi8> to vector<1x1x1x1x8xi8>
3950 ///
3951 class ContiguousExtractStridedSliceToExtract final
3952  : public OpRewritePattern<ExtractStridedSliceOp> {
3953 public:
3955 
3956  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3957  PatternRewriter &rewriter) const override {
3958  if (op.hasNonUnitStrides())
3959  return failure();
3960  Value source = op.getOperand();
3961  auto sourceType = cast<VectorType>(source.getType());
3962  if (sourceType.isScalable() || sourceType.getRank() == 0)
3963  return failure();
3964 
3965  // Compute the number of offsets to pass to ExtractOp::build. That is the
3966  // difference between the source rank and the desired slice rank. We walk
3967  // the dimensions from innermost out, and stop when the next slice dimension
3968  // is not full-size.
3969  SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
3970  int numOffsets;
3971  for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
3972  if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
3973  break;
3974  }
3975 
3976  // If the created extract op would have no offsets, then this whole
3977  // extract_strided_slice is the identity and should have been handled by
3978  // other canonicalizations.
3979  if (numOffsets == 0)
3980  return failure();
3981 
3982  // If not even the inner-most dimension is full-size, this op can't be
3983  // rewritten as an ExtractOp.
3984  if (numOffsets == sourceType.getRank() &&
3985  static_cast<int>(sizes.size()) == sourceType.getRank())
3986  return failure();
3987 
3988  // The outer dimensions must have unit size.
3989  for (int i = 0; i < numOffsets; ++i) {
3990  if (sizes[i] != 1)
3991  return failure();
3992  }
3993 
3994  // Avoid generating slices that have leading unit dimensions. The shape_cast
3995  // op that we create below would take bad generic fallback patterns
3996  // (ShapeCastOpRewritePattern).
3997  while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
3998  sizes[numOffsets] == 1) {
3999  ++numOffsets;
4000  }
4001 
4002  SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
4003  auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4004  Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
4005  extractOffsets);
4006  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
4007  return success();
4008  }
4009 };
4010 
4011 } // namespace
4012 
4013 void ExtractStridedSliceOp::getCanonicalizationPatterns(
4014  RewritePatternSet &results, MLIRContext *context) {
4015  // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
4016  // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4017  results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
4018  StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
4019  StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4020  context);
4021 }
4022 
4023 //===----------------------------------------------------------------------===//
4024 // TransferReadOp
4025 //===----------------------------------------------------------------------===//
4026 
4027 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
4028 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4029  VectorType vectorType, Value source,
4030  ValueRange indices, AffineMapAttr permutationMapAttr,
4031  /*optional*/ ArrayAttr inBoundsAttr) {
4032  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4033  Value padding = builder.create<arith::ConstantOp>(
4034  result.location, elemType, builder.getZeroAttr(elemType));
4035  build(builder, result, vectorType, source, indices, permutationMapAttr,
4036  padding, /*mask=*/Value(), inBoundsAttr);
4037 }
4038 
4039 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
4040 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4041  VectorType vectorType, Value source,
4042  ValueRange indices, AffineMap permutationMap,
4043  std::optional<ArrayRef<bool>> inBounds) {
4044  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4045  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4046  ? builder.getBoolArrayAttr(inBounds.value())
4047  : builder.getBoolArrayAttr(
4048  SmallVector<bool>(vectorType.getRank(), false));
4049  build(builder, result, vectorType, source, indices, permutationMapAttr,
4050  inBoundsAttr);
4051 }
4052 
4053 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
4054 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4055  VectorType vectorType, Value source,
4056  ValueRange indices, Value padding,
4057  std::optional<ArrayRef<bool>> inBounds) {
4058  AffineMap permutationMap = getTransferMinorIdentityMap(
4059  llvm::cast<ShapedType>(source.getType()), vectorType);
4060  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4061  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4062  ? builder.getBoolArrayAttr(inBounds.value())
4063  : builder.getBoolArrayAttr(
4064  SmallVector<bool>(vectorType.getRank(), false));
4065  build(builder, result, vectorType, source, indices, permutationMapAttr,
4066  padding,
4067  /*mask=*/Value(), inBoundsAttr);
4068 }
4069 
4070 /// 4. Builder that sets padding to zero and permutation map to
4071 /// 'getMinorIdentityMap'.
4072 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4073  VectorType vectorType, Value source,
4074  ValueRange indices,
4075  std::optional<ArrayRef<bool>> inBounds) {
4076  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4077  Value padding = builder.create<arith::ConstantOp>(
4078  result.location, elemType, builder.getZeroAttr(elemType));
4079  build(builder, result, vectorType, source, indices, padding, inBounds);
4080 }
4081 
4082 template <typename EmitFun>
4083 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
4084  EmitFun emitOpError) {
4085  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
4086  for (auto expr : permutationMap.getResults()) {
4087  auto dim = dyn_cast<AffineDimExpr>(expr);
4088  auto zero = dyn_cast<AffineConstantExpr>(expr);
4089  if (zero) {
4090  if (zero.getValue() != 0) {
4091  return emitOpError(
4092  "requires a projected permutation_map (at most one dim or the zero "
4093  "constant can appear in each result)");
4094  }
4095  continue;
4096  }
4097  if (!dim) {
4098  return emitOpError("requires a projected permutation_map (at most one "
4099  "dim or the zero constant can appear in each result)");
4100  }
4101  if (seen[dim.getPosition()]) {
4102  return emitOpError(
4103  "requires a permutation_map that is a permutation (found one dim "
4104  "used more than once)");
4105  }
4106  seen[dim.getPosition()] = true;
4107  }
4108  return success();
4109 }
4110 
4111 static LogicalResult
4112 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4113  VectorType vectorType, VectorType maskType,
4114  VectorType inferredMaskType, AffineMap permutationMap,
4115  ArrayAttr inBounds) {
4116  if (op->hasAttr("masked")) {
4117  return op->emitOpError("masked attribute has been removed. "
4118  "Use in_bounds instead.");
4119  }
4120 
4121  if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4122  return op->emitOpError(
4123  "requires source to be a memref or ranked tensor type");
4124 
4125  auto elementType = shapedType.getElementType();
4126  DataLayout dataLayout = DataLayout::closest(op);
4127  if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4128  // Memref or tensor has vector element type.
4129  unsigned sourceVecSize =
4130  dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
4131  vectorElementType.getShape().back();
4132  unsigned resultVecSize =
4133  dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
4134  vectorType.getShape().back();
4135  if (resultVecSize % sourceVecSize != 0)
4136  return op->emitOpError(
4137  "requires the bitwidth of the minor 1-D vector to be an integral "
4138  "multiple of the bitwidth of the minor 1-D vector of the source");
4139 
4140  unsigned sourceVecEltRank = vectorElementType.getRank();
4141  unsigned resultVecRank = vectorType.getRank();
4142  if (sourceVecEltRank > resultVecRank)
4143  return op->emitOpError(
4144  "requires source vector element and vector result ranks to match.");
4145  unsigned rankOffset = resultVecRank - sourceVecEltRank;
4146  // Check that permutation map results match 'rankOffset' of vector type.
4147  if (permutationMap.getNumResults() != rankOffset)
4148  return op->emitOpError("requires a permutation_map with result dims of "
4149  "the same rank as the vector type");
4150 
4151  if (maskType)
4152  return op->emitOpError("does not support masks with vector element type");
4153  } else {
4154  // Memref or tensor has scalar element type.
4155  unsigned minorSize =
4156  vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4157  unsigned resultVecSize =
4158  dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
4159  if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
4160  return op->emitOpError(
4161  "requires the bitwidth of the minor 1-D vector to be an integral "
4162  "multiple of the bitwidth of the source element type");
4163 
4164  // Check that permutation map results match rank of vector type.
4165  if (permutationMap.getNumResults() != vectorType.getRank())
4166  return op->emitOpError("requires a permutation_map with result dims of "
4167  "the same rank as the vector type");
4168  }
4169 
4170  if (permutationMap.getNumSymbols() != 0)
4171  return op->emitOpError("requires permutation_map without symbols");
4172 
4173  if (permutationMap.getNumInputs() != shapedType.getRank())
4174  return op->emitOpError("requires a permutation_map with input dims of the "
4175  "same rank as the source type");
4176 
4177  if (maskType && maskType != inferredMaskType)
4178  return op->emitOpError("inferred mask type (")
4179  << inferredMaskType << ") and mask operand type (" << maskType
4180  << ") don't match";
4181 
4182  if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
4183  return op->emitOpError("expects the in_bounds attr of same rank "
4184  "as permutation_map results: ")
4185  << AffineMapAttr::get(permutationMap)
4186  << " vs inBounds of size: " << inBounds.size();
4187 
4188  return success();
4189 }
4190 
4191 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
4192  SmallVector<StringRef, 3> elidedAttrs;
4193  elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4194  if (op.getPermutationMap().isMinorIdentity())
4195  elidedAttrs.push_back(op.getPermutationMapAttrName());
4196  // Elide in_bounds attribute if all dims are out-of-bounds.
4197  if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
4198  elidedAttrs.push_back(op.getInBoundsAttrName());
4199  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4200 }
4201 
4203  p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
4204  if (getMask())
4205  p << ", " << getMask();
4206  printTransferAttrs(p, *this);
4207  p << " : " << getShapedType() << ", " << getVectorType();
4208 }
4209 
4210 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
4211  AffineMap permMap) {
4212  auto i1Type = IntegerType::get(permMap.getContext(), 1);
4213  AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
4214  assert(invPermMap && "Inversed permutation map couldn't be computed");
4215  SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
4216 
4217  // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
4218  // 0-D mask into a single-element 1-D mask.
4219  if (maskShape.empty())
4220  maskShape.push_back(1);
4221 
4222  SmallVector<bool> scalableDims =
4223  applyPermutationMap(invPermMap, vecType.getScalableDims());
4224 
4225  return VectorType::get(maskShape, i1Type, scalableDims);
4226 }
4227 
4228 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4229  auto &builder = parser.getBuilder();
4230  SMLoc typesLoc;
4231  OpAsmParser::UnresolvedOperand sourceInfo;
4233  OpAsmParser::UnresolvedOperand paddingInfo;
4234  SmallVector<Type, 2> types;
4236  // Parsing with support for paddingValue.
4237  if (parser.parseOperand(sourceInfo) ||
4238  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
4239  parser.parseComma() || parser.parseOperand(paddingInfo))
4240  return failure();
4241  ParseResult hasMask = parser.parseOptionalComma();
4242  if (hasMask.succeeded()) {
4243  if (parser.parseOperand(maskInfo))
4244  return failure();
4245  }
4246  if (parser.parseOptionalAttrDict(result.attributes) ||
4247  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4248  return failure();
4249  if (types.size() != 2)
4250  return parser.emitError(typesLoc, "requires two types");
4251  auto indexType = builder.getIndexType();
4252  auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4253  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4254  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4255  VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4256  if (!vectorType)
4257  return parser.emitError(typesLoc, "requires vector type");
4258  auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
4259  Attribute permMapAttr = result.attributes.get(permMapAttrName);
4260  AffineMap permMap;
4261  if (!permMapAttr) {
4262  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4263  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4264  } else {
4265  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4266  }
4267  auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
4268  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4269  if (!inBoundsAttr) {
4270  result.addAttribute(inBoundsAttrName,
4271  builder.getBoolArrayAttr(
4272  SmallVector<bool>(permMap.getNumResults(), false)));
4273  }
4274  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4275  parser.resolveOperands(indexInfo, indexType, result.operands) ||
4276  parser.resolveOperand(paddingInfo, shapedType.getElementType(),
4277  result.operands))
4278  return failure();
4279  if (hasMask.succeeded()) {
4280  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4281  return parser.emitError(
4282  maskInfo.location, "does not support masks with vector element type");
4283  if (vectorType.getRank() != permMap.getNumResults()) {
4284  return parser.emitError(typesLoc,
4285  "expected the same rank for the vector and the "
4286  "results of the permutation map");
4287  }
4288  // Instead of adding the mask type as an op type, compute it based on the
4289  // vector type and the permutation map (to keep the type signature small).
4290  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4291  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4292  return failure();
4293  }
4294  result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4295  builder.getDenseI32ArrayAttr(
4296  {1, static_cast<int32_t>(indexInfo.size()), 1,
4297  static_cast<int32_t>(hasMask.succeeded())}));
4298  return parser.addTypeToList(vectorType, result.types);
4299 }
4300 
4301 LogicalResult TransferReadOp::verify() {
4302  // Consistency of elemental types in source and vector.
4303  ShapedType shapedType = getShapedType();
4304  VectorType vectorType = getVectorType();
4305  VectorType maskType = getMaskType();
4306  auto paddingType = getPadding().getType();
4307  auto permutationMap = getPermutationMap();
4308  VectorType inferredMaskType =
4309  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4310  : VectorType();
4311  auto sourceElementType = shapedType.getElementType();
4312 
4313  if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
4314  return emitOpError("requires ") << shapedType.getRank() << " indices";
4315 
4316  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4317  shapedType, vectorType, maskType,
4318  inferredMaskType, permutationMap, getInBounds())))
4319  return failure();
4320 
4321  if (auto sourceVectorElementType =
4322  llvm::dyn_cast<VectorType>(sourceElementType)) {
4323  // Source has vector element type.
4324  // Check that 'sourceVectorElementType' and 'paddingType' types match.
4325  if (sourceVectorElementType != paddingType)
4326  return emitOpError(
4327  "requires source element type and padding type to match.");
4328 
4329  } else {
4330  // Check that 'paddingType' is valid to store in a vector type.
4331  if (!VectorType::isValidElementType(paddingType))
4332  return emitOpError("requires valid padding vector elemental type");
4333 
4334  // Check that padding type and vector element types match.
4335  if (paddingType != sourceElementType)
4336  return emitOpError(
4337  "requires formal padding and source of the same elemental type");
4338  }
4339 
4340  return verifyPermutationMap(permutationMap,
4341  [&](Twine t) { return emitOpError(t); });
4342 }
4343 
4344 // MaskableOpInterface methods.
4345 
4346 /// Returns the mask type expected by this operation. Mostly used for
4347 /// verification purposes. It requires the operation to be vectorized."
4348 Type TransferReadOp::getExpectedMaskType() {
4349  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4350 }
4351 
4352 template <typename TransferOp>
4353 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4354  // TODO: support more aggressive createOrFold on:
4355  // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4356  if (op.getShapedType().isDynamicDim(indicesIdx))
4357  return false;
4358  Value index = op.getIndices()[indicesIdx];
4359  std::optional<int64_t> cstOp = getConstantIntValue(index);
4360  if (!cstOp.has_value())
4361  return false;
4362 
4363  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4364  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4365 
4366  return cstOp.value() + vectorSize <= sourceSize;
4367 }
4368 
4369 template <typename TransferOp>
4370 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4371  // TODO: support 0-d corner case.
4372  // TODO: Be less conservative.
4373  if (op.getTransferRank() == 0)
4374  return failure();
4375  AffineMap permutationMap = op.getPermutationMap();
4376  bool changed = false;
4377  SmallVector<bool, 4> newInBounds;
4378  newInBounds.reserve(op.getTransferRank());
4379  // Idxs of non-bcast dims - used when analysing bcast dims.
4380  SmallVector<unsigned> nonBcastDims;
4381 
4382  // 1. Process non-broadcast dims
4383  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
4384  // 1.1. Already marked as in-bounds, nothing to see here.
4385  if (op.isDimInBounds(i)) {
4386  newInBounds.push_back(true);
4387  continue;
4388  }
4389  // 1.2. Currently out-of-bounds, check whether we can statically determine
4390  // it is inBounds.
4391  bool inBounds = false;
4392  auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
4393  if (dimExpr) {
4394  inBounds = isInBounds(op, /*resultIdx=*/i,
4395  /*indicesIdx=*/dimExpr.getPosition());
4396  nonBcastDims.push_back(i);
4397  }
4398 
4399  newInBounds.push_back(inBounds);
4400  // We commit the pattern if it is "more inbounds".
4401  changed |= inBounds;
4402  }
4403 
4404  // 2. Handle broadcast dims
4405  // If all non-broadcast dims are "in bounds", then all bcast dims should be
4406  // "in bounds" as well.
4407  bool allNonBcastDimsInBounds = llvm::all_of(
4408  nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
4409  if (allNonBcastDimsInBounds) {
4410  for (size_t idx : permutationMap.getBroadcastDims()) {
4411  changed |= !newInBounds[idx];
4412  newInBounds[idx] = true;
4413  }
4414  }
4415 
4416  if (!changed)
4417  return failure();
4418  // OpBuilder is only used as a helper to build an I64ArrayAttr.
4419  OpBuilder b(op.getContext());
4420  op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
4421  return success();
4422 }
4423 
4424 template <typename TransferOp>
4425 static LogicalResult foldTransferFullMask(TransferOp op) {
4426  auto mask = op.getMask();
4427  if (!mask)
4428  return failure();
4429 
4430  if (getMaskFormat(mask) != MaskFormat::AllTrue)
4431  return failure();
4432 
4433  op.getMaskMutable().clear();
4434  return success();
4435 }
4436 
4437 /// ```
4438 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4439 /// : vector<1x4xf32>, tensor<4x4xf32>
4440 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
4441 /// : tensor<4x4xf32>, vector<1x4xf32>
4442 /// ```
4443 /// -> Folds into
4444 /// ```
4445 /// %v0
4446 /// ```
4447 static Value foldRAW(TransferReadOp readOp) {
4448  if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4449  return {};
4450  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4451  while (defWrite) {
4452  if (checkSameValueRAW(defWrite, readOp))
4453  return defWrite.getVector();
4455  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4456  cast<VectorTransferOpInterface>(readOp.getOperation())))
4457  break;
4458  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4459  }
4460  return {};
4461 }
4462 
4463 OpFoldResult TransferReadOp::fold(FoldAdaptor) {
4464  if (Value vec = foldRAW(*this))
4465  return vec;
4466  /// transfer_read(memrefcast) -> transfer_read
4467  if (succeeded(foldTransferInBoundsAttribute(*this)))
4468  return getResult();
4469  if (succeeded(foldTransferFullMask(*this)))
4470  return getResult();
4471  if (succeeded(memref::foldMemRefCast(*this)))
4472  return getResult();
4473  if (succeeded(tensor::foldTensorCast(*this)))
4474  return getResult();
4475  return OpFoldResult();
4476 }
4477 
4478 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4479  return llvm::to_vector<4>(getVectorType().getShape());
4480 }
4481 
4482 void TransferReadOp::getEffects(
4484  &effects) {
4485  if (llvm::isa<MemRefType>(getShapedType()))
4486  effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(),
4488 }
4489 
4490 Speculation::Speculatability TransferReadOp::getSpeculatability() {
4491  if (hasPureTensorSemantics())
4494 }
4495 
4496 namespace {
4497 /// Store to load forwarding for transfer operations with permuation maps.
4498 /// Even if the permutation maps are different we can still propagate the store
4499 /// into the load if the size of the dimensions read and written match. Then we
4500 /// can replace the transfer_read + transfer_write by vector.broadcast and
4501 /// vector.transpose.
4502 /// Example:
4503 /// ```
4504 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
4505 /// {in_bounds = [true, true],
4506 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
4507 /// vector<4x1xf32>, tensor<4x4x4xf32>
4508 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
4509 /// {in_bounds = [true, true, true, true],
4510 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
4511 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
4512 /// ```
4513 /// To:
4514 /// ```
4515 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
4516 /// %r = vector.transpose %0, [3, 0, 2, 1] :
4517 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
4518 /// ```
4519 struct TransferReadAfterWriteToBroadcast
4520  : public OpRewritePattern<TransferReadOp> {
4522 
4523  LogicalResult matchAndRewrite(TransferReadOp readOp,
4524  PatternRewriter &rewriter) const override {
4525  if (readOp.hasOutOfBoundsDim() ||
4526  !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4527  return failure();
4528  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4529  if (!defWrite)
4530  return failure();
4531  // TODO: If the written transfer chunk is a superset of the read transfer
4532  // chunk we could do an extract_strided_slice.
4533  if (readOp.getTransferChunkAccessed() !=
4534  defWrite.getTransferChunkAccessed())
4535  return failure();
4536  // TODO: Support cases where a dim is explicitly written but implicitly
4537  // read (i.e., a unit dim that is rank reduced).
4538  if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
4539  getUnusedDimsBitVector({defWrite.getPermutationMap()}))
4540  return failure();
4541  if (readOp.getIndices() != defWrite.getIndices() ||
4542  readOp.getMask() != defWrite.getMask())
4543  return failure();
4544  Value vec = defWrite.getVector();
4545  // TODO: loop through the chain of transfer_write if we can prove that they
4546  // don't overlap with the transfer_read. This requires improving
4547  // `isDisjointTransferIndices` helper.
4548  AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4549  AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
4550  AffineMap map = readMap.compose(writeMap);
4551  if (map.getNumResults() == 0)
4552  return failure();
4553  // Calculate the permutation to apply to go from the vector stored to the
4554  // vector read.
4555  SmallVector<unsigned> permutation;
4556  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
4557  return failure();
4558 
4559  Location loc = readOp.getLoc();
4560  // Calculate the broadcast shape by applying the reverse permutation to the
4561  // final shape we want.
4562  ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
4563  SmallVector<int64_t> broadcastShape(destShape.size());
4564  SmallVector<bool> broadcastScalableFlags(destShape.size());
4565  for (const auto &pos : llvm::enumerate(permutation)) {
4566  broadcastShape[pos.value()] = destShape[pos.index()];
4567  broadcastScalableFlags[pos.value()] =
4568  readOp.getVectorType().getScalableDims()[pos.index()];
4569  }
4570  VectorType broadcastedType = VectorType::get(
4571  broadcastShape, defWrite.getVectorType().getElementType(),
4572  broadcastScalableFlags);
4573  vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
4574  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
4575  rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
4576  transposePerm);
4577  return success();
4578  }
4579 };
4580 } // namespace
4581 
4582 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4583  MLIRContext *context) {
4584  results.add<TransferReadAfterWriteToBroadcast>(context);
4585 }
4586 
4587 //===----------------------------------------------------------------------===//
4588 // TransferWriteOp
4589 //===----------------------------------------------------------------------===//
4590 
4591 /// 1. Builder with type inference.
4592 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4593  Value vector, Value dest, ValueRange indices,
4594  AffineMapAttr permutationMapAttr,
4595  /*optional*/ Value mask,
4596  /*optional*/ ArrayAttr inBoundsAttr) {
4597  Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
4598  build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4599  mask, inBoundsAttr);
4600 }
4601 
4602 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
4603 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4604  Value vector, Value dest, ValueRange indices,
4605  AffineMapAttr permutationMapAttr,
4606  /*optional*/ ArrayAttr inBoundsAttr) {
4607  build(builder, result, vector, dest, indices, permutationMapAttr,
4608  /*mask=*/Value(), inBoundsAttr);
4609 }
4610 
4611 /// 3. Builder with type inference that sets an empty mask (variant without
4612 /// attrs)
4613 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4614  Value vector, Value dest, ValueRange indices,
4615  AffineMap permutationMap,
4616  std::optional<ArrayRef<bool>> inBounds) {
4617  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4618  auto inBoundsAttr =
4619  (inBounds && !inBounds.value().empty())
4620  ? builder.getBoolArrayAttr(inBounds.value())
4622  llvm::cast<VectorType>(vector.getType()).getRank(), false));
4623  build(builder, result, vector, dest, indices, permutationMapAttr,
4624  /*mask=*/Value(), inBoundsAttr);
4625 }
4626 
4627 /// 4. Builder with type inference that sets an empty mask and sets permutation
4628 /// map to 'getMinorIdentityMap'.
4629 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4630  Value vector, Value dest, ValueRange indices,
4631  std::optional<ArrayRef<bool>> inBounds) {
4632  auto vectorType = llvm::cast<VectorType>(vector.getType());
4633  AffineMap permutationMap = getTransferMinorIdentityMap(
4634  llvm::cast<ShapedType>(dest.getType()), vectorType);
4635  build(builder, result, vector, dest, indices, permutationMap, inBounds);
4636 }
4637 
4638 ParseResult TransferWriteOp::parse(OpAsmParser &parser,
4639  OperationState &result) {
4640  auto &builder = parser.getBuilder();
4641  SMLoc typesLoc;
4642  OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
4644  SmallVector<Type, 2> types;
4646  if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
4647  parser.parseOperand(sourceInfo) ||
4648  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
4649  return failure();
4650  ParseResult hasMask = parser.parseOptionalComma();
4651  if (hasMask.succeeded() && parser.parseOperand(maskInfo))
4652  return failure();
4653  if (parser.parseOptionalAttrDict(result.attributes) ||
4654  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4655  return failure();
4656  if (types.size() != 2)
4657  return parser.emitError(typesLoc, "requires two types");
4658  auto indexType = builder.getIndexType();
4659  VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4660  if (!vectorType)
4661  return parser.emitError(typesLoc, "requires vector type");
4662  ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4663  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4664  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4665  auto permMapAttrName =
4666  TransferWriteOp::getPermutationMapAttrName(result.name);
4667  auto permMapAttr = result.attributes.get(permMapAttrName);
4668  AffineMap permMap;
4669  if (!permMapAttr) {
4670  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4671  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4672  } else {
4673  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4674  }
4675  auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
4676  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4677  if (!inBoundsAttr) {
4678  result.addAttribute(inBoundsAttrName,
4679  builder.getBoolArrayAttr(
4680  SmallVector<bool>(permMap.getNumResults(), false)));
4681  }
4682  if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
4683  parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4684  parser.resolveOperands(indexInfo, indexType, result.operands))
4685  return failure();
4686  if (hasMask.succeeded()) {
4687  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4688  return parser.emitError(
4689  maskInfo.location, "does not support masks with vector element type");
4690  if (vectorType.getRank() != permMap.getNumResults()) {
4691  return parser.emitError(typesLoc,
4692  "expected the same rank for the vector and the "
4693  "results of the permutation map");
4694  }
4695  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4696  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4697  return failure();
4698  }
4699  result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4700  builder.getDenseI32ArrayAttr(
4701  {1, 1, static_cast<int32_t>(indexInfo.size()),
4702  static_cast<int32_t>(hasMask.succeeded())}));
4703  return failure(llvm::isa<RankedTensorType>(shapedType) &&
4704  parser.addTypeToList(shapedType, result.types));
4705 }
4706 
4708  p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
4709  if (getMask())
4710  p << ", " << getMask();
4711  printTransferAttrs(p, *this);
4712  p << " : " << getVectorType() << ", " << getShapedType();
4713 }
4714 
4715 LogicalResult TransferWriteOp::verify() {
4716  // Consistency of elemental types in shape and vector.
4717  ShapedType shapedType = getShapedType();
4718  VectorType vectorType = getVectorType();
4719  VectorType maskType = getMaskType();
4720  auto permutationMap = getPermutationMap();
4721  VectorType inferredMaskType =
4722  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4723  : VectorType();
4724 
4725  if (llvm::size(getIndices()) != shapedType.getRank())
4726  return emitOpError("requires ") << shapedType.getRank() << " indices";
4727 
4728  // We do not allow broadcast dimensions on TransferWriteOps for the moment,
4729  // as the semantics is unclear. This can be revisited later if necessary.
4730  if (hasBroadcastDim())
4731  return emitOpError("should not have broadcast dimensions");
4732 
4733  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4734  shapedType, vectorType, maskType,
4735  inferredMaskType, permutationMap, getInBounds())))
4736  return failure();
4737 
4738  return verifyPermutationMap(permutationMap,
4739  [&](Twine t) { return emitOpError(t); });
4740 }
4741 
4742 // MaskableOpInterface methods.
4743 
4744 /// Returns the mask type expected by this operation. Mostly used for
4745 /// verification purposes.
4746 Type TransferWriteOp::getExpectedMaskType() {
4747  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4748 }
4749 
4750 /// Fold:
4751 /// ```
4752 /// %t1 = ...
4753 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
4754 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4755 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
4756 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4757 /// ```
4758 ///
4759 /// into:
4760 ///
4761 /// ```
4762 /// %t0
4763 /// ```
4764 ///
4765 /// The producer of t1 may or may not be DCE'd depending on whether it is a
4766 /// block argument or has side effects.
4767 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4769  SmallVectorImpl<OpFoldResult> &results) {
4770  // TODO: support 0-d corner case.
4771  if (write.getTransferRank() == 0)
4772  return failure();
4773  auto rankedTensorType =
4774  llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4775  // If not operating on tensors, bail.
4776  if (!rankedTensorType)
4777  return failure();
4778  // If no read, bail.
4779  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4780  if (!read)
4781  return failure();
4782  // TODO: support 0-d corner case.
4783  if (read.getTransferRank() == 0)
4784  return failure();
4785  // For now, only accept minor identity. Future: composition is minor identity.
4786  if (!read.getPermutationMap().isMinorIdentity() ||
4787  !write.getPermutationMap().isMinorIdentity())
4788  return failure();
4789  // Bail on mismatching ranks.
4790  if (read.getTransferRank() != write.getTransferRank())
4791  return failure();
4792  // Bail on potential out-of-bounds accesses.
4793  if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4794  return failure();
4795  // Tensor types must be the same.
4796  if (read.getSource().getType() != rankedTensorType)
4797  return failure();
4798  // Vector types must be the same.
4799  if (read.getVectorType() != write.getVectorType())
4800  return failure();
4801  // Vector and Tensor shapes must match.
4802  if (read.getVectorType().getShape() != rankedTensorType.getShape())
4803  return failure();
4804  // If any index is nonzero.
4805  auto isNotConstantZero = [](Value v) {
4806  auto cstOp = getConstantIntValue(v);
4807  return !cstOp.has_value() || cstOp.value() != 0;
4808  };
4809  if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4810  llvm::any_of(write.getIndices(), isNotConstantZero))
4811  return failure();
4812  // Success.
4813  results.push_back(read.getSource());
4814  return success();
4815 }
4816 
4817 static bool checkSameValueWAR(vector::TransferReadOp read,
4818  vector::TransferWriteOp write) {
4819  return read.getSource() == write.getSource() &&
4820  read.getIndices() == write.getIndices() &&
4821  read.getPermutationMap() == write.getPermutationMap() &&
4822  read.getVectorType() == write.getVectorType() && !read.getMask() &&
4823  !write.getMask();
4824 }
4825 /// Fold transfer_write write after read:
4826 /// ```
4827 /// %t0 = ...
4828 /// %v = vector.transfer_read %t0[%c0...] :
4829 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4830 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
4831 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4832 /// ```
4833 ///
4834 /// into:
4835 ///
4836 /// ```
4837 /// %t0
4838 /// ```
4839 static LogicalResult foldWAR(TransferWriteOp write,
4840  SmallVectorImpl<OpFoldResult> &results) {
4841  if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4842  return failure();
4843  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4844  if (!read)
4845  return failure();
4846 
4847  if (!checkSameValueWAR(read, write))
4848  return failure();
4849  results.push_back(read.getSource());
4850  return success();
4851 }
4852 
4853 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4854  SmallVectorImpl<OpFoldResult> &results) {
4855  if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
4856  return success();
4857  if (succeeded(foldWAR(*this, results)))
4858  return success();
4859  if (succeeded(foldTransferInBoundsAttribute(*this)))
4860  return success();
4861  if (succeeded(foldTransferFullMask(*this)))
4862  return success();
4863  return memref::foldMemRefCast(*this);
4864 }
4865 
4866 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4867  return llvm::to_vector<4>(getVectorType().getShape());
4868 }
4869 
4870 void TransferWriteOp::getEffects(
4872  &effects) {
4873  if (llvm::isa<MemRefType>(getShapedType()))
4874  effects.emplace_back(MemoryEffects::Write::get(), &getSourceMutable(),
4876 }
4877 
4878 Speculation::Speculatability TransferWriteOp::getSpeculatability() {
4879  if (hasPureTensorSemantics())
4882 }
4883 
4884 namespace {
4885 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
4886 /// DCE
4887 /// ```
4888 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4889 /// : vector<1x4xf32>, tensor<4x4xf32>
4890 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
4891 /// : vector<1x4xf32>, tensor<4x4xf32>
4892 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4893 /// : vector<1x4xf32>, tensor<4x4xf32>
4894 /// ```
4895 ///
4896 /// into:
4897 ///
4898 /// ```
4899 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4900 /// : vector<1x4xf32>, tensor<4x4xf32>
4901 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
4902 /// : vector<1x4xf32>, tensor<4x4xf32>
4903 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4904 /// : vector<1x4xf32>, tensor<4x4xf32>
4905 /// ```
4906 ///
4907 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
4908 /// any other uses.
4909 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
4910 public:
4912  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
4913  PatternRewriter &rewriter) const override {
4914  if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4915  return failure();
4916  vector::TransferWriteOp writeToModify = writeOp;
4917 
4918  auto defWrite =
4919  writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4920  while (defWrite) {
4921  if (checkSameValueWAW(writeOp, defWrite)) {
4922  rewriter.modifyOpInPlace(writeToModify, [&]() {
4923  writeToModify.getSourceMutable().assign(defWrite.getSource());
4924  });
4925  return success();
4926  }
4928  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4929  cast<VectorTransferOpInterface>(writeOp.getOperation())))
4930  break;
4931  // If the previous write op doesn't have any other use we an safely look
4932  // at the previous store to see if it can be removed.
4933  if (!defWrite->hasOneUse())
4934  break;
4935  writeToModify = defWrite;
4936  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4937  }
4938  return failure();
4939  }
4940 };
4941 
4942 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
4943 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
4944 /// overwritten and inserted into another tensor. After this rewrite, the
4945 /// operations bufferize in-place since all of them work on the same slice.
4946 ///
4947 /// For example:
4948 /// ```mlir
4949 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
4950 /// : vector<8x16xf32>, tensor<8x16xf32>
4951 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
4952 /// : tensor<8x16xf32> to tensor<?x?xf32>
4953 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4954 /// : tensor<?x?xf32> into tensor<27x37xf32>
4955 /// ```
4956 /// folds to
4957 /// ```mlir
4958 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4959 /// : tensor<27x37xf32> to tensor<?x?xf32>
4960 /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
4961 /// : vector<8x16xf32>, tensor<?x?xf32>
4962 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4963 /// : tensor<?x?xf32> into tensor<27x37xf32>
4964 /// ```
4965 struct SwapExtractSliceOfTransferWrite
4966  : public OpRewritePattern<tensor::InsertSliceOp> {
4967 public:
4969 
4970  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
4971  PatternRewriter &rewriter) const override {
4972  if (!insertOp.hasUnitStride())
4973  return failure();
4974  auto extractOp =
4975  insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4976  if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4977  return failure();
4978  auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4979  if (!transferOp || !transferOp->hasOneUse())
4980  return failure();
4981 
4982  // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
4983  // rank-reducing.
4984  if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4985  return rewriter.notifyMatchFailure(insertOp,
4986  "use-def chain is rank-reducing");
4987  }
4988 
4989  // Fail if tensor::ExtractSliceOp has non-zero offset.
4990  if (!extractOp.hasZeroOffset()) {
4991  return rewriter.notifyMatchFailure(insertOp,
4992  "ExtractSliceOp has non-zero offset");
4993  }
4994 
4995  // Fail if tensor::TransferWriteOp has non-zero offset.
4996  if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
4997  return getConstantIntValue(value) == static_cast<int64_t>(0);
4998  })) {
4999  return rewriter.notifyMatchFailure(insertOp,
5000  "TranferWriteOp has non-zero offset");
5001  }
5002 
5003  // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
5004  if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5005  return rewriter.notifyMatchFailure(
5006  insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
5007  }
5008 
5009  for (auto [insertSize, extractSize] :
5010  llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5011  if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
5012  return rewriter.notifyMatchFailure(
5013  insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
5014  }
5015  }
5016 
5017  // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
5018  assert(transferOp.getVectorType().hasStaticShape() &&
5019  "expected vector to have a static shape");
5020  ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
5022  transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5023  if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
5024  return rewriter.notifyMatchFailure(
5025  insertOp, "TransferWriteOp may not write the full tensor.");
5026  }
5027 
5028  // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
5029  // Set all in_bounds to false and let the folder infer them.
5030  SmallVector<bool> newInBounds(vectorShape.size(), false);
5031  auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
5032  extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
5033  insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
5034  insertOp.getMixedStrides());
5035  auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
5036  transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
5037  transferOp.getIndices(), transferOp.getPermutationMapAttr(),
5038  rewriter.getBoolArrayAttr(newInBounds));
5039  rewriter.modifyOpInPlace(insertOp, [&]() {
5040  insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5041  });
5042  return success();
5043  }
5044 };
5045 
5046 } // namespace
5047 
5048 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5049  MLIRContext *context) {
5050  results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5051 }
5052 
5053 //===----------------------------------------------------------------------===//
5054 // LoadOp
5055 //===----------------------------------------------------------------------===//
5056 
5057 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5058  VectorType vecTy,
5059  MemRefType memRefTy) {
5060  // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
5061  // need any strides limitations.
5062  if (!vecTy.isScalable() &&
5063  (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5064  return success();
5065 
5066  if (!memRefTy.isLastDimUnitStride())
5067  return op->emitOpError("most minor memref dim must have unit stride");
5068  return success();
5069 }
5070 
5071 LogicalResult vector::LoadOp::verify() {
5072  VectorType resVecTy = getVectorType();
5073  MemRefType memRefTy = getMemRefType();
5074 
5075  if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
5076  return failure();
5077 
5078  // Checks for vector memrefs.
5079  Type memElemTy = memRefTy.getElementType();
5080  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5081  if (memVecTy != resVecTy)
5082  return emitOpError("base memref and result vector types should match");
5083  memElemTy = memVecTy.getElementType();
5084  }
5085 
5086  if (resVecTy.getElementType() != memElemTy)
5087  return emitOpError("base and result element types should match");
5088  if (llvm::size(getIndices()) != memRefTy.getRank())
5089  return emitOpError("requires ") << memRefTy.getRank() << " indices";
5090  return success();
5091 }
5092 
5093 OpFoldResult LoadOp::fold(FoldAdaptor) {
5094  if (succeeded(memref::foldMemRefCast(*this)))
5095  return getResult();
5096  return OpFoldResult();
5097 }
5098 
5099 //===----------------------------------------------------------------------===//
5100 // StoreOp
5101 //===----------------------------------------------------------------------===//
5102 
5103 LogicalResult vector::StoreOp::verify() {
5104  VectorType valueVecTy = getVectorType();
5105  MemRefType memRefTy = getMemRefType();
5106 
5107  if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
5108  return failure();
5109 
5110  // Checks for vector memrefs.
5111  Type memElemTy = memRefTy.getElementType();
5112  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5113  if (memVecTy != valueVecTy)
5114  return emitOpError(
5115  "base memref and valueToStore vector types should match");
5116  memElemTy = memVecTy.getElementType();
5117  }
5118 
5119  if (valueVecTy.getElementType() != memElemTy)
5120  return emitOpError("base and valueToStore element type should match");
5121  if (llvm::size(getIndices()) != memRefTy.getRank())
5122  return emitOpError("requires ") << memRefTy.getRank() << " indices";
5123  return success();
5124 }
5125 
5126 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5127  SmallVectorImpl<OpFoldResult> &results) {
5128  return memref::foldMemRefCast(*this);
5129 }
5130 
5131 //===----------------------------------------------------------------------===//
5132 // MaskedLoadOp
5133 //===----------------------------------------------------------------------===//
5134 
5135 LogicalResult MaskedLoadOp::verify() {
5136  VectorType maskVType = getMaskVectorType();
5137  VectorType passVType = getPassThruVectorType();
5138  VectorType resVType = getVectorType();
5139  MemRefType memType = getMemRefType();
5140 
5141  if (resVType.getElementType() != memType.getElementType())
5142  return emitOpError("base and result element type should match");
5143  if (llvm::size(getIndices()) != memType.getRank())
5144  return emitOpError("requires ") << memType.getRank() << " indices";
5145  if (resVType.getShape() != maskVType.getShape())
5146  return emitOpError("expected result shape to match mask shape");
5147  if (resVType != passVType)
5148  return emitOpError("expected pass_thru of same type as result type");
5149  return success();
5150 }
5151 
5152 namespace {
5153 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
5154 public:
5156  LogicalResult matchAndRewrite(MaskedLoadOp load,
5157  PatternRewriter &rewriter) const override {
5158  switch (getMaskFormat(load.getMask())) {
5159  case MaskFormat::AllTrue:
5160  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5161  load, load.getType(), load.getBase(), load.getIndices());
5162  return success();
5163  case MaskFormat::AllFalse:
5164  rewriter.replaceOp(load, load.getPassThru());
5165  return success();
5166  case MaskFormat::Unknown:
5167  return failure();
5168  }
5169  llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
5170  }
5171 };
5172 } // namespace
5173 
5174 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5175  MLIRContext *context) {
5176  results.add<MaskedLoadFolder>(context);
5177 }
5178 
5179 OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5180  if (succeeded(memref::foldMemRefCast(*this)))
5181  return getResult();
5182  return OpFoldResult();
5183 }
5184 
5185 //===----------------------------------------------------------------------===//
5186 // MaskedStoreOp
5187 //===----------------------------------------------------------------------===//
5188 
5189 LogicalResult MaskedStoreOp::verify() {
5190  VectorType maskVType = getMaskVectorType();
5191  VectorType valueVType = getVectorType();
5192  MemRefType memType = getMemRefType();
5193 
5194  if (valueVType.getElementType() != memType.getElementType())
5195  return emitOpError("base and valueToStore element type should match");
5196  if (llvm::size(getIndices()) != memType.getRank())
5197  return emitOpError("requires ") << memType.getRank() << " indices";
5198  if (valueVType.getShape() != maskVType.getShape())
5199  return emitOpError("expected valueToStore shape to match mask shape");
5200  return success();
5201 }
5202 
5203 namespace {
5204 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
5205 public:
5207  LogicalResult matchAndRewrite(MaskedStoreOp store,
5208  PatternRewriter &rewriter) const override {
5209  switch (getMaskFormat(store.getMask())) {
5210  case MaskFormat::AllTrue:
5211  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5212  store, store.getValueToStore(), store.getBase(), store.getIndices());
5213  return success();
5214  case MaskFormat::AllFalse:
5215  rewriter.eraseOp(store);
5216  return success();
5217  case MaskFormat::Unknown:
5218  return failure();
5219  }
5220  llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
5221  }
5222 };
5223 } // namespace
5224 
5225 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5226  MLIRContext *context) {
5227  results.add<MaskedStoreFolder>(context);
5228 }
5229 
5230 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5231  SmallVectorImpl<OpFoldResult> &results) {
5232  return memref::foldMemRefCast(*this);
5233 }
5234 
5235 //===----------------------------------------------------------------------===//
5236 // GatherOp
5237 //===----------------------------------------------------------------------===//
5238 
5239 LogicalResult GatherOp::verify() {
5240  VectorType indVType = getIndexVectorType();
5241  VectorType maskVType = getMaskVectorType();
5242  VectorType resVType = getVectorType();
5243  ShapedType baseType = getBaseType();
5244 
5245  if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5246  return emitOpError("requires base to be a memref or ranked tensor type");
5247 
5248  if (resVType.getElementType() != baseType.getElementType())
5249  return emitOpError("base and result element type should match");
5250  if (llvm::size(getIndices()) != baseType.getRank())
5251  return emitOpError("requires ") << baseType.getRank() << " indices";
5252  if (resVType.getShape() != indVType.getShape())
5253  return emitOpError("expected result dim to match indices dim");
5254  if (resVType.getShape() != maskVType.getShape())
5255  return emitOpError("expected result dim to match mask dim");
5256  if (resVType != getPassThruVectorType())
5257  return emitOpError("expected pass_thru of same type as result type");
5258  return success();
5259 }
5260 
5261 // MaskableOpInterface methods.
5262 
5263 /// Returns the mask type expected by this operation. Mostly used for
5264 /// verification purposes. It requires the operation to be vectorized."
5265 Type GatherOp::getExpectedMaskType() {
5266  auto vecType = this->getIndexVectorType();
5267  return VectorType::get(vecType.getShape(),
5268  IntegerType::get(vecType.getContext(), /*width=*/1),
5269  vecType.getScalableDims());
5270 }
5271 
5272 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5273  return llvm::to_vector<4>(getVectorType().getShape());
5274 }
5275 
5276 /// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
5277 static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
5278  auto vecType = dyn_cast<VectorType>(indexVec.getType());
5279  if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5280  return failure();
5281 
5282  if (indexVec.getDefiningOp<StepOp>())
5283  return success();
5284 
5285  DenseIntElementsAttr elements;
5286  if (!matchPattern(indexVec, m_Constant(&elements)))
5287  return failure();
5288 
5289  return success(
5290  llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
5291 }
5292 
5293 namespace {
5294 class GatherFolder final : public OpRewritePattern<GatherOp> {
5295 public:
5297  LogicalResult matchAndRewrite(GatherOp gather,
5298  PatternRewriter &rewriter) const override {
5299  switch (getMaskFormat(gather.getMask())) {
5300  case MaskFormat::AllTrue:
5301  return failure(); // no unmasked equivalent
5302  case MaskFormat::AllFalse:
5303  rewriter.replaceOp(gather, gather.getPassThru());
5304  return success();
5305  case MaskFormat::Unknown:
5306  return failure();
5307  }
5308  llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
5309  }
5310 };
5311 
5312 /// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
5313 /// maskedload. Only 1D fixed vectors are supported for now.
5314 class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
5315 public:
5317  LogicalResult matchAndRewrite(GatherOp op,
5318  PatternRewriter &rewriter) const override {
5319  if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5320  return failure();
5321 
5322  rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
5323  op.getIndices(), op.getMask(),
5324  op.getPassThru());
5325  return success();
5326  }
5327 };
5328 } // namespace
5329 
5330 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
5331  MLIRContext *context) {
5332  results.add<GatherFolder, FoldContiguousGather>(context);
5333 }
5334 
5335 //===----------------------------------------------------------------------===//
5336 // ScatterOp
5337 //===----------------------------------------------------------------------===//
5338 
5339 LogicalResult ScatterOp::verify() {
5340  VectorType indVType = getIndexVectorType();
5341  VectorType maskVType = getMaskVectorType();
5342  VectorType valueVType = getVectorType();
5343  MemRefType memType = getMemRefType();
5344 
5345  if (valueVType.getElementType() != memType.getElementType())
5346  return emitOpError("base and valueToStore element type should match");
5347  if (llvm::size(getIndices()) != memType.getRank())
5348  return emitOpError("requires ") << memType.getRank() << " indices";
5349  if (valueVType.getShape() != indVType.getShape())
5350  return emitOpError("expected valueToStore dim to match indices dim");
5351  if (valueVType.getShape() != maskVType.getShape())
5352  return emitOpError("expected valueToStore dim to match mask dim");
5353  return success();
5354 }
5355 
5356 namespace {
5357 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
5358 public:
5360  LogicalResult matchAndRewrite(ScatterOp scatter,
5361  PatternRewriter &rewriter) const override {
5362  switch (getMaskFormat(scatter.getMask())) {
5363  case MaskFormat::AllTrue:
5364  return failure(); // no unmasked equivalent
5365  case MaskFormat::AllFalse:
5366  rewriter.eraseOp(scatter);
5367  return success();
5368  case MaskFormat::Unknown:
5369  return failure();
5370  }
5371  llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
5372  }
5373 };
5374 
5375 /// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
5376 /// maskedstore. Only 1D fixed vectors are supported for now.
5377 class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
5378 public:
5380  LogicalResult matchAndRewrite(ScatterOp op,
5381  PatternRewriter &rewriter) const override {
5382  if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5383  return failure();
5384 
5385  rewriter.replaceOpWithNewOp<MaskedStoreOp>(
5386  op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
5387  return success();
5388  }
5389 };
5390 } // namespace
5391 
5392 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
5393  MLIRContext *context) {
5394  results.add<ScatterFolder, FoldContiguousScatter>(context);
5395 }
5396 
5397 //===----------------------------------------------------------------------===//
5398 // ExpandLoadOp
5399 //===----------------------------------------------------------------------===//
5400 
5401 LogicalResult ExpandLoadOp::verify() {
5402  VectorType maskVType = getMaskVectorType();
5403  VectorType passVType = getPassThruVectorType();
5404  VectorType resVType = getVectorType();
5405  MemRefType memType = getMemRefType();
5406 
5407  if (resVType.getElementType() != memType.getElementType())
5408  return emitOpError("base and result element type should match");
5409  if (llvm::size(getIndices()) != memType.getRank())
5410  return emitOpError("requires ") << memType.getRank() << " indices";
5411  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5412  return emitOpError("expected result dim to match mask dim");
5413  if (resVType != passVType)
5414  return emitOpError("expected pass_thru of same type as result type");
5415  return success();
5416 }
5417 
5418 namespace {
5419 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
5420 public:
5422  LogicalResult matchAndRewrite(ExpandLoadOp expand,
5423  PatternRewriter &rewriter) const override {
5424  switch (getMaskFormat(expand.getMask())) {
5425  case MaskFormat::AllTrue:
5426  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5427  expand, expand.getType(), expand.getBase(), expand.getIndices());
5428  return success();
5429  case MaskFormat::AllFalse:
5430  rewriter.replaceOp(expand, expand.getPassThru());
5431  return success();
5432  case MaskFormat::Unknown:
5433  return failure();
5434  }
5435  llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
5436  }
5437 };
5438 } // namespace
5439 
5440 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5441  MLIRContext *context) {
5442  results.add<ExpandLoadFolder>(context);
5443 }
5444 
5445 //===----------------------------------------------------------------------===//
5446 // CompressStoreOp
5447 //===----------------------------------------------------------------------===//
5448 
5449 LogicalResult CompressStoreOp::verify() {
5450  VectorType maskVType = getMaskVectorType();
5451  VectorType valueVType = getVectorType();
5452  MemRefType memType = getMemRefType();
5453 
5454  if (valueVType.getElementType() != memType.getElementType())
5455  return emitOpError("base and valueToStore element type should match");
5456  if (llvm::size(getIndices()) != memType.getRank())
5457  return emitOpError("requires ") << memType.getRank() << " indices";
5458  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5459  return emitOpError("expected valueToStore dim to match mask dim");
5460  return success();
5461 }
5462 
5463 namespace {
5464 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
5465 public:
5467  LogicalResult matchAndRewrite(CompressStoreOp compress,
5468  PatternRewriter &rewriter) const override {
5469  switch (getMaskFormat(compress.getMask())) {
5470  case MaskFormat::AllTrue:
5471  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5472  compress, compress.getValueToStore(), compress.getBase(),
5473  compress.getIndices());
5474  return success();
5475  case MaskFormat::AllFalse:
5476  rewriter.eraseOp(compress);
5477  return success();
5478  case MaskFormat::Unknown:
5479  return failure();
5480  }
5481  llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
5482  }
5483 };
5484 } // namespace
5485 
5486 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5487  MLIRContext *context) {
5488  results.add<CompressStoreFolder>(context);
5489 }
5490 
5491 //===----------------------------------------------------------------------===//
5492 // ShapeCastOp
5493 //===----------------------------------------------------------------------===//
5494 
5495 void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
5496  SetIntRangeFn setResultRanges) {
5497  setResultRanges(getResult(), argRanges.front());
5498 }
5499 
5500 /// Returns true if each element of 'a' is equal to the product of a contiguous
5501 /// sequence of the elements of 'b'. Returns false otherwise.
5502 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
5503  unsigned rankA = a.size();
5504  unsigned rankB = b.size();
5505  assert(rankA < rankB);
5506 
5507  auto isOne = [](int64_t v) { return v == 1; };
5508 
5509  // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5510  // casted to a 0-d vector.
5511  if (rankA == 0 && llvm::all_of(b, isOne))
5512  return true;
5513 
5514  unsigned i = 0;
5515  unsigned j = 0;
5516  while (i < rankA && j < rankB) {
5517  int64_t dimA = a[i];
5518  int64_t dimB = 1;
5519  while (dimB < dimA && j < rankB)
5520  dimB *= b[j++];
5521  if (dimA != dimB)
5522  break;
5523  ++i;
5524 
5525  // Handle the case when trailing dimensions are of size 1.
5526  // Include them into the contiguous sequence.
5527  if (i < rankA && llvm::all_of(a.slice(i), isOne))
5528  i = rankA;
5529  if (j < rankB && llvm::all_of(b.slice(j), isOne))
5530  j = rankB;
5531  }
5532 
5533  return i == rankA && j == rankB;
5534 }
5535 
5536 static LogicalResult verifyVectorShapeCast(Operation *op,
5537  VectorType sourceVectorType,
5538  VectorType resultVectorType) {
5539  // Check that element type is the same.
5540  if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5541  return op->emitOpError("source/result vectors must have same element type");
5542  auto sourceShape = sourceVectorType.getShape();
5543  auto resultShape = resultVectorType.getShape();
5544 
5545  // Check that product of source dim sizes matches product of result dim sizes.
5546  int64_t sourceDimProduct = std::accumulate(
5547  sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5548  int64_t resultDimProduct = std::accumulate(
5549  resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5550  if (sourceDimProduct != resultDimProduct)
5551  return op->emitOpError("source/result number of elements must match");
5552 
5553  // Check that expanding/contracting rank cases.
5554  unsigned sourceRank = sourceVectorType.getRank();
5555  unsigned resultRank = resultVectorType.getRank();
5556  if (sourceRank < resultRank) {
5557  if (!isValidShapeCast(sourceShape, resultShape))
5558  return op->emitOpError("invalid shape cast");
5559  } else if (sourceRank > resultRank) {
5560  if (!isValidShapeCast(resultShape, sourceShape))
5561  return op->emitOpError("invalid shape cast");
5562  }
5563 
5564  // Check that (non-)scalability is preserved
5565  int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5566  int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5567  if (sourceNScalableDims != resultNScalableDims)
5568  return op->emitOpError("different number of scalable dims at source (")
5569  << sourceNScalableDims << ") and result (" << resultNScalableDims
5570  << ")";
5571  sourceVectorType.getNumDynamicDims();
5572 
5573  return success();
5574 }
5575 
5576 LogicalResult ShapeCastOp::verify() {
5577  auto sourceVectorType =
5578  llvm::dyn_cast_or_null<VectorType>(getSource().getType());
5579  auto resultVectorType =
5580  llvm::dyn_cast_or_null<VectorType>(getResult().getType());
5581 
5582  // Check if source/result are of vector type.
5583  if (sourceVectorType && resultVectorType)
5584  return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
5585 
5586  return success();
5587 }
5588 
5589 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5590  // No-op shape cast.
5591  if (getSource().getType() == getResult().getType())
5592  return getSource();
5593 
5594  // Canceling shape casts.
5595  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5596  if (getResult().getType() == otherOp.getSource().getType())
5597  return otherOp.getSource();
5598 
5599  // Only allows valid transitive folding.
5600  VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5601  VectorType resultType = llvm::cast<VectorType>(getResult().getType());
5602  if (srcType.getRank() < resultType.getRank()) {
5603  if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5604  return {};
5605  } else if (srcType.getRank() > resultType.getRank()) {
5606  if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5607  return {};
5608  } else {
5609  return {};
5610  }
5611 
5612  setOperand(otherOp.getSource());
5613  return getResult();
5614  }
5615 
5616  // Cancelling broadcast and shape cast ops.
5617  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5618  if (bcastOp.getSourceType() == getType())
5619  return bcastOp.getSource();
5620  }
5621 
5622  return {};
5623 }
5624 
5625 namespace {
5626 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
5627 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
5628 public:
5630 
5631  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5632  PatternRewriter &rewriter) const override {
5633  auto constantOp =
5634  shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5635  if (!constantOp)
5636  return failure();
5637  // Only handle splat for now.
5638  auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5639  if (!dense)
5640  return failure();
5641  auto newAttr =
5642  DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
5643  dense.getSplatValue<Attribute>());
5644  rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
5645  return success();
5646  }
5647 };
5648 
5649 /// Helper function that computes a new vector type based on the input vector
5650 /// type by removing the trailing one dims:
5651 ///
5652 /// vector<4x1x1xi1> --> vector<4x1>
5653 ///
5654 static VectorType trimTrailingOneDims(VectorType oldType) {
5655  ArrayRef<int64_t> oldShape = oldType.getShape();
5656  ArrayRef<int64_t> newShape = oldShape;
5657 
5658  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
5659  ArrayRef<bool> newScalableDims = oldScalableDims;
5660 
5661  while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5662  newShape = newShape.drop_back(1);
5663  newScalableDims = newScalableDims.drop_back(1);
5664  }
5665 
5666  // Make sure we have at least 1 dimension.
5667  // TODO: Add support for 0-D vectors.
5668  if (newShape.empty()) {
5669  newShape = oldShape.take_back();
5670  newScalableDims = oldScalableDims.take_back();
5671  }
5672 
5673  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5674 }
5675 
5676 /// Folds qualifying shape_cast(create_mask) into a new create_mask
5677 ///
5678 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
5679 /// dimension. If the input vector comes from `vector.create_mask` for which
5680 /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
5681 /// to fold shape_cast into create_mask.
5682 ///
5683 /// BEFORE:
5684 /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
5685 /// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
5686 /// AFTER:
5687 /// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
5688 class ShapeCastCreateMaskFolderTrailingOneDim final
5689  : public OpRewritePattern<ShapeCastOp> {
5690 public:
5692 
5693  LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
5694  PatternRewriter &rewriter) const override {
5695  Value shapeOpSrc = shapeOp->getOperand(0);
5696  auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
5697  auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
5698  if (!createMaskOp && !constantMaskOp)
5699  return failure();
5700 
5701  VectorType shapeOpResTy = shapeOp.getResultVectorType();
5702  VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5703 
5704  VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5705  if (newVecType != shapeOpResTy)
5706  return failure();
5707 
5708  auto numDimsToDrop =
5709  shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5710 
5711  // No unit dims to drop
5712  if (!numDimsToDrop)
5713  return failure();
5714 
5715  if (createMaskOp) {
5716  auto maskOperands = createMaskOp.getOperands();
5717  auto numMaskOperands = maskOperands.size();
5718 
5719  // Check every mask dim size to see whether it can be dropped
5720  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5721  --i) {
5722  auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5723  if (!constant || (constant.value() != 1))
5724  return failure();
5725  }
5726  SmallVector<Value> newMaskOperands =
5727  maskOperands.drop_back(numDimsToDrop);
5728 
5729  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
5730  newMaskOperands);
5731  return success();
5732  }
5733 
5734  if (constantMaskOp) {
5735  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5736  auto numMaskOperands = maskDimSizes.size();
5737 
5738  // Check every mask dim size to see whether it can be dropped
5739  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5740  --i) {
5741  if (maskDimSizes[i] != 1)
5742  return failure();
5743  }
5744 
5745  auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5746  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
5747  newMaskOperands);
5748  return success();
5749  }
5750 
5751  return failure();
5752  }
5753 };
5754 
5755 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
5756 /// This only applies when the shape of the broadcast source
5757 /// 1. is a suffix of the shape of the result (i.e. when broadcast without
5758 /// reshape is expressive enough to capture the result in a single op), or
5759 /// 2. has the same element count as the shape cast result.
5760 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5761 public:
5763 
5764  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5765  PatternRewriter &rewriter) const override {
5766  auto broadcastOp =
5767  shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5768  if (!broadcastOp)
5769  return failure();
5770 
5771  ArrayRef<int64_t> broadcastSourceShape;
5772  if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5773  broadcastSourceShape = srcType.getShape();
5774  ArrayRef<int64_t> shapeCastTargetShape =
5775  shapeCastOp.getResultVectorType().getShape();
5776 
5777  // If `broadcastSourceShape` is a suffix of the result, we can just replace
5778  // with a broadcast to the final shape.
5779  if (broadcastSourceShape ==
5780  shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5781  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5782  shapeCastOp, shapeCastOp.getResultVectorType(),
5783  broadcastOp.getSource());
5784  return success();
5785  }
5786 
5787  // Otherwise, if the final result has the same element count, we can replace
5788  // with a shape cast.
5789  if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5790  if (srcType.getNumElements() ==
5791  shapeCastOp.getResultVectorType().getNumElements()) {
5792  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
5793  shapeCastOp, shapeCastOp.getResultVectorType(),
5794  broadcastOp.getSource());
5795  return success();
5796  }
5797  }
5798 
5799  return failure();
5800  }
5801 };
5802 
5803 } // namespace
5804 
5805 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
5806  MLIRContext *context) {
5807  results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5808  ShapeCastBroadcastFolder>(context);
5809 }
5810 
5811 //===----------------------------------------------------------------------===//
5812 // VectorBitCastOp
5813 //===----------------------------------------------------------------------===//
5814 
5815 LogicalResult BitCastOp::verify() {
5816  auto sourceVectorType = getSourceVectorType();
5817  auto resultVectorType = getResultVectorType();
5818 
5819  for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5820  if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5821  return emitOpError("dimension size mismatch at: ") << i;
5822  }
5823 
5824  DataLayout dataLayout = DataLayout::closest(*this);
5825  auto sourceElementBits =
5826  dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
5827  auto resultElementBits =
5828  dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
5829 
5830  if (sourceVectorType.getRank() == 0) {
5831  if (sourceElementBits != resultElementBits)
5832  return emitOpError("source/result bitwidth of the 0-D vector element "
5833  "types must be equal");
5834  } else if (sourceElementBits * sourceVectorType.getShape().back() !=
5835  resultElementBits * resultVectorType.getShape().back()) {
5836  return emitOpError(
5837  "source/result bitwidth of the minor 1-D vectors must be equal");
5838  }
5839 
5840  return success();
5841 }
5842 
5843 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
5844  // Nop cast.
5845  if (getSource().getType() == getResult().getType())
5846  return getSource();
5847 
5848  // Canceling bitcasts.
5849  if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5850  if (getResult().getType() == otherOp.getSource().getType())
5851  return otherOp.getSource();
5852 
5853  setOperand(otherOp.getSource());
5854  return getResult();
5855  }
5856 
5857  Attribute sourceConstant = adaptor.getSource();
5858  if (!sourceConstant)
5859  return {};
5860 
5861  Type srcElemType = getSourceVectorType().getElementType();
5862  Type dstElemType = getResultVectorType().getElementType();
5863 
5864  if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5865  if (floatPack.isSplat()) {
5866  auto splat = floatPack.getSplatValue<FloatAttr>();
5867 
5868  // Casting fp16 into fp32.
5869  if (srcElemType.isF16() && dstElemType.isF32()) {
5870  uint32_t bits = static_cast<uint32_t>(
5871  splat.getValue().bitcastToAPInt().getZExtValue());
5872  // Duplicate the 16-bit pattern.
5873  bits = (bits << 16) | (bits & 0xffff);
5874  APInt intBits(32, bits);
5875  APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5876  return DenseElementsAttr::get(getResultVectorType(), floatBits);
5877  }
5878  }
5879  }
5880 
5881  if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5882  if (intPack.isSplat()) {
5883  auto splat = intPack.getSplatValue<IntegerAttr>();
5884 
5885  if (llvm::isa<IntegerType>(dstElemType)) {
5886  uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
5887  uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
5888 
5889  // Casting to a larger integer bit width.
5890  if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5891  APInt intBits = splat.getValue().zext(dstBitWidth);
5892 
5893  // Duplicate the lower width element.
5894  for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5895  intBits = (intBits << srcBitWidth) | intBits;
5896  return DenseElementsAttr::get(getResultVectorType(), intBits);
5897  }
5898  }
5899  }
5900  }
5901 
5902  return {};
5903 }
5904 
5905 //===----------------------------------------------------------------------===//
5906 // TypeCastOp
5907 //===----------------------------------------------------------------------===//
5908 
5909 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
5910  auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5911  SmallVector<int64_t, 8> res(memRefType.getShape());
5912  if (vectorType)
5913  res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5914  return res;
5915 }
5916 
5917 /// Build the canonical memRefType with a single vector.
5918 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
5919 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
5920  Value source) {
5921  result.addOperands(source);
5922  MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
5923  VectorType vectorType =
5924  VectorType::get(extractShape(memRefType),
5926  result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
5927  memRefType.getMemorySpace()));
5928 }
5929 
5930 LogicalResult TypeCastOp::verify() {
5931  MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
5932  if (!canonicalType.getLayout().isIdentity())
5933  return emitOpError("expects operand to be a memref with identity layout");
5934  if (!getResultMemRefType().getLayout().isIdentity())
5935  return emitOpError("expects result to be a memref with identity layout");
5936  if (getResultMemRefType().getMemorySpace() !=
5937  getMemRefType().getMemorySpace())
5938  return emitOpError("expects result in same memory space");
5939 
5940  auto sourceType = getMemRefType();
5941  auto resultType = getResultMemRefType();
5942  if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
5944  return emitOpError(
5945  "expects result and operand with same underlying scalar type: ")
5946  << resultType;
5947  if (extractShape(sourceType) != extractShape(resultType))
5948  return emitOpError(
5949  "expects concatenated result and operand shapes to be equal: ")
5950  << resultType;
5951  return success();
5952 }
5953 
5954 //===----------------------------------------------------------------------===//
5955 // TransposeOp
5956 //===----------------------------------------------------------------------===//
5957 
5958 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
5959  Value vector, ArrayRef<int64_t> permutation) {
5960  VectorType vt = llvm::cast<VectorType>(vector.getType());
5961  SmallVector<int64_t, 4> transposedShape(vt.getRank());
5962  SmallVector<bool, 4> transposedScalableDims(vt.getRank());
5963  for (unsigned i = 0; i < permutation.size(); ++i) {
5964  transposedShape[i] = vt.getShape()[permutation[i]];
5965  transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5966  }
5967 
5968  result.addOperands(vector);
5969  result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
5970  transposedScalableDims));
5971  result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
5972  builder.getDenseI64ArrayAttr(permutation));
5973 }
5974 
5975 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5976  // Eliminate splat constant transpose ops.
5977  if (auto attr =
5978  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
5979  if (attr.isSplat())
5980  return attr.reshape(getResultVectorType());
5981 
5982  // Eliminate identity transpose ops. This happens when the dimensions of the
5983  // input vector remain in their original order after the transpose operation.
5984  ArrayRef<int64_t> perm = getPermutation();
5985 
5986  // Check if the permutation of the dimensions contains sequential values:
5987  // {0, 1, 2, ...}.
5988  for (int64_t i = 0, e = perm.size(); i < e; i++) {
5989  if (perm[i] != i)
5990  return {};
5991  }
5992 
5993  return getVector();
5994 }
5995 
5996 LogicalResult vector::TransposeOp::verify() {
5997  VectorType vectorType = getSourceVectorType();
5998  VectorType resultType = getResultVectorType();
5999  int64_t rank = resultType.getRank();
6000  if (vectorType.getRank() != rank)
6001  return emitOpError("vector result rank mismatch: ") << rank;
6002  // Verify transposition array.
6003  ArrayRef<int64_t> perm = getPermutation();
6004  int64_t size = perm.size();
6005  if (rank != size)
6006  return emitOpError("transposition length mismatch: ") << size;
6007  SmallVector<bool, 8> seen(rank, false);
6008  for (const auto &ta : llvm::enumerate(perm)) {
6009  if (ta.value() < 0 || ta.value() >= rank)
6010  return emitOpError("transposition index out of range: ") << ta.value();
6011  if (seen[ta.value()])
6012  return emitOpError("duplicate position index: ") << ta.value();
6013  seen[ta.value()] = true;
6014  if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6015  return emitOpError("dimension size mismatch at: ") << ta.value();
6016  }
6017  return success();
6018 }
6019 
6020 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6021  return llvm::to_vector<4>(getResultVectorType().getShape());
6022 }
6023 
6024 namespace {
6025 
6026 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
6027 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
6028 public:
6030 
6031  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6032  PatternRewriter &rewriter) const override {
6033  // Composes two permutations: result[i] = permutation1[permutation2[i]].
6034  auto composePermutations = [](ArrayRef<int64_t> permutation1,
6035  ArrayRef<int64_t> permutation2) {
6036  SmallVector<int64_t, 4> result;
6037  for (auto index : permutation2)
6038  result.push_back(permutation1[index]);
6039  return result;
6040  };
6041 
6042  // Return if the input of 'transposeOp' is not defined by another transpose.
6043  vector::TransposeOp parentTransposeOp =
6044  transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6045  if (!parentTransposeOp)
6046  return failure();
6047 
6048  SmallVector<int64_t, 4> permutation = composePermutations(
6049  parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6050  // Replace 'transposeOp' with a new transpose operation.
6051  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
6052  transposeOp, transposeOp.getResult().getType(),
6053  parentTransposeOp.getVector(), permutation);
6054  return success();
6055  }
6056 };
6057 
6058 // Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
6059 struct FoldTransposedScalarBroadcast final
6060  : public OpRewritePattern<vector::TransposeOp> {
6062 
6063  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6064  PatternRewriter &rewriter) const override {
6065  auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
6066  if (!bcastOp)
6067  return failure();
6068 
6069  auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
6070  if (!srcVectorType || srcVectorType.getNumElements() == 1) {
6071  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6072  transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
6073  return success();
6074  }
6075 
6076  return failure();
6077  }
6078 };
6079 
6080 // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
6081 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
6082 public:
6084 
6085  LogicalResult matchAndRewrite(TransposeOp transposeOp,
6086  PatternRewriter &rewriter) const override {
6087  auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
6088  if (!splatOp)
6089  return failure();
6090 
6091  rewriter.replaceOpWithNewOp<vector::SplatOp>(
6092  transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
6093  return success();
6094  }
6095 };
6096 
6097 /// Folds transpose(create_mask) into a new transposed create_mask.
6098 class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6099 public:
6101 
6102  LogicalResult matchAndRewrite(TransposeOp transpOp,
6103  PatternRewriter &rewriter) const override {
6104  Value transposeSrc = transpOp.getVector();
6105  auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
6106  auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
6107  if (!createMaskOp && !constantMaskOp)
6108  return failure();
6109 
6110  // Get the transpose permutation and apply it to the vector.create_mask or
6111  // vector.constant_mask operands.
6112  ArrayRef<int64_t> permutation = transpOp.getPermutation();
6113 
6114  if (createMaskOp) {
6115  auto maskOperands = createMaskOp.getOperands();
6116  SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
6117  applyPermutationToVector(newOperands, permutation);
6118 
6119  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
6120  transpOp, transpOp.getResultVectorType(), newOperands);
6121  return success();
6122  }
6123 
6124  // ConstantMaskOp case.
6125  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6126  auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
6127 
6128  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
6129  transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6130  return success();
6131  }
6132 };
6133 
6134 } // namespace
6135 
6136 void vector::TransposeOp::getCanonicalizationPatterns(
6137  RewritePatternSet &results, MLIRContext *context) {
6138  results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
6139  TransposeFolder, FoldTransposeSplat>(context);
6140 }
6141 
6142 //===----------------------------------------------------------------------===//
6143 // ConstantMaskOp
6144 //===----------------------------------------------------------------------===//
6145 
6146 void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
6147  VectorType type, ConstantMaskKind kind) {
6148  assert(kind == ConstantMaskKind::AllTrue ||
6149  kind == ConstantMaskKind::AllFalse);
6150  build(builder, result, type,
6151  kind == ConstantMaskKind::AllTrue
6152  ? type.getShape()
6153  : SmallVector<int64_t>(type.getRank(), 0));
6154 }
6155 
6156 LogicalResult ConstantMaskOp::verify() {
6157  auto resultType = llvm::cast<VectorType>(getResult().getType());
6158  // Check the corner case of 0-D vectors first.
6159  if (resultType.getRank() == 0) {
6160  if (getMaskDimSizes().size() != 1)
6161  return emitError("array attr must have length 1 for 0-D vectors");
6162  auto dim = getMaskDimSizes()[0];
6163  if (dim != 0 && dim != 1)
6164  return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
6165  return success();
6166  }
6167 
6168  // Verify that array attr size matches the rank of the vector result.
6169  if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
6170  return emitOpError(
6171  "must specify array attr of size equal vector result rank");
6172  // Verify that each array attr element is in bounds of corresponding vector
6173  // result dimension size.
6174  auto resultShape = resultType.getShape();
6175  auto resultScalableDims = resultType.getScalableDims();
6176  ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
6177  for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
6178  if (maskDimSize < 0 || maskDimSize > resultShape[index])
6179  return emitOpError(
6180  "array attr of size out of bounds of vector result dimension size");
6181  if (resultScalableDims[index] && maskDimSize != 0 &&
6182  maskDimSize != resultShape[index])
6183  return emitOpError(
6184  "only supports 'none set' or 'all set' scalable dimensions");
6185  }
6186  // Verify that if one mask dim size is zero, they all should be zero (because
6187  // the mask region is a conjunction of each mask dimension interval).
6188  bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6189  bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
6190  if (anyZeros && !allZeros)
6191  return emitOpError("expected all mask dim sizes to be zeros, "
6192  "as a result of conjunction with zero mask dim");
6193  return success();
6194 }
6195 
6196 bool ConstantMaskOp::isAllOnesMask() {
6197  auto resultType = getVectorType();
6198  // Check the corner case of 0-D vectors first.
6199  if (resultType.getRank() == 0) {
6200  assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
6201  return getMaskDimSizes()[0] == 1;
6202  }
6203  for (const auto [resultSize, maskDimSize] :
6204  llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
6205  if (maskDimSize < resultSize)
6206  return false;
6207  }
6208  return true;
6209 }
6210 
6211 //===----------------------------------------------------------------------===//
6212 // CreateMaskOp
6213 //===----------------------------------------------------------------------===//
6214 
6215 void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
6216  VectorType type,
6217  ArrayRef<OpFoldResult> mixedOperands) {
6218  SmallVector<Value> operands =
6219  getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
6220  build(builder, result, type, operands);
6221 }
6222 
6223 LogicalResult CreateMaskOp::verify() {
6224  auto vectorType = llvm::cast<VectorType>(getResult().getType());
6225  // Verify that an operand was specified for each result vector each dimension.
6226  if (vectorType.getRank() == 0) {
6227  if (getNumOperands() != 1)
6228  return emitOpError(
6229  "must specify exactly one operand for 0-D create_mask");
6230  } else if (getNumOperands() !=
6231  llvm::cast<VectorType>(getResult().getType()).getRank()) {
6232  return emitOpError(
6233  "must specify an operand for each result vector dimension");
6234  }
6235  return success();
6236 }
6237 
6238 namespace {
6239 
6240 /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
6241 ///
6242 /// Ex 1:
6243 /// %c2 = arith.constant 2 : index
6244 /// %c3 = arith.constant 3 : index
6245 /// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
6246 /// Becomes:
6247 /// vector.constant_mask [3, 2] : vector<4x3xi1>
6248 ///
6249 /// Ex 2:
6250 /// %c_neg_1 = arith.constant -1 : index
6251 /// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
6252 /// becomes:
6253 /// vector.constant_mask [0] : vector<[8]xi1>
6254 ///
6255 /// Ex 3:
6256 /// %c8 = arith.constant 8 : index
6257 /// %c16 = arith.constant 16 : index
6258 /// %0 = vector.vscale
6259 /// %1 = arith.muli %0, %c16 : index
6260 /// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
6261 /// becomes:
6262 /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
6263 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
6264 public:
6266 
6267  LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
6268  PatternRewriter &rewriter) const override {
6269  VectorType maskType = createMaskOp.getVectorType();
6270  ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
6271  ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6272 
6273  // Special case: Rank zero shape.
6274  constexpr std::array<int64_t, 1> rankZeroShape{1};
6275  constexpr std::array<bool, 1> rankZeroScalableDims{false};
6276  if (maskType.getRank() == 0) {
6277  maskTypeDimSizes = rankZeroShape;
6278  maskTypeDimScalableFlags = rankZeroScalableDims;
6279  }
6280 
6281  // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
6282  // collect the `constantDims` (for the ConstantMaskOp).
6283  SmallVector<int64_t, 4> constantDims;
6284  for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
6285  if (auto intSize = getConstantIntValue(dimSize)) {
6286  // Constant value.
6287  // If the mask dim is non-scalable this can be any value.
6288  // If the mask dim is scalable only zero (all-false) is supported.
6289  if (maskTypeDimScalableFlags[i] && intSize >= 0)
6290  return failure();
6291  constantDims.push_back(*intSize);
6292  } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
6293  // Constant vscale multiple (e.g. 4 x vscale).
6294  // Must be all-true to fold to a ConstantMask.
6295  if (vscaleMultiplier < maskTypeDimSizes[i])
6296  return failure();
6297  constantDims.push_back(*vscaleMultiplier);
6298  } else {
6299  return failure();
6300  }
6301  }
6302 
6303  // Clamp values to constant_mask bounds.
6304  for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
6305  value = std::clamp<int64_t>(value, 0, maskDimSize);
6306 
6307  // If one of dim sizes is zero, set all dims to zero.
6308  if (llvm::is_contained(constantDims, 0))
6309  constantDims.assign(constantDims.size(), 0);
6310 
6311  // Replace 'createMaskOp' with ConstantMaskOp.
6312  rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
6313  constantDims);
6314  return success();
6315  }
6316 };
6317 
6318 } // namespace
6319 
6320 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6321  MLIRContext *context) {
6322  results.add<CreateMaskFolder>(context);
6323 }
6324 
6325 //===----------------------------------------------------------------------===//
6326 // MaskOp
6327 //===----------------------------------------------------------------------===//
6328 
6329 void MaskOp::build(
6330  OpBuilder &builder, OperationState &result, Value mask,
6331  Operation *maskableOp,
6332  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6333  assert(maskRegionBuilder &&
6334  "builder callback for 'maskRegion' must be present");
6335 
6336  result.addOperands(mask);
6337  OpBuilder::InsertionGuard guard(builder);
6338  Region *maskRegion = result.addRegion();
6339  builder.createBlock(maskRegion);
6340  maskRegionBuilder(builder, maskableOp);
6341 }
6342 
6343 void MaskOp::build(
6344  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6345  Value mask, Operation *maskableOp,
6346  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6347  build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
6348  maskRegionBuilder);
6349 }
6350 
6351 void MaskOp::build(
6352  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6353  Value mask, Value passthru, Operation *maskableOp,
6354  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6355  build(builder, result, mask, maskableOp, maskRegionBuilder);
6356  if (passthru)
6357  result.addOperands(passthru);
6358  result.addTypes(resultTypes);
6359 }
6360 
6361 ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
6362  // Create the op region.
6363  result.regions.reserve(1);
6364  Region &maskRegion = *result.addRegion();
6365 
6366  auto &builder = parser.getBuilder();
6367 
6368  // Parse all the operands.
6370  if (parser.parseOperand(mask))
6371  return failure();
6372 
6373  // Optional passthru operand.
6375  ParseResult parsePassthru = parser.parseOptionalComma();
6376  if (parsePassthru.succeeded() && parser.parseOperand(passthru))
6377  return failure();
6378 
6379  // Parse op region.
6380  if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
6381  return failure();
6382 
6383  MaskOp::ensureTerminator(maskRegion, builder, result.location);
6384 
6385  // Parse the optional attribute list.
6386  if (parser.parseOptionalAttrDict(result.attributes))
6387  return failure();
6388 
6389  // Parse all the types.
6390  Type maskType;
6391  if (parser.parseColonType(maskType))
6392  return failure();
6393 
6394  SmallVector<Type> resultTypes;
6395  if (parser.parseOptionalArrowTypeList(resultTypes))
6396  return failure();
6397  result.types.append(resultTypes);
6398 
6399  // Resolve operands.
6400  if (parser.resolveOperand(mask, maskType, result.operands))
6401  return failure();
6402 
6403  if (parsePassthru.succeeded())
6404  if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
6405  return failure();
6406 
6407  return success();
6408 }
6409 
6411  p << " " << getMask();
6412  if (getPassthru())
6413  p << ", " << getPassthru();
6414 
6415  // Print single masked operation and skip terminator.
6416  p << " { ";
6417  Block *singleBlock = &getMaskRegion().getBlocks().front();
6418  if (singleBlock && !singleBlock->getOperations().empty())
6419  p.printCustomOrGenericOp(&singleBlock->front());
6420  p << " }";
6421 
6422  p.printOptionalAttrDict(getOperation()->getAttrs());
6423 
6424  p << " : " << getMask().getType();
6425  if (getNumResults() > 0)
6426  p << " -> " << getResultTypes();
6427 }
6428 
6429 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
6431  MaskOp>::ensureTerminator(region, builder, loc);
6432  // Keep the default yield terminator if the number of masked operations is not
6433  // the expected. This case will trigger a verification failure.
6434  Block &block = region.front();
6435  if (block.getOperations().size() != 2)
6436  return;
6437 
6438  // Replace default yield terminator with a new one that returns the results
6439  // from the masked operation.
6440  OpBuilder opBuilder(builder.getContext());
6441  Operation *maskedOp = &block.front();
6442  Operation *oldYieldOp = &block.back();
6443  assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
6444 
6445  // Empty vector.mask op.
6446  if (maskedOp == oldYieldOp)
6447  return;
6448 
6449  opBuilder.setInsertionPoint(oldYieldOp);
6450  opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
6451  oldYieldOp->dropAllReferences();
6452  oldYieldOp->erase();
6453 }
6454 
6455 LogicalResult MaskOp::verify() {
6456  // Structural checks.
6457  Block &block = getMaskRegion().getBlocks().front();
6458  if (block.getOperations().empty())
6459  return emitOpError("expects a terminator within the mask region");
6460 
6461  unsigned numMaskRegionOps = block.getOperations().size();
6462  if (numMaskRegionOps > 2)
6463  return emitOpError("expects only one operation to mask");
6464 
6465  // Terminator checks.
6466  auto terminator = dyn_cast<vector::YieldOp>(block.back());
6467  if (!terminator)
6468  return emitOpError("expects a terminator within the mask region");
6469 
6470  if (terminator->getNumOperands() != getNumResults())
6471  return emitOpError(
6472  "expects number of results to match mask region yielded values");
6473 
6474  // Empty vector.mask. Nothing else to check.
6475  if (numMaskRegionOps == 1)
6476  return success();
6477 
6478  auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
6479  if (!maskableOp)
6480  return emitOpError("expects a MaskableOpInterface within the mask region");
6481 
6482  // Result checks.
6483  if (maskableOp->getNumResults() != getNumResults())
6484  return emitOpError("expects number of results to match maskable operation "
6485  "number of results");
6486 
6487  if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
6488  return emitOpError(
6489  "expects result type to match maskable operation result type");
6490 
6491  if (llvm::count_if(maskableOp->getResultTypes(),
6492  [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
6493  return emitOpError("multiple vector results not supported");
6494 
6495  // Mask checks.
6496  Type expectedMaskType = maskableOp.getExpectedMaskType();
6497  if (getMask().getType() != expectedMaskType)
6498  return emitOpError("expects a ")
6499  << expectedMaskType << " mask for the maskable operation";
6500 
6501  // Passthru checks.
6502  Value passthru = getPassthru();
6503  if (passthru) {
6504  if (!maskableOp.supportsPassthru())
6505  return emitOpError(
6506  "doesn't expect a passthru argument for this maskable operation");
6507 
6508  if (maskableOp->getNumResults() != 1)
6509  return emitOpError("expects result when passthru argument is provided");
6510 
6511  if (passthru.getType() != maskableOp->getResultTypes()[0])
6512  return emitOpError("expects passthru type to match result type");
6513  }
6514 
6515  return success();
6516 }
6517 
6518 /// Folds vector.mask ops with an all-true mask.
6519 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6520  SmallVectorImpl<OpFoldResult> &results) {
6521  MaskFormat maskFormat = getMaskFormat(getMask());
6522  if (isEmpty())
6523  return failure();
6524 
6525  if (maskFormat != MaskFormat::AllTrue)
6526  return failure();
6527 
6528  // Move maskable operation outside of the `vector.mask` region.
6529  Operation *maskableOp = getMaskableOp();
6530  maskableOp->dropAllUses();
6531  maskableOp->moveBefore(getOperation());
6532 
6533  llvm::append_range(results, maskableOp->getResults());
6534  return success();
6535 }
6536 
6537 // Elides empty vector.mask operations with or without return values. Propagates
6538 // the yielded values by the vector.yield terminator, if any, or erases the op,
6539 // otherwise.
6540 class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
6542 
6543  LogicalResult matchAndRewrite(MaskOp maskOp,
6544  PatternRewriter &rewriter) const override {
6545  auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6546  if (maskingOp.getMaskableOp())
6547  return failure();
6548 
6549  if (!maskOp.isEmpty())
6550  return failure();
6551 
6552  Block *block = maskOp.getMaskBlock();
6553  auto terminator = cast<vector::YieldOp>(block->front());
6554  if (terminator.getNumOperands() == 0)
6555  rewriter.eraseOp(maskOp);
6556  else
6557  rewriter.replaceOp(maskOp, terminator.getOperands());
6558 
6559  return success();
6560  }
6561 };
6562 
6563 void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6564  MLIRContext *context) {
6565  results.add<ElideEmptyMaskOp>(context);
6566 }
6567 
6568 // MaskingOpInterface definitions.
6569 
6570 /// Returns the operation masked by this 'vector.mask'.
6571 Operation *MaskOp::getMaskableOp() {
6572  Block *block = getMaskBlock();
6573  if (block->getOperations().size() < 2)
6574  return nullptr;
6575 
6576  return &block->front();
6577 }
6578 
6579 /// Returns true if 'vector.mask' has a passthru value.
6580 bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
6581 
6582 //===----------------------------------------------------------------------===//
6583 // ScanOp
6584 //===----------------------------------------------------------------------===//
6585 
6586 LogicalResult ScanOp::verify() {
6587  VectorType srcType = getSourceType();
6588  VectorType initialType = getInitialValueType();
6589  // Check reduction dimension < rank.
6590  int64_t srcRank = srcType.getRank();
6591  int64_t reductionDim = getReductionDim();
6592  if (reductionDim >= srcRank)
6593  return emitOpError("reduction dimension ")
6594  << reductionDim << " has to be less than " << srcRank;
6595 
6596  // Check that rank(initial_value) = rank(src) - 1.
6597  int64_t initialValueRank = initialType.getRank();
6598  if (initialValueRank != srcRank - 1)
6599  return emitOpError("initial value rank ")
6600  << initialValueRank << " has to be equal to " << srcRank - 1;
6601 
6602  // Check shapes of initial value and src.
6603  ArrayRef<int64_t> srcShape = srcType.getShape();
6604  ArrayRef<int64_t> initialValueShapes = initialType.getShape();
6605  SmallVector<int64_t> expectedShape;
6606  for (int i = 0; i < srcRank; i++) {
6607  if (i != reductionDim)
6608  expectedShape.push_back(srcShape[i]);
6609  }
6610  if (!llvm::equal(initialValueShapes, expectedShape)) {
6611  return emitOpError("incompatible input/initial value shapes");
6612  }
6613 
6614  // Verify supported reduction kind.
6615  Type eltType = getDestType().getElementType();
6616  if (!isSupportedCombiningKind(getKind(), eltType))
6617  return emitOpError("unsupported reduction type ")
6618  << eltType << " for kind '" << stringifyCombiningKind(getKind())
6619  << "'";
6620 
6621  return success();
6622 }
6623 
6626  patterns
6627  .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
6628  ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
6629  StridedSliceConstantMaskFolder, TransposeFolder>(
6630  patterns.getContext(), benefit);
6631 }
6632 
6633 //===----------------------------------------------------------------------===//
6634 // SplatOp
6635 //===----------------------------------------------------------------------===//
6636 
6637 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
6638  auto constOperand = adaptor.getInput();
6639  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
6640  return {};
6641 
6642  // SplatElementsAttr::get treats single value for second arg as being a splat.
6643  return SplatElementsAttr::get(getType(), {constOperand});
6644 }
6645 
6646 void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6647  SetIntRangeFn setResultRanges) {
6648  setResultRanges(getResult(), argRanges.front());
6649 }
6650 
6652  CombiningKind kind, Value v1, Value acc,
6653  arith::FastMathFlagsAttr fastmath,
6654  Value mask) {
6655  Type t1 = getElementTypeOrSelf(v1.getType());
6656  Type tAcc = getElementTypeOrSelf(acc.getType());
6657  Value result;
6658 
6659  switch (kind) {
6660  case CombiningKind::ADD:
6661  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6662  result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
6663  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6664  result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
6665  else
6666  llvm_unreachable("invalid value types for ADD reduction");
6667  break;
6668  case CombiningKind::AND:
6669  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6670  result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
6671  break;
6672  case CombiningKind::MAXNUMF:
6673  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6674  "expected float values");
6675  result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
6676  break;
6677  case CombiningKind::MAXIMUMF:
6678  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6679  "expected float values");
6680  result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
6681  break;
6682  case CombiningKind::MINNUMF:
6683  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6684  "expected float values");
6685  result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
6686  break;
6687  case CombiningKind::MINIMUMF:
6688  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6689  "expected float values");
6690  result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
6691  break;
6692  case CombiningKind::MAXSI:
6693  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6694  result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
6695  break;
6696  case CombiningKind::MINSI:
6697  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6698  result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
6699  break;
6700  case CombiningKind::MAXUI:
6701  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6702  result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
6703  break;
6704  case CombiningKind::MINUI:
6705  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6706  result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
6707  break;
6708  case CombiningKind::MUL:
6709  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6710  result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
6711  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6712  result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
6713  else
6714  llvm_unreachable("invalid value types for MUL reduction");
6715  break;
6716  case CombiningKind::OR:
6717  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6718  result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
6719  break;
6720  case CombiningKind::XOR:
6721  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6722  result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
6723  break;
6724  };
6725 
6726  assert(result && "unknown CombiningKind");
6727  return selectPassthru(b, mask, result, acc);
6728 }
6729 
6730 //===----------------------------------------------------------------------===//
6731 // Vector Masking Utilities
6732 //===----------------------------------------------------------------------===//
6733 
6734 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
6735 /// as masked operation.
6737  Operation *maskableOp) {
6738  assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
6739  Block *insBlock = builder.getInsertionBlock();
6740  // Create a block and move the op to that block.
6741  insBlock->getOperations().splice(
6742  insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
6743  builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
6744 }
6745 
6746 /// Creates a vector.mask operation around a maskable operation. Returns the
6747 /// vector.mask operation if the mask provided is valid. Otherwise, returns
6748 /// the maskable operation itself.
6750  Operation *maskableOp, Value mask,
6751  Value passthru) {
6752  if (!mask)
6753  return maskableOp;
6754  if (passthru)
6755  return builder.create<MaskOp>(maskableOp->getLoc(),
6756  maskableOp->getResultTypes(), mask, passthru,
6757  maskableOp, createMaskOpRegion);
6758  return builder.create<MaskOp>(maskableOp->getLoc(),
6759  maskableOp->getResultTypes(), mask, maskableOp,
6761 }
6762 
6763 /// Creates a vector select operation that picks values from `newValue` or
6764 /// `passthru` for each result vector lane based on `mask`. This utility is used
6765 /// to propagate the pass-thru value of vector.mask or for cases where only the
6766 /// pass-thru value propagation is needed. VP intrinsics do not support
6767 /// pass-thru values and every mask-out lane is set to poison. LLVM backends are
6768 /// usually able to match op + select patterns and fold them into a native
6769 /// target instructions.
6771  Value newValue, Value passthru) {
6772  if (!mask)
6773  return newValue;
6774 
6775  return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(),
6776  mask, newValue, passthru);
6777 }
6778 
6779 //===----------------------------------------------------------------------===//
6780 // TableGen'd op method definitions
6781 //===----------------------------------------------------------------------===//
6782 
6783 #define GET_ATTRDEF_CLASSES
6784 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
6785 
6786 #define GET_OP_CLASSES
6787 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static SmallVector< Value > computeStrides(Location loc, RewriterBase &rewriter, ValueRange dynamicBasis, ArrayRef< int64_t > staticBasis)
Given a basis (in static and dynamic components), return the sequence of suffix products of the basis...
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
union mlir::linalg::@1183::ArityGroupAndKind::Kind kind
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:187
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MINUI(lhs, rhs)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition: VectorOps.cpp:68
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
Definition: VectorOps.cpp:1086
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
Definition: VectorOps.cpp:1391
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
Definition: VectorOps.cpp:1652
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
Definition: VectorOps.cpp:4191
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
Definition: VectorOps.cpp:1958
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
Definition: VectorOps.cpp:1826
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
Definition: VectorOps.cpp:1663
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
Definition: VectorOps.cpp:2042
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Definition: VectorOps.cpp:2316
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
Definition: VectorOps.cpp:130
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
Definition: VectorOps.cpp:2032
MaskFormat
Helper enum to classify mask value.
Definition: VectorOps.cpp:58
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
Definition: VectorOps.cpp:3219
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
Definition: VectorOps.cpp:296
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
Definition: VectorOps.cpp:868
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
Definition: VectorOps.cpp:2365
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
Definition: VectorOps.cpp:1994
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
Definition: VectorOps.cpp:2683
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
Definition: VectorOps.cpp:3155
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
Definition: VectorOps.cpp:1078
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
Definition: VectorOps.cpp:3175
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
Definition: VectorOps.cpp:175
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
Definition: VectorOps.cpp:3198
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
Definition: VectorOps.cpp:3031
static LogicalResult foldTransferFullMask(TransferOp op)
Definition: VectorOps.cpp:4425
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Definition: VectorOps.cpp:1383
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite a vector.from_elements into a vector.splat if all elements are the same SSA value.
Definition: VectorOps.cpp:2339
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
Definition: VectorOps.cpp:1278
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
Definition: VectorOps.cpp:4083
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
Definition: VectorOps.cpp:879
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
Definition: VectorOps.cpp:3140
static Value foldExtractFromShapeCast(ExtractOp extractOp)
Definition: VectorOps.cpp:1761
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
Definition: VectorOps.cpp:4112
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
Definition: VectorOps.cpp:4353
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
Definition: VectorOps.cpp:3559
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
Definition: VectorOps.cpp:1729
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
Definition: VectorOps.cpp:4370
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
Definition: VectorOps.cpp:1878
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
Definition: VectorOps.cpp:2050
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:135
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:216
unsigned getNumResults() const
Definition: AffineMap.cpp:402
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:161
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:76
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation & back()
Definition: Block.h:152
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:277
IndexType getIndexType()
Definition: Builders.cpp:51
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:266
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:314
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:544
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:835
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
Definition: Operation.cpp:584
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:865
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:648
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
bool isF32() const
Definition: Types.cpp:40
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:112
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:270
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Definition: BuiltinTypes.h:282
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:289
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:390
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Definition: VectorOps.cpp:450
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
Definition: VectorOps.cpp:125
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
Definition: VectorOps.cpp:354
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:153
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition: VectorOps.h:60
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2497
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Definition: VectorOps.cpp:4210
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
Definition: VectorOps.cpp:219
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:283
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:314
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:338
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
Definition: VectorOps.cpp:626
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition: VectorOps.h:68
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Definition: VectorOps.cpp:446
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:773
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:722
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:930
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:419
Return a fused vector::ContractionOp which represents a patterns such as:
Definition: VectorOps.cpp:1179
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Definition: VectorOps.cpp:1182
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:368
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Definition: PatternMatch.h:341
bool operator==(const KeyTy &key) const
Definition: VectorOps.cpp:381
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Definition: VectorOps.cpp:383
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.