MLIR  21.0.0git
VectorOps.cpp
Go to the documentation of this file.
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/AffineMap.h"
28 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/IRMapping.h"
34 #include "mlir/IR/PatternMatch.h"
35 #include "mlir/IR/TypeUtilities.h"
38 #include "mlir/Support/LLVM.h"
40 #include "llvm/ADT/ArrayRef.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SmallVector.h"
43 #include "llvm/ADT/StringSet.h"
44 #include "llvm/ADT/TypeSwitch.h"
45 
46 #include <cassert>
47 #include <cstdint>
48 #include <numeric>
49 
50 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
51 // Pull in all enum type and utility function definitions.
52 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
53 
54 using namespace mlir;
55 using namespace mlir::vector;
56 
57 /// Helper enum to classify mask value.
58 enum class MaskFormat {
59  AllTrue = 0,
60  AllFalse = 1,
61  Unknown = 2,
62 };
63 
64 /// Helper method to classify a mask value. Currently, the method
65 /// looks "under the hood" of a constant value with dense attributes
66 /// and a constant mask operation (since the client may be called at
67 /// various stages during progressive lowering).
69  if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
70  // Inspect constant dense values. We count up for bits that
71  // are set, count down for bits that are cleared, and bail
72  // when a mix is detected.
73  if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
74  int64_t val = 0;
75  for (bool b : denseElts.getValues<bool>())
76  if (b && val >= 0)
77  val++;
78  else if (!b && val <= 0)
79  val--;
80  else
81  return MaskFormat::Unknown;
82  if (val > 0)
83  return MaskFormat::AllTrue;
84  if (val < 0)
85  return MaskFormat::AllFalse;
86  }
87  } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
88  // Inspect constant mask index. If the index exceeds the
89  // dimension size, all bits are set. If the index is zero
90  // or less, no bits are set.
91  ArrayRef<int64_t> masks = m.getMaskDimSizes();
92  auto shape = m.getType().getShape();
93  bool allTrue = true;
94  bool allFalse = true;
95  for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
96  if (maskIdx < dimSize)
97  allTrue = false;
98  if (maskIdx > 0)
99  allFalse = false;
100  }
101  if (allTrue)
102  return MaskFormat::AllTrue;
103  if (allFalse)
104  return MaskFormat::AllFalse;
105  } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
106  // Finds all-false create_masks. An all-true create_mask requires all
107  // dims to be constants, so that'll be folded to a constant_mask, then
108  // detected in the constant_mask case.
109  auto maskOperands = m.getOperands();
110  for (Value operand : maskOperands) {
111  if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
112  int64_t dimSize =
113  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
114  if (dimSize <= 0)
115  return MaskFormat::AllFalse;
116  }
117  }
118  return MaskFormat::Unknown;
119  }
120  return MaskFormat::Unknown;
121 }
122 
123 /// Default callback to build a region with a 'vector.yield' terminator with no
124 /// arguments.
126  builder.create<vector::YieldOp>(loc);
127 }
128 
129 // Helper for verifying combining kinds in contractions and reductions.
130 static bool isSupportedCombiningKind(CombiningKind combiningKind,
131  Type elementType) {
132  switch (combiningKind) {
133  case CombiningKind::ADD:
134  case CombiningKind::MUL:
135  return elementType.isIntOrIndexOrFloat();
137  case CombiningKind::MINSI:
138  case CombiningKind::MAXUI:
139  case CombiningKind::MAXSI:
140  case CombiningKind::AND:
141  case CombiningKind::OR:
142  case CombiningKind::XOR:
143  return elementType.isIntOrIndex();
144  case CombiningKind::MINNUMF:
145  case CombiningKind::MAXNUMF:
146  case CombiningKind::MINIMUMF:
147  case CombiningKind::MAXIMUMF:
148  return llvm::isa<FloatType>(elementType);
149  }
150  return false;
151 }
152 
154  VectorType vectorType) {
155  int64_t elementVectorRank = 0;
156  VectorType elementVectorType =
157  llvm::dyn_cast<VectorType>(shapedType.getElementType());
158  if (elementVectorType)
159  elementVectorRank += elementVectorType.getRank();
160  // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
161  // TODO: replace once we have 0-d vectors.
162  if (shapedType.getRank() == 0 &&
163  vectorType.getShape() == ArrayRef<int64_t>{1})
164  return AffineMap::get(
165  /*numDims=*/0, /*numSymbols=*/0,
166  getAffineConstantExpr(0, shapedType.getContext()));
168  shapedType.getRank(), vectorType.getRank() - elementVectorRank,
169  shapedType.getContext());
170 }
171 
172 /// Check if `write` is of a constant splat and the masked `read` is padded with
173 /// the same splat value -- meaning it could be the same value as the initial
174 /// constant splat.
175 static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
176  vector::TransferReadOp read) {
177  auto readMask = read.getMask();
178  auto writeMask = write.getMask();
179  // Check if the masks are consistent. The splat value could be the same if the
180  // read is masked (and padded with the splat value), and the write is unmasked
181  // or has the same mask. Note this does not allow the case where the write is
182  // masked and the read is unmasked, as then the read could be of more elements
183  // than the write (which may not be the same value).
184  bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
185  if (!couldBeSameSplat)
186  return false;
187  // Check for constant splat (as the source of the write).
188  DenseElementsAttr splatAttr;
189  if (!matchPattern(write.getVector(),
190  m_Constant<DenseElementsAttr>(&splatAttr)) ||
191  !splatAttr.isSplat()) {
192  return false;
193  }
194  // The padding of the read and the constant splat value must be the same.
195  Attribute padAttr;
196  if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
197  return false;
198  return padAttr == splatAttr.getSplatValue<Attribute>();
199 }
200 
201 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
202  vector::TransferReadOp read) {
203  return !defWrite.hasOutOfBoundsDim() &&
204  defWrite.getIndices() == read.getIndices() &&
205  defWrite.getVectorType() == read.getVectorType() &&
206  defWrite.getPermutationMap() == read.getPermutationMap() &&
207  ((!defWrite.getMask() && !read.getMask()) ||
208  isSplatWriteConsistentWithMaskedRead(defWrite, read));
209 }
210 
211 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
212  vector::TransferWriteOp priorWrite) {
213  return priorWrite.getIndices() == write.getIndices() &&
214  priorWrite.getMask() == write.getMask() &&
215  priorWrite.getVectorType() == write.getVectorType() &&
216  priorWrite.getPermutationMap() == write.getPermutationMap();
217 }
218 
220  VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
221  bool testDynamicValueUsingBounds) {
222  // For simplicity only look at transfer of same type.
223  if (transferA.getVectorType() != transferB.getVectorType())
224  return false;
225  unsigned rankOffset = transferA.getLeadingShapedRank();
226  for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
227  Value indexA = transferA.getIndices()[i];
228  Value indexB = transferB.getIndices()[i];
229  std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
230  std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
231 
232  if (i < rankOffset) {
233  // For leading dimensions, if we can prove that index are different we
234  // know we are accessing disjoint slices.
235  if (cstIndexA.has_value() && cstIndexB.has_value()) {
236  if (*cstIndexA != *cstIndexB)
237  return true;
238  continue;
239  }
240  if (testDynamicValueUsingBounds) {
241  // First try to see if we can fully compose and simplify the affine
242  // expression as a fast track.
243  FailureOr<uint64_t> delta =
245  if (succeeded(delta) && *delta != 0)
246  return true;
247 
248  FailureOr<bool> testEqual =
249  ValueBoundsConstraintSet::areEqual(indexA, indexB);
250  if (succeeded(testEqual) && !testEqual.value())
251  return true;
252  }
253  } else {
254  // For this dimension, we slice a part of the memref we need to make sure
255  // the intervals accessed don't overlap.
256  int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
257  if (cstIndexA.has_value() && cstIndexB.has_value()) {
258  int64_t distance = std::abs(*cstIndexA - *cstIndexB);
259  if (distance >= vectorDim)
260  return true;
261  continue;
262  }
263  if (testDynamicValueUsingBounds) {
264  // First try to see if we can fully compose and simplify the affine
265  // expression as a fast track.
266  FailureOr<int64_t> delta =
268  if (succeeded(delta) && std::abs(*delta) >= vectorDim)
269  return true;
270 
271  FailureOr<int64_t> computeDelta =
273  if (succeeded(computeDelta)) {
274  if (std::abs(computeDelta.value()) >= vectorDim)
275  return true;
276  }
277  }
278  }
279  }
280  return false;
281 }
282 
283 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
284  VectorTransferOpInterface transferB,
285  bool testDynamicValueUsingBounds) {
286  if (transferA.getSource() != transferB.getSource())
287  return false;
288  return isDisjointTransferIndices(transferA, transferB,
289  testDynamicValueUsingBounds);
290 }
291 
292 // Helper to iterate over n-D vector slice elements. Calculate the next
293 // `position` in the n-D vector of size `shape`, applying an offset `offsets`.
294 // Modifies the `position` in place. Returns a failure when `position` becomes
295 // the end position.
296 static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
297  ArrayRef<int64_t> shape,
298  ArrayRef<int64_t> offsets) {
299  for (auto [posInDim, dimSize, offsetInDim] :
300  llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
301  ++posInDim;
302  if (posInDim < dimSize + offsetInDim)
303  return success();
304 
305  // Carry the overflow to the next loop iteration.
306  posInDim = offsetInDim;
307  }
308 
309  return failure();
310 }
311 
312 /// Returns the integer numbers in `values`. `values` are expected to be
313 /// constant operations.
316  llvm::transform(values, std::back_inserter(ints), [](Value value) {
317  auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
318  assert(constOp && "Unexpected non-constant index");
319  return constOp.value();
320  });
321  return ints;
322 }
323 
324 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
325 /// be constant operations.
328  llvm::transform(
329  foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
330  assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
331  return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
332  });
333  return ints;
334 }
335 
336 /// Convert `foldResults` into Values. Integer attributes are converted to
337 /// constant op.
339  ArrayRef<OpFoldResult> foldResults) {
340  SmallVector<Value> values;
341  llvm::transform(foldResults, std::back_inserter(values),
342  [&](OpFoldResult foldResult) {
343  if (auto attr = foldResult.dyn_cast<Attribute>())
344  return builder
346  loc, cast<IntegerAttr>(attr).getInt())
347  .getResult();
348 
349  return cast<Value>(foldResult);
350  });
351  return values;
352 }
353 
354 std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
355  if (value.getDefiningOp<vector::VectorScaleOp>())
356  return 1;
357  auto mul = value.getDefiningOp<arith::MulIOp>();
358  if (!mul)
359  return {};
360  auto lhs = mul.getLhs();
361  auto rhs = mul.getRhs();
362  if (lhs.getDefiningOp<vector::VectorScaleOp>())
363  return getConstantIntValue(rhs);
364  if (rhs.getDefiningOp<vector::VectorScaleOp>())
365  return getConstantIntValue(lhs);
366  return {};
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // CombiningKindAttr
371 //===----------------------------------------------------------------------===//
372 
373 namespace mlir {
374 namespace vector {
375 namespace detail {
377  using KeyTy = uint64_t;
378 
379  BitmaskEnumStorage(KeyTy val) : value(val) {}
380 
381  bool operator==(const KeyTy &key) const { return value == key; }
382 
384  const KeyTy &key) {
385  return new (allocator.allocate<BitmaskEnumStorage>())
386  BitmaskEnumStorage(key);
387  }
388 
389  KeyTy value = 0;
390 };
391 } // namespace detail
392 } // namespace vector
393 } // namespace mlir
394 
395 //===----------------------------------------------------------------------===//
396 // VectorDialect
397 //===----------------------------------------------------------------------===//
398 
399 namespace {
400 /// This class defines the interface for handling inlining with vector dialect
401 /// operations.
402 struct VectorInlinerInterface : public DialectInlinerInterface {
404 
405  /// All vector dialect ops can be inlined.
406  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
407  return true;
408  }
409 };
410 } // namespace
411 
412 void VectorDialect::initialize() {
413  addAttributes<
414 #define GET_ATTRDEF_LIST
415 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
416  >();
417 
418  addOperations<
419 #define GET_OP_LIST
420 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
421  >();
422 
423  addInterfaces<VectorInlinerInterface>();
424 
425  declarePromisedInterfaces<bufferization::BufferizableOpInterface,
426  TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
427  YieldOp>();
428  declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
429  TransferWriteOp>();
430  declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
431  declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
432  declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
433 }
434 
435 /// Materialize a single constant operation from a given attribute value with
436 /// the desired resultant type.
438  Attribute value, Type type,
439  Location loc) {
440  if (isa<ub::PoisonAttrInterface>(value))
441  return value.getDialect().materializeConstant(builder, value, type, loc);
442 
443  return arith::ConstantOp::materialize(builder, value, type, loc);
444 }
445 
447  return builder.getIntegerType(64);
448 }
449 
451  ArrayRef<int64_t> values) {
452  return builder.getI64ArrayAttr(values);
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // MultiDimReductionOp
457 //===----------------------------------------------------------------------===//
458 
459 void vector::MultiDimReductionOp::build(OpBuilder &builder,
460  OperationState &result, Value source,
461  Value acc, ArrayRef<bool> reductionMask,
462  CombiningKind kind) {
463  SmallVector<int64_t> reductionDims;
464  for (const auto &en : llvm::enumerate(reductionMask))
465  if (en.value())
466  reductionDims.push_back(en.index());
467  build(builder, result, kind, source, acc, reductionDims);
468 }
469 
470 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
471  // Single parallel dim, this is a noop.
472  if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
473  return getSource();
474  return {};
475 }
476 
477 std::optional<SmallVector<int64_t, 4>>
478 MultiDimReductionOp::getShapeForUnroll() {
479  return llvm::to_vector<4>(getSourceVectorType().getShape());
480 }
481 
482 LogicalResult MultiDimReductionOp::verify() {
483  SmallVector<int64_t> targetShape;
484  SmallVector<bool> scalableDims;
485  Type inferredReturnType;
486  auto sourceScalableDims = getSourceVectorType().getScalableDims();
487  for (auto [dimIdx, dimSize] :
488  llvm::enumerate(getSourceVectorType().getShape()))
489  if (!llvm::any_of(getReductionDims(),
490  [dimIdx = dimIdx](int64_t reductionDimIdx) {
491  return reductionDimIdx == static_cast<int64_t>(dimIdx);
492  })) {
493  targetShape.push_back(dimSize);
494  scalableDims.push_back(sourceScalableDims[dimIdx]);
495  }
496  // TODO: update to also allow 0-d vectors when available.
497  if (targetShape.empty())
498  inferredReturnType = getSourceVectorType().getElementType();
499  else
500  inferredReturnType = VectorType::get(
501  targetShape, getSourceVectorType().getElementType(), scalableDims);
502  if (getType() != inferredReturnType)
503  return emitOpError() << "destination type " << getType()
504  << " is incompatible with source type "
505  << getSourceVectorType();
506 
507  return success();
508 }
509 
510 /// Returns the mask type expected by this operation.
511 Type MultiDimReductionOp::getExpectedMaskType() {
512  auto vecType = getSourceVectorType();
513  return VectorType::get(vecType.getShape(),
514  IntegerType::get(vecType.getContext(), /*width=*/1),
515  vecType.getScalableDims());
516 }
517 
518 namespace {
519 // Only unit dimensions that are being reduced are folded. If the dimension is
520 // unit, but not reduced, it is not folded, thereby keeping the output type the
521 // same. If not all dimensions which are reduced are of unit dimension, this
522 // transformation does nothing. This is just a generalization of
523 // ElideSingleElementReduction for ReduceOp.
524 struct ElideUnitDimsInMultiDimReduction
525  : public OpRewritePattern<MultiDimReductionOp> {
527 
528  LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
529  PatternRewriter &rewriter) const override {
530  ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
531  for (const auto &dim : enumerate(shape)) {
532  if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
533  return failure();
534  }
535 
536  // Vector mask setup.
537  OpBuilder::InsertionGuard guard(rewriter);
538  Operation *rootOp;
539  Value mask;
540  if (reductionOp.isMasked()) {
541  rewriter.setInsertionPoint(reductionOp.getMaskingOp());
542  rootOp = reductionOp.getMaskingOp();
543  mask = reductionOp.getMaskingOp().getMask();
544  } else {
545  rootOp = reductionOp;
546  }
547 
548  Location loc = reductionOp.getLoc();
549  Value acc = reductionOp.getAcc();
550  Value cast;
551  if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
552  if (mask) {
553  VectorType newMaskType =
554  VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
555  dstVecType.getScalableDims());
556  mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
557  }
558  cast = rewriter.create<vector::ShapeCastOp>(
559  loc, reductionOp.getDestType(), reductionOp.getSource());
560  } else {
561  // This means we are reducing all the dimensions, and all reduction
562  // dimensions are of size 1. So a simple extraction would do.
563  SmallVector<int64_t> zeroIdx(shape.size(), 0);
564  if (mask)
565  mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx);
566  cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(),
567  zeroIdx);
568  }
569 
570  Value result =
571  vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
572  cast, /*fastmath=*/nullptr, mask);
573  rewriter.replaceOp(rootOp, result);
574  return success();
575  }
576 };
577 } // namespace
578 
579 void MultiDimReductionOp::getCanonicalizationPatterns(
580  RewritePatternSet &results, MLIRContext *context) {
581  results.add<ElideUnitDimsInMultiDimReduction>(context);
582 }
583 
584 //===----------------------------------------------------------------------===//
585 // ReductionOp
586 //===----------------------------------------------------------------------===//
587 
588 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
589  CombiningKind kind, Value vector,
590  arith::FastMathFlags fastMathFlags) {
591  build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
592 }
593 
594 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
595  CombiningKind kind, Value vector, Value acc,
596  arith::FastMathFlags fastMathFlags) {
597  build(builder, result,
598  llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
599  acc, fastMathFlags);
600 }
601 
602 LogicalResult ReductionOp::verify() {
603  // Verify for 0-D and 1-D vector.
604  int64_t rank = getSourceVectorType().getRank();
605  if (rank > 1)
606  return emitOpError("unsupported reduction rank: ") << rank;
607 
608  // Verify supported reduction kind.
609  Type eltType = getDest().getType();
610  if (!isSupportedCombiningKind(getKind(), eltType))
611  return emitOpError("unsupported reduction type '")
612  << eltType << "' for kind '" << stringifyCombiningKind(getKind())
613  << "'";
614 
615  return success();
616 }
617 
618 // MaskableOpInterface methods.
619 
620 /// Returns the mask type expected by this operation.
621 Type ReductionOp::getExpectedMaskType() {
622  auto vecType = getSourceVectorType();
623  return VectorType::get(vecType.getShape(),
624  IntegerType::get(vecType.getContext(), /*width=*/1),
625  vecType.getScalableDims());
626 }
627 
629  OpBuilder &builder, Location loc,
630  Value vector) {
631  switch (op) {
632  case arith::AtomicRMWKind::addf:
633  case arith::AtomicRMWKind::addi:
634  return builder.create<vector::ReductionOp>(vector.getLoc(),
635  CombiningKind::ADD, vector);
636  case arith::AtomicRMWKind::mulf:
637  case arith::AtomicRMWKind::muli:
638  return builder.create<vector::ReductionOp>(vector.getLoc(),
639  CombiningKind::MUL, vector);
640  case arith::AtomicRMWKind::minimumf:
641  return builder.create<vector::ReductionOp>(vector.getLoc(),
642  CombiningKind::MINIMUMF, vector);
643  case arith::AtomicRMWKind::mins:
644  return builder.create<vector::ReductionOp>(vector.getLoc(),
645  CombiningKind::MINSI, vector);
646  case arith::AtomicRMWKind::minu:
647  return builder.create<vector::ReductionOp>(vector.getLoc(),
648  CombiningKind::MINUI, vector);
649  case arith::AtomicRMWKind::maximumf:
650  return builder.create<vector::ReductionOp>(vector.getLoc(),
651  CombiningKind::MAXIMUMF, vector);
652  case arith::AtomicRMWKind::maxs:
653  return builder.create<vector::ReductionOp>(vector.getLoc(),
654  CombiningKind::MAXSI, vector);
655  case arith::AtomicRMWKind::maxu:
656  return builder.create<vector::ReductionOp>(vector.getLoc(),
657  CombiningKind::MAXUI, vector);
658  case arith::AtomicRMWKind::andi:
659  return builder.create<vector::ReductionOp>(vector.getLoc(),
660  CombiningKind::AND, vector);
661  case arith::AtomicRMWKind::ori:
662  return builder.create<vector::ReductionOp>(vector.getLoc(),
663  CombiningKind::OR, vector);
664  // TODO: Add remaining reduction operations.
665  default:
666  (void)emitOptionalError(loc, "Reduction operation type not supported");
667  break;
668  }
669  return nullptr;
670 }
671 
672 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
673  return llvm::to_vector<4>(getSourceVectorType().getShape());
674 }
675 
676 namespace {
677 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
679 
680  LogicalResult matchAndRewrite(ReductionOp reductionOp,
681  PatternRewriter &rewriter) const override {
682  // Vector mask setup.
683  OpBuilder::InsertionGuard guard(rewriter);
684  auto maskableOp =
685  cast<vector::MaskableOpInterface>(reductionOp.getOperation());
686  Operation *rootOp;
687  Value mask;
688  if (maskableOp.isMasked()) {
689  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
690  rootOp = maskableOp.getMaskingOp();
691  mask = maskableOp.getMaskingOp().getMask();
692  } else {
693  rootOp = reductionOp;
694  }
695 
696  auto vectorType = reductionOp.getSourceVectorType();
697  if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
698  return failure();
699 
700  Location loc = reductionOp.getLoc();
701  Value result;
702  if (vectorType.getRank() == 0) {
703  if (mask)
704  mask = rewriter.create<ExtractElementOp>(loc, mask);
705  result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
706  } else {
707  if (mask)
708  mask = rewriter.create<ExtractOp>(loc, mask, 0);
709  result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0);
710  }
711 
712  if (Value acc = reductionOp.getAcc())
713  result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
714  result, acc,
715  reductionOp.getFastmathAttr(), mask);
716 
717  rewriter.replaceOp(rootOp, result);
718  return success();
719  }
720 };
721 } // namespace
722 
723 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
724  MLIRContext *context) {
725  results.add<ElideSingleElementReduction>(context);
726 }
727 
728 //===----------------------------------------------------------------------===//
729 // ContractionOp
730 //===----------------------------------------------------------------------===//
731 
732 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
733  Value lhs, Value rhs, Value acc,
734  ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
735  ArrayRef<IteratorType> iteratorTypes) {
736  result.addOperands({lhs, rhs, acc});
737  result.addTypes(acc.getType());
738  result.addAttribute(
739  getIndexingMapsAttrName(result.name),
740  builder.getAffineMapArrayAttr(
741  AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
742  result.addAttribute(
743  getIteratorTypesAttrName(result.name),
744  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
745  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
746  return IteratorTypeAttr::get(builder.getContext(), t);
747  }))));
748 }
749 
750 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
751  Value lhs, Value rhs, Value acc,
752  ArrayAttr indexingMaps,
753  ArrayAttr iteratorTypes) {
754  build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
755  ContractionOp::getDefaultKind());
756 }
757 
758 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
759  Value lhs, Value rhs, Value acc,
760  ArrayAttr indexingMaps,
761  ArrayAttr iteratorTypes, CombiningKind kind) {
762  result.addOperands({lhs, rhs, acc});
763  result.addTypes(acc.getType());
764  result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
765  result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
766  result.addAttribute(getKindAttrName(result.name),
767  CombiningKindAttr::get(builder.getContext(), kind));
768 }
769 
770 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
775  SmallVector<Type, 2> types;
776  Type resultType;
777  auto loc = parser.getCurrentLocation();
778  DictionaryAttr dictAttr;
779  // TODO: Unify linalg op attribute parsing.
780  if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
781  parser.parseComma() || parser.parseOperand(rhsInfo) ||
782  parser.parseComma() || parser.parseOperand(accInfo) ||
783  parser.parseTrailingOperandList(masksInfo) ||
784  parser.parseOptionalAttrDict(result.attributes) ||
785  parser.parseColonTypeList(types) ||
786  parser.parseKeywordType("into", resultType) ||
787  parser.resolveOperand(lhsInfo, types[0], result.operands) ||
788  parser.resolveOperand(rhsInfo, types[1], result.operands) ||
789  parser.resolveOperand(accInfo, resultType, result.operands) ||
790  parser.addTypeToList(resultType, result.types))
791  return failure();
792  result.attributes.append(dictAttr.getValue().begin(),
793  dictAttr.getValue().end());
794 
795  // Convert array of string into an array of IteratyType enums. This is needed,
796  // because tests still use the old format when 'iterator_types' attribute is
797  // represented as an array of strings.
798  // TODO: Remove this conversion once tests are fixed.
799  ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
800  result.attributes.get(getIteratorTypesAttrName(result.name)));
801 
802  SmallVector<Attribute> iteratorTypeAttrs;
803 
804  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
805  auto maybeIteratorType = symbolizeIteratorType(s);
806  if (!maybeIteratorType.has_value())
807  return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
808 
809  iteratorTypeAttrs.push_back(
810  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
811  }
812  result.attributes.set(getIteratorTypesAttrName(result.name),
813  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
814 
815  if (!result.attributes.get(getKindAttrName(result.name))) {
816  result.addAttribute(
817  getKindAttrName(result.name),
819  ContractionOp::getDefaultKind()));
820  }
821  if (masksInfo.empty())
822  return success();
823  if (masksInfo.size() != 2)
824  return parser.emitError(parser.getNameLoc(),
825  "expected zero or exactly 2 vector mask operands");
826  auto lhsType = llvm::cast<VectorType>(types[0]);
827  auto rhsType = llvm::cast<VectorType>(types[1]);
828  auto maskElementType = parser.getBuilder().getI1Type();
829  std::array<VectorType, 2> maskTypes = {
830  VectorType::Builder(lhsType).setElementType(maskElementType),
831  VectorType::Builder(rhsType).setElementType(maskElementType)};
832  if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
833  return failure();
834  return success();
835 }
836 
838  // TODO: Unify printing code with linalg ops.
839  auto attrNames = getTraitAttrNames();
840  llvm::StringSet<> traitAttrsSet;
841  traitAttrsSet.insert(attrNames.begin(), attrNames.end());
843  for (auto attr : (*this)->getAttrs()) {
844  if (attr.getName() == getIteratorTypesAttrName()) {
845  auto iteratorTypes =
846  llvm::cast<ArrayAttr>(attr.getValue())
847  .getAsValueRange<IteratorTypeAttr, IteratorType>();
848  // Convert IteratorType enums into the string representation. This is
849  // needed, because tests still use the old format when 'iterator_types'
850  // attribute is represented as an array of strings.
851  // TODO: Remove this conversion once tests are fixed.
852  SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
853  llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
854  return StringAttr::get(getContext(), stringifyIteratorType(t));
855  }));
856 
857  attrs.emplace_back(getIteratorTypesAttrName(),
858  ArrayAttr::get(getContext(), iteratorTypeNames));
859  } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
860  attrs.push_back(attr);
861  }
862 
863  auto dictAttr = DictionaryAttr::get(getContext(), attrs);
864  p << " " << dictAttr << " " << getLhs() << ", ";
865  p << getRhs() << ", " << getAcc();
866 
867  p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
868  p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
869  << getResultType();
870 }
871 
872 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
873  const std::vector<std::pair<int64_t, int64_t>> &map) {
874  for (auto &dimPair : map) {
875  if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
876  dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
877  lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
878  return false;
879  }
880  return true;
881 }
882 
883 static LogicalResult verifyOutputShape(
884  ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
885  Type resType,
886  const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
887  const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
888  DenseSet<int64_t> lhsContractingDimSet;
889  DenseSet<int64_t> rhsContractingDimSet;
890  for (auto &dimPair : contractingDimMap) {
891  lhsContractingDimSet.insert(dimPair.first);
892  rhsContractingDimSet.insert(dimPair.second);
893  }
894  DenseSet<int64_t> rhsBatchDimSet;
895  for (auto &dimPair : batchDimMap)
896  rhsBatchDimSet.insert(dimPair.second);
897 
898  // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
899  SmallVector<int64_t, 4> expectedResultDims;
900  for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
901  if (lhsContractingDimSet.count(i) > 0)
902  continue;
903  expectedResultDims.push_back(lhsType.getDimSize(i));
904  }
905 
906  // Add free dimensions from 'rhsType' to 'expectedResultDims'.
907  for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
908  if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
909  continue;
910  expectedResultDims.push_back(rhsType.getDimSize(i));
911  }
912 
913  // Verify 'expectedResultDims'.
914  if (expectedResultDims.empty()) {
915  // No batch or free dimension implies a scalar result.
916  if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
917  return op.emitOpError("invalid accumulator/result vector shape");
918  } else {
919  // At least one batch or free dimension implies a vector result.
920  auto resVectorType = llvm::dyn_cast<VectorType>(resType);
921  auto accVectorType = llvm::dyn_cast<VectorType>(accType);
922  if (!resVectorType || !accVectorType)
923  return op.emitOpError("invalid accumulator/result vector shape");
924 
925  // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
926  // types fully define the result vector type. This assumes the affine maps
927  // are well-formed, which must have been verified already.
928  MLIRContext *ctx = op.getContext();
929  AffineMap lhsMap = op.getIndexingMapsArray()[0];
930  AffineMap rhsMap = op.getIndexingMapsArray()[1];
931  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
932  return op.emitOpError(
933  "expected all dimensions to be either a LHS or a RHS dimension");
934  SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
935  for (auto pair :
936  {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
937  VectorType v = pair.first;
938  auto map = pair.second;
939  for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
940  unsigned pos = map.getDimPosition(idx);
941  if (!extents[pos])
942  extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
943  }
944  }
945  if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
946  return op.emitOpError("expected all dimensions to get an extent as "
947  "either a LHS or a RHS dimension");
948 
949  AffineMap resMap = op.getIndexingMapsArray()[2];
950  auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
951  /*symbolCount=*/0, extents, ctx);
952  // Compose the resMap with the extentsMap, which is a constant map.
953  AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
954  assert(llvm::all_of(expectedMap.getResults(),
955  llvm::IsaPred<AffineConstantExpr>) &&
956  "expected constant extent along all dimensions.");
957  // Extract the expected shape and build the type.
958  auto expectedShape = llvm::to_vector<4>(
959  llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
960  return cast<AffineConstantExpr>(e).getValue();
961  }));
962  auto expected =
963  VectorType::get(expectedShape, resVectorType.getElementType(),
964  resVectorType.getScalableDims());
965  if (resVectorType != expected || accVectorType != expected)
966  return op.emitOpError(
967  "invalid accumulator/result vector shape, expected: ")
968  << expected;
969  }
970  return success();
971 }
972 
973 LogicalResult ContractionOp::verify() {
974  VectorType lhsType = getLhsType();
975  VectorType rhsType = getRhsType();
976  Type accType = getAccType();
977  Type resType = getResultType();
978 
979  if (llvm::isa<IntegerType>(lhsType.getElementType())) {
980  if (!lhsType.getElementType().isSignlessInteger())
981  return emitOpError("only supports signless integer types");
982  }
983 
984  // Verify that an indexing map was specified for each vector operand.
985  if (getIndexingMapsArray().size() != 3)
986  return emitOpError("expected an indexing map for each vector operand");
987 
988  // Verify that each index map has 'numIterators' inputs, no symbols, and
989  // that the number of map outputs equals the rank of its associated
990  // vector operand.
991  unsigned numIterators = getIteratorTypes().getValue().size();
992  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
993  auto index = it.index();
994  auto map = it.value();
995  if (map.getNumSymbols() != 0)
996  return emitOpError("expected indexing map ")
997  << index << " to have no symbols";
998  auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
999  unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1000  // Verify that the map has the right number of inputs, outputs, and indices.
1001  // This also correctly accounts for (..) -> () for rank-0 results.
1002  if (map.getNumDims() != numIterators)
1003  return emitOpError("expected indexing map ")
1004  << index << " to have " << numIterators << " number of inputs";
1005  if (map.getNumResults() != rank)
1006  return emitOpError("expected indexing map ")
1007  << index << " to have " << rank << " number of outputs";
1008  if (!map.isProjectedPermutation())
1009  return emitOpError("expected indexing map ")
1010  << index << " to be a projected permutation of its inputs";
1011  }
1012 
1013  auto contractingDimMap = getContractingDimMap();
1014  auto batchDimMap = getBatchDimMap();
1015 
1016  // Verify at least one contracting dimension pair was specified.
1017  if (contractingDimMap.empty())
1018  return emitOpError("expected at least one contracting dimension pair");
1019 
1020  // Verify contracting dimension map was properly constructed.
1021  if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1022  return emitOpError("invalid contracting dimension map");
1023 
1024  // Verify batch dimension map was properly constructed.
1025  if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1026  return emitOpError("invalid batch dimension map");
1027 
1028  // Verify 'accType' and 'resType' shape.
1029  if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1030  contractingDimMap, batchDimMap)))
1031  return failure();
1032 
1033  // Verify supported combining kind.
1034  auto vectorType = llvm::dyn_cast<VectorType>(resType);
1035  auto elementType = vectorType ? vectorType.getElementType() : resType;
1036  if (!isSupportedCombiningKind(getKind(), elementType))
1037  return emitOpError("unsupported contraction type");
1038 
1039  return success();
1040 }
1041 
1042 // MaskableOpInterface methods.
1043 
1044 /// Returns the mask type expected by this operation. Mostly used for
1045 /// verification purposes. It requires the operation to be vectorized."
1046 Type ContractionOp::getExpectedMaskType() {
1047  auto indexingMaps = this->getIndexingMapsArray();
1048  AffineMap lhsIdxMap = indexingMaps[0];
1049  AffineMap rhsIdxMap = indexingMaps[1];
1050  VectorType lhsType = this->getLhsType();
1051  VectorType rhsType = this->getRhsType();
1052 
1053  unsigned numVecDims = lhsIdxMap.getNumDims();
1054  SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1055  SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1056 
1057  // Using the information in the indexing maps, extract the size of each
1058  // dimension in the vector.contract operation from the two input operands.
1059  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1060  maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1061  maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1062  lhsType.getScalableDims()[dimIdx];
1063  }
1064  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1065  maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1066  maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1067  rhsType.getScalableDims()[dimIdx];
1068  }
1069 
1070  assert(!ShapedType::isDynamicShape(maskShape) &&
1071  "Mask shape couldn't be computed");
1072 
1073  return VectorType::get(maskShape,
1074  IntegerType::get(lhsType.getContext(), /*width=*/1),
1075  maskShapeScalableDims);
1076 }
1077 
1078 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1079  return SmallVector<StringRef>{getIndexingMapsAttrName(),
1080  getIteratorTypesAttrName(), getKindAttrName()};
1081 }
1082 
1083 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
1084  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1085  if (targetExpr == map.getResult(i))
1086  return i;
1087  return -1;
1088 }
1089 
1090 static std::vector<std::pair<int64_t, int64_t>>
1091 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1092  IteratorType targetIteratorType, MLIRContext *context) {
1093  std::vector<std::pair<int64_t, int64_t>> dimMap;
1094  for (const auto &it : llvm::enumerate(iteratorTypes)) {
1095  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1096  if (iteratorType != targetIteratorType)
1097  continue;
1098  // Search lhs/rhs map results for 'targetExpr'.
1099  auto targetExpr = getAffineDimExpr(it.index(), context);
1100  int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1101  int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1102  if (lhsDim >= 0 && rhsDim >= 0)
1103  dimMap.emplace_back(lhsDim, rhsDim);
1104  }
1105  return dimMap;
1106 }
1107 
1108 void ContractionOp::getIterationBounds(
1109  SmallVectorImpl<int64_t> &iterationBounds) {
1110  auto lhsShape = getLhsType().getShape();
1111  auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1112  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1113  SmallVector<int64_t, 2> iterationShape;
1114  for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1115  // Search lhs/rhs map results for 'targetExpr'.
1116  auto targetExpr = getAffineDimExpr(it.index(), getContext());
1117  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1118  if (iteratorType == IteratorType::reduction) {
1119  // Get reduction dim size from lhs shape (same size in rhsShape).
1120  int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1121  assert(lhsDimIndex >= 0);
1122  iterationBounds.push_back(lhsShape[lhsDimIndex]);
1123  continue;
1124  }
1125  // Get parallel dimension size from result shape.
1126  int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1127  assert(resDimIndex >= 0);
1128  assert(resVectorType != nullptr);
1129  iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1130  }
1131 }
1132 
1133 void ContractionOp::getIterationIndexMap(
1134  std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1135  unsigned numMaps = getIndexingMapsArray().size();
1136  iterationIndexMap.resize(numMaps);
1137  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1138  auto index = it.index();
1139  auto map = it.value();
1140  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1141  auto dim = cast<AffineDimExpr>(map.getResult(i));
1142  iterationIndexMap[index][dim.getPosition()] = i;
1143  }
1144  }
1145 }
1146 
1147 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1148  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1149  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1150  getContext());
1151 }
1152 
1153 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1154  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1155  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1156  getContext());
1157 }
1158 
1159 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1161  getIterationBounds(shape);
1162  return shape;
1163 }
1164 
1165 /// Return a fused vector::ContractionOp which represents a patterns such as:
1166 ///
1167 /// ```mlir
1168 /// %c0 = vector.constant 0: ...
1169 /// %c = vector.contract %a, %b, %c0: ...
1170 /// %e = add %c, %d: ...
1171 /// ```
1172 ///
1173 /// by:
1174 ///
1175 /// ```mlir
1176 /// %e = vector.contract %a, %b, %d: ...
1177 /// ```
1178 ///
1179 /// Return null if the canonicalization does not apply.
1180 // TODO: This should be a folding of Add into Contract in core but while they
1181 // live in different dialects, it is not possible without unnatural
1182 // dependencies.
1183 template <typename AddOpType>
1184 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1186 
1187  LogicalResult matchAndRewrite(AddOpType addOp,
1188  PatternRewriter &rewriter) const override {
1189  auto canonicalize = [&](Value maybeContraction,
1190  Value otherOperand) -> vector::ContractionOp {
1191  vector::ContractionOp contractionOp =
1192  dyn_cast_or_null<vector::ContractionOp>(
1193  maybeContraction.getDefiningOp());
1194  if (!contractionOp)
1195  return vector::ContractionOp();
1196  if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1197  contractionOp.getAcc().getDefiningOp())) {
1198  if (maybeZero.getValue() ==
1199  rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1200  IRMapping bvm;
1201  bvm.map(contractionOp.getAcc(), otherOperand);
1202  auto newContraction =
1203  cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1204  rewriter.replaceOp(addOp, newContraction.getResult());
1205  return newContraction;
1206  }
1207  }
1208  return vector::ContractionOp();
1209  };
1210 
1211  Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1212  vector::ContractionOp contract = canonicalize(a, b);
1213  contract = contract ? contract : canonicalize(b, a);
1214  return contract ? success() : failure();
1215  }
1216 };
1217 
1218 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1219  MLIRContext *context) {
1222 }
1223 
1224 //===----------------------------------------------------------------------===//
1225 // ExtractElementOp
1226 //===----------------------------------------------------------------------===//
1227 
1228 void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1229  SetIntRangeFn setResultRanges) {
1230  setResultRanges(getResult(), argRanges.front());
1231 }
1232 
1233 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1234  Value source) {
1235  result.addOperands({source});
1236  result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
1237 }
1238 
1239 LogicalResult vector::ExtractElementOp::verify() {
1240  VectorType vectorType = getSourceVectorType();
1241  if (vectorType.getRank() == 0) {
1242  if (getPosition())
1243  return emitOpError("expected position to be empty with 0-D vector");
1244  return success();
1245  }
1246  if (vectorType.getRank() != 1)
1247  return emitOpError("unexpected >1 vector rank");
1248  if (!getPosition())
1249  return emitOpError("expected position for 1-D vector");
1250  return success();
1251 }
1252 
1253 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1254  // Skip the 0-D vector here now.
1255  if (!adaptor.getPosition())
1256  return {};
1257 
1258  // Fold extractelement (splat X) -> X.
1259  if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
1260  return splat.getInput();
1261 
1262  // Fold extractelement(broadcast(X)) -> X.
1263  if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1264  if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
1265  return broadcast.getSource();
1266 
1267  auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1268  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1269  if (!pos || !src)
1270  return {};
1271 
1272  auto srcElements = src.getValues<Attribute>();
1273 
1274  uint64_t posIdx = pos.getInt();
1275  if (posIdx >= srcElements.size())
1276  return {};
1277 
1278  return srcElements[posIdx];
1279 }
1280 
1281 // Returns `true` if `index` is either within [0, maxIndex) or equal to
1282 // `poisonValue`.
1283 static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
1284  int64_t maxIndex) {
1285  return index == poisonValue || (index >= 0 && index < maxIndex);
1286 }
1287 
1288 //===----------------------------------------------------------------------===//
1289 // ExtractOp
1290 //===----------------------------------------------------------------------===//
1291 
1292 void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1293  SetIntRangeFn setResultRanges) {
1294  setResultRanges(getResult(), argRanges.front());
1295 }
1296 
1297 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1298  Value source, int64_t position) {
1299  build(builder, result, source, ArrayRef<int64_t>{position});
1300 }
1301 
1302 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1303  Value source, OpFoldResult position) {
1304  build(builder, result, source, ArrayRef<OpFoldResult>{position});
1305 }
1306 
1307 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1308  Value source, ArrayRef<int64_t> position) {
1309  build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1310  builder.getDenseI64ArrayAttr(position));
1311 }
1312 
1313 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1314  Value source, ArrayRef<OpFoldResult> position) {
1315  SmallVector<int64_t> staticPos;
1316  SmallVector<Value> dynamicPos;
1317  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1318  build(builder, result, source, dynamicPos,
1319  builder.getDenseI64ArrayAttr(staticPos));
1320 }
1321 
1322 LogicalResult
1323 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1324  ExtractOp::Adaptor adaptor,
1325  SmallVectorImpl<Type> &inferredReturnTypes) {
1326  auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1327  if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1328  vectorType.getRank()) {
1329  inferredReturnTypes.push_back(vectorType.getElementType());
1330  } else {
1331  auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1332  vectorType.getRank());
1333  inferredReturnTypes.push_back(VectorType::get(
1334  vectorType.getShape().drop_front(n), vectorType.getElementType(),
1335  vectorType.getScalableDims().drop_front(n)));
1336  }
1337  return success();
1338 }
1339 
1340 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1341  // Allow extracting 1-element vectors instead of scalars.
1342  auto isCompatible = [](TypeRange l, TypeRange r) {
1343  auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1344  return vectorType && vectorType.getShape().equals({1}) &&
1345  vectorType.getElementType() == r.front();
1346  };
1347  if (l.size() == 1 && r.size() == 1 &&
1348  (isCompatible(l, r) || isCompatible(r, l)))
1349  return true;
1350  return l == r;
1351 }
1352 
1353 LogicalResult vector::ExtractOp::verify() {
1354  // Note: This check must come before getMixedPosition() to prevent a crash.
1355  auto dynamicMarkersCount =
1356  llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1357  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1358  return emitOpError(
1359  "mismatch between dynamic and static positions (kDynamic marker but no "
1360  "corresponding dynamic position) -- this can only happen due to an "
1361  "incorrect fold/rewrite");
1362  auto position = getMixedPosition();
1363  if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1364  return emitOpError(
1365  "expected position attribute of rank no greater than vector rank");
1366  for (auto [idx, pos] : llvm::enumerate(position)) {
1367  if (auto attr = dyn_cast<Attribute>(pos)) {
1368  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1370  constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1371  return emitOpError("expected position attribute #")
1372  << (idx + 1)
1373  << " to be a non-negative integer smaller than the "
1374  "corresponding vector dimension or poison (-1)";
1375  }
1376  }
1377  }
1378  return success();
1379 }
1380 
1381 template <typename IntType>
1382 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1383  return llvm::to_vector<4>(llvm::map_range(
1384  arrayAttr.getAsRange<IntegerAttr>(),
1385  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1386 }
1387 
1388 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1389 /// positions.
1390 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1391  if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1392  return failure();
1393 
1394  // TODO: Canonicalization for dynamic position not implemented yet.
1395  if (extractOp.hasDynamicPosition())
1396  return failure();
1397 
1398  SmallVector<int64_t> globalPosition;
1399  ExtractOp currentOp = extractOp;
1400  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1401  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1402  while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1403  currentOp = nextOp;
1404  // TODO: Canonicalization for dynamic position not implemented yet.
1405  if (currentOp.hasDynamicPosition())
1406  return failure();
1407  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1408  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1409  }
1410  extractOp.setOperand(0, currentOp.getVector());
1411  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1412  OpBuilder b(extractOp.getContext());
1413  std::reverse(globalPosition.begin(), globalPosition.end());
1414  extractOp.setStaticPosition(globalPosition);
1415  return success();
1416 }
1417 
1418 namespace {
1419 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1420 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1421 /// Compose TransposeOp permutations as we walk back.
1422 /// This helper class keeps an updated extraction position `extractPosition`
1423 /// with extra trailing sentinels.
1424 /// The sentinels encode the internal transposition status of the result vector.
1425 /// As we iterate, extractPosition is permuted and updated.
1426 class ExtractFromInsertTransposeChainState {
1427 public:
1428  ExtractFromInsertTransposeChainState(ExtractOp e);
1429 
1430  /// Iterate over producing insert and transpose ops until we find a fold.
1431  Value fold();
1432 
1433 private:
1434  /// Return true if the vector at position `a` is contained within the vector
1435  /// at position `b`. Under insert/extract semantics, this is the same as `a`
1436  /// is a prefix of `b`.
1437  template <typename ContainerA, typename ContainerB>
1438  bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1439  return a.size() <= b.size() &&
1440  std::equal(a.begin(), a.begin() + a.size(), b.begin());
1441  }
1442 
1443  /// Return true if the vector at position `a` intersects the vector at
1444  /// position `b`. Under insert/extract semantics, this is the same as equality
1445  /// of all entries of `a` that are >=0 with the corresponding entries of b.
1446  /// Comparison is on the common prefix (i.e. zip).
1447  template <typename ContainerA, typename ContainerB>
1448  bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1449  for (auto [elemA, elemB] : llvm::zip(a, b)) {
1450  if (elemA < 0 || elemB < 0)
1451  continue;
1452  if (elemA != elemB)
1453  return false;
1454  }
1455  return true;
1456  }
1457 
1458  /// Folding is only possible in the absence of an internal permutation in the
1459  /// result vector.
1460  bool canFold() {
1461  return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1462  }
1463 
1464  // Helper to get the next defining op of interest.
1465  void updateStateForNextIteration(Value v) {
1466  nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1467  nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1468  };
1469 
1470  // Case 1. If we hit a transpose, just compose the map and iterate.
1471  // Invariant: insert + transpose do not change rank, we can always compose.
1472  LogicalResult handleTransposeOp();
1473 
1474  // Case 2: the insert position matches extractPosition exactly, early return.
1475  LogicalResult handleInsertOpWithMatchingPos(Value &res);
1476 
1477  /// Case 3: if the insert position is a prefix of extractPosition, extract a
1478  /// portion of the source of the insert.
1479  /// Example:
1480  /// ```
1481  /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1482  /// // extractPosition == [1, 2, 3]
1483  /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1484  /// // can fold to vector.extract %source[0, 3]
1485  /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1486  /// ```
1487  /// To traverse through %source, we need to set the leading dims to 0 and
1488  /// drop the extra leading dims.
1489  /// This method updates the internal state.
1490  LogicalResult handleInsertOpWithPrefixPos(Value &res);
1491 
1492  /// Try to fold in place to extract(source, extractPosition) and return the
1493  /// folded result. Return null if folding is not possible (e.g. due to an
1494  /// internal transposition in the result).
1495  Value tryToFoldExtractOpInPlace(Value source);
1496 
1497  ExtractOp extractOp;
1498  int64_t vectorRank;
1499  int64_t extractedRank;
1500 
1501  InsertOp nextInsertOp;
1502  TransposeOp nextTransposeOp;
1503 
1504  /// Sentinel values that encode the internal permutation status of the result.
1505  /// They are set to (-1, ... , -k) at the beginning and appended to
1506  /// `extractPosition`.
1507  /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1508  /// ensure that there is no internal transposition.
1509  /// Internal transposition cannot be accounted for with a folding pattern.
1510  // TODO: We could relax the internal transposition with an extra transposition
1511  // operation in a future canonicalizer.
1512  SmallVector<int64_t> sentinels;
1514 };
1515 } // namespace
1516 
1517 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1518  ExtractOp e)
1519  : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1520  extractedRank(extractOp.getNumIndices()) {
1521  assert(vectorRank >= extractedRank && "Extracted position overflow");
1522  sentinels.reserve(vectorRank - extractedRank);
1523  for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1524  sentinels.push_back(-(i + 1));
1525  extractPosition.assign(extractOp.getStaticPosition().begin(),
1526  extractOp.getStaticPosition().end());
1527  llvm::append_range(extractPosition, sentinels);
1528 }
1529 
1530 // Case 1. If we hit a transpose, just compose the map and iterate.
1531 // Invariant: insert + transpose do not change rank, we can always compose.
1532 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1533  // TODO: Canonicalization for dynamic position not implemented yet.
1534  if (extractOp.hasDynamicPosition())
1535  return failure();
1536 
1537  if (!nextTransposeOp)
1538  return failure();
1539  AffineMap m = inversePermutation(AffineMap::getPermutationMap(
1540  nextTransposeOp.getPermutation(), extractOp.getContext()));
1542  return success();
1543 }
1544 
1545 // Case 2: the insert position matches extractPosition exactly, early return.
1546 LogicalResult
1547 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1548  Value &res) {
1549  // TODO: Canonicalization for dynamic position not implemented yet.
1550  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1551  return failure();
1552 
1553  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1554  if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1555  return failure();
1556  // Case 2.a. early-exit fold.
1557  res = nextInsertOp.getSource();
1558  // Case 2.b. if internal transposition is present, canFold will be false.
1559  return success(canFold());
1560 }
1561 
1562 /// Case 3: if inserted position is a prefix of extractPosition,
1563 /// extract a portion of the source of the insertion.
1564 /// This method updates the internal state.
1565 LogicalResult
1566 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1567  // TODO: Canonicalization for dynamic position not implemented yet.
1568  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1569  return failure();
1570 
1571  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1572  if (!isContainedWithin(insertedPos, extractPosition))
1573  return failure();
1574  // Set leading dims to zero.
1575  std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1576  // Drop extra leading dims.
1577  extractPosition.erase(extractPosition.begin(),
1578  extractPosition.begin() + insertedPos.size());
1579  extractedRank = extractPosition.size() - sentinels.size();
1580  // Case 3.a. early-exit fold (break and delegate to post-while path).
1581  res = nextInsertOp.getSource();
1582  // Case 3.b. if internal transposition is present, canFold will be false.
1583  return success();
1584 }
1585 
1586 /// Try to fold in place to extract(source, extractPosition) and return the
1587 /// folded result. Return null if folding is not possible (e.g. due to an
1588 /// internal transposition in the result).
1589 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1590  Value source) {
1591  // TODO: Canonicalization for dynamic position not implemented yet.
1592  if (extractOp.hasDynamicPosition())
1593  return Value();
1594 
1595  // If we can't fold (either internal transposition, or nothing to fold), bail.
1596  bool nothingToFold = (source == extractOp.getVector());
1597  if (nothingToFold || !canFold())
1598  return Value();
1599 
1600  // Otherwise, fold by updating the op inplace and return its result.
1601  OpBuilder b(extractOp.getContext());
1602  extractOp.setStaticPosition(
1603  ArrayRef(extractPosition).take_front(extractedRank));
1604  extractOp.getVectorMutable().assign(source);
1605  return extractOp.getResult();
1606 }
1607 
1608 /// Iterate over producing insert and transpose ops until we find a fold.
1609 Value ExtractFromInsertTransposeChainState::fold() {
1610  // TODO: Canonicalization for dynamic position not implemented yet.
1611  if (extractOp.hasDynamicPosition())
1612  return Value();
1613 
1614  Value valueToExtractFrom = extractOp.getVector();
1615  updateStateForNextIteration(valueToExtractFrom);
1616  while (nextInsertOp || nextTransposeOp) {
1617  // Case 1. If we hit a transpose, just compose the map and iterate.
1618  // Invariant: insert + transpose do not change rank, we can always compose.
1619  if (succeeded(handleTransposeOp())) {
1620  valueToExtractFrom = nextTransposeOp.getVector();
1621  updateStateForNextIteration(valueToExtractFrom);
1622  continue;
1623  }
1624 
1625  Value result;
1626  // Case 2: the position match exactly.
1627  if (succeeded(handleInsertOpWithMatchingPos(result)))
1628  return result;
1629 
1630  // Case 3: if the inserted position is a prefix of extractPosition, we can
1631  // just extract a portion of the source of the insert.
1632  if (succeeded(handleInsertOpWithPrefixPos(result)))
1633  return tryToFoldExtractOpInPlace(result);
1634 
1635  // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1636  // values. This is a more difficult case and we bail.
1637  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1638  if (isContainedWithin(extractPosition, insertedPos) ||
1639  intersectsWhereNonNegative(extractPosition, insertedPos))
1640  return Value();
1641 
1642  // Case 5: No intersection, we forward the extract to insertOp.dest().
1643  valueToExtractFrom = nextInsertOp.getDest();
1644  updateStateForNextIteration(valueToExtractFrom);
1645  }
1646  // If after all this we can fold, go for it.
1647  return tryToFoldExtractOpInPlace(valueToExtractFrom);
1648 }
1649 
1650 /// Returns true if the operation has a 0-D vector type operand or result.
1651 static bool hasZeroDimVectors(Operation *op) {
1652  auto hasZeroDimVectorType = [](Type type) -> bool {
1653  auto vecType = dyn_cast<VectorType>(type);
1654  return vecType && vecType.getRank() == 0;
1655  };
1656 
1657  return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1658  llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1659 }
1660 
1661 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1662 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1663  // TODO: Canonicalization for dynamic position not implemented yet.
1664  if (extractOp.hasDynamicPosition())
1665  return Value();
1666 
1667  Operation *defOp = extractOp.getVector().getDefiningOp();
1668  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1669  return Value();
1670 
1671  Value source = defOp->getOperand(0);
1672  if (extractOp.getType() == source.getType())
1673  return source;
1674  auto getRank = [](Type type) {
1675  return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1676  : 0;
1677  };
1678 
1679  // If splat or broadcast from a scalar, just return the source scalar.
1680  unsigned broadcastSrcRank = getRank(source.getType());
1681  if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
1682  return source;
1683 
1684  unsigned extractResultRank = getRank(extractOp.getType());
1685  if (extractResultRank >= broadcastSrcRank)
1686  return Value();
1687  // Check that the dimension of the result haven't been broadcasted.
1688  auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1689  auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
1690  if (extractVecType && broadcastVecType &&
1691  extractVecType.getShape() !=
1692  broadcastVecType.getShape().take_back(extractResultRank))
1693  return Value();
1694 
1695  auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1696  int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1697 
1698  // Detect all the positions that come from "dim-1" broadcasting.
1699  // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1700  // extract position to `0` when extracting from the source operand.
1701  llvm::SetVector<int64_t> broadcastedUnitDims =
1702  broadcastOp.computeBroadcastedUnitDims();
1703  SmallVector<int64_t> extractPos(extractOp.getStaticPosition());
1704  int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1705  for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1706  if (broadcastedUnitDims.contains(i))
1707  extractPos[i] = 0;
1708  // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1709  // matching extract position when extracting from the source operand.
1710  int64_t rankDiff = broadcastSrcRank - extractResultRank;
1711  extractPos.erase(extractPos.begin(),
1712  std::next(extractPos.begin(), extractPos.size() - rankDiff));
1713  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1714  OpBuilder b(extractOp.getContext());
1715  extractOp.setOperand(0, source);
1716  extractOp.setStaticPosition(extractPos);
1717  return extractOp.getResult();
1718 }
1719 
1720 /// Fold extractOp coming from ShuffleOp.
1721 ///
1722 /// Example:
1723 ///
1724 /// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1725 /// : vector<8xf32>, vector<8xf32>
1726 /// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1727 /// ->
1728 /// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1729 ///
1730 static Value foldExtractFromShuffle(ExtractOp extractOp) {
1731  // Dynamic positions are not folded as the resulting code would be more
1732  // complex than the input code.
1733  if (extractOp.hasDynamicPosition())
1734  return Value();
1735 
1736  auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1737  if (!shuffleOp)
1738  return Value();
1739 
1740  // TODO: 0-D or multi-dimensional vectors not supported yet.
1741  if (shuffleOp.getResultVectorType().getRank() != 1)
1742  return Value();
1743 
1744  int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1745  auto shuffleMask = shuffleOp.getMask();
1746  int64_t extractIdx = extractOp.getStaticPosition()[0];
1747  int64_t shuffleIdx = shuffleMask[extractIdx];
1748 
1749  // Find the shuffled vector to extract from based on the shuffle index.
1750  if (shuffleIdx < inputVecSize) {
1751  extractOp.setOperand(0, shuffleOp.getV1());
1752  extractOp.setStaticPosition({shuffleIdx});
1753  } else {
1754  extractOp.setOperand(0, shuffleOp.getV2());
1755  extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1756  }
1757 
1758  return extractOp.getResult();
1759 }
1760 
1761 // Fold extractOp with source coming from ShapeCast op.
1762 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1763  // TODO: Canonicalization for dynamic position not implemented yet.
1764  if (extractOp.hasDynamicPosition())
1765  return Value();
1766 
1767  auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1768  if (!shapeCastOp)
1769  return Value();
1770 
1771  // Get the nth dimension size starting from lowest dimension.
1772  auto getDimReverse = [](VectorType type, int64_t n) {
1773  return type.getShape().take_back(n + 1).front();
1774  };
1775  int64_t destinationRank =
1776  llvm::isa<VectorType>(extractOp.getType())
1777  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1778  : 0;
1779  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1780  return Value();
1781  if (destinationRank > 0) {
1782  auto destinationType =
1783  llvm::cast<VectorType>(extractOp.getResult().getType());
1784  for (int64_t i = 0; i < destinationRank; i++) {
1785  // The lowest dimension of the destination must match the lowest
1786  // dimension of the shapecast op source.
1787  // TODO: This case could be support in a canonicalization pattern.
1788  if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1789  getDimReverse(destinationType, i))
1790  return Value();
1791  }
1792  }
1793  // Extract the strides associated with the extract op vector source. Then use
1794  // this to calculate a linearized position for the extract.
1795  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1796  std::reverse(extractedPos.begin(), extractedPos.end());
1797  SmallVector<int64_t, 4> strides;
1798  int64_t stride = 1;
1799  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1800  strides.push_back(stride);
1801  stride *=
1802  getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1803  }
1804 
1805  int64_t position = linearize(extractedPos, strides);
1806  // Then extract the strides associated to the shapeCast op vector source and
1807  // delinearize the position using those strides.
1808  SmallVector<int64_t, 4> newStrides;
1809  int64_t numDimension =
1810  shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1811  stride = 1;
1812  for (int64_t i = 0; i < numDimension; i++) {
1813  newStrides.push_back(stride);
1814  stride *=
1815  getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1816  }
1817  std::reverse(newStrides.begin(), newStrides.end());
1818  SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1819  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1820  OpBuilder b(extractOp.getContext());
1821  extractOp.setStaticPosition(newPosition);
1822  extractOp.setOperand(0, shapeCastOp.getSource());
1823  return extractOp.getResult();
1824 }
1825 
1826 /// Fold an ExtractOp from ExtractStridedSliceOp.
1827 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1828  // TODO: Canonicalization for dynamic position not implemented yet.
1829  if (extractOp.hasDynamicPosition())
1830  return Value();
1831 
1832  auto extractStridedSliceOp =
1833  extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1834  if (!extractStridedSliceOp)
1835  return Value();
1836 
1837  // 0-D vectors not supported.
1838  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1839  if (hasZeroDimVectors(extractStridedSliceOp))
1840  return Value();
1841 
1842  // Return if 'extractStridedSliceOp' has non-unit strides.
1843  if (extractStridedSliceOp.hasNonUnitStrides())
1844  return Value();
1845 
1846  // Trim offsets for dimensions fully extracted.
1847  auto sliceOffsets =
1848  extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1849  while (!sliceOffsets.empty()) {
1850  size_t lastOffset = sliceOffsets.size() - 1;
1851  if (sliceOffsets.back() != 0 ||
1852  extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1853  extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1854  break;
1855  sliceOffsets.pop_back();
1856  }
1857  unsigned destinationRank = 0;
1858  if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1859  destinationRank = vecType.getRank();
1860  // The dimensions of the result need to be untouched by the
1861  // extractStridedSlice op.
1862  if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1863  sliceOffsets.size())
1864  return Value();
1865 
1866  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1867  assert(extractedPos.size() >= sliceOffsets.size());
1868  for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1869  extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1870  extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1871 
1872  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1873  OpBuilder b(extractOp.getContext());
1874  extractOp.setStaticPosition(extractedPos);
1875  return extractOp.getResult();
1876 }
1877 
1878 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1879 static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1880  // TODO: Canonicalization for dynamic position not implemented yet.
1881  if (extractOp.hasDynamicPosition())
1882  return Value();
1883 
1884  int64_t destinationRank =
1885  llvm::isa<VectorType>(extractOp.getType())
1886  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1887  : 0;
1888  auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1889  if (!insertOp)
1890  return Value();
1891 
1892  // 0-D vectors not supported.
1893  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1894  if (hasZeroDimVectors(insertOp))
1895  return Value();
1896 
1897  while (insertOp) {
1898  int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1899  insertOp.getSourceVectorType().getRank();
1900  if (destinationRank > insertOp.getSourceVectorType().getRank())
1901  return Value();
1902  auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1903  ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1904 
1905  if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1906  return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1907  }))
1908  return Value();
1909  bool disjoint = false;
1910  SmallVector<int64_t, 4> offsetDiffs;
1911  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1912  int64_t start = insertOffsets[dim];
1913  int64_t size =
1914  (dim < insertRankDiff)
1915  ? 1
1916  : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1917  int64_t end = start + size;
1918  int64_t offset = extractOffsets[dim];
1919  // Check if the start of the extract offset is in the interval inserted.
1920  if (start <= offset && offset < end) {
1921  if (dim >= insertRankDiff)
1922  offsetDiffs.push_back(offset - start);
1923  continue;
1924  }
1925  disjoint = true;
1926  break;
1927  }
1928  // The extract element chunk overlap with the vector inserted.
1929  if (!disjoint) {
1930  // If any of the inner dimensions are only partially inserted we have a
1931  // partial overlap.
1932  int64_t srcRankDiff =
1933  insertOp.getSourceVectorType().getRank() - destinationRank;
1934  for (int64_t i = 0; i < destinationRank; i++) {
1935  if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1936  insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1937  insertRankDiff))
1938  return Value();
1939  }
1940  extractOp.getVectorMutable().assign(insertOp.getSource());
1941  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1942  OpBuilder b(extractOp.getContext());
1943  extractOp.setStaticPosition(offsetDiffs);
1944  return extractOp.getResult();
1945  }
1946  // If the chunk extracted is disjoint from the chunk inserted, keep
1947  // looking in the insert chain.
1948  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1949  }
1950  return Value();
1951 }
1952 
1953 /// Try to fold the extraction of a scalar from a vector defined by
1954 /// vector.from_elements. E.g.:
1955 ///
1956 /// %0 = vector.from_elements %a, %b : vector<2xf32>
1957 /// %1 = vector.extract %0[0] : f32 from vector<2xf32>
1958 /// ==> fold to %a
1959 static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
1960  // Dynamic extractions cannot be folded.
1961  if (extractOp.hasDynamicPosition())
1962  return {};
1963 
1964  // Look for extract(from_elements).
1965  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
1966  if (!fromElementsOp)
1967  return {};
1968 
1969  // Scalable vectors are not supported.
1970  auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1971  if (vecType.isScalable())
1972  return {};
1973 
1974  // Only extractions of scalars are supported.
1975  int64_t rank = vecType.getRank();
1976  ArrayRef<int64_t> indices = extractOp.getStaticPosition();
1977  if (extractOp.getType() != vecType.getElementType())
1978  return {};
1979  assert(static_cast<int64_t>(indices.size()) == rank &&
1980  "unexpected number of indices");
1981 
1982  // Compute flattened/linearized index and fold to operand.
1983  int flatIndex = 0;
1984  int stride = 1;
1985  for (int i = rank - 1; i >= 0; --i) {
1986  flatIndex += indices[i] * stride;
1987  stride *= vecType.getDimSize(i);
1988  }
1989  return fromElementsOp.getElements()[flatIndex];
1990 }
1991 
1992 /// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
1993 /// then fold it.
1994 template <typename OpType, typename AdaptorType>
1995 static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
1996  SmallVectorImpl<Value> &operands) {
1997  std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
1998  OperandRange dynamicPosition = op.getDynamicPosition();
1999  ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2000 
2001  // If the dynamic operands is empty, it is returned directly.
2002  if (!dynamicPosition.size())
2003  return {};
2004 
2005  // `index` is used to iterate over the `dynamicPosition`.
2006  unsigned index = 0;
2007 
2008  // `opChange` is a flag. If it is true, it means to update `op` in place.
2009  bool opChange = false;
2010  for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2011  if (!ShapedType::isDynamic(staticPosition[i]))
2012  continue;
2013  Attribute positionAttr = dynamicPositionAttr[index];
2014  Value position = dynamicPosition[index++];
2015  if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2016  staticPosition[i] = attr.getInt();
2017  opChange = true;
2018  continue;
2019  }
2020  operands.push_back(position);
2021  }
2022 
2023  if (opChange) {
2024  op.setStaticPosition(staticPosition);
2025  op.getOperation()->setOperands(operands);
2026  return op.getResult();
2027  }
2028  return {};
2029 }
2030 
2031 /// Fold an insert or extract operation into an poison value when a poison index
2032 /// is found at any dimension of the static position.
2034  ArrayRef<int64_t> staticPos,
2035  int64_t poisonVal) {
2036  if (!llvm::is_contained(staticPos, poisonVal))
2037  return {};
2038 
2039  return ub::PoisonAttr::get(context);
2040 }
2041 
2042 /// Fold a vector extract from is a poison source.
2044  if (llvm::isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2045  return srcAttr;
2046 
2047  return {};
2048 }
2049 
2050 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2051  // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2052  // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2053  // mismatch).
2054  if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
2055  return getVector();
2056  if (auto res = foldPoisonIndexInsertExtractOp(
2057  getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2058  return res;
2059  if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
2060  return res;
2061  if (succeeded(foldExtractOpFromExtractChain(*this)))
2062  return getResult();
2063  if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2064  return res;
2065  if (auto res = foldExtractFromBroadcast(*this))
2066  return res;
2067  if (auto res = foldExtractFromShuffle(*this))
2068  return res;
2069  if (auto res = foldExtractFromShapeCast(*this))
2070  return res;
2071  if (auto val = foldExtractFromExtractStrided(*this))
2072  return val;
2073  if (auto val = foldExtractStridedOpFromInsertChain(*this))
2074  return val;
2075  if (auto val = foldScalarExtractFromFromElements(*this))
2076  return val;
2077  SmallVector<Value> operands = {getVector()};
2078  if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
2079  return val;
2080  return OpFoldResult();
2081 }
2082 
2083 namespace {
2084 
2085 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2086 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2087 public:
2089 
2090  LogicalResult matchAndRewrite(ExtractOp extractOp,
2091  PatternRewriter &rewriter) const override {
2092  Operation *defOp = extractOp.getVector().getDefiningOp();
2093  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2094  return failure();
2095 
2096  Value source = defOp->getOperand(0);
2097  if (extractOp.getType() == source.getType())
2098  return failure();
2099  auto getRank = [](Type type) {
2100  return llvm::isa<VectorType>(type)
2101  ? llvm::cast<VectorType>(type).getRank()
2102  : 0;
2103  };
2104  unsigned broadcastSrcRank = getRank(source.getType());
2105  unsigned extractResultRank = getRank(extractOp.getType());
2106  // We only consider the case where the rank of the source is less than or
2107  // equal to the rank of the extract dst. The other cases are handled in the
2108  // folding patterns.
2109  if (extractResultRank < broadcastSrcRank)
2110  return failure();
2111 
2112  // Special case if broadcast src is a 0D vector.
2113  if (extractResultRank == 0) {
2114  assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
2115  rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
2116  return success();
2117  }
2118  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2119  extractOp, extractOp.getType(), source);
2120  return success();
2121  }
2122 };
2123 
2124 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
2125 class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
2126 public:
2128 
2129  LogicalResult matchAndRewrite(ExtractOp extractOp,
2130  PatternRewriter &rewriter) const override {
2131  // Return if 'ExtractOp' operand is not defined by a splat vector
2132  // ConstantOp.
2133  Value sourceVector = extractOp.getVector();
2134  Attribute vectorCst;
2135  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2136  return failure();
2137  auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
2138  if (!splat)
2139  return failure();
2140  TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
2141  if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
2142  newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2143  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2144  return success();
2145  }
2146 };
2147 
2148 // Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
2149 class ExtractOpNonSplatConstantFolder final
2150  : public OpRewritePattern<ExtractOp> {
2151 public:
2153 
2154  LogicalResult matchAndRewrite(ExtractOp extractOp,
2155  PatternRewriter &rewriter) const override {
2156  // TODO: Canonicalization for dynamic position not implemented yet.
2157  if (extractOp.hasDynamicPosition())
2158  return failure();
2159 
2160  // Return if 'ExtractOp' operand is not defined by a compatible vector
2161  // ConstantOp.
2162  Value sourceVector = extractOp.getVector();
2163  Attribute vectorCst;
2164  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2165  return failure();
2166 
2167  auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
2168  if (vecTy.isScalable())
2169  return failure();
2170 
2171  // The splat case is handled by `ExtractOpSplatConstantFolder`.
2172  auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2173  if (!dense || dense.isSplat())
2174  return failure();
2175 
2176  // Calculate the linearized position of the continuous chunk of elements to
2177  // extract.
2178  llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2179  copy(extractOp.getStaticPosition(), completePositions.begin());
2180  int64_t elemBeginPosition =
2181  linearize(completePositions, computeStrides(vecTy.getShape()));
2182  auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
2183 
2184  TypedAttr newAttr;
2185  if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
2186  SmallVector<Attribute> elementValues(
2187  denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2188  newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2189  } else {
2190  newAttr = *denseValuesBegin;
2191  }
2192 
2193  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2194  return success();
2195  }
2196 };
2197 
2198 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2199 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2200 public:
2202 
2203  LogicalResult matchAndRewrite(ExtractOp extractOp,
2204  PatternRewriter &rewriter) const override {
2205  auto createMaskOp =
2206  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2207  if (!createMaskOp)
2208  return failure();
2209 
2210  VectorType extractedMaskType =
2211  llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2212 
2213  if (!extractedMaskType)
2214  return failure();
2215 
2216  auto maskOperands = createMaskOp.getOperands();
2217  ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2218  VectorType maskType = createMaskOp.getVectorType();
2219 
2220  bool containsUnknownDims = false;
2221  bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2222 
2223  for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2224  dimIdx++) {
2225  int64_t pos = extractOpPos[dimIdx];
2226  Value operand = maskOperands[dimIdx];
2227  auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2228  if (!constantOp) {
2229  // Bounds of this dim unknown.
2230  containsUnknownDims = true;
2231  continue;
2232  }
2233 
2234  int64_t createMaskBound =
2235  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2236 
2237  if (pos != ShapedType::kDynamic) {
2238  // If any position is outside the range from the `create_mask`, then the
2239  // extracted mask will be all-false.
2240  allFalse |= pos >= createMaskBound;
2241  } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2242  // This dim is not all-true and since this is a dynamic index we don't
2243  // know if the extraction is within the true or false region.
2244  // Note: Zero dims have already handled via getMaskFormat().
2245  containsUnknownDims = true;
2246  }
2247  }
2248 
2249  if (allFalse) {
2250  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2251  extractOp, DenseElementsAttr::get(extractedMaskType, false));
2252  } else if (!containsUnknownDims) {
2253  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2254  extractOp, extractedMaskType,
2255  maskOperands.drop_front(extractOpPos.size()));
2256  } else {
2257  return failure();
2258  }
2259  return success();
2260  }
2261 };
2262 
2263 // Folds extract(shape_cast(..)) into shape_cast when the total element count
2264 // does not change.
2265 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2266  PatternRewriter &rewriter) {
2267  auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2268  if (!castOp)
2269  return failure();
2270 
2271  VectorType sourceType = castOp.getSourceVectorType();
2272  auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2273  if (!targetType)
2274  return failure();
2275 
2276  if (sourceType.getNumElements() != targetType.getNumElements())
2277  return failure();
2278 
2279  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2280  castOp.getSource());
2281  return success();
2282 }
2283 
2284 /// Try to canonicalize the extraction of a subvector from a vector defined by
2285 /// vector.from_elements. E.g.:
2286 ///
2287 /// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2288 /// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2289 /// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2290 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2291  PatternRewriter &rewriter) {
2292  // Dynamic positions are not supported.
2293  if (extractOp.hasDynamicPosition())
2294  return failure();
2295 
2296  // Scalar extracts are handled by the folder.
2297  auto resultType = dyn_cast<VectorType>(extractOp.getType());
2298  if (!resultType)
2299  return failure();
2300 
2301  // Look for extracts from a from_elements op.
2302  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2303  if (!fromElementsOp)
2304  return failure();
2305  VectorType inputType = fromElementsOp.getType();
2306 
2307  // Scalable vectors are not supported.
2308  if (resultType.isScalable() || inputType.isScalable())
2309  return failure();
2310 
2311  // Compute the position of first extracted element and flatten/linearize the
2312  // position.
2313  SmallVector<int64_t> firstElementPos =
2314  llvm::to_vector(extractOp.getStaticPosition());
2315  firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2316  int flatIndex = 0;
2317  int stride = 1;
2318  for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2319  flatIndex += firstElementPos[i] * stride;
2320  stride *= inputType.getDimSize(i);
2321  }
2322 
2323  // Replace the op with a smaller from_elements op.
2324  rewriter.replaceOpWithNewOp<FromElementsOp>(
2325  extractOp, resultType,
2326  fromElementsOp.getElements().slice(flatIndex,
2327  resultType.getNumElements()));
2328  return success();
2329 }
2330 
2331 } // namespace
2332 
2333 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2334  MLIRContext *context) {
2335  results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2336  ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2337  results.add(foldExtractFromShapeCastToShapeCast);
2338  results.add(foldExtractFromFromElements);
2339 }
2340 
2341 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
2342  SmallVectorImpl<int64_t> &results) {
2343  for (auto attr : arrayAttr)
2344  results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2345 }
2346 
2347 //===----------------------------------------------------------------------===//
2348 // FmaOp
2349 //===----------------------------------------------------------------------===//
2350 
2351 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2352  return llvm::to_vector<4>(getVectorType().getShape());
2353 }
2354 
2355 //===----------------------------------------------------------------------===//
2356 // FromElementsOp
2357 //===----------------------------------------------------------------------===//
2358 
2359 /// Rewrite a vector.from_elements into a vector.splat if all elements are the
2360 /// same SSA value. E.g.:
2361 ///
2362 /// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2363 /// ==> rewrite to vector.splat %a : vector<3xf32>
2364 static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
2365  PatternRewriter &rewriter) {
2366  if (!llvm::all_equal(fromElementsOp.getElements()))
2367  return failure();
2368  rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
2369  fromElementsOp.getElements().front());
2370  return success();
2371 }
2372 
2373 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2374  MLIRContext *context) {
2376 }
2377 
2378 //===----------------------------------------------------------------------===//
2379 // BroadcastOp
2380 //===----------------------------------------------------------------------===//
2381 
2382 void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2383  SetIntRangeFn setResultRanges) {
2384  setResultRanges(getResult(), argRanges.front());
2385 }
2386 
2387 /// Return the dimensions of the result vector that were formerly ones in the
2388 /// source tensor and thus correspond to "dim-1" broadcasting.
2391  ArrayRef<int64_t> dstShape) {
2392  int64_t rankDiff = dstShape.size() - srcShape.size();
2393  int64_t dstDim = rankDiff;
2395  for (auto [s1, s2] :
2396  llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2397  if (s1 != s2) {
2398  assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2399  res.insert(dstDim);
2400  }
2401  ++dstDim;
2402  }
2403  return res;
2404 }
2405 
2407  // Scalar broadcast is without any unit dim broadcast.
2408  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2409  if (!srcVectorType)
2410  return {};
2411  return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2412  getResultVectorType().getShape());
2413 }
2414 
2415 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2416 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
2417 /// This requires (and asserts) that the broadcast is free of "dim-1"
2418 /// broadcasting.
2419 /// Since vector.broadcast only allows expanding leading dimensions, an extra
2420 /// vector.transpose may be inserted to make the broadcast possible.
2421 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2422 /// the helper will assert. This means:
2423 /// 1. `dstShape` must not be empty.
2424 /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2425 /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2426 // must match the `value` shape.
2427 Value BroadcastOp::createOrFoldBroadcastOp(
2428  OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2429  const llvm::SetVector<int64_t> &broadcastedDims) {
2430  assert(!dstShape.empty() && "unexpected empty dst shape");
2431 
2432  // Well-formedness check.
2433  SmallVector<int64_t> checkShape;
2434  for (int i = 0, e = dstShape.size(); i < e; ++i) {
2435  if (broadcastedDims.contains(i))
2436  continue;
2437  checkShape.push_back(dstShape[i]);
2438  }
2439  assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2440  "ill-formed broadcastedDims contains values not confined to "
2441  "destVectorShape");
2442 
2443  Location loc = value.getLoc();
2444  Type elementType = getElementTypeOrSelf(value.getType());
2445  VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2446  VectorType dstVectorType = VectorType::get(dstShape, elementType);
2447 
2448  // Step 2. If scalar -> dstShape broadcast, just do it.
2449  if (!srcVectorType) {
2450  assert(checkShape.empty() &&
2451  "ill-formed createOrFoldBroadcastOp arguments");
2452  return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2453  }
2454 
2455  assert(srcVectorType.getShape().equals(checkShape) &&
2456  "ill-formed createOrFoldBroadcastOp arguments");
2457 
2458  // Step 3. Since vector.broadcast only allows creating leading dims,
2459  // vector -> dstShape broadcast may require a transpose.
2460  // Traverse the dims in order and construct:
2461  // 1. The leading entries of the broadcastShape that is guaranteed to be
2462  // achievable by a simple broadcast.
2463  // 2. The induced permutation for the subsequent vector.transpose that will
2464  // bring us from `broadcastShape` back to he desired `dstShape`.
2465  // If the induced permutation is not the identity, create a vector.transpose.
2466  SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2467  broadcastShape.reserve(dstShape.size());
2468  // Consider the example:
2469  // srcShape = 2x4
2470  // dstShape = 1x2x3x4x5
2471  // broadcastedDims = [0, 2, 4]
2472  //
2473  // We want to build:
2474  // broadcastShape = 1x3x5x2x4
2475  // permutation = [0, 2, 4, 1, 3]
2476  // ---V--- -----V-----
2477  // leading broadcast part src shape part
2478  //
2479  // Note that the trailing dims of broadcastShape are exactly the srcShape
2480  // by construction.
2481  // nextSrcShapeDim is used to keep track of where in the permutation the
2482  // "src shape part" occurs.
2483  int64_t nextSrcShapeDim = broadcastedDims.size();
2484  for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2485  if (broadcastedDims.contains(i)) {
2486  // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2487  // bring it to the head of the broadcastShape.
2488  // It will need to be permuted back from `broadcastShape.size() - 1` into
2489  // position `i`.
2490  broadcastShape.push_back(dstShape[i]);
2491  permutation[i] = broadcastShape.size() - 1;
2492  } else {
2493  // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2494  // shape and needs to be permuted into position `i`.
2495  // Don't touch `broadcastShape` here, the whole srcShape will be
2496  // appended after.
2497  permutation[i] = nextSrcShapeDim++;
2498  }
2499  }
2500  // 3.c. Append the srcShape.
2501  llvm::append_range(broadcastShape, srcVectorType.getShape());
2502 
2503  // Ensure there are no "dim-1" broadcasts.
2504  assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2505  .empty() &&
2506  "unexpected \"dim-1\" broadcast");
2507 
2508  VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2509  assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2510  vector::BroadcastableToResult::Success &&
2511  "must be broadcastable");
2512  Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2513  // Step 4. If we find any dimension that indeed needs to be permuted,
2514  // immediately return a new vector.transpose.
2515  for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2516  if (permutation[i] != i)
2517  return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2518  // Otherwise return res.
2519  return res;
2520 }
2521 
2523  Type srcType, VectorType dstVectorType,
2524  std::pair<VectorDim, VectorDim> *mismatchingDims) {
2525  // Broadcast scalar to vector of the same element type.
2526  if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
2527  getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
2528  return BroadcastableToResult::Success;
2529  // From now on, only vectors broadcast.
2530  VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2531  if (!srcVectorType)
2532  return BroadcastableToResult::SourceTypeNotAVector;
2533 
2534  int64_t srcRank = srcVectorType.getRank();
2535  int64_t dstRank = dstVectorType.getRank();
2536  if (srcRank > dstRank)
2537  return BroadcastableToResult::SourceRankHigher;
2538  // Source has an exact match or singleton value for all trailing dimensions
2539  // (all leading dimensions are simply duplicated).
2540  int64_t lead = dstRank - srcRank;
2541  for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2542  // Have mismatching dims (in the sense of vector.broadcast semantics) been
2543  // encountered?
2544  bool foundMismatchingDims = false;
2545 
2546  // Check fixed-width dims.
2547  int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2548  int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2549  if (srcDim != 1 && srcDim != dstDim)
2550  foundMismatchingDims = true;
2551 
2552  // Check scalable flags.
2553  bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2554  bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2555  if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2556  // 1 -> [N] is fine, everything else should be rejected when mixing
2557  // fixed-width and scalable dims
2558  (srcDimScalableFlag != dstDimScalableFlag &&
2559  (srcDim != 1 || srcDimScalableFlag)))
2560  foundMismatchingDims = true;
2561 
2562  if (foundMismatchingDims) {
2563  if (mismatchingDims != nullptr) {
2564  mismatchingDims->first.dim = srcDim;
2565  mismatchingDims->first.isScalable = srcDimScalableFlag;
2566 
2567  mismatchingDims->second.dim = dstDim;
2568  mismatchingDims->second.isScalable = dstDimScalableFlag;
2569  }
2570  return BroadcastableToResult::DimensionMismatch;
2571  }
2572  }
2573 
2574  return BroadcastableToResult::Success;
2575 }
2576 
2577 LogicalResult BroadcastOp::verify() {
2578  std::pair<VectorDim, VectorDim> mismatchingDims;
2580  getSourceType(), getResultVectorType(), &mismatchingDims);
2581  if (res == BroadcastableToResult::Success)
2582  return success();
2583  if (res == BroadcastableToResult::SourceRankHigher)
2584  return emitOpError("source rank higher than destination rank");
2585  if (res == BroadcastableToResult::DimensionMismatch) {
2586  return emitOpError("dimension mismatch (")
2587  << (mismatchingDims.first.isScalable ? "[" : "")
2588  << mismatchingDims.first.dim
2589  << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
2590  << (mismatchingDims.second.isScalable ? "[" : "")
2591  << mismatchingDims.second.dim
2592  << (mismatchingDims.second.isScalable ? "]" : "") << ")";
2593  }
2594  if (res == BroadcastableToResult::SourceTypeNotAVector)
2595  return emitOpError("source type is not a vector");
2596  llvm_unreachable("unexpected vector.broadcast op error");
2597 }
2598 
2599 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
2600  if (getSourceType() == getResultVectorType())
2601  return getSource();
2602  if (!adaptor.getSource())
2603  return {};
2604  auto vectorType = getResultVectorType();
2605  if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
2606  if (vectorType.getElementType() != attr.getType())
2607  return {};
2608  return DenseElementsAttr::get(vectorType, attr);
2609  }
2610  if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
2611  if (vectorType.getElementType() != attr.getType())
2612  return {};
2613  return DenseElementsAttr::get(vectorType, attr);
2614  }
2615  if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2616  return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
2617  return {};
2618 }
2619 
2620 namespace {
2621 
2622 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
2623 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2625 
2626  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2627  PatternRewriter &rewriter) const override {
2628  auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2629  if (!srcBroadcast)
2630  return failure();
2631  rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
2632  broadcastOp.getResultVectorType(),
2633  srcBroadcast.getSource());
2634  return success();
2635  }
2636 };
2637 } // namespace
2638 
2639 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2640  MLIRContext *context) {
2641  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2642  // calling `populateCastAwayVectorLeadingOneDimPatterns`
2643  results.add<BroadcastFolder>(context);
2644 }
2645 
2646 //===----------------------------------------------------------------------===//
2647 // ShuffleOp
2648 //===----------------------------------------------------------------------===//
2649 
2650 LogicalResult ShuffleOp::verify() {
2651  VectorType resultType = getResultVectorType();
2652  VectorType v1Type = getV1VectorType();
2653  VectorType v2Type = getV2VectorType();
2654  // Verify ranks.
2655  int64_t resRank = resultType.getRank();
2656  int64_t v1Rank = v1Type.getRank();
2657  int64_t v2Rank = v2Type.getRank();
2658  bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2659  bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2660  if (!wellFormed0DCase && !wellFormedNDCase)
2661  return emitOpError("rank mismatch");
2662 
2663  // Verify all but leading dimension sizes.
2664  for (int64_t r = 1; r < v1Rank; ++r) {
2665  int64_t resDim = resultType.getDimSize(r);
2666  int64_t v1Dim = v1Type.getDimSize(r);
2667  int64_t v2Dim = v2Type.getDimSize(r);
2668  if (resDim != v1Dim || v1Dim != v2Dim)
2669  return emitOpError("dimension mismatch");
2670  }
2671  // Verify mask length.
2672  ArrayRef<int64_t> mask = getMask();
2673  int64_t maskLength = mask.size();
2674  if (maskLength <= 0)
2675  return emitOpError("invalid mask length");
2676  if (maskLength != resultType.getDimSize(0))
2677  return emitOpError("mask length mismatch");
2678  // Verify all indices.
2679  int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2680  (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2681  for (auto [idx, maskPos] : llvm::enumerate(mask)) {
2682  if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
2683  return emitOpError("mask index #") << (idx + 1) << " out of range";
2684  }
2685  return success();
2686 }
2687 
2688 LogicalResult
2689 ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
2690  ShuffleOp::Adaptor adaptor,
2691  SmallVectorImpl<Type> &inferredReturnTypes) {
2692  auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2693  auto v1Rank = v1Type.getRank();
2694  // Construct resulting type: leading dimension matches mask
2695  // length, all trailing dimensions match the operands.
2697  shape.reserve(v1Rank);
2698  shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2699  // In the 0-D case there is no trailing shape to append.
2700  if (v1Rank > 0)
2701  llvm::append_range(shape, v1Type.getShape().drop_front());
2702  inferredReturnTypes.push_back(
2703  VectorType::get(shape, v1Type.getElementType()));
2704  return success();
2705 }
2706 
2707 template <typename T>
2708 static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
2709  T expected = begin;
2710  return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2711  return value == expected++;
2712  });
2713 }
2714 
2715 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2716  auto v1Type = getV1VectorType();
2717  auto v2Type = getV2VectorType();
2718 
2719  assert(!v1Type.isScalable() && !v2Type.isScalable() &&
2720  "Vector shuffle does not support scalable vectors");
2721 
2722  // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
2723  // but must be a canonicalization into a vector.broadcast.
2724  if (v1Type.getRank() == 0)
2725  return {};
2726 
2727  // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
2728  auto mask = getMask();
2729  if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
2730  return getV1();
2731  // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
2732  if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
2733  return getV2();
2734 
2735  Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
2736  if (!v1Attr || !v2Attr)
2737  return {};
2738 
2739  // Fold shuffle poison, poison -> poison.
2740  bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
2741  bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
2742  if (isV1Poison && isV2Poison)
2743  return ub::PoisonAttr::get(getContext());
2744 
2745  // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2746  // manipulation.
2747  if (v1Type.getRank() != 1)
2748  return {};
2749 
2750  // Poison input attributes need special handling as they are not
2751  // DenseElementsAttr. If an index is poison, we select the first element of
2752  // the first non-poison input.
2753  SmallVector<Attribute> v1Elements, v2Elements;
2754  Attribute poisonElement;
2755  if (!isV2Poison) {
2756  v2Elements =
2757  to_vector(cast<DenseElementsAttr>(v2Attr).getValues<Attribute>());
2758  poisonElement = v2Elements[0];
2759  }
2760  if (!isV1Poison) {
2761  v1Elements =
2762  to_vector(cast<DenseElementsAttr>(v1Attr).getValues<Attribute>());
2763  poisonElement = v1Elements[0];
2764  }
2765 
2766  SmallVector<Attribute> results;
2767  int64_t v1Size = v1Type.getDimSize(0);
2768  for (int64_t maskIdx : mask) {
2769  Attribute indexedElm;
2770  // TODO: Return a partial poison vector when supported by the UB dialect.
2771  if (maskIdx == ShuffleOp::kPoisonIndex) {
2772  indexedElm = poisonElement;
2773  } else {
2774  if (maskIdx < v1Size)
2775  indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
2776  else
2777  indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
2778  }
2779 
2780  results.push_back(indexedElm);
2781  }
2782 
2783  return DenseElementsAttr::get(getResultVectorType(), results);
2784 }
2785 
2786 namespace {
2787 
2788 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
2789 // to a broadcast.
2790 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
2792 
2793  LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
2794  PatternRewriter &rewriter) const override {
2795  VectorType v1VectorType = shuffleOp.getV1VectorType();
2796  ArrayRef<int64_t> mask = shuffleOp.getMask();
2797  if (v1VectorType.getRank() > 0)
2798  return failure();
2799  if (mask.size() != 1)
2800  return failure();
2801  VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
2802  if (mask[0] == 0)
2803  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2804  shuffleOp.getV1());
2805  else
2806  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2807  shuffleOp.getV2());
2808  return success();
2809  }
2810 };
2811 
2812 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
2813 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
2814 public:
2816 
2817  LogicalResult matchAndRewrite(ShuffleOp op,
2818  PatternRewriter &rewriter) const override {
2819  auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2820  auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2821 
2822  if (!v1Splat || !v2Splat)
2823  return failure();
2824 
2825  if (v1Splat.getInput() != v2Splat.getInput())
2826  return failure();
2827 
2828  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
2829  return success();
2830  }
2831 };
2832 
2833 /// Pattern to rewrite a fixed-size interleave via vector.shuffle to
2834 /// vector.interleave.
2835 class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
2836 public:
2838 
2839  LogicalResult matchAndRewrite(ShuffleOp op,
2840  PatternRewriter &rewriter) const override {
2841  VectorType resultType = op.getResultVectorType();
2842  if (resultType.isScalable())
2843  return rewriter.notifyMatchFailure(
2844  op, "ShuffleOp can't represent a scalable interleave");
2845 
2846  if (resultType.getRank() != 1)
2847  return rewriter.notifyMatchFailure(
2848  op, "ShuffleOp can't represent an n-D interleave");
2849 
2850  VectorType sourceType = op.getV1VectorType();
2851  if (sourceType != op.getV2VectorType() ||
2852  sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2853  return rewriter.notifyMatchFailure(
2854  op, "ShuffleOp types don't match an interleave");
2855  }
2856 
2857  ArrayRef<int64_t> shuffleMask = op.getMask();
2858  int64_t resultVectorSize = resultType.getNumElements();
2859  for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2860  int64_t maskValueA = shuffleMask[i * 2];
2861  int64_t maskValueB = shuffleMask[(i * 2) + 1];
2862  if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2863  return rewriter.notifyMatchFailure(op,
2864  "ShuffleOp mask not interleaving");
2865  }
2866 
2867  rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
2868  return success();
2869  }
2870 };
2871 
2872 } // namespace
2873 
2874 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
2875  MLIRContext *context) {
2876  results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2877  context);
2878 }
2879 
2880 //===----------------------------------------------------------------------===//
2881 // InsertElementOp
2882 //===----------------------------------------------------------------------===//
2883 
2884 void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2885  SetIntRangeFn setResultRanges) {
2886  setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2887 }
2888 
2889 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
2890  Value source, Value dest) {
2891  build(builder, result, source, dest, {});
2892 }
2893 
2894 LogicalResult InsertElementOp::verify() {
2895  auto dstVectorType = getDestVectorType();
2896  if (dstVectorType.getRank() == 0) {
2897  if (getPosition())
2898  return emitOpError("expected position to be empty with 0-D vector");
2899  return success();
2900  }
2901  if (dstVectorType.getRank() != 1)
2902  return emitOpError("unexpected >1 vector rank");
2903  if (!getPosition())
2904  return emitOpError("expected position for 1-D vector");
2905  return success();
2906 }
2907 
2908 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2909  // Skip the 0-D vector here.
2910  if (!adaptor.getPosition())
2911  return {};
2912 
2913  auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2914  auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2915  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2916  if (!src || !dst || !pos)
2917  return {};
2918 
2919  if (src.getType() != getDestVectorType().getElementType())
2920  return {};
2921 
2922  auto dstElements = dst.getValues<Attribute>();
2923 
2924  SmallVector<Attribute> results(dstElements);
2925 
2926  uint64_t posIdx = pos.getInt();
2927  if (posIdx >= results.size())
2928  return {};
2929  results[posIdx] = src;
2930 
2931  return DenseElementsAttr::get(getDestVectorType(), results);
2932 }
2933 
2934 //===----------------------------------------------------------------------===//
2935 // InsertOp
2936 //===----------------------------------------------------------------------===//
2937 
2938 void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2939  SetIntRangeFn setResultRanges) {
2940  setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2941 }
2942 
2943 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2944  Value source, Value dest, int64_t position) {
2945  build(builder, result, source, dest, ArrayRef<int64_t>{position});
2946 }
2947 
2948 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2949  Value source, Value dest, OpFoldResult position) {
2950  build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
2951 }
2952 
2953 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2954  Value source, Value dest,
2955  ArrayRef<int64_t> position) {
2956  SmallVector<OpFoldResult> posVals;
2957  posVals.reserve(position.size());
2958  llvm::transform(position, std::back_inserter(posVals),
2959  [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
2960  build(builder, result, source, dest, posVals);
2961 }
2962 
2963 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2964  Value source, Value dest,
2965  ArrayRef<OpFoldResult> position) {
2966  SmallVector<int64_t> staticPos;
2967  SmallVector<Value> dynamicPos;
2968  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
2969  build(builder, result, source, dest, dynamicPos,
2970  builder.getDenseI64ArrayAttr(staticPos));
2971 }
2972 
2973 LogicalResult InsertOp::verify() {
2974  SmallVector<OpFoldResult> position = getMixedPosition();
2975  auto destVectorType = getDestVectorType();
2976  if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
2977  return emitOpError(
2978  "expected position attribute of rank no greater than dest vector rank");
2979  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2980  if (srcVectorType &&
2981  (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2982  static_cast<unsigned>(destVectorType.getRank())))
2983  return emitOpError("expected position attribute rank + source rank to "
2984  "match dest vector rank");
2985  if (!srcVectorType &&
2986  (position.size() != static_cast<unsigned>(destVectorType.getRank())))
2987  return emitOpError(
2988  "expected position attribute rank to match the dest vector rank");
2989  for (auto [idx, pos] : llvm::enumerate(position)) {
2990  if (auto attr = pos.dyn_cast<Attribute>()) {
2991  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2992  if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
2993  destVectorType.getDimSize(idx))) {
2994  return emitOpError("expected position attribute #")
2995  << (idx + 1)
2996  << " to be a non-negative integer smaller than the "
2997  "corresponding "
2998  "dest vector dimension";
2999  }
3000  }
3001  }
3002  return success();
3003 }
3004 
3005 namespace {
3006 
3007 // If insertOp is only inserting unit dimensions it can be transformed to a
3008 // broadcast.
3009 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3010 public:
3012 
3013  LogicalResult matchAndRewrite(InsertOp insertOp,
3014  PatternRewriter &rewriter) const override {
3015  auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
3016  if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3017  srcVecType.getNumElements())
3018  return failure();
3019  rewriter.replaceOpWithNewOp<BroadcastOp>(
3020  insertOp, insertOp.getDestVectorType(), insertOp.getSource());
3021  return success();
3022  }
3023 };
3024 
3025 /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
3026 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3027 public:
3029 
3030  LogicalResult matchAndRewrite(InsertOp op,
3031  PatternRewriter &rewriter) const override {
3032  auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
3033  auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
3034 
3035  if (!srcSplat || !dstSplat)
3036  return failure();
3037 
3038  if (srcSplat.getInput() != dstSplat.getInput())
3039  return failure();
3040 
3041  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
3042  return success();
3043  }
3044 };
3045 
3046 // Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
3047 class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
3048 public:
3050 
3051  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3052  // unless the source vector constant has a single use.
3053  static constexpr int64_t vectorSizeFoldThreshold = 256;
3054 
3055  LogicalResult matchAndRewrite(InsertOp op,
3056  PatternRewriter &rewriter) const override {
3057  // TODO: Canonicalization for dynamic position not implemented yet.
3058  if (op.hasDynamicPosition())
3059  return failure();
3060 
3061  // Return if 'InsertOp' operand is not defined by a compatible vector
3062  // ConstantOp.
3063  TypedValue<VectorType> destVector = op.getDest();
3064  Attribute vectorDestCst;
3065  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3066  return failure();
3067  auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
3068  if (!denseDest)
3069  return failure();
3070 
3071  VectorType destTy = destVector.getType();
3072  if (destTy.isScalable())
3073  return failure();
3074 
3075  // Make sure we do not create too many large constants.
3076  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3077  !destVector.hasOneUse())
3078  return failure();
3079 
3080  Value sourceValue = op.getSource();
3081  Attribute sourceCst;
3082  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3083  return failure();
3084 
3085  // Calculate the linearized position of the continuous chunk of elements to
3086  // insert.
3087  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3088  copy(op.getStaticPosition(), completePositions.begin());
3089  int64_t insertBeginPosition =
3090  linearize(completePositions, computeStrides(destTy.getShape()));
3091 
3092  SmallVector<Attribute> insertedValues;
3093  Type destEltType = destTy.getElementType();
3094 
3095  // The `convertIntegerAttr` method specifically handles the case
3096  // for `llvm.mlir.constant` which can hold an attribute with a
3097  // different type than the return type.
3098  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
3099  for (auto value : denseSource.getValues<Attribute>())
3100  insertedValues.push_back(convertIntegerAttr(value, destEltType));
3101  } else {
3102  insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType));
3103  }
3104 
3105  auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
3106  copy(insertedValues, allValues.begin() + insertBeginPosition);
3107  auto newAttr = DenseElementsAttr::get(destTy, allValues);
3108 
3109  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3110  return success();
3111  }
3112 
3113 private:
3114  /// Converts the expected type to an IntegerAttr if there's
3115  /// a mismatch.
3116  Attribute convertIntegerAttr(Attribute attr, Type expectedType) const {
3117  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3118  if (intAttr.getType() != expectedType)
3119  return IntegerAttr::get(expectedType, intAttr.getInt());
3120  }
3121  return attr;
3122  }
3123 };
3124 
3125 } // namespace
3126 
3127 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3128  MLIRContext *context) {
3129  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3130  InsertOpConstantFolder>(context);
3131 }
3132 
3133 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3134  // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3135  // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3136  // (type mismatch).
3137  if (getNumIndices() == 0 && getSourceType() == getType())
3138  return getSource();
3139  SmallVector<Value> operands = {getSource(), getDest()};
3140  if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
3141  return val;
3142  if (auto res = foldPoisonIndexInsertExtractOp(
3143  getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3144  return res;
3145 
3146  return {};
3147 }
3148 
3149 //===----------------------------------------------------------------------===//
3150 // InsertStridedSliceOp
3151 //===----------------------------------------------------------------------===//
3152 
3153 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3154  Value source, Value dest,
3155  ArrayRef<int64_t> offsets,
3156  ArrayRef<int64_t> strides) {
3157  result.addOperands({source, dest});
3158  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3159  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3160  result.addTypes(dest.getType());
3161  result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
3162  offsetsAttr);
3163  result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
3164  stridesAttr);
3165 }
3166 
3167 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
3168 template <typename OpType>
3169 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3170  ArrayAttr arrayAttr,
3171  ArrayRef<int64_t> shape,
3172  StringRef attrName) {
3173  if (arrayAttr.size() > shape.size())
3174  return op.emitOpError("expected ")
3175  << attrName << " attribute of rank no greater than vector rank";
3176  return success();
3177 }
3178 
3179 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3180 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3181 // Otherwise, the admissible interval is [min, max].
3182 template <typename OpType>
3183 static LogicalResult
3184 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
3185  int64_t max, StringRef attrName,
3186  bool halfOpen = true) {
3187  for (auto attr : arrayAttr) {
3188  auto val = llvm::cast<IntegerAttr>(attr).getInt();
3189  auto upper = max;
3190  if (!halfOpen)
3191  upper += 1;
3192  if (val < min || val >= upper)
3193  return op.emitOpError("expected ") << attrName << " to be confined to ["
3194  << min << ", " << upper << ")";
3195  }
3196  return success();
3197 }
3198 
3199 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3200 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3201 // Otherwise, the admissible interval is [min, max].
3202 template <typename OpType>
3203 static LogicalResult
3204 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
3205  ArrayRef<int64_t> shape, StringRef attrName,
3206  bool halfOpen = true, int64_t min = 0) {
3207  for (auto [index, attrDimPair] :
3208  llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3209  int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3210  int64_t max = std::get<1>(attrDimPair);
3211  if (!halfOpen)
3212  max += 1;
3213  if (val < min || val >= max)
3214  return op.emitOpError("expected ")
3215  << attrName << " dimension " << index << " to be confined to ["
3216  << min << ", " << max << ")";
3217  }
3218  return success();
3219 }
3220 
3221 // Returns true if, for all indices i = 0..shape.size()-1, val is in the
3222 // [min, max} interval:
3223 // val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3224 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3225 // the admissible interval is [min, max].
3226 template <typename OpType>
3228  OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3229  ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3230  bool halfOpen = true, int64_t min = 1) {
3231  assert(arrayAttr1.size() <= shape.size());
3232  assert(arrayAttr2.size() <= shape.size());
3233  for (auto [index, it] :
3234  llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3235  auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3236  auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3237  int64_t max = std::get<2>(it);
3238  if (!halfOpen)
3239  max += 1;
3240  if (val1 + val2 < 0 || val1 + val2 >= max)
3241  return op.emitOpError("expected sum(")
3242  << attrName1 << ", " << attrName2 << ") dimension " << index
3243  << " to be confined to [" << min << ", " << max << ")";
3244  }
3245  return success();
3246 }
3247 
3248 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
3249  MLIRContext *context) {
3250  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3251  return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3252  });
3253  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3254 }
3255 
3256 LogicalResult InsertStridedSliceOp::verify() {
3257  auto sourceVectorType = getSourceVectorType();
3258  auto destVectorType = getDestVectorType();
3259  auto offsets = getOffsetsAttr();
3260  auto strides = getStridesAttr();
3261  if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3262  return emitOpError(
3263  "expected offsets of same size as destination vector rank");
3264  if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3265  return emitOpError("expected strides of same size as source vector rank");
3266  if (sourceVectorType.getRank() > destVectorType.getRank())
3267  return emitOpError(
3268  "expected source rank to be no greater than destination rank");
3269 
3270  auto sourceShape = sourceVectorType.getShape();
3271  auto destShape = destVectorType.getShape();
3272  SmallVector<int64_t, 4> sourceShapeAsDestShape(
3273  destShape.size() - sourceShape.size(), 0);
3274  sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3275  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3276  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3277  if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3278  offName)) ||
3279  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3280  /*max=*/1, stridesName,
3281  /*halfOpen=*/false)) ||
3283  *this, offsets,
3284  makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3285  offName, "source vector shape",
3286  /*halfOpen=*/false, /*min=*/1)))
3287  return failure();
3288 
3289  unsigned rankDiff = destShape.size() - sourceShape.size();
3290  for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3291  if (sourceVectorType.getScalableDims()[idx] !=
3292  destVectorType.getScalableDims()[idx + rankDiff]) {
3293  return emitOpError("mismatching scalable flags (at source vector idx=")
3294  << idx << ")";
3295  }
3296  if (sourceVectorType.getScalableDims()[idx]) {
3297  auto sourceSize = sourceShape[idx];
3298  auto destSize = destShape[idx + rankDiff];
3299  if (sourceSize != destSize) {
3300  return emitOpError("expected size at idx=")
3301  << idx
3302  << (" to match the corresponding base size from the input "
3303  "vector (")
3304  << sourceSize << (" vs ") << destSize << (")");
3305  }
3306  }
3307  }
3308 
3309  return success();
3310 }
3311 
3312 namespace {
3313 /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3314 /// SplatOp(X):dst_type) to SplatOp(X):dst_type.
3315 class FoldInsertStridedSliceSplat final
3316  : public OpRewritePattern<InsertStridedSliceOp> {
3317 public:
3319 
3320  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3321  PatternRewriter &rewriter) const override {
3322  auto srcSplatOp =
3323  insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
3324  auto destSplatOp =
3325  insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3326 
3327  if (!srcSplatOp || !destSplatOp)
3328  return failure();
3329 
3330  if (srcSplatOp.getInput() != destSplatOp.getInput())
3331  return failure();
3332 
3333  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3334  return success();
3335  }
3336 };
3337 
3338 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3339 /// to dst.
3340 class FoldInsertStridedSliceOfExtract final
3341  : public OpRewritePattern<InsertStridedSliceOp> {
3342 public:
3344 
3345  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3346  PatternRewriter &rewriter) const override {
3347  auto extractStridedSliceOp =
3348  insertStridedSliceOp.getSource()
3349  .getDefiningOp<vector::ExtractStridedSliceOp>();
3350 
3351  if (!extractStridedSliceOp)
3352  return failure();
3353 
3354  if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3355  return failure();
3356 
3357  // Check if have the same strides and offsets.
3358  if (extractStridedSliceOp.getStrides() !=
3359  insertStridedSliceOp.getStrides() ||
3360  extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3361  return failure();
3362 
3363  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3364  return success();
3365  }
3366 };
3367 
3368 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3369 // ConstantOp.
3370 class InsertStridedSliceConstantFolder final
3371  : public OpRewritePattern<InsertStridedSliceOp> {
3372 public:
3374 
3375  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3376  // unless the source vector constant has a single use.
3377  static constexpr int64_t vectorSizeFoldThreshold = 256;
3378 
3379  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3380  PatternRewriter &rewriter) const override {
3381  // Return if 'InsertOp' operand is not defined by a compatible vector
3382  // ConstantOp.
3383  TypedValue<VectorType> destVector = op.getDest();
3384  Attribute vectorDestCst;
3385  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3386  return failure();
3387 
3388  VectorType destTy = destVector.getType();
3389  if (destTy.isScalable())
3390  return failure();
3391 
3392  // Make sure we do not create too many large constants.
3393  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3394  !destVector.hasOneUse())
3395  return failure();
3396 
3397  TypedValue<VectorType> sourceValue = op.getSource();
3398  Attribute sourceCst;
3399  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3400  return failure();
3401 
3402  // TODO: Support poison.
3403  if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3404  return failure();
3405 
3406  // TODO: Handle non-unit strides when they become available.
3407  if (op.hasNonUnitStrides())
3408  return failure();
3409 
3410  VectorType sliceVecTy = sourceValue.getType();
3411  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3412  int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3413  SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3414  SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3415 
3416  // Calcualte the destination element indices by enumerating all slice
3417  // positions within the destination and linearizing them. The enumeration
3418  // order is lexicographic which yields a sequence of monotonically
3419  // increasing linearized position indices.
3420  // Because the destination may have higher dimensionality then the slice,
3421  // we keep track of two overlapping sets of positions and offsets.
3422  auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3423  auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3424  auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3425  auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3426  SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3427  MutableArrayRef<int64_t> currSlicePosition(
3428  currDestPosition.begin() + rankDifference, currDestPosition.end());
3429  ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3430  offsets.end());
3431  do {
3432  int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3433  assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3434  assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3435  "Invalid slice element");
3436  newValues[linearizedPosition] = *sliceValuesIt;
3437  ++sliceValuesIt;
3438  } while (succeeded(
3439  incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3440 
3441  auto newAttr = DenseElementsAttr::get(destTy, newValues);
3442  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3443  return success();
3444  }
3445 };
3446 
3447 } // namespace
3448 
3449 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3450  RewritePatternSet &results, MLIRContext *context) {
3451  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3452  InsertStridedSliceConstantFolder>(context);
3453 }
3454 
3455 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3456  if (getSourceVectorType() == getDestVectorType())
3457  return getSource();
3458  return {};
3459 }
3460 
3461 //===----------------------------------------------------------------------===//
3462 // OuterProductOp
3463 //===----------------------------------------------------------------------===//
3464 
3465 /// Build an op without mask, use the type of `acc` as the return type.
3466 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
3467  Value lhs, Value rhs, Value acc) {
3468  result.addOperands({lhs, rhs, acc});
3469  result.addTypes(acc.getType());
3470 }
3471 
3473  p << " " << getLhs() << ", " << getRhs();
3474  if (getAcc()) {
3475  p << ", " << getAcc();
3476  p.printOptionalAttrDict((*this)->getAttrs());
3477  }
3478  p << " : " << getLhs().getType() << ", " << getRhs().getType();
3479 }
3480 
3481 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
3483  Type tLHS, tRHS;
3484  if (parser.parseOperandList(operandsInfo) ||
3485  parser.parseOptionalAttrDict(result.attributes) ||
3486  parser.parseColonType(tLHS) || parser.parseComma() ||
3487  parser.parseType(tRHS))
3488  return failure();
3489  if (operandsInfo.size() < 2)
3490  return parser.emitError(parser.getNameLoc(),
3491  "expected at least 2 operands");
3492  VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3493  VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3494  if (!vLHS)
3495  return parser.emitError(parser.getNameLoc(),
3496  "expected vector type for operand #1");
3497 
3498  VectorType resType;
3499  if (vRHS) {
3500  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3501  vRHS.getScalableDims()[0]};
3502  resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
3503  vLHS.getElementType(), scalableDimsRes);
3504  } else {
3505  // Scalar RHS operand
3506  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3507  resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3508  scalableDimsRes);
3509  }
3510 
3511  if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
3512  result.attributes.append(
3513  OuterProductOp::getKindAttrName(result.name),
3515  OuterProductOp::getDefaultKind()));
3516  }
3517 
3518  return failure(
3519  parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
3520  parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
3521  (operandsInfo.size() > 2 &&
3522  parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
3523  parser.addTypeToList(resType, result.types));
3524 }
3525 
3526 LogicalResult OuterProductOp::verify() {
3527  Type tRHS = getOperandTypeRHS();
3528  VectorType vLHS = getOperandVectorTypeLHS(),
3529  vRHS = llvm::dyn_cast<VectorType>(tRHS),
3530  vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3531 
3532  if (vLHS.getRank() != 1)
3533  return emitOpError("expected 1-d vector for operand #1");
3534 
3535  if (vRHS) {
3536  // Proper OUTER operation.
3537  if (vRHS.getRank() != 1)
3538  return emitOpError("expected 1-d vector for operand #2");
3539  if (vRES.getRank() != 2)
3540  return emitOpError("expected 2-d vector result");
3541  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3542  return emitOpError("expected #1 operand dim to match result dim #1");
3543  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3544  return emitOpError("expected #2 operand dim to match result dim #2");
3545  if (vLHS.isScalable() && !vRHS.isScalable()) {
3546  // This restriction reflects what's currently supported in terms of
3547  // scalable vectors. However, we could relax this if there's a use case.
3548  return emitOpError(
3549  "expected either both or only #2 operand dim to be scalable");
3550  }
3551  } else {
3552  // An AXPY operation.
3553  if (vRES.getRank() != 1)
3554  return emitOpError("expected 1-d vector result");
3555  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3556  return emitOpError("expected #1 operand dim to match result dim #1");
3557  }
3558 
3559  if (vACC && vACC != vRES)
3560  return emitOpError("expected operand #3 of same type as result type");
3561 
3562  // Verify supported combining kind.
3563  if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
3564  return emitOpError("unsupported outerproduct type");
3565 
3566  return success();
3567 }
3568 
3569 // MaskableOpInterface methods.
3570 
3571 /// Returns the mask type expected by this operation. Mostly used for
3572 /// verification purposes. It requires the operation to be vectorized."
3573 Type OuterProductOp::getExpectedMaskType() {
3574  auto vecType = this->getResultVectorType();
3575  return VectorType::get(vecType.getShape(),
3576  IntegerType::get(vecType.getContext(), /*width=*/1),
3577  vecType.getScalableDims());
3578 }
3579 
3580 //===----------------------------------------------------------------------===//
3581 // ExtractStridedSliceOp
3582 //===----------------------------------------------------------------------===//
3583 
3584 // Inference works as follows:
3585 // 1. Add 'sizes' from prefix of dims in 'offsets'.
3586 // 2. Add sizes from 'vectorType' for remaining dims.
3587 // Scalable flags are inherited from 'vectorType'.
3588 static Type inferStridedSliceOpResultType(VectorType vectorType,
3589  ArrayAttr offsets, ArrayAttr sizes,
3590  ArrayAttr strides) {
3591  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3593  shape.reserve(vectorType.getRank());
3594  unsigned idx = 0;
3595  for (unsigned e = offsets.size(); idx < e; ++idx)
3596  shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3597  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3598  shape.push_back(vectorType.getShape()[idx]);
3599 
3600  return VectorType::get(shape, vectorType.getElementType(),
3601  vectorType.getScalableDims());
3602 }
3603 
3604 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3605  Value source, ArrayRef<int64_t> offsets,
3606  ArrayRef<int64_t> sizes,
3607  ArrayRef<int64_t> strides) {
3608  result.addOperands(source);
3609  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3610  auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
3611  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3612  result.addTypes(
3613  inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
3614  offsetsAttr, sizesAttr, stridesAttr));
3615  result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
3616  offsetsAttr);
3617  result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
3618  sizesAttr);
3619  result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
3620  stridesAttr);
3621 }
3622 
3623 LogicalResult ExtractStridedSliceOp::verify() {
3624  auto type = getSourceVectorType();
3625  auto offsets = getOffsetsAttr();
3626  auto sizes = getSizesAttr();
3627  auto strides = getStridesAttr();
3628  if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3629  return emitOpError(
3630  "expected offsets, sizes and strides attributes of same size");
3631 
3632  auto shape = type.getShape();
3633  auto offName = getOffsetsAttrName();
3634  auto sizesName = getSizesAttrName();
3635  auto stridesName = getStridesAttrName();
3636  if (failed(
3637  isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
3638  failed(
3639  isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
3640  failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
3641  stridesName)) ||
3642  failed(
3643  isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
3644  failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
3645  /*halfOpen=*/false,
3646  /*min=*/1)) ||
3647  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3648  /*max=*/1, stridesName,
3649  /*halfOpen=*/false)) ||
3650  failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
3651  shape, offName, sizesName,
3652  /*halfOpen=*/false)))
3653  return failure();
3654 
3655  auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
3656  offsets, sizes, strides);
3657  if (getResult().getType() != resultType)
3658  return emitOpError("expected result type to be ") << resultType;
3659 
3660  for (unsigned idx = 0; idx < sizes.size(); ++idx) {
3661  if (type.getScalableDims()[idx]) {
3662  auto inputDim = type.getShape()[idx];
3663  auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3664  if (inputDim != inputSize)
3665  return emitOpError("expected size at idx=")
3666  << idx
3667  << (" to match the corresponding base size from the input "
3668  "vector (")
3669  << inputSize << (" vs ") << inputDim << (")");
3670  }
3671  }
3672 
3673  return success();
3674 }
3675 
3676 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
3677 // to use the source of the InsertStrided ops if we can detect that the
3678 // extracted vector is a subset of one of the vector inserted.
3679 static LogicalResult
3680 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
3681  // Helper to extract integer out of ArrayAttr.
3682  auto getElement = [](ArrayAttr array, int idx) {
3683  return llvm::cast<IntegerAttr>(array[idx]).getInt();
3684  };
3685  ArrayAttr extractOffsets = op.getOffsets();
3686  ArrayAttr extractStrides = op.getStrides();
3687  ArrayAttr extractSizes = op.getSizes();
3688  auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3689  while (insertOp) {
3690  if (op.getSourceVectorType().getRank() !=
3691  insertOp.getSourceVectorType().getRank())
3692  return failure();
3693  ArrayAttr insertOffsets = insertOp.getOffsets();
3694  ArrayAttr insertStrides = insertOp.getStrides();
3695  // If the rank of extract is greater than the rank of insert, we are likely
3696  // extracting a partial chunk of the vector inserted.
3697  if (extractOffsets.size() > insertOffsets.size())
3698  return failure();
3699  bool patialoverlap = false;
3700  bool disjoint = false;
3701  SmallVector<int64_t, 4> offsetDiffs;
3702  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3703  if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
3704  return failure();
3705  int64_t start = getElement(insertOffsets, dim);
3706  int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3707  int64_t offset = getElement(extractOffsets, dim);
3708  int64_t size = getElement(extractSizes, dim);
3709  // Check if the start of the extract offset is in the interval inserted.
3710  if (start <= offset && offset < end) {
3711  // If the extract interval overlaps but is not fully included we may
3712  // have a partial overlap that will prevent any folding.
3713  if (offset + size > end)
3714  patialoverlap = true;
3715  offsetDiffs.push_back(offset - start);
3716  continue;
3717  }
3718  disjoint = true;
3719  break;
3720  }
3721  // The extract element chunk is a subset of the insert element.
3722  if (!disjoint && !patialoverlap) {
3723  op.setOperand(insertOp.getSource());
3724  // OpBuilder is only used as a helper to build an I64ArrayAttr.
3725  OpBuilder b(op.getContext());
3726  op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
3727  return success();
3728  }
3729  // If the chunk extracted is disjoint from the chunk inserted, keep looking
3730  // in the insert chain.
3731  if (disjoint)
3732  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3733  else {
3734  // The extracted vector partially overlap the inserted vector, we cannot
3735  // fold.
3736  return failure();
3737  }
3738  }
3739  return failure();
3740 }
3741 
3742 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3743  if (getSourceVectorType() == getResult().getType())
3744  return getVector();
3745  if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
3746  return getResult();
3747  return {};
3748 }
3749 
3750 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
3751  populateFromInt64AttrArray(getOffsets(), results);
3752 }
3753 
3754 namespace {
3755 
3756 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
3757 // ConstantMaskOp.
3758 class StridedSliceConstantMaskFolder final
3759  : public OpRewritePattern<ExtractStridedSliceOp> {
3760 public:
3762 
3763  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3764  PatternRewriter &rewriter) const override {
3765  // Return if 'extractStridedSliceOp' operand is not defined by a
3766  // ConstantMaskOp.
3767  auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3768  auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3769  if (!constantMaskOp)
3770  return failure();
3771  // Return if 'extractStridedSliceOp' has non-unit strides.
3772  if (extractStridedSliceOp.hasNonUnitStrides())
3773  return failure();
3774  // Gather constant mask dimension sizes.
3775  ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
3776  // Gather strided slice offsets and sizes.
3777  SmallVector<int64_t, 4> sliceOffsets;
3778  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
3779  sliceOffsets);
3780  SmallVector<int64_t, 4> sliceSizes;
3781  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
3782 
3783  // Compute slice of vector mask region.
3784  SmallVector<int64_t, 4> sliceMaskDimSizes;
3785  sliceMaskDimSizes.reserve(maskDimSizes.size());
3786  for (auto [maskDimSize, sliceOffset, sliceSize] :
3787  llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3788  int64_t sliceMaskDimSize = std::max(
3789  static_cast<int64_t>(0),
3790  std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3791  sliceMaskDimSizes.push_back(sliceMaskDimSize);
3792  }
3793  // Add unchanged dimensions.
3794  if (sliceMaskDimSizes.size() < maskDimSizes.size())
3795  for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3796  sliceMaskDimSizes.push_back(maskDimSizes[i]);
3797  // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
3798  // region is a conjunction of mask dim intervals).
3799  if (llvm::is_contained(sliceMaskDimSizes, 0))
3800  sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3801 
3802  // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
3803  // region.
3804  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3805  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3806  sliceMaskDimSizes);
3807  return success();
3808  }
3809 };
3810 
3811 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3812 class StridedSliceSplatConstantFolder final
3813  : public OpRewritePattern<ExtractStridedSliceOp> {
3814 public:
3816 
3817  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3818  PatternRewriter &rewriter) const override {
3819  // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3820  // ConstantOp.
3821  Value sourceVector = extractStridedSliceOp.getVector();
3822  Attribute vectorCst;
3823  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3824  return failure();
3825 
3826  auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3827  if (!splat)
3828  return failure();
3829 
3830  auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
3831  splat.getSplatValue<Attribute>());
3832  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3833  newAttr);
3834  return success();
3835  }
3836 };
3837 
3838 // Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3839 // ConstantOp.
3840 class StridedSliceNonSplatConstantFolder final
3841  : public OpRewritePattern<ExtractStridedSliceOp> {
3842 public:
3844 
3845  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3846  PatternRewriter &rewriter) const override {
3847  // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3848  // ConstantOp.
3849  Value sourceVector = extractStridedSliceOp.getVector();
3850  Attribute vectorCst;
3851  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3852  return failure();
3853 
3854  // The splat case is handled by `StridedSliceSplatConstantFolder`.
3855  auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3856  if (!dense || dense.isSplat())
3857  return failure();
3858 
3859  // TODO: Handle non-unit strides when they become available.
3860  if (extractStridedSliceOp.hasNonUnitStrides())
3861  return failure();
3862 
3863  auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
3864  ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3865  SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3866 
3867  VectorType sliceVecTy = extractStridedSliceOp.getType();
3868  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3869  int64_t sliceRank = sliceVecTy.getRank();
3870 
3871  // Expand offsets and sizes to match the vector rank.
3872  SmallVector<int64_t, 4> offsets(sliceRank, 0);
3873  copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
3874 
3875  SmallVector<int64_t, 4> sizes(sourceShape);
3876  copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
3877 
3878  // Calculate the slice elements by enumerating all slice positions and
3879  // linearizing them. The enumeration order is lexicographic which yields a
3880  // sequence of monotonically increasing linearized position indices.
3881  auto denseValuesBegin = dense.value_begin<Attribute>();
3882  SmallVector<Attribute> sliceValues;
3883  sliceValues.reserve(sliceVecTy.getNumElements());
3884  SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3885  do {
3886  int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
3887  assert(linearizedPosition < sourceVecTy.getNumElements() &&
3888  "Invalid index");
3889  sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3890  } while (
3891  succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
3892 
3893  assert(static_cast<int64_t>(sliceValues.size()) ==
3894  sliceVecTy.getNumElements() &&
3895  "Invalid number of slice elements");
3896  auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
3897  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3898  newAttr);
3899  return success();
3900  }
3901 };
3902 
3903 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
3904 // BroadcastOp(ExtractStrideSliceOp).
3905 class StridedSliceBroadcast final
3906  : public OpRewritePattern<ExtractStridedSliceOp> {
3907 public:
3909 
3910  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3911  PatternRewriter &rewriter) const override {
3912  auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3913  if (!broadcast)
3914  return failure();
3915  auto srcVecType =
3916  llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
3917  unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3918  auto dstVecType = llvm::cast<VectorType>(op.getType());
3919  unsigned dstRank = dstVecType.getRank();
3920  unsigned rankDiff = dstRank - srcRank;
3921  // Check if the most inner dimensions of the source of the broadcast are the
3922  // same as the destination of the extract. If this is the case we can just
3923  // use a broadcast as the original dimensions are untouched.
3924  bool lowerDimMatch = true;
3925  for (unsigned i = 0; i < srcRank; i++) {
3926  if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3927  lowerDimMatch = false;
3928  break;
3929  }
3930  }
3931  Value source = broadcast.getSource();
3932  // If the inner dimensions don't match, it means we need to extract from the
3933  // source of the orignal broadcast and then broadcast the extracted value.
3934  // We also need to handle degenerated cases where the source is effectively
3935  // just a single scalar.
3936  bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3937  if (!lowerDimMatch && !isScalarSrc) {
3938  source = rewriter.create<ExtractStridedSliceOp>(
3939  op->getLoc(), source,
3940  getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
3941  getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
3942  getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
3943  }
3944  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
3945  return success();
3946  }
3947 };
3948 
3949 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
3950 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
3951 public:
3953 
3954  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3955  PatternRewriter &rewriter) const override {
3956  auto splat = op.getVector().getDefiningOp<SplatOp>();
3957  if (!splat)
3958  return failure();
3959  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
3960  return success();
3961  }
3962 };
3963 
3964 /// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
3965 /// slice is contiguous, into extract and shape_cast.
3966 ///
3967 /// Example:
3968 /// Before:
3969 /// %1 = vector.extract_strided_slice %arg0 {
3970 /// offsets = [0, 0, 0, 0, 0],
3971 /// sizes = [1, 1, 1, 1, 8],
3972 /// strides = [1, 1, 1, 1, 1]
3973 /// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
3974 /// After:
3975 /// %0 = vector.extract %arg0[0, 0, 0, 0]
3976 /// : vector<8xi8> from vector<8x1x1x2x8xi8>
3977 /// %1 = vector.shape_cast %0
3978 /// : vector<8xi8> to vector<1x1x1x1x8xi8>
3979 ///
3980 class ContiguousExtractStridedSliceToExtract final
3981  : public OpRewritePattern<ExtractStridedSliceOp> {
3982 public:
3984 
3985  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3986  PatternRewriter &rewriter) const override {
3987  if (op.hasNonUnitStrides())
3988  return failure();
3989  Value source = op.getOperand();
3990  auto sourceType = cast<VectorType>(source.getType());
3991  if (sourceType.isScalable() || sourceType.getRank() == 0)
3992  return failure();
3993 
3994  // Compute the number of offsets to pass to ExtractOp::build. That is the
3995  // difference between the source rank and the desired slice rank. We walk
3996  // the dimensions from innermost out, and stop when the next slice dimension
3997  // is not full-size.
3998  SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
3999  int numOffsets;
4000  for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4001  if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4002  break;
4003  }
4004 
4005  // If the created extract op would have no offsets, then this whole
4006  // extract_strided_slice is the identity and should have been handled by
4007  // other canonicalizations.
4008  if (numOffsets == 0)
4009  return failure();
4010 
4011  // If not even the inner-most dimension is full-size, this op can't be
4012  // rewritten as an ExtractOp.
4013  if (numOffsets == sourceType.getRank() &&
4014  static_cast<int>(sizes.size()) == sourceType.getRank())
4015  return failure();
4016 
4017  // The outer dimensions must have unit size.
4018  for (int i = 0; i < numOffsets; ++i) {
4019  if (sizes[i] != 1)
4020  return failure();
4021  }
4022 
4023  // Avoid generating slices that have leading unit dimensions. The shape_cast
4024  // op that we create below would take bad generic fallback patterns
4025  // (ShapeCastOpRewritePattern).
4026  while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
4027  sizes[numOffsets] == 1) {
4028  ++numOffsets;
4029  }
4030 
4031  SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
4032  auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4033  Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
4034  extractOffsets);
4035  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
4036  return success();
4037  }
4038 };
4039 
4040 } // namespace
4041 
4042 void ExtractStridedSliceOp::getCanonicalizationPatterns(
4043  RewritePatternSet &results, MLIRContext *context) {
4044  // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
4045  // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4046  results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
4047  StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
4048  StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4049  context);
4050 }
4051 
4052 //===----------------------------------------------------------------------===//
4053 // TransferReadOp
4054 //===----------------------------------------------------------------------===//
4055 
4056 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
4057 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4058  VectorType vectorType, Value source,
4059  ValueRange indices, AffineMapAttr permutationMapAttr,
4060  /*optional*/ ArrayAttr inBoundsAttr) {
4061  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4062  Value padding = builder.create<arith::ConstantOp>(
4063  result.location, elemType, builder.getZeroAttr(elemType));
4064  build(builder, result, vectorType, source, indices, permutationMapAttr,
4065  padding, /*mask=*/Value(), inBoundsAttr);
4066 }
4067 
4068 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
4069 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4070  VectorType vectorType, Value source,
4071  ValueRange indices, AffineMap permutationMap,
4072  std::optional<ArrayRef<bool>> inBounds) {
4073  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4074  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4075  ? builder.getBoolArrayAttr(inBounds.value())
4076  : builder.getBoolArrayAttr(
4077  SmallVector<bool>(vectorType.getRank(), false));
4078  build(builder, result, vectorType, source, indices, permutationMapAttr,
4079  inBoundsAttr);
4080 }
4081 
4082 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
4083 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4084  VectorType vectorType, Value source,
4085  ValueRange indices, Value padding,
4086  std::optional<ArrayRef<bool>> inBounds) {
4087  AffineMap permutationMap = getTransferMinorIdentityMap(
4088  llvm::cast<ShapedType>(source.getType()), vectorType);
4089  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4090  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4091  ? builder.getBoolArrayAttr(inBounds.value())
4092  : builder.getBoolArrayAttr(
4093  SmallVector<bool>(vectorType.getRank(), false));
4094  build(builder, result, vectorType, source, indices, permutationMapAttr,
4095  padding,
4096  /*mask=*/Value(), inBoundsAttr);
4097 }
4098 
4099 /// 4. Builder that sets padding to zero and permutation map to
4100 /// 'getMinorIdentityMap'.
4101 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4102  VectorType vectorType, Value source,
4103  ValueRange indices,
4104  std::optional<ArrayRef<bool>> inBounds) {
4105  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4106  Value padding = builder.create<arith::ConstantOp>(
4107  result.location, elemType, builder.getZeroAttr(elemType));
4108  build(builder, result, vectorType, source, indices, padding, inBounds);
4109 }
4110 
4111 template <typename EmitFun>
4112 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
4113  EmitFun emitOpError) {
4114  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
4115  for (auto expr : permutationMap.getResults()) {
4116  auto dim = dyn_cast<AffineDimExpr>(expr);
4117  auto zero = dyn_cast<AffineConstantExpr>(expr);
4118  if (zero) {
4119  if (zero.getValue() != 0) {
4120  return emitOpError(
4121  "requires a projected permutation_map (at most one dim or the zero "
4122  "constant can appear in each result)");
4123  }
4124  continue;
4125  }
4126  if (!dim) {
4127  return emitOpError("requires a projected permutation_map (at most one "
4128  "dim or the zero constant can appear in each result)");
4129  }
4130  if (seen[dim.getPosition()]) {
4131  return emitOpError(
4132  "requires a permutation_map that is a permutation (found one dim "
4133  "used more than once)");
4134  }
4135  seen[dim.getPosition()] = true;
4136  }
4137  return success();
4138 }
4139 
4140 static LogicalResult
4141 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4142  VectorType vectorType, VectorType maskType,
4143  VectorType inferredMaskType, AffineMap permutationMap,
4144  ArrayAttr inBounds) {
4145  if (op->hasAttr("masked")) {
4146  return op->emitOpError("masked attribute has been removed. "
4147  "Use in_bounds instead.");
4148  }
4149 
4150  if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4151  return op->emitOpError(
4152  "requires source to be a memref or ranked tensor type");
4153 
4154  auto elementType = shapedType.getElementType();
4155  DataLayout dataLayout = DataLayout::closest(op);
4156  if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4157  // Memref or tensor has vector element type.
4158  unsigned sourceVecSize =
4159  dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
4160  vectorElementType.getShape().back();
4161  unsigned resultVecSize =
4162  dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
4163  vectorType.getShape().back();
4164  if (resultVecSize % sourceVecSize != 0)
4165  return op->emitOpError(
4166  "requires the bitwidth of the minor 1-D vector to be an integral "
4167  "multiple of the bitwidth of the minor 1-D vector of the source");
4168 
4169  unsigned sourceVecEltRank = vectorElementType.getRank();
4170  unsigned resultVecRank = vectorType.getRank();
4171  if (sourceVecEltRank > resultVecRank)
4172  return op->emitOpError(
4173  "requires source vector element and vector result ranks to match.");
4174  unsigned rankOffset = resultVecRank - sourceVecEltRank;
4175  // Check that permutation map results match 'rankOffset' of vector type.
4176  if (permutationMap.getNumResults() != rankOffset)
4177  return op->emitOpError("requires a permutation_map with result dims of "
4178  "the same rank as the vector type");
4179 
4180  if (maskType)
4181  return op->emitOpError("does not support masks with vector element type");
4182  } else {
4183  // Memref or tensor has scalar element type.
4184  unsigned minorSize =
4185  vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4186  unsigned resultVecSize =
4187  dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
4188  if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
4189  return op->emitOpError(
4190  "requires the bitwidth of the minor 1-D vector to be an integral "
4191  "multiple of the bitwidth of the source element type");
4192 
4193  // Check that permutation map results match rank of vector type.
4194  if (permutationMap.getNumResults() != vectorType.getRank())
4195  return op->emitOpError("requires a permutation_map with result dims of "
4196  "the same rank as the vector type");
4197  }
4198 
4199  if (permutationMap.getNumSymbols() != 0)
4200  return op->emitOpError("requires permutation_map without symbols");
4201 
4202  if (permutationMap.getNumInputs() != shapedType.getRank())
4203  return op->emitOpError("requires a permutation_map with input dims of the "
4204  "same rank as the source type");
4205 
4206  if (maskType && maskType != inferredMaskType)
4207  return op->emitOpError("inferred mask type (")
4208  << inferredMaskType << ") and mask operand type (" << maskType
4209  << ") don't match";
4210 
4211  if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
4212  return op->emitOpError("expects the in_bounds attr of same rank "
4213  "as permutation_map results: ")
4214  << AffineMapAttr::get(permutationMap)
4215  << " vs inBounds of size: " << inBounds.size();
4216 
4217  return success();
4218 }
4219 
4220 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
4221  SmallVector<StringRef, 3> elidedAttrs;
4222  elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4223  if (op.getPermutationMap().isMinorIdentity())
4224  elidedAttrs.push_back(op.getPermutationMapAttrName());
4225  // Elide in_bounds attribute if all dims are out-of-bounds.
4226  if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
4227  elidedAttrs.push_back(op.getInBoundsAttrName());
4228  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4229 }
4230 
4232  p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
4233  if (getMask())
4234  p << ", " << getMask();
4235  printTransferAttrs(p, *this);
4236  p << " : " << getShapedType() << ", " << getVectorType();
4237 }
4238 
4239 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
4240  AffineMap permMap) {
4241  auto i1Type = IntegerType::get(permMap.getContext(), 1);
4242  AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
4243  assert(invPermMap && "Inversed permutation map couldn't be computed");
4244  SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
4245 
4246  // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
4247  // 0-D mask into a single-element 1-D mask.
4248  if (maskShape.empty())
4249  maskShape.push_back(1);
4250 
4251  SmallVector<bool> scalableDims =
4252  applyPermutationMap(invPermMap, vecType.getScalableDims());
4253 
4254  return VectorType::get(maskShape, i1Type, scalableDims);
4255 }
4256 
4257 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4258  auto &builder = parser.getBuilder();
4259  SMLoc typesLoc;
4260  OpAsmParser::UnresolvedOperand sourceInfo;
4262  OpAsmParser::UnresolvedOperand paddingInfo;
4263  SmallVector<Type, 2> types;
4265  // Parsing with support for paddingValue.
4266  if (parser.parseOperand(sourceInfo) ||
4267  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
4268  parser.parseComma() || parser.parseOperand(paddingInfo))
4269  return failure();
4270  ParseResult hasMask = parser.parseOptionalComma();
4271  if (hasMask.succeeded()) {
4272  if (parser.parseOperand(maskInfo))
4273  return failure();
4274  }
4275  if (parser.parseOptionalAttrDict(result.attributes) ||
4276  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4277  return failure();
4278  if (types.size() != 2)
4279  return parser.emitError(typesLoc, "requires two types");
4280  auto indexType = builder.getIndexType();
4281  auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4282  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4283  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4284  VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4285  if (!vectorType)
4286  return parser.emitError(typesLoc, "requires vector type");
4287  auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
4288  Attribute permMapAttr = result.attributes.get(permMapAttrName);
4289  AffineMap permMap;
4290  if (!permMapAttr) {
4291  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4292  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4293  } else {
4294  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4295  }
4296  auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
4297  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4298  if (!inBoundsAttr) {
4299  result.addAttribute(inBoundsAttrName,
4300  builder.getBoolArrayAttr(
4301  SmallVector<bool>(permMap.getNumResults(), false)));
4302  }
4303  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4304  parser.resolveOperands(indexInfo, indexType, result.operands) ||
4305  parser.resolveOperand(paddingInfo, shapedType.getElementType(),
4306  result.operands))
4307  return failure();
4308  if (hasMask.succeeded()) {
4309  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4310  return parser.emitError(
4311  maskInfo.location, "does not support masks with vector element type");
4312  if (vectorType.getRank() != permMap.getNumResults()) {
4313  return parser.emitError(typesLoc,
4314  "expected the same rank for the vector and the "
4315  "results of the permutation map");
4316  }
4317  // Instead of adding the mask type as an op type, compute it based on the
4318  // vector type and the permutation map (to keep the type signature small).
4319  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4320  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4321  return failure();
4322  }
4323  result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4324  builder.getDenseI32ArrayAttr(
4325  {1, static_cast<int32_t>(indexInfo.size()), 1,
4326  static_cast<int32_t>(hasMask.succeeded())}));
4327  return parser.addTypeToList(vectorType, result.types);
4328 }
4329 
4330 LogicalResult TransferReadOp::verify() {
4331  // Consistency of elemental types in source and vector.
4332  ShapedType shapedType = getShapedType();
4333  VectorType vectorType = getVectorType();
4334  VectorType maskType = getMaskType();
4335  auto paddingType = getPadding().getType();
4336  auto permutationMap = getPermutationMap();
4337  VectorType inferredMaskType =
4338  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4339  : VectorType();
4340  auto sourceElementType = shapedType.getElementType();
4341 
4342  if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
4343  return emitOpError("requires ") << shapedType.getRank() << " indices";
4344 
4345  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4346  shapedType, vectorType, maskType,
4347  inferredMaskType, permutationMap, getInBounds())))
4348  return failure();
4349 
4350  if (auto sourceVectorElementType =
4351  llvm::dyn_cast<VectorType>(sourceElementType)) {
4352  // Source has vector element type.
4353  // Check that 'sourceVectorElementType' and 'paddingType' types match.
4354  if (sourceVectorElementType != paddingType)
4355  return emitOpError(
4356  "requires source element type and padding type to match.");
4357 
4358  } else {
4359  // Check that 'paddingType' is valid to store in a vector type.
4360  if (!VectorType::isValidElementType(paddingType))
4361  return emitOpError("requires valid padding vector elemental type");
4362 
4363  // Check that padding type and vector element types match.
4364  if (paddingType != sourceElementType)
4365  return emitOpError(
4366  "requires formal padding and source of the same elemental type");
4367  }
4368 
4369  return verifyPermutationMap(permutationMap,
4370  [&](Twine t) { return emitOpError(t); });
4371 }
4372 
4373 // MaskableOpInterface methods.
4374 
4375 /// Returns the mask type expected by this operation. Mostly used for
4376 /// verification purposes. It requires the operation to be vectorized."
4377 Type TransferReadOp::getExpectedMaskType() {
4378  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4379 }
4380 
4381 template <typename TransferOp>
4382 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4383  // TODO: support more aggressive createOrFold on:
4384  // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4385  if (op.getShapedType().isDynamicDim(indicesIdx))
4386  return false;
4387  Value index = op.getIndices()[indicesIdx];
4388  std::optional<int64_t> cstOp = getConstantIntValue(index);
4389  if (!cstOp.has_value())
4390  return false;
4391 
4392  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4393  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4394 
4395  return cstOp.value() + vectorSize <= sourceSize;
4396 }
4397 
4398 template <typename TransferOp>
4399 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4400  // TODO: support 0-d corner case.
4401  // TODO: Be less conservative.
4402  if (op.getTransferRank() == 0)
4403  return failure();
4404  AffineMap permutationMap = op.getPermutationMap();
4405  bool changed = false;
4406  SmallVector<bool, 4> newInBounds;
4407  newInBounds.reserve(op.getTransferRank());
4408  // Idxs of non-bcast dims - used when analysing bcast dims.
4409  SmallVector<unsigned> nonBcastDims;
4410 
4411  // 1. Process non-broadcast dims
4412  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
4413  // 1.1. Already marked as in-bounds, nothing to see here.
4414  if (op.isDimInBounds(i)) {
4415  newInBounds.push_back(true);
4416  continue;
4417  }
4418  // 1.2. Currently out-of-bounds, check whether we can statically determine
4419  // it is inBounds.
4420  bool inBounds = false;
4421  auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
4422  if (dimExpr) {
4423  inBounds = isInBounds(op, /*resultIdx=*/i,
4424  /*indicesIdx=*/dimExpr.getPosition());
4425  nonBcastDims.push_back(i);
4426  }
4427 
4428  newInBounds.push_back(inBounds);
4429  // We commit the pattern if it is "more inbounds".
4430  changed |= inBounds;
4431  }
4432 
4433  // 2. Handle broadcast dims
4434  // If all non-broadcast dims are "in bounds", then all bcast dims should be
4435  // "in bounds" as well.
4436  bool allNonBcastDimsInBounds = llvm::all_of(
4437  nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
4438  if (allNonBcastDimsInBounds) {
4439  for (size_t idx : permutationMap.getBroadcastDims()) {
4440  changed |= !newInBounds[idx];
4441  newInBounds[idx] = true;
4442  }
4443  }
4444 
4445  if (!changed)
4446  return failure();
4447  // OpBuilder is only used as a helper to build an I64ArrayAttr.
4448  OpBuilder b(op.getContext());
4449  op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
4450  return success();
4451 }
4452 
4453 template <typename TransferOp>
4454 static LogicalResult foldTransferFullMask(TransferOp op) {
4455  auto mask = op.getMask();
4456  if (!mask)
4457  return failure();
4458 
4459  if (getMaskFormat(mask) != MaskFormat::AllTrue)
4460  return failure();
4461 
4462  op.getMaskMutable().clear();
4463  return success();
4464 }
4465 
4466 /// ```
4467 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4468 /// : vector<1x4xf32>, tensor<4x4xf32>
4469 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
4470 /// : tensor<4x4xf32>, vector<1x4xf32>
4471 /// ```
4472 /// -> Folds into
4473 /// ```
4474 /// %v0
4475 /// ```
4476 static Value foldRAW(TransferReadOp readOp) {
4477  if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4478  return {};
4479  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4480  while (defWrite) {
4481  if (checkSameValueRAW(defWrite, readOp))
4482  return defWrite.getVector();
4484  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4485  cast<VectorTransferOpInterface>(readOp.getOperation())))
4486  break;
4487  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4488  }
4489  return {};
4490 }
4491 
4492 OpFoldResult TransferReadOp::fold(FoldAdaptor) {
4493  if (Value vec = foldRAW(*this))
4494  return vec;
4495  /// transfer_read(memrefcast) -> transfer_read
4496  if (succeeded(foldTransferInBoundsAttribute(*this)))
4497  return getResult();
4498  if (succeeded(foldTransferFullMask(*this)))
4499  return getResult();
4500  if (succeeded(memref::foldMemRefCast(*this)))
4501  return getResult();
4502  if (succeeded(tensor::foldTensorCast(*this)))
4503  return getResult();
4504  return OpFoldResult();
4505 }
4506 
4507 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4508  return llvm::to_vector<4>(getVectorType().getShape());
4509 }
4510 
4511 void TransferReadOp::getEffects(
4513  &effects) {
4514  if (llvm::isa<MemRefType>(getShapedType()))
4515  effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(),
4517 }
4518 
4519 Speculation::Speculatability TransferReadOp::getSpeculatability() {
4520  if (hasPureTensorSemantics())
4523 }
4524 
4525 namespace {
4526 /// Store to load forwarding for transfer operations with permuation maps.
4527 /// Even if the permutation maps are different we can still propagate the store
4528 /// into the load if the size of the dimensions read and written match. Then we
4529 /// can replace the transfer_read + transfer_write by vector.broadcast and
4530 /// vector.transpose.
4531 /// Example:
4532 /// ```
4533 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
4534 /// {in_bounds = [true, true],
4535 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
4536 /// vector<4x1xf32>, tensor<4x4x4xf32>
4537 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
4538 /// {in_bounds = [true, true, true, true],
4539 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
4540 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
4541 /// ```
4542 /// To:
4543 /// ```
4544 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
4545 /// %r = vector.transpose %0, [3, 0, 2, 1] :
4546 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
4547 /// ```
4548 struct TransferReadAfterWriteToBroadcast
4549  : public OpRewritePattern<TransferReadOp> {
4551 
4552  LogicalResult matchAndRewrite(TransferReadOp readOp,
4553  PatternRewriter &rewriter) const override {
4554  if (readOp.hasOutOfBoundsDim() ||
4555  !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4556  return failure();
4557  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4558  if (!defWrite)
4559  return failure();
4560  // TODO: If the written transfer chunk is a superset of the read transfer
4561  // chunk we could do an extract_strided_slice.
4562  if (readOp.getTransferChunkAccessed() !=
4563  defWrite.getTransferChunkAccessed())
4564  return failure();
4565  // TODO: Support cases where a dim is explicitly written but implicitly
4566  // read (i.e., a unit dim that is rank reduced).
4567  if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
4568  getUnusedDimsBitVector({defWrite.getPermutationMap()}))
4569  return failure();
4570  if (readOp.getIndices() != defWrite.getIndices() ||
4571  readOp.getMask() != defWrite.getMask())
4572  return failure();
4573  Value vec = defWrite.getVector();
4574  // TODO: loop through the chain of transfer_write if we can prove that they
4575  // don't overlap with the transfer_read. This requires improving
4576  // `isDisjointTransferIndices` helper.
4577  AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4578  AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
4579  AffineMap map = readMap.compose(writeMap);
4580  if (map.getNumResults() == 0)
4581  return failure();
4582  // Calculate the permutation to apply to go from the vector stored to the
4583  // vector read.
4584  SmallVector<unsigned> permutation;
4585  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
4586  return failure();
4587 
4588  Location loc = readOp.getLoc();
4589  // Calculate the broadcast shape by applying the reverse permutation to the
4590  // final shape we want.
4591  ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
4592  SmallVector<int64_t> broadcastShape(destShape.size());
4593  SmallVector<bool> broadcastScalableFlags(destShape.size());
4594  for (const auto &pos : llvm::enumerate(permutation)) {
4595  broadcastShape[pos.value()] = destShape[pos.index()];
4596  broadcastScalableFlags[pos.value()] =
4597  readOp.getVectorType().getScalableDims()[pos.index()];
4598  }
4599  VectorType broadcastedType = VectorType::get(
4600  broadcastShape, defWrite.getVectorType().getElementType(),
4601  broadcastScalableFlags);
4602  vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
4603  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
4604  rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
4605  transposePerm);
4606  return success();
4607  }
4608 };
4609 } // namespace
4610 
4611 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4612  MLIRContext *context) {
4613  results.add<TransferReadAfterWriteToBroadcast>(context);
4614 }
4615 
4616 //===----------------------------------------------------------------------===//
4617 // TransferWriteOp
4618 //===----------------------------------------------------------------------===//
4619 
4620 /// 1. Builder with type inference.
4621 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4622  Value vector, Value dest, ValueRange indices,
4623  AffineMapAttr permutationMapAttr,
4624  /*optional*/ Value mask,
4625  /*optional*/ ArrayAttr inBoundsAttr) {
4626  Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
4627  build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4628  mask, inBoundsAttr);
4629 }
4630 
4631 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
4632 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4633  Value vector, Value dest, ValueRange indices,
4634  AffineMapAttr permutationMapAttr,
4635  /*optional*/ ArrayAttr inBoundsAttr) {
4636  build(builder, result, vector, dest, indices, permutationMapAttr,
4637  /*mask=*/Value(), inBoundsAttr);
4638 }
4639 
4640 /// 3. Builder with type inference that sets an empty mask (variant without
4641 /// attrs)
4642 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4643  Value vector, Value dest, ValueRange indices,
4644  AffineMap permutationMap,
4645  std::optional<ArrayRef<bool>> inBounds) {
4646  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4647  auto inBoundsAttr =
4648  (inBounds && !inBounds.value().empty())
4649  ? builder.getBoolArrayAttr(inBounds.value())
4651  llvm::cast<VectorType>(vector.getType()).getRank(), false));
4652  build(builder, result, vector, dest, indices, permutationMapAttr,
4653  /*mask=*/Value(), inBoundsAttr);
4654 }
4655 
4656 /// 4. Builder with type inference that sets an empty mask and sets permutation
4657 /// map to 'getMinorIdentityMap'.
4658 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4659  Value vector, Value dest, ValueRange indices,
4660  std::optional<ArrayRef<bool>> inBounds) {
4661  auto vectorType = llvm::cast<VectorType>(vector.getType());
4662  AffineMap permutationMap = getTransferMinorIdentityMap(
4663  llvm::cast<ShapedType>(dest.getType()), vectorType);
4664  build(builder, result, vector, dest, indices, permutationMap, inBounds);
4665 }
4666 
4667 ParseResult TransferWriteOp::parse(OpAsmParser &parser,
4668  OperationState &result) {
4669  auto &builder = parser.getBuilder();
4670  SMLoc typesLoc;
4671  OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
4673  SmallVector<Type, 2> types;
4675  if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
4676  parser.parseOperand(sourceInfo) ||
4677  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
4678  return failure();
4679  ParseResult hasMask = parser.parseOptionalComma();
4680  if (hasMask.succeeded() && parser.parseOperand(maskInfo))
4681  return failure();
4682  if (parser.parseOptionalAttrDict(result.attributes) ||
4683  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4684  return failure();
4685  if (types.size() != 2)
4686  return parser.emitError(typesLoc, "requires two types");
4687  auto indexType = builder.getIndexType();
4688  VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4689  if (!vectorType)
4690  return parser.emitError(typesLoc, "requires vector type");
4691  ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4692  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4693  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4694  auto permMapAttrName =
4695  TransferWriteOp::getPermutationMapAttrName(result.name);
4696  auto permMapAttr = result.attributes.get(permMapAttrName);
4697  AffineMap permMap;
4698  if (!permMapAttr) {
4699  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4700  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4701  } else {
4702  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4703  }
4704  auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
4705  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4706  if (!inBoundsAttr) {
4707  result.addAttribute(inBoundsAttrName,
4708  builder.getBoolArrayAttr(
4709  SmallVector<bool>(permMap.getNumResults(), false)));
4710  }
4711  if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
4712  parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4713  parser.resolveOperands(indexInfo, indexType, result.operands))
4714  return failure();
4715  if (hasMask.succeeded()) {
4716  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4717  return parser.emitError(
4718  maskInfo.location, "does not support masks with vector element type");
4719  if (vectorType.getRank() != permMap.getNumResults()) {
4720  return parser.emitError(typesLoc,
4721  "expected the same rank for the vector and the "
4722  "results of the permutation map");
4723  }
4724  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4725  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4726  return failure();
4727  }
4728  result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4729  builder.getDenseI32ArrayAttr(
4730  {1, 1, static_cast<int32_t>(indexInfo.size()),
4731  static_cast<int32_t>(hasMask.succeeded())}));
4732  return failure(llvm::isa<RankedTensorType>(shapedType) &&
4733  parser.addTypeToList(shapedType, result.types));
4734 }
4735 
4737  p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
4738  if (getMask())
4739  p << ", " << getMask();
4740  printTransferAttrs(p, *this);
4741  p << " : " << getVectorType() << ", " << getShapedType();
4742 }
4743 
4744 LogicalResult TransferWriteOp::verify() {
4745  // Consistency of elemental types in shape and vector.
4746  ShapedType shapedType = getShapedType();
4747  VectorType vectorType = getVectorType();
4748  VectorType maskType = getMaskType();
4749  auto permutationMap = getPermutationMap();
4750  VectorType inferredMaskType =
4751  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4752  : VectorType();
4753 
4754  if (llvm::size(getIndices()) != shapedType.getRank())
4755  return emitOpError("requires ") << shapedType.getRank() << " indices";
4756 
4757  // We do not allow broadcast dimensions on TransferWriteOps for the moment,
4758  // as the semantics is unclear. This can be revisited later if necessary.
4759  if (hasBroadcastDim())
4760  return emitOpError("should not have broadcast dimensions");
4761 
4762  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4763  shapedType, vectorType, maskType,
4764  inferredMaskType, permutationMap, getInBounds())))
4765  return failure();
4766 
4767  return verifyPermutationMap(permutationMap,
4768  [&](Twine t) { return emitOpError(t); });
4769 }
4770 
4771 // MaskableOpInterface methods.
4772 
4773 /// Returns the mask type expected by this operation. Mostly used for
4774 /// verification purposes.
4775 Type TransferWriteOp::getExpectedMaskType() {
4776  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4777 }
4778 
4779 /// Fold:
4780 /// ```
4781 /// %t1 = ...
4782 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
4783 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4784 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
4785 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4786 /// ```
4787 ///
4788 /// into:
4789 ///
4790 /// ```
4791 /// %t0
4792 /// ```
4793 ///
4794 /// The producer of t1 may or may not be DCE'd depending on whether it is a
4795 /// block argument or has side effects.
4796 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4798  SmallVectorImpl<OpFoldResult> &results) {
4799  // TODO: support 0-d corner case.
4800  if (write.getTransferRank() == 0)
4801  return failure();
4802  auto rankedTensorType =
4803  llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4804  // If not operating on tensors, bail.
4805  if (!rankedTensorType)
4806  return failure();
4807  // If no read, bail.
4808  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4809  if (!read)
4810  return failure();
4811  // TODO: support 0-d corner case.
4812  if (read.getTransferRank() == 0)
4813  return failure();
4814  // For now, only accept minor identity. Future: composition is minor identity.
4815  if (!read.getPermutationMap().isMinorIdentity() ||
4816  !write.getPermutationMap().isMinorIdentity())
4817  return failure();
4818  // Bail on mismatching ranks.
4819  if (read.getTransferRank() != write.getTransferRank())
4820  return failure();
4821  // Bail on potential out-of-bounds accesses.
4822  if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4823  return failure();
4824  // Tensor types must be the same.
4825  if (read.getSource().getType() != rankedTensorType)
4826  return failure();
4827  // Vector types must be the same.
4828  if (read.getVectorType() != write.getVectorType())
4829  return failure();
4830  // Vector and Tensor shapes must match.
4831  if (read.getVectorType().getShape() != rankedTensorType.getShape())
4832  return failure();
4833  // If any index is nonzero.
4834  auto isNotConstantZero = [](Value v) {
4835  auto cstOp = getConstantIntValue(v);
4836  return !cstOp.has_value() || cstOp.value() != 0;
4837  };
4838  if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4839  llvm::any_of(write.getIndices(), isNotConstantZero))
4840  return failure();
4841  // Success.
4842  results.push_back(read.getSource());
4843  return success();
4844 }
4845 
4846 static bool checkSameValueWAR(vector::TransferReadOp read,
4847  vector::TransferWriteOp write) {
4848  return read.getSource() == write.getSource() &&
4849  read.getIndices() == write.getIndices() &&
4850  read.getPermutationMap() == write.getPermutationMap() &&
4851  read.getVectorType() == write.getVectorType() && !read.getMask() &&
4852  !write.getMask();
4853 }
4854 /// Fold transfer_write write after read:
4855 /// ```
4856 /// %t0 = ...
4857 /// %v = vector.transfer_read %t0[%c0...] :
4858 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4859 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
4860 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4861 /// ```
4862 ///
4863 /// into:
4864 ///
4865 /// ```
4866 /// %t0
4867 /// ```
4868 static LogicalResult foldWAR(TransferWriteOp write,
4869  SmallVectorImpl<OpFoldResult> &results) {
4870  if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4871  return failure();
4872  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4873  if (!read)
4874  return failure();
4875 
4876  if (!checkSameValueWAR(read, write))
4877  return failure();
4878  results.push_back(read.getSource());
4879  return success();
4880 }
4881 
4882 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4883  SmallVectorImpl<OpFoldResult> &results) {
4884  if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
4885  return success();
4886  if (succeeded(foldWAR(*this, results)))
4887  return success();
4888  if (succeeded(foldTransferInBoundsAttribute(*this)))
4889  return success();
4890  if (succeeded(foldTransferFullMask(*this)))
4891  return success();
4892  return memref::foldMemRefCast(*this);
4893 }
4894 
4895 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4896  return llvm::to_vector<4>(getVectorType().getShape());
4897 }
4898 
4899 void TransferWriteOp::getEffects(
4901  &effects) {
4902  if (llvm::isa<MemRefType>(getShapedType()))
4903  effects.emplace_back(MemoryEffects::Write::get(), &getSourceMutable(),
4905 }
4906 
4907 Speculation::Speculatability TransferWriteOp::getSpeculatability() {
4908  if (hasPureTensorSemantics())
4911 }
4912 
4913 namespace {
4914 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
4915 /// DCE
4916 /// ```
4917 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4918 /// : vector<1x4xf32>, tensor<4x4xf32>
4919 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
4920 /// : vector<1x4xf32>, tensor<4x4xf32>
4921 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4922 /// : vector<1x4xf32>, tensor<4x4xf32>
4923 /// ```
4924 ///
4925 /// into:
4926 ///
4927 /// ```
4928 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4929 /// : vector<1x4xf32>, tensor<4x4xf32>
4930 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
4931 /// : vector<1x4xf32>, tensor<4x4xf32>
4932 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4933 /// : vector<1x4xf32>, tensor<4x4xf32>
4934 /// ```
4935 ///
4936 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
4937 /// any other uses.
4938 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
4939 public:
4941  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
4942  PatternRewriter &rewriter) const override {
4943  if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4944  return failure();
4945  vector::TransferWriteOp writeToModify = writeOp;
4946 
4947  auto defWrite =
4948  writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4949  while (defWrite) {
4950  if (checkSameValueWAW(writeOp, defWrite)) {
4951  rewriter.modifyOpInPlace(writeToModify, [&]() {
4952  writeToModify.getSourceMutable().assign(defWrite.getSource());
4953  });
4954  return success();
4955  }
4957  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4958  cast<VectorTransferOpInterface>(writeOp.getOperation())))
4959  break;
4960  // If the previous write op doesn't have any other use we an safely look
4961  // at the previous store to see if it can be removed.
4962  if (!defWrite->hasOneUse())
4963  break;
4964  writeToModify = defWrite;
4965  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4966  }
4967  return failure();
4968  }
4969 };
4970 
4971 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
4972 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
4973 /// overwritten and inserted into another tensor. After this rewrite, the
4974 /// operations bufferize in-place since all of them work on the same slice.
4975 ///
4976 /// For example:
4977 /// ```mlir
4978 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
4979 /// : vector<8x16xf32>, tensor<8x16xf32>
4980 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
4981 /// : tensor<8x16xf32> to tensor<?x?xf32>
4982 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4983 /// : tensor<?x?xf32> into tensor<27x37xf32>
4984 /// ```
4985 /// folds to
4986 /// ```mlir
4987 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4988 /// : tensor<27x37xf32> to tensor<?x?xf32>
4989 /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
4990 /// : vector<8x16xf32>, tensor<?x?xf32>
4991 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4992 /// : tensor<?x?xf32> into tensor<27x37xf32>
4993 /// ```
4994 struct SwapExtractSliceOfTransferWrite
4995  : public OpRewritePattern<tensor::InsertSliceOp> {
4996 public:
4998 
4999  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5000  PatternRewriter &rewriter) const override {
5001  if (!insertOp.hasUnitStride())
5002  return failure();
5003  auto extractOp =
5004  insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5005  if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5006  return failure();
5007  auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5008  if (!transferOp || !transferOp->hasOneUse())
5009  return failure();
5010 
5011  // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
5012  // rank-reducing.
5013  if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5014  return rewriter.notifyMatchFailure(insertOp,
5015  "use-def chain is rank-reducing");
5016  }
5017 
5018  // Fail if tensor::ExtractSliceOp has non-zero offset.
5019  if (!extractOp.hasZeroOffset()) {
5020  return rewriter.notifyMatchFailure(insertOp,
5021  "ExtractSliceOp has non-zero offset");
5022  }
5023 
5024  // Fail if tensor::TransferWriteOp has non-zero offset.
5025  if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5026  return getConstantIntValue(value) == static_cast<int64_t>(0);
5027  })) {
5028  return rewriter.notifyMatchFailure(insertOp,
5029  "TranferWriteOp has non-zero offset");
5030  }
5031 
5032  // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
5033  if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5034  return rewriter.notifyMatchFailure(
5035  insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
5036  }
5037 
5038  for (auto [insertSize, extractSize] :
5039  llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5040  if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
5041  return rewriter.notifyMatchFailure(
5042  insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
5043  }
5044  }
5045 
5046  // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
5047  assert(transferOp.getVectorType().hasStaticShape() &&
5048  "expected vector to have a static shape");
5049  ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
5051  transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5052  if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
5053  return rewriter.notifyMatchFailure(
5054  insertOp, "TransferWriteOp may not write the full tensor.");
5055  }
5056 
5057  // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
5058  // Set all in_bounds to false and let the folder infer them.
5059  SmallVector<bool> newInBounds(vectorShape.size(), false);
5060  auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
5061  extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
5062  insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
5063  insertOp.getMixedStrides());
5064  auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
5065  transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
5066  transferOp.getIndices(), transferOp.getPermutationMapAttr(),
5067  rewriter.getBoolArrayAttr(newInBounds));
5068  rewriter.modifyOpInPlace(insertOp, [&]() {
5069  insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5070  });
5071  return success();
5072  }
5073 };
5074 
5075 } // namespace
5076 
5077 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5078  MLIRContext *context) {
5079  results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5080 }
5081 
5082 //===----------------------------------------------------------------------===//
5083 // LoadOp
5084 //===----------------------------------------------------------------------===//
5085 
5086 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5087  VectorType vecTy,
5088  MemRefType memRefTy) {
5089  // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
5090  // need any strides limitations.
5091  if (!vecTy.isScalable() &&
5092  (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5093  return success();
5094 
5095  if (!memRefTy.isLastDimUnitStride())
5096  return op->emitOpError("most minor memref dim must have unit stride");
5097  return success();
5098 }
5099 
5100 LogicalResult vector::LoadOp::verify() {
5101  VectorType resVecTy = getVectorType();
5102  MemRefType memRefTy = getMemRefType();
5103 
5104  if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
5105  return failure();
5106 
5107  // Checks for vector memrefs.
5108  Type memElemTy = memRefTy.getElementType();
5109  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5110  if (memVecTy != resVecTy)
5111  return emitOpError("base memref and result vector types should match");
5112  memElemTy = memVecTy.getElementType();
5113  }
5114 
5115  if (resVecTy.getElementType() != memElemTy)
5116  return emitOpError("base and result element types should match");
5117  if (llvm::size(getIndices()) != memRefTy.getRank())
5118  return emitOpError("requires ") << memRefTy.getRank() << " indices";
5119  return success();
5120 }
5121 
5122 OpFoldResult LoadOp::fold(FoldAdaptor) {
5123  if (succeeded(memref::foldMemRefCast(*this)))
5124  return getResult();
5125  return OpFoldResult();
5126 }
5127 
5128 //===----------------------------------------------------------------------===//
5129 // StoreOp
5130 //===----------------------------------------------------------------------===//
5131 
5132 LogicalResult vector::StoreOp::verify() {
5133  VectorType valueVecTy = getVectorType();
5134  MemRefType memRefTy = getMemRefType();
5135 
5136  if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
5137  return failure();
5138 
5139  // Checks for vector memrefs.
5140  Type memElemTy = memRefTy.getElementType();
5141  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5142  if (memVecTy != valueVecTy)
5143  return emitOpError(
5144  "base memref and valueToStore vector types should match");
5145  memElemTy = memVecTy.getElementType();
5146  }
5147 
5148  if (valueVecTy.getElementType() != memElemTy)
5149  return emitOpError("base and valueToStore element type should match");
5150  if (llvm::size(getIndices()) != memRefTy.getRank())
5151  return emitOpError("requires ") << memRefTy.getRank() << " indices";
5152  return success();
5153 }
5154 
5155 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5156  SmallVectorImpl<OpFoldResult> &results) {
5157  return memref::foldMemRefCast(*this);
5158 }
5159 
5160 //===----------------------------------------------------------------------===//
5161 // MaskedLoadOp
5162 //===----------------------------------------------------------------------===//
5163 
5164 LogicalResult MaskedLoadOp::verify() {
5165  VectorType maskVType = getMaskVectorType();
5166  VectorType passVType = getPassThruVectorType();
5167  VectorType resVType = getVectorType();
5168  MemRefType memType = getMemRefType();
5169 
5170  if (resVType.getElementType() != memType.getElementType())
5171  return emitOpError("base and result element type should match");
5172  if (llvm::size(getIndices()) != memType.getRank())
5173  return emitOpError("requires ") << memType.getRank() << " indices";
5174  if (resVType.getShape() != maskVType.getShape())
5175  return emitOpError("expected result shape to match mask shape");
5176  if (resVType != passVType)
5177  return emitOpError("expected pass_thru of same type as result type");
5178  return success();
5179 }
5180 
5181 namespace {
5182 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
5183 public:
5185  LogicalResult matchAndRewrite(MaskedLoadOp load,
5186  PatternRewriter &rewriter) const override {
5187  switch (getMaskFormat(load.getMask())) {
5188  case MaskFormat::AllTrue:
5189  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5190  load, load.getType(), load.getBase(), load.getIndices());
5191  return success();
5192  case MaskFormat::AllFalse:
5193  rewriter.replaceOp(load, load.getPassThru());
5194  return success();
5195  case MaskFormat::Unknown:
5196  return failure();
5197  }
5198  llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
5199  }
5200 };
5201 } // namespace
5202 
5203 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5204  MLIRContext *context) {
5205  results.add<MaskedLoadFolder>(context);
5206 }
5207 
5208 OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5209  if (succeeded(memref::foldMemRefCast(*this)))
5210  return getResult();
5211  return OpFoldResult();
5212 }
5213 
5214 //===----------------------------------------------------------------------===//
5215 // MaskedStoreOp
5216 //===----------------------------------------------------------------------===//
5217 
5218 LogicalResult MaskedStoreOp::verify() {
5219  VectorType maskVType = getMaskVectorType();
5220  VectorType valueVType = getVectorType();
5221  MemRefType memType = getMemRefType();
5222 
5223  if (valueVType.getElementType() != memType.getElementType())
5224  return emitOpError("base and valueToStore element type should match");
5225  if (llvm::size(getIndices()) != memType.getRank())
5226  return emitOpError("requires ") << memType.getRank() << " indices";
5227  if (valueVType.getShape() != maskVType.getShape())
5228  return emitOpError("expected valueToStore shape to match mask shape");
5229  return success();
5230 }
5231 
5232 namespace {
5233 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
5234 public:
5236  LogicalResult matchAndRewrite(MaskedStoreOp store,
5237  PatternRewriter &rewriter) const override {
5238  switch (getMaskFormat(store.getMask())) {
5239  case MaskFormat::AllTrue:
5240  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5241  store, store.getValueToStore(), store.getBase(), store.getIndices());
5242  return success();
5243  case MaskFormat::AllFalse:
5244  rewriter.eraseOp(store);
5245  return success();
5246  case MaskFormat::Unknown:
5247  return failure();
5248  }
5249  llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
5250  }
5251 };
5252 } // namespace
5253 
5254 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5255  MLIRContext *context) {
5256  results.add<MaskedStoreFolder>(context);
5257 }
5258 
5259 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5260  SmallVectorImpl<OpFoldResult> &results) {
5261  return memref::foldMemRefCast(*this);
5262 }
5263 
5264 //===----------------------------------------------------------------------===//
5265 // GatherOp
5266 //===----------------------------------------------------------------------===//
5267 
5268 LogicalResult GatherOp::verify() {
5269  VectorType indVType = getIndexVectorType();
5270  VectorType maskVType = getMaskVectorType();
5271  VectorType resVType = getVectorType();
5272  ShapedType baseType = getBaseType();
5273 
5274  if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5275  return emitOpError("requires base to be a memref or ranked tensor type");
5276 
5277  if (resVType.getElementType() != baseType.getElementType())
5278  return emitOpError("base and result element type should match");
5279  if (llvm::size(getIndices()) != baseType.getRank())
5280  return emitOpError("requires ") << baseType.getRank() << " indices";
5281  if (resVType.getShape() != indVType.getShape())
5282  return emitOpError("expected result dim to match indices dim");
5283  if (resVType.getShape() != maskVType.getShape())
5284  return emitOpError("expected result dim to match mask dim");
5285  if (resVType != getPassThruVectorType())
5286  return emitOpError("expected pass_thru of same type as result type");
5287  return success();
5288 }
5289 
5290 // MaskableOpInterface methods.
5291 
5292 /// Returns the mask type expected by this operation. Mostly used for
5293 /// verification purposes. It requires the operation to be vectorized."
5294 Type GatherOp::getExpectedMaskType() {
5295  auto vecType = this->getIndexVectorType();
5296  return VectorType::get(vecType.getShape(),
5297  IntegerType::get(vecType.getContext(), /*width=*/1),
5298  vecType.getScalableDims());
5299 }
5300 
5301 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5302  return llvm::to_vector<4>(getVectorType().getShape());
5303 }
5304 
5305 /// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
5306 static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
5307  auto vecType = dyn_cast<VectorType>(indexVec.getType());
5308  if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5309  return failure();
5310 
5311  if (indexVec.getDefiningOp<StepOp>())
5312  return success();
5313 
5314  DenseIntElementsAttr elements;
5315  if (!matchPattern(indexVec, m_Constant(&elements)))
5316  return failure();
5317 
5318  return success(
5319  llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
5320 }
5321 
5322 namespace {
5323 class GatherFolder final : public OpRewritePattern<GatherOp> {
5324 public:
5326  LogicalResult matchAndRewrite(GatherOp gather,
5327  PatternRewriter &rewriter) const override {
5328  switch (getMaskFormat(gather.getMask())) {
5329  case MaskFormat::AllTrue:
5330  return failure(); // no unmasked equivalent
5331  case MaskFormat::AllFalse:
5332  rewriter.replaceOp(gather, gather.getPassThru());
5333  return success();
5334  case MaskFormat::Unknown:
5335  return failure();
5336  }
5337  llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
5338  }
5339 };
5340 
5341 /// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
5342 /// maskedload. Only 1D fixed vectors are supported for now.
5343 class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
5344 public:
5346  LogicalResult matchAndRewrite(GatherOp op,
5347  PatternRewriter &rewriter) const override {
5348  if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5349  return failure();
5350 
5351  rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
5352  op.getIndices(), op.getMask(),
5353  op.getPassThru());
5354  return success();
5355  }
5356 };
5357 } // namespace
5358 
5359 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
5360  MLIRContext *context) {
5361  results.add<GatherFolder, FoldContiguousGather>(context);
5362 }
5363 
5364 //===----------------------------------------------------------------------===//
5365 // ScatterOp
5366 //===----------------------------------------------------------------------===//
5367 
5368 LogicalResult ScatterOp::verify() {
5369  VectorType indVType = getIndexVectorType();
5370  VectorType maskVType = getMaskVectorType();
5371  VectorType valueVType = getVectorType();
5372  MemRefType memType = getMemRefType();
5373 
5374  if (valueVType.getElementType() != memType.getElementType())
5375  return emitOpError("base and valueToStore element type should match");
5376  if (llvm::size(getIndices()) != memType.getRank())
5377  return emitOpError("requires ") << memType.getRank() << " indices";
5378  if (valueVType.getDimSize(0) != indVType.getDimSize(0))
5379  return emitOpError("expected valueToStore dim to match indices dim");
5380  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5381  return emitOpError("expected valueToStore dim to match mask dim");
5382  return success();
5383 }
5384 
5385 namespace {
5386 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
5387 public:
5389  LogicalResult matchAndRewrite(ScatterOp scatter,
5390  PatternRewriter &rewriter) const override {
5391  switch (getMaskFormat(scatter.getMask())) {
5392  case MaskFormat::AllTrue:
5393  return failure(); // no unmasked equivalent
5394  case MaskFormat::AllFalse:
5395  rewriter.eraseOp(scatter);
5396  return success();
5397  case MaskFormat::Unknown:
5398  return failure();
5399  }
5400  llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
5401  }
5402 };
5403 
5404 /// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
5405 /// maskedstore. Only 1D fixed vectors are supported for now.
5406 class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
5407 public:
5409  LogicalResult matchAndRewrite(ScatterOp op,
5410  PatternRewriter &rewriter) const override {
5411  if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5412  return failure();
5413 
5414  rewriter.replaceOpWithNewOp<MaskedStoreOp>(
5415  op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
5416  return success();
5417  }
5418 };
5419 } // namespace
5420 
5421 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
5422  MLIRContext *context) {
5423  results.add<ScatterFolder, FoldContiguousScatter>(context);
5424 }
5425 
5426 //===----------------------------------------------------------------------===//
5427 // ExpandLoadOp
5428 //===----------------------------------------------------------------------===//
5429 
5430 LogicalResult ExpandLoadOp::verify() {
5431  VectorType maskVType = getMaskVectorType();
5432  VectorType passVType = getPassThruVectorType();
5433  VectorType resVType = getVectorType();
5434  MemRefType memType = getMemRefType();
5435 
5436  if (resVType.getElementType() != memType.getElementType())
5437  return emitOpError("base and result element type should match");
5438  if (llvm::size(getIndices()) != memType.getRank())
5439  return emitOpError("requires ") << memType.getRank() << " indices";
5440  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5441  return emitOpError("expected result dim to match mask dim");
5442  if (resVType != passVType)
5443  return emitOpError("expected pass_thru of same type as result type");
5444  return success();
5445 }
5446 
5447 namespace {
5448 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
5449 public:
5451  LogicalResult matchAndRewrite(ExpandLoadOp expand,
5452  PatternRewriter &rewriter) const override {
5453  switch (getMaskFormat(expand.getMask())) {
5454  case MaskFormat::AllTrue:
5455  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5456  expand, expand.getType(), expand.getBase(), expand.getIndices());
5457  return success();
5458  case MaskFormat::AllFalse:
5459  rewriter.replaceOp(expand, expand.getPassThru());
5460  return success();
5461  case MaskFormat::Unknown:
5462  return failure();
5463  }
5464  llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
5465  }
5466 };
5467 } // namespace
5468 
5469 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5470  MLIRContext *context) {
5471  results.add<ExpandLoadFolder>(context);
5472 }
5473 
5474 //===----------------------------------------------------------------------===//
5475 // CompressStoreOp
5476 //===----------------------------------------------------------------------===//
5477 
5478 LogicalResult CompressStoreOp::verify() {
5479  VectorType maskVType = getMaskVectorType();
5480  VectorType valueVType = getVectorType();
5481  MemRefType memType = getMemRefType();
5482 
5483  if (valueVType.getElementType() != memType.getElementType())
5484  return emitOpError("base and valueToStore element type should match");
5485  if (llvm::size(getIndices()) != memType.getRank())
5486  return emitOpError("requires ") << memType.getRank() << " indices";
5487  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5488  return emitOpError("expected valueToStore dim to match mask dim");
5489  return success();
5490 }
5491 
5492 namespace {
5493 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
5494 public:
5496  LogicalResult matchAndRewrite(CompressStoreOp compress,
5497  PatternRewriter &rewriter) const override {
5498  switch (getMaskFormat(compress.getMask())) {
5499  case MaskFormat::AllTrue:
5500  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5501  compress, compress.getValueToStore(), compress.getBase(),
5502  compress.getIndices());
5503  return success();
5504  case MaskFormat::AllFalse:
5505  rewriter.eraseOp(compress);
5506  return success();
5507  case MaskFormat::Unknown:
5508  return failure();
5509  }
5510  llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
5511  }
5512 };
5513 } // namespace
5514 
5515 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5516  MLIRContext *context) {
5517  results.add<CompressStoreFolder>(context);
5518 }
5519 
5520 //===----------------------------------------------------------------------===//
5521 // ShapeCastOp
5522 //===----------------------------------------------------------------------===//
5523 
5524 void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
5525  SetIntRangeFn setResultRanges) {
5526  setResultRanges(getResult(), argRanges.front());
5527 }
5528 
5529 /// Returns true if each element of 'a' is equal to the product of a contiguous
5530 /// sequence of the elements of 'b'. Returns false otherwise.
5531 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
5532  unsigned rankA = a.size();
5533  unsigned rankB = b.size();
5534  assert(rankA < rankB);
5535 
5536  auto isOne = [](int64_t v) { return v == 1; };
5537 
5538  // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5539  // casted to a 0-d vector.
5540  if (rankA == 0 && llvm::all_of(b, isOne))
5541  return true;
5542 
5543  unsigned i = 0;
5544  unsigned j = 0;
5545  while (i < rankA && j < rankB) {
5546  int64_t dimA = a[i];
5547  int64_t dimB = 1;
5548  while (dimB < dimA && j < rankB)
5549  dimB *= b[j++];
5550  if (dimA != dimB)
5551  break;
5552  ++i;
5553 
5554  // Handle the case when trailing dimensions are of size 1.
5555  // Include them into the contiguous sequence.
5556  if (i < rankA && llvm::all_of(a.slice(i), isOne))
5557  i = rankA;
5558  if (j < rankB && llvm::all_of(b.slice(j), isOne))
5559  j = rankB;
5560  }
5561 
5562  return i == rankA && j == rankB;
5563 }
5564 
5565 static LogicalResult verifyVectorShapeCast(Operation *op,
5566  VectorType sourceVectorType,
5567  VectorType resultVectorType) {
5568  // Check that element type is the same.
5569  if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5570  return op->emitOpError("source/result vectors must have same element type");
5571  auto sourceShape = sourceVectorType.getShape();
5572  auto resultShape = resultVectorType.getShape();
5573 
5574  // Check that product of source dim sizes matches product of result dim sizes.
5575  int64_t sourceDimProduct = std::accumulate(
5576  sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5577  int64_t resultDimProduct = std::accumulate(
5578  resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5579  if (sourceDimProduct != resultDimProduct)
5580  return op->emitOpError("source/result number of elements must match");
5581 
5582  // Check that expanding/contracting rank cases.
5583  unsigned sourceRank = sourceVectorType.getRank();
5584  unsigned resultRank = resultVectorType.getRank();
5585  if (sourceRank < resultRank) {
5586  if (!isValidShapeCast(sourceShape, resultShape))
5587  return op->emitOpError("invalid shape cast");
5588  } else if (sourceRank > resultRank) {
5589  if (!isValidShapeCast(resultShape, sourceShape))
5590  return op->emitOpError("invalid shape cast");
5591  }
5592 
5593  // Check that (non-)scalability is preserved
5594  int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5595  int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5596  if (sourceNScalableDims != resultNScalableDims)
5597  return op->emitOpError("different number of scalable dims at source (")
5598  << sourceNScalableDims << ") and result (" << resultNScalableDims
5599  << ")";
5600  sourceVectorType.getNumDynamicDims();
5601 
5602  return success();
5603 }
5604 
5605 LogicalResult ShapeCastOp::verify() {
5606  auto sourceVectorType =
5607  llvm::dyn_cast_or_null<VectorType>(getSource().getType());
5608  auto resultVectorType =
5609  llvm::dyn_cast_or_null<VectorType>(getResult().getType());
5610 
5611  // Check if source/result are of vector type.
5612  if (sourceVectorType && resultVectorType)
5613  return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
5614 
5615  return success();
5616 }
5617 
5618 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5619  // No-op shape cast.
5620  if (getSource().getType() == getResult().getType())
5621  return getSource();
5622 
5623  // Canceling shape casts.
5624  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5625  if (getResult().getType() == otherOp.getSource().getType())
5626  return otherOp.getSource();
5627 
5628  // Only allows valid transitive folding.
5629  VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5630  VectorType resultType = llvm::cast<VectorType>(getResult().getType());
5631  if (srcType.getRank() < resultType.getRank()) {
5632  if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5633  return {};
5634  } else if (srcType.getRank() > resultType.getRank()) {
5635  if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5636  return {};
5637  } else {
5638  return {};
5639  }
5640 
5641  setOperand(otherOp.getSource());
5642  return getResult();
5643  }
5644 
5645  // Cancelling broadcast and shape cast ops.
5646  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5647  if (bcastOp.getSourceType() == getType())
5648  return bcastOp.getSource();
5649  }
5650 
5651  return {};
5652 }
5653 
5654 namespace {
5655 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
5656 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
5657 public:
5659 
5660  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5661  PatternRewriter &rewriter) const override {
5662  auto constantOp =
5663  shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5664  if (!constantOp)
5665  return failure();
5666  // Only handle splat for now.
5667  auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5668  if (!dense)
5669  return failure();
5670  auto newAttr =
5671  DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
5672  dense.getSplatValue<Attribute>());
5673  rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
5674  return success();
5675  }
5676 };
5677 
5678 /// Helper function that computes a new vector type based on the input vector
5679 /// type by removing the trailing one dims:
5680 ///
5681 /// vector<4x1x1xi1> --> vector<4x1>
5682 ///
5683 static VectorType trimTrailingOneDims(VectorType oldType) {
5684  ArrayRef<int64_t> oldShape = oldType.getShape();
5685  ArrayRef<int64_t> newShape = oldShape;
5686 
5687  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
5688  ArrayRef<bool> newScalableDims = oldScalableDims;
5689 
5690  while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5691  newShape = newShape.drop_back(1);
5692  newScalableDims = newScalableDims.drop_back(1);
5693  }
5694 
5695  // Make sure we have at least 1 dimension.
5696  // TODO: Add support for 0-D vectors.
5697  if (newShape.empty()) {
5698  newShape = oldShape.take_back();
5699  newScalableDims = oldScalableDims.take_back();
5700  }
5701 
5702  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5703 }
5704 
5705 /// Folds qualifying shape_cast(create_mask) into a new create_mask
5706 ///
5707 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
5708 /// dimension. If the input vector comes from `vector.create_mask` for which
5709 /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
5710 /// to fold shape_cast into create_mask.
5711 ///
5712 /// BEFORE:
5713 /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
5714 /// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
5715 /// AFTER:
5716 /// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
5717 class ShapeCastCreateMaskFolderTrailingOneDim final
5718  : public OpRewritePattern<ShapeCastOp> {
5719 public:
5721 
5722  LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
5723  PatternRewriter &rewriter) const override {
5724  Value shapeOpSrc = shapeOp->getOperand(0);
5725  auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
5726  auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
5727  if (!createMaskOp && !constantMaskOp)
5728  return failure();
5729 
5730  VectorType shapeOpResTy = shapeOp.getResultVectorType();
5731  VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5732 
5733  VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5734  if (newVecType != shapeOpResTy)
5735  return failure();
5736 
5737  auto numDimsToDrop =
5738  shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5739 
5740  // No unit dims to drop
5741  if (!numDimsToDrop)
5742  return failure();
5743 
5744  if (createMaskOp) {
5745  auto maskOperands = createMaskOp.getOperands();
5746  auto numMaskOperands = maskOperands.size();
5747 
5748  // Check every mask dim size to see whether it can be dropped
5749  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5750  --i) {
5751  auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5752  if (!constant || (constant.value() != 1))
5753  return failure();
5754  }
5755  SmallVector<Value> newMaskOperands =
5756  maskOperands.drop_back(numDimsToDrop);
5757 
5758  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
5759  newMaskOperands);
5760  return success();
5761  }
5762 
5763  if (constantMaskOp) {
5764  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5765  auto numMaskOperands = maskDimSizes.size();
5766 
5767  // Check every mask dim size to see whether it can be dropped
5768  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5769  --i) {
5770  if (maskDimSizes[i] != 1)
5771  return failure();
5772  }
5773 
5774  auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5775  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
5776  newMaskOperands);
5777  return success();
5778  }
5779 
5780  return failure();
5781  }
5782 };
5783 
5784 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
5785 /// This only applies when the shape of the broadcast source
5786 /// 1. is a suffix of the shape of the result (i.e. when broadcast without
5787 /// reshape is expressive enough to capture the result in a single op), or
5788 /// 2. has the same element count as the shape cast result.
5789 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5790 public:
5792 
5793  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5794  PatternRewriter &rewriter) const override {
5795  auto broadcastOp =
5796  shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5797  if (!broadcastOp)
5798  return failure();
5799 
5800  ArrayRef<int64_t> broadcastSourceShape;
5801  if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5802  broadcastSourceShape = srcType.getShape();
5803  ArrayRef<int64_t> shapeCastTargetShape =
5804  shapeCastOp.getResultVectorType().getShape();
5805 
5806  // If `broadcastSourceShape` is a suffix of the result, we can just replace
5807  // with a broadcast to the final shape.
5808  if (broadcastSourceShape ==
5809  shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5810  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5811  shapeCastOp, shapeCastOp.getResultVectorType(),
5812  broadcastOp.getSource());
5813  return success();
5814  }
5815 
5816  // Otherwise, if the final result has the same element count, we can replace
5817  // with a shape cast.
5818  if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5819  if (srcType.getNumElements() ==
5820  shapeCastOp.getResultVectorType().getNumElements()) {
5821  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
5822  shapeCastOp, shapeCastOp.getResultVectorType(),
5823  broadcastOp.getSource());
5824  return success();
5825  }
5826  }
5827 
5828  return failure();
5829  }
5830 };
5831 
5832 } // namespace
5833 
5834 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
5835  MLIRContext *context) {
5836  results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5837  ShapeCastBroadcastFolder>(context);
5838 }
5839 
5840 //===----------------------------------------------------------------------===//
5841 // VectorBitCastOp
5842 //===----------------------------------------------------------------------===//
5843 
5844 LogicalResult BitCastOp::verify() {
5845  auto sourceVectorType = getSourceVectorType();
5846  auto resultVectorType = getResultVectorType();
5847 
5848  for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5849  if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5850  return emitOpError("dimension size mismatch at: ") << i;
5851  }
5852 
5853  DataLayout dataLayout = DataLayout::closest(*this);
5854  auto sourceElementBits =
5855  dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
5856  auto resultElementBits =
5857  dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
5858 
5859  if (sourceVectorType.getRank() == 0) {
5860  if (sourceElementBits != resultElementBits)
5861  return emitOpError("source/result bitwidth of the 0-D vector element "
5862  "types must be equal");
5863  } else if (sourceElementBits * sourceVectorType.getShape().back() !=
5864  resultElementBits * resultVectorType.getShape().back()) {
5865  return emitOpError(
5866  "source/result bitwidth of the minor 1-D vectors must be equal");
5867  }
5868 
5869  return success();
5870 }
5871 
5872 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
5873  // Nop cast.
5874  if (getSource().getType() == getResult().getType())
5875  return getSource();
5876 
5877  // Canceling bitcasts.
5878  if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5879  if (getResult().getType() == otherOp.getSource().getType())
5880  return otherOp.getSource();
5881 
5882  setOperand(otherOp.getSource());
5883  return getResult();
5884  }
5885 
5886  Attribute sourceConstant = adaptor.getSource();
5887  if (!sourceConstant)
5888  return {};
5889 
5890  Type srcElemType = getSourceVectorType().getElementType();
5891  Type dstElemType = getResultVectorType().getElementType();
5892 
5893  if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5894  if (floatPack.isSplat()) {
5895  auto splat = floatPack.getSplatValue<FloatAttr>();
5896 
5897  // Casting fp16 into fp32.
5898  if (srcElemType.isF16() && dstElemType.isF32()) {
5899  uint32_t bits = static_cast<uint32_t>(
5900  splat.getValue().bitcastToAPInt().getZExtValue());
5901  // Duplicate the 16-bit pattern.
5902  bits = (bits << 16) | (bits & 0xffff);
5903  APInt intBits(32, bits);
5904  APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5905  return DenseElementsAttr::get(getResultVectorType(), floatBits);
5906  }
5907  }
5908  }
5909 
5910  if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5911  if (intPack.isSplat()) {
5912  auto splat = intPack.getSplatValue<IntegerAttr>();
5913 
5914  if (llvm::isa<IntegerType>(dstElemType)) {
5915  uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
5916  uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
5917 
5918  // Casting to a larger integer bit width.
5919  if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5920  APInt intBits = splat.getValue().zext(dstBitWidth);
5921 
5922  // Duplicate the lower width element.
5923  for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5924  intBits = (intBits << srcBitWidth) | intBits;
5925  return DenseElementsAttr::get(getResultVectorType(), intBits);
5926  }
5927  }
5928  }
5929  }
5930 
5931  return {};
5932 }
5933 
5934 //===----------------------------------------------------------------------===//
5935 // TypeCastOp
5936 //===----------------------------------------------------------------------===//
5937 
5938 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
5939  auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5940  SmallVector<int64_t, 8> res(memRefType.getShape());
5941  if (vectorType)
5942  res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5943  return res;
5944 }
5945 
5946 /// Build the canonical memRefType with a single vector.
5947 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
5948 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
5949  Value source) {
5950  result.addOperands(source);
5951  MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
5952  VectorType vectorType =
5953  VectorType::get(extractShape(memRefType),
5955  result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
5956  memRefType.getMemorySpace()));
5957 }
5958 
5959 LogicalResult TypeCastOp::verify() {
5960  MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
5961  if (!canonicalType.getLayout().isIdentity())
5962  return emitOpError("expects operand to be a memref with identity layout");
5963  if (!getResultMemRefType().getLayout().isIdentity())
5964  return emitOpError("expects result to be a memref with identity layout");
5965  if (getResultMemRefType().getMemorySpace() !=
5966  getMemRefType().getMemorySpace())
5967  return emitOpError("expects result in same memory space");
5968 
5969  auto sourceType = getMemRefType();
5970  auto resultType = getResultMemRefType();
5971  if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
5973  return emitOpError(
5974  "expects result and operand with same underlying scalar type: ")
5975  << resultType;
5976  if (extractShape(sourceType) != extractShape(resultType))
5977  return emitOpError(
5978  "expects concatenated result and operand shapes to be equal: ")
5979  << resultType;
5980  return success();
5981 }
5982 
5983 //===----------------------------------------------------------------------===//
5984 // TransposeOp
5985 //===----------------------------------------------------------------------===//
5986 
5987 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
5988  Value vector, ArrayRef<int64_t> permutation) {
5989  VectorType vt = llvm::cast<VectorType>(vector.getType());
5990  SmallVector<int64_t, 4> transposedShape(vt.getRank());
5991  SmallVector<bool, 4> transposedScalableDims(vt.getRank());
5992  for (unsigned i = 0; i < permutation.size(); ++i) {
5993  transposedShape[i] = vt.getShape()[permutation[i]];
5994  transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5995  }
5996 
5997  result.addOperands(vector);
5998  result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
5999  transposedScalableDims));
6000  result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
6001  builder.getDenseI64ArrayAttr(permutation));
6002 }
6003 
6004 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6005  // Eliminate splat constant transpose ops.
6006  if (auto attr =
6007  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
6008  if (attr.isSplat())
6009  return attr.reshape(getResultVectorType());
6010 
6011  // Eliminate identity transpose ops. This happens when the dimensions of the
6012  // input vector remain in their original order after the transpose operation.
6013  ArrayRef<int64_t> perm = getPermutation();
6014 
6015  // Check if the permutation of the dimensions contains sequential values:
6016  // {0, 1, 2, ...}.
6017  for (int64_t i = 0, e = perm.size(); i < e; i++) {
6018  if (perm[i] != i)
6019  return {};
6020  }
6021 
6022  return getVector();
6023 }
6024 
6025 LogicalResult vector::TransposeOp::verify() {
6026  VectorType vectorType = getSourceVectorType();
6027  VectorType resultType = getResultVectorType();
6028  int64_t rank = resultType.getRank();
6029  if (vectorType.getRank() != rank)
6030  return emitOpError("vector result rank mismatch: ") << rank;
6031  // Verify transposition array.
6032  ArrayRef<int64_t> perm = getPermutation();
6033  int64_t size = perm.size();
6034  if (rank != size)
6035  return emitOpError("transposition length mismatch: ") << size;
6036  SmallVector<bool, 8> seen(rank, false);
6037  for (const auto &ta : llvm::enumerate(perm)) {
6038  if (ta.value() < 0 || ta.value() >= rank)
6039  return emitOpError("transposition index out of range: ") << ta.value();
6040  if (seen[ta.value()])
6041  return emitOpError("duplicate position index: ") << ta.value();
6042  seen[ta.value()] = true;
6043  if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6044  return emitOpError("dimension size mismatch at: ") << ta.value();
6045  }
6046  return success();
6047 }
6048 
6049 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6050  return llvm::to_vector<4>(getResultVectorType().getShape());
6051 }
6052 
6053 namespace {
6054 
6055 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
6056 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
6057 public:
6059 
6060  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6061  PatternRewriter &rewriter) const override {
6062  // Composes two permutations: result[i] = permutation1[permutation2[i]].
6063  auto composePermutations = [](ArrayRef<int64_t> permutation1,
6064  ArrayRef<int64_t> permutation2) {
6065  SmallVector<int64_t, 4> result;
6066  for (auto index : permutation2)
6067  result.push_back(permutation1[index]);
6068  return result;
6069  };
6070 
6071  // Return if the input of 'transposeOp' is not defined by another transpose.
6072  vector::TransposeOp parentTransposeOp =
6073  transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6074  if (!parentTransposeOp)
6075  return failure();
6076 
6077  SmallVector<int64_t, 4> permutation = composePermutations(
6078  parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6079  // Replace 'transposeOp' with a new transpose operation.
6080  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
6081  transposeOp, transposeOp.getResult().getType(),
6082  parentTransposeOp.getVector(), permutation);
6083  return success();
6084  }
6085 };
6086 
6087 // Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
6088 struct FoldTransposedScalarBroadcast final
6089  : public OpRewritePattern<vector::TransposeOp> {
6091 
6092  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6093  PatternRewriter &rewriter) const override {
6094  auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
6095  if (!bcastOp)
6096  return failure();
6097 
6098  auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
6099  if (!srcVectorType || srcVectorType.getNumElements() == 1) {
6100  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6101  transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
6102  return success();
6103  }
6104 
6105  return failure();
6106  }
6107 };
6108 
6109 // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
6110 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
6111 public:
6113 
6114  LogicalResult matchAndRewrite(TransposeOp transposeOp,
6115  PatternRewriter &rewriter) const override {
6116  auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
6117  if (!splatOp)
6118  return failure();
6119 
6120  rewriter.replaceOpWithNewOp<vector::SplatOp>(
6121  transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
6122  return success();
6123  }
6124 };
6125 
6126 /// Folds transpose(create_mask) into a new transposed create_mask.
6127 class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6128 public:
6130 
6131  LogicalResult matchAndRewrite(TransposeOp transpOp,
6132  PatternRewriter &rewriter) const override {
6133  Value transposeSrc = transpOp.getVector();
6134  auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
6135  auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
6136  if (!createMaskOp && !constantMaskOp)
6137  return failure();
6138 
6139  // Get the transpose permutation and apply it to the vector.create_mask or
6140  // vector.constant_mask operands.
6141  ArrayRef<int64_t> permutation = transpOp.getPermutation();
6142 
6143  if (createMaskOp) {
6144  auto maskOperands = createMaskOp.getOperands();
6145  SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
6146  applyPermutationToVector(newOperands, permutation);
6147 
6148  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
6149  transpOp, transpOp.getResultVectorType(), newOperands);
6150  return success();
6151  }
6152 
6153  // ConstantMaskOp case.
6154  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6155  auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
6156 
6157  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
6158  transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6159  return success();
6160  }
6161 };
6162 
6163 } // namespace
6164 
6165 void vector::TransposeOp::getCanonicalizationPatterns(
6166  RewritePatternSet &results, MLIRContext *context) {
6167  results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
6168  TransposeFolder, FoldTransposeSplat>(context);
6169 }
6170 
6171 //===----------------------------------------------------------------------===//
6172 // ConstantMaskOp
6173 //===----------------------------------------------------------------------===//
6174 
6175 void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
6176  VectorType type, ConstantMaskKind kind) {
6177  assert(kind == ConstantMaskKind::AllTrue ||
6178  kind == ConstantMaskKind::AllFalse);
6179  build(builder, result, type,
6180  kind == ConstantMaskKind::AllTrue
6181  ? type.getShape()
6182  : SmallVector<int64_t>(type.getRank(), 0));
6183 }
6184 
6185 LogicalResult ConstantMaskOp::verify() {
6186  auto resultType = llvm::cast<VectorType>(getResult().getType());
6187  // Check the corner case of 0-D vectors first.
6188  if (resultType.getRank() == 0) {
6189  if (getMaskDimSizes().size() != 1)
6190  return emitError("array attr must have length 1 for 0-D vectors");
6191  auto dim = getMaskDimSizes()[0];
6192  if (dim != 0 && dim != 1)
6193  return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
6194  return success();
6195  }
6196 
6197  // Verify that array attr size matches the rank of the vector result.
6198  if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
6199  return emitOpError(
6200  "must specify array attr of size equal vector result rank");
6201  // Verify that each array attr element is in bounds of corresponding vector
6202  // result dimension size.
6203  auto resultShape = resultType.getShape();
6204  auto resultScalableDims = resultType.getScalableDims();
6205  ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
6206  for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
6207  if (maskDimSize < 0 || maskDimSize > resultShape[index])
6208  return emitOpError(
6209  "array attr of size out of bounds of vector result dimension size");
6210  if (resultScalableDims[index] && maskDimSize != 0 &&
6211  maskDimSize != resultShape[index])
6212  return emitOpError(
6213  "only supports 'none set' or 'all set' scalable dimensions");
6214  }
6215  // Verify that if one mask dim size is zero, they all should be zero (because
6216  // the mask region is a conjunction of each mask dimension interval).
6217  bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6218  bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
6219  if (anyZeros && !allZeros)
6220  return emitOpError("expected all mask dim sizes to be zeros, "
6221  "as a result of conjunction with zero mask dim");
6222  return success();
6223 }
6224 
6225 bool ConstantMaskOp::isAllOnesMask() {
6226  auto resultType = getVectorType();
6227  // Check the corner case of 0-D vectors first.
6228  if (resultType.getRank() == 0) {
6229  assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
6230  return getMaskDimSizes()[0] == 1;
6231  }
6232  for (const auto [resultSize, maskDimSize] :
6233  llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
6234  if (maskDimSize < resultSize)
6235  return false;
6236  }
6237  return true;
6238 }
6239 
6240 //===----------------------------------------------------------------------===//
6241 // CreateMaskOp
6242 //===----------------------------------------------------------------------===//
6243 
6244 void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
6245  VectorType type,
6246  ArrayRef<OpFoldResult> mixedOperands) {
6247  SmallVector<Value> operands =
6248  getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
6249  build(builder, result, type, operands);
6250 }
6251 
6252 LogicalResult CreateMaskOp::verify() {
6253  auto vectorType = llvm::cast<VectorType>(getResult().getType());
6254  // Verify that an operand was specified for each result vector each dimension.
6255  if (vectorType.getRank() == 0) {
6256  if (getNumOperands() != 1)
6257  return emitOpError(
6258  "must specify exactly one operand for 0-D create_mask");
6259  } else if (getNumOperands() !=
6260  llvm::cast<VectorType>(getResult().getType()).getRank()) {
6261  return emitOpError(
6262  "must specify an operand for each result vector dimension");
6263  }
6264  return success();
6265 }
6266 
6267 namespace {
6268 
6269 /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
6270 ///
6271 /// Ex 1:
6272 /// %c2 = arith.constant 2 : index
6273 /// %c3 = arith.constant 3 : index
6274 /// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
6275 /// Becomes:
6276 /// vector.constant_mask [3, 2] : vector<4x3xi1>
6277 ///
6278 /// Ex 2:
6279 /// %c_neg_1 = arith.constant -1 : index
6280 /// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
6281 /// becomes:
6282 /// vector.constant_mask [0] : vector<[8]xi1>
6283 ///
6284 /// Ex 3:
6285 /// %c8 = arith.constant 8 : index
6286 /// %c16 = arith.constant 16 : index
6287 /// %0 = vector.vscale
6288 /// %1 = arith.muli %0, %c16 : index
6289 /// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
6290 /// becomes:
6291 /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
6292 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
6293 public:
6295 
6296  LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
6297  PatternRewriter &rewriter) const override {
6298  VectorType maskType = createMaskOp.getVectorType();
6299  ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
6300  ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6301 
6302  // Special case: Rank zero shape.
6303  constexpr std::array<int64_t, 1> rankZeroShape{1};
6304  constexpr std::array<bool, 1> rankZeroScalableDims{false};
6305  if (maskType.getRank() == 0) {
6306  maskTypeDimSizes = rankZeroShape;
6307  maskTypeDimScalableFlags = rankZeroScalableDims;
6308  }
6309 
6310  // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
6311  // collect the `constantDims` (for the ConstantMaskOp).
6312  SmallVector<int64_t, 4> constantDims;
6313  for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
6314  if (auto intSize = getConstantIntValue(dimSize)) {
6315  // Constant value.
6316  // If the mask dim is non-scalable this can be any value.
6317  // If the mask dim is scalable only zero (all-false) is supported.
6318  if (maskTypeDimScalableFlags[i] && intSize >= 0)
6319  return failure();
6320  constantDims.push_back(*intSize);
6321  } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
6322  // Constant vscale multiple (e.g. 4 x vscale).
6323  // Must be all-true to fold to a ConstantMask.
6324  if (vscaleMultiplier < maskTypeDimSizes[i])
6325  return failure();
6326  constantDims.push_back(*vscaleMultiplier);
6327  } else {
6328  return failure();
6329  }
6330  }
6331 
6332  // Clamp values to constant_mask bounds.
6333  for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
6334  value = std::clamp<int64_t>(value, 0, maskDimSize);
6335 
6336  // If one of dim sizes is zero, set all dims to zero.
6337  if (llvm::is_contained(constantDims, 0))
6338  constantDims.assign(constantDims.size(), 0);
6339 
6340  // Replace 'createMaskOp' with ConstantMaskOp.
6341  rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
6342  constantDims);
6343  return success();
6344  }
6345 };
6346 
6347 } // namespace
6348 
6349 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6350  MLIRContext *context) {
6351  results.add<CreateMaskFolder>(context);
6352 }
6353 
6354 //===----------------------------------------------------------------------===//
6355 // MaskOp
6356 //===----------------------------------------------------------------------===//
6357 
6358 void MaskOp::build(
6359  OpBuilder &builder, OperationState &result, Value mask,
6360  Operation *maskableOp,
6361  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6362  assert(maskRegionBuilder &&
6363  "builder callback for 'maskRegion' must be present");
6364 
6365  result.addOperands(mask);
6366  OpBuilder::InsertionGuard guard(builder);
6367  Region *maskRegion = result.addRegion();
6368  builder.createBlock(maskRegion);
6369  maskRegionBuilder(builder, maskableOp);
6370 }
6371 
6372 void MaskOp::build(
6373  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6374  Value mask, Operation *maskableOp,
6375  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6376  build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
6377  maskRegionBuilder);
6378 }
6379 
6380 void MaskOp::build(
6381  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6382  Value mask, Value passthru, Operation *maskableOp,
6383  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6384  build(builder, result, mask, maskableOp, maskRegionBuilder);
6385  if (passthru)
6386  result.addOperands(passthru);
6387  result.addTypes(resultTypes);
6388 }
6389 
6390 ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
6391  // Create the op region.
6392  result.regions.reserve(1);
6393  Region &maskRegion = *result.addRegion();
6394 
6395  auto &builder = parser.getBuilder();
6396 
6397  // Parse all the operands.
6399  if (parser.parseOperand(mask))
6400  return failure();
6401 
6402  // Optional passthru operand.
6404  ParseResult parsePassthru = parser.parseOptionalComma();
6405  if (parsePassthru.succeeded() && parser.parseOperand(passthru))
6406  return failure();
6407 
6408  // Parse op region.
6409  if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
6410  return failure();
6411 
6412  MaskOp::ensureTerminator(maskRegion, builder, result.location);
6413 
6414  // Parse the optional attribute list.
6415  if (parser.parseOptionalAttrDict(result.attributes))
6416  return failure();
6417 
6418  // Parse all the types.
6419  Type maskType;
6420  if (parser.parseColonType(maskType))
6421  return failure();
6422 
6423  SmallVector<Type> resultTypes;
6424  if (parser.parseOptionalArrowTypeList(resultTypes))
6425  return failure();
6426  result.types.append(resultTypes);
6427 
6428  // Resolve operands.
6429  if (parser.resolveOperand(mask, maskType, result.operands))
6430  return failure();
6431 
6432  if (parsePassthru.succeeded())
6433  if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
6434  return failure();
6435 
6436  return success();
6437 }
6438 
6440  p << " " << getMask();
6441  if (getPassthru())
6442  p << ", " << getPassthru();
6443 
6444  // Print single masked operation and skip terminator.
6445  p << " { ";
6446  Block *singleBlock = &getMaskRegion().getBlocks().front();
6447  if (singleBlock && !singleBlock->getOperations().empty())
6448  p.printCustomOrGenericOp(&singleBlock->front());
6449  p << " }";
6450 
6451  p.printOptionalAttrDict(getOperation()->getAttrs());
6452 
6453  p << " : " << getMask().getType();
6454  if (getNumResults() > 0)
6455  p << " -> " << getResultTypes();
6456 }
6457 
6458 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
6460  MaskOp>::ensureTerminator(region, builder, loc);
6461  // Keep the default yield terminator if the number of masked operations is not
6462  // the expected. This case will trigger a verification failure.
6463  Block &block = region.front();
6464  if (block.getOperations().size() != 2)
6465  return;
6466 
6467  // Replace default yield terminator with a new one that returns the results
6468  // from the masked operation.
6469  OpBuilder opBuilder(builder.getContext());
6470  Operation *maskedOp = &block.front();
6471  Operation *oldYieldOp = &block.back();
6472  assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
6473 
6474  // Empty vector.mask op.
6475  if (maskedOp == oldYieldOp)
6476  return;
6477 
6478  opBuilder.setInsertionPoint(oldYieldOp);
6479  opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
6480  oldYieldOp->dropAllReferences();
6481  oldYieldOp->erase();
6482 }
6483 
6484 LogicalResult MaskOp::verify() {
6485  // Structural checks.
6486  Block &block = getMaskRegion().getBlocks().front();
6487  if (block.getOperations().empty())
6488  return emitOpError("expects a terminator within the mask region");
6489 
6490  unsigned numMaskRegionOps = block.getOperations().size();
6491  if (numMaskRegionOps > 2)
6492  return emitOpError("expects only one operation to mask");
6493 
6494  // Terminator checks.
6495  auto terminator = dyn_cast<vector::YieldOp>(block.back());
6496  if (!terminator)
6497  return emitOpError("expects a terminator within the mask region");
6498 
6499  if (terminator->getNumOperands() != getNumResults())
6500  return emitOpError(
6501  "expects number of results to match mask region yielded values");
6502 
6503  // Empty vector.mask. Nothing else to check.
6504  if (numMaskRegionOps == 1)
6505  return success();
6506 
6507  auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
6508  if (!maskableOp)
6509  return emitOpError("expects a MaskableOpInterface within the mask region");
6510 
6511  // Result checks.
6512  if (maskableOp->getNumResults() != getNumResults())
6513  return emitOpError("expects number of results to match maskable operation "
6514  "number of results");
6515 
6516  if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
6517  return emitOpError(
6518  "expects result type to match maskable operation result type");
6519 
6520  if (llvm::count_if(maskableOp->getResultTypes(),
6521  [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
6522  return emitOpError("multiple vector results not supported");
6523 
6524  // Mask checks.
6525  Type expectedMaskType = maskableOp.getExpectedMaskType();
6526  if (getMask().getType() != expectedMaskType)
6527  return emitOpError("expects a ")
6528  << expectedMaskType << " mask for the maskable operation";
6529 
6530  // Passthru checks.
6531  Value passthru = getPassthru();
6532  if (passthru) {
6533  if (!maskableOp.supportsPassthru())
6534  return emitOpError(
6535  "doesn't expect a passthru argument for this maskable operation");
6536 
6537  if (maskableOp->getNumResults() != 1)
6538  return emitOpError("expects result when passthru argument is provided");
6539 
6540  if (passthru.getType() != maskableOp->getResultTypes()[0])
6541  return emitOpError("expects passthru type to match result type");
6542  }
6543 
6544  return success();
6545 }
6546 
6547 /// Folds vector.mask ops with an all-true mask.
6548 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6549  SmallVectorImpl<OpFoldResult> &results) {
6550  MaskFormat maskFormat = getMaskFormat(getMask());
6551  if (isEmpty())
6552  return failure();
6553 
6554  if (maskFormat != MaskFormat::AllTrue)
6555  return failure();
6556 
6557  // Move maskable operation outside of the `vector.mask` region.
6558  Operation *maskableOp = getMaskableOp();
6559  maskableOp->dropAllUses();
6560  maskableOp->moveBefore(getOperation());
6561 
6562  llvm::append_range(results, maskableOp->getResults());
6563  return success();
6564 }
6565 
6566 // Elides empty vector.mask operations with or without return values. Propagates
6567 // the yielded values by the vector.yield terminator, if any, or erases the op,
6568 // otherwise.
6569 class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
6571 
6572  LogicalResult matchAndRewrite(MaskOp maskOp,
6573  PatternRewriter &rewriter) const override {
6574  auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6575  if (maskingOp.getMaskableOp())
6576  return failure();
6577 
6578  if (!maskOp.isEmpty())
6579  return failure();
6580 
6581  Block *block = maskOp.getMaskBlock();
6582  auto terminator = cast<vector::YieldOp>(block->front());
6583  if (terminator.getNumOperands() == 0)
6584  rewriter.eraseOp(maskOp);
6585  else
6586  rewriter.replaceOp(maskOp, terminator.getOperands());
6587 
6588  return success();
6589  }
6590 };
6591 
6592 void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6593  MLIRContext *context) {
6594  results.add<ElideEmptyMaskOp>(context);
6595 }
6596 
6597 // MaskingOpInterface definitions.
6598 
6599 /// Returns the operation masked by this 'vector.mask'.
6600 Operation *MaskOp::getMaskableOp() {
6601  Block *block = getMaskBlock();
6602  if (block->getOperations().size() < 2)
6603  return nullptr;
6604 
6605  return &block->front();
6606 }
6607 
6608 /// Returns true if 'vector.mask' has a passthru value.
6609 bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
6610 
6611 //===----------------------------------------------------------------------===//
6612 // ScanOp
6613 //===----------------------------------------------------------------------===//
6614 
6615 LogicalResult ScanOp::verify() {
6616  VectorType srcType = getSourceType();
6617  VectorType initialType = getInitialValueType();
6618  // Check reduction dimension < rank.
6619  int64_t srcRank = srcType.getRank();
6620  int64_t reductionDim = getReductionDim();
6621  if (reductionDim >= srcRank)
6622  return emitOpError("reduction dimension ")
6623  << reductionDim << " has to be less than " << srcRank;
6624 
6625  // Check that rank(initial_value) = rank(src) - 1.
6626  int64_t initialValueRank = initialType.getRank();
6627  if (initialValueRank != srcRank - 1)
6628  return emitOpError("initial value rank ")
6629  << initialValueRank << " has to be equal to " << srcRank - 1;
6630 
6631  // Check shapes of initial value and src.
6632  ArrayRef<int64_t> srcShape = srcType.getShape();
6633  ArrayRef<int64_t> initialValueShapes = initialType.getShape();
6634  SmallVector<int64_t> expectedShape;
6635  for (int i = 0; i < srcRank; i++) {
6636  if (i != reductionDim)
6637  expectedShape.push_back(srcShape[i]);
6638  }
6639  if (!llvm::equal(initialValueShapes, expectedShape)) {
6640  return emitOpError("incompatible input/initial value shapes");
6641  }
6642 
6643  // Verify supported reduction kind.
6644  Type eltType = getDestType().getElementType();
6645  if (!isSupportedCombiningKind(getKind(), eltType))
6646  return emitOpError("unsupported reduction type ")
6647  << eltType << " for kind '" << stringifyCombiningKind(getKind())
6648  << "'";
6649 
6650  return success();
6651 }
6652 
6655  patterns
6656  .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
6657  ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
6658  StridedSliceConstantMaskFolder, TransposeFolder>(
6659  patterns.getContext(), benefit);
6660 }
6661 
6662 //===----------------------------------------------------------------------===//
6663 // SplatOp
6664 //===----------------------------------------------------------------------===//
6665 
6666 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
6667  auto constOperand = adaptor.getInput();
6668  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
6669  return {};
6670 
6671  // SplatElementsAttr::get treats single value for second arg as being a splat.
6672  return SplatElementsAttr::get(getType(), {constOperand});
6673 }
6674 
6675 void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6676  SetIntRangeFn setResultRanges) {
6677  setResultRanges(getResult(), argRanges.front());
6678 }
6679 
6681  CombiningKind kind, Value v1, Value acc,
6682  arith::FastMathFlagsAttr fastmath,
6683  Value mask) {
6684  Type t1 = getElementTypeOrSelf(v1.getType());
6685  Type tAcc = getElementTypeOrSelf(acc.getType());
6686  Value result;
6687 
6688  switch (kind) {
6689  case CombiningKind::ADD:
6690  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6691  result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
6692  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6693  result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
6694  else
6695  llvm_unreachable("invalid value types for ADD reduction");
6696  break;
6697  case CombiningKind::AND:
6698  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6699  result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
6700  break;
6701  case CombiningKind::MAXNUMF:
6702  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6703  "expected float values");
6704  result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
6705  break;
6706  case CombiningKind::MAXIMUMF:
6707  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6708  "expected float values");
6709  result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
6710  break;
6711  case CombiningKind::MINNUMF:
6712  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6713  "expected float values");
6714  result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
6715  break;
6716  case CombiningKind::MINIMUMF:
6717  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6718  "expected float values");
6719  result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
6720  break;
6721  case CombiningKind::MAXSI:
6722  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6723  result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
6724  break;
6725  case CombiningKind::MINSI:
6726  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6727  result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
6728  break;
6729  case CombiningKind::MAXUI:
6730  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6731  result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
6732  break;
6733  case CombiningKind::MINUI:
6734  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6735  result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
6736  break;
6737  case CombiningKind::MUL:
6738  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6739  result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
6740  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6741  result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
6742  else
6743  llvm_unreachable("invalid value types for MUL reduction");
6744  break;
6745  case CombiningKind::OR:
6746  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6747  result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
6748  break;
6749  case CombiningKind::XOR:
6750  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6751  result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
6752  break;
6753  };
6754 
6755  assert(result && "unknown CombiningKind");
6756  return selectPassthru(b, mask, result, acc);
6757 }
6758 
6759 //===----------------------------------------------------------------------===//
6760 // Vector Masking Utilities
6761 //===----------------------------------------------------------------------===//
6762 
6763 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
6764 /// as masked operation.
6766  Operation *maskableOp) {
6767  assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
6768  Block *insBlock = builder.getInsertionBlock();
6769  // Create a block and move the op to that block.
6770  insBlock->getOperations().splice(
6771  insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
6772  builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
6773 }
6774 
6775 /// Creates a vector.mask operation around a maskable operation. Returns the
6776 /// vector.mask operation if the mask provided is valid. Otherwise, returns
6777 /// the maskable operation itself.
6779  Operation *maskableOp, Value mask,
6780  Value passthru) {
6781  if (!mask)
6782  return maskableOp;
6783  if (passthru)
6784  return builder.create<MaskOp>(maskableOp->getLoc(),
6785  maskableOp->getResultTypes(), mask, passthru,
6786  maskableOp, createMaskOpRegion);
6787  return builder.create<MaskOp>(maskableOp->getLoc(),
6788  maskableOp->getResultTypes(), mask, maskableOp,
6790 }
6791 
6792 /// Creates a vector select operation that picks values from `newValue` or
6793 /// `passthru` for each result vector lane based on `mask`. This utility is used
6794 /// to propagate the pass-thru value of vector.mask or for cases where only the
6795 /// pass-thru value propagation is needed. VP intrinsics do not support
6796 /// pass-thru values and every mask-out lane is set to poison. LLVM backends are
6797 /// usually able to match op + select patterns and fold them into a native
6798 /// target instructions.
6800  Value newValue, Value passthru) {
6801  if (!mask)
6802  return newValue;
6803 
6804  return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(),
6805  mask, newValue, passthru);
6806 }
6807 
6808 //===----------------------------------------------------------------------===//
6809 // TableGen'd op method definitions
6810 //===----------------------------------------------------------------------===//
6811 
6812 #define GET_ATTRDEF_CLASSES
6813 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
6814 
6815 #define GET_OP_CLASSES
6816 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MINUI(lhs, rhs)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition: VectorOps.cpp:68
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
Definition: VectorOps.cpp:1091
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
Definition: VectorOps.cpp:1390
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
Definition: VectorOps.cpp:1651
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
Definition: VectorOps.cpp:4220
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
Definition: VectorOps.cpp:1959
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
Definition: VectorOps.cpp:1827
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
Definition: VectorOps.cpp:1662
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
Definition: VectorOps.cpp:2043
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Definition: VectorOps.cpp:2341
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
Definition: VectorOps.cpp:130
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
Definition: VectorOps.cpp:2033
MaskFormat
Helper enum to classify mask value.
Definition: VectorOps.cpp:58
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
Definition: VectorOps.cpp:3248
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
Definition: VectorOps.cpp:296
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
Definition: VectorOps.cpp:872
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
Definition: VectorOps.cpp:2390
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
Definition: VectorOps.cpp:1995
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
Definition: VectorOps.cpp:2708
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
Definition: VectorOps.cpp:3184
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
Definition: VectorOps.cpp:1083
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
Definition: VectorOps.cpp:3204
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
Definition: VectorOps.cpp:175
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
Definition: VectorOps.cpp:3227
static LogicalResult foldTransferFullMask(TransferOp op)
Definition: VectorOps.cpp:4454
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Definition: VectorOps.cpp:1382
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite a vector.from_elements into a vector.splat if all elements are the same SSA value.
Definition: VectorOps.cpp:2364
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
Definition: VectorOps.cpp:1283
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
Definition: VectorOps.cpp:4112
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
Definition: VectorOps.cpp:883
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
Definition: VectorOps.cpp:3169
static Value foldExtractFromShapeCast(ExtractOp extractOp)
Definition: VectorOps.cpp:1762
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
Definition: VectorOps.cpp:4141
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
Definition: VectorOps.cpp:4382
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
Definition: VectorOps.cpp:3588
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
Definition: VectorOps.cpp:1730
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
Definition: VectorOps.cpp:4399
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
Definition: VectorOps.cpp:1879
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:135
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:216
unsigned getNumResults() const
Definition: AffineMap.cpp:402
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:161
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:76
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation & back()
Definition: Block.h:152
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:277
IndexType getIndexType()
Definition: Builders.cpp:51
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:266
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:314
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:544
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:835
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
Definition: Operation.cpp:584
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:112
bool isF32() const
Definition: Types.cpp:40
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:104
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:114
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:270
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Definition: BuiltinTypes.h:282
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:289
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:389
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Definition: VectorOps.cpp:450
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
Definition: VectorOps.cpp:125
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
Definition: VectorOps.cpp:354
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:153
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition: VectorOps.h:60
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2522
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Definition: VectorOps.cpp:4239
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
Definition: VectorOps.cpp:219
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:283
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:314
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:338
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
Definition: VectorOps.cpp:628
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition: VectorOps.h:68
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Definition: VectorOps.cpp:446
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:773
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:722
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:641
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:930
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
Return a fused vector::ContractionOp which represents a patterns such as:
Definition: VectorOps.cpp:1184
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Definition: VectorOps.cpp:1187
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:329
bool operator==(const KeyTy &key) const
Definition: VectorOps.cpp:381
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Definition: VectorOps.cpp:383
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.