MLIR  20.0.0git
VectorOps.cpp
Go to the documentation of this file.
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
24 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/AffineMap.h"
26 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/BuiltinOps.h"
29 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/IRMapping.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/IR/TypeUtilities.h"
37 #include "mlir/Support/LLVM.h"
39 #include "llvm/ADT/ArrayRef.h"
40 #include "llvm/ADT/STLExtras.h"
41 #include "llvm/ADT/SmallVector.h"
42 #include "llvm/ADT/StringSet.h"
43 #include "llvm/ADT/TypeSwitch.h"
44 #include "llvm/ADT/bit.h"
45 
46 #include <cassert>
47 #include <cstdint>
48 #include <numeric>
49 
50 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
51 // Pull in all enum type and utility function definitions.
52 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
53 
54 using namespace mlir;
55 using namespace mlir::vector;
56 
57 /// Helper enum to classify mask value.
58 enum class MaskFormat {
59  AllTrue = 0,
60  AllFalse = 1,
61  Unknown = 2,
62 };
63 
64 /// Helper method to classify a mask value. Currently, the method
65 /// looks "under the hood" of a constant value with dense attributes
66 /// and a constant mask operation (since the client may be called at
67 /// various stages during progressive lowering).
69  if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
70  // Inspect constant dense values. We count up for bits that
71  // are set, count down for bits that are cleared, and bail
72  // when a mix is detected.
73  if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
74  int64_t val = 0;
75  for (bool b : denseElts.getValues<bool>())
76  if (b && val >= 0)
77  val++;
78  else if (!b && val <= 0)
79  val--;
80  else
81  return MaskFormat::Unknown;
82  if (val > 0)
83  return MaskFormat::AllTrue;
84  if (val < 0)
85  return MaskFormat::AllFalse;
86  }
87  } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
88  // Inspect constant mask index. If the index exceeds the
89  // dimension size, all bits are set. If the index is zero
90  // or less, no bits are set.
91  ArrayRef<int64_t> masks = m.getMaskDimSizes();
92  auto shape = m.getType().getShape();
93  bool allTrue = true;
94  bool allFalse = true;
95  for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
96  if (maskIdx < dimSize)
97  allTrue = false;
98  if (maskIdx > 0)
99  allFalse = false;
100  }
101  if (allTrue)
102  return MaskFormat::AllTrue;
103  if (allFalse)
104  return MaskFormat::AllFalse;
105  } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
106  // Finds all-false create_masks. An all-true create_mask requires all
107  // dims to be constants, so that'll be folded to a constant_mask, then
108  // detected in the constant_mask case.
109  auto maskOperands = m.getOperands();
110  for (Value operand : maskOperands) {
111  if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
112  int64_t dimSize =
113  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
114  if (dimSize <= 0)
115  return MaskFormat::AllFalse;
116  }
117  }
118  return MaskFormat::Unknown;
119  }
120  return MaskFormat::Unknown;
121 }
122 
123 /// Default callback to build a region with a 'vector.yield' terminator with no
124 /// arguments.
126  builder.create<vector::YieldOp>(loc);
127 }
128 
129 // Helper for verifying combining kinds in contractions and reductions.
130 static bool isSupportedCombiningKind(CombiningKind combiningKind,
131  Type elementType) {
132  switch (combiningKind) {
133  case CombiningKind::ADD:
134  case CombiningKind::MUL:
135  return elementType.isIntOrIndexOrFloat();
137  case CombiningKind::MINSI:
138  case CombiningKind::MAXUI:
139  case CombiningKind::MAXSI:
140  case CombiningKind::AND:
141  case CombiningKind::OR:
142  case CombiningKind::XOR:
143  return elementType.isIntOrIndex();
144  case CombiningKind::MINNUMF:
145  case CombiningKind::MAXNUMF:
146  case CombiningKind::MINIMUMF:
147  case CombiningKind::MAXIMUMF:
148  return llvm::isa<FloatType>(elementType);
149  }
150  return false;
151 }
152 
154  VectorType vectorType) {
155  int64_t elementVectorRank = 0;
156  VectorType elementVectorType =
157  llvm::dyn_cast<VectorType>(shapedType.getElementType());
158  if (elementVectorType)
159  elementVectorRank += elementVectorType.getRank();
160  // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
161  // TODO: replace once we have 0-d vectors.
162  if (shapedType.getRank() == 0 &&
163  vectorType.getShape() == ArrayRef<int64_t>{1})
164  return AffineMap::get(
165  /*numDims=*/0, /*numSymbols=*/0,
166  getAffineConstantExpr(0, shapedType.getContext()));
168  shapedType.getRank(), vectorType.getRank() - elementVectorRank,
169  shapedType.getContext());
170 }
171 
172 /// Check if `write` is of a constant splat and the masked `read` is padded with
173 /// the same splat value -- meaning it could be the same value as the initial
174 /// constant splat.
175 static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
176  vector::TransferReadOp read) {
177  auto readMask = read.getMask();
178  auto writeMask = write.getMask();
179  // Check if the masks are consistent. The splat value could be the same if the
180  // read is masked (and padded with the splat value), and the write is unmasked
181  // or has the same mask. Note this does not allow the case where the write is
182  // masked and the read is unmasked, as then the read could be of more elements
183  // than the write (which may not be the same value).
184  bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
185  if (!couldBeSameSplat)
186  return false;
187  // Check for constant splat (as the source of the write).
188  DenseElementsAttr splatAttr;
189  if (!matchPattern(write.getVector(),
190  m_Constant<DenseElementsAttr>(&splatAttr)) ||
191  !splatAttr.isSplat()) {
192  return false;
193  }
194  // The padding of the read and the constant splat value must be the same.
195  Attribute padAttr;
196  if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
197  return false;
198  return padAttr == splatAttr.getSplatValue<Attribute>();
199 }
200 
201 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
202  vector::TransferReadOp read) {
203  return !defWrite.hasOutOfBoundsDim() &&
204  defWrite.getIndices() == read.getIndices() &&
205  defWrite.getVectorType() == read.getVectorType() &&
206  defWrite.getPermutationMap() == read.getPermutationMap() &&
207  ((!defWrite.getMask() && !read.getMask()) ||
208  isSplatWriteConsistentWithMaskedRead(defWrite, read));
209 }
210 
211 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
212  vector::TransferWriteOp priorWrite) {
213  return priorWrite.getIndices() == write.getIndices() &&
214  priorWrite.getMask() == write.getMask() &&
215  priorWrite.getVectorType() == write.getVectorType() &&
216  priorWrite.getPermutationMap() == write.getPermutationMap();
217 }
218 
220  VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
221  bool testDynamicValueUsingBounds) {
222  // For simplicity only look at transfer of same type.
223  if (transferA.getVectorType() != transferB.getVectorType())
224  return false;
225  unsigned rankOffset = transferA.getLeadingShapedRank();
226  for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
227  Value indexA = transferA.getIndices()[i];
228  Value indexB = transferB.getIndices()[i];
229  std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
230  std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
231 
232  if (i < rankOffset) {
233  // For leading dimensions, if we can prove that index are different we
234  // know we are accessing disjoint slices.
235  if (cstIndexA.has_value() && cstIndexB.has_value()) {
236  if (*cstIndexA != *cstIndexB)
237  return true;
238  continue;
239  }
240  if (testDynamicValueUsingBounds) {
241  // First try to see if we can fully compose and simplify the affine
242  // expression as a fast track.
243  FailureOr<uint64_t> delta =
245  if (succeeded(delta) && *delta != 0)
246  return true;
247 
248  FailureOr<bool> testEqual =
249  ValueBoundsConstraintSet::areEqual(indexA, indexB);
250  if (succeeded(testEqual) && !testEqual.value())
251  return true;
252  }
253  } else {
254  // For this dimension, we slice a part of the memref we need to make sure
255  // the intervals accessed don't overlap.
256  int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
257  if (cstIndexA.has_value() && cstIndexB.has_value()) {
258  int64_t distance = std::abs(*cstIndexA - *cstIndexB);
259  if (distance >= vectorDim)
260  return true;
261  continue;
262  }
263  if (testDynamicValueUsingBounds) {
264  // First try to see if we can fully compose and simplify the affine
265  // expression as a fast track.
266  FailureOr<int64_t> delta =
268  if (succeeded(delta) && std::abs(*delta) >= vectorDim)
269  return true;
270 
271  FailureOr<int64_t> computeDelta =
273  if (succeeded(computeDelta)) {
274  if (std::abs(computeDelta.value()) >= vectorDim)
275  return true;
276  }
277  }
278  }
279  }
280  return false;
281 }
282 
283 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
284  VectorTransferOpInterface transferB,
285  bool testDynamicValueUsingBounds) {
286  if (transferA.getSource() != transferB.getSource())
287  return false;
288  return isDisjointTransferIndices(transferA, transferB,
289  testDynamicValueUsingBounds);
290 }
291 
292 // Helper to iterate over n-D vector slice elements. Calculate the next
293 // `position` in the n-D vector of size `shape`, applying an offset `offsets`.
294 // Modifies the `position` in place. Returns a failure when `position` becomes
295 // the end position.
296 static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
297  ArrayRef<int64_t> shape,
298  ArrayRef<int64_t> offsets) {
299  for (auto [posInDim, dimSize, offsetInDim] :
300  llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
301  ++posInDim;
302  if (posInDim < dimSize + offsetInDim)
303  return success();
304 
305  // Carry the overflow to the next loop iteration.
306  posInDim = offsetInDim;
307  }
308 
309  return failure();
310 }
311 
312 /// Returns the integer numbers in `values`. `values` are expected to be
313 /// constant operations.
316  llvm::transform(values, std::back_inserter(ints), [](Value value) {
317  auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
318  assert(constOp && "Unexpected non-constant index");
319  return constOp.value();
320  });
321  return ints;
322 }
323 
324 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
325 /// be constant operations.
328  llvm::transform(
329  foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
330  assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
331  return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
332  });
333  return ints;
334 }
335 
336 /// Convert `foldResults` into Values. Integer attributes are converted to
337 /// constant op.
339  ArrayRef<OpFoldResult> foldResults) {
340  SmallVector<Value> values;
341  llvm::transform(foldResults, std::back_inserter(values),
342  [&](OpFoldResult foldResult) {
343  if (auto attr = foldResult.dyn_cast<Attribute>())
344  return builder
346  loc, cast<IntegerAttr>(attr).getInt())
347  .getResult();
348 
349  return cast<Value>(foldResult);
350  });
351  return values;
352 }
353 
354 std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
355  if (value.getDefiningOp<vector::VectorScaleOp>())
356  return 1;
357  auto mul = value.getDefiningOp<arith::MulIOp>();
358  if (!mul)
359  return {};
360  auto lhs = mul.getLhs();
361  auto rhs = mul.getRhs();
362  if (lhs.getDefiningOp<vector::VectorScaleOp>())
363  return getConstantIntValue(rhs);
364  if (rhs.getDefiningOp<vector::VectorScaleOp>())
365  return getConstantIntValue(lhs);
366  return {};
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // CombiningKindAttr
371 //===----------------------------------------------------------------------===//
372 
373 namespace mlir {
374 namespace vector {
375 namespace detail {
377  using KeyTy = uint64_t;
378 
379  BitmaskEnumStorage(KeyTy val) : value(val) {}
380 
381  bool operator==(const KeyTy &key) const { return value == key; }
382 
384  const KeyTy &key) {
385  return new (allocator.allocate<BitmaskEnumStorage>())
386  BitmaskEnumStorage(key);
387  }
388 
389  KeyTy value = 0;
390 };
391 } // namespace detail
392 } // namespace vector
393 } // namespace mlir
394 
395 //===----------------------------------------------------------------------===//
396 // VectorDialect
397 //===----------------------------------------------------------------------===//
398 
399 namespace {
400 /// This class defines the interface for handling inlining with vector dialect
401 /// operations.
402 struct VectorInlinerInterface : public DialectInlinerInterface {
404 
405  /// All vector dialect ops can be inlined.
406  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
407  return true;
408  }
409 };
410 } // namespace
411 
412 void VectorDialect::initialize() {
413  addAttributes<
414 #define GET_ATTRDEF_LIST
415 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
416  >();
417 
418  addOperations<
419 #define GET_OP_LIST
420 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
421  >();
422 
423  addInterfaces<VectorInlinerInterface>();
424 
425  declarePromisedInterfaces<bufferization::BufferizableOpInterface,
426  TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
427  YieldOp>();
428  declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
429  TransferWriteOp>();
430  declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
431  declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
432 }
433 
434 /// Materialize a single constant operation from a given attribute value with
435 /// the desired resultant type.
437  Attribute value, Type type,
438  Location loc) {
439  return arith::ConstantOp::materialize(builder, value, type, loc);
440 }
441 
443  return builder.getIntegerType(64);
444 }
445 
447  ArrayRef<int64_t> values) {
448  return builder.getI64ArrayAttr(values);
449 }
450 
451 //===----------------------------------------------------------------------===//
452 // MultiDimReductionOp
453 //===----------------------------------------------------------------------===//
454 
455 void vector::MultiDimReductionOp::build(OpBuilder &builder,
456  OperationState &result, Value source,
457  Value acc, ArrayRef<bool> reductionMask,
458  CombiningKind kind) {
459  SmallVector<int64_t> reductionDims;
460  for (const auto &en : llvm::enumerate(reductionMask))
461  if (en.value())
462  reductionDims.push_back(en.index());
463  build(builder, result, kind, source, acc, reductionDims);
464 }
465 
466 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
467  // Single parallel dim, this is a noop.
468  if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
469  return getSource();
470  return {};
471 }
472 
473 std::optional<SmallVector<int64_t, 4>>
474 MultiDimReductionOp::getShapeForUnroll() {
475  return llvm::to_vector<4>(getSourceVectorType().getShape());
476 }
477 
478 LogicalResult MultiDimReductionOp::verify() {
479  SmallVector<int64_t> targetShape;
480  SmallVector<bool> scalableDims;
481  Type inferredReturnType;
482  auto sourceScalableDims = getSourceVectorType().getScalableDims();
483  for (auto [dimIdx, dimSize] :
484  llvm::enumerate(getSourceVectorType().getShape()))
485  if (!llvm::any_of(getReductionDims(),
486  [dimIdx = dimIdx](int64_t reductionDimIdx) {
487  return reductionDimIdx == static_cast<int64_t>(dimIdx);
488  })) {
489  targetShape.push_back(dimSize);
490  scalableDims.push_back(sourceScalableDims[dimIdx]);
491  }
492  // TODO: update to also allow 0-d vectors when available.
493  if (targetShape.empty())
494  inferredReturnType = getSourceVectorType().getElementType();
495  else
496  inferredReturnType = VectorType::get(
497  targetShape, getSourceVectorType().getElementType(), scalableDims);
498  if (getType() != inferredReturnType)
499  return emitOpError() << "destination type " << getType()
500  << " is incompatible with source type "
501  << getSourceVectorType();
502 
503  return success();
504 }
505 
506 /// Returns the mask type expected by this operation.
507 Type MultiDimReductionOp::getExpectedMaskType() {
508  auto vecType = getSourceVectorType();
509  return VectorType::get(vecType.getShape(),
510  IntegerType::get(vecType.getContext(), /*width=*/1),
511  vecType.getScalableDims());
512 }
513 
514 namespace {
515 // Only unit dimensions that are being reduced are folded. If the dimension is
516 // unit, but not reduced, it is not folded, thereby keeping the output type the
517 // same. If not all dimensions which are reduced are of unit dimension, this
518 // transformation does nothing. This is just a generalization of
519 // ElideSingleElementReduction for ReduceOp.
520 struct ElideUnitDimsInMultiDimReduction
521  : public OpRewritePattern<MultiDimReductionOp> {
523 
524  LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
525  PatternRewriter &rewriter) const override {
526  ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
527  for (const auto &dim : enumerate(shape)) {
528  if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
529  return failure();
530  }
531 
532  // Vector mask setup.
533  OpBuilder::InsertionGuard guard(rewriter);
534  Operation *rootOp;
535  Value mask;
536  if (reductionOp.isMasked()) {
537  rewriter.setInsertionPoint(reductionOp.getMaskingOp());
538  rootOp = reductionOp.getMaskingOp();
539  mask = reductionOp.getMaskingOp().getMask();
540  } else {
541  rootOp = reductionOp;
542  }
543 
544  Location loc = reductionOp.getLoc();
545  Value acc = reductionOp.getAcc();
546  Value cast;
547  if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
548  if (mask) {
549  VectorType newMaskType =
550  VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
551  dstVecType.getScalableDims());
552  mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
553  }
554  cast = rewriter.create<vector::ShapeCastOp>(
555  loc, reductionOp.getDestType(), reductionOp.getSource());
556  } else {
557  // This means we are reducing all the dimensions, and all reduction
558  // dimensions are of size 1. So a simple extraction would do.
559  SmallVector<int64_t> zeroIdx(shape.size(), 0);
560  if (mask)
561  mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx);
562  cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(),
563  zeroIdx);
564  }
565 
566  Value result =
567  vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
568  cast, /*fastmath=*/nullptr, mask);
569  rewriter.replaceOp(rootOp, result);
570  return success();
571  }
572 };
573 } // namespace
574 
575 void MultiDimReductionOp::getCanonicalizationPatterns(
576  RewritePatternSet &results, MLIRContext *context) {
577  results.add<ElideUnitDimsInMultiDimReduction>(context);
578 }
579 
580 //===----------------------------------------------------------------------===//
581 // ReductionOp
582 //===----------------------------------------------------------------------===//
583 
584 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
585  CombiningKind kind, Value vector,
586  arith::FastMathFlags fastMathFlags) {
587  build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
588 }
589 
590 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
591  CombiningKind kind, Value vector, Value acc,
592  arith::FastMathFlags fastMathFlags) {
593  build(builder, result,
594  llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
595  acc, fastMathFlags);
596 }
597 
598 LogicalResult ReductionOp::verify() {
599  // Verify for 0-D and 1-D vector.
600  int64_t rank = getSourceVectorType().getRank();
601  if (rank > 1)
602  return emitOpError("unsupported reduction rank: ") << rank;
603 
604  // Verify supported reduction kind.
605  Type eltType = getDest().getType();
606  if (!isSupportedCombiningKind(getKind(), eltType))
607  return emitOpError("unsupported reduction type '")
608  << eltType << "' for kind '" << stringifyCombiningKind(getKind())
609  << "'";
610 
611  return success();
612 }
613 
614 // MaskableOpInterface methods.
615 
616 /// Returns the mask type expected by this operation.
617 Type ReductionOp::getExpectedMaskType() {
618  auto vecType = getSourceVectorType();
619  return VectorType::get(vecType.getShape(),
620  IntegerType::get(vecType.getContext(), /*width=*/1),
621  vecType.getScalableDims());
622 }
623 
625  OpBuilder &builder, Location loc,
626  Value vector) {
627  switch (op) {
628  case arith::AtomicRMWKind::addf:
629  case arith::AtomicRMWKind::addi:
630  return builder.create<vector::ReductionOp>(vector.getLoc(),
631  CombiningKind::ADD, vector);
632  case arith::AtomicRMWKind::mulf:
633  case arith::AtomicRMWKind::muli:
634  return builder.create<vector::ReductionOp>(vector.getLoc(),
635  CombiningKind::MUL, vector);
636  case arith::AtomicRMWKind::minimumf:
637  return builder.create<vector::ReductionOp>(vector.getLoc(),
638  CombiningKind::MINIMUMF, vector);
639  case arith::AtomicRMWKind::mins:
640  return builder.create<vector::ReductionOp>(vector.getLoc(),
641  CombiningKind::MINSI, vector);
642  case arith::AtomicRMWKind::minu:
643  return builder.create<vector::ReductionOp>(vector.getLoc(),
644  CombiningKind::MINUI, vector);
645  case arith::AtomicRMWKind::maximumf:
646  return builder.create<vector::ReductionOp>(vector.getLoc(),
647  CombiningKind::MAXIMUMF, vector);
648  case arith::AtomicRMWKind::maxs:
649  return builder.create<vector::ReductionOp>(vector.getLoc(),
650  CombiningKind::MAXSI, vector);
651  case arith::AtomicRMWKind::maxu:
652  return builder.create<vector::ReductionOp>(vector.getLoc(),
653  CombiningKind::MAXUI, vector);
654  case arith::AtomicRMWKind::andi:
655  return builder.create<vector::ReductionOp>(vector.getLoc(),
656  CombiningKind::AND, vector);
657  case arith::AtomicRMWKind::ori:
658  return builder.create<vector::ReductionOp>(vector.getLoc(),
659  CombiningKind::OR, vector);
660  // TODO: Add remaining reduction operations.
661  default:
662  (void)emitOptionalError(loc, "Reduction operation type not supported");
663  break;
664  }
665  return nullptr;
666 }
667 
668 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
669  return llvm::to_vector<4>(getSourceVectorType().getShape());
670 }
671 
672 namespace {
673 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
675 
676  LogicalResult matchAndRewrite(ReductionOp reductionOp,
677  PatternRewriter &rewriter) const override {
678  // Vector mask setup.
679  OpBuilder::InsertionGuard guard(rewriter);
680  auto maskableOp =
681  cast<vector::MaskableOpInterface>(reductionOp.getOperation());
682  Operation *rootOp;
683  Value mask;
684  if (maskableOp.isMasked()) {
685  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
686  rootOp = maskableOp.getMaskingOp();
687  mask = maskableOp.getMaskingOp().getMask();
688  } else {
689  rootOp = reductionOp;
690  }
691 
692  auto vectorType = reductionOp.getSourceVectorType();
693  if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
694  return failure();
695 
696  Location loc = reductionOp.getLoc();
697  Value result;
698  if (vectorType.getRank() == 0) {
699  if (mask)
700  mask = rewriter.create<ExtractElementOp>(loc, mask);
701  result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
702  } else {
703  if (mask)
704  mask = rewriter.create<ExtractOp>(loc, mask, 0);
705  result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0);
706  }
707 
708  if (Value acc = reductionOp.getAcc())
709  result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
710  result, acc,
711  reductionOp.getFastmathAttr(), mask);
712 
713  rewriter.replaceOp(rootOp, result);
714  return success();
715  }
716 };
717 } // namespace
718 
719 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
720  MLIRContext *context) {
721  results.add<ElideSingleElementReduction>(context);
722 }
723 
724 //===----------------------------------------------------------------------===//
725 // ContractionOp
726 //===----------------------------------------------------------------------===//
727 
728 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
729  Value lhs, Value rhs, Value acc,
730  ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
731  ArrayRef<IteratorType> iteratorTypes) {
732  result.addOperands({lhs, rhs, acc});
733  result.addTypes(acc.getType());
734  result.addAttribute(
735  getIndexingMapsAttrName(result.name),
736  builder.getAffineMapArrayAttr(
737  AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
738  result.addAttribute(
739  getIteratorTypesAttrName(result.name),
740  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
741  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
742  return IteratorTypeAttr::get(builder.getContext(), t);
743  }))));
744 }
745 
746 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
747  Value lhs, Value rhs, Value acc,
748  ArrayAttr indexingMaps,
749  ArrayAttr iteratorTypes) {
750  build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
751  ContractionOp::getDefaultKind());
752 }
753 
754 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
755  Value lhs, Value rhs, Value acc,
756  ArrayAttr indexingMaps,
757  ArrayAttr iteratorTypes, CombiningKind kind) {
758  result.addOperands({lhs, rhs, acc});
759  result.addTypes(acc.getType());
760  result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
761  result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
762  result.addAttribute(getKindAttrName(result.name),
763  CombiningKindAttr::get(builder.getContext(), kind));
764 }
765 
766 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
771  SmallVector<Type, 2> types;
772  Type resultType;
773  auto loc = parser.getCurrentLocation();
774  DictionaryAttr dictAttr;
775  // TODO: Unify linalg op attribute parsing.
776  if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
777  parser.parseComma() || parser.parseOperand(rhsInfo) ||
778  parser.parseComma() || parser.parseOperand(accInfo) ||
779  parser.parseTrailingOperandList(masksInfo) ||
780  parser.parseOptionalAttrDict(result.attributes) ||
781  parser.parseColonTypeList(types) ||
782  parser.parseKeywordType("into", resultType) ||
783  parser.resolveOperand(lhsInfo, types[0], result.operands) ||
784  parser.resolveOperand(rhsInfo, types[1], result.operands) ||
785  parser.resolveOperand(accInfo, resultType, result.operands) ||
786  parser.addTypeToList(resultType, result.types))
787  return failure();
788  result.attributes.append(dictAttr.getValue().begin(),
789  dictAttr.getValue().end());
790 
791  // Convert array of string into an array of IteratyType enums. This is needed,
792  // because tests still use the old format when 'iterator_types' attribute is
793  // represented as an array of strings.
794  // TODO: Remove this conversion once tests are fixed.
795  ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
796  result.attributes.get(getIteratorTypesAttrName(result.name)));
797 
798  SmallVector<Attribute> iteratorTypeAttrs;
799 
800  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
801  auto maybeIteratorType = symbolizeIteratorType(s);
802  if (!maybeIteratorType.has_value())
803  return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
804 
805  iteratorTypeAttrs.push_back(
806  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
807  }
808  result.attributes.set(getIteratorTypesAttrName(result.name),
809  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
810 
811  if (!result.attributes.get(getKindAttrName(result.name))) {
812  result.addAttribute(
813  getKindAttrName(result.name),
815  ContractionOp::getDefaultKind()));
816  }
817  if (masksInfo.empty())
818  return success();
819  if (masksInfo.size() != 2)
820  return parser.emitError(parser.getNameLoc(),
821  "expected zero or exactly 2 vector mask operands");
822  auto lhsType = llvm::cast<VectorType>(types[0]);
823  auto rhsType = llvm::cast<VectorType>(types[1]);
824  auto maskElementType = parser.getBuilder().getI1Type();
825  std::array<VectorType, 2> maskTypes = {
826  VectorType::Builder(lhsType).setElementType(maskElementType),
827  VectorType::Builder(rhsType).setElementType(maskElementType)};
828  if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
829  return failure();
830  return success();
831 }
832 
834  // TODO: Unify printing code with linalg ops.
835  auto attrNames = getTraitAttrNames();
836  llvm::StringSet<> traitAttrsSet;
837  traitAttrsSet.insert(attrNames.begin(), attrNames.end());
839  for (auto attr : (*this)->getAttrs()) {
840  if (attr.getName() == getIteratorTypesAttrName()) {
841  auto iteratorTypes =
842  llvm::cast<ArrayAttr>(attr.getValue())
843  .getAsValueRange<IteratorTypeAttr, IteratorType>();
844  // Convert IteratorType enums into the string representation. This is
845  // needed, because tests still use the old format when 'iterator_types'
846  // attribute is represented as an array of strings.
847  // TODO: Remove this conversion once tests are fixed.
848  SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
849  llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
850  return StringAttr::get(getContext(), stringifyIteratorType(t));
851  }));
852 
853  attrs.emplace_back(getIteratorTypesAttrName(),
854  ArrayAttr::get(getContext(), iteratorTypeNames));
855  } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
856  attrs.push_back(attr);
857  }
858 
859  auto dictAttr = DictionaryAttr::get(getContext(), attrs);
860  p << " " << dictAttr << " " << getLhs() << ", ";
861  p << getRhs() << ", " << getAcc();
862 
863  p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
864  p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
865  << getResultType();
866 }
867 
868 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
869  const std::vector<std::pair<int64_t, int64_t>> &map) {
870  for (auto &dimPair : map) {
871  if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
872  dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
873  lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
874  return false;
875  }
876  return true;
877 }
878 
879 static LogicalResult verifyOutputShape(
880  ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
881  Type resType,
882  const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
883  const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
884  DenseSet<int64_t> lhsContractingDimSet;
885  DenseSet<int64_t> rhsContractingDimSet;
886  for (auto &dimPair : contractingDimMap) {
887  lhsContractingDimSet.insert(dimPair.first);
888  rhsContractingDimSet.insert(dimPair.second);
889  }
890  DenseSet<int64_t> rhsBatchDimSet;
891  for (auto &dimPair : batchDimMap)
892  rhsBatchDimSet.insert(dimPair.second);
893 
894  // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
895  SmallVector<int64_t, 4> expectedResultDims;
896  for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
897  if (lhsContractingDimSet.count(i) > 0)
898  continue;
899  expectedResultDims.push_back(lhsType.getDimSize(i));
900  }
901 
902  // Add free dimensions from 'rhsType' to 'expectedResultDims'.
903  for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
904  if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
905  continue;
906  expectedResultDims.push_back(rhsType.getDimSize(i));
907  }
908 
909  // Verify 'expectedResultDims'.
910  if (expectedResultDims.empty()) {
911  // No batch or free dimension implies a scalar result.
912  if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
913  return op.emitOpError("invalid accumulator/result vector shape");
914  } else {
915  // At least one batch or free dimension implies a vector result.
916  auto resVectorType = llvm::dyn_cast<VectorType>(resType);
917  auto accVectorType = llvm::dyn_cast<VectorType>(accType);
918  if (!resVectorType || !accVectorType)
919  return op.emitOpError("invalid accumulator/result vector shape");
920 
921  // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
922  // types fully define the result vector type. This assumes the affine maps
923  // are well-formed, which must have been verified already.
924  MLIRContext *ctx = op.getContext();
925  AffineMap lhsMap = op.getIndexingMapsArray()[0];
926  AffineMap rhsMap = op.getIndexingMapsArray()[1];
927  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
928  return op.emitOpError(
929  "expected all dimensions to be either a LHS or a RHS dimension");
930  SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
931  for (auto pair :
932  {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
933  VectorType v = pair.first;
934  auto map = pair.second;
935  for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
936  unsigned pos = map.getDimPosition(idx);
937  if (!extents[pos])
938  extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
939  }
940  }
941  if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
942  return op.emitOpError("expected all dimensions to get an extent as "
943  "either a LHS or a RHS dimension");
944 
945  AffineMap resMap = op.getIndexingMapsArray()[2];
946  auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
947  /*symbolCount=*/0, extents, ctx);
948  // Compose the resMap with the extentsMap, which is a constant map.
949  AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
950  assert(llvm::all_of(expectedMap.getResults(),
951  llvm::IsaPred<AffineConstantExpr>) &&
952  "expected constant extent along all dimensions.");
953  // Extract the expected shape and build the type.
954  auto expectedShape = llvm::to_vector<4>(
955  llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
956  return cast<AffineConstantExpr>(e).getValue();
957  }));
958  auto expected =
959  VectorType::get(expectedShape, resVectorType.getElementType(),
960  resVectorType.getScalableDims());
961  if (resVectorType != expected || accVectorType != expected)
962  return op.emitOpError(
963  "invalid accumulator/result vector shape, expected: ")
964  << expected;
965  }
966  return success();
967 }
968 
969 LogicalResult ContractionOp::verify() {
970  VectorType lhsType = getLhsType();
971  VectorType rhsType = getRhsType();
972  Type accType = getAccType();
973  Type resType = getResultType();
974 
975  if (llvm::isa<IntegerType>(lhsType.getElementType())) {
976  if (!lhsType.getElementType().isSignlessInteger())
977  return emitOpError("only supports signless integer types");
978  }
979 
980  // Verify that an indexing map was specified for each vector operand.
981  if (getIndexingMapsArray().size() != 3)
982  return emitOpError("expected an indexing map for each vector operand");
983 
984  // Verify that each index map has 'numIterators' inputs, no symbols, and
985  // that the number of map outputs equals the rank of its associated
986  // vector operand.
987  unsigned numIterators = getIteratorTypes().getValue().size();
988  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
989  auto index = it.index();
990  auto map = it.value();
991  if (map.getNumSymbols() != 0)
992  return emitOpError("expected indexing map ")
993  << index << " to have no symbols";
994  auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
995  unsigned rank = vectorType ? vectorType.getShape().size() : 0;
996  // Verify that the map has the right number of inputs, outputs, and indices.
997  // This also correctly accounts for (..) -> () for rank-0 results.
998  if (map.getNumDims() != numIterators)
999  return emitOpError("expected indexing map ")
1000  << index << " to have " << numIterators << " number of inputs";
1001  if (map.getNumResults() != rank)
1002  return emitOpError("expected indexing map ")
1003  << index << " to have " << rank << " number of outputs";
1004  if (!map.isProjectedPermutation())
1005  return emitOpError("expected indexing map ")
1006  << index << " to be a projected permutation of its inputs";
1007  }
1008 
1009  auto contractingDimMap = getContractingDimMap();
1010  auto batchDimMap = getBatchDimMap();
1011 
1012  // Verify at least one contracting dimension pair was specified.
1013  if (contractingDimMap.empty())
1014  return emitOpError("expected at least one contracting dimension pair");
1015 
1016  // Verify contracting dimension map was properly constructed.
1017  if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1018  return emitOpError("invalid contracting dimension map");
1019 
1020  // Verify batch dimension map was properly constructed.
1021  if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1022  return emitOpError("invalid batch dimension map");
1023 
1024  // Verify 'accType' and 'resType' shape.
1025  if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1026  contractingDimMap, batchDimMap)))
1027  return failure();
1028 
1029  // Verify supported combining kind.
1030  auto vectorType = llvm::dyn_cast<VectorType>(resType);
1031  auto elementType = vectorType ? vectorType.getElementType() : resType;
1032  if (!isSupportedCombiningKind(getKind(), elementType))
1033  return emitOpError("unsupported contraction type");
1034 
1035  return success();
1036 }
1037 
1038 // MaskableOpInterface methods.
1039 
1040 /// Returns the mask type expected by this operation. Mostly used for
1041 /// verification purposes. It requires the operation to be vectorized."
1042 Type ContractionOp::getExpectedMaskType() {
1043  auto indexingMaps = this->getIndexingMapsArray();
1044  AffineMap lhsIdxMap = indexingMaps[0];
1045  AffineMap rhsIdxMap = indexingMaps[1];
1046  VectorType lhsType = this->getLhsType();
1047  VectorType rhsType = this->getRhsType();
1048 
1049  unsigned numVecDims = lhsIdxMap.getNumDims();
1050  SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1051  SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1052 
1053  // Using the information in the indexing maps, extract the size of each
1054  // dimension in the vector.contract operation from the two input operands.
1055  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1056  maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1057  maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1058  lhsType.getScalableDims()[dimIdx];
1059  }
1060  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1061  maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1062  maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1063  rhsType.getScalableDims()[dimIdx];
1064  }
1065 
1066  assert(!ShapedType::isDynamicShape(maskShape) &&
1067  "Mask shape couldn't be computed");
1068 
1069  return VectorType::get(maskShape,
1070  IntegerType::get(lhsType.getContext(), /*width=*/1),
1071  maskShapeScalableDims);
1072 }
1073 
1074 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1075  return SmallVector<StringRef>{getIndexingMapsAttrName(),
1076  getIteratorTypesAttrName(), getKindAttrName()};
1077 }
1078 
1079 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
1080  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1081  if (targetExpr == map.getResult(i))
1082  return i;
1083  return -1;
1084 }
1085 
1086 static std::vector<std::pair<int64_t, int64_t>>
1087 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1088  IteratorType targetIteratorType, MLIRContext *context) {
1089  std::vector<std::pair<int64_t, int64_t>> dimMap;
1090  for (const auto &it : llvm::enumerate(iteratorTypes)) {
1091  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1092  if (iteratorType != targetIteratorType)
1093  continue;
1094  // Search lhs/rhs map results for 'targetExpr'.
1095  auto targetExpr = getAffineDimExpr(it.index(), context);
1096  int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1097  int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1098  if (lhsDim >= 0 && rhsDim >= 0)
1099  dimMap.emplace_back(lhsDim, rhsDim);
1100  }
1101  return dimMap;
1102 }
1103 
1104 void ContractionOp::getIterationBounds(
1105  SmallVectorImpl<int64_t> &iterationBounds) {
1106  auto lhsShape = getLhsType().getShape();
1107  auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1108  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1109  SmallVector<int64_t, 2> iterationShape;
1110  for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1111  // Search lhs/rhs map results for 'targetExpr'.
1112  auto targetExpr = getAffineDimExpr(it.index(), getContext());
1113  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1114  if (iteratorType == IteratorType::reduction) {
1115  // Get reduction dim size from lhs shape (same size in rhsShape).
1116  int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1117  assert(lhsDimIndex >= 0);
1118  iterationBounds.push_back(lhsShape[lhsDimIndex]);
1119  continue;
1120  }
1121  // Get parallel dimension size from result shape.
1122  int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1123  assert(resDimIndex >= 0);
1124  assert(resVectorType != nullptr);
1125  iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1126  }
1127 }
1128 
1129 void ContractionOp::getIterationIndexMap(
1130  std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1131  unsigned numMaps = getIndexingMapsArray().size();
1132  iterationIndexMap.resize(numMaps);
1133  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1134  auto index = it.index();
1135  auto map = it.value();
1136  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1137  auto dim = cast<AffineDimExpr>(map.getResult(i));
1138  iterationIndexMap[index][dim.getPosition()] = i;
1139  }
1140  }
1141 }
1142 
1143 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1144  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1145  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1146  getContext());
1147 }
1148 
1149 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1150  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1151  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1152  getContext());
1153 }
1154 
1155 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1157  getIterationBounds(shape);
1158  return shape;
1159 }
1160 
1161 /// Return a fused vector::ContractionOp which represents a patterns such as:
1162 ///
1163 /// ```mlir
1164 /// %c0 = vector.constant 0: ...
1165 /// %c = vector.contract %a, %b, %c0: ...
1166 /// %e = add %c, %d: ...
1167 /// ```
1168 ///
1169 /// by:
1170 ///
1171 /// ```mlir
1172 /// %e = vector.contract %a, %b, %d: ...
1173 /// ```
1174 ///
1175 /// Return null if the canonicalization does not apply.
1176 // TODO: This should be a folding of Add into Contract in core but while they
1177 // live in different dialects, it is not possible without unnatural
1178 // dependencies.
1179 template <typename AddOpType>
1180 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1182 
1183  LogicalResult matchAndRewrite(AddOpType addOp,
1184  PatternRewriter &rewriter) const override {
1185  auto canonicalize = [&](Value maybeContraction,
1186  Value otherOperand) -> vector::ContractionOp {
1187  vector::ContractionOp contractionOp =
1188  dyn_cast_or_null<vector::ContractionOp>(
1189  maybeContraction.getDefiningOp());
1190  if (!contractionOp)
1191  return vector::ContractionOp();
1192  if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1193  contractionOp.getAcc().getDefiningOp())) {
1194  if (maybeZero.getValue() ==
1195  rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1196  IRMapping bvm;
1197  bvm.map(contractionOp.getAcc(), otherOperand);
1198  auto newContraction =
1199  cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1200  rewriter.replaceOp(addOp, newContraction.getResult());
1201  return newContraction;
1202  }
1203  }
1204  return vector::ContractionOp();
1205  };
1206 
1207  Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1208  vector::ContractionOp contract = canonicalize(a, b);
1209  contract = contract ? contract : canonicalize(b, a);
1210  return contract ? success() : failure();
1211  }
1212 };
1213 
1214 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1215  MLIRContext *context) {
1218 }
1219 
1220 //===----------------------------------------------------------------------===//
1221 // ExtractElementOp
1222 //===----------------------------------------------------------------------===//
1223 
1224 void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1225  SetIntRangeFn setResultRanges) {
1226  setResultRanges(getResult(), argRanges.front());
1227 }
1228 
1229 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1230  Value source) {
1231  result.addOperands({source});
1232  result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
1233 }
1234 
1235 LogicalResult vector::ExtractElementOp::verify() {
1236  VectorType vectorType = getSourceVectorType();
1237  if (vectorType.getRank() == 0) {
1238  if (getPosition())
1239  return emitOpError("expected position to be empty with 0-D vector");
1240  return success();
1241  }
1242  if (vectorType.getRank() != 1)
1243  return emitOpError("unexpected >1 vector rank");
1244  if (!getPosition())
1245  return emitOpError("expected position for 1-D vector");
1246  return success();
1247 }
1248 
1249 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1250  // Skip the 0-D vector here now.
1251  if (!adaptor.getPosition())
1252  return {};
1253 
1254  // Fold extractelement (splat X) -> X.
1255  if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
1256  return splat.getInput();
1257 
1258  // Fold extractelement(broadcast(X)) -> X.
1259  if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1260  if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
1261  return broadcast.getSource();
1262 
1263  auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1264  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1265  if (!pos || !src)
1266  return {};
1267 
1268  auto srcElements = src.getValues<Attribute>();
1269 
1270  uint64_t posIdx = pos.getInt();
1271  if (posIdx >= srcElements.size())
1272  return {};
1273 
1274  return srcElements[posIdx];
1275 }
1276 
1277 //===----------------------------------------------------------------------===//
1278 // ExtractOp
1279 //===----------------------------------------------------------------------===//
1280 
1281 void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1282  SetIntRangeFn setResultRanges) {
1283  setResultRanges(getResult(), argRanges.front());
1284 }
1285 
1286 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1287  Value source, int64_t position) {
1288  build(builder, result, source, ArrayRef<int64_t>{position});
1289 }
1290 
1291 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1292  Value source, OpFoldResult position) {
1293  build(builder, result, source, ArrayRef<OpFoldResult>{position});
1294 }
1295 
1296 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1297  Value source, ArrayRef<int64_t> position) {
1298  build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1299  builder.getDenseI64ArrayAttr(position));
1300 }
1301 
1302 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1303  Value source, ArrayRef<OpFoldResult> position) {
1304  SmallVector<int64_t> staticPos;
1305  SmallVector<Value> dynamicPos;
1306  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1307  build(builder, result, source, dynamicPos,
1308  builder.getDenseI64ArrayAttr(staticPos));
1309 }
1310 
1311 LogicalResult
1312 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1313  ExtractOp::Adaptor adaptor,
1314  SmallVectorImpl<Type> &inferredReturnTypes) {
1315  auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1316  if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1317  vectorType.getRank()) {
1318  inferredReturnTypes.push_back(vectorType.getElementType());
1319  } else {
1320  auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1321  vectorType.getRank());
1322  inferredReturnTypes.push_back(VectorType::get(
1323  vectorType.getShape().drop_front(n), vectorType.getElementType(),
1324  vectorType.getScalableDims().drop_front(n)));
1325  }
1326  return success();
1327 }
1328 
1329 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1330  // Allow extracting 1-element vectors instead of scalars.
1331  auto isCompatible = [](TypeRange l, TypeRange r) {
1332  auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1333  return vectorType && vectorType.getShape().equals({1}) &&
1334  vectorType.getElementType() == r.front();
1335  };
1336  if (l.size() == 1 && r.size() == 1 &&
1337  (isCompatible(l, r) || isCompatible(r, l)))
1338  return true;
1339  return l == r;
1340 }
1341 
1342 LogicalResult vector::ExtractOp::verify() {
1343  // Note: This check must come before getMixedPosition() to prevent a crash.
1344  auto dynamicMarkersCount =
1345  llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1346  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1347  return emitOpError(
1348  "mismatch between dynamic and static positions (kDynamic marker but no "
1349  "corresponding dynamic position) -- this can only happen due to an "
1350  "incorrect fold/rewrite");
1351  auto position = getMixedPosition();
1352  if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1353  return emitOpError(
1354  "expected position attribute of rank no greater than vector rank");
1355  for (auto [idx, pos] : llvm::enumerate(position)) {
1356  if (auto attr = dyn_cast<Attribute>(pos)) {
1357  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1358  if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
1359  return emitOpError("expected position attribute #")
1360  << (idx + 1)
1361  << " to be a non-negative integer smaller than the "
1362  "corresponding vector dimension";
1363  }
1364  }
1365  }
1366  return success();
1367 }
1368 
1369 template <typename IntType>
1370 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1371  return llvm::to_vector<4>(llvm::map_range(
1372  arrayAttr.getAsRange<IntegerAttr>(),
1373  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1374 }
1375 
1376 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1377 /// positions.
1378 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1379  if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1380  return failure();
1381 
1382  // TODO: Canonicalization for dynamic position not implemented yet.
1383  if (extractOp.hasDynamicPosition())
1384  return failure();
1385 
1386  SmallVector<int64_t> globalPosition;
1387  ExtractOp currentOp = extractOp;
1388  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1389  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1390  while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1391  currentOp = nextOp;
1392  // TODO: Canonicalization for dynamic position not implemented yet.
1393  if (currentOp.hasDynamicPosition())
1394  return failure();
1395  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1396  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1397  }
1398  extractOp.setOperand(0, currentOp.getVector());
1399  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1400  OpBuilder b(extractOp.getContext());
1401  std::reverse(globalPosition.begin(), globalPosition.end());
1402  extractOp.setStaticPosition(globalPosition);
1403  return success();
1404 }
1405 
1406 namespace {
1407 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1408 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1409 /// Compose TransposeOp permutations as we walk back.
1410 /// This helper class keeps an updated extraction position `extractPosition`
1411 /// with extra trailing sentinels.
1412 /// The sentinels encode the internal transposition status of the result vector.
1413 /// As we iterate, extractPosition is permuted and updated.
1414 class ExtractFromInsertTransposeChainState {
1415 public:
1416  ExtractFromInsertTransposeChainState(ExtractOp e);
1417 
1418  /// Iterate over producing insert and transpose ops until we find a fold.
1419  Value fold();
1420 
1421 private:
1422  /// Return true if the vector at position `a` is contained within the vector
1423  /// at position `b`. Under insert/extract semantics, this is the same as `a`
1424  /// is a prefix of `b`.
1425  template <typename ContainerA, typename ContainerB>
1426  bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1427  return a.size() <= b.size() &&
1428  std::equal(a.begin(), a.begin() + a.size(), b.begin());
1429  }
1430 
1431  /// Return true if the vector at position `a` intersects the vector at
1432  /// position `b`. Under insert/extract semantics, this is the same as equality
1433  /// of all entries of `a` that are >=0 with the corresponding entries of b.
1434  /// Comparison is on the common prefix (i.e. zip).
1435  template <typename ContainerA, typename ContainerB>
1436  bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1437  for (auto [elemA, elemB] : llvm::zip(a, b)) {
1438  if (elemA < 0 || elemB < 0)
1439  continue;
1440  if (elemA != elemB)
1441  return false;
1442  }
1443  return true;
1444  }
1445 
1446  /// Folding is only possible in the absence of an internal permutation in the
1447  /// result vector.
1448  bool canFold() {
1449  return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1450  }
1451 
1452  // Helper to get the next defining op of interest.
1453  void updateStateForNextIteration(Value v) {
1454  nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1455  nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1456  };
1457 
1458  // Case 1. If we hit a transpose, just compose the map and iterate.
1459  // Invariant: insert + transpose do not change rank, we can always compose.
1460  LogicalResult handleTransposeOp();
1461 
1462  // Case 2: the insert position matches extractPosition exactly, early return.
1463  LogicalResult handleInsertOpWithMatchingPos(Value &res);
1464 
1465  /// Case 3: if the insert position is a prefix of extractPosition, extract a
1466  /// portion of the source of the insert.
1467  /// Example:
1468  /// ```
1469  /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1470  /// // extractPosition == [1, 2, 3]
1471  /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1472  /// // can fold to vector.extract %source[0, 3]
1473  /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1474  /// ```
1475  /// To traverse through %source, we need to set the leading dims to 0 and
1476  /// drop the extra leading dims.
1477  /// This method updates the internal state.
1478  LogicalResult handleInsertOpWithPrefixPos(Value &res);
1479 
1480  /// Try to fold in place to extract(source, extractPosition) and return the
1481  /// folded result. Return null if folding is not possible (e.g. due to an
1482  /// internal tranposition in the result).
1483  Value tryToFoldExtractOpInPlace(Value source);
1484 
1485  ExtractOp extractOp;
1486  int64_t vectorRank;
1487  int64_t extractedRank;
1488 
1489  InsertOp nextInsertOp;
1490  TransposeOp nextTransposeOp;
1491 
1492  /// Sentinel values that encode the internal permutation status of the result.
1493  /// They are set to (-1, ... , -k) at the beginning and appended to
1494  /// `extractPosition`.
1495  /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1496  /// ensure that there is no internal transposition.
1497  /// Internal transposition cannot be accounted for with a folding pattern.
1498  // TODO: We could relax the internal transposition with an extra transposition
1499  // operation in a future canonicalizer.
1500  SmallVector<int64_t> sentinels;
1502 };
1503 } // namespace
1504 
1505 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1506  ExtractOp e)
1507  : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1508  extractedRank(extractOp.getNumIndices()) {
1509  assert(vectorRank >= extractedRank && "Extracted position overflow");
1510  sentinels.reserve(vectorRank - extractedRank);
1511  for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1512  sentinels.push_back(-(i + 1));
1513  extractPosition.assign(extractOp.getStaticPosition().begin(),
1514  extractOp.getStaticPosition().end());
1515  llvm::append_range(extractPosition, sentinels);
1516 }
1517 
1518 // Case 1. If we hit a transpose, just compose the map and iterate.
1519 // Invariant: insert + transpose do not change rank, we can always compose.
1520 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1521  // TODO: Canonicalization for dynamic position not implemented yet.
1522  if (extractOp.hasDynamicPosition())
1523  return failure();
1524 
1525  if (!nextTransposeOp)
1526  return failure();
1527  AffineMap m = inversePermutation(AffineMap::getPermutationMap(
1528  nextTransposeOp.getPermutation(), extractOp.getContext()));
1530  return success();
1531 }
1532 
1533 // Case 2: the insert position matches extractPosition exactly, early return.
1534 LogicalResult
1535 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1536  Value &res) {
1537  // TODO: Canonicalization for dynamic position not implemented yet.
1538  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1539  return failure();
1540 
1541  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1542  if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1543  return failure();
1544  // Case 2.a. early-exit fold.
1545  res = nextInsertOp.getSource();
1546  // Case 2.b. if internal transposition is present, canFold will be false.
1547  return success(canFold());
1548 }
1549 
1550 /// Case 3: if inserted position is a prefix of extractPosition,
1551 /// extract a portion of the source of the insertion.
1552 /// This method updates the internal state.
1553 LogicalResult
1554 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1555  // TODO: Canonicalization for dynamic position not implemented yet.
1556  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1557  return failure();
1558 
1559  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1560  if (!isContainedWithin(insertedPos, extractPosition))
1561  return failure();
1562  // Set leading dims to zero.
1563  std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1564  // Drop extra leading dims.
1565  extractPosition.erase(extractPosition.begin(),
1566  extractPosition.begin() + insertedPos.size());
1567  extractedRank = extractPosition.size() - sentinels.size();
1568  // Case 3.a. early-exit fold (break and delegate to post-while path).
1569  res = nextInsertOp.getSource();
1570  // Case 3.b. if internal transposition is present, canFold will be false.
1571  return success();
1572 }
1573 
1574 /// Try to fold in place to extract(source, extractPosition) and return the
1575 /// folded result. Return null if folding is not possible (e.g. due to an
1576 /// internal tranposition in the result).
1577 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1578  Value source) {
1579  // TODO: Canonicalization for dynamic position not implemented yet.
1580  if (extractOp.hasDynamicPosition())
1581  return Value();
1582 
1583  // If we can't fold (either internal transposition, or nothing to fold), bail.
1584  bool nothingToFold = (source == extractOp.getVector());
1585  if (nothingToFold || !canFold())
1586  return Value();
1587 
1588  // Otherwise, fold by updating the op inplace and return its result.
1589  OpBuilder b(extractOp.getContext());
1590  extractOp.setStaticPosition(
1591  ArrayRef(extractPosition).take_front(extractedRank));
1592  extractOp.getVectorMutable().assign(source);
1593  return extractOp.getResult();
1594 }
1595 
1596 /// Iterate over producing insert and transpose ops until we find a fold.
1597 Value ExtractFromInsertTransposeChainState::fold() {
1598  // TODO: Canonicalization for dynamic position not implemented yet.
1599  if (extractOp.hasDynamicPosition())
1600  return Value();
1601 
1602  Value valueToExtractFrom = extractOp.getVector();
1603  updateStateForNextIteration(valueToExtractFrom);
1604  while (nextInsertOp || nextTransposeOp) {
1605  // Case 1. If we hit a transpose, just compose the map and iterate.
1606  // Invariant: insert + transpose do not change rank, we can always compose.
1607  if (succeeded(handleTransposeOp())) {
1608  valueToExtractFrom = nextTransposeOp.getVector();
1609  updateStateForNextIteration(valueToExtractFrom);
1610  continue;
1611  }
1612 
1613  Value result;
1614  // Case 2: the position match exactly.
1615  if (succeeded(handleInsertOpWithMatchingPos(result)))
1616  return result;
1617 
1618  // Case 3: if the inserted position is a prefix of extractPosition, we can
1619  // just extract a portion of the source of the insert.
1620  if (succeeded(handleInsertOpWithPrefixPos(result)))
1621  return tryToFoldExtractOpInPlace(result);
1622 
1623  // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1624  // values. This is a more difficult case and we bail.
1625  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1626  if (isContainedWithin(extractPosition, insertedPos) ||
1627  intersectsWhereNonNegative(extractPosition, insertedPos))
1628  return Value();
1629 
1630  // Case 5: No intersection, we forward the extract to insertOp.dest().
1631  valueToExtractFrom = nextInsertOp.getDest();
1632  updateStateForNextIteration(valueToExtractFrom);
1633  }
1634  // If after all this we can fold, go for it.
1635  return tryToFoldExtractOpInPlace(valueToExtractFrom);
1636 }
1637 
1638 /// Returns true if the operation has a 0-D vector type operand or result.
1639 static bool hasZeroDimVectors(Operation *op) {
1640  auto hasZeroDimVectorType = [](Type type) -> bool {
1641  auto vecType = dyn_cast<VectorType>(type);
1642  return vecType && vecType.getRank() == 0;
1643  };
1644 
1645  return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1646  llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1647 }
1648 
1649 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1650 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1651  // TODO: Canonicalization for dynamic position not implemented yet.
1652  if (extractOp.hasDynamicPosition())
1653  return Value();
1654 
1655  Operation *defOp = extractOp.getVector().getDefiningOp();
1656  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1657  return Value();
1658 
1659  Value source = defOp->getOperand(0);
1660  if (extractOp.getType() == source.getType())
1661  return source;
1662  auto getRank = [](Type type) {
1663  return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1664  : 0;
1665  };
1666 
1667  // If splat or broadcast from a scalar, just return the source scalar.
1668  unsigned broadcastSrcRank = getRank(source.getType());
1669  if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
1670  return source;
1671 
1672  unsigned extractResultRank = getRank(extractOp.getType());
1673  if (extractResultRank >= broadcastSrcRank)
1674  return Value();
1675  // Check that the dimension of the result haven't been broadcasted.
1676  auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1677  auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
1678  if (extractVecType && broadcastVecType &&
1679  extractVecType.getShape() !=
1680  broadcastVecType.getShape().take_back(extractResultRank))
1681  return Value();
1682 
1683  auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1684  int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1685 
1686  // Detect all the positions that come from "dim-1" broadcasting.
1687  // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1688  // extract position to `0` when extracting from the source operand.
1689  llvm::SetVector<int64_t> broadcastedUnitDims =
1690  broadcastOp.computeBroadcastedUnitDims();
1691  SmallVector<int64_t> extractPos(extractOp.getStaticPosition());
1692  int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1693  for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1694  if (broadcastedUnitDims.contains(i))
1695  extractPos[i] = 0;
1696  // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1697  // matching extract position when extracting from the source operand.
1698  int64_t rankDiff = broadcastSrcRank - extractResultRank;
1699  extractPos.erase(extractPos.begin(),
1700  std::next(extractPos.begin(), extractPos.size() - rankDiff));
1701  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1702  OpBuilder b(extractOp.getContext());
1703  extractOp.setOperand(0, source);
1704  extractOp.setStaticPosition(extractPos);
1705  return extractOp.getResult();
1706 }
1707 
1708 /// Fold extractOp coming from ShuffleOp.
1709 ///
1710 /// Example:
1711 ///
1712 /// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1713 /// : vector<8xf32>, vector<8xf32>
1714 /// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1715 /// ->
1716 /// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1717 ///
1718 static Value foldExtractFromShuffle(ExtractOp extractOp) {
1719  // Dynamic positions are not folded as the resulting code would be more
1720  // complex than the input code.
1721  if (extractOp.hasDynamicPosition())
1722  return Value();
1723 
1724  auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1725  if (!shuffleOp)
1726  return Value();
1727 
1728  // TODO: 0-D or multi-dimensional vectors not supported yet.
1729  if (shuffleOp.getResultVectorType().getRank() != 1)
1730  return Value();
1731 
1732  int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1733  auto shuffleMask = shuffleOp.getMask();
1734  int64_t extractIdx = extractOp.getStaticPosition()[0];
1735  int64_t shuffleIdx = shuffleMask[extractIdx];
1736 
1737  // Find the shuffled vector to extract from based on the shuffle index.
1738  if (shuffleIdx < inputVecSize) {
1739  extractOp.setOperand(0, shuffleOp.getV1());
1740  extractOp.setStaticPosition({shuffleIdx});
1741  } else {
1742  extractOp.setOperand(0, shuffleOp.getV2());
1743  extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1744  }
1745 
1746  return extractOp.getResult();
1747 }
1748 
1749 // Fold extractOp with source coming from ShapeCast op.
1750 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1751  // TODO: Canonicalization for dynamic position not implemented yet.
1752  if (extractOp.hasDynamicPosition())
1753  return Value();
1754 
1755  auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1756  if (!shapeCastOp)
1757  return Value();
1758 
1759  // Get the nth dimension size starting from lowest dimension.
1760  auto getDimReverse = [](VectorType type, int64_t n) {
1761  return type.getShape().take_back(n + 1).front();
1762  };
1763  int64_t destinationRank =
1764  llvm::isa<VectorType>(extractOp.getType())
1765  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1766  : 0;
1767  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1768  return Value();
1769  if (destinationRank > 0) {
1770  auto destinationType =
1771  llvm::cast<VectorType>(extractOp.getResult().getType());
1772  for (int64_t i = 0; i < destinationRank; i++) {
1773  // The lowest dimension of the destination must match the lowest
1774  // dimension of the shapecast op source.
1775  // TODO: This case could be support in a canonicalization pattern.
1776  if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1777  getDimReverse(destinationType, i))
1778  return Value();
1779  }
1780  }
1781  // Extract the strides associated with the extract op vector source. Then use
1782  // this to calculate a linearized position for the extract.
1783  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1784  std::reverse(extractedPos.begin(), extractedPos.end());
1785  SmallVector<int64_t, 4> strides;
1786  int64_t stride = 1;
1787  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1788  strides.push_back(stride);
1789  stride *=
1790  getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1791  }
1792 
1793  int64_t position = linearize(extractedPos, strides);
1794  // Then extract the strides associated to the shapeCast op vector source and
1795  // delinearize the position using those strides.
1796  SmallVector<int64_t, 4> newStrides;
1797  int64_t numDimension =
1798  shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1799  stride = 1;
1800  for (int64_t i = 0; i < numDimension; i++) {
1801  newStrides.push_back(stride);
1802  stride *=
1803  getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1804  }
1805  std::reverse(newStrides.begin(), newStrides.end());
1806  SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1807  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1808  OpBuilder b(extractOp.getContext());
1809  extractOp.setStaticPosition(newPosition);
1810  extractOp.setOperand(0, shapeCastOp.getSource());
1811  return extractOp.getResult();
1812 }
1813 
1814 /// Fold an ExtractOp from ExtractStridedSliceOp.
1815 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1816  // TODO: Canonicalization for dynamic position not implemented yet.
1817  if (extractOp.hasDynamicPosition())
1818  return Value();
1819 
1820  auto extractStridedSliceOp =
1821  extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1822  if (!extractStridedSliceOp)
1823  return Value();
1824 
1825  // 0-D vectors not supported.
1826  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1827  if (hasZeroDimVectors(extractStridedSliceOp))
1828  return Value();
1829 
1830  // Return if 'extractStridedSliceOp' has non-unit strides.
1831  if (extractStridedSliceOp.hasNonUnitStrides())
1832  return Value();
1833 
1834  // Trim offsets for dimensions fully extracted.
1835  auto sliceOffsets =
1836  extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1837  while (!sliceOffsets.empty()) {
1838  size_t lastOffset = sliceOffsets.size() - 1;
1839  if (sliceOffsets.back() != 0 ||
1840  extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1841  extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1842  break;
1843  sliceOffsets.pop_back();
1844  }
1845  unsigned destinationRank = 0;
1846  if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1847  destinationRank = vecType.getRank();
1848  // The dimensions of the result need to be untouched by the
1849  // extractStridedSlice op.
1850  if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1851  sliceOffsets.size())
1852  return Value();
1853 
1854  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1855  assert(extractedPos.size() >= sliceOffsets.size());
1856  for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1857  extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1858  extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1859 
1860  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1861  OpBuilder b(extractOp.getContext());
1862  extractOp.setStaticPosition(extractedPos);
1863  return extractOp.getResult();
1864 }
1865 
1866 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1867 static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1868  // TODO: Canonicalization for dynamic position not implemented yet.
1869  if (extractOp.hasDynamicPosition())
1870  return Value();
1871 
1872  int64_t destinationRank =
1873  llvm::isa<VectorType>(extractOp.getType())
1874  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1875  : 0;
1876  auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1877  if (!insertOp)
1878  return Value();
1879 
1880  // 0-D vectors not supported.
1881  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1882  if (hasZeroDimVectors(insertOp))
1883  return Value();
1884 
1885  while (insertOp) {
1886  int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1887  insertOp.getSourceVectorType().getRank();
1888  if (destinationRank > insertOp.getSourceVectorType().getRank())
1889  return Value();
1890  auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1891  ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1892 
1893  if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1894  return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1895  }))
1896  return Value();
1897  bool disjoint = false;
1898  SmallVector<int64_t, 4> offsetDiffs;
1899  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1900  int64_t start = insertOffsets[dim];
1901  int64_t size =
1902  (dim < insertRankDiff)
1903  ? 1
1904  : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1905  int64_t end = start + size;
1906  int64_t offset = extractOffsets[dim];
1907  // Check if the start of the extract offset is in the interval inserted.
1908  if (start <= offset && offset < end) {
1909  if (dim >= insertRankDiff)
1910  offsetDiffs.push_back(offset - start);
1911  continue;
1912  }
1913  disjoint = true;
1914  break;
1915  }
1916  // The extract element chunk overlap with the vector inserted.
1917  if (!disjoint) {
1918  // If any of the inner dimensions are only partially inserted we have a
1919  // partial overlap.
1920  int64_t srcRankDiff =
1921  insertOp.getSourceVectorType().getRank() - destinationRank;
1922  for (int64_t i = 0; i < destinationRank; i++) {
1923  if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1924  insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1925  insertRankDiff))
1926  return Value();
1927  }
1928  extractOp.getVectorMutable().assign(insertOp.getSource());
1929  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1930  OpBuilder b(extractOp.getContext());
1931  extractOp.setStaticPosition(offsetDiffs);
1932  return extractOp.getResult();
1933  }
1934  // If the chunk extracted is disjoint from the chunk inserted, keep
1935  // looking in the insert chain.
1936  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1937  }
1938  return Value();
1939 }
1940 
1941 /// Try to fold the extraction of a scalar from a vector defined by
1942 /// vector.from_elements. E.g.:
1943 ///
1944 /// %0 = vector.from_elements %a, %b : vector<2xf32>
1945 /// %1 = vector.extract %0[0] : f32 from vector<2xf32>
1946 /// ==> fold to %a
1947 static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
1948  // Dynamic extractions cannot be folded.
1949  if (extractOp.hasDynamicPosition())
1950  return {};
1951 
1952  // Look for extract(from_elements).
1953  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
1954  if (!fromElementsOp)
1955  return {};
1956 
1957  // Scalable vectors are not supported.
1958  auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1959  if (vecType.isScalable())
1960  return {};
1961 
1962  // Only extractions of scalars are supported.
1963  int64_t rank = vecType.getRank();
1964  ArrayRef<int64_t> indices = extractOp.getStaticPosition();
1965  if (extractOp.getType() != vecType.getElementType())
1966  return {};
1967  assert(static_cast<int64_t>(indices.size()) == rank &&
1968  "unexpected number of indices");
1969 
1970  // Compute flattened/linearized index and fold to operand.
1971  int flatIndex = 0;
1972  int stride = 1;
1973  for (int i = rank - 1; i >= 0; --i) {
1974  flatIndex += indices[i] * stride;
1975  stride *= vecType.getDimSize(i);
1976  }
1977  return fromElementsOp.getElements()[flatIndex];
1978 }
1979 
1980 OpFoldResult ExtractOp::fold(FoldAdaptor) {
1981  // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
1982  // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
1983  // mismatch).
1984  if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
1985  return getVector();
1986  if (succeeded(foldExtractOpFromExtractChain(*this)))
1987  return getResult();
1988  if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
1989  return res;
1990  if (auto res = foldExtractFromBroadcast(*this))
1991  return res;
1992  if (auto res = foldExtractFromShuffle(*this))
1993  return res;
1994  if (auto res = foldExtractFromShapeCast(*this))
1995  return res;
1996  if (auto val = foldExtractFromExtractStrided(*this))
1997  return val;
1998  if (auto val = foldExtractStridedOpFromInsertChain(*this))
1999  return val;
2000  if (auto val = foldScalarExtractFromFromElements(*this))
2001  return val;
2002  return OpFoldResult();
2003 }
2004 
2005 namespace {
2006 
2007 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2008 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2009 public:
2011 
2012  LogicalResult matchAndRewrite(ExtractOp extractOp,
2013  PatternRewriter &rewriter) const override {
2014  Operation *defOp = extractOp.getVector().getDefiningOp();
2015  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2016  return failure();
2017 
2018  Value source = defOp->getOperand(0);
2019  if (extractOp.getType() == source.getType())
2020  return failure();
2021  auto getRank = [](Type type) {
2022  return llvm::isa<VectorType>(type)
2023  ? llvm::cast<VectorType>(type).getRank()
2024  : 0;
2025  };
2026  unsigned broadcastSrcRank = getRank(source.getType());
2027  unsigned extractResultRank = getRank(extractOp.getType());
2028  // We only consider the case where the rank of the source is less than or
2029  // equal to the rank of the extract dst. The other cases are handled in the
2030  // folding patterns.
2031  if (extractResultRank < broadcastSrcRank)
2032  return failure();
2033 
2034  // Special case if broadcast src is a 0D vector.
2035  if (extractResultRank == 0) {
2036  assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
2037  rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
2038  return success();
2039  }
2040  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2041  extractOp, extractOp.getType(), source);
2042  return success();
2043  }
2044 };
2045 
2046 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
2047 class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
2048 public:
2050 
2051  LogicalResult matchAndRewrite(ExtractOp extractOp,
2052  PatternRewriter &rewriter) const override {
2053  // Return if 'ExtractOp' operand is not defined by a splat vector
2054  // ConstantOp.
2055  Value sourceVector = extractOp.getVector();
2056  Attribute vectorCst;
2057  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2058  return failure();
2059  auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
2060  if (!splat)
2061  return failure();
2062  TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
2063  if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
2064  newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2065  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2066  return success();
2067  }
2068 };
2069 
2070 // Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
2071 class ExtractOpNonSplatConstantFolder final
2072  : public OpRewritePattern<ExtractOp> {
2073 public:
2075 
2076  LogicalResult matchAndRewrite(ExtractOp extractOp,
2077  PatternRewriter &rewriter) const override {
2078  // TODO: Canonicalization for dynamic position not implemented yet.
2079  if (extractOp.hasDynamicPosition())
2080  return failure();
2081 
2082  // Return if 'ExtractOp' operand is not defined by a compatible vector
2083  // ConstantOp.
2084  Value sourceVector = extractOp.getVector();
2085  Attribute vectorCst;
2086  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2087  return failure();
2088 
2089  auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
2090  if (vecTy.isScalable())
2091  return failure();
2092 
2093  // The splat case is handled by `ExtractOpSplatConstantFolder`.
2094  auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2095  if (!dense || dense.isSplat())
2096  return failure();
2097 
2098  // Calculate the linearized position of the continuous chunk of elements to
2099  // extract.
2100  llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2101  copy(extractOp.getStaticPosition(), completePositions.begin());
2102  int64_t elemBeginPosition =
2103  linearize(completePositions, computeStrides(vecTy.getShape()));
2104  auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
2105 
2106  TypedAttr newAttr;
2107  if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
2108  SmallVector<Attribute> elementValues(
2109  denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2110  newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2111  } else {
2112  newAttr = *denseValuesBegin;
2113  }
2114 
2115  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2116  return success();
2117  }
2118 };
2119 
2120 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2121 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2122 public:
2124 
2125  LogicalResult matchAndRewrite(ExtractOp extractOp,
2126  PatternRewriter &rewriter) const override {
2127  auto createMaskOp =
2128  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2129  if (!createMaskOp)
2130  return failure();
2131 
2132  VectorType extractedMaskType =
2133  llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2134 
2135  if (!extractedMaskType)
2136  return failure();
2137 
2138  auto maskOperands = createMaskOp.getOperands();
2139  ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2140  VectorType maskType = createMaskOp.getVectorType();
2141 
2142  bool containsUnknownDims = false;
2143  bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2144 
2145  for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2146  dimIdx++) {
2147  int64_t pos = extractOpPos[dimIdx];
2148  Value operand = maskOperands[dimIdx];
2149  auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2150  if (!constantOp) {
2151  // Bounds of this dim unknown.
2152  containsUnknownDims = true;
2153  continue;
2154  }
2155 
2156  int64_t createMaskBound =
2157  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2158 
2159  if (pos != ShapedType::kDynamic) {
2160  // If any position is outside the range from the `create_mask`, then the
2161  // extracted mask will be all-false.
2162  allFalse |= pos >= createMaskBound;
2163  } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2164  // This dim is not all-true and since this is a dynamic index we don't
2165  // know if the extraction is within the true or false region.
2166  // Note: Zero dims have already handled via getMaskFormat().
2167  containsUnknownDims = true;
2168  }
2169  }
2170 
2171  if (allFalse) {
2172  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2173  extractOp, DenseElementsAttr::get(extractedMaskType, false));
2174  } else if (!containsUnknownDims) {
2175  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2176  extractOp, extractedMaskType,
2177  maskOperands.drop_front(extractOpPos.size()));
2178  } else {
2179  return failure();
2180  }
2181  return success();
2182  }
2183 };
2184 
2185 // Folds extract(shape_cast(..)) into shape_cast when the total element count
2186 // does not change.
2187 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2188  PatternRewriter &rewriter) {
2189  auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2190  if (!castOp)
2191  return failure();
2192 
2193  VectorType sourceType = castOp.getSourceVectorType();
2194  auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2195  if (!targetType)
2196  return failure();
2197 
2198  if (sourceType.getNumElements() != targetType.getNumElements())
2199  return failure();
2200 
2201  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2202  castOp.getSource());
2203  return success();
2204 }
2205 
2206 /// Try to canonicalize the extraction of a subvector from a vector defined by
2207 /// vector.from_elements. E.g.:
2208 ///
2209 /// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2210 /// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2211 /// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2212 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2213  PatternRewriter &rewriter) {
2214  // Dynamic positions are not supported.
2215  if (extractOp.hasDynamicPosition())
2216  return failure();
2217 
2218  // Scalar extracts are handled by the folder.
2219  auto resultType = dyn_cast<VectorType>(extractOp.getType());
2220  if (!resultType)
2221  return failure();
2222 
2223  // Look for extracts from a from_elements op.
2224  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2225  if (!fromElementsOp)
2226  return failure();
2227  VectorType inputType = fromElementsOp.getType();
2228 
2229  // Scalable vectors are not supported.
2230  if (resultType.isScalable() || inputType.isScalable())
2231  return failure();
2232 
2233  // Compute the position of first extracted element and flatten/linearize the
2234  // position.
2235  SmallVector<int64_t> firstElementPos =
2236  llvm::to_vector(extractOp.getStaticPosition());
2237  firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2238  int flatIndex = 0;
2239  int stride = 1;
2240  for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2241  flatIndex += firstElementPos[i] * stride;
2242  stride *= inputType.getDimSize(i);
2243  }
2244 
2245  // Replace the op with a smaller from_elements op.
2246  rewriter.replaceOpWithNewOp<FromElementsOp>(
2247  extractOp, resultType,
2248  fromElementsOp.getElements().slice(flatIndex,
2249  resultType.getNumElements()));
2250  return success();
2251 }
2252 } // namespace
2253 
2254 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2255  MLIRContext *context) {
2256  results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2257  ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2258  results.add(foldExtractFromShapeCastToShapeCast);
2259  results.add(foldExtractFromFromElements);
2260 }
2261 
2262 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
2263  SmallVectorImpl<int64_t> &results) {
2264  for (auto attr : arrayAttr)
2265  results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2266 }
2267 
2268 //===----------------------------------------------------------------------===//
2269 // FmaOp
2270 //===----------------------------------------------------------------------===//
2271 
2272 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2273  return llvm::to_vector<4>(getVectorType().getShape());
2274 }
2275 
2276 //===----------------------------------------------------------------------===//
2277 // FromElementsOp
2278 //===----------------------------------------------------------------------===//
2279 
2280 /// Rewrite a vector.from_elements into a vector.splat if all elements are the
2281 /// same SSA value. E.g.:
2282 ///
2283 /// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2284 /// ==> rewrite to vector.splat %a : vector<3xf32>
2285 static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
2286  PatternRewriter &rewriter) {
2287  if (!llvm::all_equal(fromElementsOp.getElements()))
2288  return failure();
2289  rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
2290  fromElementsOp.getElements().front());
2291  return success();
2292 }
2293 
2294 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2295  MLIRContext *context) {
2297 }
2298 
2299 //===----------------------------------------------------------------------===//
2300 // BroadcastOp
2301 //===----------------------------------------------------------------------===//
2302 
2303 void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2304  SetIntRangeFn setResultRanges) {
2305  setResultRanges(getResult(), argRanges.front());
2306 }
2307 
2308 /// Return the dimensions of the result vector that were formerly ones in the
2309 /// source tensor and thus correspond to "dim-1" broadcasting.
2312  ArrayRef<int64_t> dstShape) {
2313  int64_t rankDiff = dstShape.size() - srcShape.size();
2314  int64_t dstDim = rankDiff;
2316  for (auto [s1, s2] :
2317  llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2318  if (s1 != s2) {
2319  assert(s1 == 1 && "expected dim-1 broadcasting");
2320  res.insert(dstDim);
2321  }
2322  ++dstDim;
2323  }
2324  return res;
2325 }
2326 
2328  // Scalar broadcast is without any unit dim broadcast.
2329  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2330  if (!srcVectorType)
2331  return {};
2332  return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2333  getResultVectorType().getShape());
2334 }
2335 
2336 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2337 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
2338 /// This requires (and asserts) that the broadcast is free of dim-1
2339 /// broadcasting.
2340 /// Since vector.broadcast only allows expanding leading dimensions, an extra
2341 /// vector.transpose may be inserted to make the broadcast possible.
2342 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2343 /// the helper will assert. This means:
2344 /// 1. `dstShape` must not be empty.
2345 /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2346 /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2347 // must match the `value` shape.
2348 Value BroadcastOp::createOrFoldBroadcastOp(
2349  OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2350  const llvm::SetVector<int64_t> &broadcastedDims) {
2351  assert(!dstShape.empty() && "unexpected empty dst shape");
2352 
2353  // Well-formedness check.
2354  SmallVector<int64_t> checkShape;
2355  for (int i = 0, e = dstShape.size(); i < e; ++i) {
2356  if (broadcastedDims.contains(i))
2357  continue;
2358  checkShape.push_back(dstShape[i]);
2359  }
2360  assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2361  "ill-formed broadcastedDims contains values not confined to "
2362  "destVectorShape");
2363 
2364  Location loc = value.getLoc();
2365  Type elementType = getElementTypeOrSelf(value.getType());
2366  VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2367  VectorType dstVectorType = VectorType::get(dstShape, elementType);
2368 
2369  // Step 2. If scalar -> dstShape broadcast, just do it.
2370  if (!srcVectorType) {
2371  assert(checkShape.empty() &&
2372  "ill-formed createOrFoldBroadcastOp arguments");
2373  return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2374  }
2375 
2376  assert(srcVectorType.getShape().equals(checkShape) &&
2377  "ill-formed createOrFoldBroadcastOp arguments");
2378 
2379  // Step 3. Since vector.broadcast only allows creating leading dims,
2380  // vector -> dstShape broadcast may require a transpose.
2381  // Traverse the dims in order and construct:
2382  // 1. The leading entries of the broadcastShape that is guaranteed to be
2383  // achievable by a simple broadcast.
2384  // 2. The induced permutation for the subsequent vector.transpose that will
2385  // bring us from `broadcastShape` back to he desired `dstShape`.
2386  // If the induced permutation is not the identity, create a vector.transpose.
2387  SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2388  broadcastShape.reserve(dstShape.size());
2389  // Consider the example:
2390  // srcShape = 2x4
2391  // dstShape = 1x2x3x4x5
2392  // broadcastedDims = [0, 2, 4]
2393  //
2394  // We want to build:
2395  // broadcastShape = 1x3x5x2x4
2396  // permutation = [0, 2, 4, 1, 3]
2397  // ---V--- -----V-----
2398  // leading broadcast part src shape part
2399  //
2400  // Note that the trailing dims of broadcastShape are exactly the srcShape
2401  // by construction.
2402  // nextSrcShapeDim is used to keep track of where in the permutation the
2403  // "src shape part" occurs.
2404  int64_t nextSrcShapeDim = broadcastedDims.size();
2405  for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2406  if (broadcastedDims.contains(i)) {
2407  // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2408  // bring it to the head of the broadcastShape.
2409  // It will need to be permuted back from `broadcastShape.size() - 1` into
2410  // position `i`.
2411  broadcastShape.push_back(dstShape[i]);
2412  permutation[i] = broadcastShape.size() - 1;
2413  } else {
2414  // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2415  // shape and needs to be permuted into position `i`.
2416  // Don't touch `broadcastShape` here, the whole srcShape will be
2417  // appended after.
2418  permutation[i] = nextSrcShapeDim++;
2419  }
2420  }
2421  // 3.c. Append the srcShape.
2422  llvm::append_range(broadcastShape, srcVectorType.getShape());
2423 
2424  // Ensure there are no dim-1 broadcasts.
2425  assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2426  .empty() &&
2427  "unexpected dim-1 broadcast");
2428 
2429  VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2430  assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2431  vector::BroadcastableToResult::Success &&
2432  "must be broadcastable");
2433  Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2434  // Step 4. If we find any dimension that indeed needs to be permuted,
2435  // immediately return a new vector.transpose.
2436  for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2437  if (permutation[i] != i)
2438  return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2439  // Otherwise return res.
2440  return res;
2441 }
2442 
2444  Type srcType, VectorType dstVectorType,
2445  std::pair<VectorDim, VectorDim> *mismatchingDims) {
2446  // Broadcast scalar to vector of the same element type.
2447  if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
2448  getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
2449  return BroadcastableToResult::Success;
2450  // From now on, only vectors broadcast.
2451  VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2452  if (!srcVectorType)
2453  return BroadcastableToResult::SourceTypeNotAVector;
2454 
2455  int64_t srcRank = srcVectorType.getRank();
2456  int64_t dstRank = dstVectorType.getRank();
2457  if (srcRank > dstRank)
2458  return BroadcastableToResult::SourceRankHigher;
2459  // Source has an exact match or singleton value for all trailing dimensions
2460  // (all leading dimensions are simply duplicated).
2461  int64_t lead = dstRank - srcRank;
2462  for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2463  // Have mismatching dims (in the sense of vector.broadcast semantics) been
2464  // encountered?
2465  bool foundMismatchingDims = false;
2466 
2467  // Check fixed-width dims.
2468  int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2469  int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2470  if (srcDim != 1 && srcDim != dstDim)
2471  foundMismatchingDims = true;
2472 
2473  // Check scalable flags.
2474  bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2475  bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2476  if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2477  // 1 -> [N] is fine, everything else should be rejected when mixing
2478  // fixed-width and scalable dims
2479  (srcDimScalableFlag != dstDimScalableFlag &&
2480  (srcDim != 1 || srcDimScalableFlag)))
2481  foundMismatchingDims = true;
2482 
2483  if (foundMismatchingDims) {
2484  if (mismatchingDims != nullptr) {
2485  mismatchingDims->first.dim = srcDim;
2486  mismatchingDims->first.isScalable = srcDimScalableFlag;
2487 
2488  mismatchingDims->second.dim = dstDim;
2489  mismatchingDims->second.isScalable = dstDimScalableFlag;
2490  }
2491  return BroadcastableToResult::DimensionMismatch;
2492  }
2493  }
2494 
2495  return BroadcastableToResult::Success;
2496 }
2497 
2498 LogicalResult BroadcastOp::verify() {
2499  std::pair<VectorDim, VectorDim> mismatchingDims;
2501  getSourceType(), getResultVectorType(), &mismatchingDims);
2502  if (res == BroadcastableToResult::Success)
2503  return success();
2504  if (res == BroadcastableToResult::SourceRankHigher)
2505  return emitOpError("source rank higher than destination rank");
2506  if (res == BroadcastableToResult::DimensionMismatch) {
2507  return emitOpError("dimension mismatch (")
2508  << (mismatchingDims.first.isScalable ? "[" : "")
2509  << mismatchingDims.first.dim
2510  << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
2511  << (mismatchingDims.second.isScalable ? "[" : "")
2512  << mismatchingDims.second.dim
2513  << (mismatchingDims.second.isScalable ? "]" : "") << ")";
2514  }
2515  if (res == BroadcastableToResult::SourceTypeNotAVector)
2516  return emitOpError("source type is not a vector");
2517  llvm_unreachable("unexpected vector.broadcast op error");
2518 }
2519 
2520 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
2521  if (getSourceType() == getResultVectorType())
2522  return getSource();
2523  if (!adaptor.getSource())
2524  return {};
2525  auto vectorType = getResultVectorType();
2526  if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
2527  return DenseElementsAttr::get(vectorType, adaptor.getSource());
2528  if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2529  return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
2530  return {};
2531 }
2532 
2533 namespace {
2534 
2535 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
2536 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2538 
2539  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2540  PatternRewriter &rewriter) const override {
2541  auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2542  if (!srcBroadcast)
2543  return failure();
2544  rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
2545  broadcastOp.getResultVectorType(),
2546  srcBroadcast.getSource());
2547  return success();
2548  }
2549 };
2550 } // namespace
2551 
2552 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2553  MLIRContext *context) {
2554  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2555  // calling `populateCastAwayVectorLeadingOneDimPatterns`
2556  results.add<BroadcastFolder>(context);
2557 }
2558 
2559 //===----------------------------------------------------------------------===//
2560 // ShuffleOp
2561 //===----------------------------------------------------------------------===//
2562 
2563 LogicalResult ShuffleOp::verify() {
2564  VectorType resultType = getResultVectorType();
2565  VectorType v1Type = getV1VectorType();
2566  VectorType v2Type = getV2VectorType();
2567  // Verify ranks.
2568  int64_t resRank = resultType.getRank();
2569  int64_t v1Rank = v1Type.getRank();
2570  int64_t v2Rank = v2Type.getRank();
2571  bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2572  bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2573  if (!wellFormed0DCase && !wellFormedNDCase)
2574  return emitOpError("rank mismatch");
2575 
2576  // Verify all but leading dimension sizes.
2577  for (int64_t r = 1; r < v1Rank; ++r) {
2578  int64_t resDim = resultType.getDimSize(r);
2579  int64_t v1Dim = v1Type.getDimSize(r);
2580  int64_t v2Dim = v2Type.getDimSize(r);
2581  if (resDim != v1Dim || v1Dim != v2Dim)
2582  return emitOpError("dimension mismatch");
2583  }
2584  // Verify mask length.
2585  ArrayRef<int64_t> mask = getMask();
2586  int64_t maskLength = mask.size();
2587  if (maskLength <= 0)
2588  return emitOpError("invalid mask length");
2589  if (maskLength != resultType.getDimSize(0))
2590  return emitOpError("mask length mismatch");
2591  // Verify all indices.
2592  int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2593  (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2594  for (auto [idx, maskPos] : llvm::enumerate(mask)) {
2595  if (maskPos < 0 || maskPos >= indexSize)
2596  return emitOpError("mask index #") << (idx + 1) << " out of range";
2597  }
2598  return success();
2599 }
2600 
2601 LogicalResult
2602 ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
2603  ShuffleOp::Adaptor adaptor,
2604  SmallVectorImpl<Type> &inferredReturnTypes) {
2605  auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2606  auto v1Rank = v1Type.getRank();
2607  // Construct resulting type: leading dimension matches mask
2608  // length, all trailing dimensions match the operands.
2610  shape.reserve(v1Rank);
2611  shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2612  // In the 0-D case there is no trailing shape to append.
2613  if (v1Rank > 0)
2614  llvm::append_range(shape, v1Type.getShape().drop_front());
2615  inferredReturnTypes.push_back(
2616  VectorType::get(shape, v1Type.getElementType()));
2617  return success();
2618 }
2619 
2620 template <typename T>
2621 static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
2622  T expected = begin;
2623  return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2624  return value == expected++;
2625  });
2626 }
2627 
2628 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2629  VectorType v1Type = getV1VectorType();
2630  // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
2631  // but must be a canonicalization into a vector.broadcast.
2632  if (v1Type.getRank() == 0)
2633  return {};
2634 
2635  // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
2636  if (!v1Type.isScalable() &&
2637  isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
2638  return getV1();
2639  // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
2640  if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
2641  isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
2642  getV2VectorType().getDimSize(0)))
2643  return getV2();
2644 
2645  Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
2646  if (!lhs || !rhs)
2647  return {};
2648 
2649  auto lhsType =
2650  llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType());
2651  // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2652  // manipulation.
2653  if (lhsType.getRank() != 1)
2654  return {};
2655  int64_t lhsSize = lhsType.getDimSize(0);
2656 
2657  SmallVector<Attribute> results;
2658  auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
2659  auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
2660  for (int64_t i : this->getMask()) {
2661  if (i >= lhsSize) {
2662  results.push_back(rhsElements[i - lhsSize]);
2663  } else {
2664  results.push_back(lhsElements[i]);
2665  }
2666  }
2667 
2668  return DenseElementsAttr::get(getResultVectorType(), results);
2669 }
2670 
2671 namespace {
2672 
2673 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
2674 // to a broadcast.
2675 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
2677 
2678  LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
2679  PatternRewriter &rewriter) const override {
2680  VectorType v1VectorType = shuffleOp.getV1VectorType();
2681  ArrayRef<int64_t> mask = shuffleOp.getMask();
2682  if (v1VectorType.getRank() > 0)
2683  return failure();
2684  if (mask.size() != 1)
2685  return failure();
2686  VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
2687  if (mask[0] == 0)
2688  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2689  shuffleOp.getV1());
2690  else
2691  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2692  shuffleOp.getV2());
2693  return success();
2694  }
2695 };
2696 
2697 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
2698 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
2699 public:
2701 
2702  LogicalResult matchAndRewrite(ShuffleOp op,
2703  PatternRewriter &rewriter) const override {
2704  auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2705  auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2706 
2707  if (!v1Splat || !v2Splat)
2708  return failure();
2709 
2710  if (v1Splat.getInput() != v2Splat.getInput())
2711  return failure();
2712 
2713  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
2714  return success();
2715  }
2716 };
2717 
2718 /// Pattern to rewrite a fixed-size interleave via vector.shuffle to
2719 /// vector.interleave.
2720 class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
2721 public:
2723 
2724  LogicalResult matchAndRewrite(ShuffleOp op,
2725  PatternRewriter &rewriter) const override {
2726  VectorType resultType = op.getResultVectorType();
2727  if (resultType.isScalable())
2728  return rewriter.notifyMatchFailure(
2729  op, "ShuffleOp can't represent a scalable interleave");
2730 
2731  if (resultType.getRank() != 1)
2732  return rewriter.notifyMatchFailure(
2733  op, "ShuffleOp can't represent an n-D interleave");
2734 
2735  VectorType sourceType = op.getV1VectorType();
2736  if (sourceType != op.getV2VectorType() ||
2737  sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2738  return rewriter.notifyMatchFailure(
2739  op, "ShuffleOp types don't match an interleave");
2740  }
2741 
2742  ArrayRef<int64_t> shuffleMask = op.getMask();
2743  int64_t resultVectorSize = resultType.getNumElements();
2744  for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2745  int64_t maskValueA = shuffleMask[i * 2];
2746  int64_t maskValueB = shuffleMask[(i * 2) + 1];
2747  if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2748  return rewriter.notifyMatchFailure(op,
2749  "ShuffleOp mask not interleaving");
2750  }
2751 
2752  rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
2753  return success();
2754  }
2755 };
2756 
2757 } // namespace
2758 
2759 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
2760  MLIRContext *context) {
2761  results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2762  context);
2763 }
2764 
2765 //===----------------------------------------------------------------------===//
2766 // InsertElementOp
2767 //===----------------------------------------------------------------------===//
2768 
2769 void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2770  SetIntRangeFn setResultRanges) {
2771  setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2772 }
2773 
2774 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
2775  Value source, Value dest) {
2776  build(builder, result, source, dest, {});
2777 }
2778 
2779 LogicalResult InsertElementOp::verify() {
2780  auto dstVectorType = getDestVectorType();
2781  if (dstVectorType.getRank() == 0) {
2782  if (getPosition())
2783  return emitOpError("expected position to be empty with 0-D vector");
2784  return success();
2785  }
2786  if (dstVectorType.getRank() != 1)
2787  return emitOpError("unexpected >1 vector rank");
2788  if (!getPosition())
2789  return emitOpError("expected position for 1-D vector");
2790  return success();
2791 }
2792 
2793 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2794  // Skip the 0-D vector here.
2795  if (!adaptor.getPosition())
2796  return {};
2797 
2798  auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2799  auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2800  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2801  if (!src || !dst || !pos)
2802  return {};
2803 
2804  if (src.getType() != getDestVectorType().getElementType())
2805  return {};
2806 
2807  auto dstElements = dst.getValues<Attribute>();
2808 
2809  SmallVector<Attribute> results(dstElements);
2810 
2811  uint64_t posIdx = pos.getInt();
2812  if (posIdx >= results.size())
2813  return {};
2814  results[posIdx] = src;
2815 
2816  return DenseElementsAttr::get(getDestVectorType(), results);
2817 }
2818 
2819 //===----------------------------------------------------------------------===//
2820 // InsertOp
2821 //===----------------------------------------------------------------------===//
2822 
2823 void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2824  SetIntRangeFn setResultRanges) {
2825  setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2826 }
2827 
2828 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2829  Value source, Value dest, int64_t position) {
2830  build(builder, result, source, dest, ArrayRef<int64_t>{position});
2831 }
2832 
2833 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2834  Value source, Value dest, OpFoldResult position) {
2835  build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
2836 }
2837 
2838 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2839  Value source, Value dest,
2840  ArrayRef<int64_t> position) {
2841  SmallVector<OpFoldResult> posVals;
2842  posVals.reserve(position.size());
2843  llvm::transform(position, std::back_inserter(posVals),
2844  [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
2845  build(builder, result, source, dest, posVals);
2846 }
2847 
2848 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2849  Value source, Value dest,
2850  ArrayRef<OpFoldResult> position) {
2851  SmallVector<int64_t> staticPos;
2852  SmallVector<Value> dynamicPos;
2853  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
2854  build(builder, result, source, dest, dynamicPos,
2855  builder.getDenseI64ArrayAttr(staticPos));
2856 }
2857 
2858 LogicalResult InsertOp::verify() {
2859  SmallVector<OpFoldResult> position = getMixedPosition();
2860  auto destVectorType = getDestVectorType();
2861  if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
2862  return emitOpError(
2863  "expected position attribute of rank no greater than dest vector rank");
2864  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2865  if (srcVectorType &&
2866  (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2867  static_cast<unsigned>(destVectorType.getRank())))
2868  return emitOpError("expected position attribute rank + source rank to "
2869  "match dest vector rank");
2870  if (!srcVectorType &&
2871  (position.size() != static_cast<unsigned>(destVectorType.getRank())))
2872  return emitOpError(
2873  "expected position attribute rank to match the dest vector rank");
2874  for (auto [idx, pos] : llvm::enumerate(position)) {
2875  if (auto attr = pos.dyn_cast<Attribute>()) {
2876  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2877  if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
2878  return emitOpError("expected position attribute #")
2879  << (idx + 1)
2880  << " to be a non-negative integer smaller than the "
2881  "corresponding "
2882  "dest vector dimension";
2883  }
2884  }
2885  }
2886  return success();
2887 }
2888 
2889 namespace {
2890 
2891 // If insertOp is only inserting unit dimensions it can be transformed to a
2892 // broadcast.
2893 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
2894 public:
2896 
2897  LogicalResult matchAndRewrite(InsertOp insertOp,
2898  PatternRewriter &rewriter) const override {
2899  auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
2900  if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2901  srcVecType.getNumElements())
2902  return failure();
2903  rewriter.replaceOpWithNewOp<BroadcastOp>(
2904  insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2905  return success();
2906  }
2907 };
2908 
2909 /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
2910 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
2911 public:
2913 
2914  LogicalResult matchAndRewrite(InsertOp op,
2915  PatternRewriter &rewriter) const override {
2916  auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2917  auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2918 
2919  if (!srcSplat || !dstSplat)
2920  return failure();
2921 
2922  if (srcSplat.getInput() != dstSplat.getInput())
2923  return failure();
2924 
2925  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
2926  return success();
2927  }
2928 };
2929 
2930 // Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
2931 class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
2932 public:
2934 
2935  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
2936  // unless the source vector constant has a single use.
2937  static constexpr int64_t vectorSizeFoldThreshold = 256;
2938 
2939  LogicalResult matchAndRewrite(InsertOp op,
2940  PatternRewriter &rewriter) const override {
2941  // TODO: Canonicalization for dynamic position not implemented yet.
2942  if (op.hasDynamicPosition())
2943  return failure();
2944 
2945  // Return if 'InsertOp' operand is not defined by a compatible vector
2946  // ConstantOp.
2947  TypedValue<VectorType> destVector = op.getDest();
2948  Attribute vectorDestCst;
2949  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
2950  return failure();
2951  auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
2952  if (!denseDest)
2953  return failure();
2954 
2955  VectorType destTy = destVector.getType();
2956  if (destTy.isScalable())
2957  return failure();
2958 
2959  // Make sure we do not create too many large constants.
2960  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
2961  !destVector.hasOneUse())
2962  return failure();
2963 
2964  Value sourceValue = op.getSource();
2965  Attribute sourceCst;
2966  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
2967  return failure();
2968 
2969  // Calculate the linearized position of the continuous chunk of elements to
2970  // insert.
2971  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
2972  copy(op.getStaticPosition(), completePositions.begin());
2973  int64_t insertBeginPosition =
2974  linearize(completePositions, computeStrides(destTy.getShape()));
2975 
2976  SmallVector<Attribute> insertedValues;
2977  Type destEltType = destTy.getElementType();
2978 
2979  // The `convertIntegerAttr` method specifically handles the case
2980  // for `llvm.mlir.constant` which can hold an attribute with a
2981  // different type than the return type.
2982  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
2983  for (auto value : denseSource.getValues<Attribute>())
2984  insertedValues.push_back(convertIntegerAttr(value, destEltType));
2985  } else {
2986  insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType));
2987  }
2988 
2989  auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
2990  copy(insertedValues, allValues.begin() + insertBeginPosition);
2991  auto newAttr = DenseElementsAttr::get(destTy, allValues);
2992 
2993  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
2994  return success();
2995  }
2996 
2997 private:
2998  /// Converts the expected type to an IntegerAttr if there's
2999  /// a mismatch.
3000  Attribute convertIntegerAttr(Attribute attr, Type expectedType) const {
3001  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3002  if (intAttr.getType() != expectedType)
3003  return IntegerAttr::get(expectedType, intAttr.getInt());
3004  }
3005  return attr;
3006  }
3007 };
3008 
3009 } // namespace
3010 
3011 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3012  MLIRContext *context) {
3013  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3014  InsertOpConstantFolder>(context);
3015 }
3016 
3017 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3018  // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3019  // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3020  // (type mismatch).
3021  if (getNumIndices() == 0 && getSourceType() == getType())
3022  return getSource();
3023  return {};
3024 }
3025 
3026 //===----------------------------------------------------------------------===//
3027 // InsertStridedSliceOp
3028 //===----------------------------------------------------------------------===//
3029 
3030 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3031  Value source, Value dest,
3032  ArrayRef<int64_t> offsets,
3033  ArrayRef<int64_t> strides) {
3034  result.addOperands({source, dest});
3035  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3036  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3037  result.addTypes(dest.getType());
3038  result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
3039  offsetsAttr);
3040  result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
3041  stridesAttr);
3042 }
3043 
3044 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
3045 template <typename OpType>
3046 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3047  ArrayAttr arrayAttr,
3048  ArrayRef<int64_t> shape,
3049  StringRef attrName) {
3050  if (arrayAttr.size() > shape.size())
3051  return op.emitOpError("expected ")
3052  << attrName << " attribute of rank no greater than vector rank";
3053  return success();
3054 }
3055 
3056 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3057 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3058 // Otherwise, the admissible interval is [min, max].
3059 template <typename OpType>
3060 static LogicalResult
3061 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
3062  int64_t max, StringRef attrName,
3063  bool halfOpen = true) {
3064  for (auto attr : arrayAttr) {
3065  auto val = llvm::cast<IntegerAttr>(attr).getInt();
3066  auto upper = max;
3067  if (!halfOpen)
3068  upper += 1;
3069  if (val < min || val >= upper)
3070  return op.emitOpError("expected ") << attrName << " to be confined to ["
3071  << min << ", " << upper << ")";
3072  }
3073  return success();
3074 }
3075 
3076 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3077 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3078 // Otherwise, the admissible interval is [min, max].
3079 template <typename OpType>
3080 static LogicalResult
3081 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
3082  ArrayRef<int64_t> shape, StringRef attrName,
3083  bool halfOpen = true, int64_t min = 0) {
3084  for (auto [index, attrDimPair] :
3085  llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3086  int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3087  int64_t max = std::get<1>(attrDimPair);
3088  if (!halfOpen)
3089  max += 1;
3090  if (val < min || val >= max)
3091  return op.emitOpError("expected ")
3092  << attrName << " dimension " << index << " to be confined to ["
3093  << min << ", " << max << ")";
3094  }
3095  return success();
3096 }
3097 
3098 // Returns true if, for all indices i = 0..shape.size()-1, val is in the
3099 // [min, max} interval:
3100 // val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3101 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3102 // the admissible interval is [min, max].
3103 template <typename OpType>
3105  OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3106  ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3107  bool halfOpen = true, int64_t min = 1) {
3108  assert(arrayAttr1.size() <= shape.size());
3109  assert(arrayAttr2.size() <= shape.size());
3110  for (auto [index, it] :
3111  llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3112  auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3113  auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3114  int64_t max = std::get<2>(it);
3115  if (!halfOpen)
3116  max += 1;
3117  if (val1 + val2 < 0 || val1 + val2 >= max)
3118  return op.emitOpError("expected sum(")
3119  << attrName1 << ", " << attrName2 << ") dimension " << index
3120  << " to be confined to [" << min << ", " << max << ")";
3121  }
3122  return success();
3123 }
3124 
3125 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
3126  MLIRContext *context) {
3127  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3128  return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3129  });
3130  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3131 }
3132 
3133 LogicalResult InsertStridedSliceOp::verify() {
3134  auto sourceVectorType = getSourceVectorType();
3135  auto destVectorType = getDestVectorType();
3136  auto offsets = getOffsetsAttr();
3137  auto strides = getStridesAttr();
3138  if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3139  return emitOpError(
3140  "expected offsets of same size as destination vector rank");
3141  if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3142  return emitOpError("expected strides of same size as source vector rank");
3143  if (sourceVectorType.getRank() > destVectorType.getRank())
3144  return emitOpError(
3145  "expected source rank to be no greater than destination rank");
3146 
3147  auto sourceShape = sourceVectorType.getShape();
3148  auto destShape = destVectorType.getShape();
3149  SmallVector<int64_t, 4> sourceShapeAsDestShape(
3150  destShape.size() - sourceShape.size(), 0);
3151  sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3152  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3153  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3154  if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3155  offName)) ||
3156  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3157  /*max=*/1, stridesName,
3158  /*halfOpen=*/false)) ||
3160  *this, offsets,
3161  makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3162  offName, "source vector shape",
3163  /*halfOpen=*/false, /*min=*/1)))
3164  return failure();
3165 
3166  unsigned rankDiff = destShape.size() - sourceShape.size();
3167  for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3168  if (sourceVectorType.getScalableDims()[idx] !=
3169  destVectorType.getScalableDims()[idx + rankDiff]) {
3170  return emitOpError("mismatching scalable flags (at source vector idx=")
3171  << idx << ")";
3172  }
3173  if (sourceVectorType.getScalableDims()[idx]) {
3174  auto sourceSize = sourceShape[idx];
3175  auto destSize = destShape[idx + rankDiff];
3176  if (sourceSize != destSize) {
3177  return emitOpError("expected size at idx=")
3178  << idx
3179  << (" to match the corresponding base size from the input "
3180  "vector (")
3181  << sourceSize << (" vs ") << destSize << (")");
3182  }
3183  }
3184  }
3185 
3186  return success();
3187 }
3188 
3189 namespace {
3190 /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3191 /// SplatOp(X):dst_type) to SplatOp(X):dst_type.
3192 class FoldInsertStridedSliceSplat final
3193  : public OpRewritePattern<InsertStridedSliceOp> {
3194 public:
3196 
3197  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3198  PatternRewriter &rewriter) const override {
3199  auto srcSplatOp =
3200  insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
3201  auto destSplatOp =
3202  insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3203 
3204  if (!srcSplatOp || !destSplatOp)
3205  return failure();
3206 
3207  if (srcSplatOp.getInput() != destSplatOp.getInput())
3208  return failure();
3209 
3210  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3211  return success();
3212  }
3213 };
3214 
3215 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3216 /// to dst.
3217 class FoldInsertStridedSliceOfExtract final
3218  : public OpRewritePattern<InsertStridedSliceOp> {
3219 public:
3221 
3222  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3223  PatternRewriter &rewriter) const override {
3224  auto extractStridedSliceOp =
3225  insertStridedSliceOp.getSource()
3226  .getDefiningOp<vector::ExtractStridedSliceOp>();
3227 
3228  if (!extractStridedSliceOp)
3229  return failure();
3230 
3231  if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3232  return failure();
3233 
3234  // Check if have the same strides and offsets.
3235  if (extractStridedSliceOp.getStrides() !=
3236  insertStridedSliceOp.getStrides() ||
3237  extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3238  return failure();
3239 
3240  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3241  return success();
3242  }
3243 };
3244 
3245 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3246 // ConstantOp.
3247 class InsertStridedSliceConstantFolder final
3248  : public OpRewritePattern<InsertStridedSliceOp> {
3249 public:
3251 
3252  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3253  // unless the source vector constant has a single use.
3254  static constexpr int64_t vectorSizeFoldThreshold = 256;
3255 
3256  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3257  PatternRewriter &rewriter) const override {
3258  // Return if 'InsertOp' operand is not defined by a compatible vector
3259  // ConstantOp.
3260  TypedValue<VectorType> destVector = op.getDest();
3261  Attribute vectorDestCst;
3262  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3263  return failure();
3264 
3265  VectorType destTy = destVector.getType();
3266  if (destTy.isScalable())
3267  return failure();
3268 
3269  // Make sure we do not create too many large constants.
3270  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3271  !destVector.hasOneUse())
3272  return failure();
3273 
3274  auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3275 
3276  TypedValue<VectorType> sourceValue = op.getSource();
3277  Attribute sourceCst;
3278  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3279  return failure();
3280 
3281  // TODO: Handle non-unit strides when they become available.
3282  if (op.hasNonUnitStrides())
3283  return failure();
3284 
3285  VectorType sliceVecTy = sourceValue.getType();
3286  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3287  int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3288  SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3289  SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3290 
3291  // Calcualte the destination element indices by enumerating all slice
3292  // positions within the destination and linearizing them. The enumeration
3293  // order is lexicographic which yields a sequence of monotonically
3294  // increasing linearized position indices.
3295  // Because the destination may have higher dimensionality then the slice,
3296  // we keep track of two overlapping sets of positions and offsets.
3297  auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3298  auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3299  auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3300  SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3301  MutableArrayRef<int64_t> currSlicePosition(
3302  currDestPosition.begin() + rankDifference, currDestPosition.end());
3303  ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3304  offsets.end());
3305  do {
3306  int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3307  assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3308  assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3309  "Invalid slice element");
3310  newValues[linearizedPosition] = *sliceValuesIt;
3311  ++sliceValuesIt;
3312  } while (succeeded(
3313  incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3314 
3315  auto newAttr = DenseElementsAttr::get(destTy, newValues);
3316  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3317  return success();
3318  }
3319 };
3320 
3321 } // namespace
3322 
3323 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3324  RewritePatternSet &results, MLIRContext *context) {
3325  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3326  InsertStridedSliceConstantFolder>(context);
3327 }
3328 
3329 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3330  if (getSourceVectorType() == getDestVectorType())
3331  return getSource();
3332  return {};
3333 }
3334 
3335 //===----------------------------------------------------------------------===//
3336 // OuterProductOp
3337 //===----------------------------------------------------------------------===//
3338 
3339 /// Build an op without mask, use the type of `acc` as the return type.
3340 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
3341  Value lhs, Value rhs, Value acc) {
3342  result.addOperands({lhs, rhs, acc});
3343  result.addTypes(acc.getType());
3344 }
3345 
3347  p << " " << getLhs() << ", " << getRhs();
3348  if (getAcc()) {
3349  p << ", " << getAcc();
3350  p.printOptionalAttrDict((*this)->getAttrs());
3351  }
3352  p << " : " << getLhs().getType() << ", " << getRhs().getType();
3353 }
3354 
3355 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
3357  Type tLHS, tRHS;
3358  if (parser.parseOperandList(operandsInfo) ||
3359  parser.parseOptionalAttrDict(result.attributes) ||
3360  parser.parseColonType(tLHS) || parser.parseComma() ||
3361  parser.parseType(tRHS))
3362  return failure();
3363  if (operandsInfo.size() < 2)
3364  return parser.emitError(parser.getNameLoc(),
3365  "expected at least 2 operands");
3366  VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3367  VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3368  if (!vLHS)
3369  return parser.emitError(parser.getNameLoc(),
3370  "expected vector type for operand #1");
3371 
3372  VectorType resType;
3373  if (vRHS) {
3374  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3375  vRHS.getScalableDims()[0]};
3376  resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
3377  vLHS.getElementType(), scalableDimsRes);
3378  } else {
3379  // Scalar RHS operand
3380  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3381  resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3382  scalableDimsRes);
3383  }
3384 
3385  if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
3386  result.attributes.append(
3387  OuterProductOp::getKindAttrName(result.name),
3389  OuterProductOp::getDefaultKind()));
3390  }
3391 
3392  return failure(
3393  parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
3394  parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
3395  (operandsInfo.size() > 2 &&
3396  parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
3397  parser.addTypeToList(resType, result.types));
3398 }
3399 
3400 LogicalResult OuterProductOp::verify() {
3401  Type tRHS = getOperandTypeRHS();
3402  VectorType vLHS = getOperandVectorTypeLHS(),
3403  vRHS = llvm::dyn_cast<VectorType>(tRHS),
3404  vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3405 
3406  if (vLHS.getRank() != 1)
3407  return emitOpError("expected 1-d vector for operand #1");
3408 
3409  if (vRHS) {
3410  // Proper OUTER operation.
3411  if (vRHS.getRank() != 1)
3412  return emitOpError("expected 1-d vector for operand #2");
3413  if (vRES.getRank() != 2)
3414  return emitOpError("expected 2-d vector result");
3415  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3416  return emitOpError("expected #1 operand dim to match result dim #1");
3417  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3418  return emitOpError("expected #2 operand dim to match result dim #2");
3419  if (vLHS.isScalable() && !vRHS.isScalable()) {
3420  // This restriction reflects what's currently supported in terms of
3421  // scalable vectors. However, we could relax this if there's a use case.
3422  return emitOpError(
3423  "expected either both or only #2 operand dim to be scalable");
3424  }
3425  } else {
3426  // An AXPY operation.
3427  if (vRES.getRank() != 1)
3428  return emitOpError("expected 1-d vector result");
3429  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3430  return emitOpError("expected #1 operand dim to match result dim #1");
3431  }
3432 
3433  if (vACC && vACC != vRES)
3434  return emitOpError("expected operand #3 of same type as result type");
3435 
3436  // Verify supported combining kind.
3437  if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
3438  return emitOpError("unsupported outerproduct type");
3439 
3440  return success();
3441 }
3442 
3443 // MaskableOpInterface methods.
3444 
3445 /// Returns the mask type expected by this operation. Mostly used for
3446 /// verification purposes. It requires the operation to be vectorized."
3447 Type OuterProductOp::getExpectedMaskType() {
3448  auto vecType = this->getResultVectorType();
3449  return VectorType::get(vecType.getShape(),
3450  IntegerType::get(vecType.getContext(), /*width=*/1),
3451  vecType.getScalableDims());
3452 }
3453 
3454 //===----------------------------------------------------------------------===//
3455 // ExtractStridedSliceOp
3456 //===----------------------------------------------------------------------===//
3457 
3458 // Inference works as follows:
3459 // 1. Add 'sizes' from prefix of dims in 'offsets'.
3460 // 2. Add sizes from 'vectorType' for remaining dims.
3461 // Scalable flags are inherited from 'vectorType'.
3462 static Type inferStridedSliceOpResultType(VectorType vectorType,
3463  ArrayAttr offsets, ArrayAttr sizes,
3464  ArrayAttr strides) {
3465  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3467  shape.reserve(vectorType.getRank());
3468  unsigned idx = 0;
3469  for (unsigned e = offsets.size(); idx < e; ++idx)
3470  shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3471  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3472  shape.push_back(vectorType.getShape()[idx]);
3473 
3474  return VectorType::get(shape, vectorType.getElementType(),
3475  vectorType.getScalableDims());
3476 }
3477 
3478 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3479  Value source, ArrayRef<int64_t> offsets,
3480  ArrayRef<int64_t> sizes,
3481  ArrayRef<int64_t> strides) {
3482  result.addOperands(source);
3483  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3484  auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
3485  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3486  result.addTypes(
3487  inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
3488  offsetsAttr, sizesAttr, stridesAttr));
3489  result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
3490  offsetsAttr);
3491  result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
3492  sizesAttr);
3493  result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
3494  stridesAttr);
3495 }
3496 
3497 LogicalResult ExtractStridedSliceOp::verify() {
3498  auto type = getSourceVectorType();
3499  auto offsets = getOffsetsAttr();
3500  auto sizes = getSizesAttr();
3501  auto strides = getStridesAttr();
3502  if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3503  return emitOpError(
3504  "expected offsets, sizes and strides attributes of same size");
3505 
3506  auto shape = type.getShape();
3507  auto offName = getOffsetsAttrName();
3508  auto sizesName = getSizesAttrName();
3509  auto stridesName = getStridesAttrName();
3510  if (failed(
3511  isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
3512  failed(
3513  isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
3514  failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
3515  stridesName)) ||
3516  failed(
3517  isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
3518  failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
3519  /*halfOpen=*/false,
3520  /*min=*/1)) ||
3521  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3522  /*max=*/1, stridesName,
3523  /*halfOpen=*/false)) ||
3524  failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
3525  shape, offName, sizesName,
3526  /*halfOpen=*/false)))
3527  return failure();
3528 
3529  auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
3530  offsets, sizes, strides);
3531  if (getResult().getType() != resultType)
3532  return emitOpError("expected result type to be ") << resultType;
3533 
3534  for (unsigned idx = 0; idx < sizes.size(); ++idx) {
3535  if (type.getScalableDims()[idx]) {
3536  auto inputDim = type.getShape()[idx];
3537  auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3538  if (inputDim != inputSize)
3539  return emitOpError("expected size at idx=")
3540  << idx
3541  << (" to match the corresponding base size from the input "
3542  "vector (")
3543  << inputSize << (" vs ") << inputDim << (")");
3544  }
3545  }
3546 
3547  return success();
3548 }
3549 
3550 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
3551 // to use the source of the InsertStrided ops if we can detect that the
3552 // extracted vector is a subset of one of the vector inserted.
3553 static LogicalResult
3554 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
3555  // Helper to extract integer out of ArrayAttr.
3556  auto getElement = [](ArrayAttr array, int idx) {
3557  return llvm::cast<IntegerAttr>(array[idx]).getInt();
3558  };
3559  ArrayAttr extractOffsets = op.getOffsets();
3560  ArrayAttr extractStrides = op.getStrides();
3561  ArrayAttr extractSizes = op.getSizes();
3562  auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3563  while (insertOp) {
3564  if (op.getSourceVectorType().getRank() !=
3565  insertOp.getSourceVectorType().getRank())
3566  return failure();
3567  ArrayAttr insertOffsets = insertOp.getOffsets();
3568  ArrayAttr insertStrides = insertOp.getStrides();
3569  // If the rank of extract is greater than the rank of insert, we are likely
3570  // extracting a partial chunk of the vector inserted.
3571  if (extractOffsets.size() > insertOffsets.size())
3572  return failure();
3573  bool patialoverlap = false;
3574  bool disjoint = false;
3575  SmallVector<int64_t, 4> offsetDiffs;
3576  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3577  if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
3578  return failure();
3579  int64_t start = getElement(insertOffsets, dim);
3580  int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3581  int64_t offset = getElement(extractOffsets, dim);
3582  int64_t size = getElement(extractSizes, dim);
3583  // Check if the start of the extract offset is in the interval inserted.
3584  if (start <= offset && offset < end) {
3585  // If the extract interval overlaps but is not fully included we may
3586  // have a partial overlap that will prevent any folding.
3587  if (offset + size > end)
3588  patialoverlap = true;
3589  offsetDiffs.push_back(offset - start);
3590  continue;
3591  }
3592  disjoint = true;
3593  break;
3594  }
3595  // The extract element chunk is a subset of the insert element.
3596  if (!disjoint && !patialoverlap) {
3597  op.setOperand(insertOp.getSource());
3598  // OpBuilder is only used as a helper to build an I64ArrayAttr.
3599  OpBuilder b(op.getContext());
3600  op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
3601  return success();
3602  }
3603  // If the chunk extracted is disjoint from the chunk inserted, keep looking
3604  // in the insert chain.
3605  if (disjoint)
3606  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3607  else {
3608  // The extracted vector partially overlap the inserted vector, we cannot
3609  // fold.
3610  return failure();
3611  }
3612  }
3613  return failure();
3614 }
3615 
3616 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3617  if (getSourceVectorType() == getResult().getType())
3618  return getVector();
3619  if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
3620  return getResult();
3621  return {};
3622 }
3623 
3624 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
3625  populateFromInt64AttrArray(getOffsets(), results);
3626 }
3627 
3628 namespace {
3629 
3630 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
3631 // ConstantMaskOp.
3632 class StridedSliceConstantMaskFolder final
3633  : public OpRewritePattern<ExtractStridedSliceOp> {
3634 public:
3636 
3637  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3638  PatternRewriter &rewriter) const override {
3639  // Return if 'extractStridedSliceOp' operand is not defined by a
3640  // ConstantMaskOp.
3641  auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3642  auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3643  if (!constantMaskOp)
3644  return failure();
3645  // Return if 'extractStridedSliceOp' has non-unit strides.
3646  if (extractStridedSliceOp.hasNonUnitStrides())
3647  return failure();
3648  // Gather constant mask dimension sizes.
3649  ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
3650  // Gather strided slice offsets and sizes.
3651  SmallVector<int64_t, 4> sliceOffsets;
3652  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
3653  sliceOffsets);
3654  SmallVector<int64_t, 4> sliceSizes;
3655  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
3656 
3657  // Compute slice of vector mask region.
3658  SmallVector<int64_t, 4> sliceMaskDimSizes;
3659  sliceMaskDimSizes.reserve(maskDimSizes.size());
3660  for (auto [maskDimSize, sliceOffset, sliceSize] :
3661  llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3662  int64_t sliceMaskDimSize = std::max(
3663  static_cast<int64_t>(0),
3664  std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3665  sliceMaskDimSizes.push_back(sliceMaskDimSize);
3666  }
3667  // Add unchanged dimensions.
3668  if (sliceMaskDimSizes.size() < maskDimSizes.size())
3669  for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3670  sliceMaskDimSizes.push_back(maskDimSizes[i]);
3671  // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
3672  // region is a conjunction of mask dim intervals).
3673  if (llvm::is_contained(sliceMaskDimSizes, 0))
3674  sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3675 
3676  // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
3677  // region.
3678  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3679  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3680  sliceMaskDimSizes);
3681  return success();
3682  }
3683 };
3684 
3685 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3686 class StridedSliceSplatConstantFolder final
3687  : public OpRewritePattern<ExtractStridedSliceOp> {
3688 public:
3690 
3691  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3692  PatternRewriter &rewriter) const override {
3693  // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3694  // ConstantOp.
3695  Value sourceVector = extractStridedSliceOp.getVector();
3696  Attribute vectorCst;
3697  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3698  return failure();
3699 
3700  auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3701  if (!splat)
3702  return failure();
3703 
3704  auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
3705  splat.getSplatValue<Attribute>());
3706  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3707  newAttr);
3708  return success();
3709  }
3710 };
3711 
3712 // Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3713 // ConstantOp.
3714 class StridedSliceNonSplatConstantFolder final
3715  : public OpRewritePattern<ExtractStridedSliceOp> {
3716 public:
3718 
3719  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3720  PatternRewriter &rewriter) const override {
3721  // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3722  // ConstantOp.
3723  Value sourceVector = extractStridedSliceOp.getVector();
3724  Attribute vectorCst;
3725  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3726  return failure();
3727 
3728  // The splat case is handled by `StridedSliceSplatConstantFolder`.
3729  auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3730  if (!dense || dense.isSplat())
3731  return failure();
3732 
3733  // TODO: Handle non-unit strides when they become available.
3734  if (extractStridedSliceOp.hasNonUnitStrides())
3735  return failure();
3736 
3737  auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
3738  ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3739  SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3740 
3741  VectorType sliceVecTy = extractStridedSliceOp.getType();
3742  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3743  int64_t sliceRank = sliceVecTy.getRank();
3744 
3745  // Expand offsets and sizes to match the vector rank.
3746  SmallVector<int64_t, 4> offsets(sliceRank, 0);
3747  copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
3748 
3749  SmallVector<int64_t, 4> sizes(sourceShape);
3750  copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
3751 
3752  // Calculate the slice elements by enumerating all slice positions and
3753  // linearizing them. The enumeration order is lexicographic which yields a
3754  // sequence of monotonically increasing linearized position indices.
3755  auto denseValuesBegin = dense.value_begin<Attribute>();
3756  SmallVector<Attribute> sliceValues;
3757  sliceValues.reserve(sliceVecTy.getNumElements());
3758  SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3759  do {
3760  int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
3761  assert(linearizedPosition < sourceVecTy.getNumElements() &&
3762  "Invalid index");
3763  sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3764  } while (
3765  succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
3766 
3767  assert(static_cast<int64_t>(sliceValues.size()) ==
3768  sliceVecTy.getNumElements() &&
3769  "Invalid number of slice elements");
3770  auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
3771  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3772  newAttr);
3773  return success();
3774  }
3775 };
3776 
3777 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
3778 // BroadcastOp(ExtractStrideSliceOp).
3779 class StridedSliceBroadcast final
3780  : public OpRewritePattern<ExtractStridedSliceOp> {
3781 public:
3783 
3784  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3785  PatternRewriter &rewriter) const override {
3786  auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3787  if (!broadcast)
3788  return failure();
3789  auto srcVecType =
3790  llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
3791  unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3792  auto dstVecType = llvm::cast<VectorType>(op.getType());
3793  unsigned dstRank = dstVecType.getRank();
3794  unsigned rankDiff = dstRank - srcRank;
3795  // Check if the most inner dimensions of the source of the broadcast are the
3796  // same as the destination of the extract. If this is the case we can just
3797  // use a broadcast as the original dimensions are untouched.
3798  bool lowerDimMatch = true;
3799  for (unsigned i = 0; i < srcRank; i++) {
3800  if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3801  lowerDimMatch = false;
3802  break;
3803  }
3804  }
3805  Value source = broadcast.getSource();
3806  // If the inner dimensions don't match, it means we need to extract from the
3807  // source of the orignal broadcast and then broadcast the extracted value.
3808  // We also need to handle degenerated cases where the source is effectively
3809  // just a single scalar.
3810  bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3811  if (!lowerDimMatch && !isScalarSrc) {
3812  source = rewriter.create<ExtractStridedSliceOp>(
3813  op->getLoc(), source,
3814  getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
3815  getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
3816  getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
3817  }
3818  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
3819  return success();
3820  }
3821 };
3822 
3823 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
3824 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
3825 public:
3827 
3828  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3829  PatternRewriter &rewriter) const override {
3830  auto splat = op.getVector().getDefiningOp<SplatOp>();
3831  if (!splat)
3832  return failure();
3833  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
3834  return success();
3835  }
3836 };
3837 
3838 /// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
3839 /// slice is contiguous, into extract and shape_cast.
3840 ///
3841 /// Example:
3842 /// Before:
3843 /// %1 = vector.extract_strided_slice %arg0 {
3844 /// offsets = [0, 0, 0, 0, 0],
3845 /// sizes = [1, 1, 1, 1, 8],
3846 /// strides = [1, 1, 1, 1, 1]
3847 /// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
3848 /// After:
3849 /// %0 = vector.extract %arg0[0, 0, 0, 0]
3850 /// : vector<8xi8> from vector<8x1x1x2x8xi8>
3851 /// %1 = vector.shape_cast %0
3852 /// : vector<8xi8> to vector<1x1x1x1x8xi8>
3853 ///
3854 class ContiguousExtractStridedSliceToExtract final
3855  : public OpRewritePattern<ExtractStridedSliceOp> {
3856 public:
3858 
3859  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3860  PatternRewriter &rewriter) const override {
3861  if (op.hasNonUnitStrides())
3862  return failure();
3863  Value source = op.getOperand();
3864  auto sourceType = cast<VectorType>(source.getType());
3865  if (sourceType.isScalable() || sourceType.getRank() == 0)
3866  return failure();
3867 
3868  // Compute the number of offsets to pass to ExtractOp::build. That is the
3869  // difference between the source rank and the desired slice rank. We walk
3870  // the dimensions from innermost out, and stop when the next slice dimension
3871  // is not full-size.
3872  SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
3873  int numOffsets;
3874  for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
3875  if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
3876  break;
3877  }
3878 
3879  // If the created extract op would have no offsets, then this whole
3880  // extract_strided_slice is the identity and should have been handled by
3881  // other canonicalizations.
3882  if (numOffsets == 0)
3883  return failure();
3884 
3885  // If not even the inner-most dimension is full-size, this op can't be
3886  // rewritten as an ExtractOp.
3887  if (numOffsets == sourceType.getRank() &&
3888  static_cast<int>(sizes.size()) == sourceType.getRank())
3889  return failure();
3890 
3891  // The outer dimensions must have unit size.
3892  for (int i = 0; i < numOffsets; ++i) {
3893  if (sizes[i] != 1)
3894  return failure();
3895  }
3896 
3897  // Avoid generating slices that have leading unit dimensions. The shape_cast
3898  // op that we create below would take bad generic fallback patterns
3899  // (ShapeCastOpRewritePattern).
3900  while (sizes[numOffsets] == 1 &&
3901  numOffsets < static_cast<int>(sizes.size()) - 1) {
3902  ++numOffsets;
3903  }
3904 
3905  SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
3906  auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
3907  Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
3908  extractOffsets);
3909  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
3910  return success();
3911  }
3912 };
3913 
3914 } // namespace
3915 
3916 void ExtractStridedSliceOp::getCanonicalizationPatterns(
3917  RewritePatternSet &results, MLIRContext *context) {
3918  // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
3919  // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
3920  results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
3921  StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3922  StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
3923  context);
3924 }
3925 
3926 //===----------------------------------------------------------------------===//
3927 // TransferReadOp
3928 //===----------------------------------------------------------------------===//
3929 
3930 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
3931 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3932  VectorType vectorType, Value source,
3933  ValueRange indices, AffineMapAttr permutationMapAttr,
3934  /*optional*/ ArrayAttr inBoundsAttr) {
3935  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3936  Value padding = builder.create<arith::ConstantOp>(
3937  result.location, elemType, builder.getZeroAttr(elemType));
3938  build(builder, result, vectorType, source, indices, permutationMapAttr,
3939  padding, /*mask=*/Value(), inBoundsAttr);
3940 }
3941 
3942 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
3943 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3944  VectorType vectorType, Value source,
3945  ValueRange indices, AffineMap permutationMap,
3946  std::optional<ArrayRef<bool>> inBounds) {
3947  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3948  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3949  ? builder.getBoolArrayAttr(inBounds.value())
3950  : builder.getBoolArrayAttr(
3951  SmallVector<bool>(vectorType.getRank(), false));
3952  build(builder, result, vectorType, source, indices, permutationMapAttr,
3953  inBoundsAttr);
3954 }
3955 
3956 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
3957 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3958  VectorType vectorType, Value source,
3959  ValueRange indices, Value padding,
3960  std::optional<ArrayRef<bool>> inBounds) {
3961  AffineMap permutationMap = getTransferMinorIdentityMap(
3962  llvm::cast<ShapedType>(source.getType()), vectorType);
3963  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3964  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3965  ? builder.getBoolArrayAttr(inBounds.value())
3966  : builder.getBoolArrayAttr(
3967  SmallVector<bool>(vectorType.getRank(), false));
3968  build(builder, result, vectorType, source, indices, permutationMapAttr,
3969  padding,
3970  /*mask=*/Value(), inBoundsAttr);
3971 }
3972 
3973 /// 4. Builder that sets padding to zero and permutation map to
3974 /// 'getMinorIdentityMap'.
3975 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3976  VectorType vectorType, Value source,
3977  ValueRange indices,
3978  std::optional<ArrayRef<bool>> inBounds) {
3979  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3980  Value padding = builder.create<arith::ConstantOp>(
3981  result.location, elemType, builder.getZeroAttr(elemType));
3982  build(builder, result, vectorType, source, indices, padding, inBounds);
3983 }
3984 
3985 template <typename EmitFun>
3986 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
3987  EmitFun emitOpError) {
3988  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
3989  for (auto expr : permutationMap.getResults()) {
3990  auto dim = dyn_cast<AffineDimExpr>(expr);
3991  auto zero = dyn_cast<AffineConstantExpr>(expr);
3992  if (zero) {
3993  if (zero.getValue() != 0) {
3994  return emitOpError(
3995  "requires a projected permutation_map (at most one dim or the zero "
3996  "constant can appear in each result)");
3997  }
3998  continue;
3999  }
4000  if (!dim) {
4001  return emitOpError("requires a projected permutation_map (at most one "
4002  "dim or the zero constant can appear in each result)");
4003  }
4004  if (seen[dim.getPosition()]) {
4005  return emitOpError(
4006  "requires a permutation_map that is a permutation (found one dim "
4007  "used more than once)");
4008  }
4009  seen[dim.getPosition()] = true;
4010  }
4011  return success();
4012 }
4013 
4014 static LogicalResult
4015 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4016  VectorType vectorType, VectorType maskType,
4017  VectorType inferredMaskType, AffineMap permutationMap,
4018  ArrayAttr inBounds) {
4019  if (op->hasAttr("masked")) {
4020  return op->emitOpError("masked attribute has been removed. "
4021  "Use in_bounds instead.");
4022  }
4023 
4024  if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4025  return op->emitOpError(
4026  "requires source to be a memref or ranked tensor type");
4027 
4028  auto elementType = shapedType.getElementType();
4029  DataLayout dataLayout = DataLayout::closest(op);
4030  if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4031  // Memref or tensor has vector element type.
4032  unsigned sourceVecSize =
4033  dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
4034  vectorElementType.getShape().back();
4035  unsigned resultVecSize =
4036  dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
4037  vectorType.getShape().back();
4038  if (resultVecSize % sourceVecSize != 0)
4039  return op->emitOpError(
4040  "requires the bitwidth of the minor 1-D vector to be an integral "
4041  "multiple of the bitwidth of the minor 1-D vector of the source");
4042 
4043  unsigned sourceVecEltRank = vectorElementType.getRank();
4044  unsigned resultVecRank = vectorType.getRank();
4045  if (sourceVecEltRank > resultVecRank)
4046  return op->emitOpError(
4047  "requires source vector element and vector result ranks to match.");
4048  unsigned rankOffset = resultVecRank - sourceVecEltRank;
4049  // Check that permutation map results match 'rankOffset' of vector type.
4050  if (permutationMap.getNumResults() != rankOffset)
4051  return op->emitOpError("requires a permutation_map with result dims of "
4052  "the same rank as the vector type");
4053 
4054  if (maskType)
4055  return op->emitOpError("does not support masks with vector element type");
4056  } else {
4057  // Memref or tensor has scalar element type.
4058  unsigned minorSize =
4059  vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4060  unsigned resultVecSize =
4061  dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
4062  if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
4063  return op->emitOpError(
4064  "requires the bitwidth of the minor 1-D vector to be an integral "
4065  "multiple of the bitwidth of the source element type");
4066 
4067  // Check that permutation map results match rank of vector type.
4068  if (permutationMap.getNumResults() != vectorType.getRank())
4069  return op->emitOpError("requires a permutation_map with result dims of "
4070  "the same rank as the vector type");
4071  }
4072 
4073  if (permutationMap.getNumSymbols() != 0)
4074  return op->emitOpError("requires permutation_map without symbols");
4075 
4076  if (permutationMap.getNumInputs() != shapedType.getRank())
4077  return op->emitOpError("requires a permutation_map with input dims of the "
4078  "same rank as the source type");
4079 
4080  if (maskType && maskType != inferredMaskType)
4081  return op->emitOpError("inferred mask type (")
4082  << inferredMaskType << ") and mask operand type (" << maskType
4083  << ") don't match";
4084 
4085  if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
4086  return op->emitOpError("expects the in_bounds attr of same rank "
4087  "as permutation_map results: ")
4088  << AffineMapAttr::get(permutationMap)
4089  << " vs inBounds of size: " << inBounds.size();
4090 
4091  return success();
4092 }
4093 
4094 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
4095  SmallVector<StringRef, 3> elidedAttrs;
4096  elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4097  if (op.getPermutationMap().isMinorIdentity())
4098  elidedAttrs.push_back(op.getPermutationMapAttrName());
4099  // Elide in_bounds attribute if all dims are out-of-bounds.
4100  if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
4101  elidedAttrs.push_back(op.getInBoundsAttrName());
4102  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4103 }
4104 
4106  p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
4107  if (getMask())
4108  p << ", " << getMask();
4109  printTransferAttrs(p, *this);
4110  p << " : " << getShapedType() << ", " << getVectorType();
4111 }
4112 
4113 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
4114  AffineMap permMap) {
4115  auto i1Type = IntegerType::get(permMap.getContext(), 1);
4116  AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
4117  assert(invPermMap && "Inversed permutation map couldn't be computed");
4118  SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
4119 
4120  // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
4121  // 0-D mask into a single-element 1-D mask.
4122  if (maskShape.empty())
4123  maskShape.push_back(1);
4124 
4125  SmallVector<bool> scalableDims =
4126  applyPermutationMap(invPermMap, vecType.getScalableDims());
4127 
4128  return VectorType::get(maskShape, i1Type, scalableDims);
4129 }
4130 
4131 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4132  auto &builder = parser.getBuilder();
4133  SMLoc typesLoc;
4134  OpAsmParser::UnresolvedOperand sourceInfo;
4136  OpAsmParser::UnresolvedOperand paddingInfo;
4137  SmallVector<Type, 2> types;
4139  // Parsing with support for paddingValue.
4140  if (parser.parseOperand(sourceInfo) ||
4141  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
4142  parser.parseComma() || parser.parseOperand(paddingInfo))
4143  return failure();
4144  ParseResult hasMask = parser.parseOptionalComma();
4145  if (hasMask.succeeded()) {
4146  if (parser.parseOperand(maskInfo))
4147  return failure();
4148  }
4149  if (parser.parseOptionalAttrDict(result.attributes) ||
4150  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4151  return failure();
4152  if (types.size() != 2)
4153  return parser.emitError(typesLoc, "requires two types");
4154  auto indexType = builder.getIndexType();
4155  auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4156  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4157  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4158  VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4159  if (!vectorType)
4160  return parser.emitError(typesLoc, "requires vector type");
4161  auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
4162  Attribute permMapAttr = result.attributes.get(permMapAttrName);
4163  AffineMap permMap;
4164  if (!permMapAttr) {
4165  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4166  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4167  } else {
4168  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4169  }
4170  auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
4171  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4172  if (!inBoundsAttr) {
4173  result.addAttribute(inBoundsAttrName,
4174  builder.getBoolArrayAttr(
4175  SmallVector<bool>(permMap.getNumResults(), false)));
4176  }
4177  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4178  parser.resolveOperands(indexInfo, indexType, result.operands) ||
4179  parser.resolveOperand(paddingInfo, shapedType.getElementType(),
4180  result.operands))
4181  return failure();
4182  if (hasMask.succeeded()) {
4183  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4184  return parser.emitError(
4185  maskInfo.location, "does not support masks with vector element type");
4186  if (vectorType.getRank() != permMap.getNumResults()) {
4187  return parser.emitError(typesLoc,
4188  "expected the same rank for the vector and the "
4189  "results of the permutation map");
4190  }
4191  // Instead of adding the mask type as an op type, compute it based on the
4192  // vector type and the permutation map (to keep the type signature small).
4193  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4194  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4195  return failure();
4196  }
4197  result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4198  builder.getDenseI32ArrayAttr(
4199  {1, static_cast<int32_t>(indexInfo.size()), 1,
4200  static_cast<int32_t>(hasMask.succeeded())}));
4201  return parser.addTypeToList(vectorType, result.types);
4202 }
4203 
4204 LogicalResult TransferReadOp::verify() {
4205  // Consistency of elemental types in source and vector.
4206  ShapedType shapedType = getShapedType();
4207  VectorType vectorType = getVectorType();
4208  VectorType maskType = getMaskType();
4209  auto paddingType = getPadding().getType();
4210  auto permutationMap = getPermutationMap();
4211  VectorType inferredMaskType =
4212  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4213  : VectorType();
4214  auto sourceElementType = shapedType.getElementType();
4215 
4216  if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
4217  return emitOpError("requires ") << shapedType.getRank() << " indices";
4218 
4219  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4220  shapedType, vectorType, maskType,
4221  inferredMaskType, permutationMap, getInBounds())))
4222  return failure();
4223 
4224  if (auto sourceVectorElementType =
4225  llvm::dyn_cast<VectorType>(sourceElementType)) {
4226  // Source has vector element type.
4227  // Check that 'sourceVectorElementType' and 'paddingType' types match.
4228  if (sourceVectorElementType != paddingType)
4229  return emitOpError(
4230  "requires source element type and padding type to match.");
4231 
4232  } else {
4233  // Check that 'paddingType' is valid to store in a vector type.
4234  if (!VectorType::isValidElementType(paddingType))
4235  return emitOpError("requires valid padding vector elemental type");
4236 
4237  // Check that padding type and vector element types match.
4238  if (paddingType != sourceElementType)
4239  return emitOpError(
4240  "requires formal padding and source of the same elemental type");
4241  }
4242 
4243  return verifyPermutationMap(permutationMap,
4244  [&](Twine t) { return emitOpError(t); });
4245 }
4246 
4247 // MaskableOpInterface methods.
4248 
4249 /// Returns the mask type expected by this operation. Mostly used for
4250 /// verification purposes. It requires the operation to be vectorized."
4251 Type TransferReadOp::getExpectedMaskType() {
4252  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4253 }
4254 
4255 template <typename TransferOp>
4256 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4257  // TODO: support more aggressive createOrFold on:
4258  // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4259  if (op.getShapedType().isDynamicDim(indicesIdx))
4260  return false;
4261  Value index = op.getIndices()[indicesIdx];
4262  std::optional<int64_t> cstOp = getConstantIntValue(index);
4263  if (!cstOp.has_value())
4264  return false;
4265 
4266  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4267  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4268 
4269  return cstOp.value() + vectorSize <= sourceSize;
4270 }
4271 
4272 template <typename TransferOp>
4273 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4274  // TODO: support 0-d corner case.
4275  // TODO: Be less conservative.
4276  if (op.getTransferRank() == 0)
4277  return failure();
4278  AffineMap permutationMap = op.getPermutationMap();
4279  bool changed = false;
4280  SmallVector<bool, 4> newInBounds;
4281  newInBounds.reserve(op.getTransferRank());
4282  // Idxs of non-bcast dims - used when analysing bcast dims.
4283  SmallVector<unsigned> nonBcastDims;
4284 
4285  // 1. Process non-broadcast dims
4286  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
4287  // 1.1. Already marked as in-bounds, nothing to see here.
4288  if (op.isDimInBounds(i)) {
4289  newInBounds.push_back(true);
4290  continue;
4291  }
4292  // 1.2. Currently out-of-bounds, check whether we can statically determine
4293  // it is inBounds.
4294  bool inBounds = false;
4295  auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
4296  if (dimExpr) {
4297  inBounds = isInBounds(op, /*resultIdx=*/i,
4298  /*indicesIdx=*/dimExpr.getPosition());
4299  nonBcastDims.push_back(i);
4300  }
4301 
4302  newInBounds.push_back(inBounds);
4303  // We commit the pattern if it is "more inbounds".
4304  changed |= inBounds;
4305  }
4306 
4307  // 2. Handle broadcast dims
4308  // If all non-broadcast dims are "in bounds", then all bcast dims should be
4309  // "in bounds" as well.
4310  bool allNonBcastDimsInBounds = llvm::all_of(
4311  nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
4312  if (allNonBcastDimsInBounds) {
4313  for (size_t idx : permutationMap.getBroadcastDims()) {
4314  changed |= !newInBounds[idx];
4315  newInBounds[idx] = true;
4316  }
4317  }
4318 
4319  if (!changed)
4320  return failure();
4321  // OpBuilder is only used as a helper to build an I64ArrayAttr.
4322  OpBuilder b(op.getContext());
4323  op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
4324  return success();
4325 }
4326 
4327 template <typename TransferOp>
4328 static LogicalResult foldTransferFullMask(TransferOp op) {
4329  auto mask = op.getMask();
4330  if (!mask)
4331  return failure();
4332 
4333  if (getMaskFormat(mask) != MaskFormat::AllTrue)
4334  return failure();
4335 
4336  op.getMaskMutable().clear();
4337  return success();
4338 }
4339 
4340 /// ```
4341 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4342 /// : vector<1x4xf32>, tensor<4x4xf32>
4343 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
4344 /// : tensor<4x4xf32>, vector<1x4xf32>
4345 /// ```
4346 /// -> Folds into
4347 /// ```
4348 /// %v0
4349 /// ```
4350 static Value foldRAW(TransferReadOp readOp) {
4351  if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4352  return {};
4353  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4354  while (defWrite) {
4355  if (checkSameValueRAW(defWrite, readOp))
4356  return defWrite.getVector();
4358  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4359  cast<VectorTransferOpInterface>(readOp.getOperation())))
4360  break;
4361  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4362  }
4363  return {};
4364 }
4365 
4366 OpFoldResult TransferReadOp::fold(FoldAdaptor) {
4367  if (Value vec = foldRAW(*this))
4368  return vec;
4369  /// transfer_read(memrefcast) -> transfer_read
4370  if (succeeded(foldTransferInBoundsAttribute(*this)))
4371  return getResult();
4372  if (succeeded(foldTransferFullMask(*this)))
4373  return getResult();
4374  if (succeeded(memref::foldMemRefCast(*this)))
4375  return getResult();
4376  if (succeeded(tensor::foldTensorCast(*this)))
4377  return getResult();
4378  return OpFoldResult();
4379 }
4380 
4381 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4382  return llvm::to_vector<4>(getVectorType().getShape());
4383 }
4384 
4385 void TransferReadOp::getEffects(
4387  &effects) {
4388  if (llvm::isa<MemRefType>(getShapedType()))
4389  effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(),
4391 }
4392 
4393 Speculation::Speculatability TransferReadOp::getSpeculatability() {
4394  if (hasPureTensorSemantics())
4397 }
4398 
4399 namespace {
4400 /// Store to load forwarding for transfer operations with permuation maps.
4401 /// Even if the permutation maps are different we can still propagate the store
4402 /// into the load if the size of the dimensions read and written match. Then we
4403 /// can replace the transfer_read + transfer_write by vector.broadcast and
4404 /// vector.transpose.
4405 /// Example:
4406 /// ```
4407 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
4408 /// {in_bounds = [true, true],
4409 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
4410 /// vector<4x1xf32>, tensor<4x4x4xf32>
4411 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
4412 /// {in_bounds = [true, true, true, true],
4413 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
4414 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
4415 /// ```
4416 /// To:
4417 /// ```
4418 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
4419 /// %r = vector.transpose %0, [3, 0, 2, 1] :
4420 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
4421 /// ```
4422 struct TransferReadAfterWriteToBroadcast
4423  : public OpRewritePattern<TransferReadOp> {
4425 
4426  LogicalResult matchAndRewrite(TransferReadOp readOp,
4427  PatternRewriter &rewriter) const override {
4428  if (readOp.hasOutOfBoundsDim() ||
4429  !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4430  return failure();
4431  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4432  if (!defWrite)
4433  return failure();
4434  // TODO: If the written transfer chunk is a superset of the read transfer
4435  // chunk we could do an extract_strided_slice.
4436  if (readOp.getTransferChunkAccessed() !=
4437  defWrite.getTransferChunkAccessed())
4438  return failure();
4439  // TODO: Support cases where a dim is explicitly written but implicitly
4440  // read (i.e., a unit dim that is rank reduced).
4441  if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
4442  getUnusedDimsBitVector({defWrite.getPermutationMap()}))
4443  return failure();
4444  if (readOp.getIndices() != defWrite.getIndices() ||
4445  readOp.getMask() != defWrite.getMask())
4446  return failure();
4447  Value vec = defWrite.getVector();
4448  // TODO: loop through the chain of transfer_write if we can prove that they
4449  // don't overlap with the transfer_read. This requires improving
4450  // `isDisjointTransferIndices` helper.
4451  AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4452  AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
4453  AffineMap map = readMap.compose(writeMap);
4454  if (map.getNumResults() == 0)
4455  return failure();
4456  // Calculate the permutation to apply to go from the vector stored to the
4457  // vector read.
4458  SmallVector<unsigned> permutation;
4459  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
4460  return failure();
4461 
4462  Location loc = readOp.getLoc();
4463  // Calculate the broadcast shape by applying the reverse permutation to the
4464  // final shape we want.
4465  ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
4466  SmallVector<int64_t> broadcastShape(destShape.size());
4467  SmallVector<bool> broadcastScalableFlags(destShape.size());
4468  for (const auto &pos : llvm::enumerate(permutation)) {
4469  broadcastShape[pos.value()] = destShape[pos.index()];
4470  broadcastScalableFlags[pos.value()] =
4471  readOp.getVectorType().getScalableDims()[pos.index()];
4472  }
4473  VectorType broadcastedType = VectorType::get(
4474  broadcastShape, defWrite.getVectorType().getElementType(),
4475  broadcastScalableFlags);
4476  vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
4477  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
4478  rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
4479  transposePerm);
4480  return success();
4481  }
4482 };
4483 } // namespace
4484 
4485 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4486  MLIRContext *context) {
4487  results.add<TransferReadAfterWriteToBroadcast>(context);
4488 }
4489 
4490 //===----------------------------------------------------------------------===//
4491 // TransferWriteOp
4492 //===----------------------------------------------------------------------===//
4493 
4494 /// 1. Builder with type inference.
4495 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4496  Value vector, Value dest, ValueRange indices,
4497  AffineMapAttr permutationMapAttr,
4498  /*optional*/ Value mask,
4499  /*optional*/ ArrayAttr inBoundsAttr) {
4500  Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
4501  build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4502  mask, inBoundsAttr);
4503 }
4504 
4505 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
4506 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4507  Value vector, Value dest, ValueRange indices,
4508  AffineMapAttr permutationMapAttr,
4509  /*optional*/ ArrayAttr inBoundsAttr) {
4510  build(builder, result, vector, dest, indices, permutationMapAttr,
4511  /*mask=*/Value(), inBoundsAttr);
4512 }
4513 
4514 /// 3. Builder with type inference that sets an empty mask (variant without
4515 /// attrs)
4516 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4517  Value vector, Value dest, ValueRange indices,
4518  AffineMap permutationMap,
4519  std::optional<ArrayRef<bool>> inBounds) {
4520  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4521  auto inBoundsAttr =
4522  (inBounds && !inBounds.value().empty())
4523  ? builder.getBoolArrayAttr(inBounds.value())
4525  llvm::cast<VectorType>(vector.getType()).getRank(), false));
4526  build(builder, result, vector, dest, indices, permutationMapAttr,
4527  /*mask=*/Value(), inBoundsAttr);
4528 }
4529 
4530 /// 4. Builder with type inference that sets an empty mask and sets permutation
4531 /// map to 'getMinorIdentityMap'.
4532 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4533  Value vector, Value dest, ValueRange indices,
4534  std::optional<ArrayRef<bool>> inBounds) {
4535  auto vectorType = llvm::cast<VectorType>(vector.getType());
4536  AffineMap permutationMap = getTransferMinorIdentityMap(
4537  llvm::cast<ShapedType>(dest.getType()), vectorType);
4538  build(builder, result, vector, dest, indices, permutationMap, inBounds);
4539 }
4540 
4541 ParseResult TransferWriteOp::parse(OpAsmParser &parser,
4542  OperationState &result) {
4543  auto &builder = parser.getBuilder();
4544  SMLoc typesLoc;
4545  OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
4547  SmallVector<Type, 2> types;
4549  if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
4550  parser.parseOperand(sourceInfo) ||
4551  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
4552  return failure();
4553  ParseResult hasMask = parser.parseOptionalComma();
4554  if (hasMask.succeeded() && parser.parseOperand(maskInfo))
4555  return failure();
4556  if (parser.parseOptionalAttrDict(result.attributes) ||
4557  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4558  return failure();
4559  if (types.size() != 2)
4560  return parser.emitError(typesLoc, "requires two types");
4561  auto indexType = builder.getIndexType();
4562  VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4563  if (!vectorType)
4564  return parser.emitError(typesLoc, "requires vector type");
4565  ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4566  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4567  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4568  auto permMapAttrName =
4569  TransferWriteOp::getPermutationMapAttrName(result.name);
4570  auto permMapAttr = result.attributes.get(permMapAttrName);
4571  AffineMap permMap;
4572  if (!permMapAttr) {
4573  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4574  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4575  } else {
4576  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4577  }
4578  auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
4579  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4580  if (!inBoundsAttr) {
4581  result.addAttribute(inBoundsAttrName,
4582  builder.getBoolArrayAttr(
4583  SmallVector<bool>(permMap.getNumResults(), false)));
4584  }
4585  if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
4586  parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4587  parser.resolveOperands(indexInfo, indexType, result.operands))
4588  return failure();
4589  if (hasMask.succeeded()) {
4590  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4591  return parser.emitError(
4592  maskInfo.location, "does not support masks with vector element type");
4593  if (vectorType.getRank() != permMap.getNumResults()) {
4594  return parser.emitError(typesLoc,
4595  "expected the same rank for the vector and the "
4596  "results of the permutation map");
4597  }
4598  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4599  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4600  return failure();
4601  }
4602  result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4603  builder.getDenseI32ArrayAttr(
4604  {1, 1, static_cast<int32_t>(indexInfo.size()),
4605  static_cast<int32_t>(hasMask.succeeded())}));
4606  return failure(llvm::isa<RankedTensorType>(shapedType) &&
4607  parser.addTypeToList(shapedType, result.types));
4608 }
4609 
4611  p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
4612  if (getMask())
4613  p << ", " << getMask();
4614  printTransferAttrs(p, *this);
4615  p << " : " << getVectorType() << ", " << getShapedType();
4616 }
4617 
4618 LogicalResult TransferWriteOp::verify() {
4619  // Consistency of elemental types in shape and vector.
4620  ShapedType shapedType = getShapedType();
4621  VectorType vectorType = getVectorType();
4622  VectorType maskType = getMaskType();
4623  auto permutationMap = getPermutationMap();
4624  VectorType inferredMaskType =
4625  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4626  : VectorType();
4627 
4628  if (llvm::size(getIndices()) != shapedType.getRank())
4629  return emitOpError("requires ") << shapedType.getRank() << " indices";
4630 
4631  // We do not allow broadcast dimensions on TransferWriteOps for the moment,
4632  // as the semantics is unclear. This can be revisited later if necessary.
4633  if (hasBroadcastDim())
4634  return emitOpError("should not have broadcast dimensions");
4635 
4636  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4637  shapedType, vectorType, maskType,
4638  inferredMaskType, permutationMap, getInBounds())))
4639  return failure();
4640 
4641  return verifyPermutationMap(permutationMap,
4642  [&](Twine t) { return emitOpError(t); });
4643 }
4644 
4645 // MaskableOpInterface methods.
4646 
4647 /// Returns the mask type expected by this operation. Mostly used for
4648 /// verification purposes.
4649 Type TransferWriteOp::getExpectedMaskType() {
4650  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4651 }
4652 
4653 /// Fold:
4654 /// ```
4655 /// %t1 = ...
4656 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
4657 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4658 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
4659 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4660 /// ```
4661 ///
4662 /// into:
4663 ///
4664 /// ```
4665 /// %t0
4666 /// ```
4667 ///
4668 /// The producer of t1 may or may not be DCE'd depending on whether it is a
4669 /// block argument or has side effects.
4670 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4672  SmallVectorImpl<OpFoldResult> &results) {
4673  // TODO: support 0-d corner case.
4674  if (write.getTransferRank() == 0)
4675  return failure();
4676  auto rankedTensorType =
4677  llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4678  // If not operating on tensors, bail.
4679  if (!rankedTensorType)
4680  return failure();
4681  // If no read, bail.
4682  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4683  if (!read)
4684  return failure();
4685  // TODO: support 0-d corner case.
4686  if (read.getTransferRank() == 0)
4687  return failure();
4688  // For now, only accept minor identity. Future: composition is minor identity.
4689  if (!read.getPermutationMap().isMinorIdentity() ||
4690  !write.getPermutationMap().isMinorIdentity())
4691  return failure();
4692  // Bail on mismatching ranks.
4693  if (read.getTransferRank() != write.getTransferRank())
4694  return failure();
4695  // Bail on potential out-of-bounds accesses.
4696  if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4697  return failure();
4698  // Tensor types must be the same.
4699  if (read.getSource().getType() != rankedTensorType)
4700  return failure();
4701  // Vector types must be the same.
4702  if (read.getVectorType() != write.getVectorType())
4703  return failure();
4704  // Vector and Tensor shapes must match.
4705  if (read.getVectorType().getShape() != rankedTensorType.getShape())
4706  return failure();
4707  // If any index is nonzero.
4708  auto isNotConstantZero = [](Value v) {
4709  auto cstOp = getConstantIntValue(v);
4710  return !cstOp.has_value() || cstOp.value() != 0;
4711  };
4712  if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4713  llvm::any_of(write.getIndices(), isNotConstantZero))
4714  return failure();
4715  // Success.
4716  results.push_back(read.getSource());
4717  return success();
4718 }
4719 
4720 static bool checkSameValueWAR(vector::TransferReadOp read,
4721  vector::TransferWriteOp write) {
4722  return read.getSource() == write.getSource() &&
4723  read.getIndices() == write.getIndices() &&
4724  read.getPermutationMap() == write.getPermutationMap() &&
4725  read.getVectorType() == write.getVectorType() && !read.getMask() &&
4726  !write.getMask();
4727 }
4728 /// Fold transfer_write write after read:
4729 /// ```
4730 /// %t0 = ...
4731 /// %v = vector.transfer_read %t0[%c0...] :
4732 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4733 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
4734 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4735 /// ```
4736 ///
4737 /// into:
4738 ///
4739 /// ```
4740 /// %t0
4741 /// ```
4742 static LogicalResult foldWAR(TransferWriteOp write,
4743  SmallVectorImpl<OpFoldResult> &results) {
4744  if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4745  return failure();
4746  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4747  if (!read)
4748  return failure();
4749 
4750  if (!checkSameValueWAR(read, write))
4751  return failure();
4752  results.push_back(read.getSource());
4753  return success();
4754 }
4755 
4756 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4757  SmallVectorImpl<OpFoldResult> &results) {
4758  if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
4759  return success();
4760  if (succeeded(foldWAR(*this, results)))
4761  return success();
4762  if (succeeded(foldTransferInBoundsAttribute(*this)))
4763  return success();
4764  if (succeeded(foldTransferFullMask(*this)))
4765  return success();
4766  return memref::foldMemRefCast(*this);
4767 }
4768 
4769 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4770  return llvm::to_vector<4>(getVectorType().getShape());
4771 }
4772 
4773 void TransferWriteOp::getEffects(
4775  &effects) {
4776  if (llvm::isa<MemRefType>(getShapedType()))
4777  effects.emplace_back(MemoryEffects::Write::get(), &getSourceMutable(),
4779 }
4780 
4781 Speculation::Speculatability TransferWriteOp::getSpeculatability() {
4782  if (hasPureTensorSemantics())
4785 }
4786 
4787 namespace {
4788 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
4789 /// DCE
4790 /// ```
4791 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4792 /// : vector<1x4xf32>, tensor<4x4xf32>
4793 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
4794 /// : vector<1x4xf32>, tensor<4x4xf32>
4795 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4796 /// : vector<1x4xf32>, tensor<4x4xf32>
4797 /// ```
4798 ///
4799 /// into:
4800 ///
4801 /// ```
4802 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4803 /// : vector<1x4xf32>, tensor<4x4xf32>
4804 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
4805 /// : vector<1x4xf32>, tensor<4x4xf32>
4806 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4807 /// : vector<1x4xf32>, tensor<4x4xf32>
4808 /// ```
4809 ///
4810 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
4811 /// any other uses.
4812 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
4813 public:
4815  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
4816  PatternRewriter &rewriter) const override {
4817  if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4818  return failure();
4819  vector::TransferWriteOp writeToModify = writeOp;
4820 
4821  auto defWrite =
4822  writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4823  while (defWrite) {
4824  if (checkSameValueWAW(writeOp, defWrite)) {
4825  rewriter.modifyOpInPlace(writeToModify, [&]() {
4826  writeToModify.getSourceMutable().assign(defWrite.getSource());
4827  });
4828  return success();
4829  }
4831  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4832  cast<VectorTransferOpInterface>(writeOp.getOperation())))
4833  break;
4834  // If the previous write op doesn't have any other use we an safely look
4835  // at the previous store to see if it can be removed.
4836  if (!defWrite->hasOneUse())
4837  break;
4838  writeToModify = defWrite;
4839  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4840  }
4841  return failure();
4842  }
4843 };
4844 
4845 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
4846 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
4847 /// overwritten and inserted into another tensor. After this rewrite, the
4848 /// operations bufferize in-place since all of them work on the same slice.
4849 ///
4850 /// For example:
4851 /// ```mlir
4852 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
4853 /// : vector<8x16xf32>, tensor<8x16xf32>
4854 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
4855 /// : tensor<8x16xf32> to tensor<?x?xf32>
4856 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4857 /// : tensor<?x?xf32> into tensor<27x37xf32>
4858 /// ```
4859 /// folds to
4860 /// ```mlir
4861 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4862 /// : tensor<27x37xf32> to tensor<?x?xf32>
4863 /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
4864 /// : vector<8x16xf32>, tensor<?x?xf32>
4865 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4866 /// : tensor<?x?xf32> into tensor<27x37xf32>
4867 /// ```
4868 struct SwapExtractSliceOfTransferWrite
4869  : public OpRewritePattern<tensor::InsertSliceOp> {
4870 public:
4872 
4873  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
4874  PatternRewriter &rewriter) const override {
4875  if (!insertOp.hasUnitStride())
4876  return failure();
4877  auto extractOp =
4878  insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4879  if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4880  return failure();
4881  auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4882  if (!transferOp || !transferOp->hasOneUse())
4883  return failure();
4884 
4885  // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
4886  // rank-reducing.
4887  if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4888  return rewriter.notifyMatchFailure(insertOp,
4889  "use-def chain is rank-reducing");
4890  }
4891 
4892  // Fail if tensor::ExtractSliceOp has non-zero offset.
4893  if (!extractOp.hasZeroOffset()) {
4894  return rewriter.notifyMatchFailure(insertOp,
4895  "ExtractSliceOp has non-zero offset");
4896  }
4897 
4898  // Fail if tensor::TransferWriteOp has non-zero offset.
4899  if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
4900  return getConstantIntValue(value) == static_cast<int64_t>(0);
4901  })) {
4902  return rewriter.notifyMatchFailure(insertOp,
4903  "TranferWriteOp has non-zero offset");
4904  }
4905 
4906  // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
4907  if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
4908  return rewriter.notifyMatchFailure(
4909  insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
4910  }
4911 
4912  for (auto [insertSize, extractSize] :
4913  llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
4914  if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
4915  return rewriter.notifyMatchFailure(
4916  insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
4917  }
4918  }
4919 
4920  // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
4921  assert(transferOp.getVectorType().hasStaticShape() &&
4922  "expected vector to have a static shape");
4923  ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
4925  transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
4926  if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
4927  return rewriter.notifyMatchFailure(
4928  insertOp, "TransferWriteOp may not write the full tensor.");
4929  }
4930 
4931  // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
4932  // Set all in_bounds to false and let the folder infer them.
4933  SmallVector<bool> newInBounds(vectorShape.size(), false);
4934  auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
4935  extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
4936  insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
4937  insertOp.getMixedStrides());
4938  auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
4939  transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
4940  transferOp.getIndices(), transferOp.getPermutationMapAttr(),
4941  rewriter.getBoolArrayAttr(newInBounds));
4942  rewriter.modifyOpInPlace(insertOp, [&]() {
4943  insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
4944  });
4945  return success();
4946  }
4947 };
4948 
4949 } // namespace
4950 
4951 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
4952  MLIRContext *context) {
4953  results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
4954 }
4955 
4956 //===----------------------------------------------------------------------===//
4957 // LoadOp
4958 //===----------------------------------------------------------------------===//
4959 
4960 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
4961  VectorType vecTy,
4962  MemRefType memRefTy) {
4963  // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
4964  // need any strides limitations.
4965  if (!vecTy.isScalable() &&
4966  (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
4967  return success();
4968 
4969  if (!isLastMemrefDimUnitStride(memRefTy))
4970  return op->emitOpError("most minor memref dim must have unit stride");
4971  return success();
4972 }
4973 
4974 LogicalResult vector::LoadOp::verify() {
4975  VectorType resVecTy = getVectorType();
4976  MemRefType memRefTy = getMemRefType();
4977 
4978  if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
4979  return failure();
4980 
4981  // Checks for vector memrefs.
4982  Type memElemTy = memRefTy.getElementType();
4983  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4984  if (memVecTy != resVecTy)
4985  return emitOpError("base memref and result vector types should match");
4986  memElemTy = memVecTy.getElementType();
4987  }
4988 
4989  if (resVecTy.getElementType() != memElemTy)
4990  return emitOpError("base and result element types should match");
4991  if (llvm::size(getIndices()) != memRefTy.getRank())
4992  return emitOpError("requires ") << memRefTy.getRank() << " indices";
4993  return success();
4994 }
4995 
4996 OpFoldResult LoadOp::fold(FoldAdaptor) {
4997  if (succeeded(memref::foldMemRefCast(*this)))
4998  return getResult();
4999  return OpFoldResult();
5000 }
5001 
5002 //===----------------------------------------------------------------------===//
5003 // StoreOp
5004 //===----------------------------------------------------------------------===//
5005 
5006 LogicalResult vector::StoreOp::verify() {
5007  VectorType valueVecTy = getVectorType();
5008  MemRefType memRefTy = getMemRefType();
5009 
5010  if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
5011  return failure();
5012 
5013  // Checks for vector memrefs.
5014  Type memElemTy = memRefTy.getElementType();
5015  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5016  if (memVecTy != valueVecTy)
5017  return emitOpError(
5018  "base memref and valueToStore vector types should match");
5019  memElemTy = memVecTy.getElementType();
5020  }
5021 
5022  if (valueVecTy.getElementType() != memElemTy)
5023  return emitOpError("base and valueToStore element type should match");
5024  if (llvm::size(getIndices()) != memRefTy.getRank())
5025  return emitOpError("requires ") << memRefTy.getRank() << " indices";
5026  return success();
5027 }
5028 
5029 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5030  SmallVectorImpl<OpFoldResult> &results) {
5031  return memref::foldMemRefCast(*this);
5032 }
5033 
5034 //===----------------------------------------------------------------------===//
5035 // MaskedLoadOp
5036 //===----------------------------------------------------------------------===//
5037 
5038 LogicalResult MaskedLoadOp::verify() {
5039  VectorType maskVType = getMaskVectorType();
5040  VectorType passVType = getPassThruVectorType();
5041  VectorType resVType = getVectorType();
5042  MemRefType memType = getMemRefType();
5043 
5044  if (resVType.getElementType() != memType.getElementType())
5045  return emitOpError("base and result element type should match");
5046  if (llvm::size(getIndices()) != memType.getRank())
5047  return emitOpError("requires ") << memType.getRank() << " indices";
5048  if (resVType.getShape() != maskVType.getShape())
5049  return emitOpError("expected result shape to match mask shape");
5050  if (resVType != passVType)
5051  return emitOpError("expected pass_thru of same type as result type");
5052  return success();
5053 }
5054 
5055 namespace {
5056 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
5057 public:
5059  LogicalResult matchAndRewrite(MaskedLoadOp load,
5060  PatternRewriter &rewriter) const override {
5061  switch (getMaskFormat(load.getMask())) {
5062  case MaskFormat::AllTrue:
5063  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5064  load, load.getType(), load.getBase(), load.getIndices());
5065  return success();
5066  case MaskFormat::AllFalse:
5067  rewriter.replaceOp(load, load.getPassThru());
5068  return success();
5069  case MaskFormat::Unknown:
5070  return failure();
5071  }
5072  llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
5073  }
5074 };
5075 } // namespace
5076 
5077 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5078  MLIRContext *context) {
5079  results.add<MaskedLoadFolder>(context);
5080 }
5081 
5082 OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5083  if (succeeded(memref::foldMemRefCast(*this)))
5084  return getResult();
5085  return OpFoldResult();
5086 }
5087 
5088 //===----------------------------------------------------------------------===//
5089 // MaskedStoreOp
5090 //===----------------------------------------------------------------------===//
5091 
5092 LogicalResult MaskedStoreOp::verify() {
5093  VectorType maskVType = getMaskVectorType();
5094  VectorType valueVType = getVectorType();
5095  MemRefType memType = getMemRefType();
5096 
5097  if (valueVType.getElementType() != memType.getElementType())
5098  return emitOpError("base and valueToStore element type should match");
5099  if (llvm::size(getIndices()) != memType.getRank())
5100  return emitOpError("requires ") << memType.getRank() << " indices";
5101  if (valueVType.getShape() != maskVType.getShape())
5102  return emitOpError("expected valueToStore shape to match mask shape");
5103  return success();
5104 }
5105 
5106 namespace {
5107 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
5108 public:
5110  LogicalResult matchAndRewrite(MaskedStoreOp store,
5111  PatternRewriter &rewriter) const override {
5112  switch (getMaskFormat(store.getMask())) {
5113  case MaskFormat::AllTrue:
5114  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5115  store, store.getValueToStore(), store.getBase(), store.getIndices());
5116  return success();
5117  case MaskFormat::AllFalse:
5118  rewriter.eraseOp(store);
5119  return success();
5120  case MaskFormat::Unknown:
5121  return failure();
5122  }
5123  llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
5124  }
5125 };
5126 } // namespace
5127 
5128 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5129  MLIRContext *context) {
5130  results.add<MaskedStoreFolder>(context);
5131 }
5132 
5133 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5134  SmallVectorImpl<OpFoldResult> &results) {
5135  return memref::foldMemRefCast(*this);
5136 }
5137 
5138 //===----------------------------------------------------------------------===//
5139 // GatherOp
5140 //===----------------------------------------------------------------------===//
5141 
5142 LogicalResult GatherOp::verify() {
5143  VectorType indVType = getIndexVectorType();
5144  VectorType maskVType = getMaskVectorType();
5145  VectorType resVType = getVectorType();
5146  ShapedType baseType = getBaseType();
5147 
5148  if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5149  return emitOpError("requires base to be a memref or ranked tensor type");
5150 
5151  if (resVType.getElementType() != baseType.getElementType())
5152  return emitOpError("base and result element type should match");
5153  if (llvm::size(getIndices()) != baseType.getRank())
5154  return emitOpError("requires ") << baseType.getRank() << " indices";
5155  if (resVType.getShape() != indVType.getShape())
5156  return emitOpError("expected result dim to match indices dim");
5157  if (resVType.getShape() != maskVType.getShape())
5158  return emitOpError("expected result dim to match mask dim");
5159  if (resVType != getPassThruVectorType())
5160  return emitOpError("expected pass_thru of same type as result type");
5161  return success();
5162 }
5163 
5164 // MaskableOpInterface methods.
5165 
5166 /// Returns the mask type expected by this operation. Mostly used for
5167 /// verification purposes. It requires the operation to be vectorized."
5168 Type GatherOp::getExpectedMaskType() {
5169  auto vecType = this->getIndexVectorType();
5170  return VectorType::get(vecType.getShape(),
5171  IntegerType::get(vecType.getContext(), /*width=*/1),
5172  vecType.getScalableDims());
5173 }
5174 
5175 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5176  return llvm::to_vector<4>(getVectorType().getShape());
5177 }
5178 
5179 namespace {
5180 class GatherFolder final : public OpRewritePattern<GatherOp> {
5181 public:
5183  LogicalResult matchAndRewrite(GatherOp gather,
5184  PatternRewriter &rewriter) const override {
5185  switch (getMaskFormat(gather.getMask())) {
5186  case MaskFormat::AllTrue:
5187  return failure(); // no unmasked equivalent
5188  case MaskFormat::AllFalse:
5189  rewriter.replaceOp(gather, gather.getPassThru());
5190  return success();
5191  case MaskFormat::Unknown:
5192  return failure();
5193  }
5194  llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
5195  }
5196 };
5197 } // namespace
5198 
5199 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
5200  MLIRContext *context) {
5201  results.add<GatherFolder>(context);
5202 }
5203 
5204 //===----------------------------------------------------------------------===//
5205 // ScatterOp
5206 //===----------------------------------------------------------------------===//
5207 
5208 LogicalResult ScatterOp::verify() {
5209  VectorType indVType = getIndexVectorType();
5210  VectorType maskVType = getMaskVectorType();
5211  VectorType valueVType = getVectorType();
5212  MemRefType memType = getMemRefType();
5213 
5214  if (valueVType.getElementType() != memType.getElementType())
5215  return emitOpError("base and valueToStore element type should match");
5216  if (llvm::size(getIndices()) != memType.getRank())
5217  return emitOpError("requires ") << memType.getRank() << " indices";
5218  if (valueVType.getDimSize(0) != indVType.getDimSize(0))
5219  return emitOpError("expected valueToStore dim to match indices dim");
5220  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5221  return emitOpError("expected valueToStore dim to match mask dim");
5222  return success();
5223 }
5224 
5225 namespace {
5226 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
5227 public:
5229  LogicalResult matchAndRewrite(ScatterOp scatter,
5230  PatternRewriter &rewriter) const override {
5231  switch (getMaskFormat(scatter.getMask())) {
5232  case MaskFormat::AllTrue:
5233  return failure(); // no unmasked equivalent
5234  case MaskFormat::AllFalse:
5235  rewriter.eraseOp(scatter);
5236  return success();
5237  case MaskFormat::Unknown:
5238  return failure();
5239  }
5240  llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
5241  }
5242 };
5243 } // namespace
5244 
5245 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
5246  MLIRContext *context) {
5247  results.add<ScatterFolder>(context);
5248 }
5249 
5250 //===----------------------------------------------------------------------===//
5251 // ExpandLoadOp
5252 //===----------------------------------------------------------------------===//
5253 
5254 LogicalResult ExpandLoadOp::verify() {
5255  VectorType maskVType = getMaskVectorType();
5256  VectorType passVType = getPassThruVectorType();
5257  VectorType resVType = getVectorType();
5258  MemRefType memType = getMemRefType();
5259 
5260  if (resVType.getElementType() != memType.getElementType())
5261  return emitOpError("base and result element type should match");
5262  if (llvm::size(getIndices()) != memType.getRank())
5263  return emitOpError("requires ") << memType.getRank() << " indices";
5264  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5265  return emitOpError("expected result dim to match mask dim");
5266  if (resVType != passVType)
5267  return emitOpError("expected pass_thru of same type as result type");
5268  return success();
5269 }
5270 
5271 namespace {
5272 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
5273 public:
5275  LogicalResult matchAndRewrite(ExpandLoadOp expand,
5276  PatternRewriter &rewriter) const override {
5277  switch (getMaskFormat(expand.getMask())) {
5278  case MaskFormat::AllTrue:
5279  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5280  expand, expand.getType(), expand.getBase(), expand.getIndices());
5281  return success();
5282  case MaskFormat::AllFalse:
5283  rewriter.replaceOp(expand, expand.getPassThru());
5284  return success();
5285  case MaskFormat::Unknown:
5286  return failure();
5287  }
5288  llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
5289  }
5290 };
5291 } // namespace
5292 
5293 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5294  MLIRContext *context) {
5295  results.add<ExpandLoadFolder>(context);
5296 }
5297 
5298 //===----------------------------------------------------------------------===//
5299 // CompressStoreOp
5300 //===----------------------------------------------------------------------===//
5301 
5302 LogicalResult CompressStoreOp::verify() {
5303  VectorType maskVType = getMaskVectorType();
5304  VectorType valueVType = getVectorType();
5305  MemRefType memType = getMemRefType();
5306 
5307  if (valueVType.getElementType() != memType.getElementType())
5308  return emitOpError("base and valueToStore element type should match");
5309  if (llvm::size(getIndices()) != memType.getRank())
5310  return emitOpError("requires ") << memType.getRank() << " indices";
5311  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5312  return emitOpError("expected valueToStore dim to match mask dim");
5313  return success();
5314 }
5315 
5316 namespace {
5317 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
5318 public:
5320  LogicalResult matchAndRewrite(CompressStoreOp compress,
5321  PatternRewriter &rewriter) const override {
5322  switch (getMaskFormat(compress.getMask())) {
5323  case MaskFormat::AllTrue:
5324  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5325  compress, compress.getValueToStore(), compress.getBase(),
5326  compress.getIndices());
5327  return success();
5328  case MaskFormat::AllFalse:
5329  rewriter.eraseOp(compress);
5330  return success();
5331  case MaskFormat::Unknown:
5332  return failure();
5333  }
5334  llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
5335  }
5336 };
5337 } // namespace
5338 
5339 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5340  MLIRContext *context) {
5341  results.add<CompressStoreFolder>(context);
5342 }
5343 
5344 //===----------------------------------------------------------------------===//
5345 // ShapeCastOp
5346 //===----------------------------------------------------------------------===//
5347 
5348 void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
5349  SetIntRangeFn setResultRanges) {
5350  setResultRanges(getResult(), argRanges.front());
5351 }
5352 
5353 /// Returns true if each element of 'a' is equal to the product of a contiguous
5354 /// sequence of the elements of 'b'. Returns false otherwise.
5355 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
5356  unsigned rankA = a.size();
5357  unsigned rankB = b.size();
5358  assert(rankA < rankB);
5359 
5360  auto isOne = [](int64_t v) { return v == 1; };
5361 
5362  // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5363  // casted to a 0-d vector.
5364  if (rankA == 0 && llvm::all_of(b, isOne))
5365  return true;
5366 
5367  unsigned i = 0;
5368  unsigned j = 0;
5369  while (i < rankA && j < rankB) {
5370  int64_t dimA = a[i];
5371  int64_t dimB = 1;
5372  while (dimB < dimA && j < rankB)
5373  dimB *= b[j++];
5374  if (dimA != dimB)
5375  break;
5376  ++i;
5377 
5378  // Handle the case when trailing dimensions are of size 1.
5379  // Include them into the contiguous sequence.
5380  if (i < rankA && llvm::all_of(a.slice(i), isOne))
5381  i = rankA;
5382  if (j < rankB && llvm::all_of(b.slice(j), isOne))
5383  j = rankB;
5384  }
5385 
5386  return i == rankA && j == rankB;
5387 }
5388 
5389 static LogicalResult verifyVectorShapeCast(Operation *op,
5390  VectorType sourceVectorType,
5391  VectorType resultVectorType) {
5392  // Check that element type is the same.
5393  if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5394  return op->emitOpError("source/result vectors must have same element type");
5395  auto sourceShape = sourceVectorType.getShape();
5396  auto resultShape = resultVectorType.getShape();
5397 
5398  // Check that product of source dim sizes matches product of result dim sizes.
5399  int64_t sourceDimProduct = std::accumulate(
5400  sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5401  int64_t resultDimProduct = std::accumulate(
5402  resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5403  if (sourceDimProduct != resultDimProduct)
5404  return op->emitOpError("source/result number of elements must match");
5405 
5406  // Check that expanding/contracting rank cases.
5407  unsigned sourceRank = sourceVectorType.getRank();
5408  unsigned resultRank = resultVectorType.getRank();
5409  if (sourceRank < resultRank) {
5410  if (!isValidShapeCast(sourceShape, resultShape))
5411  return op->emitOpError("invalid shape cast");
5412  } else if (sourceRank > resultRank) {
5413  if (!isValidShapeCast(resultShape, sourceShape))
5414  return op->emitOpError("invalid shape cast");
5415  }
5416 
5417  // Check that (non-)scalability is preserved
5418  int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5419  int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5420  if (sourceNScalableDims != resultNScalableDims)
5421  return op->emitOpError("different number of scalable dims at source (")
5422  << sourceNScalableDims << ") and result (" << resultNScalableDims
5423  << ")";
5424  sourceVectorType.getNumDynamicDims();
5425 
5426  return success();
5427 }
5428 
5429 LogicalResult ShapeCastOp::verify() {
5430  auto sourceVectorType =
5431  llvm::dyn_cast_or_null<VectorType>(getSource().getType());
5432  auto resultVectorType =
5433  llvm::dyn_cast_or_null<VectorType>(getResult().getType());
5434 
5435  // Check if source/result are of vector type.
5436  if (sourceVectorType && resultVectorType)
5437  return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
5438 
5439  return success();
5440 }
5441 
5442 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5443  // No-op shape cast.
5444  if (getSource().getType() == getResult().getType())
5445  return getSource();
5446 
5447  // Canceling shape casts.
5448  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5449  if (getResult().getType() == otherOp.getSource().getType())
5450  return otherOp.getSource();
5451 
5452  // Only allows valid transitive folding.
5453  VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5454  VectorType resultType = llvm::cast<VectorType>(getResult().getType());
5455  if (srcType.getRank() < resultType.getRank()) {
5456  if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5457  return {};
5458  } else if (srcType.getRank() > resultType.getRank()) {
5459  if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5460  return {};
5461  } else {
5462  return {};
5463  }
5464 
5465  setOperand(otherOp.getSource());
5466  return getResult();
5467  }
5468 
5469  // Cancelling broadcast and shape cast ops.
5470  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5471  if (bcastOp.getSourceType() == getType())
5472  return bcastOp.getSource();
5473  }
5474 
5475  return {};
5476 }
5477 
5478 namespace {
5479 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
5480 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
5481 public:
5483 
5484  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5485  PatternRewriter &rewriter) const override {
5486  auto constantOp =
5487  shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5488  if (!constantOp)
5489  return failure();
5490  // Only handle splat for now.
5491  auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5492  if (!dense)
5493  return failure();
5494  auto newAttr =
5495  DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
5496  dense.getSplatValue<Attribute>());
5497  rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
5498  return success();
5499  }
5500 };
5501 
5502 /// Helper function that computes a new vector type based on the input vector
5503 /// type by removing the trailing one dims:
5504 ///
5505 /// vector<4x1x1xi1> --> vector<4x1>
5506 ///
5507 static VectorType trimTrailingOneDims(VectorType oldType) {
5508  ArrayRef<int64_t> oldShape = oldType.getShape();
5509  ArrayRef<int64_t> newShape = oldShape;
5510 
5511  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
5512  ArrayRef<bool> newScalableDims = oldScalableDims;
5513 
5514  while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5515  newShape = newShape.drop_back(1);
5516  newScalableDims = newScalableDims.drop_back(1);
5517  }
5518 
5519  // Make sure we have at least 1 dimension.
5520  // TODO: Add support for 0-D vectors.
5521  if (newShape.empty()) {
5522  newShape = oldShape.take_back();
5523  newScalableDims = oldScalableDims.take_back();
5524  }
5525 
5526  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5527 }
5528 
5529 /// Folds qualifying shape_cast(create_mask) into a new create_mask
5530 ///
5531 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
5532 /// dimension. If the input vector comes from `vector.create_mask` for which
5533 /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
5534 /// to fold shape_cast into create_mask.
5535 ///
5536 /// BEFORE:
5537 /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
5538 /// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
5539 /// AFTER:
5540 /// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
5541 class ShapeCastCreateMaskFolderTrailingOneDim final
5542  : public OpRewritePattern<ShapeCastOp> {
5543 public:
5545 
5546  LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
5547  PatternRewriter &rewriter) const override {
5548  Value shapeOpSrc = shapeOp->getOperand(0);
5549  auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
5550  auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
5551  if (!createMaskOp && !constantMaskOp)
5552  return failure();
5553 
5554  VectorType shapeOpResTy = shapeOp.getResultVectorType();
5555  VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5556 
5557  VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5558  if (newVecType != shapeOpResTy)
5559  return failure();
5560 
5561  auto numDimsToDrop =
5562  shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5563 
5564  // No unit dims to drop
5565  if (!numDimsToDrop)
5566  return failure();
5567 
5568  if (createMaskOp) {
5569  auto maskOperands = createMaskOp.getOperands();
5570  auto numMaskOperands = maskOperands.size();
5571 
5572  // Check every mask dim size to see whether it can be dropped
5573  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5574  --i) {
5575  auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5576  if (!constant || (constant.value() != 1))
5577  return failure();
5578  }
5579  SmallVector<Value> newMaskOperands =
5580  maskOperands.drop_back(numDimsToDrop);
5581 
5582  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
5583  newMaskOperands);
5584  return success();
5585  }
5586 
5587  if (constantMaskOp) {
5588  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5589  auto numMaskOperands = maskDimSizes.size();
5590 
5591  // Check every mask dim size to see whether it can be dropped
5592  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5593  --i) {
5594  if (maskDimSizes[i] != 1)
5595  return failure();
5596  }
5597 
5598  auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5599  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
5600  newMaskOperands);
5601  return success();
5602  }
5603 
5604  return failure();
5605  }
5606 };
5607 
5608 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
5609 /// This only applies when the shape of the broadcast source
5610 /// 1. is a suffix of the shape of the result (i.e. when broadcast without
5611 /// reshape is expressive enough to capture the result in a single op), or
5612 /// 2. has the same element count as the shape cast result.
5613 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5614 public:
5616 
5617  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5618  PatternRewriter &rewriter) const override {
5619  auto broadcastOp =
5620  shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5621  if (!broadcastOp)
5622  return failure();
5623 
5624  ArrayRef<int64_t> broadcastSourceShape;
5625  if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5626  broadcastSourceShape = srcType.getShape();
5627  ArrayRef<int64_t> shapeCastTargetShape =
5628  shapeCastOp.getResultVectorType().getShape();
5629 
5630  // If `broadcastSourceShape` is a suffix of the result, we can just replace
5631  // with a broadcast to the final shape.
5632  if (broadcastSourceShape ==
5633  shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5634  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5635  shapeCastOp, shapeCastOp.getResultVectorType(),
5636  broadcastOp.getSource());
5637  return success();
5638  }
5639 
5640  // Otherwise, if the final result has the same element count, we can replace
5641  // with a shape cast.
5642  if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5643  if (srcType.getNumElements() ==
5644  shapeCastOp.getResultVectorType().getNumElements()) {
5645  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
5646  shapeCastOp, shapeCastOp.getResultVectorType(),
5647  broadcastOp.getSource());
5648  return success();
5649  }
5650  }
5651 
5652  return failure();
5653  }
5654 };
5655 
5656 } // namespace
5657 
5658 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
5659  MLIRContext *context) {
5660  results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5661  ShapeCastBroadcastFolder>(context);
5662 }
5663 
5664 //===----------------------------------------------------------------------===//
5665 // VectorBitCastOp
5666 //===----------------------------------------------------------------------===//
5667 
5668 LogicalResult BitCastOp::verify() {
5669  auto sourceVectorType = getSourceVectorType();
5670  auto resultVectorType = getResultVectorType();
5671 
5672  for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5673  if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5674  return emitOpError("dimension size mismatch at: ") << i;
5675  }
5676 
5677  DataLayout dataLayout = DataLayout::closest(*this);
5678  auto sourceElementBits =
5679  dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
5680  auto resultElementBits =
5681  dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
5682 
5683  if (sourceVectorType.getRank() == 0) {
5684  if (sourceElementBits != resultElementBits)
5685  return emitOpError("source/result bitwidth of the 0-D vector element "
5686  "types must be equal");
5687  } else if (sourceElementBits * sourceVectorType.getShape().back() !=
5688  resultElementBits * resultVectorType.getShape().back()) {
5689  return emitOpError(
5690  "source/result bitwidth of the minor 1-D vectors must be equal");
5691  }
5692 
5693  return success();
5694 }
5695 
5696 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
5697  // Nop cast.
5698  if (getSource().getType() == getResult().getType())
5699  return getSource();
5700 
5701  // Canceling bitcasts.
5702  if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5703  if (getResult().getType() == otherOp.getSource().getType())
5704  return otherOp.getSource();
5705 
5706  setOperand(otherOp.getSource());
5707  return getResult();
5708  }
5709 
5710  Attribute sourceConstant = adaptor.getSource();
5711  if (!sourceConstant)
5712  return {};
5713 
5714  Type srcElemType = getSourceVectorType().getElementType();
5715  Type dstElemType = getResultVectorType().getElementType();
5716 
5717  if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5718  if (floatPack.isSplat()) {
5719  auto splat = floatPack.getSplatValue<FloatAttr>();
5720 
5721  // Casting fp16 into fp32.
5722  if (srcElemType.isF16() && dstElemType.isF32()) {
5723  uint32_t bits = static_cast<uint32_t>(
5724  splat.getValue().bitcastToAPInt().getZExtValue());
5725  // Duplicate the 16-bit pattern.
5726  bits = (bits << 16) | (bits & 0xffff);
5727  APInt intBits(32, bits);
5728  APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5729  return DenseElementsAttr::get(getResultVectorType(), floatBits);
5730  }
5731  }
5732  }
5733 
5734  if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5735  if (intPack.isSplat()) {
5736  auto splat = intPack.getSplatValue<IntegerAttr>();
5737 
5738  if (llvm::isa<IntegerType>(dstElemType)) {
5739  uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
5740  uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
5741 
5742  // Casting to a larger integer bit width.
5743  if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5744  APInt intBits = splat.getValue().zext(dstBitWidth);
5745 
5746  // Duplicate the lower width element.
5747  for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5748  intBits = (intBits << srcBitWidth) | intBits;
5749  return DenseElementsAttr::get(getResultVectorType(), intBits);
5750  }
5751  }
5752  }
5753  }
5754 
5755  return {};
5756 }
5757 
5758 //===----------------------------------------------------------------------===//
5759 // TypeCastOp
5760 //===----------------------------------------------------------------------===//
5761 
5762 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
5763  auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5764  SmallVector<int64_t, 8> res(memRefType.getShape());
5765  if (vectorType)
5766  res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5767  return res;
5768 }
5769 
5770 /// Build the canonical memRefType with a single vector.
5771 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
5772 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
5773  Value source) {
5774  result.addOperands(source);
5775  MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
5776  VectorType vectorType =
5777  VectorType::get(extractShape(memRefType),
5779  result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
5780  memRefType.getMemorySpace()));
5781 }
5782 
5783 LogicalResult TypeCastOp::verify() {
5784  MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType());
5785  if (!canonicalType.getLayout().isIdentity())
5786  return emitOpError("expects operand to be a memref with identity layout");
5787  if (!getResultMemRefType().getLayout().isIdentity())
5788  return emitOpError("expects result to be a memref with identity layout");
5789  if (getResultMemRefType().getMemorySpace() !=
5790  getMemRefType().getMemorySpace())
5791  return emitOpError("expects result in same memory space");
5792 
5793  auto sourceType = getMemRefType();
5794  auto resultType = getResultMemRefType();
5795  if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
5797  return emitOpError(
5798  "expects result and operand with same underlying scalar type: ")
5799  << resultType;
5800  if (extractShape(sourceType) != extractShape(resultType))
5801  return emitOpError(
5802  "expects concatenated result and operand shapes to be equal: ")
5803  << resultType;
5804  return success();
5805 }
5806 
5807 //===----------------------------------------------------------------------===//
5808 // TransposeOp
5809 //===----------------------------------------------------------------------===//
5810 
5811 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
5812  Value vector, ArrayRef<int64_t> permutation) {
5813  VectorType vt = llvm::cast<VectorType>(vector.getType());
5814  SmallVector<int64_t, 4> transposedShape(vt.getRank());
5815  SmallVector<bool, 4> transposedScalableDims(vt.getRank());
5816  for (unsigned i = 0; i < permutation.size(); ++i) {
5817  transposedShape[i] = vt.getShape()[permutation[i]];
5818  transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5819  }
5820 
5821  result.addOperands(vector);
5822  result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
5823  transposedScalableDims));
5824  result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
5825  builder.getDenseI64ArrayAttr(permutation));
5826 }
5827 
5828 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5829  // Eliminate splat constant transpose ops.
5830  if (auto attr =
5831  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
5832  if (attr.isSplat())
5833  return attr.reshape(getResultVectorType());
5834 
5835  // Eliminate identity transpose ops. This happens when the dimensions of the
5836  // input vector remain in their original order after the transpose operation.
5837  ArrayRef<int64_t> perm = getPermutation();
5838 
5839  // Check if the permutation of the dimensions contains sequential values:
5840  // {0, 1, 2, ...}.
5841  for (int64_t i = 0, e = perm.size(); i < e; i++) {
5842  if (perm[i] != i)
5843  return {};
5844  }
5845 
5846  return getVector();
5847 }
5848 
5849 LogicalResult vector::TransposeOp::verify() {
5850  VectorType vectorType = getSourceVectorType();
5851  VectorType resultType = getResultVectorType();
5852  int64_t rank = resultType.getRank();
5853  if (vectorType.getRank() != rank)
5854  return emitOpError("vector result rank mismatch: ") << rank;
5855  // Verify transposition array.
5856  ArrayRef<int64_t> perm = getPermutation();
5857  int64_t size = perm.size();
5858  if (rank != size)
5859  return emitOpError("transposition length mismatch: ") << size;
5860  SmallVector<bool, 8> seen(rank, false);
5861  for (const auto &ta : llvm::enumerate(perm)) {
5862  if (ta.value() < 0 || ta.value() >= rank)
5863  return emitOpError("transposition index out of range: ") << ta.value();
5864  if (seen[ta.value()])
5865  return emitOpError("duplicate position index: ") << ta.value();
5866  seen[ta.value()] = true;
5867  if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
5868  return emitOpError("dimension size mismatch at: ") << ta.value();
5869  }
5870  return success();
5871 }
5872 
5873 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
5874  return llvm::to_vector<4>(getResultVectorType().getShape());
5875 }
5876 
5877 namespace {
5878 
5879 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
5880 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
5881 public:
5883 
5884  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5885  PatternRewriter &rewriter) const override {
5886  // Composes two permutations: result[i] = permutation1[permutation2[i]].
5887  auto composePermutations = [](ArrayRef<int64_t> permutation1,
5888  ArrayRef<int64_t> permutation2) {
5889  SmallVector<int64_t, 4> result;
5890  for (auto index : permutation2)
5891  result.push_back(permutation1[index]);
5892  return result;
5893  };
5894 
5895  // Return if the input of 'transposeOp' is not defined by another transpose.
5896  vector::TransposeOp parentTransposeOp =
5897  transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
5898  if (!parentTransposeOp)
5899  return failure();
5900 
5901  SmallVector<int64_t, 4> permutation = composePermutations(
5902  parentTransposeOp.getPermutation(), transposeOp.getPermutation());
5903  // Replace 'transposeOp' with a new transpose operation.
5904  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
5905  transposeOp, transposeOp.getResult().getType(),
5906  parentTransposeOp.getVector(), permutation);
5907  return success();
5908  }
5909 };
5910 
5911 // Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
5912 struct FoldTransposedScalarBroadcast final
5913  : public OpRewritePattern<vector::TransposeOp> {
5915 
5916  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5917  PatternRewriter &rewriter) const override {
5918  auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
5919  if (!bcastOp)
5920  return failure();
5921 
5922  auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
5923  if (!srcVectorType || srcVectorType.getNumElements() == 1) {
5924  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5925  transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
5926  return success();
5927  }
5928 
5929  return failure();
5930  }
5931 };
5932 
5933 // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
5934 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
5935 public:
5937 
5938  LogicalResult matchAndRewrite(TransposeOp transposeOp,
5939  PatternRewriter &rewriter) const override {
5940  auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
5941  if (!splatOp)
5942  return failure();
5943 
5944  rewriter.replaceOpWithNewOp<vector::SplatOp>(
5945  transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
5946  return success();
5947  }
5948 };
5949 
5950 /// Folds transpose(create_mask) into a new transposed create_mask.
5951 class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
5952 public:
5954 
5955  LogicalResult matchAndRewrite(TransposeOp transpOp,
5956  PatternRewriter &rewriter) const override {
5957  Value transposeSrc = transpOp.getVector();
5958  auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
5959  auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
5960  if (!createMaskOp && !constantMaskOp)
5961  return failure();
5962 
5963  // Get the transpose permutation and apply it to the vector.create_mask or
5964  // vector.constant_mask operands.
5965  ArrayRef<int64_t> permutation = transpOp.getPermutation();
5966 
5967  if (createMaskOp) {
5968  auto maskOperands = createMaskOp.getOperands();
5969  SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
5970  applyPermutationToVector(newOperands, permutation);
5971 
5972  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
5973  transpOp, transpOp.getResultVectorType(), newOperands);
5974  return success();
5975  }
5976 
5977  // ConstantMaskOp case.
5978  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5979  auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
5980 
5981  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
5982  transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
5983  return success();
5984  }
5985 };
5986 
5987 } // namespace
5988 
5989 void vector::TransposeOp::getCanonicalizationPatterns(
5990  RewritePatternSet &results, MLIRContext *context) {
5991  results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
5992  TransposeFolder, FoldTransposeSplat>(context);
5993 }
5994 
5995 //===----------------------------------------------------------------------===//
5996 // ConstantMaskOp
5997 //===----------------------------------------------------------------------===//
5998 
5999 void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
6000  VectorType type, ConstantMaskKind kind) {
6001  assert(kind == ConstantMaskKind::AllTrue ||
6002  kind == ConstantMaskKind::AllFalse);
6003  build(builder, result, type,
6004  kind == ConstantMaskKind::AllTrue
6005  ? type.getShape()
6006  : SmallVector<int64_t>(type.getRank(), 0));
6007 }
6008 
6009 LogicalResult ConstantMaskOp::verify() {
6010  auto resultType = llvm::cast<VectorType>(getResult().getType());
6011  // Check the corner case of 0-D vectors first.
6012  if (resultType.getRank() == 0) {
6013  if (getMaskDimSizes().size() != 1)
6014  return emitError("array attr must have length 1 for 0-D vectors");
6015  auto dim = getMaskDimSizes()[0];
6016  if (dim != 0 && dim != 1)
6017  return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
6018  return success();
6019  }
6020 
6021  // Verify that array attr size matches the rank of the vector result.
6022  if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
6023  return emitOpError(
6024  "must specify array attr of size equal vector result rank");
6025  // Verify that each array attr element is in bounds of corresponding vector
6026  // result dimension size.
6027  auto resultShape = resultType.getShape();
6028  auto resultScalableDims = resultType.getScalableDims();
6029  ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
6030  for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
6031  if (maskDimSize < 0 || maskDimSize > resultShape[index])
6032  return emitOpError(
6033  "array attr of size out of bounds of vector result dimension size");
6034  if (resultScalableDims[index] && maskDimSize != 0 &&
6035  maskDimSize != resultShape[index])
6036  return emitOpError(
6037  "only supports 'none set' or 'all set' scalable dimensions");
6038  }
6039  // Verify that if one mask dim size is zero, they all should be zero (because
6040  // the mask region is a conjunction of each mask dimension interval).
6041  bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6042  bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
6043  if (anyZeros && !allZeros)
6044  return emitOpError("expected all mask dim sizes to be zeros, "
6045  "as a result of conjunction with zero mask dim");
6046  return success();
6047 }
6048 
6049 bool ConstantMaskOp::isAllOnesMask() {
6050  auto resultType = getVectorType();
6051  // Check the corner case of 0-D vectors first.
6052  if (resultType.getRank() == 0) {
6053  assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
6054  return getMaskDimSizes()[0] == 1;
6055  }
6056  for (const auto [resultSize, maskDimSize] :
6057  llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
6058  if (maskDimSize < resultSize)
6059  return false;
6060  }
6061  return true;
6062 }
6063 
6064 //===----------------------------------------------------------------------===//
6065 // CreateMaskOp
6066 //===----------------------------------------------------------------------===//
6067 
6068 void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
6069  VectorType type,
6070  ArrayRef<OpFoldResult> mixedOperands) {
6071  SmallVector<Value> operands =
6072  getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
6073  build(builder, result, type, operands);
6074 }
6075 
6076 LogicalResult CreateMaskOp::verify() {
6077  auto vectorType = llvm::cast<VectorType>(getResult().getType());
6078  // Verify that an operand was specified for each result vector each dimension.
6079  if (vectorType.getRank() == 0) {
6080  if (getNumOperands() != 1)
6081  return emitOpError(
6082  "must specify exactly one operand for 0-D create_mask");
6083  } else if (getNumOperands() !=
6084  llvm::cast<VectorType>(getResult().getType()).getRank()) {
6085  return emitOpError(
6086  "must specify an operand for each result vector dimension");
6087  }
6088  return success();
6089 }
6090 
6091 namespace {
6092 
6093 /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
6094 ///
6095 /// Ex 1:
6096 /// %c2 = arith.constant 2 : index
6097 /// %c3 = arith.constant 3 : index
6098 /// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
6099 /// Becomes:
6100 /// vector.constant_mask [3, 2] : vector<4x3xi1>
6101 ///
6102 /// Ex 2:
6103 /// %c_neg_1 = arith.constant -1 : index
6104 /// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
6105 /// becomes:
6106 /// vector.constant_mask [0] : vector<[8]xi1>
6107 ///
6108 /// Ex 3:
6109 /// %c8 = arith.constant 8 : index
6110 /// %c16 = arith.constant 16 : index
6111 /// %0 = vector.vscale
6112 /// %1 = arith.muli %0, %c16 : index
6113 /// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
6114 /// becomes:
6115 /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
6116 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
6117 public:
6119 
6120  LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
6121  PatternRewriter &rewriter) const override {
6122  VectorType maskType = createMaskOp.getVectorType();
6123  ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
6124  ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6125 
6126  // Special case: Rank zero shape.
6127  constexpr std::array<int64_t, 1> rankZeroShape{1};
6128  constexpr std::array<bool, 1> rankZeroScalableDims{false};
6129  if (maskType.getRank() == 0) {
6130  maskTypeDimSizes = rankZeroShape;
6131  maskTypeDimScalableFlags = rankZeroScalableDims;
6132  }
6133 
6134  // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
6135  // collect the `constantDims` (for the ConstantMaskOp).
6136  SmallVector<int64_t, 4> constantDims;
6137  for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
6138  if (auto intSize = getConstantIntValue(dimSize)) {
6139  // Constant value.
6140  // If the mask dim is non-scalable this can be any value.
6141  // If the mask dim is scalable only zero (all-false) is supported.
6142  if (maskTypeDimScalableFlags[i] && intSize >= 0)
6143  return failure();
6144  constantDims.push_back(*intSize);
6145  } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
6146  // Constant vscale multiple (e.g. 4 x vscale).
6147  // Must be all-true to fold to a ConstantMask.
6148  if (vscaleMultiplier < maskTypeDimSizes[i])
6149  return failure();
6150  constantDims.push_back(*vscaleMultiplier);
6151  } else {
6152  return failure();
6153  }
6154  }
6155 
6156  // Clamp values to constant_mask bounds.
6157  for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
6158  value = std::clamp<int64_t>(value, 0, maskDimSize);
6159 
6160  // If one of dim sizes is zero, set all dims to zero.
6161  if (llvm::is_contained(constantDims, 0))
6162  constantDims.assign(constantDims.size(), 0);
6163 
6164  // Replace 'createMaskOp' with ConstantMaskOp.
6165  rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
6166  constantDims);
6167  return success();
6168  }
6169 };
6170 
6171 } // namespace
6172 
6173 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6174  MLIRContext *context) {
6175  results.add<CreateMaskFolder>(context);
6176 }
6177 
6178 //===----------------------------------------------------------------------===//
6179 // MaskOp
6180 //===----------------------------------------------------------------------===//
6181 
6182 void MaskOp::build(
6183  OpBuilder &builder, OperationState &result, Value mask,
6184  Operation *maskableOp,
6185  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6186  assert(maskRegionBuilder &&
6187  "builder callback for 'maskRegion' must be present");
6188 
6189  result.addOperands(mask);
6190  OpBuilder::InsertionGuard guard(builder);
6191  Region *maskRegion = result.addRegion();
6192  builder.createBlock(maskRegion);
6193  maskRegionBuilder(builder, maskableOp);
6194 }
6195 
6196 void MaskOp::build(
6197  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6198  Value mask, Operation *maskableOp,
6199  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6200  build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
6201  maskRegionBuilder);
6202 }
6203 
6204 void MaskOp::build(
6205  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6206  Value mask, Value passthru, Operation *maskableOp,
6207  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6208  build(builder, result, mask, maskableOp, maskRegionBuilder);
6209  if (passthru)
6210  result.addOperands(passthru);
6211  result.addTypes(resultTypes);
6212 }
6213 
6214 ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
6215  // Create the op region.
6216  result.regions.reserve(1);
6217  Region &maskRegion = *result.addRegion();
6218 
6219  auto &builder = parser.getBuilder();
6220 
6221  // Parse all the operands.
6223  if (parser.parseOperand(mask))
6224  return failure();
6225 
6226  // Optional passthru operand.
6228  ParseResult parsePassthru = parser.parseOptionalComma();
6229  if (parsePassthru.succeeded() && parser.parseOperand(passthru))
6230  return failure();
6231 
6232  // Parse op region.
6233  if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
6234  return failure();
6235 
6236  MaskOp::ensureTerminator(maskRegion, builder, result.location);
6237 
6238  // Parse the optional attribute list.
6239  if (parser.parseOptionalAttrDict(result.attributes))
6240  return failure();
6241 
6242  // Parse all the types.
6243  Type maskType;
6244  if (parser.parseColonType(maskType))
6245  return failure();
6246 
6247  SmallVector<Type> resultTypes;
6248  if (parser.parseOptionalArrowTypeList(resultTypes))
6249  return failure();
6250  result.types.append(resultTypes);
6251 
6252  // Resolve operands.
6253  if (parser.resolveOperand(mask, maskType, result.operands))
6254  return failure();
6255 
6256  if (parsePassthru.succeeded())
6257  if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
6258  return failure();
6259 
6260  return success();
6261 }
6262 
6264  p << " " << getMask();
6265  if (getPassthru())
6266  p << ", " << getPassthru();
6267 
6268  // Print single masked operation and skip terminator.
6269  p << " { ";
6270  Block *singleBlock = &getMaskRegion().getBlocks().front();
6271  if (singleBlock && !singleBlock->getOperations().empty())
6272  p.printCustomOrGenericOp(&singleBlock->front());
6273  p << " }";
6274 
6275  p.printOptionalAttrDict(getOperation()->getAttrs());
6276 
6277  p << " : " << getMask().getType();
6278  if (getNumResults() > 0)
6279  p << " -> " << getResultTypes();
6280 }
6281 
6282 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
6284  MaskOp>::ensureTerminator(region, builder, loc);
6285  // Keep the default yield terminator if the number of masked operations is not
6286  // the expected. This case will trigger a verification failure.
6287  Block &block = region.front();
6288  if (block.getOperations().size() != 2)
6289  return;
6290 
6291  // Replace default yield terminator with a new one that returns the results
6292  // from the masked operation.
6293  OpBuilder opBuilder(builder.getContext());
6294  Operation *maskedOp = &block.front();
6295  Operation *oldYieldOp = &block.back();
6296  assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
6297 
6298  // Empty vector.mask op.
6299  if (maskedOp == oldYieldOp)
6300  return;
6301 
6302  opBuilder.setInsertionPoint(oldYieldOp);
6303  opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
6304  oldYieldOp->dropAllReferences();
6305  oldYieldOp->erase();
6306 }
6307 
6308 LogicalResult MaskOp::verify() {
6309  // Structural checks.
6310  Block &block = getMaskRegion().getBlocks().front();
6311  if (block.getOperations().empty())
6312  return emitOpError("expects a terminator within the mask region");
6313 
6314  unsigned numMaskRegionOps = block.getOperations().size();
6315  if (numMaskRegionOps > 2)
6316  return emitOpError("expects only one operation to mask");
6317 
6318  // Terminator checks.
6319  auto terminator = dyn_cast<vector::YieldOp>(block.back());
6320  if (!terminator)
6321  return emitOpError("expects a terminator within the mask region");
6322 
6323  if (terminator->getNumOperands() != getNumResults())
6324  return emitOpError(
6325  "expects number of results to match mask region yielded values");
6326 
6327  // Empty vector.mask. Nothing else to check.
6328  if (numMaskRegionOps == 1)
6329  return success();
6330 
6331  auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
6332  if (!maskableOp)
6333  return emitOpError("expects a MaskableOpInterface within the mask region");
6334 
6335  // Result checks.
6336  if (maskableOp->getNumResults() != getNumResults())
6337  return emitOpError("expects number of results to match maskable operation "
6338  "number of results");
6339 
6340  if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
6341  return emitOpError(
6342  "expects result type to match maskable operation result type");
6343 
6344  if (llvm::count_if(maskableOp->getResultTypes(),
6345  [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
6346  return emitOpError("multiple vector results not supported");
6347 
6348  // Mask checks.
6349  Type expectedMaskType = maskableOp.getExpectedMaskType();
6350  if (getMask().getType() != expectedMaskType)
6351  return emitOpError("expects a ")
6352  << expectedMaskType << " mask for the maskable operation";
6353 
6354  // Passthru checks.
6355  Value passthru = getPassthru();
6356  if (passthru) {
6357  if (!maskableOp.supportsPassthru())
6358  return emitOpError(
6359  "doesn't expect a passthru argument for this maskable operation");
6360 
6361  if (maskableOp->getNumResults() != 1)
6362  return emitOpError("expects result when passthru argument is provided");
6363 
6364  if (passthru.getType() != maskableOp->getResultTypes()[0])
6365  return emitOpError("expects passthru type to match result type");
6366  }
6367 
6368  return success();
6369 }
6370 
6371 /// Folds vector.mask ops with an all-true mask.
6372 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6373  SmallVectorImpl<OpFoldResult> &results) {
6374  MaskFormat maskFormat = getMaskFormat(getMask());
6375  if (isEmpty())
6376  return failure();
6377 
6378  if (maskFormat != MaskFormat::AllTrue)
6379  return failure();
6380 
6381  // Move maskable operation outside of the `vector.mask` region.
6382  Operation *maskableOp = getMaskableOp();
6383  maskableOp->dropAllUses();
6384  maskableOp->moveBefore(getOperation());
6385 
6386  llvm::append_range(results, maskableOp->getResults());
6387  return success();
6388 }
6389 
6390 // Elides empty vector.mask operations with or without return values. Propagates
6391 // the yielded values by the vector.yield terminator, if any, or erases the op,
6392 // otherwise.
6393 class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
6395 
6396  LogicalResult matchAndRewrite(MaskOp maskOp,
6397  PatternRewriter &rewriter) const override {
6398  auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6399  if (maskingOp.getMaskableOp())
6400  return failure();
6401 
6402  if (!maskOp.isEmpty())
6403  return failure();
6404 
6405  Block *block = maskOp.getMaskBlock();
6406  auto terminator = cast<vector::YieldOp>(block->front());
6407  if (terminator.getNumOperands() == 0)
6408  rewriter.eraseOp(maskOp);
6409  else
6410  rewriter.replaceOp(maskOp, terminator.getOperands());
6411 
6412  return success();
6413  }
6414 };
6415 
6416 void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6417  MLIRContext *context) {
6418  results.add<ElideEmptyMaskOp>(context);
6419 }
6420 
6421 // MaskingOpInterface definitions.
6422 
6423 /// Returns the operation masked by this 'vector.mask'.
6424 Operation *MaskOp::getMaskableOp() {
6425  Block *block = getMaskBlock();
6426  if (block->getOperations().size() < 2)
6427  return nullptr;
6428 
6429  return &block->front();
6430 }
6431 
6432 /// Returns true if 'vector.mask' has a passthru value.
6433 bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
6434 
6435 //===----------------------------------------------------------------------===//
6436 // ScanOp
6437 //===----------------------------------------------------------------------===//
6438 
6439 LogicalResult ScanOp::verify() {
6440  VectorType srcType = getSourceType();
6441  VectorType initialType = getInitialValueType();
6442  // Check reduction dimension < rank.
6443  int64_t srcRank = srcType.getRank();
6444  int64_t reductionDim = getReductionDim();
6445  if (reductionDim >= srcRank)
6446  return emitOpError("reduction dimension ")
6447  << reductionDim << " has to be less than " << srcRank;
6448 
6449  // Check that rank(initial_value) = rank(src) - 1.
6450  int64_t initialValueRank = initialType.getRank();
6451  if (initialValueRank != srcRank - 1)
6452  return emitOpError("initial value rank ")
6453  << initialValueRank << " has to be equal to " << srcRank - 1;
6454 
6455  // Check shapes of initial value and src.
6456  ArrayRef<int64_t> srcShape = srcType.getShape();
6457  ArrayRef<int64_t> initialValueShapes = initialType.getShape();
6458  SmallVector<int64_t> expectedShape;
6459  for (int i = 0; i < srcRank; i++) {
6460  if (i != reductionDim)
6461  expectedShape.push_back(srcShape[i]);
6462  }
6463  if (!llvm::equal(initialValueShapes, expectedShape)) {
6464  return emitOpError("incompatible input/initial value shapes");
6465  }
6466 
6467  // Verify supported reduction kind.
6468  Type eltType = getDestType().getElementType();
6469  if (!isSupportedCombiningKind(getKind(), eltType))
6470  return emitOpError("unsupported reduction type ")
6471  << eltType << " for kind '" << stringifyCombiningKind(getKind())
6472  << "'";
6473 
6474  return success();
6475 }
6476 
6479  patterns
6480  .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
6481  ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
6482  StridedSliceConstantMaskFolder, TransposeFolder>(
6483  patterns.getContext(), benefit);
6484 }
6485 
6486 //===----------------------------------------------------------------------===//
6487 // SplatOp
6488 //===----------------------------------------------------------------------===//
6489 
6490 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
6491  auto constOperand = adaptor.getInput();
6492  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
6493  return {};
6494 
6495  // SplatElementsAttr::get treats single value for second arg as being a splat.
6496  return SplatElementsAttr::get(getType(), {constOperand});
6497 }
6498 
6499 void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6500  SetIntRangeFn setResultRanges) {
6501  setResultRanges(getResult(), argRanges.front());
6502 }
6503 
6505  CombiningKind kind, Value v1, Value acc,
6506  arith::FastMathFlagsAttr fastmath,
6507  Value mask) {
6508  Type t1 = getElementTypeOrSelf(v1.getType());
6509  Type tAcc = getElementTypeOrSelf(acc.getType());
6510  Value result;
6511 
6512  switch (kind) {
6513  case CombiningKind::ADD:
6514  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6515  result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
6516  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6517  result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
6518  else
6519  llvm_unreachable("invalid value types for ADD reduction");
6520  break;
6521  case CombiningKind::AND:
6522  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6523  result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
6524  break;
6525  case CombiningKind::MAXNUMF:
6526  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6527  "expected float values");
6528  result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
6529  break;
6530  case CombiningKind::MAXIMUMF:
6531  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6532  "expected float values");
6533  result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
6534  break;
6535  case CombiningKind::MINNUMF:
6536  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6537  "expected float values");
6538  result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
6539  break;
6540  case CombiningKind::MINIMUMF:
6541  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6542  "expected float values");
6543  result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
6544  break;
6545  case CombiningKind::MAXSI:
6546  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6547  result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
6548  break;
6549  case CombiningKind::MINSI:
6550  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6551  result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
6552  break;
6553  case CombiningKind::MAXUI:
6554  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6555  result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
6556  break;
6557  case CombiningKind::MINUI:
6558  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6559  result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
6560  break;
6561  case CombiningKind::MUL:
6562  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6563  result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
6564  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6565  result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
6566  else
6567  llvm_unreachable("invalid value types for MUL reduction");
6568  break;
6569  case CombiningKind::OR:
6570  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6571  result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
6572  break;
6573  case CombiningKind::XOR:
6574  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6575  result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
6576  break;
6577  };
6578 
6579  assert(result && "unknown CombiningKind");
6580  return selectPassthru(b, mask, result, acc);
6581 }
6582 
6583 //===----------------------------------------------------------------------===//
6584 // Vector Masking Utilities
6585 //===----------------------------------------------------------------------===//
6586 
6587 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
6588 /// as masked operation.
6590  Operation *maskableOp) {
6591  assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
6592  Block *insBlock = builder.getInsertionBlock();
6593  // Create a block and move the op to that block.
6594  insBlock->getOperations().splice(
6595  insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
6596  builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
6597 }
6598 
6599 /// Creates a vector.mask operation around a maskable operation. Returns the
6600 /// vector.mask operation if the mask provided is valid. Otherwise, returns
6601 /// the maskable operation itself.
6603  Operation *maskableOp, Value mask,
6604  Value passthru) {
6605  if (!mask)
6606  return maskableOp;
6607  if (passthru)
6608  return builder.create<MaskOp>(maskableOp->getLoc(),
6609  maskableOp->getResultTypes(), mask, passthru,
6610  maskableOp, createMaskOpRegion);
6611  return builder.create<MaskOp>(maskableOp->getLoc(),
6612  maskableOp->getResultTypes(), mask, maskableOp,
6614 }
6615 
6616 /// Creates a vector select operation that picks values from `newValue` or
6617 /// `passthru` for each result vector lane based on `mask`. This utility is used
6618 /// to propagate the pass-thru value of vector.mask or for cases where only the
6619 /// pass-thru value propagation is needed. VP intrinsics do not support
6620 /// pass-thru values and every mask-out lane is set to poison. LLVM backends are
6621 /// usually able to match op + select patterns and fold them into a native
6622 /// target instructions.
6624  Value newValue, Value passthru) {
6625  if (!mask)
6626  return newValue;
6627 
6628  return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(),
6629  mask, newValue, passthru);
6630 }
6631 
6632 //===----------------------------------------------------------------------===//
6633 // TableGen'd op method definitions
6634 //===----------------------------------------------------------------------===//
6635 
6636 #define GET_ATTRDEF_CLASSES
6637 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
6638 
6639 #define GET_OP_CLASSES
6640 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MINUI(lhs, rhs)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition: VectorOps.cpp:68
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
Definition: VectorOps.cpp:1087
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
Definition: VectorOps.cpp:1378
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
Definition: VectorOps.cpp:1639
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
Definition: VectorOps.cpp:4094
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
Definition: VectorOps.cpp:1947
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
Definition: VectorOps.cpp:1815
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
Definition: VectorOps.cpp:1650
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Definition: VectorOps.cpp:2262
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
Definition: VectorOps.cpp:130
MaskFormat
Helper enum to classify mask value.
Definition: VectorOps.cpp:58
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
Definition: VectorOps.cpp:3125
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
Definition: VectorOps.cpp:296
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
Definition: VectorOps.cpp:868
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
Definition: VectorOps.cpp:2311
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
Definition: VectorOps.cpp:2621
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
Definition: VectorOps.cpp:3061
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
Definition: VectorOps.cpp:1079
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
Definition: VectorOps.cpp:3081
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
Definition: VectorOps.cpp:175
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
Definition: VectorOps.cpp:3104
static LogicalResult foldTransferFullMask(TransferOp op)
Definition: VectorOps.cpp:4328
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Definition: VectorOps.cpp:1370
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite a vector.from_elements into a vector.splat if all elements are the same SSA value.
Definition: VectorOps.cpp:2285
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
Definition: VectorOps.cpp:3986
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
Definition: VectorOps.cpp:879
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
Definition: VectorOps.cpp:3046
static Value foldExtractFromShapeCast(ExtractOp extractOp)
Definition: VectorOps.cpp:1750
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
Definition: VectorOps.cpp:4015
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
Definition: VectorOps.cpp:4256
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
Definition: VectorOps.cpp:3462
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
Definition: VectorOps.cpp:1718
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
Definition: VectorOps.cpp:4273
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
Definition: VectorOps.cpp:1867
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:135
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:216
unsigned getNumResults() const
Definition: AffineMap.cpp:402
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:161
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation & back()
Definition: Block.h:152
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:203
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:97
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:321
IndexType getIndexType()
Definition: Builders.cpp:95
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:310
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:358
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:529
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:451
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:835
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
Definition: Operation.cpp:584
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:131
bool isF32() const
Definition: Types.cpp:59
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:123
bool isF16() const
Definition: Types.cpp:57
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:317
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Definition: BuiltinTypes.h:329
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:336
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:357
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Definition: VectorOps.cpp:446
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
Definition: VectorOps.cpp:125
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
Definition: VectorOps.cpp:354
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:153
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition: VectorOps.h:60
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2443
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Definition: VectorOps.cpp:4113
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
Definition: VectorOps.cpp:219
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:283
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:314
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:338
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
Definition: VectorOps.cpp:624
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition: VectorOps.h:68
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Definition: VectorOps.cpp:442
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:773
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:722
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:641
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:930
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
Return a fused vector::ContractionOp which represents a patterns such as:
Definition: VectorOps.cpp:1180
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Definition: VectorOps.cpp:1183
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:329
bool operator==(const KeyTy &key) const
Definition: VectorOps.cpp:381
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Definition: VectorOps.cpp:383
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.