MLIR  21.0.0git
VectorOps.cpp
Go to the documentation of this file.
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/AffineMap.h"
28 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/IRMapping.h"
34 #include "mlir/IR/PatternMatch.h"
35 #include "mlir/IR/TypeUtilities.h"
38 #include "mlir/Support/LLVM.h"
40 #include "llvm/ADT/ArrayRef.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SmallVector.h"
43 #include "llvm/ADT/StringSet.h"
44 #include "llvm/ADT/TypeSwitch.h"
45 
46 #include <cassert>
47 #include <cstdint>
48 #include <numeric>
49 
50 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
51 // Pull in all enum type and utility function definitions.
52 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
53 
54 using namespace mlir;
55 using namespace mlir::vector;
56 
57 /// Helper enum to classify mask value.
58 enum class MaskFormat {
59  AllTrue = 0,
60  AllFalse = 1,
61  Unknown = 2,
62 };
63 
64 /// Helper method to classify a mask value. Currently, the method
65 /// looks "under the hood" of a constant value with dense attributes
66 /// and a constant mask operation (since the client may be called at
67 /// various stages during progressive lowering).
69  if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
70  // Inspect constant dense values. We count up for bits that
71  // are set, count down for bits that are cleared, and bail
72  // when a mix is detected.
73  if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
74  int64_t val = 0;
75  for (bool b : denseElts.getValues<bool>())
76  if (b && val >= 0)
77  val++;
78  else if (!b && val <= 0)
79  val--;
80  else
81  return MaskFormat::Unknown;
82  if (val > 0)
83  return MaskFormat::AllTrue;
84  if (val < 0)
85  return MaskFormat::AllFalse;
86  }
87  } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
88  // Inspect constant mask index. If the index exceeds the
89  // dimension size, all bits are set. If the index is zero
90  // or less, no bits are set.
91  ArrayRef<int64_t> masks = m.getMaskDimSizes();
92  auto shape = m.getType().getShape();
93  bool allTrue = true;
94  bool allFalse = true;
95  for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
96  if (maskIdx < dimSize)
97  allTrue = false;
98  if (maskIdx > 0)
99  allFalse = false;
100  }
101  if (allTrue)
102  return MaskFormat::AllTrue;
103  if (allFalse)
104  return MaskFormat::AllFalse;
105  } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
106  // Finds all-false create_masks. An all-true create_mask requires all
107  // dims to be constants, so that'll be folded to a constant_mask, then
108  // detected in the constant_mask case.
109  auto maskOperands = m.getOperands();
110  for (Value operand : maskOperands) {
111  if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
112  int64_t dimSize =
113  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
114  if (dimSize <= 0)
115  return MaskFormat::AllFalse;
116  }
117  }
118  return MaskFormat::Unknown;
119  }
120  return MaskFormat::Unknown;
121 }
122 
123 /// Default callback to build a region with a 'vector.yield' terminator with no
124 /// arguments.
126  builder.create<vector::YieldOp>(loc);
127 }
128 
129 // Helper for verifying combining kinds in contractions and reductions.
130 static bool isSupportedCombiningKind(CombiningKind combiningKind,
131  Type elementType) {
132  switch (combiningKind) {
133  case CombiningKind::ADD:
134  case CombiningKind::MUL:
135  return elementType.isIntOrIndexOrFloat();
137  case CombiningKind::MINSI:
138  case CombiningKind::MAXUI:
139  case CombiningKind::MAXSI:
140  case CombiningKind::AND:
141  case CombiningKind::OR:
142  case CombiningKind::XOR:
143  return elementType.isIntOrIndex();
144  case CombiningKind::MINNUMF:
145  case CombiningKind::MAXNUMF:
146  case CombiningKind::MINIMUMF:
147  case CombiningKind::MAXIMUMF:
148  return llvm::isa<FloatType>(elementType);
149  }
150  return false;
151 }
152 
154  VectorType vectorType) {
155  int64_t elementVectorRank = 0;
156  VectorType elementVectorType =
157  llvm::dyn_cast<VectorType>(shapedType.getElementType());
158  if (elementVectorType)
159  elementVectorRank += elementVectorType.getRank();
160  // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
161  // TODO: replace once we have 0-d vectors.
162  if (shapedType.getRank() == 0 &&
163  vectorType.getShape() == ArrayRef<int64_t>{1})
164  return AffineMap::get(
165  /*numDims=*/0, /*numSymbols=*/0,
166  getAffineConstantExpr(0, shapedType.getContext()));
168  shapedType.getRank(), vectorType.getRank() - elementVectorRank,
169  shapedType.getContext());
170 }
171 
172 /// Check if `write` is of a constant splat and the masked `read` is padded with
173 /// the same splat value -- meaning it could be the same value as the initial
174 /// constant splat.
175 static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
176  vector::TransferReadOp read) {
177  auto readMask = read.getMask();
178  auto writeMask = write.getMask();
179  // Check if the masks are consistent. The splat value could be the same if the
180  // read is masked (and padded with the splat value), and the write is unmasked
181  // or has the same mask. Note this does not allow the case where the write is
182  // masked and the read is unmasked, as then the read could be of more elements
183  // than the write (which may not be the same value).
184  bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
185  if (!couldBeSameSplat)
186  return false;
187  // Check for constant splat (as the source of the write).
188  DenseElementsAttr splatAttr;
189  if (!matchPattern(write.getVector(),
190  m_Constant<DenseElementsAttr>(&splatAttr)) ||
191  !splatAttr.isSplat()) {
192  return false;
193  }
194  // The padding of the read and the constant splat value must be the same.
195  Attribute padAttr;
196  if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
197  return false;
198  return padAttr == splatAttr.getSplatValue<Attribute>();
199 }
200 
201 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
202  vector::TransferReadOp read) {
203  return !defWrite.hasOutOfBoundsDim() &&
204  defWrite.getIndices() == read.getIndices() &&
205  defWrite.getVectorType() == read.getVectorType() &&
206  defWrite.getPermutationMap() == read.getPermutationMap() &&
207  ((!defWrite.getMask() && !read.getMask()) ||
208  isSplatWriteConsistentWithMaskedRead(defWrite, read));
209 }
210 
211 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
212  vector::TransferWriteOp priorWrite) {
213  return priorWrite.getIndices() == write.getIndices() &&
214  priorWrite.getMask() == write.getMask() &&
215  priorWrite.getVectorType() == write.getVectorType() &&
216  priorWrite.getPermutationMap() == write.getPermutationMap();
217 }
218 
220  VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
221  bool testDynamicValueUsingBounds) {
222  // For simplicity only look at transfer of same type.
223  if (transferA.getVectorType() != transferB.getVectorType())
224  return false;
225  unsigned rankOffset = transferA.getLeadingShapedRank();
226  for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
227  Value indexA = transferA.getIndices()[i];
228  Value indexB = transferB.getIndices()[i];
229  std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
230  std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
231 
232  if (i < rankOffset) {
233  // For leading dimensions, if we can prove that index are different we
234  // know we are accessing disjoint slices.
235  if (cstIndexA.has_value() && cstIndexB.has_value()) {
236  if (*cstIndexA != *cstIndexB)
237  return true;
238  continue;
239  }
240  if (testDynamicValueUsingBounds) {
241  // First try to see if we can fully compose and simplify the affine
242  // expression as a fast track.
243  FailureOr<uint64_t> delta =
245  if (succeeded(delta) && *delta != 0)
246  return true;
247 
248  FailureOr<bool> testEqual =
249  ValueBoundsConstraintSet::areEqual(indexA, indexB);
250  if (succeeded(testEqual) && !testEqual.value())
251  return true;
252  }
253  } else {
254  // For this dimension, we slice a part of the memref we need to make sure
255  // the intervals accessed don't overlap.
256  int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
257  if (cstIndexA.has_value() && cstIndexB.has_value()) {
258  int64_t distance = std::abs(*cstIndexA - *cstIndexB);
259  if (distance >= vectorDim)
260  return true;
261  continue;
262  }
263  if (testDynamicValueUsingBounds) {
264  // First try to see if we can fully compose and simplify the affine
265  // expression as a fast track.
266  FailureOr<int64_t> delta =
268  if (succeeded(delta) && std::abs(*delta) >= vectorDim)
269  return true;
270 
271  FailureOr<int64_t> computeDelta =
273  if (succeeded(computeDelta)) {
274  if (std::abs(computeDelta.value()) >= vectorDim)
275  return true;
276  }
277  }
278  }
279  }
280  return false;
281 }
282 
283 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
284  VectorTransferOpInterface transferB,
285  bool testDynamicValueUsingBounds) {
286  if (transferA.getSource() != transferB.getSource())
287  return false;
288  return isDisjointTransferIndices(transferA, transferB,
289  testDynamicValueUsingBounds);
290 }
291 
292 // Helper to iterate over n-D vector slice elements. Calculate the next
293 // `position` in the n-D vector of size `shape`, applying an offset `offsets`.
294 // Modifies the `position` in place. Returns a failure when `position` becomes
295 // the end position.
296 static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
297  ArrayRef<int64_t> shape,
298  ArrayRef<int64_t> offsets) {
299  for (auto [posInDim, dimSize, offsetInDim] :
300  llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
301  ++posInDim;
302  if (posInDim < dimSize + offsetInDim)
303  return success();
304 
305  // Carry the overflow to the next loop iteration.
306  posInDim = offsetInDim;
307  }
308 
309  return failure();
310 }
311 
312 /// Returns the integer numbers in `values`. `values` are expected to be
313 /// constant operations.
316  llvm::transform(values, std::back_inserter(ints), [](Value value) {
317  auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
318  assert(constOp && "Unexpected non-constant index");
319  return constOp.value();
320  });
321  return ints;
322 }
323 
324 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
325 /// be constant operations.
328  llvm::transform(
329  foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
330  assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
331  return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
332  });
333  return ints;
334 }
335 
336 /// Convert `foldResults` into Values. Integer attributes are converted to
337 /// constant op.
339  ArrayRef<OpFoldResult> foldResults) {
340  SmallVector<Value> values;
341  llvm::transform(foldResults, std::back_inserter(values),
342  [&](OpFoldResult foldResult) {
343  if (auto attr = foldResult.dyn_cast<Attribute>())
344  return builder
346  loc, cast<IntegerAttr>(attr).getInt())
347  .getResult();
348 
349  return cast<Value>(foldResult);
350  });
351  return values;
352 }
353 
354 std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
355  if (value.getDefiningOp<vector::VectorScaleOp>())
356  return 1;
357  auto mul = value.getDefiningOp<arith::MulIOp>();
358  if (!mul)
359  return {};
360  auto lhs = mul.getLhs();
361  auto rhs = mul.getRhs();
362  if (lhs.getDefiningOp<vector::VectorScaleOp>())
363  return getConstantIntValue(rhs);
364  if (rhs.getDefiningOp<vector::VectorScaleOp>())
365  return getConstantIntValue(lhs);
366  return {};
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // CombiningKindAttr
371 //===----------------------------------------------------------------------===//
372 
373 namespace mlir {
374 namespace vector {
375 namespace detail {
377  using KeyTy = uint64_t;
378 
379  BitmaskEnumStorage(KeyTy val) : value(val) {}
380 
381  bool operator==(const KeyTy &key) const { return value == key; }
382 
384  const KeyTy &key) {
385  return new (allocator.allocate<BitmaskEnumStorage>())
386  BitmaskEnumStorage(key);
387  }
388 
389  KeyTy value = 0;
390 };
391 } // namespace detail
392 } // namespace vector
393 } // namespace mlir
394 
395 //===----------------------------------------------------------------------===//
396 // VectorDialect
397 //===----------------------------------------------------------------------===//
398 
399 namespace {
400 /// This class defines the interface for handling inlining with vector dialect
401 /// operations.
402 struct VectorInlinerInterface : public DialectInlinerInterface {
404 
405  /// All vector dialect ops can be inlined.
406  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
407  return true;
408  }
409 };
410 } // namespace
411 
412 void VectorDialect::initialize() {
413  addAttributes<
414 #define GET_ATTRDEF_LIST
415 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
416  >();
417 
418  addOperations<
419 #define GET_OP_LIST
420 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
421  >();
422 
423  addInterfaces<VectorInlinerInterface>();
424 
425  declarePromisedInterfaces<bufferization::BufferizableOpInterface,
426  TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
427  YieldOp>();
428  declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
429  TransferWriteOp>();
430  declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
431  declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
432  declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
433 }
434 
435 /// Materialize a single constant operation from a given attribute value with
436 /// the desired resultant type.
438  Attribute value, Type type,
439  Location loc) {
440  if (isa<ub::PoisonAttrInterface>(value))
441  return value.getDialect().materializeConstant(builder, value, type, loc);
442 
443  return arith::ConstantOp::materialize(builder, value, type, loc);
444 }
445 
447  return builder.getIntegerType(64);
448 }
449 
451  ArrayRef<int64_t> values) {
452  return builder.getI64ArrayAttr(values);
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // MultiDimReductionOp
457 //===----------------------------------------------------------------------===//
458 
459 void vector::MultiDimReductionOp::build(OpBuilder &builder,
460  OperationState &result, Value source,
461  Value acc, ArrayRef<bool> reductionMask,
462  CombiningKind kind) {
463  SmallVector<int64_t> reductionDims;
464  for (const auto &en : llvm::enumerate(reductionMask))
465  if (en.value())
466  reductionDims.push_back(en.index());
467  build(builder, result, kind, source, acc, reductionDims);
468 }
469 
470 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
471  // Single parallel dim, this is a noop.
472  if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
473  return getSource();
474  return {};
475 }
476 
477 std::optional<SmallVector<int64_t, 4>>
478 MultiDimReductionOp::getShapeForUnroll() {
479  return llvm::to_vector<4>(getSourceVectorType().getShape());
480 }
481 
482 LogicalResult MultiDimReductionOp::verify() {
483  SmallVector<int64_t> targetShape;
484  SmallVector<bool> scalableDims;
485  Type inferredReturnType;
486  auto sourceScalableDims = getSourceVectorType().getScalableDims();
487  for (auto [dimIdx, dimSize] :
488  llvm::enumerate(getSourceVectorType().getShape()))
489  if (!llvm::any_of(getReductionDims(),
490  [dimIdx = dimIdx](int64_t reductionDimIdx) {
491  return reductionDimIdx == static_cast<int64_t>(dimIdx);
492  })) {
493  targetShape.push_back(dimSize);
494  scalableDims.push_back(sourceScalableDims[dimIdx]);
495  }
496  // TODO: update to also allow 0-d vectors when available.
497  if (targetShape.empty())
498  inferredReturnType = getSourceVectorType().getElementType();
499  else
500  inferredReturnType = VectorType::get(
501  targetShape, getSourceVectorType().getElementType(), scalableDims);
502  if (getType() != inferredReturnType)
503  return emitOpError() << "destination type " << getType()
504  << " is incompatible with source type "
505  << getSourceVectorType();
506 
507  return success();
508 }
509 
510 /// Returns the mask type expected by this operation.
511 Type MultiDimReductionOp::getExpectedMaskType() {
512  auto vecType = getSourceVectorType();
513  return VectorType::get(vecType.getShape(),
514  IntegerType::get(vecType.getContext(), /*width=*/1),
515  vecType.getScalableDims());
516 }
517 
518 namespace {
519 // Only unit dimensions that are being reduced are folded. If the dimension is
520 // unit, but not reduced, it is not folded, thereby keeping the output type the
521 // same. If not all dimensions which are reduced are of unit dimension, this
522 // transformation does nothing. This is just a generalization of
523 // ElideSingleElementReduction for ReduceOp.
524 struct ElideUnitDimsInMultiDimReduction
525  : public OpRewritePattern<MultiDimReductionOp> {
527 
528  LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
529  PatternRewriter &rewriter) const override {
530  ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
531  for (const auto &dim : enumerate(shape)) {
532  if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
533  return failure();
534  }
535 
536  // Vector mask setup.
537  OpBuilder::InsertionGuard guard(rewriter);
538  Operation *rootOp;
539  Value mask;
540  if (reductionOp.isMasked()) {
541  rewriter.setInsertionPoint(reductionOp.getMaskingOp());
542  rootOp = reductionOp.getMaskingOp();
543  mask = reductionOp.getMaskingOp().getMask();
544  } else {
545  rootOp = reductionOp;
546  }
547 
548  Location loc = reductionOp.getLoc();
549  Value acc = reductionOp.getAcc();
550  Value cast;
551  if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
552  if (mask) {
553  VectorType newMaskType =
554  VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
555  dstVecType.getScalableDims());
556  mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
557  }
558  cast = rewriter.create<vector::ShapeCastOp>(
559  loc, reductionOp.getDestType(), reductionOp.getSource());
560  } else {
561  // This means we are reducing all the dimensions, and all reduction
562  // dimensions are of size 1. So a simple extraction would do.
563  SmallVector<int64_t> zeroIdx(shape.size(), 0);
564  if (mask)
565  mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx);
566  cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(),
567  zeroIdx);
568  }
569 
570  Value result =
571  vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
572  cast, /*fastmath=*/nullptr, mask);
573  rewriter.replaceOp(rootOp, result);
574  return success();
575  }
576 };
577 } // namespace
578 
579 void MultiDimReductionOp::getCanonicalizationPatterns(
580  RewritePatternSet &results, MLIRContext *context) {
581  results.add<ElideUnitDimsInMultiDimReduction>(context);
582 }
583 
584 //===----------------------------------------------------------------------===//
585 // ReductionOp
586 //===----------------------------------------------------------------------===//
587 
588 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
589  CombiningKind kind, Value vector,
590  arith::FastMathFlags fastMathFlags) {
591  build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
592 }
593 
594 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
595  CombiningKind kind, Value vector, Value acc,
596  arith::FastMathFlags fastMathFlags) {
597  build(builder, result,
598  llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
599  acc, fastMathFlags);
600 }
601 
602 LogicalResult ReductionOp::verify() {
603  // Verify for 0-D and 1-D vector.
604  int64_t rank = getSourceVectorType().getRank();
605  if (rank > 1)
606  return emitOpError("unsupported reduction rank: ") << rank;
607 
608  // Verify supported reduction kind.
609  Type eltType = getDest().getType();
610  if (!isSupportedCombiningKind(getKind(), eltType))
611  return emitOpError("unsupported reduction type '")
612  << eltType << "' for kind '" << stringifyCombiningKind(getKind())
613  << "'";
614 
615  return success();
616 }
617 
618 // MaskableOpInterface methods.
619 
620 /// Returns the mask type expected by this operation.
621 Type ReductionOp::getExpectedMaskType() {
622  auto vecType = getSourceVectorType();
623  return VectorType::get(vecType.getShape(),
624  IntegerType::get(vecType.getContext(), /*width=*/1),
625  vecType.getScalableDims());
626 }
627 
629  OpBuilder &builder, Location loc,
630  Value vector) {
631  switch (op) {
632  case arith::AtomicRMWKind::addf:
633  case arith::AtomicRMWKind::addi:
634  return builder.create<vector::ReductionOp>(vector.getLoc(),
635  CombiningKind::ADD, vector);
636  case arith::AtomicRMWKind::mulf:
637  case arith::AtomicRMWKind::muli:
638  return builder.create<vector::ReductionOp>(vector.getLoc(),
639  CombiningKind::MUL, vector);
640  case arith::AtomicRMWKind::minimumf:
641  return builder.create<vector::ReductionOp>(vector.getLoc(),
642  CombiningKind::MINIMUMF, vector);
643  case arith::AtomicRMWKind::mins:
644  return builder.create<vector::ReductionOp>(vector.getLoc(),
645  CombiningKind::MINSI, vector);
646  case arith::AtomicRMWKind::minu:
647  return builder.create<vector::ReductionOp>(vector.getLoc(),
648  CombiningKind::MINUI, vector);
649  case arith::AtomicRMWKind::maximumf:
650  return builder.create<vector::ReductionOp>(vector.getLoc(),
651  CombiningKind::MAXIMUMF, vector);
652  case arith::AtomicRMWKind::maxs:
653  return builder.create<vector::ReductionOp>(vector.getLoc(),
654  CombiningKind::MAXSI, vector);
655  case arith::AtomicRMWKind::maxu:
656  return builder.create<vector::ReductionOp>(vector.getLoc(),
657  CombiningKind::MAXUI, vector);
658  case arith::AtomicRMWKind::andi:
659  return builder.create<vector::ReductionOp>(vector.getLoc(),
660  CombiningKind::AND, vector);
661  case arith::AtomicRMWKind::ori:
662  return builder.create<vector::ReductionOp>(vector.getLoc(),
663  CombiningKind::OR, vector);
664  // TODO: Add remaining reduction operations.
665  default:
666  (void)emitOptionalError(loc, "Reduction operation type not supported");
667  break;
668  }
669  return nullptr;
670 }
671 
672 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
673  return llvm::to_vector<4>(getSourceVectorType().getShape());
674 }
675 
676 namespace {
677 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
679 
680  LogicalResult matchAndRewrite(ReductionOp reductionOp,
681  PatternRewriter &rewriter) const override {
682  // Vector mask setup.
683  OpBuilder::InsertionGuard guard(rewriter);
684  auto maskableOp =
685  cast<vector::MaskableOpInterface>(reductionOp.getOperation());
686  Operation *rootOp;
687  Value mask;
688  if (maskableOp.isMasked()) {
689  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
690  rootOp = maskableOp.getMaskingOp();
691  mask = maskableOp.getMaskingOp().getMask();
692  } else {
693  rootOp = reductionOp;
694  }
695 
696  auto vectorType = reductionOp.getSourceVectorType();
697  if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
698  return failure();
699 
700  Location loc = reductionOp.getLoc();
701  Value result;
702  if (vectorType.getRank() == 0) {
703  if (mask)
704  mask = rewriter.create<ExtractElementOp>(loc, mask);
705  result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
706  } else {
707  if (mask)
708  mask = rewriter.create<ExtractOp>(loc, mask, 0);
709  result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0);
710  }
711 
712  if (Value acc = reductionOp.getAcc())
713  result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
714  result, acc,
715  reductionOp.getFastmathAttr(), mask);
716 
717  rewriter.replaceOp(rootOp, result);
718  return success();
719  }
720 };
721 } // namespace
722 
723 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
724  MLIRContext *context) {
725  results.add<ElideSingleElementReduction>(context);
726 }
727 
728 //===----------------------------------------------------------------------===//
729 // ContractionOp
730 //===----------------------------------------------------------------------===//
731 
732 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
733  Value lhs, Value rhs, Value acc,
734  ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
735  ArrayRef<IteratorType> iteratorTypes) {
736  result.addOperands({lhs, rhs, acc});
737  result.addTypes(acc.getType());
738  result.addAttribute(
739  getIndexingMapsAttrName(result.name),
740  builder.getAffineMapArrayAttr(
741  AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
742  result.addAttribute(
743  getIteratorTypesAttrName(result.name),
744  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
745  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
746  return IteratorTypeAttr::get(builder.getContext(), t);
747  }))));
748 }
749 
750 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
751  Value lhs, Value rhs, Value acc,
752  ArrayAttr indexingMaps,
753  ArrayAttr iteratorTypes) {
754  build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
755  ContractionOp::getDefaultKind());
756 }
757 
758 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
759  Value lhs, Value rhs, Value acc,
760  ArrayAttr indexingMaps,
761  ArrayAttr iteratorTypes, CombiningKind kind) {
762  result.addOperands({lhs, rhs, acc});
763  result.addTypes(acc.getType());
764  result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
765  result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
766  result.addAttribute(getKindAttrName(result.name),
767  CombiningKindAttr::get(builder.getContext(), kind));
768 }
769 
770 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
775  SmallVector<Type, 2> types;
776  Type resultType;
777  auto loc = parser.getCurrentLocation();
778  DictionaryAttr dictAttr;
779  // TODO: Unify linalg op attribute parsing.
780  if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
781  parser.parseComma() || parser.parseOperand(rhsInfo) ||
782  parser.parseComma() || parser.parseOperand(accInfo) ||
783  parser.parseTrailingOperandList(masksInfo) ||
784  parser.parseOptionalAttrDict(result.attributes) ||
785  parser.parseColonTypeList(types) ||
786  parser.parseKeywordType("into", resultType) ||
787  parser.resolveOperand(lhsInfo, types[0], result.operands) ||
788  parser.resolveOperand(rhsInfo, types[1], result.operands) ||
789  parser.resolveOperand(accInfo, resultType, result.operands) ||
790  parser.addTypeToList(resultType, result.types))
791  return failure();
792  result.attributes.append(dictAttr.getValue().begin(),
793  dictAttr.getValue().end());
794 
795  // Convert array of string into an array of IteratyType enums. This is needed,
796  // because tests still use the old format when 'iterator_types' attribute is
797  // represented as an array of strings.
798  // TODO: Remove this conversion once tests are fixed.
799  ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
800  result.attributes.get(getIteratorTypesAttrName(result.name)));
801 
802  SmallVector<Attribute> iteratorTypeAttrs;
803 
804  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
805  auto maybeIteratorType = symbolizeIteratorType(s);
806  if (!maybeIteratorType.has_value())
807  return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
808 
809  iteratorTypeAttrs.push_back(
810  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
811  }
812  result.attributes.set(getIteratorTypesAttrName(result.name),
813  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
814 
815  if (!result.attributes.get(getKindAttrName(result.name))) {
816  result.addAttribute(
817  getKindAttrName(result.name),
819  ContractionOp::getDefaultKind()));
820  }
821  if (masksInfo.empty())
822  return success();
823  if (masksInfo.size() != 2)
824  return parser.emitError(parser.getNameLoc(),
825  "expected zero or exactly 2 vector mask operands");
826  auto lhsType = llvm::cast<VectorType>(types[0]);
827  auto rhsType = llvm::cast<VectorType>(types[1]);
828  auto maskElementType = parser.getBuilder().getI1Type();
829  std::array<VectorType, 2> maskTypes = {
830  VectorType::Builder(lhsType).setElementType(maskElementType),
831  VectorType::Builder(rhsType).setElementType(maskElementType)};
832  if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
833  return failure();
834  return success();
835 }
836 
838  // TODO: Unify printing code with linalg ops.
839  auto attrNames = getTraitAttrNames();
840  llvm::StringSet<> traitAttrsSet;
841  traitAttrsSet.insert(attrNames.begin(), attrNames.end());
843  for (auto attr : (*this)->getAttrs()) {
844  if (attr.getName() == getIteratorTypesAttrName()) {
845  auto iteratorTypes =
846  llvm::cast<ArrayAttr>(attr.getValue())
847  .getAsValueRange<IteratorTypeAttr, IteratorType>();
848  // Convert IteratorType enums into the string representation. This is
849  // needed, because tests still use the old format when 'iterator_types'
850  // attribute is represented as an array of strings.
851  // TODO: Remove this conversion once tests are fixed.
852  SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
853  llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
854  return StringAttr::get(getContext(), stringifyIteratorType(t));
855  }));
856 
857  attrs.emplace_back(getIteratorTypesAttrName(),
858  ArrayAttr::get(getContext(), iteratorTypeNames));
859  } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
860  attrs.push_back(attr);
861  }
862 
863  auto dictAttr = DictionaryAttr::get(getContext(), attrs);
864  p << " " << dictAttr << " " << getLhs() << ", ";
865  p << getRhs() << ", " << getAcc();
866 
867  p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
868  p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
869  << getResultType();
870 }
871 
872 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
873  const std::vector<std::pair<int64_t, int64_t>> &map) {
874  for (auto &dimPair : map) {
875  if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
876  dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
877  lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
878  return false;
879  }
880  return true;
881 }
882 
883 static LogicalResult verifyOutputShape(
884  ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
885  Type resType,
886  const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
887  const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
888  DenseSet<int64_t> lhsContractingDimSet;
889  DenseSet<int64_t> rhsContractingDimSet;
890  for (auto &dimPair : contractingDimMap) {
891  lhsContractingDimSet.insert(dimPair.first);
892  rhsContractingDimSet.insert(dimPair.second);
893  }
894  DenseSet<int64_t> rhsBatchDimSet;
895  for (auto &dimPair : batchDimMap)
896  rhsBatchDimSet.insert(dimPair.second);
897 
898  // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
899  SmallVector<int64_t, 4> expectedResultDims;
900  for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
901  if (lhsContractingDimSet.count(i) > 0)
902  continue;
903  expectedResultDims.push_back(lhsType.getDimSize(i));
904  }
905 
906  // Add free dimensions from 'rhsType' to 'expectedResultDims'.
907  for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
908  if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
909  continue;
910  expectedResultDims.push_back(rhsType.getDimSize(i));
911  }
912 
913  // Verify 'expectedResultDims'.
914  if (expectedResultDims.empty()) {
915  // No batch or free dimension implies a scalar result.
916  if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
917  return op.emitOpError("invalid accumulator/result vector shape");
918  } else {
919  // At least one batch or free dimension implies a vector result.
920  auto resVectorType = llvm::dyn_cast<VectorType>(resType);
921  auto accVectorType = llvm::dyn_cast<VectorType>(accType);
922  if (!resVectorType || !accVectorType)
923  return op.emitOpError("invalid accumulator/result vector shape");
924 
925  // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
926  // types fully define the result vector type. This assumes the affine maps
927  // are well-formed, which must have been verified already.
928  MLIRContext *ctx = op.getContext();
929  AffineMap lhsMap = op.getIndexingMapsArray()[0];
930  AffineMap rhsMap = op.getIndexingMapsArray()[1];
931  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
932  return op.emitOpError(
933  "expected all dimensions to be either a LHS or a RHS dimension");
934  SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
935  for (auto pair :
936  {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
937  VectorType v = pair.first;
938  auto map = pair.second;
939  for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
940  unsigned pos = map.getDimPosition(idx);
941  if (!extents[pos])
942  extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
943  }
944  }
945  if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
946  return op.emitOpError("expected all dimensions to get an extent as "
947  "either a LHS or a RHS dimension");
948 
949  AffineMap resMap = op.getIndexingMapsArray()[2];
950  auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
951  /*symbolCount=*/0, extents, ctx);
952  // Compose the resMap with the extentsMap, which is a constant map.
953  AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
954  assert(llvm::all_of(expectedMap.getResults(),
955  llvm::IsaPred<AffineConstantExpr>) &&
956  "expected constant extent along all dimensions.");
957  // Extract the expected shape and build the type.
958  auto expectedShape = llvm::to_vector<4>(
959  llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
960  return cast<AffineConstantExpr>(e).getValue();
961  }));
962  auto expected =
963  VectorType::get(expectedShape, resVectorType.getElementType(),
964  resVectorType.getScalableDims());
965  if (resVectorType != expected || accVectorType != expected)
966  return op.emitOpError(
967  "invalid accumulator/result vector shape, expected: ")
968  << expected;
969  }
970  return success();
971 }
972 
973 LogicalResult ContractionOp::verify() {
974  VectorType lhsType = getLhsType();
975  VectorType rhsType = getRhsType();
976  Type accType = getAccType();
977  Type resType = getResultType();
978 
979  if (llvm::isa<IntegerType>(lhsType.getElementType())) {
980  if (!lhsType.getElementType().isSignlessInteger())
981  return emitOpError("only supports signless integer types");
982  }
983 
984  // Verify that an indexing map was specified for each vector operand.
985  if (getIndexingMapsArray().size() != 3)
986  return emitOpError("expected an indexing map for each vector operand");
987 
988  // Verify that each index map has 'numIterators' inputs, no symbols, and
989  // that the number of map outputs equals the rank of its associated
990  // vector operand.
991  unsigned numIterators = getIteratorTypes().getValue().size();
992  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
993  auto index = it.index();
994  auto map = it.value();
995  if (map.getNumSymbols() != 0)
996  return emitOpError("expected indexing map ")
997  << index << " to have no symbols";
998  auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
999  unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1000  // Verify that the map has the right number of inputs, outputs, and indices.
1001  // This also correctly accounts for (..) -> () for rank-0 results.
1002  if (map.getNumDims() != numIterators)
1003  return emitOpError("expected indexing map ")
1004  << index << " to have " << numIterators << " number of inputs";
1005  if (map.getNumResults() != rank)
1006  return emitOpError("expected indexing map ")
1007  << index << " to have " << rank << " number of outputs";
1008  if (!map.isProjectedPermutation())
1009  return emitOpError("expected indexing map ")
1010  << index << " to be a projected permutation of its inputs";
1011  }
1012 
1013  auto contractingDimMap = getContractingDimMap();
1014  auto batchDimMap = getBatchDimMap();
1015 
1016  // Verify at least one contracting dimension pair was specified.
1017  if (contractingDimMap.empty())
1018  return emitOpError("expected at least one contracting dimension pair");
1019 
1020  // Verify contracting dimension map was properly constructed.
1021  if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1022  return emitOpError("invalid contracting dimension map");
1023 
1024  // Verify batch dimension map was properly constructed.
1025  if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1026  return emitOpError("invalid batch dimension map");
1027 
1028  // Verify 'accType' and 'resType' shape.
1029  if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1030  contractingDimMap, batchDimMap)))
1031  return failure();
1032 
1033  // Verify supported combining kind.
1034  auto vectorType = llvm::dyn_cast<VectorType>(resType);
1035  auto elementType = vectorType ? vectorType.getElementType() : resType;
1036  if (!isSupportedCombiningKind(getKind(), elementType))
1037  return emitOpError("unsupported contraction type");
1038 
1039  return success();
1040 }
1041 
1042 // MaskableOpInterface methods.
1043 
1044 /// Returns the mask type expected by this operation. Mostly used for
1045 /// verification purposes. It requires the operation to be vectorized."
1046 Type ContractionOp::getExpectedMaskType() {
1047  auto indexingMaps = this->getIndexingMapsArray();
1048  AffineMap lhsIdxMap = indexingMaps[0];
1049  AffineMap rhsIdxMap = indexingMaps[1];
1050  VectorType lhsType = this->getLhsType();
1051  VectorType rhsType = this->getRhsType();
1052 
1053  unsigned numVecDims = lhsIdxMap.getNumDims();
1054  SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1055  SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1056 
1057  // Using the information in the indexing maps, extract the size of each
1058  // dimension in the vector.contract operation from the two input operands.
1059  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1060  maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1061  maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1062  lhsType.getScalableDims()[dimIdx];
1063  }
1064  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1065  maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1066  maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1067  rhsType.getScalableDims()[dimIdx];
1068  }
1069 
1070  assert(!ShapedType::isDynamicShape(maskShape) &&
1071  "Mask shape couldn't be computed");
1072 
1073  return VectorType::get(maskShape,
1074  IntegerType::get(lhsType.getContext(), /*width=*/1),
1075  maskShapeScalableDims);
1076 }
1077 
1078 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1079  return SmallVector<StringRef>{getIndexingMapsAttrName(),
1080  getIteratorTypesAttrName(), getKindAttrName()};
1081 }
1082 
1083 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
1084  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1085  if (targetExpr == map.getResult(i))
1086  return i;
1087  return -1;
1088 }
1089 
1090 static std::vector<std::pair<int64_t, int64_t>>
1091 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1092  IteratorType targetIteratorType, MLIRContext *context) {
1093  std::vector<std::pair<int64_t, int64_t>> dimMap;
1094  for (const auto &it : llvm::enumerate(iteratorTypes)) {
1095  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1096  if (iteratorType != targetIteratorType)
1097  continue;
1098  // Search lhs/rhs map results for 'targetExpr'.
1099  auto targetExpr = getAffineDimExpr(it.index(), context);
1100  int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1101  int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1102  if (lhsDim >= 0 && rhsDim >= 0)
1103  dimMap.emplace_back(lhsDim, rhsDim);
1104  }
1105  return dimMap;
1106 }
1107 
1108 void ContractionOp::getIterationBounds(
1109  SmallVectorImpl<int64_t> &iterationBounds) {
1110  auto lhsShape = getLhsType().getShape();
1111  auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1112  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1113  SmallVector<int64_t, 2> iterationShape;
1114  for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1115  // Search lhs/rhs map results for 'targetExpr'.
1116  auto targetExpr = getAffineDimExpr(it.index(), getContext());
1117  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1118  if (iteratorType == IteratorType::reduction) {
1119  // Get reduction dim size from lhs shape (same size in rhsShape).
1120  int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1121  assert(lhsDimIndex >= 0);
1122  iterationBounds.push_back(lhsShape[lhsDimIndex]);
1123  continue;
1124  }
1125  // Get parallel dimension size from result shape.
1126  int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1127  assert(resDimIndex >= 0);
1128  assert(resVectorType != nullptr);
1129  iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1130  }
1131 }
1132 
1133 void ContractionOp::getIterationIndexMap(
1134  std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1135  unsigned numMaps = getIndexingMapsArray().size();
1136  iterationIndexMap.resize(numMaps);
1137  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1138  auto index = it.index();
1139  auto map = it.value();
1140  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1141  auto dim = cast<AffineDimExpr>(map.getResult(i));
1142  iterationIndexMap[index][dim.getPosition()] = i;
1143  }
1144  }
1145 }
1146 
1147 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1148  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1149  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1150  getContext());
1151 }
1152 
1153 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1154  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1155  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1156  getContext());
1157 }
1158 
1159 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1161  getIterationBounds(shape);
1162  return shape;
1163 }
1164 
1165 /// Return a fused vector::ContractionOp which represents a patterns such as:
1166 ///
1167 /// ```mlir
1168 /// %c0 = vector.constant 0: ...
1169 /// %c = vector.contract %a, %b, %c0: ...
1170 /// %e = add %c, %d: ...
1171 /// ```
1172 ///
1173 /// by:
1174 ///
1175 /// ```mlir
1176 /// %e = vector.contract %a, %b, %d: ...
1177 /// ```
1178 ///
1179 /// Return null if the canonicalization does not apply.
1180 // TODO: This should be a folding of Add into Contract in core but while they
1181 // live in different dialects, it is not possible without unnatural
1182 // dependencies.
1183 template <typename AddOpType>
1184 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1186 
1187  LogicalResult matchAndRewrite(AddOpType addOp,
1188  PatternRewriter &rewriter) const override {
1189  auto canonicalize = [&](Value maybeContraction,
1190  Value otherOperand) -> vector::ContractionOp {
1191  vector::ContractionOp contractionOp =
1192  dyn_cast_or_null<vector::ContractionOp>(
1193  maybeContraction.getDefiningOp());
1194  if (!contractionOp)
1195  return vector::ContractionOp();
1196  if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1197  contractionOp.getAcc().getDefiningOp())) {
1198  if (maybeZero.getValue() ==
1199  rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1200  IRMapping bvm;
1201  bvm.map(contractionOp.getAcc(), otherOperand);
1202  auto newContraction =
1203  cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1204  rewriter.replaceOp(addOp, newContraction.getResult());
1205  return newContraction;
1206  }
1207  }
1208  return vector::ContractionOp();
1209  };
1210 
1211  Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1212  vector::ContractionOp contract = canonicalize(a, b);
1213  contract = contract ? contract : canonicalize(b, a);
1214  return contract ? success() : failure();
1215  }
1216 };
1217 
1218 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1219  MLIRContext *context) {
1222 }
1223 
1224 //===----------------------------------------------------------------------===//
1225 // ExtractElementOp
1226 //===----------------------------------------------------------------------===//
1227 
1228 void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1229  SetIntRangeFn setResultRanges) {
1230  setResultRanges(getResult(), argRanges.front());
1231 }
1232 
1233 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1234  Value source) {
1235  result.addOperands({source});
1236  result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
1237 }
1238 
1239 LogicalResult vector::ExtractElementOp::verify() {
1240  VectorType vectorType = getSourceVectorType();
1241  if (vectorType.getRank() == 0) {
1242  if (getPosition())
1243  return emitOpError("expected position to be empty with 0-D vector");
1244  return success();
1245  }
1246  if (vectorType.getRank() != 1)
1247  return emitOpError("unexpected >1 vector rank");
1248  if (!getPosition())
1249  return emitOpError("expected position for 1-D vector");
1250  return success();
1251 }
1252 
1253 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1254  // Skip the 0-D vector here now.
1255  if (!adaptor.getPosition())
1256  return {};
1257 
1258  // Fold extractelement (splat X) -> X.
1259  if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
1260  return splat.getInput();
1261 
1262  // Fold extractelement(broadcast(X)) -> X.
1263  if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1264  if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
1265  return broadcast.getSource();
1266 
1267  auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1268  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1269  if (!pos || !src)
1270  return {};
1271 
1272  auto srcElements = src.getValues<Attribute>();
1273 
1274  uint64_t posIdx = pos.getInt();
1275  if (posIdx >= srcElements.size())
1276  return {};
1277 
1278  return srcElements[posIdx];
1279 }
1280 
1281 // Returns `true` if `index` is either within [0, maxIndex) or equal to
1282 // `poisonValue`.
1283 static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
1284  int64_t maxIndex) {
1285  return index == poisonValue || (index >= 0 && index < maxIndex);
1286 }
1287 
1288 //===----------------------------------------------------------------------===//
1289 // ExtractOp
1290 //===----------------------------------------------------------------------===//
1291 
1292 void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1293  SetIntRangeFn setResultRanges) {
1294  setResultRanges(getResult(), argRanges.front());
1295 }
1296 
1297 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1298  Value source, int64_t position) {
1299  build(builder, result, source, ArrayRef<int64_t>{position});
1300 }
1301 
1302 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1303  Value source, OpFoldResult position) {
1304  build(builder, result, source, ArrayRef<OpFoldResult>{position});
1305 }
1306 
1307 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1308  Value source, ArrayRef<int64_t> position) {
1309  build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1310  builder.getDenseI64ArrayAttr(position));
1311 }
1312 
1313 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1314  Value source, ArrayRef<OpFoldResult> position) {
1315  SmallVector<int64_t> staticPos;
1316  SmallVector<Value> dynamicPos;
1317  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1318  build(builder, result, source, dynamicPos,
1319  builder.getDenseI64ArrayAttr(staticPos));
1320 }
1321 
1322 LogicalResult
1323 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1324  ExtractOp::Adaptor adaptor,
1325  SmallVectorImpl<Type> &inferredReturnTypes) {
1326  auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1327  if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1328  vectorType.getRank()) {
1329  inferredReturnTypes.push_back(vectorType.getElementType());
1330  } else {
1331  auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1332  vectorType.getRank());
1333  inferredReturnTypes.push_back(VectorType::get(
1334  vectorType.getShape().drop_front(n), vectorType.getElementType(),
1335  vectorType.getScalableDims().drop_front(n)));
1336  }
1337  return success();
1338 }
1339 
1340 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1341  // Allow extracting 1-element vectors instead of scalars.
1342  auto isCompatible = [](TypeRange l, TypeRange r) {
1343  auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1344  return vectorType && vectorType.getShape().equals({1}) &&
1345  vectorType.getElementType() == r.front();
1346  };
1347  if (l.size() == 1 && r.size() == 1 &&
1348  (isCompatible(l, r) || isCompatible(r, l)))
1349  return true;
1350  return l == r;
1351 }
1352 
1353 LogicalResult vector::ExtractOp::verify() {
1354  // Note: This check must come before getMixedPosition() to prevent a crash.
1355  auto dynamicMarkersCount =
1356  llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1357  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1358  return emitOpError(
1359  "mismatch between dynamic and static positions (kDynamic marker but no "
1360  "corresponding dynamic position) -- this can only happen due to an "
1361  "incorrect fold/rewrite");
1362  auto position = getMixedPosition();
1363  if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1364  return emitOpError(
1365  "expected position attribute of rank no greater than vector rank");
1366  for (auto [idx, pos] : llvm::enumerate(position)) {
1367  if (auto attr = dyn_cast<Attribute>(pos)) {
1368  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1370  constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1371  return emitOpError("expected position attribute #")
1372  << (idx + 1)
1373  << " to be a non-negative integer smaller than the "
1374  "corresponding vector dimension or poison (-1)";
1375  }
1376  }
1377  }
1378  return success();
1379 }
1380 
1381 template <typename IntType>
1382 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1383  return llvm::to_vector<4>(llvm::map_range(
1384  arrayAttr.getAsRange<IntegerAttr>(),
1385  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1386 }
1387 
1388 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1389 /// positions.
1390 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1391  if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1392  return failure();
1393 
1394  // TODO: Canonicalization for dynamic position not implemented yet.
1395  if (extractOp.hasDynamicPosition())
1396  return failure();
1397 
1398  SmallVector<int64_t> globalPosition;
1399  ExtractOp currentOp = extractOp;
1400  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1401  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1402  while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1403  currentOp = nextOp;
1404  // TODO: Canonicalization for dynamic position not implemented yet.
1405  if (currentOp.hasDynamicPosition())
1406  return failure();
1407  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1408  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1409  }
1410  extractOp.setOperand(0, currentOp.getVector());
1411  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1412  OpBuilder b(extractOp.getContext());
1413  std::reverse(globalPosition.begin(), globalPosition.end());
1414  extractOp.setStaticPosition(globalPosition);
1415  return success();
1416 }
1417 
1418 namespace {
1419 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1420 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1421 /// Compose TransposeOp permutations as we walk back.
1422 /// This helper class keeps an updated extraction position `extractPosition`
1423 /// with extra trailing sentinels.
1424 /// The sentinels encode the internal transposition status of the result vector.
1425 /// As we iterate, extractPosition is permuted and updated.
1426 class ExtractFromInsertTransposeChainState {
1427 public:
1428  ExtractFromInsertTransposeChainState(ExtractOp e);
1429 
1430  /// Iterate over producing insert and transpose ops until we find a fold.
1431  Value fold();
1432 
1433 private:
1434  /// Return true if the vector at position `a` is contained within the vector
1435  /// at position `b`. Under insert/extract semantics, this is the same as `a`
1436  /// is a prefix of `b`.
1437  template <typename ContainerA, typename ContainerB>
1438  bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1439  return a.size() <= b.size() &&
1440  std::equal(a.begin(), a.begin() + a.size(), b.begin());
1441  }
1442 
1443  /// Return true if the vector at position `a` intersects the vector at
1444  /// position `b`. Under insert/extract semantics, this is the same as equality
1445  /// of all entries of `a` that are >=0 with the corresponding entries of b.
1446  /// Comparison is on the common prefix (i.e. zip).
1447  template <typename ContainerA, typename ContainerB>
1448  bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1449  for (auto [elemA, elemB] : llvm::zip(a, b)) {
1450  if (elemA < 0 || elemB < 0)
1451  continue;
1452  if (elemA != elemB)
1453  return false;
1454  }
1455  return true;
1456  }
1457 
1458  /// Folding is only possible in the absence of an internal permutation in the
1459  /// result vector.
1460  bool canFold() {
1461  return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1462  }
1463 
1464  // Helper to get the next defining op of interest.
1465  void updateStateForNextIteration(Value v) {
1466  nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1467  nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1468  };
1469 
1470  // Case 1. If we hit a transpose, just compose the map and iterate.
1471  // Invariant: insert + transpose do not change rank, we can always compose.
1472  LogicalResult handleTransposeOp();
1473 
1474  // Case 2: the insert position matches extractPosition exactly, early return.
1475  LogicalResult handleInsertOpWithMatchingPos(Value &res);
1476 
1477  /// Case 3: if the insert position is a prefix of extractPosition, extract a
1478  /// portion of the source of the insert.
1479  /// Example:
1480  /// ```
1481  /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1482  /// // extractPosition == [1, 2, 3]
1483  /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1484  /// // can fold to vector.extract %source[0, 3]
1485  /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1486  /// ```
1487  /// To traverse through %source, we need to set the leading dims to 0 and
1488  /// drop the extra leading dims.
1489  /// This method updates the internal state.
1490  LogicalResult handleInsertOpWithPrefixPos(Value &res);
1491 
1492  /// Try to fold in place to extract(source, extractPosition) and return the
1493  /// folded result. Return null if folding is not possible (e.g. due to an
1494  /// internal transposition in the result).
1495  Value tryToFoldExtractOpInPlace(Value source);
1496 
1497  ExtractOp extractOp;
1498  int64_t vectorRank;
1499  int64_t extractedRank;
1500 
1501  InsertOp nextInsertOp;
1502  TransposeOp nextTransposeOp;
1503 
1504  /// Sentinel values that encode the internal permutation status of the result.
1505  /// They are set to (-1, ... , -k) at the beginning and appended to
1506  /// `extractPosition`.
1507  /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1508  /// ensure that there is no internal transposition.
1509  /// Internal transposition cannot be accounted for with a folding pattern.
1510  // TODO: We could relax the internal transposition with an extra transposition
1511  // operation in a future canonicalizer.
1512  SmallVector<int64_t> sentinels;
1514 };
1515 } // namespace
1516 
1517 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1518  ExtractOp e)
1519  : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1520  extractedRank(extractOp.getNumIndices()) {
1521  assert(vectorRank >= extractedRank && "Extracted position overflow");
1522  sentinels.reserve(vectorRank - extractedRank);
1523  for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1524  sentinels.push_back(-(i + 1));
1525  extractPosition.assign(extractOp.getStaticPosition().begin(),
1526  extractOp.getStaticPosition().end());
1527  llvm::append_range(extractPosition, sentinels);
1528 }
1529 
1530 // Case 1. If we hit a transpose, just compose the map and iterate.
1531 // Invariant: insert + transpose do not change rank, we can always compose.
1532 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1533  // TODO: Canonicalization for dynamic position not implemented yet.
1534  if (extractOp.hasDynamicPosition())
1535  return failure();
1536 
1537  if (!nextTransposeOp)
1538  return failure();
1539  AffineMap m = inversePermutation(AffineMap::getPermutationMap(
1540  nextTransposeOp.getPermutation(), extractOp.getContext()));
1542  return success();
1543 }
1544 
1545 // Case 2: the insert position matches extractPosition exactly, early return.
1546 LogicalResult
1547 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1548  Value &res) {
1549  // TODO: Canonicalization for dynamic position not implemented yet.
1550  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1551  return failure();
1552 
1553  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1554  if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1555  return failure();
1556  // Case 2.a. early-exit fold.
1557  res = nextInsertOp.getSource();
1558  // Case 2.b. if internal transposition is present, canFold will be false.
1559  return success(canFold());
1560 }
1561 
1562 /// Case 3: if inserted position is a prefix of extractPosition,
1563 /// extract a portion of the source of the insertion.
1564 /// This method updates the internal state.
1565 LogicalResult
1566 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1567  // TODO: Canonicalization for dynamic position not implemented yet.
1568  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1569  return failure();
1570 
1571  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1572  if (!isContainedWithin(insertedPos, extractPosition))
1573  return failure();
1574  // Set leading dims to zero.
1575  std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1576  // Drop extra leading dims.
1577  extractPosition.erase(extractPosition.begin(),
1578  extractPosition.begin() + insertedPos.size());
1579  extractedRank = extractPosition.size() - sentinels.size();
1580  // Case 3.a. early-exit fold (break and delegate to post-while path).
1581  res = nextInsertOp.getSource();
1582  // Case 3.b. if internal transposition is present, canFold will be false.
1583  return success();
1584 }
1585 
1586 /// Try to fold in place to extract(source, extractPosition) and return the
1587 /// folded result. Return null if folding is not possible (e.g. due to an
1588 /// internal transposition in the result).
1589 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1590  Value source) {
1591  // TODO: Canonicalization for dynamic position not implemented yet.
1592  if (extractOp.hasDynamicPosition())
1593  return Value();
1594 
1595  // If we can't fold (either internal transposition, or nothing to fold), bail.
1596  bool nothingToFold = (source == extractOp.getVector());
1597  if (nothingToFold || !canFold())
1598  return Value();
1599 
1600  // Otherwise, fold by updating the op inplace and return its result.
1601  OpBuilder b(extractOp.getContext());
1602  extractOp.setStaticPosition(
1603  ArrayRef(extractPosition).take_front(extractedRank));
1604  extractOp.getVectorMutable().assign(source);
1605  return extractOp.getResult();
1606 }
1607 
1608 /// Iterate over producing insert and transpose ops until we find a fold.
1609 Value ExtractFromInsertTransposeChainState::fold() {
1610  // TODO: Canonicalization for dynamic position not implemented yet.
1611  if (extractOp.hasDynamicPosition())
1612  return Value();
1613 
1614  Value valueToExtractFrom = extractOp.getVector();
1615  updateStateForNextIteration(valueToExtractFrom);
1616  while (nextInsertOp || nextTransposeOp) {
1617  // Case 1. If we hit a transpose, just compose the map and iterate.
1618  // Invariant: insert + transpose do not change rank, we can always compose.
1619  if (succeeded(handleTransposeOp())) {
1620  valueToExtractFrom = nextTransposeOp.getVector();
1621  updateStateForNextIteration(valueToExtractFrom);
1622  continue;
1623  }
1624 
1625  Value result;
1626  // Case 2: the position match exactly.
1627  if (succeeded(handleInsertOpWithMatchingPos(result)))
1628  return result;
1629 
1630  // Case 3: if the inserted position is a prefix of extractPosition, we can
1631  // just extract a portion of the source of the insert.
1632  if (succeeded(handleInsertOpWithPrefixPos(result)))
1633  return tryToFoldExtractOpInPlace(result);
1634 
1635  // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1636  // values. This is a more difficult case and we bail.
1637  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1638  if (isContainedWithin(extractPosition, insertedPos) ||
1639  intersectsWhereNonNegative(extractPosition, insertedPos))
1640  return Value();
1641 
1642  // Case 5: No intersection, we forward the extract to insertOp.dest().
1643  valueToExtractFrom = nextInsertOp.getDest();
1644  updateStateForNextIteration(valueToExtractFrom);
1645  }
1646  // If after all this we can fold, go for it.
1647  return tryToFoldExtractOpInPlace(valueToExtractFrom);
1648 }
1649 
1650 /// Returns true if the operation has a 0-D vector type operand or result.
1651 static bool hasZeroDimVectors(Operation *op) {
1652  auto hasZeroDimVectorType = [](Type type) -> bool {
1653  auto vecType = dyn_cast<VectorType>(type);
1654  return vecType && vecType.getRank() == 0;
1655  };
1656 
1657  return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1658  llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1659 }
1660 
1661 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1662 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1663  // TODO: Canonicalization for dynamic position not implemented yet.
1664  if (extractOp.hasDynamicPosition())
1665  return Value();
1666 
1667  Operation *defOp = extractOp.getVector().getDefiningOp();
1668  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1669  return Value();
1670 
1671  Value source = defOp->getOperand(0);
1672  if (extractOp.getType() == source.getType())
1673  return source;
1674  auto getRank = [](Type type) {
1675  return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1676  : 0;
1677  };
1678 
1679  // If splat or broadcast from a scalar, just return the source scalar.
1680  unsigned broadcastSrcRank = getRank(source.getType());
1681  if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
1682  return source;
1683 
1684  unsigned extractResultRank = getRank(extractOp.getType());
1685  if (extractResultRank >= broadcastSrcRank)
1686  return Value();
1687  // Check that the dimension of the result haven't been broadcasted.
1688  auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1689  auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
1690  if (extractVecType && broadcastVecType &&
1691  extractVecType.getShape() !=
1692  broadcastVecType.getShape().take_back(extractResultRank))
1693  return Value();
1694 
1695  auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1696  int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1697 
1698  // Detect all the positions that come from "dim-1" broadcasting.
1699  // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1700  // extract position to `0` when extracting from the source operand.
1701  llvm::SetVector<int64_t> broadcastedUnitDims =
1702  broadcastOp.computeBroadcastedUnitDims();
1703  SmallVector<int64_t> extractPos(extractOp.getStaticPosition());
1704  int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1705  for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1706  if (broadcastedUnitDims.contains(i))
1707  extractPos[i] = 0;
1708  // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1709  // matching extract position when extracting from the source operand.
1710  int64_t rankDiff = broadcastSrcRank - extractResultRank;
1711  extractPos.erase(extractPos.begin(),
1712  std::next(extractPos.begin(), extractPos.size() - rankDiff));
1713  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1714  OpBuilder b(extractOp.getContext());
1715  extractOp.setOperand(0, source);
1716  extractOp.setStaticPosition(extractPos);
1717  return extractOp.getResult();
1718 }
1719 
1720 /// Fold extractOp coming from ShuffleOp.
1721 ///
1722 /// Example:
1723 ///
1724 /// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1725 /// : vector<8xf32>, vector<8xf32>
1726 /// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1727 /// ->
1728 /// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1729 ///
1730 static Value foldExtractFromShuffle(ExtractOp extractOp) {
1731  // Dynamic positions are not folded as the resulting code would be more
1732  // complex than the input code.
1733  if (extractOp.hasDynamicPosition())
1734  return Value();
1735 
1736  auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1737  if (!shuffleOp)
1738  return Value();
1739 
1740  // TODO: 0-D or multi-dimensional vectors not supported yet.
1741  if (shuffleOp.getResultVectorType().getRank() != 1)
1742  return Value();
1743 
1744  int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1745  auto shuffleMask = shuffleOp.getMask();
1746  int64_t extractIdx = extractOp.getStaticPosition()[0];
1747  int64_t shuffleIdx = shuffleMask[extractIdx];
1748 
1749  // Find the shuffled vector to extract from based on the shuffle index.
1750  if (shuffleIdx < inputVecSize) {
1751  extractOp.setOperand(0, shuffleOp.getV1());
1752  extractOp.setStaticPosition({shuffleIdx});
1753  } else {
1754  extractOp.setOperand(0, shuffleOp.getV2());
1755  extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1756  }
1757 
1758  return extractOp.getResult();
1759 }
1760 
1761 // Fold extractOp with source coming from ShapeCast op.
1762 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1763  // TODO: Canonicalization for dynamic position not implemented yet.
1764  if (extractOp.hasDynamicPosition())
1765  return Value();
1766 
1767  auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1768  if (!shapeCastOp)
1769  return Value();
1770 
1771  // Get the nth dimension size starting from lowest dimension.
1772  auto getDimReverse = [](VectorType type, int64_t n) {
1773  return type.getShape().take_back(n + 1).front();
1774  };
1775  int64_t destinationRank =
1776  llvm::isa<VectorType>(extractOp.getType())
1777  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1778  : 0;
1779  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1780  return Value();
1781  if (destinationRank > 0) {
1782  auto destinationType =
1783  llvm::cast<VectorType>(extractOp.getResult().getType());
1784  for (int64_t i = 0; i < destinationRank; i++) {
1785  // The lowest dimension of the destination must match the lowest
1786  // dimension of the shapecast op source.
1787  // TODO: This case could be support in a canonicalization pattern.
1788  if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1789  getDimReverse(destinationType, i))
1790  return Value();
1791  }
1792  }
1793  // Extract the strides associated with the extract op vector source. Then use
1794  // this to calculate a linearized position for the extract.
1795  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1796  std::reverse(extractedPos.begin(), extractedPos.end());
1797  SmallVector<int64_t, 4> strides;
1798  int64_t stride = 1;
1799  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1800  strides.push_back(stride);
1801  stride *=
1802  getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1803  }
1804 
1805  int64_t position = linearize(extractedPos, strides);
1806  // Then extract the strides associated to the shapeCast op vector source and
1807  // delinearize the position using those strides.
1808  SmallVector<int64_t, 4> newStrides;
1809  int64_t numDimension =
1810  shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1811  stride = 1;
1812  for (int64_t i = 0; i < numDimension; i++) {
1813  newStrides.push_back(stride);
1814  stride *=
1815  getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1816  }
1817  std::reverse(newStrides.begin(), newStrides.end());
1818  SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1819  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1820  OpBuilder b(extractOp.getContext());
1821  extractOp.setStaticPosition(newPosition);
1822  extractOp.setOperand(0, shapeCastOp.getSource());
1823  return extractOp.getResult();
1824 }
1825 
1826 /// Fold an ExtractOp from ExtractStridedSliceOp.
1827 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1828  // TODO: Canonicalization for dynamic position not implemented yet.
1829  if (extractOp.hasDynamicPosition())
1830  return Value();
1831 
1832  auto extractStridedSliceOp =
1833  extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1834  if (!extractStridedSliceOp)
1835  return Value();
1836 
1837  // 0-D vectors not supported.
1838  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1839  if (hasZeroDimVectors(extractStridedSliceOp))
1840  return Value();
1841 
1842  // Return if 'extractStridedSliceOp' has non-unit strides.
1843  if (extractStridedSliceOp.hasNonUnitStrides())
1844  return Value();
1845 
1846  // Trim offsets for dimensions fully extracted.
1847  auto sliceOffsets =
1848  extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1849  while (!sliceOffsets.empty()) {
1850  size_t lastOffset = sliceOffsets.size() - 1;
1851  if (sliceOffsets.back() != 0 ||
1852  extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1853  extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1854  break;
1855  sliceOffsets.pop_back();
1856  }
1857  unsigned destinationRank = 0;
1858  if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1859  destinationRank = vecType.getRank();
1860  // The dimensions of the result need to be untouched by the
1861  // extractStridedSlice op.
1862  if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1863  sliceOffsets.size())
1864  return Value();
1865 
1866  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1867  assert(extractedPos.size() >= sliceOffsets.size());
1868  for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1869  extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1870  extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1871 
1872  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1873  OpBuilder b(extractOp.getContext());
1874  extractOp.setStaticPosition(extractedPos);
1875  return extractOp.getResult();
1876 }
1877 
1878 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1879 static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1880  // TODO: Canonicalization for dynamic position not implemented yet.
1881  if (extractOp.hasDynamicPosition())
1882  return Value();
1883 
1884  int64_t destinationRank =
1885  llvm::isa<VectorType>(extractOp.getType())
1886  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1887  : 0;
1888  auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1889  if (!insertOp)
1890  return Value();
1891 
1892  // 0-D vectors not supported.
1893  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1894  if (hasZeroDimVectors(insertOp))
1895  return Value();
1896 
1897  while (insertOp) {
1898  int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1899  insertOp.getSourceVectorType().getRank();
1900  if (destinationRank > insertOp.getSourceVectorType().getRank())
1901  return Value();
1902  auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1903  ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1904 
1905  if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1906  return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1907  }))
1908  return Value();
1909  bool disjoint = false;
1910  SmallVector<int64_t, 4> offsetDiffs;
1911  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1912  int64_t start = insertOffsets[dim];
1913  int64_t size =
1914  (dim < insertRankDiff)
1915  ? 1
1916  : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1917  int64_t end = start + size;
1918  int64_t offset = extractOffsets[dim];
1919  // Check if the start of the extract offset is in the interval inserted.
1920  if (start <= offset && offset < end) {
1921  if (dim >= insertRankDiff)
1922  offsetDiffs.push_back(offset - start);
1923  continue;
1924  }
1925  disjoint = true;
1926  break;
1927  }
1928  // The extract element chunk overlap with the vector inserted.
1929  if (!disjoint) {
1930  // If any of the inner dimensions are only partially inserted we have a
1931  // partial overlap.
1932  int64_t srcRankDiff =
1933  insertOp.getSourceVectorType().getRank() - destinationRank;
1934  for (int64_t i = 0; i < destinationRank; i++) {
1935  if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1936  insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1937  insertRankDiff))
1938  return Value();
1939  }
1940  extractOp.getVectorMutable().assign(insertOp.getSource());
1941  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1942  OpBuilder b(extractOp.getContext());
1943  extractOp.setStaticPosition(offsetDiffs);
1944  return extractOp.getResult();
1945  }
1946  // If the chunk extracted is disjoint from the chunk inserted, keep
1947  // looking in the insert chain.
1948  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1949  }
1950  return Value();
1951 }
1952 
1953 /// Try to fold the extraction of a scalar from a vector defined by
1954 /// vector.from_elements. E.g.:
1955 ///
1956 /// %0 = vector.from_elements %a, %b : vector<2xf32>
1957 /// %1 = vector.extract %0[0] : f32 from vector<2xf32>
1958 /// ==> fold to %a
1959 static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
1960  // Dynamic extractions cannot be folded.
1961  if (extractOp.hasDynamicPosition())
1962  return {};
1963 
1964  // Look for extract(from_elements).
1965  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
1966  if (!fromElementsOp)
1967  return {};
1968 
1969  // Scalable vectors are not supported.
1970  auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1971  if (vecType.isScalable())
1972  return {};
1973 
1974  // Only extractions of scalars are supported.
1975  int64_t rank = vecType.getRank();
1976  ArrayRef<int64_t> indices = extractOp.getStaticPosition();
1977  if (extractOp.getType() != vecType.getElementType())
1978  return {};
1979  assert(static_cast<int64_t>(indices.size()) == rank &&
1980  "unexpected number of indices");
1981 
1982  // Compute flattened/linearized index and fold to operand.
1983  int flatIndex = 0;
1984  int stride = 1;
1985  for (int i = rank - 1; i >= 0; --i) {
1986  flatIndex += indices[i] * stride;
1987  stride *= vecType.getDimSize(i);
1988  }
1989  return fromElementsOp.getElements()[flatIndex];
1990 }
1991 
1992 /// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
1993 /// then fold it.
1994 template <typename OpType, typename AdaptorType>
1995 static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
1996  SmallVectorImpl<Value> &operands) {
1997  std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
1998  OperandRange dynamicPosition = op.getDynamicPosition();
1999  ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2000 
2001  // If the dynamic operands is empty, it is returned directly.
2002  if (!dynamicPosition.size())
2003  return {};
2004 
2005  // `index` is used to iterate over the `dynamicPosition`.
2006  unsigned index = 0;
2007 
2008  // `opChange` is a flag. If it is true, it means to update `op` in place.
2009  bool opChange = false;
2010  for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2011  if (!ShapedType::isDynamic(staticPosition[i]))
2012  continue;
2013  Attribute positionAttr = dynamicPositionAttr[index];
2014  Value position = dynamicPosition[index++];
2015  if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2016  staticPosition[i] = attr.getInt();
2017  opChange = true;
2018  continue;
2019  }
2020  operands.push_back(position);
2021  }
2022 
2023  if (opChange) {
2024  op.setStaticPosition(staticPosition);
2025  op.getOperation()->setOperands(operands);
2026  return op.getResult();
2027  }
2028  return {};
2029 }
2030 
2031 /// Fold an insert or extract operation into an poison value when a poison index
2032 /// is found at any dimension of the static position.
2034  ArrayRef<int64_t> staticPos,
2035  int64_t poisonVal) {
2036  if (!llvm::is_contained(staticPos, poisonVal))
2037  return {};
2038 
2039  return ub::PoisonAttr::get(context);
2040 }
2041 
2042 /// Fold a vector extract from is a poison source.
2044  if (llvm::isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2045  return srcAttr;
2046 
2047  return {};
2048 }
2049 
2050 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2051  // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2052  // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2053  // mismatch).
2054  if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
2055  return getVector();
2056  if (auto res = foldPoisonIndexInsertExtractOp(
2057  getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2058  return res;
2059  if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
2060  return res;
2061  if (succeeded(foldExtractOpFromExtractChain(*this)))
2062  return getResult();
2063  if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2064  return res;
2065  if (auto res = foldExtractFromBroadcast(*this))
2066  return res;
2067  if (auto res = foldExtractFromShuffle(*this))
2068  return res;
2069  if (auto res = foldExtractFromShapeCast(*this))
2070  return res;
2071  if (auto val = foldExtractFromExtractStrided(*this))
2072  return val;
2073  if (auto val = foldExtractStridedOpFromInsertChain(*this))
2074  return val;
2075  if (auto val = foldScalarExtractFromFromElements(*this))
2076  return val;
2077  SmallVector<Value> operands = {getVector()};
2078  if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
2079  return val;
2080  return OpFoldResult();
2081 }
2082 
2083 namespace {
2084 
2085 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2086 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2087 public:
2089 
2090  LogicalResult matchAndRewrite(ExtractOp extractOp,
2091  PatternRewriter &rewriter) const override {
2092  Operation *defOp = extractOp.getVector().getDefiningOp();
2093  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2094  return failure();
2095 
2096  Value source = defOp->getOperand(0);
2097  if (extractOp.getType() == source.getType())
2098  return failure();
2099  auto getRank = [](Type type) {
2100  return llvm::isa<VectorType>(type)
2101  ? llvm::cast<VectorType>(type).getRank()
2102  : 0;
2103  };
2104  unsigned broadcastSrcRank = getRank(source.getType());
2105  unsigned extractResultRank = getRank(extractOp.getType());
2106  // We only consider the case where the rank of the source is less than or
2107  // equal to the rank of the extract dst. The other cases are handled in the
2108  // folding patterns.
2109  if (extractResultRank < broadcastSrcRank)
2110  return failure();
2111 
2112  // Special case if broadcast src is a 0D vector.
2113  if (extractResultRank == 0) {
2114  assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
2115  rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
2116  return success();
2117  }
2118  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2119  extractOp, extractOp.getType(), source);
2120  return success();
2121  }
2122 };
2123 
2124 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
2125 class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
2126 public:
2128 
2129  LogicalResult matchAndRewrite(ExtractOp extractOp,
2130  PatternRewriter &rewriter) const override {
2131  // Return if 'ExtractOp' operand is not defined by a splat vector
2132  // ConstantOp.
2133  Value sourceVector = extractOp.getVector();
2134  Attribute vectorCst;
2135  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2136  return failure();
2137  auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
2138  if (!splat)
2139  return failure();
2140  TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
2141  if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
2142  newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2143  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2144  return success();
2145  }
2146 };
2147 
2148 // Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
2149 class ExtractOpNonSplatConstantFolder final
2150  : public OpRewritePattern<ExtractOp> {
2151 public:
2153 
2154  LogicalResult matchAndRewrite(ExtractOp extractOp,
2155  PatternRewriter &rewriter) const override {
2156  // TODO: Canonicalization for dynamic position not implemented yet.
2157  if (extractOp.hasDynamicPosition())
2158  return failure();
2159 
2160  // Return if 'ExtractOp' operand is not defined by a compatible vector
2161  // ConstantOp.
2162  Value sourceVector = extractOp.getVector();
2163  Attribute vectorCst;
2164  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2165  return failure();
2166 
2167  auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
2168  if (vecTy.isScalable())
2169  return failure();
2170 
2171  // The splat case is handled by `ExtractOpSplatConstantFolder`.
2172  auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2173  if (!dense || dense.isSplat())
2174  return failure();
2175 
2176  // Calculate the linearized position of the continuous chunk of elements to
2177  // extract.
2178  llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2179  copy(extractOp.getStaticPosition(), completePositions.begin());
2180  int64_t elemBeginPosition =
2181  linearize(completePositions, computeStrides(vecTy.getShape()));
2182  auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
2183 
2184  TypedAttr newAttr;
2185  if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
2186  SmallVector<Attribute> elementValues(
2187  denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2188  newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2189  } else {
2190  newAttr = *denseValuesBegin;
2191  }
2192 
2193  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2194  return success();
2195  }
2196 };
2197 
2198 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2199 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2200 public:
2202 
2203  LogicalResult matchAndRewrite(ExtractOp extractOp,
2204  PatternRewriter &rewriter) const override {
2205  auto createMaskOp =
2206  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2207  if (!createMaskOp)
2208  return failure();
2209 
2210  VectorType extractedMaskType =
2211  llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2212 
2213  if (!extractedMaskType)
2214  return failure();
2215 
2216  auto maskOperands = createMaskOp.getOperands();
2217  ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2218  VectorType maskType = createMaskOp.getVectorType();
2219 
2220  bool containsUnknownDims = false;
2221  bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2222 
2223  for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2224  dimIdx++) {
2225  int64_t pos = extractOpPos[dimIdx];
2226  Value operand = maskOperands[dimIdx];
2227  auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2228  if (!constantOp) {
2229  // Bounds of this dim unknown.
2230  containsUnknownDims = true;
2231  continue;
2232  }
2233 
2234  int64_t createMaskBound =
2235  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2236 
2237  if (pos != ShapedType::kDynamic) {
2238  // If any position is outside the range from the `create_mask`, then the
2239  // extracted mask will be all-false.
2240  allFalse |= pos >= createMaskBound;
2241  } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2242  // This dim is not all-true and since this is a dynamic index we don't
2243  // know if the extraction is within the true or false region.
2244  // Note: Zero dims have already handled via getMaskFormat().
2245  containsUnknownDims = true;
2246  }
2247  }
2248 
2249  if (allFalse) {
2250  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2251  extractOp, DenseElementsAttr::get(extractedMaskType, false));
2252  } else if (!containsUnknownDims) {
2253  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2254  extractOp, extractedMaskType,
2255  maskOperands.drop_front(extractOpPos.size()));
2256  } else {
2257  return failure();
2258  }
2259  return success();
2260  }
2261 };
2262 
2263 // Folds extract(shape_cast(..)) into shape_cast when the total element count
2264 // does not change.
2265 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2266  PatternRewriter &rewriter) {
2267  auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2268  if (!castOp)
2269  return failure();
2270 
2271  VectorType sourceType = castOp.getSourceVectorType();
2272  auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2273  if (!targetType)
2274  return failure();
2275 
2276  if (sourceType.getNumElements() != targetType.getNumElements())
2277  return failure();
2278 
2279  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2280  castOp.getSource());
2281  return success();
2282 }
2283 
2284 /// Try to canonicalize the extraction of a subvector from a vector defined by
2285 /// vector.from_elements. E.g.:
2286 ///
2287 /// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2288 /// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2289 /// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2290 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2291  PatternRewriter &rewriter) {
2292  // Dynamic positions are not supported.
2293  if (extractOp.hasDynamicPosition())
2294  return failure();
2295 
2296  // Scalar extracts are handled by the folder.
2297  auto resultType = dyn_cast<VectorType>(extractOp.getType());
2298  if (!resultType)
2299  return failure();
2300 
2301  // Look for extracts from a from_elements op.
2302  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2303  if (!fromElementsOp)
2304  return failure();
2305  VectorType inputType = fromElementsOp.getType();
2306 
2307  // Scalable vectors are not supported.
2308  if (resultType.isScalable() || inputType.isScalable())
2309  return failure();
2310 
2311  // Compute the position of first extracted element and flatten/linearize the
2312  // position.
2313  SmallVector<int64_t> firstElementPos =
2314  llvm::to_vector(extractOp.getStaticPosition());
2315  firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2316  int flatIndex = 0;
2317  int stride = 1;
2318  for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2319  flatIndex += firstElementPos[i] * stride;
2320  stride *= inputType.getDimSize(i);
2321  }
2322 
2323  // Replace the op with a smaller from_elements op.
2324  rewriter.replaceOpWithNewOp<FromElementsOp>(
2325  extractOp, resultType,
2326  fromElementsOp.getElements().slice(flatIndex,
2327  resultType.getNumElements()));
2328  return success();
2329 }
2330 
2331 } // namespace
2332 
2333 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2334  MLIRContext *context) {
2335  results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2336  ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2337  results.add(foldExtractFromShapeCastToShapeCast);
2338  results.add(foldExtractFromFromElements);
2339 }
2340 
2341 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
2342  SmallVectorImpl<int64_t> &results) {
2343  for (auto attr : arrayAttr)
2344  results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2345 }
2346 
2347 //===----------------------------------------------------------------------===//
2348 // FmaOp
2349 //===----------------------------------------------------------------------===//
2350 
2351 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2352  return llvm::to_vector<4>(getVectorType().getShape());
2353 }
2354 
2355 //===----------------------------------------------------------------------===//
2356 // FromElementsOp
2357 //===----------------------------------------------------------------------===//
2358 
2359 /// Rewrite a vector.from_elements into a vector.splat if all elements are the
2360 /// same SSA value. E.g.:
2361 ///
2362 /// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2363 /// ==> rewrite to vector.splat %a : vector<3xf32>
2364 static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
2365  PatternRewriter &rewriter) {
2366  if (!llvm::all_equal(fromElementsOp.getElements()))
2367  return failure();
2368  rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
2369  fromElementsOp.getElements().front());
2370  return success();
2371 }
2372 
2373 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2374  MLIRContext *context) {
2376 }
2377 
2378 //===----------------------------------------------------------------------===//
2379 // BroadcastOp
2380 //===----------------------------------------------------------------------===//
2381 
2382 void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2383  SetIntRangeFn setResultRanges) {
2384  setResultRanges(getResult(), argRanges.front());
2385 }
2386 
2387 /// Return the dimensions of the result vector that were formerly ones in the
2388 /// source tensor and thus correspond to "dim-1" broadcasting.
2391  ArrayRef<int64_t> dstShape) {
2392  int64_t rankDiff = dstShape.size() - srcShape.size();
2393  int64_t dstDim = rankDiff;
2395  for (auto [s1, s2] :
2396  llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2397  if (s1 != s2) {
2398  assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2399  res.insert(dstDim);
2400  }
2401  ++dstDim;
2402  }
2403  return res;
2404 }
2405 
2407  // Scalar broadcast is without any unit dim broadcast.
2408  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2409  if (!srcVectorType)
2410  return {};
2411  return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2412  getResultVectorType().getShape());
2413 }
2414 
2415 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2416 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
2417 /// This requires (and asserts) that the broadcast is free of "dim-1"
2418 /// broadcasting.
2419 /// Since vector.broadcast only allows expanding leading dimensions, an extra
2420 /// vector.transpose may be inserted to make the broadcast possible.
2421 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2422 /// the helper will assert. This means:
2423 /// 1. `dstShape` must not be empty.
2424 /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2425 /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2426 // must match the `value` shape.
2427 Value BroadcastOp::createOrFoldBroadcastOp(
2428  OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2429  const llvm::SetVector<int64_t> &broadcastedDims) {
2430  assert(!dstShape.empty() && "unexpected empty dst shape");
2431 
2432  // Well-formedness check.
2433  SmallVector<int64_t> checkShape;
2434  for (int i = 0, e = dstShape.size(); i < e; ++i) {
2435  if (broadcastedDims.contains(i))
2436  continue;
2437  checkShape.push_back(dstShape[i]);
2438  }
2439  assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2440  "ill-formed broadcastedDims contains values not confined to "
2441  "destVectorShape");
2442 
2443  Location loc = value.getLoc();
2444  Type elementType = getElementTypeOrSelf(value.getType());
2445  VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2446  VectorType dstVectorType = VectorType::get(dstShape, elementType);
2447 
2448  // Step 2. If scalar -> dstShape broadcast, just do it.
2449  if (!srcVectorType) {
2450  assert(checkShape.empty() &&
2451  "ill-formed createOrFoldBroadcastOp arguments");
2452  return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2453  }
2454 
2455  assert(srcVectorType.getShape().equals(checkShape) &&
2456  "ill-formed createOrFoldBroadcastOp arguments");
2457 
2458  // Step 3. Since vector.broadcast only allows creating leading dims,
2459  // vector -> dstShape broadcast may require a transpose.
2460  // Traverse the dims in order and construct:
2461  // 1. The leading entries of the broadcastShape that is guaranteed to be
2462  // achievable by a simple broadcast.
2463  // 2. The induced permutation for the subsequent vector.transpose that will
2464  // bring us from `broadcastShape` back to he desired `dstShape`.
2465  // If the induced permutation is not the identity, create a vector.transpose.
2466  SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2467  broadcastShape.reserve(dstShape.size());
2468  // Consider the example:
2469  // srcShape = 2x4
2470  // dstShape = 1x2x3x4x5
2471  // broadcastedDims = [0, 2, 4]
2472  //
2473  // We want to build:
2474  // broadcastShape = 1x3x5x2x4
2475  // permutation = [0, 2, 4, 1, 3]
2476  // ---V--- -----V-----
2477  // leading broadcast part src shape part
2478  //
2479  // Note that the trailing dims of broadcastShape are exactly the srcShape
2480  // by construction.
2481  // nextSrcShapeDim is used to keep track of where in the permutation the
2482  // "src shape part" occurs.
2483  int64_t nextSrcShapeDim = broadcastedDims.size();
2484  for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2485  if (broadcastedDims.contains(i)) {
2486  // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2487  // bring it to the head of the broadcastShape.
2488  // It will need to be permuted back from `broadcastShape.size() - 1` into
2489  // position `i`.
2490  broadcastShape.push_back(dstShape[i]);
2491  permutation[i] = broadcastShape.size() - 1;
2492  } else {
2493  // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2494  // shape and needs to be permuted into position `i`.
2495  // Don't touch `broadcastShape` here, the whole srcShape will be
2496  // appended after.
2497  permutation[i] = nextSrcShapeDim++;
2498  }
2499  }
2500  // 3.c. Append the srcShape.
2501  llvm::append_range(broadcastShape, srcVectorType.getShape());
2502 
2503  // Ensure there are no "dim-1" broadcasts.
2504  assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2505  .empty() &&
2506  "unexpected \"dim-1\" broadcast");
2507 
2508  VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2509  assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2510  vector::BroadcastableToResult::Success &&
2511  "must be broadcastable");
2512  Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2513  // Step 4. If we find any dimension that indeed needs to be permuted,
2514  // immediately return a new vector.transpose.
2515  for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2516  if (permutation[i] != i)
2517  return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2518  // Otherwise return res.
2519  return res;
2520 }
2521 
2523  Type srcType, VectorType dstVectorType,
2524  std::pair<VectorDim, VectorDim> *mismatchingDims) {
2525  // Broadcast scalar to vector of the same element type.
2526  if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
2527  getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
2528  return BroadcastableToResult::Success;
2529  // From now on, only vectors broadcast.
2530  VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2531  if (!srcVectorType)
2532  return BroadcastableToResult::SourceTypeNotAVector;
2533 
2534  int64_t srcRank = srcVectorType.getRank();
2535  int64_t dstRank = dstVectorType.getRank();
2536  if (srcRank > dstRank)
2537  return BroadcastableToResult::SourceRankHigher;
2538  // Source has an exact match or singleton value for all trailing dimensions
2539  // (all leading dimensions are simply duplicated).
2540  int64_t lead = dstRank - srcRank;
2541  for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2542  // Have mismatching dims (in the sense of vector.broadcast semantics) been
2543  // encountered?
2544  bool foundMismatchingDims = false;
2545 
2546  // Check fixed-width dims.
2547  int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2548  int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2549  if (srcDim != 1 && srcDim != dstDim)
2550  foundMismatchingDims = true;
2551 
2552  // Check scalable flags.
2553  bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2554  bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2555  if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2556  // 1 -> [N] is fine, everything else should be rejected when mixing
2557  // fixed-width and scalable dims
2558  (srcDimScalableFlag != dstDimScalableFlag &&
2559  (srcDim != 1 || srcDimScalableFlag)))
2560  foundMismatchingDims = true;
2561 
2562  if (foundMismatchingDims) {
2563  if (mismatchingDims != nullptr) {
2564  mismatchingDims->first.dim = srcDim;
2565  mismatchingDims->first.isScalable = srcDimScalableFlag;
2566 
2567  mismatchingDims->second.dim = dstDim;
2568  mismatchingDims->second.isScalable = dstDimScalableFlag;
2569  }
2570  return BroadcastableToResult::DimensionMismatch;
2571  }
2572  }
2573 
2574  return BroadcastableToResult::Success;
2575 }
2576 
2577 LogicalResult BroadcastOp::verify() {
2578  std::pair<VectorDim, VectorDim> mismatchingDims;
2580  getSourceType(), getResultVectorType(), &mismatchingDims);
2581  if (res == BroadcastableToResult::Success)
2582  return success();
2583  if (res == BroadcastableToResult::SourceRankHigher)
2584  return emitOpError("source rank higher than destination rank");
2585  if (res == BroadcastableToResult::DimensionMismatch) {
2586  return emitOpError("dimension mismatch (")
2587  << (mismatchingDims.first.isScalable ? "[" : "")
2588  << mismatchingDims.first.dim
2589  << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
2590  << (mismatchingDims.second.isScalable ? "[" : "")
2591  << mismatchingDims.second.dim
2592  << (mismatchingDims.second.isScalable ? "]" : "") << ")";
2593  }
2594  if (res == BroadcastableToResult::SourceTypeNotAVector)
2595  return emitOpError("source type is not a vector");
2596  llvm_unreachable("unexpected vector.broadcast op error");
2597 }
2598 
2599 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
2600  if (getSourceType() == getResultVectorType())
2601  return getSource();
2602  if (!adaptor.getSource())
2603  return {};
2604  auto vectorType = getResultVectorType();
2605  if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
2606  if (vectorType.getElementType() != attr.getType())
2607  return {};
2608  return DenseElementsAttr::get(vectorType, attr);
2609  }
2610  if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
2611  if (vectorType.getElementType() != attr.getType())
2612  return {};
2613  return DenseElementsAttr::get(vectorType, attr);
2614  }
2615  if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2616  return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
2617  return {};
2618 }
2619 
2620 namespace {
2621 
2622 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
2623 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2625 
2626  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2627  PatternRewriter &rewriter) const override {
2628  auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2629  if (!srcBroadcast)
2630  return failure();
2631  rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
2632  broadcastOp.getResultVectorType(),
2633  srcBroadcast.getSource());
2634  return success();
2635  }
2636 };
2637 } // namespace
2638 
2639 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2640  MLIRContext *context) {
2641  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2642  // calling `populateCastAwayVectorLeadingOneDimPatterns`
2643  results.add<BroadcastFolder>(context);
2644 }
2645 
2646 //===----------------------------------------------------------------------===//
2647 // ShuffleOp
2648 //===----------------------------------------------------------------------===//
2649 
2650 LogicalResult ShuffleOp::verify() {
2651  VectorType resultType = getResultVectorType();
2652  VectorType v1Type = getV1VectorType();
2653  VectorType v2Type = getV2VectorType();
2654  // Verify ranks.
2655  int64_t resRank = resultType.getRank();
2656  int64_t v1Rank = v1Type.getRank();
2657  int64_t v2Rank = v2Type.getRank();
2658  bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2659  bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2660  if (!wellFormed0DCase && !wellFormedNDCase)
2661  return emitOpError("rank mismatch");
2662 
2663  // Verify all but leading dimension sizes.
2664  for (int64_t r = 1; r < v1Rank; ++r) {
2665  int64_t resDim = resultType.getDimSize(r);
2666  int64_t v1Dim = v1Type.getDimSize(r);
2667  int64_t v2Dim = v2Type.getDimSize(r);
2668  if (resDim != v1Dim || v1Dim != v2Dim)
2669  return emitOpError("dimension mismatch");
2670  }
2671  // Verify mask length.
2672  ArrayRef<int64_t> mask = getMask();
2673  int64_t maskLength = mask.size();
2674  if (maskLength <= 0)
2675  return emitOpError("invalid mask length");
2676  if (maskLength != resultType.getDimSize(0))
2677  return emitOpError("mask length mismatch");
2678  // Verify all indices.
2679  int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2680  (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2681  for (auto [idx, maskPos] : llvm::enumerate(mask)) {
2682  if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
2683  return emitOpError("mask index #") << (idx + 1) << " out of range";
2684  }
2685  return success();
2686 }
2687 
2688 LogicalResult
2689 ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
2690  ShuffleOp::Adaptor adaptor,
2691  SmallVectorImpl<Type> &inferredReturnTypes) {
2692  auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2693  auto v1Rank = v1Type.getRank();
2694  // Construct resulting type: leading dimension matches mask
2695  // length, all trailing dimensions match the operands.
2697  shape.reserve(v1Rank);
2698  shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2699  // In the 0-D case there is no trailing shape to append.
2700  if (v1Rank > 0)
2701  llvm::append_range(shape, v1Type.getShape().drop_front());
2702  inferredReturnTypes.push_back(
2703  VectorType::get(shape, v1Type.getElementType()));
2704  return success();
2705 }
2706 
2707 template <typename T>
2708 static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
2709  T expected = begin;
2710  return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2711  return value == expected++;
2712  });
2713 }
2714 
2715 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2716  auto v1Type = getV1VectorType();
2717  auto v2Type = getV2VectorType();
2718 
2719  assert(!v1Type.isScalable() && !v2Type.isScalable() &&
2720  "Vector shuffle does not support scalable vectors");
2721 
2722  // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
2723  // but must be a canonicalization into a vector.broadcast.
2724  if (v1Type.getRank() == 0)
2725  return {};
2726 
2727  // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
2728  auto mask = getMask();
2729  if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
2730  return getV1();
2731  // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
2732  if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
2733  return getV2();
2734 
2735  Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
2736  if (!v1Attr || !v2Attr)
2737  return {};
2738 
2739  // Fold shuffle poison, poison -> poison.
2740  bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
2741  bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
2742  if (isV1Poison && isV2Poison)
2743  return ub::PoisonAttr::get(getContext());
2744 
2745  // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2746  // manipulation.
2747  if (v1Type.getRank() != 1)
2748  return {};
2749 
2750  // Poison input attributes need special handling as they are not
2751  // DenseElementsAttr. If an index is poison, we select the first element of
2752  // the first non-poison input.
2753  SmallVector<Attribute> v1Elements, v2Elements;
2754  Attribute poisonElement;
2755  if (!isV2Poison) {
2756  v2Elements =
2757  to_vector(cast<DenseElementsAttr>(v2Attr).getValues<Attribute>());
2758  poisonElement = v2Elements[0];
2759  }
2760  if (!isV1Poison) {
2761  v1Elements =
2762  to_vector(cast<DenseElementsAttr>(v1Attr).getValues<Attribute>());
2763  poisonElement = v1Elements[0];
2764  }
2765 
2766  SmallVector<Attribute> results;
2767  int64_t v1Size = v1Type.getDimSize(0);
2768  for (int64_t maskIdx : mask) {
2769  Attribute indexedElm;
2770  // TODO: Return a partial poison vector when supported by the UB dialect.
2771  if (maskIdx == ShuffleOp::kPoisonIndex) {
2772  indexedElm = poisonElement;
2773  } else {
2774  if (maskIdx < v1Size)
2775  indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
2776  else
2777  indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
2778  }
2779 
2780  results.push_back(indexedElm);
2781  }
2782 
2783  return DenseElementsAttr::get(getResultVectorType(), results);
2784 }
2785 
2786 namespace {
2787 
2788 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
2789 // to a broadcast.
2790 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
2792 
2793  LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
2794  PatternRewriter &rewriter) const override {
2795  VectorType v1VectorType = shuffleOp.getV1VectorType();
2796  ArrayRef<int64_t> mask = shuffleOp.getMask();
2797  if (v1VectorType.getRank() > 0)
2798  return failure();
2799  if (mask.size() != 1)
2800  return failure();
2801  VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
2802  if (mask[0] == 0)
2803  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2804  shuffleOp.getV1());
2805  else
2806  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2807  shuffleOp.getV2());
2808  return success();
2809  }
2810 };
2811 
2812 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
2813 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
2814 public:
2816 
2817  LogicalResult matchAndRewrite(ShuffleOp op,
2818  PatternRewriter &rewriter) const override {
2819  auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2820  auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2821 
2822  if (!v1Splat || !v2Splat)
2823  return failure();
2824 
2825  if (v1Splat.getInput() != v2Splat.getInput())
2826  return failure();
2827 
2828  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
2829  return success();
2830  }
2831 };
2832 
2833 /// Pattern to rewrite a fixed-size interleave via vector.shuffle to
2834 /// vector.interleave.
2835 class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
2836 public:
2838 
2839  LogicalResult matchAndRewrite(ShuffleOp op,
2840  PatternRewriter &rewriter) const override {
2841  VectorType resultType = op.getResultVectorType();
2842  if (resultType.isScalable())
2843  return rewriter.notifyMatchFailure(
2844  op, "ShuffleOp can't represent a scalable interleave");
2845 
2846  if (resultType.getRank() != 1)
2847  return rewriter.notifyMatchFailure(
2848  op, "ShuffleOp can't represent an n-D interleave");
2849 
2850  VectorType sourceType = op.getV1VectorType();
2851  if (sourceType != op.getV2VectorType() ||
2852  sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2853  return rewriter.notifyMatchFailure(
2854  op, "ShuffleOp types don't match an interleave");
2855  }
2856 
2857  ArrayRef<int64_t> shuffleMask = op.getMask();
2858  int64_t resultVectorSize = resultType.getNumElements();
2859  for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2860  int64_t maskValueA = shuffleMask[i * 2];
2861  int64_t maskValueB = shuffleMask[(i * 2) + 1];
2862  if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2863  return rewriter.notifyMatchFailure(op,
2864  "ShuffleOp mask not interleaving");
2865  }
2866 
2867  rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
2868  return success();
2869  }
2870 };
2871 
2872 } // namespace
2873 
2874 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
2875  MLIRContext *context) {
2876  results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2877  context);
2878 }
2879 
2880 //===----------------------------------------------------------------------===//
2881 // InsertElementOp
2882 //===----------------------------------------------------------------------===//
2883 
2884 void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2885  SetIntRangeFn setResultRanges) {
2886  setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2887 }
2888 
2889 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
2890  Value source, Value dest) {
2891  build(builder, result, source, dest, {});
2892 }
2893 
2894 LogicalResult InsertElementOp::verify() {
2895  auto dstVectorType = getDestVectorType();
2896  if (dstVectorType.getRank() == 0) {
2897  if (getPosition())
2898  return emitOpError("expected position to be empty with 0-D vector");
2899  return success();
2900  }
2901  if (dstVectorType.getRank() != 1)
2902  return emitOpError("unexpected >1 vector rank");
2903  if (!getPosition())
2904  return emitOpError("expected position for 1-D vector");
2905  return success();
2906 }
2907 
2908 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2909  // Skip the 0-D vector here.
2910  if (!adaptor.getPosition())
2911  return {};
2912 
2913  auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2914  auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2915  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2916  if (!src || !dst || !pos)
2917  return {};
2918 
2919  if (src.getType() != getDestVectorType().getElementType())
2920  return {};
2921 
2922  auto dstElements = dst.getValues<Attribute>();
2923 
2924  SmallVector<Attribute> results(dstElements);
2925 
2926  uint64_t posIdx = pos.getInt();
2927  if (posIdx >= results.size())
2928  return {};
2929  results[posIdx] = src;
2930 
2931  return DenseElementsAttr::get(getDestVectorType(), results);
2932 }
2933 
2934 //===----------------------------------------------------------------------===//
2935 // InsertOp
2936 //===----------------------------------------------------------------------===//
2937 
2938 void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2939  SetIntRangeFn setResultRanges) {
2940  setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2941 }
2942 
2943 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2944  Value source, Value dest, int64_t position) {
2945  build(builder, result, source, dest, ArrayRef<int64_t>{position});
2946 }
2947 
2948 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2949  Value source, Value dest, OpFoldResult position) {
2950  build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
2951 }
2952 
2953 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2954  Value source, Value dest,
2955  ArrayRef<int64_t> position) {
2956  SmallVector<OpFoldResult> posVals;
2957  posVals.reserve(position.size());
2958  llvm::transform(position, std::back_inserter(posVals),
2959  [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
2960  build(builder, result, source, dest, posVals);
2961 }
2962 
2963 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2964  Value source, Value dest,
2965  ArrayRef<OpFoldResult> position) {
2966  SmallVector<int64_t> staticPos;
2967  SmallVector<Value> dynamicPos;
2968  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
2969  build(builder, result, source, dest, dynamicPos,
2970  builder.getDenseI64ArrayAttr(staticPos));
2971 }
2972 
2973 LogicalResult InsertOp::verify() {
2974  SmallVector<OpFoldResult> position = getMixedPosition();
2975  auto destVectorType = getDestVectorType();
2976  if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
2977  return emitOpError(
2978  "expected position attribute of rank no greater than dest vector rank");
2979  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2980  if (srcVectorType &&
2981  (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2982  static_cast<unsigned>(destVectorType.getRank())))
2983  return emitOpError("expected position attribute rank + source rank to "
2984  "match dest vector rank");
2985  if (!srcVectorType &&
2986  (position.size() != static_cast<unsigned>(destVectorType.getRank())))
2987  return emitOpError(
2988  "expected position attribute rank to match the dest vector rank");
2989  for (auto [idx, pos] : llvm::enumerate(position)) {
2990  if (auto attr = pos.dyn_cast<Attribute>()) {
2991  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2992  if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
2993  destVectorType.getDimSize(idx))) {
2994  return emitOpError("expected position attribute #")
2995  << (idx + 1)
2996  << " to be a non-negative integer smaller than the "
2997  "corresponding "
2998  "dest vector dimension";
2999  }
3000  }
3001  }
3002  return success();
3003 }
3004 
3005 namespace {
3006 
3007 // If insertOp is only inserting unit dimensions it can be transformed to a
3008 // broadcast.
3009 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3010 public:
3012 
3013  LogicalResult matchAndRewrite(InsertOp insertOp,
3014  PatternRewriter &rewriter) const override {
3015  auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
3016  if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3017  srcVecType.getNumElements())
3018  return failure();
3019  rewriter.replaceOpWithNewOp<BroadcastOp>(
3020  insertOp, insertOp.getDestVectorType(), insertOp.getSource());
3021  return success();
3022  }
3023 };
3024 
3025 /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
3026 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3027 public:
3029 
3030  LogicalResult matchAndRewrite(InsertOp op,
3031  PatternRewriter &rewriter) const override {
3032  auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
3033  auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
3034 
3035  if (!srcSplat || !dstSplat)
3036  return failure();
3037 
3038  if (srcSplat.getInput() != dstSplat.getInput())
3039  return failure();
3040 
3041  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
3042  return success();
3043  }
3044 };
3045 
3046 // Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
3047 class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
3048 public:
3050 
3051  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3052  // unless the source vector constant has a single use.
3053  static constexpr int64_t vectorSizeFoldThreshold = 256;
3054 
3055  LogicalResult matchAndRewrite(InsertOp op,
3056  PatternRewriter &rewriter) const override {
3057  // TODO: Canonicalization for dynamic position not implemented yet.
3058  if (op.hasDynamicPosition())
3059  return failure();
3060 
3061  // Return if 'InsertOp' operand is not defined by a compatible vector
3062  // ConstantOp.
3063  TypedValue<VectorType> destVector = op.getDest();
3064  Attribute vectorDestCst;
3065  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3066  return failure();
3067  auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
3068  if (!denseDest)
3069  return failure();
3070 
3071  VectorType destTy = destVector.getType();
3072  if (destTy.isScalable())
3073  return failure();
3074 
3075  // Make sure we do not create too many large constants.
3076  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3077  !destVector.hasOneUse())
3078  return failure();
3079 
3080  Value sourceValue = op.getSource();
3081  Attribute sourceCst;
3082  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3083  return failure();
3084 
3085  // Calculate the linearized position of the continuous chunk of elements to
3086  // insert.
3087  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3088  copy(op.getStaticPosition(), completePositions.begin());
3089  int64_t insertBeginPosition =
3090  linearize(completePositions, computeStrides(destTy.getShape()));
3091 
3092  SmallVector<Attribute> insertedValues;
3093  Type destEltType = destTy.getElementType();
3094 
3095  // The `convertIntegerAttr` method specifically handles the case
3096  // for `llvm.mlir.constant` which can hold an attribute with a
3097  // different type than the return type.
3098  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
3099  for (auto value : denseSource.getValues<Attribute>())
3100  insertedValues.push_back(convertIntegerAttr(value, destEltType));
3101  } else {
3102  insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType));
3103  }
3104 
3105  auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
3106  copy(insertedValues, allValues.begin() + insertBeginPosition);
3107  auto newAttr = DenseElementsAttr::get(destTy, allValues);
3108 
3109  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3110  return success();
3111  }
3112 
3113 private:
3114  /// Converts the expected type to an IntegerAttr if there's
3115  /// a mismatch.
3116  Attribute convertIntegerAttr(Attribute attr, Type expectedType) const {
3117  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3118  if (intAttr.getType() != expectedType)
3119  return IntegerAttr::get(expectedType, intAttr.getInt());
3120  }
3121  return attr;
3122  }
3123 };
3124 
3125 } // namespace
3126 
3127 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3128  MLIRContext *context) {
3129  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3130  InsertOpConstantFolder>(context);
3131 }
3132 
3133 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3134  // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3135  // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3136  // (type mismatch).
3137  if (getNumIndices() == 0 && getSourceType() == getType())
3138  return getSource();
3139  SmallVector<Value> operands = {getSource(), getDest()};
3140  if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
3141  return val;
3142  if (auto res = foldPoisonIndexInsertExtractOp(
3143  getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3144  return res;
3145 
3146  return {};
3147 }
3148 
3149 //===----------------------------------------------------------------------===//
3150 // InsertStridedSliceOp
3151 //===----------------------------------------------------------------------===//
3152 
3153 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3154  Value source, Value dest,
3155  ArrayRef<int64_t> offsets,
3156  ArrayRef<int64_t> strides) {
3157  result.addOperands({source, dest});
3158  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3159  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3160  result.addTypes(dest.getType());
3161  result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
3162  offsetsAttr);
3163  result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
3164  stridesAttr);
3165 }
3166 
3167 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
3168 template <typename OpType>
3169 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3170  ArrayAttr arrayAttr,
3171  ArrayRef<int64_t> shape,
3172  StringRef attrName) {
3173  if (arrayAttr.size() > shape.size())
3174  return op.emitOpError("expected ")
3175  << attrName << " attribute of rank no greater than vector rank";
3176  return success();
3177 }
3178 
3179 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3180 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3181 // Otherwise, the admissible interval is [min, max].
3182 template <typename OpType>
3183 static LogicalResult
3184 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
3185  int64_t max, StringRef attrName,
3186  bool halfOpen = true) {
3187  for (auto attr : arrayAttr) {
3188  auto val = llvm::cast<IntegerAttr>(attr).getInt();
3189  auto upper = max;
3190  if (!halfOpen)
3191  upper += 1;
3192  if (val < min || val >= upper)
3193  return op.emitOpError("expected ") << attrName << " to be confined to ["
3194  << min << ", " << upper << ")";
3195  }
3196  return success();
3197 }
3198 
3199 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3200 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3201 // Otherwise, the admissible interval is [min, max].
3202 template <typename OpType>
3203 static LogicalResult
3204 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
3205  ArrayRef<int64_t> shape, StringRef attrName,
3206  bool halfOpen = true, int64_t min = 0) {
3207  for (auto [index, attrDimPair] :
3208  llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3209  int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3210  int64_t max = std::get<1>(attrDimPair);
3211  if (!halfOpen)
3212  max += 1;
3213  if (val < min || val >= max)
3214  return op.emitOpError("expected ")
3215  << attrName << " dimension " << index << " to be confined to ["
3216  << min << ", " << max << ")";
3217  }
3218  return success();
3219 }
3220 
3221 // Returns true if, for all indices i = 0..shape.size()-1, val is in the
3222 // [min, max} interval:
3223 // val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3224 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3225 // the admissible interval is [min, max].
3226 template <typename OpType>
3228  OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3229  ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3230  bool halfOpen = true, int64_t min = 1) {
3231  assert(arrayAttr1.size() <= shape.size());
3232  assert(arrayAttr2.size() <= shape.size());
3233  for (auto [index, it] :
3234  llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3235  auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3236  auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3237  int64_t max = std::get<2>(it);
3238  if (!halfOpen)
3239  max += 1;
3240  if (val1 + val2 < 0 || val1 + val2 >= max)
3241  return op.emitOpError("expected sum(")
3242  << attrName1 << ", " << attrName2 << ") dimension " << index
3243  << " to be confined to [" << min << ", " << max << ")";
3244  }
3245  return success();
3246 }
3247 
3248 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
3249  MLIRContext *context) {
3250  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3251  return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3252  });
3253  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3254 }
3255 
3256 LogicalResult InsertStridedSliceOp::verify() {
3257  auto sourceVectorType = getSourceVectorType();
3258  auto destVectorType = getDestVectorType();
3259  auto offsets = getOffsetsAttr();
3260  auto strides = getStridesAttr();
3261  if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3262  return emitOpError(
3263  "expected offsets of same size as destination vector rank");
3264  if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3265  return emitOpError("expected strides of same size as source vector rank");
3266  if (sourceVectorType.getRank() > destVectorType.getRank())
3267  return emitOpError(
3268  "expected source rank to be no greater than destination rank");
3269 
3270  auto sourceShape = sourceVectorType.getShape();
3271  auto destShape = destVectorType.getShape();
3272  SmallVector<int64_t, 4> sourceShapeAsDestShape(
3273  destShape.size() - sourceShape.size(), 0);
3274  sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3275  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3276  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3277  if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3278  offName)) ||
3279  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3280  /*max=*/1, stridesName,
3281  /*halfOpen=*/false)) ||
3283  *this, offsets,
3284  makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3285  offName, "source vector shape",
3286  /*halfOpen=*/false, /*min=*/1)))
3287  return failure();
3288 
3289  unsigned rankDiff = destShape.size() - sourceShape.size();
3290  for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3291  if (sourceVectorType.getScalableDims()[idx] !=
3292  destVectorType.getScalableDims()[idx + rankDiff]) {
3293  return emitOpError("mismatching scalable flags (at source vector idx=")
3294  << idx << ")";
3295  }
3296  if (sourceVectorType.getScalableDims()[idx]) {
3297  auto sourceSize = sourceShape[idx];
3298  auto destSize = destShape[idx + rankDiff];
3299  if (sourceSize != destSize) {
3300  return emitOpError("expected size at idx=")
3301  << idx
3302  << (" to match the corresponding base size from the input "
3303  "vector (")
3304  << sourceSize << (" vs ") << destSize << (")");
3305  }
3306  }
3307  }
3308 
3309  return success();
3310 }
3311 
3312 namespace {
3313 /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3314 /// SplatOp(X):dst_type) to SplatOp(X):dst_type.
3315 class FoldInsertStridedSliceSplat final
3316  : public OpRewritePattern<InsertStridedSliceOp> {
3317 public:
3319 
3320  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3321  PatternRewriter &rewriter) const override {
3322  auto srcSplatOp =
3323  insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
3324  auto destSplatOp =
3325  insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3326 
3327  if (!srcSplatOp || !destSplatOp)
3328  return failure();
3329 
3330  if (srcSplatOp.getInput() != destSplatOp.getInput())
3331  return failure();
3332 
3333  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3334  return success();
3335  }
3336 };
3337 
3338 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3339 /// to dst.
3340 class FoldInsertStridedSliceOfExtract final
3341  : public OpRewritePattern<InsertStridedSliceOp> {
3342 public:
3344 
3345  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3346  PatternRewriter &rewriter) const override {
3347  auto extractStridedSliceOp =
3348  insertStridedSliceOp.getSource()
3349  .getDefiningOp<vector::ExtractStridedSliceOp>();
3350 
3351  if (!extractStridedSliceOp)
3352  return failure();
3353 
3354  if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3355  return failure();
3356 
3357  // Check if have the same strides and offsets.
3358  if (extractStridedSliceOp.getStrides() !=
3359  insertStridedSliceOp.getStrides() ||
3360  extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3361  return failure();
3362 
3363  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3364  return success();
3365  }
3366 };
3367 
3368 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3369 // ConstantOp.
3370 class InsertStridedSliceConstantFolder final
3371  : public OpRewritePattern<InsertStridedSliceOp> {
3372 public:
3374 
3375  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3376  // unless the source vector constant has a single use.
3377  static constexpr int64_t vectorSizeFoldThreshold = 256;
3378 
3379  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3380  PatternRewriter &rewriter) const override {
3381  // Return if 'InsertOp' operand is not defined by a compatible vector
3382  // ConstantOp.
3383  TypedValue<VectorType> destVector = op.getDest();
3384  Attribute vectorDestCst;
3385  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3386  return failure();
3387 
3388  VectorType destTy = destVector.getType();
3389  if (destTy.isScalable())
3390  return failure();
3391 
3392  // Make sure we do not create too many large constants.
3393  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3394  !destVector.hasOneUse())
3395  return failure();
3396 
3397  TypedValue<VectorType> sourceValue = op.getSource();
3398  Attribute sourceCst;
3399  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3400  return failure();
3401 
3402  // TODO: Support poison.
3403  if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3404  return failure();
3405 
3406  // TODO: Handle non-unit strides when they become available.
3407  if (op.hasNonUnitStrides())
3408  return failure();
3409 
3410  VectorType sliceVecTy = sourceValue.getType();
3411  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3412  int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3413  SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3414  SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3415 
3416  // Calcualte the destination element indices by enumerating all slice
3417  // positions within the destination and linearizing them. The enumeration
3418  // order is lexicographic which yields a sequence of monotonically
3419  // increasing linearized position indices.
3420  // Because the destination may have higher dimensionality then the slice,
3421  // we keep track of two overlapping sets of positions and offsets.
3422  auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3423  auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3424  auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3425  auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3426  SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3427  MutableArrayRef<int64_t> currSlicePosition(
3428  currDestPosition.begin() + rankDifference, currDestPosition.end());
3429  ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3430  offsets.end());
3431  do {
3432  int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3433  assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3434  assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3435  "Invalid slice element");
3436  newValues[linearizedPosition] = *sliceValuesIt;
3437  ++sliceValuesIt;
3438  } while (succeeded(
3439  incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3440 
3441  auto newAttr = DenseElementsAttr::get(destTy, newValues);
3442  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3443  return success();
3444  }
3445 };
3446 
3447 } // namespace
3448 
3449 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3450  RewritePatternSet &results, MLIRContext *context) {
3451  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3452  InsertStridedSliceConstantFolder>(context);
3453 }
3454 
3455 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3456  if (getSourceVectorType() == getDestVectorType())
3457  return getSource();
3458  return {};
3459 }
3460 
3461 //===----------------------------------------------------------------------===//
3462 // OuterProductOp
3463 //===----------------------------------------------------------------------===//
3464 
3465 /// Build an op without mask, use the type of `acc` as the return type.
3466 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
3467  Value lhs, Value rhs, Value acc) {
3468  result.addOperands({lhs, rhs, acc});
3469  result.addTypes(acc.getType());
3470 }
3471 
3473  p << " " << getLhs() << ", " << getRhs();
3474  if (getAcc()) {
3475  p << ", " << getAcc();
3476  p.printOptionalAttrDict((*this)->getAttrs());
3477  }
3478  p << " : " << getLhs().getType() << ", " << getRhs().getType();
3479 }
3480 
3481 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
3483  Type tLHS, tRHS;
3484  if (parser.parseOperandList(operandsInfo) ||
3485  parser.parseOptionalAttrDict(result.attributes) ||
3486  parser.parseColonType(tLHS) || parser.parseComma() ||
3487  parser.parseType(tRHS))
3488  return failure();
3489  if (operandsInfo.size() < 2)
3490  return parser.emitError(parser.getNameLoc(),
3491  "expected at least 2 operands");
3492  VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3493  VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3494  if (!vLHS)
3495  return parser.emitError(parser.getNameLoc(),
3496  "expected vector type for operand #1");
3497 
3498  VectorType resType;
3499  if (vRHS) {
3500  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3501  vRHS.getScalableDims()[0]};
3502  resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
3503  vLHS.getElementType(), scalableDimsRes);
3504  } else {
3505  // Scalar RHS operand
3506  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3507  resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3508  scalableDimsRes);
3509  }
3510 
3511  if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
3512  result.attributes.append(
3513  OuterProductOp::getKindAttrName(result.name),
3515  OuterProductOp::getDefaultKind()));
3516  }
3517 
3518  return failure(
3519  parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
3520  parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
3521  (operandsInfo.size() > 2 &&
3522  parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
3523  parser.addTypeToList(resType, result.types));
3524 }
3525 
3526 LogicalResult OuterProductOp::verify() {
3527  Type tRHS = getOperandTypeRHS();
3528  VectorType vLHS = getOperandVectorTypeLHS(),
3529  vRHS = llvm::dyn_cast<VectorType>(tRHS),
3530  vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3531 
3532  if (vLHS.getRank() != 1)
3533  return emitOpError("expected 1-d vector for operand #1");
3534 
3535  if (vRHS) {
3536  // Proper OUTER operation.
3537  if (vRHS.getRank() != 1)
3538  return emitOpError("expected 1-d vector for operand #2");
3539  if (vRES.getRank() != 2)
3540  return emitOpError("expected 2-d vector result");
3541  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3542  return emitOpError("expected #1 operand dim to match result dim #1");
3543  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3544  return emitOpError("expected #2 operand dim to match result dim #2");
3545  if (vLHS.isScalable() && !vRHS.isScalable()) {
3546  // This restriction reflects what's currently supported in terms of
3547  // scalable vectors. However, we could relax this if there's a use case.
3548  return emitOpError(
3549  "expected either both or only #2 operand dim to be scalable");
3550  }
3551  } else {
3552  // An AXPY operation.
3553  if (vRES.getRank() != 1)
3554  return emitOpError("expected 1-d vector result");
3555  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3556  return emitOpError("expected #1 operand dim to match result dim #1");
3557  }
3558 
3559  if (vACC && vACC != vRES)
3560  return emitOpError("expected operand #3 of same type as result type");
3561 
3562  // Verify supported combining kind.
3563  if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
3564  return emitOpError("unsupported outerproduct type");
3565 
3566  return success();
3567 }
3568 
3569 // MaskableOpInterface methods.
3570 
3571 /// Returns the mask type expected by this operation. Mostly used for
3572 /// verification purposes. It requires the operation to be vectorized."
3573 Type OuterProductOp::getExpectedMaskType() {
3574  auto vecType = this->getResultVectorType();
3575  return VectorType::get(vecType.getShape(),
3576  IntegerType::get(vecType.getContext(), /*width=*/1),
3577  vecType.getScalableDims());
3578 }
3579 
3580 //===----------------------------------------------------------------------===//
3581 // ExtractStridedSliceOp
3582 //===----------------------------------------------------------------------===//
3583 
3584 // Inference works as follows:
3585 // 1. Add 'sizes' from prefix of dims in 'offsets'.
3586 // 2. Add sizes from 'vectorType' for remaining dims.
3587 // Scalable flags are inherited from 'vectorType'.
3588 static Type inferStridedSliceOpResultType(VectorType vectorType,
3589  ArrayAttr offsets, ArrayAttr sizes,
3590  ArrayAttr strides) {
3591  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3593  shape.reserve(vectorType.getRank());
3594  unsigned idx = 0;
3595  for (unsigned e = offsets.size(); idx < e; ++idx)
3596  shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3597  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3598  shape.push_back(vectorType.getShape()[idx]);
3599 
3600  return VectorType::get(shape, vectorType.getElementType(),
3601  vectorType.getScalableDims());
3602 }
3603 
3604 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3605  Value source, ArrayRef<int64_t> offsets,
3606  ArrayRef<int64_t> sizes,
3607  ArrayRef<int64_t> strides) {
3608  result.addOperands(source);
3609  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3610  auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
3611  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3612  result.addTypes(
3613  inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
3614  offsetsAttr, sizesAttr, stridesAttr));
3615  result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
3616  offsetsAttr);
3617  result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
3618  sizesAttr);
3619  result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
3620  stridesAttr);
3621 }
3622 
3623 LogicalResult ExtractStridedSliceOp::verify() {
3624  auto type = getSourceVectorType();
3625  auto offsets = getOffsetsAttr();
3626  auto sizes = getSizesAttr();
3627  auto strides = getStridesAttr();
3628  if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3629  return emitOpError(
3630  "expected offsets, sizes and strides attributes of same size");
3631 
3632  auto shape = type.getShape();
3633  auto offName = getOffsetsAttrName();
3634  auto sizesName = getSizesAttrName();
3635  auto stridesName = getStridesAttrName();
3636  if (failed(
3637  isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
3638  failed(
3639  isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
3640  failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
3641  stridesName)) ||
3642  failed(
3643  isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
3644  failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
3645  /*halfOpen=*/false,
3646  /*min=*/1)) ||
3647  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3648  /*max=*/1, stridesName,
3649  /*halfOpen=*/false)) ||
3650  failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
3651  shape, offName, sizesName,
3652  /*halfOpen=*/false)))
3653  return failure();
3654 
3655  auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
3656  offsets, sizes, strides);
3657  if (getResult().getType() != resultType)
3658  return emitOpError("expected result type to be ") << resultType;
3659 
3660  for (unsigned idx = 0; idx < sizes.size(); ++idx) {
3661  if (type.getScalableDims()[idx]) {
3662  auto inputDim = type.getShape()[idx];
3663  auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3664  if (inputDim != inputSize)
3665  return emitOpError("expected size at idx=")
3666  << idx
3667  << (" to match the corresponding base size from the input "
3668  "vector (")
3669  << inputSize << (" vs ") << inputDim << (")");
3670  }
3671  }
3672 
3673  return success();
3674 }
3675 
3676 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
3677 // to use the source of the InsertStrided ops if we can detect that the
3678 // extracted vector is a subset of one of the vector inserted.
3679 static LogicalResult
3680 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
3681  // Helper to extract integer out of ArrayAttr.
3682  auto getElement = [](ArrayAttr array, int idx) {
3683  return llvm::cast<IntegerAttr>(array[idx]).getInt();
3684  };
3685  ArrayAttr extractOffsets = op.getOffsets();
3686  ArrayAttr extractStrides = op.getStrides();
3687  ArrayAttr extractSizes = op.getSizes();
3688  auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3689  while (insertOp) {
3690  if (op.getSourceVectorType().getRank() !=
3691  insertOp.getSourceVectorType().getRank())
3692  return failure();
3693  ArrayAttr insertOffsets = insertOp.getOffsets();
3694  ArrayAttr insertStrides = insertOp.getStrides();
3695  // If the rank of extract is greater than the rank of insert, we are likely
3696  // extracting a partial chunk of the vector inserted.
3697  if (extractOffsets.size() > insertOffsets.size())
3698  return failure();
3699  bool patialoverlap = false;
3700  bool disjoint = false;
3701  SmallVector<int64_t, 4> offsetDiffs;
3702  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3703  if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
3704  return failure();
3705  int64_t start = getElement(insertOffsets, dim);
3706  int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3707  int64_t offset = getElement(extractOffsets, dim);
3708  int64_t size = getElement(extractSizes, dim);
3709  // Check if the start of the extract offset is in the interval inserted.
3710  if (start <= offset && offset < end) {
3711  // If the extract interval overlaps but is not fully included we may
3712  // have a partial overlap that will prevent any folding.
3713  if (offset + size > end)
3714  patialoverlap = true;
3715  offsetDiffs.push_back(offset - start);
3716  continue;
3717  }
3718  disjoint = true;
3719  break;
3720  }
3721  // The extract element chunk is a subset of the insert element.
3722  if (!disjoint && !patialoverlap) {
3723  op.setOperand(insertOp.getSource());
3724  // OpBuilder is only used as a helper to build an I64ArrayAttr.
3725  OpBuilder b(op.getContext());
3726  op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
3727  return success();
3728  }
3729  // If the chunk extracted is disjoint from the chunk inserted, keep looking
3730  // in the insert chain.
3731  if (disjoint)
3732  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3733  else {
3734  // The extracted vector partially overlap the inserted vector, we cannot
3735  // fold.
3736  return failure();
3737  }
3738  }
3739  return failure();
3740 }
3741 
3742 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3743  if (getSourceVectorType() == getResult().getType())
3744  return getVector();
3745  if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
3746  return getResult();
3747  return {};
3748 }
3749 
3750 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
3751  populateFromInt64AttrArray(getOffsets(), results);
3752 }
3753 
3754 namespace {
3755 
3756 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
3757 // ConstantMaskOp.
3758 class StridedSliceConstantMaskFolder final
3759  : public OpRewritePattern<ExtractStridedSliceOp> {
3760 public:
3762 
3763  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3764  PatternRewriter &rewriter) const override {
3765  // Return if 'extractStridedSliceOp' operand is not defined by a
3766  // ConstantMaskOp.
3767  auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3768  auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3769  if (!constantMaskOp)
3770  return failure();
3771  // Return if 'extractStridedSliceOp' has non-unit strides.
3772  if (extractStridedSliceOp.hasNonUnitStrides())
3773  return failure();
3774  // Gather constant mask dimension sizes.
3775  ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
3776  // Gather strided slice offsets and sizes.
3777  SmallVector<int64_t, 4> sliceOffsets;
3778  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
3779  sliceOffsets);
3780  SmallVector<int64_t, 4> sliceSizes;
3781  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
3782 
3783  // Compute slice of vector mask region.
3784  SmallVector<int64_t, 4> sliceMaskDimSizes;
3785  sliceMaskDimSizes.reserve(maskDimSizes.size());
3786  for (auto [maskDimSize, sliceOffset, sliceSize] :
3787  llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3788  int64_t sliceMaskDimSize = std::max(
3789  static_cast<int64_t>(0),
3790  std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3791  sliceMaskDimSizes.push_back(sliceMaskDimSize);
3792  }
3793  // Add unchanged dimensions.
3794  if (sliceMaskDimSizes.size() < maskDimSizes.size())
3795  for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3796  sliceMaskDimSizes.push_back(maskDimSizes[i]);
3797  // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
3798  // region is a conjunction of mask dim intervals).
3799  if (llvm::is_contained(sliceMaskDimSizes, 0))
3800  sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3801 
3802  // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
3803  // region.
3804  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3805  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3806  sliceMaskDimSizes);
3807  return success();
3808  }
3809 };
3810 
3811 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3812 class StridedSliceSplatConstantFolder final
3813  : public OpRewritePattern<ExtractStridedSliceOp> {
3814 public:
3816 
3817  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3818  PatternRewriter &rewriter) const override {
3819  // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3820  // ConstantOp.
3821  Value sourceVector = extractStridedSliceOp.getVector();
3822  Attribute vectorCst;
3823  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3824  return failure();
3825 
3826  auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3827  if (!splat)
3828  return failure();
3829 
3830  auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
3831  splat.getSplatValue<Attribute>());
3832  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3833  newAttr);
3834  return success();
3835  }
3836 };
3837 
3838 // Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3839 // ConstantOp.
3840 class StridedSliceNonSplatConstantFolder final
3841  : public OpRewritePattern<ExtractStridedSliceOp> {
3842 public:
3844 
3845  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3846  PatternRewriter &rewriter) const override {
3847  // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3848  // ConstantOp.
3849  Value sourceVector = extractStridedSliceOp.getVector();
3850  Attribute vectorCst;
3851  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3852  return failure();
3853 
3854  // The splat case is handled by `StridedSliceSplatConstantFolder`.
3855  auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3856  if (!dense || dense.isSplat())
3857  return failure();
3858 
3859  // TODO: Handle non-unit strides when they become available.
3860  if (extractStridedSliceOp.hasNonUnitStrides())
3861  return failure();
3862 
3863  auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
3864  ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3865  SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3866 
3867  VectorType sliceVecTy = extractStridedSliceOp.getType();
3868  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3869  int64_t sliceRank = sliceVecTy.getRank();
3870 
3871  // Expand offsets and sizes to match the vector rank.
3872  SmallVector<int64_t, 4> offsets(sliceRank, 0);
3873  copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
3874 
3875  SmallVector<int64_t, 4> sizes(sourceShape);
3876  copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
3877 
3878  // Calculate the slice elements by enumerating all slice positions and
3879  // linearizing them. The enumeration order is lexicographic which yields a
3880  // sequence of monotonically increasing linearized position indices.
3881  auto denseValuesBegin = dense.value_begin<Attribute>();
3882  SmallVector<Attribute> sliceValues;
3883  sliceValues.reserve(sliceVecTy.getNumElements());
3884  SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3885  do {
3886  int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
3887  assert(linearizedPosition < sourceVecTy.getNumElements() &&
3888  "Invalid index");
3889  sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3890  } while (
3891  succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
3892 
3893  assert(static_cast<int64_t>(sliceValues.size()) ==
3894  sliceVecTy.getNumElements() &&
3895  "Invalid number of slice elements");
3896  auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
3897  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3898  newAttr);
3899  return success();
3900  }
3901 };
3902 
3903 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
3904 // BroadcastOp(ExtractStrideSliceOp).
3905 class StridedSliceBroadcast final
3906  : public OpRewritePattern<ExtractStridedSliceOp> {
3907 public:
3909 
3910  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3911  PatternRewriter &rewriter) const override {
3912  auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3913  if (!broadcast)
3914  return failure();
3915  auto srcVecType =
3916  llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
3917  unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3918  auto dstVecType = llvm::cast<VectorType>(op.getType());
3919  unsigned dstRank = dstVecType.getRank();
3920  unsigned rankDiff = dstRank - srcRank;
3921  // Check if the most inner dimensions of the source of the broadcast are the
3922  // same as the destination of the extract. If this is the case we can just
3923  // use a broadcast as the original dimensions are untouched.
3924  bool lowerDimMatch = true;
3925  for (unsigned i = 0; i < srcRank; i++) {
3926  if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3927  lowerDimMatch = false;
3928  break;
3929  }
3930  }
3931  Value source = broadcast.getSource();
3932  // If the inner dimensions don't match, it means we need to extract from the
3933  // source of the orignal broadcast and then broadcast the extracted value.
3934  // We also need to handle degenerated cases where the source is effectively
3935  // just a single scalar.
3936  bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3937  if (!lowerDimMatch && !isScalarSrc) {
3938  source = rewriter.create<ExtractStridedSliceOp>(
3939  op->getLoc(), source,
3940  getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
3941  getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
3942  getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
3943  }
3944  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
3945  return success();
3946  }
3947 };
3948 
3949 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
3950 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
3951 public:
3953 
3954  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3955  PatternRewriter &rewriter) const override {
3956  auto splat = op.getVector().getDefiningOp<SplatOp>();
3957  if (!splat)
3958  return failure();
3959  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
3960  return success();
3961  }
3962 };
3963 
3964 /// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
3965 /// slice is contiguous, into extract and shape_cast.
3966 ///
3967 /// Example:
3968 /// Before:
3969 /// %1 = vector.extract_strided_slice %arg0 {
3970 /// offsets = [0, 0, 0, 0, 0],
3971 /// sizes = [1, 1, 1, 1, 8],
3972 /// strides = [1, 1, 1, 1, 1]
3973 /// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
3974 /// After:
3975 /// %0 = vector.extract %arg0[0, 0, 0, 0]
3976 /// : vector<8xi8> from vector<8x1x1x2x8xi8>
3977 /// %1 = vector.shape_cast %0
3978 /// : vector<8xi8> to vector<1x1x1x1x8xi8>
3979 ///
3980 class ContiguousExtractStridedSliceToExtract final
3981  : public OpRewritePattern<ExtractStridedSliceOp> {
3982 public:
3984 
3985  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3986  PatternRewriter &rewriter) const override {
3987  if (op.hasNonUnitStrides())
3988  return failure();
3989  Value source = op.getOperand();
3990  auto sourceType = cast<VectorType>(source.getType());
3991  if (sourceType.isScalable() || sourceType.getRank() == 0)
3992  return failure();
3993 
3994  // Compute the number of offsets to pass to ExtractOp::build. That is the
3995  // difference between the source rank and the desired slice rank. We walk
3996  // the dimensions from innermost out, and stop when the next slice dimension
3997  // is not full-size.
3998  SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
3999  int numOffsets;
4000  for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4001  if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4002  break;
4003  }
4004 
4005  // If the created extract op would have no offsets, then this whole
4006  // extract_strided_slice is the identity and should have been handled by
4007  // other canonicalizations.
4008  if (numOffsets == 0)
4009  return failure();
4010 
4011  // If not even the inner-most dimension is full-size, this op can't be
4012  // rewritten as an ExtractOp.
4013  if (numOffsets == sourceType.getRank() &&
4014  static_cast<int>(sizes.size()) == sourceType.getRank())
4015  return failure();
4016 
4017  // The outer dimensions must have unit size.
4018  for (int i = 0; i < numOffsets; ++i) {
4019  if (sizes[i] != 1)
4020  return failure();
4021  }
4022 
4023  // Avoid generating slices that have leading unit dimensions. The shape_cast
4024  // op that we create below would take bad generic fallback patterns
4025  // (ShapeCastOpRewritePattern).
4026  while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
4027  sizes[numOffsets] == 1) {
4028  ++numOffsets;
4029  }
4030 
4031  SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
4032  auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4033  Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
4034  extractOffsets);
4035  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
4036  return success();
4037  }
4038 };
4039 
4040 } // namespace
4041 
4042 void ExtractStridedSliceOp::getCanonicalizationPatterns(
4043  RewritePatternSet &results, MLIRContext *context) {
4044  // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
4045  // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4046  results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
4047  StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
4048  StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4049  context);
4050 }
4051 
4052 //===----------------------------------------------------------------------===//
4053 // TransferReadOp
4054 //===----------------------------------------------------------------------===//
4055 
4056 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
4057 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4058  VectorType vectorType, Value source,
4059  ValueRange indices, AffineMapAttr permutationMapAttr,
4060  /*optional*/ ArrayAttr inBoundsAttr) {
4061  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4062  Value padding = builder.create<arith::ConstantOp>(
4063  result.location, elemType, builder.getZeroAttr(elemType));
4064  build(builder, result, vectorType, source, indices, permutationMapAttr,
4065  padding, /*mask=*/Value(), inBoundsAttr);
4066 }
4067 
4068 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
4069 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4070  VectorType vectorType, Value source,
4071  ValueRange indices, AffineMap permutationMap,
4072  std::optional<ArrayRef<bool>> inBounds) {
4073  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4074  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4075  ? builder.getBoolArrayAttr(inBounds.value())
4076  : builder.getBoolArrayAttr(
4077  SmallVector<bool>(vectorType.getRank(), false));
4078  build(builder, result, vectorType, source, indices, permutationMapAttr,
4079  inBoundsAttr);
4080 }
4081 
4082 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
4083 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4084  VectorType vectorType, Value source,
4085  ValueRange indices, Value padding,
4086  std::optional<ArrayRef<bool>> inBounds) {
4087  AffineMap permutationMap = getTransferMinorIdentityMap(
4088  llvm::cast<ShapedType>(source.getType()), vectorType);
4089  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4090  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4091  ? builder.getBoolArrayAttr(inBounds.value())
4092  : builder.getBoolArrayAttr(
4093  SmallVector<bool>(vectorType.getRank(), false));
4094  build(builder, result, vectorType, source, indices, permutationMapAttr,
4095  padding,
4096  /*mask=*/Value(), inBoundsAttr);
4097 }
4098 
4099 /// 4. Builder that sets padding to zero and permutation map to
4100 /// 'getMinorIdentityMap'.
4101 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4102  VectorType vectorType, Value source,
4103  ValueRange indices,
4104  std::optional<ArrayRef<bool>> inBounds) {
4105  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4106  Value padding = builder.create<arith::ConstantOp>(
4107  result.location, elemType, builder.getZeroAttr(elemType));
4108  build(builder, result, vectorType, source, indices, padding, inBounds);
4109 }
4110 
4111 template <typename EmitFun>
4112 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
4113  EmitFun emitOpError) {
4114  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
4115  for (auto expr : permutationMap.getResults()) {
4116  auto dim = dyn_cast<AffineDimExpr>(expr);
4117  auto zero = dyn_cast<AffineConstantExpr>(expr);
4118  if (zero) {
4119  if (zero.getValue() != 0) {
4120  return emitOpError(
4121  "requires a projected permutation_map (at most one dim or the zero "
4122  "constant can appear in each result)");
4123  }
4124  continue;
4125  }
4126  if (!dim) {
4127  return emitOpError("requires a projected permutation_map (at most one "
4128  "dim or the zero constant can appear in each result)");
4129  }
4130  if (seen[dim.getPosition()]) {
4131  return emitOpError(
4132  "requires a permutation_map that is a permutation (found one dim "
4133  "used more than once)");
4134  }
4135  seen[dim.getPosition()] = true;
4136  }
4137  return success();
4138 }
4139 
4140 static LogicalResult
4141 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4142  VectorType vectorType, VectorType maskType,
4143  VectorType inferredMaskType, AffineMap permutationMap,
4144  ArrayAttr inBounds) {
4145  if (op->hasAttr("masked")) {
4146  return op->emitOpError("masked attribute has been removed. "
4147  "Use in_bounds instead.");
4148  }
4149 
4150  if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4151  return op->emitOpError(
4152  "requires source to be a memref or ranked tensor type");
4153 
4154  auto elementType = shapedType.getElementType();
4155  DataLayout dataLayout = DataLayout::closest(op);
4156  if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4157  // Memref or tensor has vector element type.
4158  unsigned sourceVecSize =
4159  dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
4160  vectorElementType.getShape().back();
4161  unsigned resultVecSize =
4162  dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
4163  vectorType.getShape().back();
4164  if (resultVecSize % sourceVecSize != 0)
4165  return op->emitOpError(
4166  "requires the bitwidth of the minor 1-D vector to be an integral "
4167  "multiple of the bitwidth of the minor 1-D vector of the source");
4168 
4169  unsigned sourceVecEltRank = vectorElementType.getRank();
4170  unsigned resultVecRank = vectorType.getRank();
4171  if (sourceVecEltRank > resultVecRank)
4172  return op->emitOpError(
4173  "requires source vector element and vector result ranks to match.");
4174  unsigned rankOffset = resultVecRank - sourceVecEltRank;
4175  // Check that permutation map results match 'rankOffset' of vector type.
4176  if (permutationMap.getNumResults() != rankOffset)
4177  return op->emitOpError("requires a permutation_map with result dims of "
4178  "the same rank as the vector type");
4179 
4180  if (maskType)
4181  return op->emitOpError("does not support masks with vector element type");
4182  } else {
4183  // Memref or tensor has scalar element type.
4184  unsigned minorSize =
4185  vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4186  unsigned resultVecSize =
4187  dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
4188  if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
4189  return op->emitOpError(
4190  "requires the bitwidth of the minor 1-D vector to be an integral "
4191  "multiple of the bitwidth of the source element type");
4192 
4193  // Check that permutation map results match rank of vector type.
4194  if (permutationMap.getNumResults() != vectorType.getRank())
4195  return op->emitOpError("requires a permutation_map with result dims of "
4196  "the same rank as the vector type");
4197  }
4198 
4199  if (permutationMap.getNumSymbols() != 0)
4200  return op->emitOpError("requires permutation_map without symbols");
4201 
4202  if (permutationMap.getNumInputs() != shapedType.getRank())
4203  return op->emitOpError("requires a permutation_map with input dims of the "
4204  "same rank as the source type");
4205 
4206  if (maskType && maskType != inferredMaskType)
4207  return op->emitOpError("inferred mask type (")
4208  << inferredMaskType << ") and mask operand type (" << maskType
4209  << ") don't match";
4210 
4211  if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
4212  return op->emitOpError("expects the in_bounds attr of same rank "
4213  "as permutation_map results: ")
4214  << AffineMapAttr::get(permutationMap)
4215  << " vs inBounds of size: " << inBounds.size();
4216 
4217  return success();
4218 }
4219 
4220 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
4221  SmallVector<StringRef, 3> elidedAttrs;
4222  elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4223  if (op.getPermutationMap().isMinorIdentity())
4224  elidedAttrs.push_back(op.getPermutationMapAttrName());
4225  // Elide in_bounds attribute if all dims are out-of-bounds.
4226  if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
4227  elidedAttrs.push_back(op.getInBoundsAttrName());
4228  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4229 }
4230 
4232  p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
4233  if (getMask())
4234  p << ", " << getMask();
4235  printTransferAttrs(p, *this);
4236  p << " : " << getShapedType() << ", " << getVectorType();
4237 }
4238 
4239 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
4240  AffineMap permMap) {
4241  auto i1Type = IntegerType::get(permMap.getContext(), 1);
4242  AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
4243  assert(invPermMap && "Inversed permutation map couldn't be computed");
4244  SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
4245 
4246  // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
4247  // 0-D mask into a single-element 1-D mask.
4248  if (maskShape.empty())
4249  maskShape.push_back(1);
4250 
4251  SmallVector<bool> scalableDims =
4252  applyPermutationMap(invPermMap, vecType.getScalableDims());
4253 
4254  return VectorType::get(maskShape, i1Type, scalableDims);
4255 }
4256 
4257 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4258  auto &builder = parser.getBuilder();
4259  SMLoc typesLoc;
4260  OpAsmParser::UnresolvedOperand sourceInfo;
4262  OpAsmParser::UnresolvedOperand paddingInfo;
4263  SmallVector<Type, 2> types;
4265  // Parsing with support for paddingValue.
4266  if (parser.parseOperand(sourceInfo) ||
4267  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
4268  parser.parseComma() || parser.parseOperand(paddingInfo))
4269  return failure();
4270  ParseResult hasMask = parser.parseOptionalComma();
4271  if (hasMask.succeeded()) {
4272  if (parser.parseOperand(maskInfo))
4273  return failure();
4274  }
4275  if (parser.parseOptionalAttrDict(result.attributes) ||
4276  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4277  return failure();
4278  if (types.size() != 2)
4279  return parser.emitError(typesLoc, "requires two types");
4280  auto indexType = builder.getIndexType();
4281  auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4282  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4283  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4284  VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4285  if (!vectorType)
4286  return parser.emitError(typesLoc, "requires vector type");
4287  auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
4288  Attribute permMapAttr = result.attributes.get(permMapAttrName);
4289  AffineMap permMap;
4290  if (!permMapAttr) {
4291  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4292  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4293  } else {
4294  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4295  }
4296  auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
4297  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4298  if (!inBoundsAttr) {
4299  result.addAttribute(inBoundsAttrName,
4300  builder.getBoolArrayAttr(
4301  SmallVector<bool>(permMap.getNumResults(), false)));
4302  }
4303  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4304  parser.resolveOperands(indexInfo, indexType, result.operands) ||
4305  parser.resolveOperand(paddingInfo, shapedType.getElementType(),
4306  result.operands))
4307  return failure();
4308  if (hasMask.succeeded()) {
4309  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4310  return parser.emitError(
4311  maskInfo.location, "does not support masks with vector element type");
4312  if (vectorType.getRank() != permMap.getNumResults()) {
4313  return parser.emitError(typesLoc,
4314  "expected the same rank for the vector and the "
4315  "results of the permutation map");
4316  }
4317  // Instead of adding the mask type as an op type, compute it based on the
4318  // vector type and the permutation map (to keep the type signature small).
4319  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4320  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4321  return failure();
4322  }
4323  result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4324  builder.getDenseI32ArrayAttr(
4325  {1, static_cast<int32_t>(indexInfo.size()), 1,
4326  static_cast<int32_t>(hasMask.succeeded())}));
4327  return parser.addTypeToList(vectorType, result.types);
4328 }
4329 
4330 LogicalResult TransferReadOp::verify() {
4331  // Consistency of elemental types in source and vector.
4332  ShapedType shapedType = getShapedType();
4333  VectorType vectorType = getVectorType();
4334  VectorType maskType = getMaskType();
4335  auto paddingType = getPadding().getType();
4336  auto permutationMap = getPermutationMap();
4337  VectorType inferredMaskType =
4338  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4339  : VectorType();
4340  auto sourceElementType = shapedType.getElementType();
4341 
4342  if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
4343  return emitOpError("requires ") << shapedType.getRank() << " indices";
4344 
4345  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4346  shapedType, vectorType, maskType,
4347  inferredMaskType, permutationMap, getInBounds())))
4348  return failure();
4349 
4350  if (auto sourceVectorElementType =
4351  llvm::dyn_cast<VectorType>(sourceElementType)) {
4352  // Source has vector element type.
4353  // Check that 'sourceVectorElementType' and 'paddingType' types match.
4354  if (sourceVectorElementType != paddingType)
4355  return emitOpError(
4356  "requires source element type and padding type to match.");
4357 
4358  } else {
4359  // Check that 'paddingType' is valid to store in a vector type.
4360  if (!VectorType::isValidElementType(paddingType))
4361  return emitOpError("requires valid padding vector elemental type");
4362 
4363  // Check that padding type and vector element types match.
4364  if (paddingType != sourceElementType)
4365  return emitOpError(
4366  "requires formal padding and source of the same elemental type");
4367  }
4368 
4369  return verifyPermutationMap(permutationMap,
4370  [&](Twine t) { return emitOpError(t); });
4371 }
4372 
4373 // MaskableOpInterface methods.
4374 
4375 /// Returns the mask type expected by this operation. Mostly used for
4376 /// verification purposes. It requires the operation to be vectorized."
4377 Type TransferReadOp::getExpectedMaskType() {
4378  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4379 }
4380 
4381 template <typename TransferOp>
4382 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4383  // TODO: support more aggressive createOrFold on:
4384  // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4385  if (op.getShapedType().isDynamicDim(indicesIdx))
4386  return false;
4387  Value index = op.getIndices()[indicesIdx];
4388  std::optional<int64_t> cstOp = getConstantIntValue(index);
4389  if (!cstOp.has_value())
4390  return false;
4391 
4392  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4393  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4394 
4395  return cstOp.value() + vectorSize <= sourceSize;
4396 }
4397 
4398 template <typename TransferOp>
4399 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4400  // TODO: support 0-d corner case.
4401  // TODO: Be less conservative.
4402  if (op.getTransferRank() == 0)
4403  return failure();
4404  AffineMap permutationMap = op.getPermutationMap();
4405  bool changed = false;
4406  SmallVector<bool, 4> newInBounds;
4407  newInBounds.reserve(op.getTransferRank());
4408  // Idxs of non-bcast dims - used when analysing bcast dims.
4409  SmallVector<unsigned> nonBcastDims;
4410 
4411  // 1. Process non-broadcast dims
4412  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
4413  // 1.1. Already marked as in-bounds, nothing to see here.
4414  if (op.isDimInBounds(i)) {
4415  newInBounds.push_back(true);
4416  continue;
4417  }
4418  // 1.2. Currently out-of-bounds, check whether we can statically determine
4419  // it is inBounds.
4420  bool inBounds = false;
4421  auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
4422  if (dimExpr) {
4423  inBounds = isInBounds(op, /*resultIdx=*/i,
4424  /*indicesIdx=*/dimExpr.getPosition());
4425  nonBcastDims.push_back(i);
4426  }
4427 
4428  newInBounds.push_back(inBounds);
4429  // We commit the pattern if it is "more inbounds".
4430  changed |= inBounds;
4431  }
4432 
4433  // 2. Handle broadcast dims
4434  // If all non-broadcast dims are "in bounds", then all bcast dims should be
4435  // "in bounds" as well.
4436  bool allNonBcastDimsInBounds = llvm::all_of(
4437  nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
4438  if (allNonBcastDimsInBounds) {
4439  for (size_t idx : permutationMap.getBroadcastDims()) {
4440  changed |= !newInBounds[idx];
4441  newInBounds[idx] = true;
4442  }
4443  }
4444 
4445  if (!changed)
4446  return failure();
4447  // OpBuilder is only used as a helper to build an I64ArrayAttr.
4448  OpBuilder b(op.getContext());
4449  op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
4450  return success();
4451 }
4452 
4453 template <typename TransferOp>
4454 static LogicalResult foldTransferFullMask(TransferOp op) {
4455  auto mask = op.getMask();
4456  if (!mask)
4457  return failure();
4458 
4459  if (getMaskFormat(mask) != MaskFormat::AllTrue)
4460  return failure();
4461 
4462  op.getMaskMutable().clear();
4463  return success();
4464 }
4465 
4466 /// ```
4467 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4468 /// : vector<1x4xf32>, tensor<4x4xf32>
4469 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
4470 /// : tensor<4x4xf32>, vector<1x4xf32>
4471 /// ```
4472 /// -> Folds into
4473 /// ```
4474 /// %v0
4475 /// ```
4476 static Value foldRAW(TransferReadOp readOp) {
4477  if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4478  return {};
4479  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4480  while (defWrite) {
4481  if (checkSameValueRAW(defWrite, readOp))
4482  return defWrite.getVector();
4484  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4485  cast<VectorTransferOpInterface>(readOp.getOperation())))
4486  break;
4487  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4488  }
4489  return {};
4490 }
4491 
4492 OpFoldResult TransferReadOp::fold(FoldAdaptor) {
4493  if (Value vec = foldRAW(*this))
4494  return vec;
4495  /// transfer_read(memrefcast) -> transfer_read
4496  if (succeeded(foldTransferInBoundsAttribute(*this)))
4497  return getResult();
4498  if (succeeded(foldTransferFullMask(*this)))
4499  return getResult();
4500  if (succeeded(memref::foldMemRefCast(*this)))
4501  return getResult();
4502  if (succeeded(tensor::foldTensorCast(*this)))
4503  return getResult();
4504  return OpFoldResult();
4505 }
4506 
4507 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4508  return llvm::to_vector<4>(getVectorType().getShape());
4509 }
4510 
4511 void TransferReadOp::getEffects(
4513  &effects) {
4514  if (llvm::isa<MemRefType>(getShapedType()))
4515  effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(),
4517 }
4518 
4519 Speculation::Speculatability TransferReadOp::getSpeculatability() {
4520  if (hasPureTensorSemantics())
4523 }
4524 
4525 namespace {
4526 /// Store to load forwarding for transfer operations with permuation maps.
4527 /// Even if the permutation maps are different we can still propagate the store
4528 /// into the load if the size of the dimensions read and written match. Then we
4529 /// can replace the transfer_read + transfer_write by vector.broadcast and
4530 /// vector.transpose.
4531 /// Example:
4532 /// ```
4533 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
4534 /// {in_bounds = [true, true],
4535 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
4536 /// vector<4x1xf32>, tensor<4x4x4xf32>
4537 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
4538 /// {in_bounds = [true, true, true, true],
4539 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
4540 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
4541 /// ```
4542 /// To:
4543 /// ```
4544 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
4545 /// %r = vector.transpose %0, [3, 0, 2, 1] :
4546 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
4547 /// ```
4548 struct TransferReadAfterWriteToBroadcast
4549  : public OpRewritePattern<TransferReadOp> {
4551 
4552  LogicalResult matchAndRewrite(TransferReadOp readOp,
4553  PatternRewriter &rewriter) const override {
4554  if (readOp.hasOutOfBoundsDim() ||
4555  !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4556  return failure();
4557  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4558  if (!defWrite)
4559  return failure();
4560  // TODO: If the written transfer chunk is a superset of the read transfer
4561  // chunk we could do an extract_strided_slice.
4562  if (readOp.getTransferChunkAccessed() !=
4563  defWrite.getTransferChunkAccessed())
4564  return failure();
4565  // TODO: Support cases where a dim is explicitly written but implicitly
4566  // read (i.e., a unit dim that is rank reduced).
4567  if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
4568  getUnusedDimsBitVector({defWrite.getPermutationMap()}))
4569  return failure();
4570  if (readOp.getIndices() != defWrite.getIndices() ||
4571  readOp.getMask() != defWrite.getMask())
4572  return failure();
4573  Value vec = defWrite.getVector();
4574  // TODO: loop through the chain of transfer_write if we can prove that they
4575  // don't overlap with the transfer_read. This requires improving
4576  // `isDisjointTransferIndices` helper.
4577  AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4578  AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
4579  AffineMap map = readMap.compose(writeMap);
4580  if (map.getNumResults() == 0)
4581  return failure();
4582  // Calculate the permutation to apply to go from the vector stored to the
4583  // vector read.
4584  SmallVector<unsigned> permutation;
4585  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
4586  return failure();
4587 
4588  Location loc = readOp.getLoc();
4589  // Calculate the broadcast shape by applying the reverse permutation to the
4590  // final shape we want.
4591  ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
4592  SmallVector<int64_t> broadcastShape(destShape.size());
4593  SmallVector<bool> broadcastScalableFlags(destShape.size());
4594  for (const auto &pos : llvm::enumerate(permutation)) {
4595  broadcastShape[pos.value()] = destShape[pos.index()];
4596  broadcastScalableFlags[pos.value()] =
4597  readOp.getVectorType().getScalableDims()[pos.index()];
4598  }
4599  VectorType broadcastedType = VectorType::get(
4600  broadcastShape, defWrite.getVectorType().getElementType(),
4601  broadcastScalableFlags);
4602  vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
4603  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
4604  rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
4605  transposePerm);
4606  return success();
4607  }
4608 };
4609 } // namespace
4610 
4611 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4612  MLIRContext *context) {
4613  results.add<TransferReadAfterWriteToBroadcast>(context);
4614 }
4615 
4616 //===----------------------------------------------------------------------===//
4617 // TransferWriteOp
4618 //===----------------------------------------------------------------------===//
4619 
4620 /// 1. Builder with type inference.
4621 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4622  Value vector, Value dest, ValueRange indices,
4623  AffineMapAttr permutationMapAttr,
4624  /*optional*/ Value mask,
4625  /*optional*/ ArrayAttr inBoundsAttr) {
4626  Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
4627  build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4628  mask, inBoundsAttr);
4629 }
4630 
4631 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
4632 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4633  Value vector, Value dest, ValueRange indices,
4634  AffineMapAttr permutationMapAttr,
4635  /*optional*/ ArrayAttr inBoundsAttr) {
4636  build(builder, result, vector, dest, indices, permutationMapAttr,
4637  /*mask=*/Value(), inBoundsAttr);
4638 }
4639 
4640 /// 3. Builder with type inference that sets an empty mask (variant without
4641 /// attrs)
4642 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4643  Value vector, Value dest, ValueRange indices,
4644  AffineMap permutationMap,
4645  std::optional<ArrayRef<bool>> inBounds) {
4646  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4647  auto inBoundsAttr =
4648  (inBounds && !inBounds.value().empty())
4649  ? builder.getBoolArrayAttr(inBounds.value())
4651  llvm::cast<VectorType>(vector.getType()).getRank(), false));
4652  build(builder, result, vector, dest, indices, permutationMapAttr,
4653  /*mask=*/Value(), inBoundsAttr);
4654 }
4655 
4656 /// 4. Builder with type inference that sets an empty mask and sets permutation
4657 /// map to 'getMinorIdentityMap'.
4658 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4659  Value vector, Value dest, ValueRange indices,
4660  std::optional<ArrayRef<bool>> inBounds) {
4661  auto vectorType = llvm::cast<VectorType>(vector.getType());
4662  AffineMap permutationMap = getTransferMinorIdentityMap(
4663  llvm::cast<ShapedType>(dest.getType()), vectorType);
4664  build(builder, result, vector, dest, indices, permutationMap, inBounds);
4665 }
4666 
4667 ParseResult TransferWriteOp::parse(OpAsmParser &parser,
4668  OperationState &result) {
4669  auto &builder = parser.getBuilder();
4670  SMLoc typesLoc;
4671  OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
4673  SmallVector<Type, 2> types;
4675  if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
4676  parser.parseOperand(sourceInfo) ||
4677  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
4678  return failure();
4679  ParseResult hasMask = parser.parseOptionalComma();
4680  if (hasMask.succeeded() && parser.parseOperand(maskInfo))
4681  return failure();
4682  if (parser.parseOptionalAttrDict(result.attributes) ||
4683  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4684  return failure();
4685  if (types.size() != 2)
4686  return parser.emitError(typesLoc, "requires two types");
4687  auto indexType = builder.getIndexType();
4688  VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4689  if (!vectorType)
4690  return parser.emitError(typesLoc, "requires vector type");
4691  ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4692  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4693  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4694  auto permMapAttrName =
4695  TransferWriteOp::getPermutationMapAttrName(result.name);
4696  auto permMapAttr = result.attributes.get(permMapAttrName);
4697  AffineMap permMap;
4698  if (!permMapAttr) {
4699  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4700  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4701  } else {
4702  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4703  }
4704  auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
4705  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4706  if (!inBoundsAttr) {
4707  result.addAttribute(inBoundsAttrName,
4708  builder.getBoolArrayAttr(
4709  SmallVector<bool>(permMap.getNumResults(), false)));
4710  }
4711  if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
4712  parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4713  parser.resolveOperands(indexInfo, indexType, result.operands))
4714  return failure();
4715  if (hasMask.succeeded()) {
4716  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4717  return parser.emitError(
4718  maskInfo.location, "does not support masks with vector element type");
4719  if (vectorType.getRank() != permMap.getNumResults()) {
4720  return parser.emitError(typesLoc,
4721  "expected the same rank for the vector and the "
4722  "results of the permutation map");
4723  }
4724  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4725  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4726  return failure();
4727  }
4728  result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4729  builder.getDenseI32ArrayAttr(
4730  {1, 1, static_cast<int32_t>(indexInfo.size()),
4731  static_cast<int32_t>(hasMask.succeeded())}));
4732  return failure(llvm::isa<RankedTensorType>(shapedType) &&
4733  parser.addTypeToList(shapedType, result.types));
4734 }
4735 
4737  p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
4738  if (getMask())
4739  p << ", " << getMask();
4740  printTransferAttrs(p, *this);
4741  p << " : " << getVectorType() << ", " << getShapedType();
4742 }
4743 
4744 LogicalResult TransferWriteOp::verify() {
4745  // Consistency of elemental types in shape and vector.
4746  ShapedType shapedType = getShapedType();
4747  VectorType vectorType = getVectorType();
4748  VectorType maskType = getMaskType();
4749  auto permutationMap = getPermutationMap();
4750  VectorType inferredMaskType =
4751  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4752  : VectorType();
4753 
4754  if (llvm::size(getIndices()) != shapedType.getRank())
4755  return emitOpError("requires ") << shapedType.getRank() << " indices";
4756 
4757  // We do not allow broadcast dimensions on TransferWriteOps for the moment,
4758  // as the semantics is unclear. This can be revisited later if necessary.
4759  if (hasBroadcastDim())
4760  return emitOpError("should not have broadcast dimensions");
4761 
4762  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4763  shapedType, vectorType, maskType,
4764  inferredMaskType, permutationMap, getInBounds())))
4765  return failure();
4766 
4767  return verifyPermutationMap(permutationMap,
4768  [&](Twine t) { return emitOpError(t); });
4769 }
4770 
4771 // MaskableOpInterface methods.
4772 
4773 /// Returns the mask type expected by this operation. Mostly used for
4774 /// verification purposes.
4775 Type TransferWriteOp::getExpectedMaskType() {
4776  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4777 }
4778 
4779 /// Fold:
4780 /// ```
4781 /// %t1 = ...
4782 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
4783 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4784 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
4785 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4786 /// ```
4787 ///
4788 /// into:
4789 ///
4790 /// ```
4791 /// %t0
4792 /// ```
4793 ///
4794 /// The producer of t1 may or may not be DCE'd depending on whether it is a
4795 /// block argument or has side effects.
4796 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4798  SmallVectorImpl<OpFoldResult> &results) {
4799  // TODO: support 0-d corner case.
4800  if (write.getTransferRank() == 0)
4801  return failure();
4802  auto rankedTensorType =
4803  llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4804  // If not operating on tensors, bail.
4805  if (!rankedTensorType)
4806  return failure();
4807  // If no read, bail.
4808  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4809  if (!read)
4810  return failure();
4811  // TODO: support 0-d corner case.
4812  if (read.getTransferRank() == 0)
4813  return failure();
4814  // For now, only accept minor identity. Future: composition is minor identity.
4815  if (!read.getPermutationMap().isMinorIdentity() ||
4816  !write.getPermutationMap().isMinorIdentity())
4817  return failure();
4818  // Bail on mismatching ranks.
4819  if (read.getTransferRank() != write.getTransferRank())
4820  return failure();
4821  // Bail on potential out-of-bounds accesses.
4822  if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4823  return failure();
4824  // Tensor types must be the same.
4825  if (read.getSource().getType() != rankedTensorType)
4826  return failure();
4827  // Vector types must be the same.
4828  if (read.getVectorType() != write.getVectorType())
4829  return failure();
4830  // Vector and Tensor shapes must match.
4831  if (read.getVectorType().getShape() != rankedTensorType.getShape())
4832  return failure();
4833  // If any index is nonzero.
4834  auto isNotConstantZero = [](Value v) {
4835  auto cstOp = getConstantIntValue(v);
4836  return !cstOp.has_value() || cstOp.value() != 0;
4837  };
4838  if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4839  llvm::any_of(write.getIndices(), isNotConstantZero))
4840  return failure();
4841  // Success.
4842  results.push_back(read.getSource());
4843  return success();
4844 }
4845 
4846 static bool checkSameValueWAR(vector::TransferReadOp read,
4847  vector::TransferWriteOp write) {
4848  return read.getSource() == write.getSource() &&
4849  read.getIndices() == write.getIndices() &&
4850  read.getPermutationMap() == write.getPermutationMap() &&
4851  read.getVectorType() == write.getVectorType() && !read.getMask() &&
4852  !write.getMask();
4853 }
4854 /// Fold transfer_write write after read:
4855 /// ```
4856 /// %t0 = ...
4857 /// %v = vector.transfer_read %t0[%c0...] :
4858 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4859 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
4860 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4861 /// ```
4862 ///
4863 /// into:
4864 ///
4865 /// ```
4866 /// %t0
4867 /// ```
4868 static LogicalResult foldWAR(TransferWriteOp write,
4869  SmallVectorImpl<OpFoldResult> &results) {
4870  if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4871  return failure();
4872  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4873  if (!read)
4874  return failure();
4875 
4876  if (!checkSameValueWAR(read, write))
4877  return failure();
4878  results.push_back(read.getSource());
4879  return success();
4880 }
4881 
4882 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4883  SmallVectorImpl<OpFoldResult> &results) {
4884  if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
4885  return success();
4886  if (succeeded(foldWAR(*this, results)))
4887  return success();
4888  if (succeeded(foldTransferInBoundsAttribute(*this)))
4889  return success();
4890  if (succeeded(foldTransferFullMask(*this)))
4891  return success();
4892  return memref::foldMemRefCast(*this);
4893 }
4894 
4895 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4896  return llvm::to_vector<4>(getVectorType().getShape());
4897 }
4898 
4899 void TransferWriteOp::getEffects(
4901  &effects) {
4902  if (llvm::isa<MemRefType>(getShapedType()))
4903  effects.emplace_back(MemoryEffects::Write::get(), &getSourceMutable(),
4905 }
4906 
4907 Speculation::Speculatability TransferWriteOp::getSpeculatability() {
4908  if (hasPureTensorSemantics())
4911 }
4912 
4913 namespace {
4914 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
4915 /// DCE
4916 /// ```
4917 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4918 /// : vector<1x4xf32>, tensor<4x4xf32>
4919 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
4920 /// : vector<1x4xf32>, tensor<4x4xf32>
4921 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4922 /// : vector<1x4xf32>, tensor<4x4xf32>
4923 /// ```
4924 ///
4925 /// into:
4926 ///
4927 /// ```
4928 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4929 /// : vector<1x4xf32>, tensor<4x4xf32>
4930 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
4931 /// : vector<1x4xf32>, tensor<4x4xf32>
4932 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4933 /// : vector<1x4xf32>, tensor<4x4xf32>
4934 /// ```
4935 ///
4936 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
4937 /// any other uses.
4938 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
4939 public:
4941  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
4942  PatternRewriter &rewriter) const override {
4943  if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4944  return failure();
4945  vector::TransferWriteOp writeToModify = writeOp;
4946 
4947  auto defWrite =
4948  writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4949  while (defWrite) {
4950  if (checkSameValueWAW(writeOp, defWrite)) {
4951  rewriter.modifyOpInPlace(writeToModify, [&]() {
4952  writeToModify.getSourceMutable().assign(defWrite.getSource());
4953  });
4954  return success();
4955  }
4957  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4958  cast<VectorTransferOpInterface>(writeOp.getOperation())))
4959  break;
4960  // If the previous write op doesn't have any other use we an safely look
4961  // at the previous store to see if it can be removed.
4962  if (!defWrite->hasOneUse())
4963  break;
4964  writeToModify = defWrite;
4965  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4966  }
4967  return failure();
4968  }
4969 };
4970 
4971 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
4972 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
4973 /// overwritten and inserted into another tensor. After this rewrite, the
4974 /// operations bufferize in-place since all of them work on the same slice.
4975 ///
4976 /// For example:
4977 /// ```mlir
4978 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
4979 /// : vector<8x16xf32>, tensor<8x16xf32>
4980 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
4981 /// : tensor<8x16xf32> to tensor<?x?xf32>
4982 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4983 /// : tensor<?x?xf32> into tensor<27x37xf32>
4984 /// ```
4985 /// folds to
4986 /// ```mlir
4987 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4988 /// : tensor<27x37xf32> to tensor<?x?xf32>
4989 /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
4990 /// : vector<8x16xf32>, tensor<?x?xf32>
4991 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4992 /// : tensor<?x?xf32> into tensor<27x37xf32>
4993 /// ```
4994 struct SwapExtractSliceOfTransferWrite
4995  : public OpRewritePattern<tensor::InsertSliceOp> {
4996 public:
4998 
4999  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5000  PatternRewriter &rewriter) const override {
5001  if (!insertOp.hasUnitStride())
5002  return failure();
5003  auto extractOp =
5004  insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5005  if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5006  return failure();
5007  auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5008  if (!transferOp || !transferOp->hasOneUse())
5009  return failure();
5010 
5011  // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
5012  // rank-reducing.
5013  if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5014  return rewriter.notifyMatchFailure(insertOp,
5015  "use-def chain is rank-reducing");
5016  }
5017 
5018  // Fail if tensor::ExtractSliceOp has non-zero offset.
5019  if (!extractOp.hasZeroOffset()) {
5020  return rewriter.notifyMatchFailure(insertOp,
5021  "ExtractSliceOp has non-zero offset");
5022  }
5023 
5024  // Fail if tensor::TransferWriteOp has non-zero offset.
5025  if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5026  return getConstantIntValue(value) == static_cast<int64_t>(0);
5027  })) {
5028  return rewriter.notifyMatchFailure(insertOp,
5029  "TranferWriteOp has non-zero offset");
5030  }
5031 
5032  // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
5033  if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5034  return rewriter.notifyMatchFailure(
5035  insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
5036  }
5037 
5038  for (auto [insertSize, extractSize] :
5039  llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5040  if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
5041  return rewriter.notifyMatchFailure(
5042  insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
5043  }
5044  }
5045 
5046  // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
5047  assert(transferOp.getVectorType().hasStaticShape() &&
5048  "expected vector to have a static shape");
5049  ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
5051  transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5052  if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
5053  return rewriter.notifyMatchFailure(
5054  insertOp, "TransferWriteOp may not write the full tensor.");
5055  }
5056 
5057  // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
5058  // Set all in_bounds to false and let the folder infer them.
5059  SmallVector<bool> newInBounds(vectorShape.size(), false);
5060  auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
5061  extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
5062  insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
5063  insertOp.getMixedStrides());
5064  auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
5065  transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
5066  transferOp.getIndices(), transferOp.getPermutationMapAttr(),
5067  rewriter.getBoolArrayAttr(newInBounds));
5068  rewriter.modifyOpInPlace(insertOp, [&]() {
5069  insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5070  });
5071  return success();
5072  }
5073 };
5074 
5075 } // namespace
5076 
5077 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5078  MLIRContext *context) {
5079  results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5080 }
5081 
5082 //===----------------------------------------------------------------------===//
5083 // LoadOp
5084 //===----------------------------------------------------------------------===//
5085 
5086 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5087  VectorType vecTy,
5088  MemRefType memRefTy) {
5089  // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
5090  // need any strides limitations.
5091  if (!vecTy.isScalable() &&
5092  (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5093  return success();
5094 
5095  if (!memRefTy.isLastDimUnitStride())
5096  return op->emitOpError("most minor memref dim must have unit stride");
5097  return success();
5098 }
5099 
5100 LogicalResult vector::LoadOp::verify() {
5101  VectorType resVecTy = getVectorType();
5102  MemRefType memRefTy = getMemRefType();
5103 
5104  if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
5105  return failure();
5106 
5107  // Checks for vector memrefs.
5108  Type memElemTy = memRefTy.getElementType();
5109  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5110  if (memVecTy != resVecTy)
5111  return emitOpError("base memref and result vector types should match");
5112  memElemTy = memVecTy.getElementType();
5113  }
5114 
5115  if (resVecTy.getElementType() != memElemTy)
5116  return emitOpError("base and result element types should match");
5117  if (llvm::size(getIndices()) != memRefTy.getRank())
5118  return emitOpError("requires ") << memRefTy.getRank() << " indices";
5119  return success();
5120 }
5121 
5122 OpFoldResult LoadOp::fold(FoldAdaptor) {
5123  if (succeeded(memref::foldMemRefCast(*this)))
5124  return getResult();
5125  return OpFoldResult();
5126 }
5127 
5128 //===----------------------------------------------------------------------===//
5129 // StoreOp
5130 //===----------------------------------------------------------------------===//
5131 
5132 LogicalResult vector::StoreOp::verify() {
5133  VectorType valueVecTy = getVectorType();
5134  MemRefType memRefTy = getMemRefType();
5135 
5136  if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
5137  return failure();
5138 
5139  // Checks for vector memrefs.
5140  Type memElemTy = memRefTy.getElementType();
5141  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5142  if (memVecTy != valueVecTy)
5143  return emitOpError(
5144  "base memref and valueToStore vector types should match");
5145  memElemTy = memVecTy.getElementType();
5146  }
5147 
5148  if (valueVecTy.getElementType() != memElemTy)
5149  return emitOpError("base and valueToStore element type should match");
5150  if (llvm::size(getIndices()) != memRefTy.getRank())
5151  return emitOpError("requires ") << memRefTy.getRank() << " indices";
5152  return success();
5153 }
5154 
5155 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5156  SmallVectorImpl<OpFoldResult> &results) {
5157  return memref::foldMemRefCast(*this);
5158 }
5159 
5160 //===----------------------------------------------------------------------===//
5161 // MaskedLoadOp
5162 //===----------------------------------------------------------------------===//
5163 
5164 LogicalResult MaskedLoadOp::verify() {
5165  VectorType maskVType = getMaskVectorType();
5166  VectorType passVType = getPassThruVectorType();
5167  VectorType resVType = getVectorType();
5168  MemRefType memType = getMemRefType();
5169 
5170  if (resVType.getElementType() != memType.getElementType())
5171  return emitOpError("base and result element type should match");
5172  if (llvm::size(getIndices()) != memType.getRank())
5173  return emitOpError("requires ") << memType.getRank() << " indices";
5174  if (resVType.getShape() != maskVType.getShape())
5175  return emitOpError("expected result shape to match mask shape");
5176  if (resVType != passVType)
5177  return emitOpError("expected pass_thru of same type as result type");
5178  return success();
5179 }
5180 
5181 namespace {
5182 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
5183 public:
5185  LogicalResult matchAndRewrite(MaskedLoadOp load,
5186  PatternRewriter &rewriter) const override {
5187  switch (getMaskFormat(load.getMask())) {
5188  case MaskFormat::AllTrue:
5189  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5190  load, load.getType(), load.getBase(), load.getIndices());
5191  return success();
5192  case MaskFormat::AllFalse:
5193  rewriter.replaceOp(load, load.getPassThru());
5194  return success();
5195  case MaskFormat::Unknown:
5196  return failure();
5197  }
5198  llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
5199  }
5200 };
5201 } // namespace
5202 
5203 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5204  MLIRContext *context) {
5205  results.add<MaskedLoadFolder>(context);
5206 }
5207 
5208 OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5209  if (succeeded(memref::foldMemRefCast(*this)))
5210  return getResult();
5211  return OpFoldResult();
5212 }
5213 
5214 //===----------------------------------------------------------------------===//
5215 // MaskedStoreOp
5216 //===----------------------------------------------------------------------===//
5217 
5218 LogicalResult MaskedStoreOp::verify() {
5219  VectorType maskVType = getMaskVectorType();
5220  VectorType valueVType = getVectorType();
5221  MemRefType memType = getMemRefType();
5222 
5223  if (valueVType.getElementType() != memType.getElementType())
5224  return emitOpError("base and valueToStore element type should match");
5225  if (llvm::size(getIndices()) != memType.getRank())
5226  return emitOpError("requires ") << memType.getRank() << " indices";
5227  if (valueVType.getShape() != maskVType.getShape())
5228  return emitOpError("expected valueToStore shape to match mask shape");
5229  return success();
5230 }
5231 
5232 namespace {
5233 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
5234 public:
5236  LogicalResult matchAndRewrite(MaskedStoreOp store,
5237  PatternRewriter &rewriter) const override {
5238  switch (getMaskFormat(store.getMask())) {
5239  case MaskFormat::AllTrue:
5240  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5241  store, store.getValueToStore(), store.getBase(), store.getIndices());
5242  return success();
5243  case MaskFormat::AllFalse:
5244  rewriter.eraseOp(store);
5245  return success();
5246  case MaskFormat::Unknown:
5247  return failure();
5248  }
5249  llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
5250  }
5251 };
5252 } // namespace
5253 
5254 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5255  MLIRContext *context) {
5256  results.add<MaskedStoreFolder>(context);
5257 }
5258 
5259 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5260  SmallVectorImpl<OpFoldResult> &results) {
5261  return memref::foldMemRefCast(*this);
5262 }
5263 
5264 //===----------------------------------------------------------------------===//
5265 // GatherOp
5266 //===----------------------------------------------------------------------===//
5267 
5268 LogicalResult GatherOp::verify() {
5269  VectorType indVType = getIndexVectorType();
5270  VectorType maskVType = getMaskVectorType();
5271  VectorType resVType = getVectorType();
5272  ShapedType baseType = getBaseType();
5273 
5274  if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5275  return emitOpError("requires base to be a memref or ranked tensor type");
5276 
5277  if (resVType.getElementType() != baseType.getElementType())
5278  return emitOpError("base and result element type should match");
5279  if (llvm::size(getIndices()) != baseType.getRank())
5280  return emitOpError("requires ") << baseType.getRank() << " indices";
5281  if (resVType.getShape() != indVType.getShape())
5282  return emitOpError("expected result dim to match indices dim");
5283  if (resVType.getShape() != maskVType.getShape())
5284  return emitOpError("expected result dim to match mask dim");
5285  if (resVType != getPassThruVectorType())
5286  return emitOpError("expected pass_thru of same type as result type");
5287  return success();
5288 }
5289 
5290 // MaskableOpInterface methods.
5291 
5292 /// Returns the mask type expected by this operation. Mostly used for
5293 /// verification purposes. It requires the operation to be vectorized."
5294 Type GatherOp::getExpectedMaskType() {
5295  auto vecType = this->getIndexVectorType();
5296  return VectorType::get(vecType.getShape(),
5297  IntegerType::get(vecType.getContext(), /*width=*/1),
5298  vecType.getScalableDims());
5299 }
5300 
5301 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5302  return llvm::to_vector<4>(getVectorType().getShape());
5303 }
5304 
5305 /// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
5306 static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
5307  auto vecType = dyn_cast<VectorType>(indexVec.getType());
5308  if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5309  return failure();
5310 
5311  if (indexVec.getDefiningOp<StepOp>())
5312  return success();
5313 
5314  DenseIntElementsAttr elements;
5315  if (!matchPattern(indexVec, m_Constant(&elements)))
5316  return failure();
5317 
5318  return success(
5319  llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
5320 }
5321 
5322 namespace {
5323 class GatherFolder final : public OpRewritePattern<GatherOp> {
5324 public:
5326  LogicalResult matchAndRewrite(GatherOp gather,
5327  PatternRewriter &rewriter) const override {
5328  switch (getMaskFormat(gather.getMask())) {
5329  case MaskFormat::AllTrue:
5330  return failure(); // no unmasked equivalent
5331  case MaskFormat::AllFalse:
5332  rewriter.replaceOp(gather, gather.getPassThru());
5333  return success();
5334  case MaskFormat::Unknown:
5335  return failure();
5336  }
5337  llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
5338  }
5339 };
5340 
5341 /// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
5342 /// maskedload. Only 1D fixed vectors are supported for now.
5343 class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
5344 public:
5346  LogicalResult matchAndRewrite(GatherOp op,
5347  PatternRewriter &rewriter) const override {
5348  if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5349  return failure();
5350 
5351  rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
5352  op.getIndices(), op.getMask(),
5353  op.getPassThru());
5354  return success();
5355  }
5356 };
5357 } // namespace
5358 
5359 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
5360  MLIRContext *context) {
5361  results.add<GatherFolder, FoldContiguousGather>(context);
5362 }
5363 
5364 //===----------------------------------------------------------------------===//
5365 // ScatterOp
5366 //===----------------------------------------------------------------------===//
5367 
5368 LogicalResult ScatterOp::verify() {
5369  VectorType indVType = getIndexVectorType();
5370  VectorType maskVType = getMaskVectorType();
5371  VectorType valueVType = getVectorType();
5372  MemRefType memType = getMemRefType();
5373 
5374  if (valueVType.getElementType() != memType.getElementType())
5375  return emitOpError("base and valueToStore element type should match");
5376  if (llvm::size(getIndices()) != memType.getRank())
5377  return emitOpError("requires ") << memType.getRank() << " indices";
5378  if (valueVType.getDimSize(0) != indVType.getDimSize(0))
5379  return emitOpError("expected valueToStore dim to match indices dim");
5380  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5381  return emitOpError("expected valueToStore dim to match mask dim");
5382  return success();
5383 }
5384 
5385 namespace {
5386 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
5387 public:
5389  LogicalResult matchAndRewrite(ScatterOp scatter,
5390  PatternRewriter &rewriter) const override {
5391  switch (getMaskFormat(scatter.getMask())) {
5392  case MaskFormat::AllTrue:
5393  return failure(); // no unmasked equivalent
5394  case MaskFormat::AllFalse:
5395  rewriter.eraseOp(scatter);
5396  return success();
5397  case MaskFormat::Unknown:
5398  return failure();
5399  }
5400  llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
5401  }
5402 };
5403 
5404 /// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
5405 /// maskedstore. Only 1D fixed vectors are supported for now.
5406 class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
5407 public:
5409  LogicalResult matchAndRewrite(ScatterOp op,
5410  PatternRewriter &rewriter) const override {
5411  if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5412  return failure();
5413 
5414  rewriter.replaceOpWithNewOp<MaskedStoreOp>(
5415  op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
5416  return success();
5417  }
5418 };
5419 } // namespace
5420 
5421 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
5422  MLIRContext *context) {
5423  results.add<ScatterFolder, FoldContiguousScatter>(context);
5424 }
5425 
5426 //===----------------------------------------------------------------------===//
5427 // ExpandLoadOp
5428 //===----------------------------------------------------------------------===//
5429 
5430 LogicalResult ExpandLoadOp::verify() {
5431  VectorType maskVType = getMaskVectorType();
5432  VectorType passVType = getPassThruVectorType();
5433  VectorType resVType = getVectorType();
5434  MemRefType memType = getMemRefType();
5435 
5436  if (resVType.getElementType() != memType.getElementType())
5437  return emitOpError("base and result element type should match");
5438  if (llvm::size(getIndices()) != memType.getRank())
5439  return emitOpError("requires ") << memType.getRank() << " indices";
5440  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5441  return emitOpError("expected result dim to match mask dim");
5442  if (resVType != passVType)
5443  return emitOpError("expected pass_thru of same type as result type");
5444  return success();
5445 }
5446 
5447 namespace {
5448 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
5449 public:
5451  LogicalResult matchAndRewrite(ExpandLoadOp expand,
5452  PatternRewriter &rewriter) const override {
5453  switch (getMaskFormat(expand.getMask())) {
5454  case MaskFormat::AllTrue:
5455  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5456  expand, expand.getType(), expand.getBase(), expand.getIndices());
5457  return success();
5458  case MaskFormat::AllFalse:
5459  rewriter.replaceOp(expand, expand.getPassThru());
5460  return success();
5461  case MaskFormat::Unknown:
5462  return failure();
5463  }
5464  llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
5465  }
5466 };
5467 } // namespace
5468 
5469 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5470  MLIRContext *context) {
5471  results.add<ExpandLoadFolder>(context);
5472 }
5473 
5474 //===----------------------------------------------------------------------===//
5475 // CompressStoreOp
5476 //===----------------------------------------------------------------------===//
5477 
5478 LogicalResult CompressStoreOp::verify() {
5479  VectorType maskVType = getMaskVectorType();
5480  VectorType valueVType = getVectorType();
5481  MemRefType memType = getMemRefType();
5482 
5483  if (valueVType.getElementType() != memType.getElementType())
5484  return emitOpError("base and valueToStore element type should match");
5485  if (llvm::size(getIndices()) != memType.getRank())
5486  return emitOpError("requires ") << memType.getRank() << " indices";
5487  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5488  return emitOpError("expected valueToStore dim to match mask dim");
5489  return success();
5490 }
5491 
5492 namespace {
5493 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
5494 public:
5496  LogicalResult matchAndRewrite(CompressStoreOp compress,
5497  PatternRewriter &rewriter) const override {
5498  switch (getMaskFormat(compress.getMask())) {
5499  case MaskFormat::AllTrue:
5500  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5501  compress, compress.getValueToStore(), compress.getBase(),
5502  compress.getIndices());
5503  return success();
5504  case MaskFormat::AllFalse:
5505  rewriter.eraseOp(compress);
5506  return success();
5507  case MaskFormat::Unknown:
5508  return failure();
5509  }
5510  llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
5511  }
5512 };
5513 } // namespace
5514 
5515 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5516  MLIRContext *context) {
5517  results.add<CompressStoreFolder>(context);
5518 }
5519 
5520 //===----------------------------------------------------------------------===//
5521 // ShapeCastOp
5522 //===----------------------------------------------------------------------===//
5523 
5524 void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
5525  SetIntRangeFn setResultRanges) {
5526  setResultRanges(getResult(), argRanges.front());
5527 }
5528 
5529 /// Returns true if each element of 'a' is equal to the product of a contiguous
5530 /// sequence of the elements of 'b'. Returns false otherwise.
5531 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
5532  unsigned rankA = a.size();
5533  unsigned rankB = b.size();
5534  assert(rankA < rankB);
5535 
5536  auto isOne = [](int64_t v) { return v == 1; };
5537 
5538  // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5539  // casted to a 0-d vector.
5540  if (rankA == 0 && llvm::all_of(b, isOne))
5541  return true;
5542 
5543  unsigned i = 0;
5544  unsigned j = 0;
5545  while (i < rankA && j < rankB) {
5546  int64_t dimA = a[i];
5547  int64_t dimB = 1;
5548  while (dimB < dimA && j < rankB)
5549  dimB *= b[j++];
5550  if (dimA != dimB)
5551  break;
5552  ++i;
5553 
5554  // Handle the case when trailing dimensions are of size 1.
5555  // Include them into the contiguous sequence.
5556  if (i < rankA && llvm::all_of(a.slice(i), isOne))
5557  i = rankA;
5558  if (j < rankB && llvm::all_of(b.slice(j), isOne))
5559  j = rankB;
5560  }
5561 
5562  return i == rankA && j == rankB;
5563 }
5564 
5565 static LogicalResult verifyVectorShapeCast(Operation *op,
5566  VectorType sourceVectorType,
5567  VectorType resultVectorType) {
5568  // Check that element type is the same.
5569  if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5570  return op->emitOpError("source/result vectors must have same element type");
5571  auto sourceShape = sourceVectorType.getShape();
5572  auto resultShape = resultVectorType.getShape();
5573 
5574  // Check that product of source dim sizes matches product of result dim sizes.
5575  int64_t sourceDimProduct = std::accumulate(
5576  sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5577  int64_t resultDimProduct = std::accumulate(
5578  resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5579  if (sourceDimProduct != resultDimProduct)
5580  return op->emitOpError("source/result number of elements must match");
5581 
5582  // Check that expanding/contracting rank cases.
5583  unsigned sourceRank = sourceVectorType.getRank();
5584  unsigned resultRank = resultVectorType.getRank();
5585  if (sourceRank < resultRank) {
5586  if (!isValidShapeCast(sourceShape, resultShape))
5587  return op->emitOpError("invalid shape cast");
5588  } else if (sourceRank > resultRank) {
5589  if (!isValidShapeCast(resultShape, sourceShape))
5590  return op->emitOpError("invalid shape cast");
5591  }
5592 
5593  // Check that (non-)scalability is preserved
5594  int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5595  int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5596  if (sourceNScalableDims != resultNScalableDims)
5597  return op->emitOpError("different number of scalable dims at source (")
5598  << sourceNScalableDims << ") and result (" << resultNScalableDims
5599  << ")";
5600  sourceVectorType.getNumDynamicDims();
5601 
5602  return success();
5603 }
5604 
5605 LogicalResult ShapeCastOp::verify() {
5606  auto sourceVectorType =
5607  llvm::dyn_cast_or_null<VectorType>(getSource().getType());
5608  auto resultVectorType =
5609  llvm::dyn_cast_or_null<VectorType>(getResult().getType());
5610 
5611  // Check if source/result are of vector type.
5612  if (sourceVectorType && resultVectorType)
5613  return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
5614 
5615  return success();
5616 }
5617 
5618 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5619  // No-op shape cast.
5620  if (getSource().getType() == getResult().getType())
5621  return getSource();
5622 
5623  // Canceling shape casts.
5624  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5625  if (getResult().getType() == otherOp.getSource().getType())
5626  return otherOp.getSource();
5627 
5628  // Only allows valid transitive folding.
5629  VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5630  VectorType resultType = llvm::cast<VectorType>(getResult().getType());
5631  if (srcType.getRank() < resultType.getRank()) {
5632  if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5633  return {};
5634  } else if (srcType.getRank() > resultType.getRank()) {
5635  if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5636  return {};
5637  } else {
5638  return {};
5639  }
5640 
5641  setOperand(otherOp.getSource());
5642  return getResult();
5643  }
5644 
5645  // Cancelling broadcast and shape cast ops.
5646  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5647  if (bcastOp.getSourceType() == getType())
5648  return bcastOp.getSource();
5649  }
5650 
5651  return {};
5652 }
5653 
5654 namespace {
5655 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
5656 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
5657 public:
5659 
5660  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5661  PatternRewriter &rewriter) const override {
5662  auto constantOp =
5663  shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5664  if (!constantOp)
5665  return failure();
5666  // Only handle splat for now.
5667  auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5668  if (!dense)
5669  return failure();
5670  auto newAttr =
5671  DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
5672  dense.getSplatValue<Attribute>());
5673  rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
5674  return success();
5675  }
5676 };
5677 
5678 /// Helper function that computes a new vector type based on the input vector
5679 /// type by removing the trailing one dims:
5680 ///
5681 /// vector<4x1x1xi1> --> vector<4x1>
5682 ///
5683 static VectorType trimTrailingOneDims(VectorType oldType) {
5684  ArrayRef<int64_t> oldShape = oldType.getShape();
5685  ArrayRef<int64_t> newShape = oldShape;
5686 
5687  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
5688  ArrayRef<bool> newScalableDims = oldScalableDims;
5689 
5690  while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5691  newShape = newShape.drop_back(1);
5692  newScalableDims = newScalableDims.drop_back(1);
5693  }
5694 
5695  // Make sure we have at least 1 dimension.
5696  // TODO: Add support for 0-D vectors.
5697  if (newShape.empty()) {
5698  newShape = oldShape.take_back();
5699  newScalableDims = oldScalableDims.take_back();
5700  }
5701 
5702  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5703 }
5704 
5705 /// Folds qualifying shape_cast(create_mask) into a new create_mask
5706 ///
5707 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
5708 /// dimension. If the input vector comes from `vector.create_mask` for which
5709 /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
5710 /// to fold shape_cast into create_mask.
5711 ///
5712 /// BEFORE:
5713 /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
5714 /// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
5715 /// AFTER:
5716 /// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
5717 class ShapeCastCreateMaskFolderTrailingOneDim final
5718  : public OpRewritePattern<ShapeCastOp> {
5719 public:
5721 
5722  LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
5723  PatternRewriter &rewriter) const override {
5724  Value shapeOpSrc = shapeOp->getOperand(0);
5725  auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
5726  auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
5727  if (!createMaskOp && !constantMaskOp)
5728  return failure();
5729 
5730  VectorType shapeOpResTy = shapeOp.getResultVectorType();
5731  VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5732 
5733  VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5734  if (newVecType != shapeOpResTy)
5735  return failure();
5736 
5737  auto numDimsToDrop =
5738  shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5739 
5740  // No unit dims to drop
5741  if (!numDimsToDrop)
5742  return failure();
5743 
5744  if (createMaskOp) {
5745  auto maskOperands = createMaskOp.getOperands();
5746  auto numMaskOperands = maskOperands.size();
5747 
5748  // Check every mask dim size to see whether it can be dropped
5749  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5750  --i) {
5751  auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5752  if (!constant || (constant.value() != 1))
5753  return failure();
5754  }
5755  SmallVector<Value> newMaskOperands =
5756  maskOperands.drop_back(numDimsToDrop);
5757 
5758  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
5759  newMaskOperands);
5760  return success();
5761  }
5762 
5763  if (constantMaskOp) {
5764  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5765  auto numMaskOperands = maskDimSizes.size();
5766 
5767  // Check every mask dim size to see whether it can be dropped
5768  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5769  --i) {
5770  if (maskDimSizes[i] != 1)
5771  return failure();
5772  }
5773 
5774  auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5775  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
5776  newMaskOperands);
5777  return success();
5778  }
5779 
5780  return failure();
5781  }
5782 };
5783 
5784 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
5785 /// This only applies when the shape of the broadcast source
5786 /// 1. is a suffix of the shape of the result (i.e. when broadcast without
5787 /// reshape is expressive enough to capture the result in a single op), or
5788 /// 2. has the same element count as the shape cast result.
5789 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5790 public:
5792 
5793  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5794  PatternRewriter &rewriter) const override {
5795  auto broadcastOp =
5796  shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5797  if (!broadcastOp)
5798  return failure();
5799 
5800  ArrayRef<int64_t> broadcastSourceShape;
5801  if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5802  broadcastSourceShape = srcType.getShape();
5803  ArrayRef<int64_t> shapeCastTargetShape =
5804  shapeCastOp.getResultVectorType().getShape();
5805 
5806  // If `broadcastSourceShape` is a suffix of the result, we can just replace
5807  // with a broadcast to the final shape.
5808  if (broadcastSourceShape ==
5809  shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5810  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5811  shapeCastOp, shapeCastOp.getResultVectorType(),
5812  broadcastOp.getSource());
5813  return success();
5814  }
5815 
5816  // Otherwise, if the final result has the same element count, we can replace
5817  // with a shape cast.
5818  if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5819  if (srcType.getNumElements() ==
5820  shapeCastOp.getResultVectorType().getNumElements()) {
5821  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
5822  shapeCastOp, shapeCastOp.getResultVectorType(),
5823  broadcastOp.getSource());
5824  return success();
5825  }
5826  }
5827 
5828  return failure();
5829  }
5830 };
5831 
5832 } // namespace
5833 
5834 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
5835  MLIRContext *context) {
5836  results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5837  ShapeCastBroadcastFolder>(context);
5838 }
5839 
5840 //===----------------------------------------------------------------------===//
5841 // VectorBitCastOp
5842 //===----------------------------------------------------------------------===//
5843 
5844 LogicalResult BitCastOp::verify() {
5845  auto sourceVectorType = getSourceVectorType();
5846  auto resultVectorType = getResultVectorType();
5847 
5848  for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5849  if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5850  return emitOpError("dimension size mismatch at: ") << i;
5851  }
5852 
5853  DataLayout dataLayout = DataLayout::closest(*this);
5854  auto sourceElementBits =
5855  dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
5856  auto resultElementBits =
5857  dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
5858 
5859  if (sourceVectorType.getRank() == 0) {
5860  if (sourceElementBits != resultElementBits)
5861  return emitOpError("source/result bitwidth of the 0-D vector element "
5862  "types must be equal");
5863  } else if (sourceElementBits * sourceVectorType.getShape().back() !=
5864  resultElementBits * resultVectorType.getShape().back()) {
5865  return emitOpError(
5866  "source/result bitwidth of the minor 1-D vectors must be equal");
5867  }
5868 
5869  return success();
5870 }
5871 
5872 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
5873  // Nop cast.
5874  if (getSource().getType() == getResult().getType())
5875  return getSource();
5876 
5877  // Canceling bitcasts.
5878  if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5879  if (getResult().getType() == otherOp.getSource().getType())
5880  return otherOp.getSource();
5881 
5882  setOperand(otherOp.getSource());
5883  return getResult();
5884  }
5885 
5886  Attribute sourceConstant = adaptor.getSource();
5887  if (!sourceConstant)
5888  return {};
5889 
5890  Type srcElemType = getSourceVectorType().getElementType();
5891  Type dstElemType = getResultVectorType().getElementType();
5892 
5893  if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5894  if (floatPack.isSplat()) {
5895  auto splat = floatPack.getSplatValue<FloatAttr>();
5896 
5897  // Casting fp16 into fp32.
5898  if (srcElemType.isF16() && dstElemType.isF32()) {
5899  uint32_t bits = static_cast<uint32_t>(
5900  splat.getValue().bitcastToAPInt().getZExtValue());
5901  // Duplicate the 16-bit pattern.
5902  bits = (bits << 16) | (bits & 0xffff);
5903  APInt intBits(32, bits);
5904  APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5905  return DenseElementsAttr::get(getResultVectorType(), floatBits);
5906  }
5907  }
5908  }
5909 
5910  if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5911  if (intPack.isSplat()) {
5912  auto splat = intPack.getSplatValue<IntegerAttr>();
5913 
5914  if (llvm::isa<IntegerType>(dstElemType)) {
5915  uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
5916  uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
5917 
5918  // Casting to a larger integer bit width.
5919  if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5920  APInt intBits = splat.getValue().zext(dstBitWidth);
5921 
5922  // Duplicate the lower width element.
5923  for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5924  intBits = (intBits << srcBitWidth) | intBits;
5925  return DenseElementsAttr::get(getResultVectorType(), intBits);
5926  }
5927  }
5928  }
5929  }
5930 
5931  return {};
5932 }
5933 
5934 //===----------------------------------------------------------------------===//
5935 // TypeCastOp
5936 //===----------------------------------------------------------------------===//
5937 
5938 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
5939  auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5940  SmallVector<int64_t, 8> res(memRefType.getShape());
5941  if (vectorType)
5942  res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5943  return res;
5944 }
5945 
5946 /// Build the canonical memRefType with a single vector.
5947 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
5948 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
5949  Value source) {
5950  result.addOperands(source);
5951  MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
5952  VectorType vectorType =
5953  VectorType::get(extractShape(memRefType),
5955  result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
5956  memRefType.getMemorySpace()));
5957 }
5958 
5959 LogicalResult TypeCastOp::verify() {
5960  MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
5961  if (!canonicalType.getLayout().isIdentity())
5962  return emitOpError("expects operand to be a memref with identity layout");
5963  if (!getResultMemRefType().getLayout().isIdentity())
5964  return emitOpError("expects result to be a memref with identity layout");
5965  if (getResultMemRefType().getMemorySpace() !=
5966  getMemRefType().getMemorySpace())
5967  return emitOpError("expects result in same memory space");
5968 
5969  auto sourceType = getMemRefType();
5970  auto resultType = getResultMemRefType();
5971  if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
5973  return emitOpError(
5974  "expects result and operand with same underlying scalar type: ")
5975  << resultType;
5976  if (extractShape(sourceType) != extractShape(resultType))
5977  return emitOpError(
5978  "expects concatenated result and operand shapes to be equal: ")
5979  << resultType;
5980  return success();
5981 }
5982 
5983 //===----------------------------------------------------------------------===//
5984 // TransposeOp
5985 //===----------------------------------------------------------------------===//
5986 
5987 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
5988  Value vector, ArrayRef<int64_t> permutation) {
5989  VectorType vt = llvm::cast<VectorType>(vector.getType());
5990  SmallVector<int64_t, 4> transposedShape(vt.getRank());
5991  SmallVector<bool, 4> transposedScalableDims(vt.getRank());
5992  for (unsigned i = 0; i < permutation.size(); ++i) {
5993  transposedShape[i] = vt.getShape()[permutation[i]];
5994  transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5995  }
5996 
5997  result.addOperands(vector);
5998  result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
5999  transposedScalableDims));
6000  result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
6001  builder.getDenseI64ArrayAttr(permutation));
6002 }
6003 
6004 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6005  // Eliminate splat constant transpose ops.
6006  if (auto attr =
6007  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
6008  if (attr.isSplat())
6009  return attr.reshape(getResultVectorType());
6010 
6011  // Eliminate identity transpose ops. This happens when the dimensions of the
6012  // input vector remain in their original order after the transpose operation.
6013  ArrayRef<int64_t> perm = getPermutation();
6014 
6015  // Check if the permutation of the dimensions contains sequential values:
6016  // {0, 1, 2, ...}.
6017  for (int64_t i = 0, e = perm.size(); i < e; i++) {
6018  if (perm[i] != i)
6019  return {};
6020  }
6021 
6022  return getVector();
6023 }
6024 
6025 LogicalResult vector::TransposeOp::verify() {
6026  VectorType vectorType = getSourceVectorType();
6027  VectorType resultType = getResultVectorType();
6028  int64_t rank = resultType.getRank();
6029  if (vectorType.getRank() != rank)
6030  return emitOpError("vector result rank mismatch: ") << rank;
6031  // Verify transposition array.
6032  ArrayRef<int64_t> perm = getPermutation();
6033  int64_t size = perm.size();
6034  if (rank != size)
6035  return emitOpError("transposition length mismatch: ") << size;
6036  SmallVector<bool, 8> seen(rank, false);
6037  for (const auto &ta : llvm::enumerate(perm)) {
6038  if (ta.value() < 0 || ta.value() >= rank)
6039  return emitOpError("transposition index out of range: ") << ta.value();
6040  if (seen[ta.value()])
6041  return emitOpError("duplicate position index: ") << ta.value();
6042  seen[ta.value()] = true;
6043  if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6044  return emitOpError("dimension size mismatch at: ") << ta.value();
6045  }
6046  return success();
6047 }
6048 
6049 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6050  return llvm::to_vector<4>(getResultVectorType().getShape());
6051 }
6052 
6053 namespace {
6054 
6055 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
6056 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
6057 public:
6059 
6060  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6061  PatternRewriter &rewriter) const override {
6062  // Composes two permutations: result[i] = permutation1[permutation2[i]].
6063  auto composePermutations = [](ArrayRef<int64_t> permutation1,
6064  ArrayRef<int64_t> permutation2) {
6065  SmallVector<int64_t, 4> result;
6066  for (auto index : permutation2)
6067  result.push_back(permutation1[index]);
6068  return result;
6069  };
6070 
6071  // Return if the input of 'transposeOp' is not defined by another transpose.
6072  vector::TransposeOp parentTransposeOp =
6073  transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6074  if (!parentTransposeOp)
6075  return failure();
6076 
6077  SmallVector<int64_t, 4> permutation = composePermutations(
6078  parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6079  // Replace 'transposeOp' with a new transpose operation.
6080  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
6081  transposeOp, transposeOp.getResult().getType(),
6082  parentTransposeOp.getVector(), permutation);
6083  return success();
6084  }
6085 };
6086 
6087 // Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
6088 struct FoldTransposedScalarBroadcast final
6089  : public OpRewritePattern<vector::TransposeOp> {
6091 
6092  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6093  PatternRewriter &rewriter) const override {
6094  auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
6095  if (!bcastOp)
6096  return failure();
6097 
6098  auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
6099  if (!srcVectorType || srcVectorType.getNumElements() == 1) {
6100  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6101  transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
6102  return success();
6103  }
6104 
6105  return failure();
6106  }
6107 };
6108 
6109 // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
6110 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
6111 public:
6113 
6114  LogicalResult matchAndRewrite(TransposeOp transposeOp,
6115  PatternRewriter &rewriter) const override {
6116  auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
6117  if (!splatOp)
6118  return failure();
6119 
6120  rewriter.replaceOpWithNewOp<vector::SplatOp>(
6121  transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
6122  return success();
6123  }
6124 };
6125 
6126 /// Folds transpose(create_mask) into a new transposed create_mask.