MLIR  20.0.0git
VectorOps.cpp
Go to the documentation of this file.
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
24 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/AffineMap.h"
26 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/BuiltinOps.h"
29 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/IRMapping.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/IR/TypeUtilities.h"
37 #include "mlir/Support/LLVM.h"
39 #include "llvm/ADT/ArrayRef.h"
40 #include "llvm/ADT/STLExtras.h"
41 #include "llvm/ADT/SmallVector.h"
42 #include "llvm/ADT/StringSet.h"
43 #include "llvm/ADT/TypeSwitch.h"
44 #include "llvm/ADT/bit.h"
45 
46 #include <cassert>
47 #include <cstdint>
48 #include <numeric>
49 
50 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
51 // Pull in all enum type and utility function definitions.
52 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
53 
54 using namespace mlir;
55 using namespace mlir::vector;
56 
57 /// Helper enum to classify mask value.
58 enum class MaskFormat {
59  AllTrue = 0,
60  AllFalse = 1,
61  Unknown = 2,
62 };
63 
64 /// Helper method to classify a mask value. Currently, the method
65 /// looks "under the hood" of a constant value with dense attributes
66 /// and a constant mask operation (since the client may be called at
67 /// various stages during progressive lowering).
69  if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
70  // Inspect constant dense values. We count up for bits that
71  // are set, count down for bits that are cleared, and bail
72  // when a mix is detected.
73  if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
74  int64_t val = 0;
75  for (bool b : denseElts.getValues<bool>())
76  if (b && val >= 0)
77  val++;
78  else if (!b && val <= 0)
79  val--;
80  else
81  return MaskFormat::Unknown;
82  if (val > 0)
83  return MaskFormat::AllTrue;
84  if (val < 0)
85  return MaskFormat::AllFalse;
86  }
87  } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
88  // Inspect constant mask index. If the index exceeds the
89  // dimension size, all bits are set. If the index is zero
90  // or less, no bits are set.
91  ArrayRef<int64_t> masks = m.getMaskDimSizes();
92  auto shape = m.getType().getShape();
93  bool allTrue = true;
94  bool allFalse = true;
95  for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
96  if (maskIdx < dimSize)
97  allTrue = false;
98  if (maskIdx > 0)
99  allFalse = false;
100  }
101  if (allTrue)
102  return MaskFormat::AllTrue;
103  if (allFalse)
104  return MaskFormat::AllFalse;
105  } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
106  // Finds all-false create_masks. An all-true create_mask requires all
107  // dims to be constants, so that'll be folded to a constant_mask, then
108  // detected in the constant_mask case.
109  auto maskOperands = m.getOperands();
110  for (Value operand : maskOperands) {
111  if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
112  int64_t dimSize =
113  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
114  if (dimSize <= 0)
115  return MaskFormat::AllFalse;
116  }
117  }
118  return MaskFormat::Unknown;
119  }
120  return MaskFormat::Unknown;
121 }
122 
123 /// Default callback to build a region with a 'vector.yield' terminator with no
124 /// arguments.
126  builder.create<vector::YieldOp>(loc);
127 }
128 
129 // Helper for verifying combining kinds in contractions and reductions.
130 static bool isSupportedCombiningKind(CombiningKind combiningKind,
131  Type elementType) {
132  switch (combiningKind) {
133  case CombiningKind::ADD:
134  case CombiningKind::MUL:
135  return elementType.isIntOrIndexOrFloat();
137  case CombiningKind::MINSI:
138  case CombiningKind::MAXUI:
139  case CombiningKind::MAXSI:
140  case CombiningKind::AND:
141  case CombiningKind::OR:
142  case CombiningKind::XOR:
143  return elementType.isIntOrIndex();
144  case CombiningKind::MINNUMF:
145  case CombiningKind::MAXNUMF:
146  case CombiningKind::MINIMUMF:
147  case CombiningKind::MAXIMUMF:
148  return llvm::isa<FloatType>(elementType);
149  }
150  return false;
151 }
152 
154  VectorType vectorType) {
155  int64_t elementVectorRank = 0;
156  VectorType elementVectorType =
157  llvm::dyn_cast<VectorType>(shapedType.getElementType());
158  if (elementVectorType)
159  elementVectorRank += elementVectorType.getRank();
160  // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
161  // TODO: replace once we have 0-d vectors.
162  if (shapedType.getRank() == 0 &&
163  vectorType.getShape() == ArrayRef<int64_t>{1})
164  return AffineMap::get(
165  /*numDims=*/0, /*numSymbols=*/0,
166  getAffineConstantExpr(0, shapedType.getContext()));
168  shapedType.getRank(), vectorType.getRank() - elementVectorRank,
169  shapedType.getContext());
170 }
171 
172 /// Check if `write` is of a constant splat and the masked `read` is padded with
173 /// the same splat value -- meaning it could be the same value as the initial
174 /// constant splat.
175 static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
176  vector::TransferReadOp read) {
177  auto readMask = read.getMask();
178  auto writeMask = write.getMask();
179  // Check if the masks are consistent. The splat value could be the same if the
180  // read is masked (and padded with the splat value), and the write is unmasked
181  // or has the same mask. Note this does not allow the case where the write is
182  // masked and the read is unmasked, as then the read could be of more elements
183  // than the write (which may not be the same value).
184  bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
185  if (!couldBeSameSplat)
186  return false;
187  // Check for constant splat (as the source of the write).
188  DenseElementsAttr splatAttr;
189  if (!matchPattern(write.getVector(),
190  m_Constant<DenseElementsAttr>(&splatAttr)) ||
191  !splatAttr.isSplat()) {
192  return false;
193  }
194  // The padding of the read and the constant splat value must be the same.
195  Attribute padAttr;
196  if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
197  return false;
198  return padAttr == splatAttr.getSplatValue<Attribute>();
199 }
200 
201 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
202  vector::TransferReadOp read) {
203  return !defWrite.hasOutOfBoundsDim() &&
204  defWrite.getIndices() == read.getIndices() &&
205  defWrite.getVectorType() == read.getVectorType() &&
206  defWrite.getPermutationMap() == read.getPermutationMap() &&
207  ((!defWrite.getMask() && !read.getMask()) ||
208  isSplatWriteConsistentWithMaskedRead(defWrite, read));
209 }
210 
211 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
212  vector::TransferWriteOp priorWrite) {
213  return priorWrite.getIndices() == write.getIndices() &&
214  priorWrite.getMask() == write.getMask() &&
215  priorWrite.getVectorType() == write.getVectorType() &&
216  priorWrite.getPermutationMap() == write.getPermutationMap();
217 }
218 
220  VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
221  bool testDynamicValueUsingBounds) {
222  // For simplicity only look at transfer of same type.
223  if (transferA.getVectorType() != transferB.getVectorType())
224  return false;
225  unsigned rankOffset = transferA.getLeadingShapedRank();
226  for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
227  Value indexA = transferA.getIndices()[i];
228  Value indexB = transferB.getIndices()[i];
229  std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
230  std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
231 
232  if (i < rankOffset) {
233  // For leading dimensions, if we can prove that index are different we
234  // know we are accessing disjoint slices.
235  if (cstIndexA.has_value() && cstIndexB.has_value()) {
236  if (*cstIndexA != *cstIndexB)
237  return true;
238  continue;
239  }
240  if (testDynamicValueUsingBounds) {
241  // First try to see if we can fully compose and simplify the affine
242  // expression as a fast track.
243  FailureOr<uint64_t> delta =
245  if (succeeded(delta) && *delta != 0)
246  return true;
247 
248  FailureOr<bool> testEqual =
249  ValueBoundsConstraintSet::areEqual(indexA, indexB);
250  if (succeeded(testEqual) && !testEqual.value())
251  return true;
252  }
253  } else {
254  // For this dimension, we slice a part of the memref we need to make sure
255  // the intervals accessed don't overlap.
256  int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
257  if (cstIndexA.has_value() && cstIndexB.has_value()) {
258  int64_t distance = std::abs(*cstIndexA - *cstIndexB);
259  if (distance >= vectorDim)
260  return true;
261  continue;
262  }
263  if (testDynamicValueUsingBounds) {
264  // First try to see if we can fully compose and simplify the affine
265  // expression as a fast track.
266  FailureOr<int64_t> delta =
268  if (succeeded(delta) && std::abs(*delta) >= vectorDim)
269  return true;
270 
271  FailureOr<int64_t> computeDelta =
273  if (succeeded(computeDelta)) {
274  if (std::abs(computeDelta.value()) >= vectorDim)
275  return true;
276  }
277  }
278  }
279  }
280  return false;
281 }
282 
283 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
284  VectorTransferOpInterface transferB,
285  bool testDynamicValueUsingBounds) {
286  if (transferA.getSource() != transferB.getSource())
287  return false;
288  return isDisjointTransferIndices(transferA, transferB,
289  testDynamicValueUsingBounds);
290 }
291 
292 // Helper to iterate over n-D vector slice elements. Calculate the next
293 // `position` in the n-D vector of size `shape`, applying an offset `offsets`.
294 // Modifies the `position` in place. Returns a failure when `position` becomes
295 // the end position.
296 static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
297  ArrayRef<int64_t> shape,
298  ArrayRef<int64_t> offsets) {
299  for (auto [posInDim, dimSize, offsetInDim] :
300  llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
301  ++posInDim;
302  if (posInDim < dimSize + offsetInDim)
303  return success();
304 
305  // Carry the overflow to the next loop iteration.
306  posInDim = offsetInDim;
307  }
308 
309  return failure();
310 }
311 
312 /// Returns the integer numbers in `values`. `values` are expected to be
313 /// constant operations.
316  llvm::transform(values, std::back_inserter(ints), [](Value value) {
317  auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
318  assert(constOp && "Unexpected non-constant index");
319  return constOp.value();
320  });
321  return ints;
322 }
323 
324 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
325 /// be constant operations.
328  llvm::transform(
329  foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
330  assert(foldResult.is<Attribute>() && "Unexpected non-constant index");
331  return cast<IntegerAttr>(foldResult.get<Attribute>()).getInt();
332  });
333  return ints;
334 }
335 
336 /// Convert `foldResults` into Values. Integer attributes are converted to
337 /// constant op.
339  ArrayRef<OpFoldResult> foldResults) {
340  SmallVector<Value> values;
341  llvm::transform(foldResults, std::back_inserter(values),
342  [&](OpFoldResult foldResult) {
343  if (auto attr = foldResult.dyn_cast<Attribute>())
344  return builder
346  loc, cast<IntegerAttr>(attr).getInt())
347  .getResult();
348 
349  return foldResult.get<Value>();
350  });
351  return values;
352 }
353 
354 std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
355  if (value.getDefiningOp<vector::VectorScaleOp>())
356  return 1;
357  auto mul = value.getDefiningOp<arith::MulIOp>();
358  if (!mul)
359  return {};
360  auto lhs = mul.getLhs();
361  auto rhs = mul.getRhs();
362  if (lhs.getDefiningOp<vector::VectorScaleOp>())
363  return getConstantIntValue(rhs);
364  if (rhs.getDefiningOp<vector::VectorScaleOp>())
365  return getConstantIntValue(lhs);
366  return {};
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // CombiningKindAttr
371 //===----------------------------------------------------------------------===//
372 
373 namespace mlir {
374 namespace vector {
375 namespace detail {
377  using KeyTy = uint64_t;
378 
379  BitmaskEnumStorage(KeyTy val) : value(val) {}
380 
381  bool operator==(const KeyTy &key) const { return value == key; }
382 
384  const KeyTy &key) {
385  return new (allocator.allocate<BitmaskEnumStorage>())
386  BitmaskEnumStorage(key);
387  }
388 
389  KeyTy value = 0;
390 };
391 } // namespace detail
392 } // namespace vector
393 } // namespace mlir
394 
395 //===----------------------------------------------------------------------===//
396 // VectorDialect
397 //===----------------------------------------------------------------------===//
398 
399 namespace {
400 /// This class defines the interface for handling inlining with vector dialect
401 /// operations.
402 struct VectorInlinerInterface : public DialectInlinerInterface {
404 
405  /// All vector dialect ops can be inlined.
406  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
407  return true;
408  }
409 };
410 } // namespace
411 
412 void VectorDialect::initialize() {
413  addAttributes<
414 #define GET_ATTRDEF_LIST
415 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
416  >();
417 
418  addOperations<
419 #define GET_OP_LIST
420 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
421  >();
422 
423  addInterfaces<VectorInlinerInterface>();
424 
425  declarePromisedInterfaces<bufferization::BufferizableOpInterface,
426  TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
427  YieldOp>();
428  declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
429  TransferWriteOp>();
430  declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
431  declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
432 }
433 
434 /// Materialize a single constant operation from a given attribute value with
435 /// the desired resultant type.
437  Attribute value, Type type,
438  Location loc) {
439  return arith::ConstantOp::materialize(builder, value, type, loc);
440 }
441 
443  return builder.getIntegerType(64);
444 }
445 
447  ArrayRef<int64_t> values) {
448  return builder.getI64ArrayAttr(values);
449 }
450 
451 //===----------------------------------------------------------------------===//
452 // MultiDimReductionOp
453 //===----------------------------------------------------------------------===//
454 
455 void vector::MultiDimReductionOp::build(OpBuilder &builder,
456  OperationState &result, Value source,
457  Value acc, ArrayRef<bool> reductionMask,
458  CombiningKind kind) {
459  SmallVector<int64_t> reductionDims;
460  for (const auto &en : llvm::enumerate(reductionMask))
461  if (en.value())
462  reductionDims.push_back(en.index());
463  build(builder, result, kind, source, acc, reductionDims);
464 }
465 
466 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
467  // Single parallel dim, this is a noop.
468  if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
469  return getSource();
470  return {};
471 }
472 
473 std::optional<SmallVector<int64_t, 4>>
474 MultiDimReductionOp::getShapeForUnroll() {
475  return llvm::to_vector<4>(getSourceVectorType().getShape());
476 }
477 
478 LogicalResult MultiDimReductionOp::verify() {
479  SmallVector<int64_t> targetShape;
480  SmallVector<bool> scalableDims;
481  Type inferredReturnType;
482  auto sourceScalableDims = getSourceVectorType().getScalableDims();
483  for (auto [dimIdx, dimSize] :
484  llvm::enumerate(getSourceVectorType().getShape()))
485  if (!llvm::any_of(getReductionDims(),
486  [dimIdx = dimIdx](int64_t reductionDimIdx) {
487  return reductionDimIdx == static_cast<int64_t>(dimIdx);
488  })) {
489  targetShape.push_back(dimSize);
490  scalableDims.push_back(sourceScalableDims[dimIdx]);
491  }
492  // TODO: update to also allow 0-d vectors when available.
493  if (targetShape.empty())
494  inferredReturnType = getSourceVectorType().getElementType();
495  else
496  inferredReturnType = VectorType::get(
497  targetShape, getSourceVectorType().getElementType(), scalableDims);
498  if (getType() != inferredReturnType)
499  return emitOpError() << "destination type " << getType()
500  << " is incompatible with source type "
501  << getSourceVectorType();
502 
503  return success();
504 }
505 
506 /// Returns the mask type expected by this operation.
507 Type MultiDimReductionOp::getExpectedMaskType() {
508  auto vecType = getSourceVectorType();
509  return VectorType::get(vecType.getShape(),
510  IntegerType::get(vecType.getContext(), /*width=*/1),
511  vecType.getScalableDims());
512 }
513 
514 namespace {
515 // Only unit dimensions that are being reduced are folded. If the dimension is
516 // unit, but not reduced, it is not folded, thereby keeping the output type the
517 // same. If not all dimensions which are reduced are of unit dimension, this
518 // transformation does nothing. This is just a generalization of
519 // ElideSingleElementReduction for ReduceOp.
520 struct ElideUnitDimsInMultiDimReduction
521  : public OpRewritePattern<MultiDimReductionOp> {
523 
524  LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
525  PatternRewriter &rewriter) const override {
526  ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
527  for (const auto &dim : enumerate(shape)) {
528  if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
529  return failure();
530  }
531 
532  // Vector mask setup.
533  OpBuilder::InsertionGuard guard(rewriter);
534  Operation *rootOp;
535  Value mask;
536  if (reductionOp.isMasked()) {
537  rewriter.setInsertionPoint(reductionOp.getMaskingOp());
538  rootOp = reductionOp.getMaskingOp();
539  mask = reductionOp.getMaskingOp().getMask();
540  } else {
541  rootOp = reductionOp;
542  }
543 
544  Location loc = reductionOp.getLoc();
545  Value acc = reductionOp.getAcc();
546  Value cast;
547  if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
548  if (mask) {
549  VectorType newMaskType =
550  VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
551  dstVecType.getScalableDims());
552  mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
553  }
554  cast = rewriter.create<vector::ShapeCastOp>(
555  loc, reductionOp.getDestType(), reductionOp.getSource());
556  } else {
557  // This means we are reducing all the dimensions, and all reduction
558  // dimensions are of size 1. So a simple extraction would do.
559  SmallVector<int64_t> zeroIdx(shape.size(), 0);
560  if (mask)
561  mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx);
562  cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(),
563  zeroIdx);
564  }
565 
566  Value result =
567  vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
568  cast, /*fastmath=*/nullptr, mask);
569  rewriter.replaceOp(rootOp, result);
570  return success();
571  }
572 };
573 } // namespace
574 
575 void MultiDimReductionOp::getCanonicalizationPatterns(
576  RewritePatternSet &results, MLIRContext *context) {
577  results.add<ElideUnitDimsInMultiDimReduction>(context);
578 }
579 
580 //===----------------------------------------------------------------------===//
581 // ReductionOp
582 //===----------------------------------------------------------------------===//
583 
584 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
585  CombiningKind kind, Value vector,
586  arith::FastMathFlags fastMathFlags) {
587  build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
588 }
589 
590 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
591  CombiningKind kind, Value vector, Value acc,
592  arith::FastMathFlags fastMathFlags) {
593  build(builder, result,
594  llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
595  acc, fastMathFlags);
596 }
597 
598 LogicalResult ReductionOp::verify() {
599  // Verify for 0-D and 1-D vector.
600  int64_t rank = getSourceVectorType().getRank();
601  if (rank > 1)
602  return emitOpError("unsupported reduction rank: ") << rank;
603 
604  // Verify supported reduction kind.
605  Type eltType = getDest().getType();
606  if (!isSupportedCombiningKind(getKind(), eltType))
607  return emitOpError("unsupported reduction type '")
608  << eltType << "' for kind '" << stringifyCombiningKind(getKind())
609  << "'";
610 
611  return success();
612 }
613 
614 // MaskableOpInterface methods.
615 
616 /// Returns the mask type expected by this operation.
617 Type ReductionOp::getExpectedMaskType() {
618  auto vecType = getSourceVectorType();
619  return VectorType::get(vecType.getShape(),
620  IntegerType::get(vecType.getContext(), /*width=*/1),
621  vecType.getScalableDims());
622 }
623 
625  OpBuilder &builder, Location loc,
626  Value vector) {
627  switch (op) {
628  case arith::AtomicRMWKind::addf:
629  case arith::AtomicRMWKind::addi:
630  return builder.create<vector::ReductionOp>(vector.getLoc(),
631  CombiningKind::ADD, vector);
632  case arith::AtomicRMWKind::mulf:
633  case arith::AtomicRMWKind::muli:
634  return builder.create<vector::ReductionOp>(vector.getLoc(),
635  CombiningKind::MUL, vector);
636  case arith::AtomicRMWKind::minimumf:
637  return builder.create<vector::ReductionOp>(vector.getLoc(),
638  CombiningKind::MINIMUMF, vector);
639  case arith::AtomicRMWKind::mins:
640  return builder.create<vector::ReductionOp>(vector.getLoc(),
641  CombiningKind::MINSI, vector);
642  case arith::AtomicRMWKind::minu:
643  return builder.create<vector::ReductionOp>(vector.getLoc(),
644  CombiningKind::MINUI, vector);
645  case arith::AtomicRMWKind::maximumf:
646  return builder.create<vector::ReductionOp>(vector.getLoc(),
647  CombiningKind::MAXIMUMF, vector);
648  case arith::AtomicRMWKind::maxs:
649  return builder.create<vector::ReductionOp>(vector.getLoc(),
650  CombiningKind::MAXSI, vector);
651  case arith::AtomicRMWKind::maxu:
652  return builder.create<vector::ReductionOp>(vector.getLoc(),
653  CombiningKind::MAXUI, vector);
654  case arith::AtomicRMWKind::andi:
655  return builder.create<vector::ReductionOp>(vector.getLoc(),
656  CombiningKind::AND, vector);
657  case arith::AtomicRMWKind::ori:
658  return builder.create<vector::ReductionOp>(vector.getLoc(),
659  CombiningKind::OR, vector);
660  // TODO: Add remaining reduction operations.
661  default:
662  (void)emitOptionalError(loc, "Reduction operation type not supported");
663  break;
664  }
665  return nullptr;
666 }
667 
668 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
669  return llvm::to_vector<4>(getSourceVectorType().getShape());
670 }
671 
672 namespace {
673 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
675 
676  LogicalResult matchAndRewrite(ReductionOp reductionOp,
677  PatternRewriter &rewriter) const override {
678  // Vector mask setup.
679  OpBuilder::InsertionGuard guard(rewriter);
680  auto maskableOp =
681  cast<vector::MaskableOpInterface>(reductionOp.getOperation());
682  Operation *rootOp;
683  Value mask;
684  if (maskableOp.isMasked()) {
685  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
686  rootOp = maskableOp.getMaskingOp();
687  mask = maskableOp.getMaskingOp().getMask();
688  } else {
689  rootOp = reductionOp;
690  }
691 
692  auto vectorType = reductionOp.getSourceVectorType();
693  if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
694  return failure();
695 
696  Location loc = reductionOp.getLoc();
697  Value result;
698  if (vectorType.getRank() == 0) {
699  if (mask)
700  mask = rewriter.create<ExtractElementOp>(loc, mask);
701  result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
702  } else {
703  if (mask)
704  mask = rewriter.create<ExtractOp>(loc, mask, 0);
705  result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0);
706  }
707 
708  if (Value acc = reductionOp.getAcc())
709  result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
710  result, acc,
711  reductionOp.getFastmathAttr(), mask);
712 
713  rewriter.replaceOp(rootOp, result);
714  return success();
715  }
716 };
717 } // namespace
718 
719 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
720  MLIRContext *context) {
721  results.add<ElideSingleElementReduction>(context);
722 }
723 
724 //===----------------------------------------------------------------------===//
725 // ContractionOp
726 //===----------------------------------------------------------------------===//
727 
728 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
729  Value lhs, Value rhs, Value acc,
730  ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
731  ArrayRef<IteratorType> iteratorTypes) {
732  result.addOperands({lhs, rhs, acc});
733  result.addTypes(acc.getType());
734  result.addAttribute(
735  getIndexingMapsAttrName(result.name),
736  builder.getAffineMapArrayAttr(
737  AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
738  result.addAttribute(
739  getIteratorTypesAttrName(result.name),
740  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
741  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
742  return IteratorTypeAttr::get(builder.getContext(), t);
743  }))));
744 }
745 
746 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
747  Value lhs, Value rhs, Value acc,
748  ArrayAttr indexingMaps,
749  ArrayAttr iteratorTypes) {
750  build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
751  ContractionOp::getDefaultKind());
752 }
753 
754 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
755  Value lhs, Value rhs, Value acc,
756  ArrayAttr indexingMaps,
757  ArrayAttr iteratorTypes, CombiningKind kind) {
758  result.addOperands({lhs, rhs, acc});
759  result.addTypes(acc.getType());
760  result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
761  result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
762  result.addAttribute(getKindAttrName(result.name),
763  CombiningKindAttr::get(builder.getContext(), kind));
764 }
765 
766 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
771  SmallVector<Type, 2> types;
772  Type resultType;
773  auto loc = parser.getCurrentLocation();
774  DictionaryAttr dictAttr;
775  // TODO: Unify linalg op attribute parsing.
776  if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
777  parser.parseComma() || parser.parseOperand(rhsInfo) ||
778  parser.parseComma() || parser.parseOperand(accInfo) ||
779  parser.parseTrailingOperandList(masksInfo) ||
780  parser.parseOptionalAttrDict(result.attributes) ||
781  parser.parseColonTypeList(types) ||
782  parser.parseKeywordType("into", resultType) ||
783  parser.resolveOperand(lhsInfo, types[0], result.operands) ||
784  parser.resolveOperand(rhsInfo, types[1], result.operands) ||
785  parser.resolveOperand(accInfo, resultType, result.operands) ||
786  parser.addTypeToList(resultType, result.types))
787  return failure();
788  result.attributes.append(dictAttr.getValue().begin(),
789  dictAttr.getValue().end());
790 
791  // Convert array of string into an array of IteratyType enums. This is needed,
792  // because tests still use the old format when 'iterator_types' attribute is
793  // represented as an array of strings.
794  // TODO: Remove this conversion once tests are fixed.
795  ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
796  result.attributes.get(getIteratorTypesAttrName(result.name)));
797 
798  SmallVector<Attribute> iteratorTypeAttrs;
799 
800  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
801  auto maybeIteratorType = symbolizeIteratorType(s);
802  if (!maybeIteratorType.has_value())
803  return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
804 
805  iteratorTypeAttrs.push_back(
806  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
807  }
808  result.attributes.set(getIteratorTypesAttrName(result.name),
809  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
810 
811  if (!result.attributes.get(getKindAttrName(result.name))) {
812  result.addAttribute(
813  getKindAttrName(result.name),
815  ContractionOp::getDefaultKind()));
816  }
817  if (masksInfo.empty())
818  return success();
819  if (masksInfo.size() != 2)
820  return parser.emitError(parser.getNameLoc(),
821  "expected zero or exactly 2 vector mask operands");
822  auto lhsType = llvm::cast<VectorType>(types[0]);
823  auto rhsType = llvm::cast<VectorType>(types[1]);
824  auto maskElementType = parser.getBuilder().getI1Type();
825  std::array<VectorType, 2> maskTypes = {
826  VectorType::Builder(lhsType).setElementType(maskElementType),
827  VectorType::Builder(rhsType).setElementType(maskElementType)};
828  if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
829  return failure();
830  return success();
831 }
832 
834  // TODO: Unify printing code with linalg ops.
835  auto attrNames = getTraitAttrNames();
836  llvm::StringSet<> traitAttrsSet;
837  traitAttrsSet.insert(attrNames.begin(), attrNames.end());
839  for (auto attr : (*this)->getAttrs()) {
840  if (attr.getName() == getIteratorTypesAttrName()) {
841  auto iteratorTypes =
842  llvm::cast<ArrayAttr>(attr.getValue())
843  .getAsValueRange<IteratorTypeAttr, IteratorType>();
844  // Convert IteratorType enums into the string representation. This is
845  // needed, because tests still use the old format when 'iterator_types'
846  // attribute is represented as an array of strings.
847  // TODO: Remove this conversion once tests are fixed.
848  SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
849  llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
850  return StringAttr::get(getContext(), stringifyIteratorType(t));
851  }));
852 
853  attrs.emplace_back(getIteratorTypesAttrName(),
854  ArrayAttr::get(getContext(), iteratorTypeNames));
855  } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
856  attrs.push_back(attr);
857  }
858 
859  auto dictAttr = DictionaryAttr::get(getContext(), attrs);
860  p << " " << dictAttr << " " << getLhs() << ", ";
861  p << getRhs() << ", " << getAcc();
862 
863  p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
864  p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
865  << getResultType();
866 }
867 
868 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
869  const std::vector<std::pair<int64_t, int64_t>> &map) {
870  for (auto &dimPair : map) {
871  if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
872  dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
873  lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
874  return false;
875  }
876  return true;
877 }
878 
879 static LogicalResult verifyOutputShape(
880  ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
881  Type resType,
882  const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
883  const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
884  DenseSet<int64_t> lhsContractingDimSet;
885  DenseSet<int64_t> rhsContractingDimSet;
886  for (auto &dimPair : contractingDimMap) {
887  lhsContractingDimSet.insert(dimPair.first);
888  rhsContractingDimSet.insert(dimPair.second);
889  }
890  DenseSet<int64_t> rhsBatchDimSet;
891  for (auto &dimPair : batchDimMap)
892  rhsBatchDimSet.insert(dimPair.second);
893 
894  // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
895  SmallVector<int64_t, 4> expectedResultDims;
896  for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
897  if (lhsContractingDimSet.count(i) > 0)
898  continue;
899  expectedResultDims.push_back(lhsType.getDimSize(i));
900  }
901 
902  // Add free dimensions from 'rhsType' to 'expectedResultDims'.
903  for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
904  if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
905  continue;
906  expectedResultDims.push_back(rhsType.getDimSize(i));
907  }
908 
909  // Verify 'expectedResultDims'.
910  if (expectedResultDims.empty()) {
911  // No batch or free dimension implies a scalar result.
912  if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
913  return op.emitOpError("invalid accumulator/result vector shape");
914  } else {
915  // At least one batch or free dimension implies a vector result.
916  auto resVectorType = llvm::dyn_cast<VectorType>(resType);
917  auto accVectorType = llvm::dyn_cast<VectorType>(accType);
918  if (!resVectorType || !accVectorType)
919  return op.emitOpError("invalid accumulator/result vector shape");
920 
921  // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
922  // types fully define the result vector type. This assumes the affine maps
923  // are well-formed, which must have been verified already.
924  MLIRContext *ctx = op.getContext();
925  AffineMap lhsMap = op.getIndexingMapsArray()[0];
926  AffineMap rhsMap = op.getIndexingMapsArray()[1];
927  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
928  return op.emitOpError(
929  "expected all dimensions to be either a LHS or a RHS dimension");
930  SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
931  for (auto pair :
932  {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
933  VectorType v = pair.first;
934  auto map = pair.second;
935  for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
936  unsigned pos = map.getDimPosition(idx);
937  if (!extents[pos])
938  extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
939  }
940  }
941  if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
942  return op.emitOpError("expected all dimensions to get an extent as "
943  "either a LHS or a RHS dimension");
944 
945  AffineMap resMap = op.getIndexingMapsArray()[2];
946  auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
947  /*symbolCount=*/0, extents, ctx);
948  // Compose the resMap with the extentsMap, which is a constant map.
949  AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
950  assert(llvm::all_of(expectedMap.getResults(),
951  llvm::IsaPred<AffineConstantExpr>) &&
952  "expected constant extent along all dimensions.");
953  // Extract the expected shape and build the type.
954  auto expectedShape = llvm::to_vector<4>(
955  llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
956  return cast<AffineConstantExpr>(e).getValue();
957  }));
958  auto expected =
959  VectorType::get(expectedShape, resVectorType.getElementType(),
960  resVectorType.getScalableDims());
961  if (resVectorType != expected || accVectorType != expected)
962  return op.emitOpError(
963  "invalid accumulator/result vector shape, expected: ")
964  << expected;
965  }
966  return success();
967 }
968 
969 LogicalResult ContractionOp::verify() {
970  VectorType lhsType = getLhsType();
971  VectorType rhsType = getRhsType();
972  Type accType = getAccType();
973  Type resType = getResultType();
974 
975  if (llvm::isa<IntegerType>(lhsType.getElementType())) {
976  if (!lhsType.getElementType().isSignlessInteger())
977  return emitOpError("only supports signless integer types");
978  }
979 
980  // Verify that an indexing map was specified for each vector operand.
981  if (getIndexingMapsArray().size() != 3)
982  return emitOpError("expected an indexing map for each vector operand");
983 
984  // Verify that each index map has 'numIterators' inputs, no symbols, and
985  // that the number of map outputs equals the rank of its associated
986  // vector operand.
987  unsigned numIterators = getIteratorTypes().getValue().size();
988  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
989  auto index = it.index();
990  auto map = it.value();
991  if (map.getNumSymbols() != 0)
992  return emitOpError("expected indexing map ")
993  << index << " to have no symbols";
994  auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
995  unsigned rank = vectorType ? vectorType.getShape().size() : 0;
996  // Verify that the map has the right number of inputs, outputs, and indices.
997  // This also correctly accounts for (..) -> () for rank-0 results.
998  if (map.getNumDims() != numIterators)
999  return emitOpError("expected indexing map ")
1000  << index << " to have " << numIterators << " number of inputs";
1001  if (map.getNumResults() != rank)
1002  return emitOpError("expected indexing map ")
1003  << index << " to have " << rank << " number of outputs";
1004  if (!map.isProjectedPermutation())
1005  return emitOpError("expected indexing map ")
1006  << index << " to be a projected permutation of its inputs";
1007  }
1008 
1009  auto contractingDimMap = getContractingDimMap();
1010  auto batchDimMap = getBatchDimMap();
1011 
1012  // Verify at least one contracting dimension pair was specified.
1013  if (contractingDimMap.empty())
1014  return emitOpError("expected at least one contracting dimension pair");
1015 
1016  // Verify contracting dimension map was properly constructed.
1017  if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1018  return emitOpError("invalid contracting dimension map");
1019 
1020  // Verify batch dimension map was properly constructed.
1021  if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1022  return emitOpError("invalid batch dimension map");
1023 
1024  // Verify 'accType' and 'resType' shape.
1025  if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1026  contractingDimMap, batchDimMap)))
1027  return failure();
1028 
1029  // Verify supported combining kind.
1030  auto vectorType = llvm::dyn_cast<VectorType>(resType);
1031  auto elementType = vectorType ? vectorType.getElementType() : resType;
1032  if (!isSupportedCombiningKind(getKind(), elementType))
1033  return emitOpError("unsupported contraction type");
1034 
1035  return success();
1036 }
1037 
1038 // MaskableOpInterface methods.
1039 
1040 /// Returns the mask type expected by this operation. Mostly used for
1041 /// verification purposes. It requires the operation to be vectorized."
1042 Type ContractionOp::getExpectedMaskType() {
1043  auto indexingMaps = this->getIndexingMapsArray();
1044  AffineMap lhsIdxMap = indexingMaps[0];
1045  AffineMap rhsIdxMap = indexingMaps[1];
1046  VectorType lhsType = this->getLhsType();
1047  VectorType rhsType = this->getRhsType();
1048 
1049  unsigned numVecDims = lhsIdxMap.getNumDims();
1050  SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1051  SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1052 
1053  // Using the information in the indexing maps, extract the size of each
1054  // dimension in the vector.contract operation from the two input operands.
1055  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1056  maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1057  maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1058  lhsType.getScalableDims()[dimIdx];
1059  }
1060  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1061  maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1062  maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1063  rhsType.getScalableDims()[dimIdx];
1064  }
1065 
1066  assert(!ShapedType::isDynamicShape(maskShape) &&
1067  "Mask shape couldn't be computed");
1068 
1069  return VectorType::get(maskShape,
1070  IntegerType::get(lhsType.getContext(), /*width=*/1),
1071  maskShapeScalableDims);
1072 }
1073 
1074 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1075  return SmallVector<StringRef>{getIndexingMapsAttrName(),
1076  getIteratorTypesAttrName(), getKindAttrName()};
1077 }
1078 
1079 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
1080  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1081  if (targetExpr == map.getResult(i))
1082  return i;
1083  return -1;
1084 }
1085 
1086 static std::vector<std::pair<int64_t, int64_t>>
1087 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1088  IteratorType targetIteratorType, MLIRContext *context) {
1089  std::vector<std::pair<int64_t, int64_t>> dimMap;
1090  for (const auto &it : llvm::enumerate(iteratorTypes)) {
1091  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1092  if (iteratorType != targetIteratorType)
1093  continue;
1094  // Search lhs/rhs map results for 'targetExpr'.
1095  auto targetExpr = getAffineDimExpr(it.index(), context);
1096  int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1097  int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1098  if (lhsDim >= 0 && rhsDim >= 0)
1099  dimMap.emplace_back(lhsDim, rhsDim);
1100  }
1101  return dimMap;
1102 }
1103 
1104 void ContractionOp::getIterationBounds(
1105  SmallVectorImpl<int64_t> &iterationBounds) {
1106  auto lhsShape = getLhsType().getShape();
1107  auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1108  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1109  SmallVector<int64_t, 2> iterationShape;
1110  for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1111  // Search lhs/rhs map results for 'targetExpr'.
1112  auto targetExpr = getAffineDimExpr(it.index(), getContext());
1113  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1114  if (iteratorType == IteratorType::reduction) {
1115  // Get reduction dim size from lhs shape (same size in rhsShape).
1116  int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1117  assert(lhsDimIndex >= 0);
1118  iterationBounds.push_back(lhsShape[lhsDimIndex]);
1119  continue;
1120  }
1121  // Get parallel dimension size from result shape.
1122  int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1123  assert(resDimIndex >= 0);
1124  assert(resVectorType != nullptr);
1125  iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1126  }
1127 }
1128 
1129 void ContractionOp::getIterationIndexMap(
1130  std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1131  unsigned numMaps = getIndexingMapsArray().size();
1132  iterationIndexMap.resize(numMaps);
1133  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1134  auto index = it.index();
1135  auto map = it.value();
1136  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1137  auto dim = cast<AffineDimExpr>(map.getResult(i));
1138  iterationIndexMap[index][dim.getPosition()] = i;
1139  }
1140  }
1141 }
1142 
1143 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1144  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1145  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1146  getContext());
1147 }
1148 
1149 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1150  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1151  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1152  getContext());
1153 }
1154 
1155 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1157  getIterationBounds(shape);
1158  return shape;
1159 }
1160 
1161 /// Return a fused vector::ContractionOp which represents a patterns such as:
1162 ///
1163 /// ```mlir
1164 /// %c0 = vector.constant 0: ...
1165 /// %c = vector.contract %a, %b, %c0: ...
1166 /// %e = add %c, %d: ...
1167 /// ```
1168 ///
1169 /// by:
1170 ///
1171 /// ```mlir
1172 /// %e = vector.contract %a, %b, %d: ...
1173 /// ```
1174 ///
1175 /// Return null if the canonicalization does not apply.
1176 // TODO: This should be a folding of Add into Contract in core but while they
1177 // live in different dialects, it is not possible without unnatural
1178 // dependencies.
1179 template <typename AddOpType>
1180 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1182 
1183  LogicalResult matchAndRewrite(AddOpType addOp,
1184  PatternRewriter &rewriter) const override {
1185  auto canonicalize = [&](Value maybeContraction,
1186  Value otherOperand) -> vector::ContractionOp {
1187  vector::ContractionOp contractionOp =
1188  dyn_cast_or_null<vector::ContractionOp>(
1189  maybeContraction.getDefiningOp());
1190  if (!contractionOp)
1191  return vector::ContractionOp();
1192  if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1193  contractionOp.getAcc().getDefiningOp())) {
1194  if (maybeZero.getValue() ==
1195  rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1196  IRMapping bvm;
1197  bvm.map(contractionOp.getAcc(), otherOperand);
1198  auto newContraction =
1199  cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1200  rewriter.replaceOp(addOp, newContraction.getResult());
1201  return newContraction;
1202  }
1203  }
1204  return vector::ContractionOp();
1205  };
1206 
1207  Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1208  vector::ContractionOp contract = canonicalize(a, b);
1209  contract = contract ? contract : canonicalize(b, a);
1210  return contract ? success() : failure();
1211  }
1212 };
1213 
1214 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1215  MLIRContext *context) {
1218 }
1219 
1220 //===----------------------------------------------------------------------===//
1221 // ExtractElementOp
1222 //===----------------------------------------------------------------------===//
1223 
1224 void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1225  SetIntRangeFn setResultRanges) {
1226  setResultRanges(getResult(), argRanges.front());
1227 }
1228 
1229 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1230  Value source) {
1231  result.addOperands({source});
1232  result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
1233 }
1234 
1235 LogicalResult vector::ExtractElementOp::verify() {
1236  VectorType vectorType = getSourceVectorType();
1237  if (vectorType.getRank() == 0) {
1238  if (getPosition())
1239  return emitOpError("expected position to be empty with 0-D vector");
1240  return success();
1241  }
1242  if (vectorType.getRank() != 1)
1243  return emitOpError("unexpected >1 vector rank");
1244  if (!getPosition())
1245  return emitOpError("expected position for 1-D vector");
1246  return success();
1247 }
1248 
1249 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1250  // Skip the 0-D vector here now.
1251  if (!adaptor.getPosition())
1252  return {};
1253 
1254  // Fold extractelement (splat X) -> X.
1255  if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
1256  return splat.getInput();
1257 
1258  // Fold extractelement(broadcast(X)) -> X.
1259  if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1260  if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
1261  return broadcast.getSource();
1262 
1263  auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1264  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1265  if (!pos || !src)
1266  return {};
1267 
1268  auto srcElements = src.getValues<Attribute>();
1269 
1270  uint64_t posIdx = pos.getInt();
1271  if (posIdx >= srcElements.size())
1272  return {};
1273 
1274  return srcElements[posIdx];
1275 }
1276 
1277 //===----------------------------------------------------------------------===//
1278 // ExtractOp
1279 //===----------------------------------------------------------------------===//
1280 
1281 void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1282  SetIntRangeFn setResultRanges) {
1283  setResultRanges(getResult(), argRanges.front());
1284 }
1285 
1286 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1287  Value source, int64_t position) {
1288  build(builder, result, source, ArrayRef<int64_t>{position});
1289 }
1290 
1291 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1292  Value source, OpFoldResult position) {
1293  build(builder, result, source, ArrayRef<OpFoldResult>{position});
1294 }
1295 
1296 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1297  Value source, ArrayRef<int64_t> position) {
1298  build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1299  builder.getDenseI64ArrayAttr(position));
1300 }
1301 
1302 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1303  Value source, ArrayRef<OpFoldResult> position) {
1304  SmallVector<int64_t> staticPos;
1305  SmallVector<Value> dynamicPos;
1306  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1307  build(builder, result, source, dynamicPos,
1308  builder.getDenseI64ArrayAttr(staticPos));
1309 }
1310 
1311 LogicalResult
1312 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1313  ExtractOp::Adaptor adaptor,
1314  SmallVectorImpl<Type> &inferredReturnTypes) {
1315  auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1316  if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1317  vectorType.getRank()) {
1318  inferredReturnTypes.push_back(vectorType.getElementType());
1319  } else {
1320  auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1321  vectorType.getRank());
1322  inferredReturnTypes.push_back(VectorType::get(
1323  vectorType.getShape().drop_front(n), vectorType.getElementType(),
1324  vectorType.getScalableDims().drop_front(n)));
1325  }
1326  return success();
1327 }
1328 
1329 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1330  // Allow extracting 1-element vectors instead of scalars.
1331  auto isCompatible = [](TypeRange l, TypeRange r) {
1332  auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1333  return vectorType && vectorType.getShape().equals({1}) &&
1334  vectorType.getElementType() == r.front();
1335  };
1336  if (l.size() == 1 && r.size() == 1 &&
1337  (isCompatible(l, r) || isCompatible(r, l)))
1338  return true;
1339  return l == r;
1340 }
1341 
1342 LogicalResult vector::ExtractOp::verify() {
1343  // Note: This check must come before getMixedPosition() to prevent a crash.
1344  auto dynamicMarkersCount =
1345  llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1346  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1347  return emitOpError(
1348  "mismatch between dynamic and static positions (kDynamic marker but no "
1349  "corresponding dynamic position) -- this can only happen due to an "
1350  "incorrect fold/rewrite");
1351  auto position = getMixedPosition();
1352  if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1353  return emitOpError(
1354  "expected position attribute of rank no greater than vector rank");
1355  for (auto [idx, pos] : llvm::enumerate(position)) {
1356  if (pos.is<Attribute>()) {
1357  int64_t constIdx = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
1358  if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
1359  return emitOpError("expected position attribute #")
1360  << (idx + 1)
1361  << " to be a non-negative integer smaller than the "
1362  "corresponding vector dimension";
1363  }
1364  }
1365  }
1366  return success();
1367 }
1368 
1369 template <typename IntType>
1370 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1371  return llvm::to_vector<4>(llvm::map_range(
1372  arrayAttr.getAsRange<IntegerAttr>(),
1373  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1374 }
1375 
1376 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1377 /// positions.
1378 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1379  if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1380  return failure();
1381 
1382  // TODO: Canonicalization for dynamic position not implemented yet.
1383  if (extractOp.hasDynamicPosition())
1384  return failure();
1385 
1386  SmallVector<int64_t> globalPosition;
1387  ExtractOp currentOp = extractOp;
1388  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1389  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1390  while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1391  currentOp = nextOp;
1392  // TODO: Canonicalization for dynamic position not implemented yet.
1393  if (currentOp.hasDynamicPosition())
1394  return failure();
1395  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1396  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1397  }
1398  extractOp.setOperand(0, currentOp.getVector());
1399  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1400  OpBuilder b(extractOp.getContext());
1401  std::reverse(globalPosition.begin(), globalPosition.end());
1402  extractOp.setStaticPosition(globalPosition);
1403  return success();
1404 }
1405 
1406 namespace {
1407 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1408 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1409 /// Compose TransposeOp permutations as we walk back.
1410 /// This helper class keeps an updated extraction position `extractPosition`
1411 /// with extra trailing sentinels.
1412 /// The sentinels encode the internal transposition status of the result vector.
1413 /// As we iterate, extractPosition is permuted and updated.
1414 class ExtractFromInsertTransposeChainState {
1415 public:
1416  ExtractFromInsertTransposeChainState(ExtractOp e);
1417 
1418  /// Iterate over producing insert and transpose ops until we find a fold.
1419  Value fold();
1420 
1421 private:
1422  /// Return true if the vector at position `a` is contained within the vector
1423  /// at position `b`. Under insert/extract semantics, this is the same as `a`
1424  /// is a prefix of `b`.
1425  template <typename ContainerA, typename ContainerB>
1426  bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1427  return a.size() <= b.size() &&
1428  std::equal(a.begin(), a.begin() + a.size(), b.begin());
1429  }
1430 
1431  /// Return true if the vector at position `a` intersects the vector at
1432  /// position `b`. Under insert/extract semantics, this is the same as equality
1433  /// of all entries of `a` that are >=0 with the corresponding entries of b.
1434  /// Comparison is on the common prefix (i.e. zip).
1435  template <typename ContainerA, typename ContainerB>
1436  bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1437  for (auto [elemA, elemB] : llvm::zip(a, b)) {
1438  if (elemA < 0 || elemB < 0)
1439  continue;
1440  if (elemA != elemB)
1441  return false;
1442  }
1443  return true;
1444  }
1445 
1446  /// Folding is only possible in the absence of an internal permutation in the
1447  /// result vector.
1448  bool canFold() {
1449  return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1450  }
1451 
1452  // Helper to get the next defining op of interest.
1453  void updateStateForNextIteration(Value v) {
1454  nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1455  nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1456  };
1457 
1458  // Case 1. If we hit a transpose, just compose the map and iterate.
1459  // Invariant: insert + transpose do not change rank, we can always compose.
1460  LogicalResult handleTransposeOp();
1461 
1462  // Case 2: the insert position matches extractPosition exactly, early return.
1463  LogicalResult handleInsertOpWithMatchingPos(Value &res);
1464 
1465  /// Case 3: if the insert position is a prefix of extractPosition, extract a
1466  /// portion of the source of the insert.
1467  /// Example:
1468  /// ```
1469  /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1470  /// // extractPosition == [1, 2, 3]
1471  /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1472  /// // can fold to vector.extract %source[0, 3]
1473  /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1474  /// ```
1475  /// To traverse through %source, we need to set the leading dims to 0 and
1476  /// drop the extra leading dims.
1477  /// This method updates the internal state.
1478  LogicalResult handleInsertOpWithPrefixPos(Value &res);
1479 
1480  /// Try to fold in place to extract(source, extractPosition) and return the
1481  /// folded result. Return null if folding is not possible (e.g. due to an
1482  /// internal tranposition in the result).
1483  Value tryToFoldExtractOpInPlace(Value source);
1484 
1485  ExtractOp extractOp;
1486  int64_t vectorRank;
1487  int64_t extractedRank;
1488 
1489  InsertOp nextInsertOp;
1490  TransposeOp nextTransposeOp;
1491 
1492  /// Sentinel values that encode the internal permutation status of the result.
1493  /// They are set to (-1, ... , -k) at the beginning and appended to
1494  /// `extractPosition`.
1495  /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1496  /// ensure that there is no internal transposition.
1497  /// Internal transposition cannot be accounted for with a folding pattern.
1498  // TODO: We could relax the internal transposition with an extra transposition
1499  // operation in a future canonicalizer.
1500  SmallVector<int64_t> sentinels;
1502 };
1503 } // namespace
1504 
1505 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1506  ExtractOp e)
1507  : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1508  extractedRank(extractOp.getNumIndices()) {
1509  assert(vectorRank >= extractedRank && "Extracted position overflow");
1510  sentinels.reserve(vectorRank - extractedRank);
1511  for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1512  sentinels.push_back(-(i + 1));
1513  extractPosition.assign(extractOp.getStaticPosition().begin(),
1514  extractOp.getStaticPosition().end());
1515  llvm::append_range(extractPosition, sentinels);
1516 }
1517 
1518 // Case 1. If we hit a transpose, just compose the map and iterate.
1519 // Invariant: insert + transpose do not change rank, we can always compose.
1520 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1521  // TODO: Canonicalization for dynamic position not implemented yet.
1522  if (extractOp.hasDynamicPosition())
1523  return failure();
1524 
1525  if (!nextTransposeOp)
1526  return failure();
1527  AffineMap m = inversePermutation(AffineMap::getPermutationMap(
1528  nextTransposeOp.getPermutation(), extractOp.getContext()));
1530  return success();
1531 }
1532 
1533 // Case 2: the insert position matches extractPosition exactly, early return.
1534 LogicalResult
1535 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1536  Value &res) {
1537  // TODO: Canonicalization for dynamic position not implemented yet.
1538  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1539  return failure();
1540 
1541  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1542  if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1543  return failure();
1544  // Case 2.a. early-exit fold.
1545  res = nextInsertOp.getSource();
1546  // Case 2.b. if internal transposition is present, canFold will be false.
1547  return success(canFold());
1548 }
1549 
1550 /// Case 3: if inserted position is a prefix of extractPosition,
1551 /// extract a portion of the source of the insertion.
1552 /// This method updates the internal state.
1553 LogicalResult
1554 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1555  // TODO: Canonicalization for dynamic position not implemented yet.
1556  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1557  return failure();
1558 
1559  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1560  if (!isContainedWithin(insertedPos, extractPosition))
1561  return failure();
1562  // Set leading dims to zero.
1563  std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1564  // Drop extra leading dims.
1565  extractPosition.erase(extractPosition.begin(),
1566  extractPosition.begin() + insertedPos.size());
1567  extractedRank = extractPosition.size() - sentinels.size();
1568  // Case 3.a. early-exit fold (break and delegate to post-while path).
1569  res = nextInsertOp.getSource();
1570  // Case 3.b. if internal transposition is present, canFold will be false.
1571  return success();
1572 }
1573 
1574 /// Try to fold in place to extract(source, extractPosition) and return the
1575 /// folded result. Return null if folding is not possible (e.g. due to an
1576 /// internal tranposition in the result).
1577 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1578  Value source) {
1579  // TODO: Canonicalization for dynamic position not implemented yet.
1580  if (extractOp.hasDynamicPosition())
1581  return Value();
1582 
1583  // If we can't fold (either internal transposition, or nothing to fold), bail.
1584  bool nothingToFold = (source == extractOp.getVector());
1585  if (nothingToFold || !canFold())
1586  return Value();
1587 
1588  // Otherwise, fold by updating the op inplace and return its result.
1589  OpBuilder b(extractOp.getContext());
1590  extractOp.setStaticPosition(
1591  ArrayRef(extractPosition).take_front(extractedRank));
1592  extractOp.getVectorMutable().assign(source);
1593  return extractOp.getResult();
1594 }
1595 
1596 /// Iterate over producing insert and transpose ops until we find a fold.
1597 Value ExtractFromInsertTransposeChainState::fold() {
1598  // TODO: Canonicalization for dynamic position not implemented yet.
1599  if (extractOp.hasDynamicPosition())
1600  return Value();
1601 
1602  Value valueToExtractFrom = extractOp.getVector();
1603  updateStateForNextIteration(valueToExtractFrom);
1604  while (nextInsertOp || nextTransposeOp) {
1605  // Case 1. If we hit a transpose, just compose the map and iterate.
1606  // Invariant: insert + transpose do not change rank, we can always compose.
1607  if (succeeded(handleTransposeOp())) {
1608  valueToExtractFrom = nextTransposeOp.getVector();
1609  updateStateForNextIteration(valueToExtractFrom);
1610  continue;
1611  }
1612 
1613  Value result;
1614  // Case 2: the position match exactly.
1615  if (succeeded(handleInsertOpWithMatchingPos(result)))
1616  return result;
1617 
1618  // Case 3: if the inserted position is a prefix of extractPosition, we can
1619  // just extract a portion of the source of the insert.
1620  if (succeeded(handleInsertOpWithPrefixPos(result)))
1621  return tryToFoldExtractOpInPlace(result);
1622 
1623  // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1624  // values. This is a more difficult case and we bail.
1625  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1626  if (isContainedWithin(extractPosition, insertedPos) ||
1627  intersectsWhereNonNegative(extractPosition, insertedPos))
1628  return Value();
1629 
1630  // Case 5: No intersection, we forward the extract to insertOp.dest().
1631  valueToExtractFrom = nextInsertOp.getDest();
1632  updateStateForNextIteration(valueToExtractFrom);
1633  }
1634  // If after all this we can fold, go for it.
1635  return tryToFoldExtractOpInPlace(valueToExtractFrom);
1636 }
1637 
1638 /// Returns true if the operation has a 0-D vector type operand or result.
1639 static bool hasZeroDimVectors(Operation *op) {
1640  auto hasZeroDimVectorType = [](Type type) -> bool {
1641  auto vecType = dyn_cast<VectorType>(type);
1642  return vecType && vecType.getRank() == 0;
1643  };
1644 
1645  return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1646  llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1647 }
1648 
1649 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1650 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1651  // TODO: Canonicalization for dynamic position not implemented yet.
1652  if (extractOp.hasDynamicPosition())
1653  return Value();
1654 
1655  Operation *defOp = extractOp.getVector().getDefiningOp();
1656  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1657  return Value();
1658 
1659  Value source = defOp->getOperand(0);
1660  if (extractOp.getType() == source.getType())
1661  return source;
1662  auto getRank = [](Type type) {
1663  return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1664  : 0;
1665  };
1666 
1667  // If splat or broadcast from a scalar, just return the source scalar.
1668  unsigned broadcastSrcRank = getRank(source.getType());
1669  if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
1670  return source;
1671 
1672  unsigned extractResultRank = getRank(extractOp.getType());
1673  if (extractResultRank >= broadcastSrcRank)
1674  return Value();
1675  // Check that the dimension of the result haven't been broadcasted.
1676  auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1677  auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
1678  if (extractVecType && broadcastVecType &&
1679  extractVecType.getShape() !=
1680  broadcastVecType.getShape().take_back(extractResultRank))
1681  return Value();
1682 
1683  auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1684  int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1685 
1686  // Detect all the positions that come from "dim-1" broadcasting.
1687  // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1688  // extract position to `0` when extracting from the source operand.
1689  llvm::SetVector<int64_t> broadcastedUnitDims =
1690  broadcastOp.computeBroadcastedUnitDims();
1691  SmallVector<int64_t> extractPos(extractOp.getStaticPosition());
1692  int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1693  for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1694  if (broadcastedUnitDims.contains(i))
1695  extractPos[i] = 0;
1696  // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1697  // matching extract position when extracting from the source operand.
1698  int64_t rankDiff = broadcastSrcRank - extractResultRank;
1699  extractPos.erase(extractPos.begin(),
1700  std::next(extractPos.begin(), extractPos.size() - rankDiff));
1701  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1702  OpBuilder b(extractOp.getContext());
1703  extractOp.setOperand(0, source);
1704  extractOp.setStaticPosition(extractPos);
1705  return extractOp.getResult();
1706 }
1707 
1708 /// Fold extractOp coming from ShuffleOp.
1709 ///
1710 /// Example:
1711 ///
1712 /// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1713 /// : vector<8xf32>, vector<8xf32>
1714 /// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1715 /// ->
1716 /// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1717 ///
1718 static Value foldExtractFromShuffle(ExtractOp extractOp) {
1719  // Dynamic positions are not folded as the resulting code would be more
1720  // complex than the input code.
1721  if (extractOp.hasDynamicPosition())
1722  return Value();
1723 
1724  auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1725  if (!shuffleOp)
1726  return Value();
1727 
1728  // TODO: 0-D or multi-dimensional vectors not supported yet.
1729  if (shuffleOp.getResultVectorType().getRank() != 1)
1730  return Value();
1731 
1732  int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1733  auto shuffleMask = shuffleOp.getMask();
1734  int64_t extractIdx = extractOp.getStaticPosition()[0];
1735  int64_t shuffleIdx = shuffleMask[extractIdx];
1736 
1737  // Find the shuffled vector to extract from based on the shuffle index.
1738  if (shuffleIdx < inputVecSize) {
1739  extractOp.setOperand(0, shuffleOp.getV1());
1740  extractOp.setStaticPosition({shuffleIdx});
1741  } else {
1742  extractOp.setOperand(0, shuffleOp.getV2());
1743  extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1744  }
1745 
1746  return extractOp.getResult();
1747 }
1748 
1749 // Fold extractOp with source coming from ShapeCast op.
1750 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1751  // TODO: Canonicalization for dynamic position not implemented yet.
1752  if (extractOp.hasDynamicPosition())
1753  return Value();
1754 
1755  auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1756  if (!shapeCastOp)
1757  return Value();
1758 
1759  // 0-D vectors not supported.
1760  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1761  if (hasZeroDimVectors(shapeCastOp))
1762  return Value();
1763 
1764  // Get the nth dimension size starting from lowest dimension.
1765  auto getDimReverse = [](VectorType type, int64_t n) {
1766  return type.getShape().take_back(n + 1).front();
1767  };
1768  int64_t destinationRank =
1769  llvm::isa<VectorType>(extractOp.getType())
1770  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1771  : 0;
1772  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1773  return Value();
1774  if (destinationRank > 0) {
1775  auto destinationType =
1776  llvm::cast<VectorType>(extractOp.getResult().getType());
1777  for (int64_t i = 0; i < destinationRank; i++) {
1778  // The lowest dimension of the destination must match the lowest
1779  // dimension of the shapecast op source.
1780  // TODO: This case could be support in a canonicalization pattern.
1781  if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1782  getDimReverse(destinationType, i))
1783  return Value();
1784  }
1785  }
1786  // Extract the strides associated with the extract op vector source. Then use
1787  // this to calculate a linearized position for the extract.
1788  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1789  std::reverse(extractedPos.begin(), extractedPos.end());
1790  SmallVector<int64_t, 4> strides;
1791  int64_t stride = 1;
1792  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1793  strides.push_back(stride);
1794  stride *=
1795  getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1796  }
1797 
1798  int64_t position = linearize(extractedPos, strides);
1799  // Then extract the strides associated to the shapeCast op vector source and
1800  // delinearize the position using those strides.
1801  SmallVector<int64_t, 4> newStrides;
1802  int64_t numDimension =
1803  shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1804  stride = 1;
1805  for (int64_t i = 0; i < numDimension; i++) {
1806  newStrides.push_back(stride);
1807  stride *=
1808  getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1809  }
1810  std::reverse(newStrides.begin(), newStrides.end());
1811  SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1812  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1813  OpBuilder b(extractOp.getContext());
1814  extractOp.setStaticPosition(newPosition);
1815  extractOp.setOperand(0, shapeCastOp.getSource());
1816  return extractOp.getResult();
1817 }
1818 
1819 /// Fold an ExtractOp from ExtractStridedSliceOp.
1820 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1821  // TODO: Canonicalization for dynamic position not implemented yet.
1822  if (extractOp.hasDynamicPosition())
1823  return Value();
1824 
1825  auto extractStridedSliceOp =
1826  extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1827  if (!extractStridedSliceOp)
1828  return Value();
1829 
1830  // 0-D vectors not supported.
1831  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1832  if (hasZeroDimVectors(extractStridedSliceOp))
1833  return Value();
1834 
1835  // Return if 'extractStridedSliceOp' has non-unit strides.
1836  if (extractStridedSliceOp.hasNonUnitStrides())
1837  return Value();
1838 
1839  // Trim offsets for dimensions fully extracted.
1840  auto sliceOffsets =
1841  extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1842  while (!sliceOffsets.empty()) {
1843  size_t lastOffset = sliceOffsets.size() - 1;
1844  if (sliceOffsets.back() != 0 ||
1845  extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1846  extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1847  break;
1848  sliceOffsets.pop_back();
1849  }
1850  unsigned destinationRank = 0;
1851  if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1852  destinationRank = vecType.getRank();
1853  // The dimensions of the result need to be untouched by the
1854  // extractStridedSlice op.
1855  if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1856  sliceOffsets.size())
1857  return Value();
1858 
1859  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1860  assert(extractedPos.size() >= sliceOffsets.size());
1861  for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1862  extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1863  extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1864 
1865  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1866  OpBuilder b(extractOp.getContext());
1867  extractOp.setStaticPosition(extractedPos);
1868  return extractOp.getResult();
1869 }
1870 
1871 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1872 static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1873  // TODO: Canonicalization for dynamic position not implemented yet.
1874  if (extractOp.hasDynamicPosition())
1875  return Value();
1876 
1877  int64_t destinationRank =
1878  llvm::isa<VectorType>(extractOp.getType())
1879  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1880  : 0;
1881  auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1882  if (!insertOp)
1883  return Value();
1884 
1885  // 0-D vectors not supported.
1886  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1887  if (hasZeroDimVectors(insertOp))
1888  return Value();
1889 
1890  while (insertOp) {
1891  int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1892  insertOp.getSourceVectorType().getRank();
1893  if (destinationRank > insertOp.getSourceVectorType().getRank())
1894  return Value();
1895  auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1896  ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1897 
1898  if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1899  return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1900  }))
1901  return Value();
1902  bool disjoint = false;
1903  SmallVector<int64_t, 4> offsetDiffs;
1904  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1905  int64_t start = insertOffsets[dim];
1906  int64_t size =
1907  (dim < insertRankDiff)
1908  ? 1
1909  : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1910  int64_t end = start + size;
1911  int64_t offset = extractOffsets[dim];
1912  // Check if the start of the extract offset is in the interval inserted.
1913  if (start <= offset && offset < end) {
1914  if (dim >= insertRankDiff)
1915  offsetDiffs.push_back(offset - start);
1916  continue;
1917  }
1918  disjoint = true;
1919  break;
1920  }
1921  // The extract element chunk overlap with the vector inserted.
1922  if (!disjoint) {
1923  // If any of the inner dimensions are only partially inserted we have a
1924  // partial overlap.
1925  int64_t srcRankDiff =
1926  insertOp.getSourceVectorType().getRank() - destinationRank;
1927  for (int64_t i = 0; i < destinationRank; i++) {
1928  if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1929  insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1930  insertRankDiff))
1931  return Value();
1932  }
1933  extractOp.getVectorMutable().assign(insertOp.getSource());
1934  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1935  OpBuilder b(extractOp.getContext());
1936  extractOp.setStaticPosition(offsetDiffs);
1937  return extractOp.getResult();
1938  }
1939  // If the chunk extracted is disjoint from the chunk inserted, keep
1940  // looking in the insert chain.
1941  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1942  }
1943  return Value();
1944 }
1945 
1946 /// Try to fold the extraction of a scalar from a vector defined by
1947 /// vector.from_elements. E.g.:
1948 ///
1949 /// %0 = vector.from_elements %a, %b : vector<2xf32>
1950 /// %1 = vector.extract %0[0] : f32 from vector<2xf32>
1951 /// ==> fold to %a
1952 static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
1953  // Dynamic extractions cannot be folded.
1954  if (extractOp.hasDynamicPosition())
1955  return {};
1956 
1957  // Look for extract(from_elements).
1958  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
1959  if (!fromElementsOp)
1960  return {};
1961 
1962  // Scalable vectors are not supported.
1963  auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1964  if (vecType.isScalable())
1965  return {};
1966 
1967  // Only extractions of scalars are supported.
1968  int64_t rank = vecType.getRank();
1969  ArrayRef<int64_t> indices = extractOp.getStaticPosition();
1970  if (extractOp.getType() != vecType.getElementType())
1971  return {};
1972  assert(static_cast<int64_t>(indices.size()) == rank &&
1973  "unexpected number of indices");
1974 
1975  // Compute flattened/linearized index and fold to operand.
1976  int flatIndex = 0;
1977  int stride = 1;
1978  for (int i = rank - 1; i >= 0; --i) {
1979  flatIndex += indices[i] * stride;
1980  stride *= vecType.getDimSize(i);
1981  }
1982  return fromElementsOp.getElements()[flatIndex];
1983 }
1984 
1985 OpFoldResult ExtractOp::fold(FoldAdaptor) {
1986  // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
1987  // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
1988  // mismatch).
1989  if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
1990  return getVector();
1991  if (succeeded(foldExtractOpFromExtractChain(*this)))
1992  return getResult();
1993  if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
1994  return res;
1995  if (auto res = foldExtractFromBroadcast(*this))
1996  return res;
1997  if (auto res = foldExtractFromShuffle(*this))
1998  return res;
1999  if (auto res = foldExtractFromShapeCast(*this))
2000  return res;
2001  if (auto val = foldExtractFromExtractStrided(*this))
2002  return val;
2003  if (auto val = foldExtractStridedOpFromInsertChain(*this))
2004  return val;
2005  if (auto val = foldScalarExtractFromFromElements(*this))
2006  return val;
2007  return OpFoldResult();
2008 }
2009 
2010 namespace {
2011 
2012 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2013 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2014 public:
2016 
2017  LogicalResult matchAndRewrite(ExtractOp extractOp,
2018  PatternRewriter &rewriter) const override {
2019  Operation *defOp = extractOp.getVector().getDefiningOp();
2020  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2021  return failure();
2022 
2023  Value source = defOp->getOperand(0);
2024  if (extractOp.getType() == source.getType())
2025  return failure();
2026  auto getRank = [](Type type) {
2027  return llvm::isa<VectorType>(type)
2028  ? llvm::cast<VectorType>(type).getRank()
2029  : 0;
2030  };
2031  unsigned broadcastSrcRank = getRank(source.getType());
2032  unsigned extractResultRank = getRank(extractOp.getType());
2033  // We only consider the case where the rank of the source is less than or
2034  // equal to the rank of the extract dst. The other cases are handled in the
2035  // folding patterns.
2036  if (extractResultRank < broadcastSrcRank)
2037  return failure();
2038 
2039  // Special case if broadcast src is a 0D vector.
2040  if (extractResultRank == 0) {
2041  assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
2042  rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
2043  return success();
2044  }
2045  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2046  extractOp, extractOp.getType(), source);
2047  return success();
2048  }
2049 };
2050 
2051 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
2052 class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
2053 public:
2055 
2056  LogicalResult matchAndRewrite(ExtractOp extractOp,
2057  PatternRewriter &rewriter) const override {
2058  // Return if 'ExtractOp' operand is not defined by a splat vector
2059  // ConstantOp.
2060  Value sourceVector = extractOp.getVector();
2061  Attribute vectorCst;
2062  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2063  return failure();
2064  auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
2065  if (!splat)
2066  return failure();
2067  TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
2068  if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
2069  newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2070  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2071  return success();
2072  }
2073 };
2074 
2075 // Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
2076 class ExtractOpNonSplatConstantFolder final
2077  : public OpRewritePattern<ExtractOp> {
2078 public:
2080 
2081  LogicalResult matchAndRewrite(ExtractOp extractOp,
2082  PatternRewriter &rewriter) const override {
2083  // TODO: Canonicalization for dynamic position not implemented yet.
2084  if (extractOp.hasDynamicPosition())
2085  return failure();
2086 
2087  // Return if 'ExtractOp' operand is not defined by a compatible vector
2088  // ConstantOp.
2089  Value sourceVector = extractOp.getVector();
2090  Attribute vectorCst;
2091  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2092  return failure();
2093 
2094  auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
2095  if (vecTy.isScalable())
2096  return failure();
2097 
2098  // The splat case is handled by `ExtractOpSplatConstantFolder`.
2099  auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2100  if (!dense || dense.isSplat())
2101  return failure();
2102 
2103  // Calculate the linearized position of the continuous chunk of elements to
2104  // extract.
2105  llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2106  copy(extractOp.getStaticPosition(), completePositions.begin());
2107  int64_t elemBeginPosition =
2108  linearize(completePositions, computeStrides(vecTy.getShape()));
2109  auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
2110 
2111  TypedAttr newAttr;
2112  if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
2113  SmallVector<Attribute> elementValues(
2114  denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2115  newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2116  } else {
2117  newAttr = *denseValuesBegin;
2118  }
2119 
2120  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2121  return success();
2122  }
2123 };
2124 
2125 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2126 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2127 public:
2129 
2130  LogicalResult matchAndRewrite(ExtractOp extractOp,
2131  PatternRewriter &rewriter) const override {
2132  auto createMaskOp =
2133  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2134  if (!createMaskOp)
2135  return failure();
2136 
2137  VectorType extractedMaskType =
2138  llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2139 
2140  if (!extractedMaskType)
2141  return failure();
2142 
2143  auto maskOperands = createMaskOp.getOperands();
2144  ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2145  VectorType maskType = createMaskOp.getVectorType();
2146 
2147  bool containsUnknownDims = false;
2148  bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2149 
2150  for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2151  dimIdx++) {
2152  int64_t pos = extractOpPos[dimIdx];
2153  Value operand = maskOperands[dimIdx];
2154  auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2155  if (!constantOp) {
2156  // Bounds of this dim unknown.
2157  containsUnknownDims = true;
2158  continue;
2159  }
2160 
2161  int64_t createMaskBound =
2162  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2163 
2164  if (pos != ShapedType::kDynamic) {
2165  // If any position is outside the range from the `create_mask`, then the
2166  // extracted mask will be all-false.
2167  allFalse |= pos >= createMaskBound;
2168  } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2169  // This dim is not all-true and since this is a dynamic index we don't
2170  // know if the extraction is within the true or false region.
2171  // Note: Zero dims have already handled via getMaskFormat().
2172  containsUnknownDims = true;
2173  }
2174  }
2175 
2176  if (allFalse) {
2177  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2178  extractOp, DenseElementsAttr::get(extractedMaskType, false));
2179  } else if (!containsUnknownDims) {
2180  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2181  extractOp, extractedMaskType,
2182  maskOperands.drop_front(extractOpPos.size()));
2183  } else {
2184  return failure();
2185  }
2186  return success();
2187  }
2188 };
2189 
2190 // Folds extract(shape_cast(..)) into shape_cast when the total element count
2191 // does not change.
2192 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2193  PatternRewriter &rewriter) {
2194  auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2195  if (!castOp)
2196  return failure();
2197 
2198  VectorType sourceType = castOp.getSourceVectorType();
2199  auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2200  if (!targetType)
2201  return failure();
2202 
2203  if (sourceType.getNumElements() != targetType.getNumElements())
2204  return failure();
2205 
2206  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2207  castOp.getSource());
2208  return success();
2209 }
2210 
2211 /// Try to canonicalize the extraction of a subvector from a vector defined by
2212 /// vector.from_elements. E.g.:
2213 ///
2214 /// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2215 /// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2216 /// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2217 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2218  PatternRewriter &rewriter) {
2219  // Dynamic positions are not supported.
2220  if (extractOp.hasDynamicPosition())
2221  return failure();
2222 
2223  // Scalar extracts are handled by the folder.
2224  auto resultType = dyn_cast<VectorType>(extractOp.getType());
2225  if (!resultType)
2226  return failure();
2227 
2228  // Look for extracts from a from_elements op.
2229  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2230  if (!fromElementsOp)
2231  return failure();
2232  VectorType inputType = fromElementsOp.getType();
2233 
2234  // Scalable vectors are not supported.
2235  if (resultType.isScalable() || inputType.isScalable())
2236  return failure();
2237 
2238  // Compute the position of first extracted element and flatten/linearize the
2239  // position.
2240  SmallVector<int64_t> firstElementPos =
2241  llvm::to_vector(extractOp.getStaticPosition());
2242  firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2243  int flatIndex = 0;
2244  int stride = 1;
2245  for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2246  flatIndex += firstElementPos[i] * stride;
2247  stride *= inputType.getDimSize(i);
2248  }
2249 
2250  // Replace the op with a smaller from_elements op.
2251  rewriter.replaceOpWithNewOp<FromElementsOp>(
2252  extractOp, resultType,
2253  fromElementsOp.getElements().slice(flatIndex,
2254  resultType.getNumElements()));
2255  return success();
2256 }
2257 } // namespace
2258 
2259 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2260  MLIRContext *context) {
2261  results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2262  ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2263  results.add(foldExtractFromShapeCastToShapeCast);
2264  results.add(foldExtractFromFromElements);
2265 }
2266 
2267 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
2268  SmallVectorImpl<int64_t> &results) {
2269  for (auto attr : arrayAttr)
2270  results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2271 }
2272 
2273 //===----------------------------------------------------------------------===//
2274 // FmaOp
2275 //===----------------------------------------------------------------------===//
2276 
2277 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2278  return llvm::to_vector<4>(getVectorType().getShape());
2279 }
2280 
2281 //===----------------------------------------------------------------------===//
2282 // FromElementsOp
2283 //===----------------------------------------------------------------------===//
2284 
2285 /// Rewrite a vector.from_elements into a vector.splat if all elements are the
2286 /// same SSA value. E.g.:
2287 ///
2288 /// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2289 /// ==> rewrite to vector.splat %a : vector<3xf32>
2290 static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
2291  PatternRewriter &rewriter) {
2292  if (!llvm::all_equal(fromElementsOp.getElements()))
2293  return failure();
2294  rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
2295  fromElementsOp.getElements().front());
2296  return success();
2297 }
2298 
2299 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2300  MLIRContext *context) {
2302 }
2303 
2304 //===----------------------------------------------------------------------===//
2305 // BroadcastOp
2306 //===----------------------------------------------------------------------===//
2307 
2308 void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2309  SetIntRangeFn setResultRanges) {
2310  setResultRanges(getResult(), argRanges.front());
2311 }
2312 
2313 /// Return the dimensions of the result vector that were formerly ones in the
2314 /// source tensor and thus correspond to "dim-1" broadcasting.
2317  ArrayRef<int64_t> dstShape) {
2318  int64_t rankDiff = dstShape.size() - srcShape.size();
2319  int64_t dstDim = rankDiff;
2321  for (auto [s1, s2] :
2322  llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2323  if (s1 != s2) {
2324  assert(s1 == 1 && "expected dim-1 broadcasting");
2325  res.insert(dstDim);
2326  }
2327  ++dstDim;
2328  }
2329  return res;
2330 }
2331 
2333  // Scalar broadcast is without any unit dim broadcast.
2334  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2335  if (!srcVectorType)
2336  return {};
2337  return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2338  getResultVectorType().getShape());
2339 }
2340 
2341 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2342 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
2343 /// This requires (and asserts) that the broadcast is free of dim-1
2344 /// broadcasting.
2345 /// Since vector.broadcast only allows expanding leading dimensions, an extra
2346 /// vector.transpose may be inserted to make the broadcast possible.
2347 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2348 /// the helper will assert. This means:
2349 /// 1. `dstShape` must not be empty.
2350 /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2351 /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2352 // must match the `value` shape.
2353 Value BroadcastOp::createOrFoldBroadcastOp(
2354  OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2355  const llvm::SetVector<int64_t> &broadcastedDims) {
2356  assert(!dstShape.empty() && "unexpected empty dst shape");
2357 
2358  // Well-formedness check.
2359  SmallVector<int64_t> checkShape;
2360  for (int i = 0, e = dstShape.size(); i < e; ++i) {
2361  if (broadcastedDims.contains(i))
2362  continue;
2363  checkShape.push_back(dstShape[i]);
2364  }
2365  assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2366  "ill-formed broadcastedDims contains values not confined to "
2367  "destVectorShape");
2368 
2369  Location loc = value.getLoc();
2370  Type elementType = getElementTypeOrSelf(value.getType());
2371  VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2372  VectorType dstVectorType = VectorType::get(dstShape, elementType);
2373 
2374  // Step 2. If scalar -> dstShape broadcast, just do it.
2375  if (!srcVectorType) {
2376  assert(checkShape.empty() &&
2377  "ill-formed createOrFoldBroadcastOp arguments");
2378  return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2379  }
2380 
2381  assert(srcVectorType.getShape().equals(checkShape) &&
2382  "ill-formed createOrFoldBroadcastOp arguments");
2383 
2384  // Step 3. Since vector.broadcast only allows creating leading dims,
2385  // vector -> dstShape broadcast may require a transpose.
2386  // Traverse the dims in order and construct:
2387  // 1. The leading entries of the broadcastShape that is guaranteed to be
2388  // achievable by a simple broadcast.
2389  // 2. The induced permutation for the subsequent vector.transpose that will
2390  // bring us from `broadcastShape` back to he desired `dstShape`.
2391  // If the induced permutation is not the identity, create a vector.transpose.
2392  SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2393  broadcastShape.reserve(dstShape.size());
2394  // Consider the example:
2395  // srcShape = 2x4
2396  // dstShape = 1x2x3x4x5
2397  // broadcastedDims = [0, 2, 4]
2398  //
2399  // We want to build:
2400  // broadcastShape = 1x3x5x2x4
2401  // permutation = [0, 2, 4, 1, 3]
2402  // ---V--- -----V-----
2403  // leading broadcast part src shape part
2404  //
2405  // Note that the trailing dims of broadcastShape are exactly the srcShape
2406  // by construction.
2407  // nextSrcShapeDim is used to keep track of where in the permutation the
2408  // "src shape part" occurs.
2409  int64_t nextSrcShapeDim = broadcastedDims.size();
2410  for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2411  if (broadcastedDims.contains(i)) {
2412  // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2413  // bring it to the head of the broadcastShape.
2414  // It will need to be permuted back from `broadcastShape.size() - 1` into
2415  // position `i`.
2416  broadcastShape.push_back(dstShape[i]);
2417  permutation[i] = broadcastShape.size() - 1;
2418  } else {
2419  // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2420  // shape and needs to be permuted into position `i`.
2421  // Don't touch `broadcastShape` here, the whole srcShape will be
2422  // appended after.
2423  permutation[i] = nextSrcShapeDim++;
2424  }
2425  }
2426  // 3.c. Append the srcShape.
2427  llvm::append_range(broadcastShape, srcVectorType.getShape());
2428 
2429  // Ensure there are no dim-1 broadcasts.
2430  assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2431  .empty() &&
2432  "unexpected dim-1 broadcast");
2433 
2434  VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2435  assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2436  vector::BroadcastableToResult::Success &&
2437  "must be broadcastable");
2438  Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2439  // Step 4. If we find any dimension that indeed needs to be permuted,
2440  // immediately return a new vector.transpose.
2441  for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2442  if (permutation[i] != i)
2443  return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2444  // Otherwise return res.
2445  return res;
2446 }
2447 
2449  Type srcType, VectorType dstVectorType,
2450  std::pair<VectorDim, VectorDim> *mismatchingDims) {
2451  // Broadcast scalar to vector of the same element type.
2452  if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
2453  getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
2454  return BroadcastableToResult::Success;
2455  // From now on, only vectors broadcast.
2456  VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2457  if (!srcVectorType)
2458  return BroadcastableToResult::SourceTypeNotAVector;
2459 
2460  int64_t srcRank = srcVectorType.getRank();
2461  int64_t dstRank = dstVectorType.getRank();
2462  if (srcRank > dstRank)
2463  return BroadcastableToResult::SourceRankHigher;
2464  // Source has an exact match or singleton value for all trailing dimensions
2465  // (all leading dimensions are simply duplicated).
2466  int64_t lead = dstRank - srcRank;
2467  for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2468  // Have mismatching dims (in the sense of vector.broadcast semantics) been
2469  // encountered?
2470  bool foundMismatchingDims = false;
2471 
2472  // Check fixed-width dims.
2473  int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2474  int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2475  if (srcDim != 1 && srcDim != dstDim)
2476  foundMismatchingDims = true;
2477 
2478  // Check scalable flags.
2479  bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2480  bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2481  if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2482  // 1 -> [N] is fine, everything else should be rejected when mixing
2483  // fixed-width and scalable dims
2484  (srcDimScalableFlag != dstDimScalableFlag &&
2485  (srcDim != 1 || srcDimScalableFlag)))
2486  foundMismatchingDims = true;
2487 
2488  if (foundMismatchingDims) {
2489  if (mismatchingDims != nullptr) {
2490  mismatchingDims->first.dim = srcDim;
2491  mismatchingDims->first.isScalable = srcDimScalableFlag;
2492 
2493  mismatchingDims->second.dim = dstDim;
2494  mismatchingDims->second.isScalable = dstDimScalableFlag;
2495  }
2496  return BroadcastableToResult::DimensionMismatch;
2497  }
2498  }
2499 
2500  return BroadcastableToResult::Success;
2501 }
2502 
2503 LogicalResult BroadcastOp::verify() {
2504  std::pair<VectorDim, VectorDim> mismatchingDims;
2506  getSourceType(), getResultVectorType(), &mismatchingDims);
2507  if (res == BroadcastableToResult::Success)
2508  return success();
2509  if (res == BroadcastableToResult::SourceRankHigher)
2510  return emitOpError("source rank higher than destination rank");
2511  if (res == BroadcastableToResult::DimensionMismatch) {
2512  return emitOpError("dimension mismatch (")
2513  << (mismatchingDims.first.isScalable ? "[" : "")
2514  << mismatchingDims.first.dim
2515  << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
2516  << (mismatchingDims.second.isScalable ? "[" : "")
2517  << mismatchingDims.second.dim
2518  << (mismatchingDims.second.isScalable ? "]" : "") << ")";
2519  }
2520  if (res == BroadcastableToResult::SourceTypeNotAVector)
2521  return emitOpError("source type is not a vector");
2522  llvm_unreachable("unexpected vector.broadcast op error");
2523 }
2524 
2525 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
2526  if (getSourceType() == getResultVectorType())
2527  return getSource();
2528  if (!adaptor.getSource())
2529  return {};
2530  auto vectorType = getResultVectorType();
2531  if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
2532  return DenseElementsAttr::get(vectorType, adaptor.getSource());
2533  if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2534  return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
2535  return {};
2536 }
2537 
2538 namespace {
2539 
2540 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
2541 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2543 
2544  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2545  PatternRewriter &rewriter) const override {
2546  auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2547  if (!srcBroadcast)
2548  return failure();
2549  rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
2550  broadcastOp.getResultVectorType(),
2551  srcBroadcast.getSource());
2552  return success();
2553  }
2554 };
2555 } // namespace
2556 
2557 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2558  MLIRContext *context) {
2559  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2560  // calling `populateCastAwayVectorLeadingOneDimPatterns`
2561  results.add<BroadcastFolder>(context);
2562 }
2563 
2564 //===----------------------------------------------------------------------===//
2565 // ShuffleOp
2566 //===----------------------------------------------------------------------===//
2567 
2568 LogicalResult ShuffleOp::verify() {
2569  VectorType resultType = getResultVectorType();
2570  VectorType v1Type = getV1VectorType();
2571  VectorType v2Type = getV2VectorType();
2572  // Verify ranks.
2573  int64_t resRank = resultType.getRank();
2574  int64_t v1Rank = v1Type.getRank();
2575  int64_t v2Rank = v2Type.getRank();
2576  bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2577  bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2578  if (!wellFormed0DCase && !wellFormedNDCase)
2579  return emitOpError("rank mismatch");
2580 
2581  // Verify all but leading dimension sizes.
2582  for (int64_t r = 1; r < v1Rank; ++r) {
2583  int64_t resDim = resultType.getDimSize(r);
2584  int64_t v1Dim = v1Type.getDimSize(r);
2585  int64_t v2Dim = v2Type.getDimSize(r);
2586  if (resDim != v1Dim || v1Dim != v2Dim)
2587  return emitOpError("dimension mismatch");
2588  }
2589  // Verify mask length.
2590  ArrayRef<int64_t> mask = getMask();
2591  int64_t maskLength = mask.size();
2592  if (maskLength <= 0)
2593  return emitOpError("invalid mask length");
2594  if (maskLength != resultType.getDimSize(0))
2595  return emitOpError("mask length mismatch");
2596  // Verify all indices.
2597  int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2598  (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2599  for (auto [idx, maskPos] : llvm::enumerate(mask)) {
2600  if (maskPos < 0 || maskPos >= indexSize)
2601  return emitOpError("mask index #") << (idx + 1) << " out of range";
2602  }
2603  return success();
2604 }
2605 
2606 LogicalResult
2607 ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
2608  ShuffleOp::Adaptor adaptor,
2609  SmallVectorImpl<Type> &inferredReturnTypes) {
2610  auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2611  auto v1Rank = v1Type.getRank();
2612  // Construct resulting type: leading dimension matches mask
2613  // length, all trailing dimensions match the operands.
2615  shape.reserve(v1Rank);
2616  shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2617  // In the 0-D case there is no trailing shape to append.
2618  if (v1Rank > 0)
2619  llvm::append_range(shape, v1Type.getShape().drop_front());
2620  inferredReturnTypes.push_back(
2621  VectorType::get(shape, v1Type.getElementType()));
2622  return success();
2623 }
2624 
2625 template <typename T>
2626 static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
2627  T expected = begin;
2628  return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2629  return value == expected++;
2630  });
2631 }
2632 
2633 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2634  VectorType v1Type = getV1VectorType();
2635  // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
2636  // but must be a canonicalization into a vector.broadcast.
2637  if (v1Type.getRank() == 0)
2638  return {};
2639 
2640  // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
2641  if (!v1Type.isScalable() &&
2642  isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
2643  return getV1();
2644  // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
2645  if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
2646  isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
2647  getV2VectorType().getDimSize(0)))
2648  return getV2();
2649 
2650  Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
2651  if (!lhs || !rhs)
2652  return {};
2653 
2654  auto lhsType =
2655  llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType());
2656  // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2657  // manipulation.
2658  if (lhsType.getRank() != 1)
2659  return {};
2660  int64_t lhsSize = lhsType.getDimSize(0);
2661 
2662  SmallVector<Attribute> results;
2663  auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
2664  auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
2665  for (int64_t i : this->getMask()) {
2666  if (i >= lhsSize) {
2667  results.push_back(rhsElements[i - lhsSize]);
2668  } else {
2669  results.push_back(lhsElements[i]);
2670  }
2671  }
2672 
2673  return DenseElementsAttr::get(getResultVectorType(), results);
2674 }
2675 
2676 namespace {
2677 
2678 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
2679 // to a broadcast.
2680 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
2682 
2683  LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
2684  PatternRewriter &rewriter) const override {
2685  VectorType v1VectorType = shuffleOp.getV1VectorType();
2686  ArrayRef<int64_t> mask = shuffleOp.getMask();
2687  if (v1VectorType.getRank() > 0)
2688  return failure();
2689  if (mask.size() != 1)
2690  return failure();
2691  VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
2692  if (mask[0] == 0)
2693  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2694  shuffleOp.getV1());
2695  else
2696  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2697  shuffleOp.getV2());
2698  return success();
2699  }
2700 };
2701 
2702 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
2703 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
2704 public:
2706 
2707  LogicalResult matchAndRewrite(ShuffleOp op,
2708  PatternRewriter &rewriter) const override {
2709  auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2710  auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2711 
2712  if (!v1Splat || !v2Splat)
2713  return failure();
2714 
2715  if (v1Splat.getInput() != v2Splat.getInput())
2716  return failure();
2717 
2718  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
2719  return success();
2720  }
2721 };
2722 
2723 /// Pattern to rewrite a fixed-size interleave via vector.shuffle to
2724 /// vector.interleave.
2725 class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
2726 public:
2728 
2729  LogicalResult matchAndRewrite(ShuffleOp op,
2730  PatternRewriter &rewriter) const override {
2731  VectorType resultType = op.getResultVectorType();
2732  if (resultType.isScalable())
2733  return rewriter.notifyMatchFailure(
2734  op, "ShuffleOp can't represent a scalable interleave");
2735 
2736  if (resultType.getRank() != 1)
2737  return rewriter.notifyMatchFailure(
2738  op, "ShuffleOp can't represent an n-D interleave");
2739 
2740  VectorType sourceType = op.getV1VectorType();
2741  if (sourceType != op.getV2VectorType() ||
2742  sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2743  return rewriter.notifyMatchFailure(
2744  op, "ShuffleOp types don't match an interleave");
2745  }
2746 
2747  ArrayRef<int64_t> shuffleMask = op.getMask();
2748  int64_t resultVectorSize = resultType.getNumElements();
2749  for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2750  int64_t maskValueA = shuffleMask[i * 2];
2751  int64_t maskValueB = shuffleMask[(i * 2) + 1];
2752  if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2753  return rewriter.notifyMatchFailure(op,
2754  "ShuffleOp mask not interleaving");
2755  }
2756 
2757  rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
2758  return success();
2759  }
2760 };
2761 
2762 } // namespace
2763 
2764 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
2765  MLIRContext *context) {
2766  results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2767  context);
2768 }
2769 
2770 //===----------------------------------------------------------------------===//
2771 // InsertElementOp
2772 //===----------------------------------------------------------------------===//
2773 
2774 void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2775  SetIntRangeFn setResultRanges) {
2776  setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2777 }
2778 
2779 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
2780  Value source, Value dest) {
2781  build(builder, result, source, dest, {});
2782 }
2783 
2784 LogicalResult InsertElementOp::verify() {
2785  auto dstVectorType = getDestVectorType();
2786  if (dstVectorType.getRank() == 0) {
2787  if (getPosition())
2788  return emitOpError("expected position to be empty with 0-D vector");
2789  return success();
2790  }
2791  if (dstVectorType.getRank() != 1)
2792  return emitOpError("unexpected >1 vector rank");
2793  if (!getPosition())
2794  return emitOpError("expected position for 1-D vector");
2795  return success();
2796 }
2797 
2798 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2799  // Skip the 0-D vector here.
2800  if (!adaptor.getPosition())
2801  return {};
2802 
2803  auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2804  auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2805  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2806  if (!src || !dst || !pos)
2807  return {};
2808 
2809  if (src.getType() != getDestVectorType().getElementType())
2810  return {};
2811 
2812  auto dstElements = dst.getValues<Attribute>();
2813 
2814  SmallVector<Attribute> results(dstElements);
2815 
2816  uint64_t posIdx = pos.getInt();
2817  if (posIdx >= results.size())
2818  return {};
2819  results[posIdx] = src;
2820 
2821  return DenseElementsAttr::get(getDestVectorType(), results);
2822 }
2823 
2824 //===----------------------------------------------------------------------===//
2825 // InsertOp
2826 //===----------------------------------------------------------------------===//
2827 
2828 void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2829  SetIntRangeFn setResultRanges) {
2830  setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2831 }
2832 
2833 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2834  Value source, Value dest, int64_t position) {
2835  build(builder, result, source, dest, ArrayRef<int64_t>{position});
2836 }
2837 
2838 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2839  Value source, Value dest, OpFoldResult position) {
2840  build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
2841 }
2842 
2843 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2844  Value source, Value dest,
2845  ArrayRef<int64_t> position) {
2846  SmallVector<OpFoldResult> posVals;
2847  posVals.reserve(position.size());
2848  llvm::transform(position, std::back_inserter(posVals),
2849  [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
2850  build(builder, result, source, dest, posVals);
2851 }
2852 
2853 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2854  Value source, Value dest,
2855  ArrayRef<OpFoldResult> position) {
2856  SmallVector<int64_t> staticPos;
2857  SmallVector<Value> dynamicPos;
2858  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
2859  build(builder, result, source, dest, dynamicPos,
2860  builder.getDenseI64ArrayAttr(staticPos));
2861 }
2862 
2863 LogicalResult InsertOp::verify() {
2864  SmallVector<OpFoldResult> position = getMixedPosition();
2865  auto destVectorType = getDestVectorType();
2866  if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
2867  return emitOpError(
2868  "expected position attribute of rank no greater than dest vector rank");
2869  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2870  if (srcVectorType &&
2871  (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2872  static_cast<unsigned>(destVectorType.getRank())))
2873  return emitOpError("expected position attribute rank + source rank to "
2874  "match dest vector rank");
2875  if (!srcVectorType &&
2876  (position.size() != static_cast<unsigned>(destVectorType.getRank())))
2877  return emitOpError(
2878  "expected position attribute rank to match the dest vector rank");
2879  for (auto [idx, pos] : llvm::enumerate(position)) {
2880  if (auto attr = pos.dyn_cast<Attribute>()) {
2881  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2882  if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
2883  return emitOpError("expected position attribute #")
2884  << (idx + 1)
2885  << " to be a non-negative integer smaller than the "
2886  "corresponding "
2887  "dest vector dimension";
2888  }
2889  }
2890  }
2891  return success();
2892 }
2893 
2894 namespace {
2895 
2896 // If insertOp is only inserting unit dimensions it can be transformed to a
2897 // broadcast.
2898 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
2899 public:
2901 
2902  LogicalResult matchAndRewrite(InsertOp insertOp,
2903  PatternRewriter &rewriter) const override {
2904  auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
2905  if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2906  srcVecType.getNumElements())
2907  return failure();
2908  rewriter.replaceOpWithNewOp<BroadcastOp>(
2909  insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2910  return success();
2911  }
2912 };
2913 
2914 /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
2915 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
2916 public:
2918 
2919  LogicalResult matchAndRewrite(InsertOp op,
2920  PatternRewriter &rewriter) const override {
2921  auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2922  auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2923 
2924  if (!srcSplat || !dstSplat)
2925  return failure();
2926 
2927  if (srcSplat.getInput() != dstSplat.getInput())
2928  return failure();
2929 
2930  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
2931  return success();
2932  }
2933 };
2934 
2935 // Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
2936 class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
2937 public:
2939 
2940  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
2941  // unless the source vector constant has a single use.
2942  static constexpr int64_t vectorSizeFoldThreshold = 256;
2943 
2944  LogicalResult matchAndRewrite(InsertOp op,
2945  PatternRewriter &rewriter) const override {
2946  // TODO: Canonicalization for dynamic position not implemented yet.
2947  if (op.hasDynamicPosition())
2948  return failure();
2949 
2950  // Return if 'InsertOp' operand is not defined by a compatible vector
2951  // ConstantOp.
2952  TypedValue<VectorType> destVector = op.getDest();
2953  Attribute vectorDestCst;
2954  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
2955  return failure();
2956  auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
2957  if (!denseDest)
2958  return failure();
2959 
2960  VectorType destTy = destVector.getType();
2961  if (destTy.isScalable())
2962  return failure();
2963 
2964  // Make sure we do not create too many large constants.
2965  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
2966  !destVector.hasOneUse())
2967  return failure();
2968 
2969  Value sourceValue = op.getSource();
2970  Attribute sourceCst;
2971  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
2972  return failure();
2973 
2974  // Calculate the linearized position of the continuous chunk of elements to
2975  // insert.
2976  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
2977  copy(op.getStaticPosition(), completePositions.begin());
2978  int64_t insertBeginPosition =
2979  linearize(completePositions, computeStrides(destTy.getShape()));
2980 
2981  SmallVector<Attribute> insertedValues;
2982  Type destEltType = destTy.getElementType();
2983 
2984  // The `convertIntegerAttr` method specifically handles the case
2985  // for `llvm.mlir.constant` which can hold an attribute with a
2986  // different type than the return type.
2987  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
2988  for (auto value : denseSource.getValues<Attribute>())
2989  insertedValues.push_back(convertIntegerAttr(value, destEltType));
2990  } else {
2991  insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType));
2992  }
2993 
2994  auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
2995  copy(insertedValues, allValues.begin() + insertBeginPosition);
2996  auto newAttr = DenseElementsAttr::get(destTy, allValues);
2997 
2998  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
2999  return success();
3000  }
3001 
3002 private:
3003  /// Converts the expected type to an IntegerAttr if there's
3004  /// a mismatch.
3005  Attribute convertIntegerAttr(Attribute attr, Type expectedType) const {
3006  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3007  if (intAttr.getType() != expectedType)
3008  return IntegerAttr::get(expectedType, intAttr.getInt());
3009  }
3010  return attr;
3011  }
3012 };
3013 
3014 } // namespace
3015 
3016 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3017  MLIRContext *context) {
3018  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3019  InsertOpConstantFolder>(context);
3020 }
3021 
3022 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3023  // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3024  // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3025  // (type mismatch).
3026  if (getNumIndices() == 0 && getSourceType() == getType())
3027  return getSource();
3028  return {};
3029 }
3030 
3031 //===----------------------------------------------------------------------===//
3032 // InsertStridedSliceOp
3033 //===----------------------------------------------------------------------===//
3034 
3035 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3036  Value source, Value dest,
3037  ArrayRef<int64_t> offsets,
3038  ArrayRef<int64_t> strides) {
3039  result.addOperands({source, dest});
3040  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3041  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3042  result.addTypes(dest.getType());
3043  result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
3044  offsetsAttr);
3045  result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
3046  stridesAttr);
3047 }
3048 
3049 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
3050 template <typename OpType>
3051 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3052  ArrayAttr arrayAttr,
3053  ArrayRef<int64_t> shape,
3054  StringRef attrName) {
3055  if (arrayAttr.size() > shape.size())
3056  return op.emitOpError("expected ")
3057  << attrName << " attribute of rank no greater than vector rank";
3058  return success();
3059 }
3060 
3061 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3062 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3063 // Otherwise, the admissible interval is [min, max].
3064 template <typename OpType>
3065 static LogicalResult
3066 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
3067  int64_t max, StringRef attrName,
3068  bool halfOpen = true) {
3069  for (auto attr : arrayAttr) {
3070  auto val = llvm::cast<IntegerAttr>(attr).getInt();
3071  auto upper = max;
3072  if (!halfOpen)
3073  upper += 1;
3074  if (val < min || val >= upper)
3075  return op.emitOpError("expected ") << attrName << " to be confined to ["
3076  << min << ", " << upper << ")";
3077  }
3078  return success();
3079 }
3080 
3081 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3082 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3083 // Otherwise, the admissible interval is [min, max].
3084 template <typename OpType>
3085 static LogicalResult
3086 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
3087  ArrayRef<int64_t> shape, StringRef attrName,
3088  bool halfOpen = true, int64_t min = 0) {
3089  for (auto [index, attrDimPair] :
3090  llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3091  int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3092  int64_t max = std::get<1>(attrDimPair);
3093  if (!halfOpen)
3094  max += 1;
3095  if (val < min || val >= max)
3096  return op.emitOpError("expected ")
3097  << attrName << " dimension " << index << " to be confined to ["
3098  << min << ", " << max << ")";
3099  }
3100  return success();
3101 }
3102 
3103 // Returns true if, for all indices i = 0..shape.size()-1, val is in the
3104 // [min, max} interval:
3105 // val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3106 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3107 // the admissible interval is [min, max].
3108 template <typename OpType>
3110  OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3111  ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3112  bool halfOpen = true, int64_t min = 1) {
3113  assert(arrayAttr1.size() <= shape.size());
3114  assert(arrayAttr2.size() <= shape.size());
3115  for (auto [index, it] :
3116  llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3117  auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3118  auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3119  int64_t max = std::get<2>(it);
3120  if (!halfOpen)
3121  max += 1;
3122  if (val1 + val2 < 0 || val1 + val2 >= max)
3123  return op.emitOpError("expected sum(")
3124  << attrName1 << ", " << attrName2 << ") dimension " << index
3125  << " to be confined to [" << min << ", " << max << ")";
3126  }
3127  return success();
3128 }
3129 
3130 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
3131  MLIRContext *context) {
3132  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3133  return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3134  });
3135  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3136 }
3137 
3138 LogicalResult InsertStridedSliceOp::verify() {
3139  auto sourceVectorType = getSourceVectorType();
3140  auto destVectorType = getDestVectorType();
3141  auto offsets = getOffsetsAttr();
3142  auto strides = getStridesAttr();
3143  if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3144  return emitOpError(
3145  "expected offsets of same size as destination vector rank");
3146  if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3147  return emitOpError("expected strides of same size as source vector rank");
3148  if (sourceVectorType.getRank() > destVectorType.getRank())
3149  return emitOpError(
3150  "expected source rank to be no greater than destination rank");
3151 
3152  auto sourceShape = sourceVectorType.getShape();
3153  auto destShape = destVectorType.getShape();
3154  SmallVector<int64_t, 4> sourceShapeAsDestShape(
3155  destShape.size() - sourceShape.size(), 0);
3156  sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3157  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3158  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3159  if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3160  offName)) ||
3161  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3162  /*max=*/1, stridesName,
3163  /*halfOpen=*/false)) ||
3165  *this, offsets,
3166  makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3167  offName, "source vector shape",
3168  /*halfOpen=*/false, /*min=*/1)))
3169  return failure();
3170 
3171  unsigned rankDiff = destShape.size() - sourceShape.size();
3172  for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3173  if (sourceVectorType.getScalableDims()[idx] !=
3174  destVectorType.getScalableDims()[idx + rankDiff]) {
3175  return emitOpError("mismatching scalable flags (at source vector idx=")
3176  << idx << ")";
3177  }
3178  if (sourceVectorType.getScalableDims()[idx]) {
3179  auto sourceSize = sourceShape[idx];
3180  auto destSize = destShape[idx + rankDiff];
3181  if (sourceSize != destSize) {
3182  return emitOpError("expected size at idx=")
3183  << idx
3184  << (" to match the corresponding base size from the input "
3185  "vector (")
3186  << sourceSize << (" vs ") << destSize << (")");
3187  }
3188  }
3189  }
3190 
3191  return success();
3192 }
3193 
3194 namespace {
3195 /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3196 /// SplatOp(X):dst_type) to SplatOp(X):dst_type.
3197 class FoldInsertStridedSliceSplat final
3198  : public OpRewritePattern<InsertStridedSliceOp> {
3199 public:
3201 
3202  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3203  PatternRewriter &rewriter) const override {
3204  auto srcSplatOp =
3205  insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
3206  auto destSplatOp =
3207  insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3208 
3209  if (!srcSplatOp || !destSplatOp)
3210  return failure();
3211 
3212  if (srcSplatOp.getInput() != destSplatOp.getInput())
3213  return failure();
3214 
3215  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3216  return success();
3217  }
3218 };
3219 
3220 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3221 /// to dst.
3222 class FoldInsertStridedSliceOfExtract final
3223  : public OpRewritePattern<InsertStridedSliceOp> {
3224 public:
3226 
3227  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3228  PatternRewriter &rewriter) const override {
3229  auto extractStridedSliceOp =
3230  insertStridedSliceOp.getSource()
3231  .getDefiningOp<vector::ExtractStridedSliceOp>();
3232 
3233  if (!extractStridedSliceOp)
3234  return failure();
3235 
3236  if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3237  return failure();
3238 
3239  // Check if have the same strides and offsets.
3240  if (extractStridedSliceOp.getStrides() !=
3241  insertStridedSliceOp.getStrides() ||
3242  extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3243  return failure();
3244 
3245  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3246  return success();
3247  }
3248 };
3249 
3250 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3251 // ConstantOp.
3252 class InsertStridedSliceConstantFolder final
3253  : public OpRewritePattern<InsertStridedSliceOp> {
3254 public:
3256 
3257  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3258  // unless the source vector constant has a single use.
3259  static constexpr int64_t vectorSizeFoldThreshold = 256;
3260 
3261  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3262  PatternRewriter &rewriter) const override {
3263  // Return if 'InsertOp' operand is not defined by a compatible vector
3264  // ConstantOp.
3265  TypedValue<VectorType> destVector = op.getDest();
3266  Attribute vectorDestCst;
3267  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3268  return failure();
3269 
3270  VectorType destTy = destVector.getType();
3271  if (destTy.isScalable())
3272  return failure();
3273 
3274  // Make sure we do not create too many large constants.
3275  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3276  !destVector.hasOneUse())
3277  return failure();
3278 
3279  auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3280 
3281  TypedValue<VectorType> sourceValue = op.getSource();
3282  Attribute sourceCst;
3283  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3284  return failure();
3285 
3286  // TODO: Handle non-unit strides when they become available.
3287  if (op.hasNonUnitStrides())
3288  return failure();
3289 
3290  VectorType sliceVecTy = sourceValue.getType();
3291  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3292  int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3293  SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3294  SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3295 
3296  // Calcualte the destination element indices by enumerating all slice
3297  // positions within the destination and linearizing them. The enumeration
3298  // order is lexicographic which yields a sequence of monotonically
3299  // increasing linearized position indices.
3300  // Because the destination may have higher dimensionality then the slice,
3301  // we keep track of two overlapping sets of positions and offsets.
3302  auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3303  auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3304  auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3305  SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3306  MutableArrayRef<int64_t> currSlicePosition(
3307  currDestPosition.begin() + rankDifference, currDestPosition.end());
3308  ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3309  offsets.end());
3310  do {
3311  int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3312  assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3313  assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3314  "Invalid slice element");
3315  newValues[linearizedPosition] = *sliceValuesIt;
3316  ++sliceValuesIt;
3317  } while (succeeded(
3318  incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3319 
3320  auto newAttr = DenseElementsAttr::get(destTy, newValues);
3321  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3322  return success();
3323  }
3324 };
3325 
3326 } // namespace
3327 
3328 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3329  RewritePatternSet &results, MLIRContext *context) {
3330  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3331  InsertStridedSliceConstantFolder>(context);
3332 }
3333 
3334 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3335  if (getSourceVectorType() == getDestVectorType())
3336  return getSource();
3337  return {};
3338 }
3339 
3340 //===----------------------------------------------------------------------===//
3341 // OuterProductOp
3342 //===----------------------------------------------------------------------===//
3343 
3344 /// Build an op without mask, use the type of `acc` as the return type.
3345 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
3346  Value lhs, Value rhs, Value acc) {
3347  result.addOperands({lhs, rhs, acc});
3348  result.addTypes(acc.getType());
3349 }
3350 
3352  p << " " << getLhs() << ", " << getRhs();
3353  if (getAcc()) {
3354  p << ", " << getAcc();
3355  p.printOptionalAttrDict((*this)->getAttrs());
3356  }
3357  p << " : " << getLhs().getType() << ", " << getRhs().getType();
3358 }
3359 
3360 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
3362  Type tLHS, tRHS;
3363  if (parser.parseOperandList(operandsInfo) ||
3364  parser.parseOptionalAttrDict(result.attributes) ||
3365  parser.parseColonType(tLHS) || parser.parseComma() ||
3366  parser.parseType(tRHS))
3367  return failure();
3368  if (operandsInfo.size() < 2)
3369  return parser.emitError(parser.getNameLoc(),
3370  "expected at least 2 operands");
3371  VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3372  VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3373  if (!vLHS)
3374  return parser.emitError(parser.getNameLoc(),
3375  "expected vector type for operand #1");
3376 
3377  VectorType resType;
3378  if (vRHS) {
3379  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3380  vRHS.getScalableDims()[0]};
3381  resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
3382  vLHS.getElementType(), scalableDimsRes);
3383  } else {
3384  // Scalar RHS operand
3385  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3386  resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3387  scalableDimsRes);
3388  }
3389 
3390  if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
3391  result.attributes.append(
3392  OuterProductOp::getKindAttrName(result.name),
3394  OuterProductOp::getDefaultKind()));
3395  }
3396 
3397  return failure(
3398  parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
3399  parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
3400  (operandsInfo.size() > 2 &&
3401  parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
3402  parser.addTypeToList(resType, result.types));
3403 }
3404 
3405 LogicalResult OuterProductOp::verify() {
3406  Type tRHS = getOperandTypeRHS();
3407  VectorType vLHS = getOperandVectorTypeLHS(),
3408  vRHS = llvm::dyn_cast<VectorType>(tRHS),
3409  vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3410 
3411  if (vLHS.getRank() != 1)
3412  return emitOpError("expected 1-d vector for operand #1");
3413 
3414  if (vRHS) {
3415  // Proper OUTER operation.
3416  if (vRHS.getRank() != 1)
3417  return emitOpError("expected 1-d vector for operand #2");
3418  if (vRES.getRank() != 2)
3419  return emitOpError("expected 2-d vector result");
3420  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3421  return emitOpError("expected #1 operand dim to match result dim #1");
3422  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3423  return emitOpError("expected #2 operand dim to match result dim #2");
3424  if (vLHS.isScalable() && !vRHS.isScalable()) {
3425  // This restriction reflects what's currently supported in terms of
3426  // scalable vectors. However, we could relax this if there's a use case.
3427  return emitOpError(
3428  "expected either both or only #2 operand dim to be scalable");
3429  }
3430  } else {
3431  // An AXPY operation.
3432  if (vRES.getRank() != 1)
3433  return emitOpError("expected 1-d vector result");
3434  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3435  return emitOpError("expected #1 operand dim to match result dim #1");
3436  }
3437 
3438  if (vACC && vACC != vRES)
3439  return emitOpError("expected operand #3 of same type as result type");
3440 
3441  // Verify supported combining kind.
3442  if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
3443  return emitOpError("unsupported outerproduct type");
3444 
3445  return success();
3446 }
3447 
3448 // MaskableOpInterface methods.
3449 
3450 /// Returns the mask type expected by this operation. Mostly used for
3451 /// verification purposes. It requires the operation to be vectorized."
3452 Type OuterProductOp::getExpectedMaskType() {
3453  auto vecType = this->getResultVectorType();
3454  return VectorType::get(vecType.getShape(),
3455  IntegerType::get(vecType.getContext(), /*width=*/1),
3456  vecType.getScalableDims());
3457 }
3458 
3459 //===----------------------------------------------------------------------===//
3460 // ExtractStridedSliceOp
3461 //===----------------------------------------------------------------------===//
3462 
3463 // Inference works as follows:
3464 // 1. Add 'sizes' from prefix of dims in 'offsets'.
3465 // 2. Add sizes from 'vectorType' for remaining dims.
3466 // Scalable flags are inherited from 'vectorType'.
3467 static Type inferStridedSliceOpResultType(VectorType vectorType,
3468  ArrayAttr offsets, ArrayAttr sizes,
3469  ArrayAttr strides) {
3470  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3472  shape.reserve(vectorType.getRank());
3473  unsigned idx = 0;
3474  for (unsigned e = offsets.size(); idx < e; ++idx)
3475  shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3476  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3477  shape.push_back(vectorType.getShape()[idx]);
3478 
3479  return VectorType::get(shape, vectorType.getElementType(),
3480  vectorType.getScalableDims());
3481 }
3482 
3483 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3484  Value source, ArrayRef<int64_t> offsets,
3485  ArrayRef<int64_t> sizes,
3486  ArrayRef<int64_t> strides) {
3487  result.addOperands(source);
3488  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3489  auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
3490  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3491  result.addTypes(
3492  inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
3493  offsetsAttr, sizesAttr, stridesAttr));
3494  result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
3495  offsetsAttr);
3496  result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
3497  sizesAttr);
3498  result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
3499  stridesAttr);
3500 }
3501 
3502 LogicalResult ExtractStridedSliceOp::verify() {
3503  auto type = getSourceVectorType();
3504  auto offsets = getOffsetsAttr();
3505  auto sizes = getSizesAttr();
3506  auto strides = getStridesAttr();
3507  if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3508  return emitOpError(
3509  "expected offsets, sizes and strides attributes of same size");
3510 
3511  auto shape = type.getShape();
3512  auto offName = getOffsetsAttrName();
3513  auto sizesName = getSizesAttrName();
3514  auto stridesName = getStridesAttrName();
3515  if (failed(
3516  isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
3517  failed(
3518  isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
3519  failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
3520  stridesName)) ||
3521  failed(
3522  isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
3523  failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
3524  /*halfOpen=*/false,
3525  /*min=*/1)) ||
3526  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3527  /*max=*/1, stridesName,
3528  /*halfOpen=*/false)) ||
3529  failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
3530  shape, offName, sizesName,
3531  /*halfOpen=*/false)))
3532  return failure();
3533 
3534  auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
3535  offsets, sizes, strides);
3536  if (getResult().getType() != resultType)
3537  return emitOpError("expected result type to be ") << resultType;
3538 
3539  for (unsigned idx = 0; idx < sizes.size(); ++idx) {
3540  if (type.getScalableDims()[idx]) {
3541  auto inputDim = type.getShape()[idx];
3542  auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3543  if (inputDim != inputSize)
3544  return emitOpError("expected size at idx=")
3545  << idx
3546  << (" to match the corresponding base size from the input "
3547  "vector (")
3548  << inputSize << (" vs ") << inputDim << (")");
3549  }
3550  }
3551 
3552  return success();
3553 }
3554 
3555 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
3556 // to use the source of the InsertStrided ops if we can detect that the
3557 // extracted vector is a subset of one of the vector inserted.
3558 static LogicalResult
3559 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
3560  // Helper to extract integer out of ArrayAttr.
3561  auto getElement = [](ArrayAttr array, int idx) {
3562  return llvm::cast<IntegerAttr>(array[idx]).getInt();
3563  };
3564  ArrayAttr extractOffsets = op.getOffsets();
3565  ArrayAttr extractStrides = op.getStrides();
3566  ArrayAttr extractSizes = op.getSizes();
3567  auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3568  while (insertOp) {
3569  if (op.getSourceVectorType().getRank() !=
3570  insertOp.getSourceVectorType().getRank())
3571  return failure();
3572  ArrayAttr insertOffsets = insertOp.getOffsets();
3573  ArrayAttr insertStrides = insertOp.getStrides();
3574  // If the rank of extract is greater than the rank of insert, we are likely
3575  // extracting a partial chunk of the vector inserted.
3576  if (extractOffsets.size() > insertOffsets.size())
3577  return failure();
3578  bool patialoverlap = false;
3579  bool disjoint = false;
3580  SmallVector<int64_t, 4> offsetDiffs;
3581  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3582  if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
3583  return failure();
3584  int64_t start = getElement(insertOffsets, dim);
3585  int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3586  int64_t offset = getElement(extractOffsets, dim);
3587  int64_t size = getElement(extractSizes, dim);
3588  // Check if the start of the extract offset is in the interval inserted.
3589  if (start <= offset && offset < end) {
3590  // If the extract interval overlaps but is not fully included we may
3591  // have a partial overlap that will prevent any folding.
3592  if (offset + size > end)
3593  patialoverlap = true;
3594  offsetDiffs.push_back(offset - start);
3595  continue;
3596  }
3597  disjoint = true;
3598  break;
3599  }
3600  // The extract element chunk is a subset of the insert element.
3601  if (!disjoint && !patialoverlap) {
3602  op.setOperand(insertOp.getSource());
3603  // OpBuilder is only used as a helper to build an I64ArrayAttr.
3604  OpBuilder b(op.getContext());
3605  op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
3606  return success();
3607  }
3608  // If the chunk extracted is disjoint from the chunk inserted, keep looking
3609  // in the insert chain.
3610  if (disjoint)
3611  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3612  else {
3613  // The extracted vector partially overlap the inserted vector, we cannot
3614  // fold.
3615  return failure();
3616  }
3617  }
3618  return failure();
3619 }
3620 
3621 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3622  if (getSourceVectorType() == getResult().getType())
3623  return getVector();
3624  if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
3625  return getResult();
3626  return {};
3627 }
3628 
3629 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
3630  populateFromInt64AttrArray(getOffsets(), results);
3631 }
3632 
3633 namespace {
3634 
3635 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
3636 // ConstantMaskOp.
3637 class StridedSliceConstantMaskFolder final
3638  : public OpRewritePattern<ExtractStridedSliceOp> {
3639 public:
3641 
3642  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3643  PatternRewriter &rewriter) const override {
3644  // Return if 'extractStridedSliceOp' operand is not defined by a
3645  // ConstantMaskOp.
3646  auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3647  auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3648  if (!constantMaskOp)
3649  return failure();
3650  // Return if 'extractStridedSliceOp' has non-unit strides.
3651  if (extractStridedSliceOp.hasNonUnitStrides())
3652  return failure();
3653  // Gather constant mask dimension sizes.
3654  ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
3655  // Gather strided slice offsets and sizes.
3656  SmallVector<int64_t, 4> sliceOffsets;
3657  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
3658  sliceOffsets);
3659  SmallVector<int64_t, 4> sliceSizes;
3660  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
3661 
3662  // Compute slice of vector mask region.
3663  SmallVector<int64_t, 4> sliceMaskDimSizes;
3664  sliceMaskDimSizes.reserve(maskDimSizes.size());
3665  for (auto [maskDimSize, sliceOffset, sliceSize] :
3666  llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3667  int64_t sliceMaskDimSize = std::max(
3668  static_cast<int64_t>(0),
3669  std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3670  sliceMaskDimSizes.push_back(sliceMaskDimSize);
3671  }
3672  // Add unchanged dimensions.
3673  if (sliceMaskDimSizes.size() < maskDimSizes.size())
3674  for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3675  sliceMaskDimSizes.push_back(maskDimSizes[i]);
3676  // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
3677  // region is a conjunction of mask dim intervals).
3678  if (llvm::is_contained(sliceMaskDimSizes, 0))
3679  sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3680 
3681  // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
3682  // region.
3683  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3684  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3685  sliceMaskDimSizes);
3686  return success();
3687  }
3688 };
3689 
3690 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3691 class StridedSliceSplatConstantFolder final
3692  : public OpRewritePattern<ExtractStridedSliceOp> {
3693 public:
3695 
3696  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3697  PatternRewriter &rewriter) const override {
3698  // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3699  // ConstantOp.
3700  Value sourceVector = extractStridedSliceOp.getVector();
3701  Attribute vectorCst;
3702  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3703  return failure();
3704 
3705  auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3706  if (!splat)
3707  return failure();
3708 
3709  auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
3710  splat.getSplatValue<Attribute>());
3711  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3712  newAttr);
3713  return success();
3714  }
3715 };
3716 
3717 // Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3718 // ConstantOp.
3719 class StridedSliceNonSplatConstantFolder final
3720  : public OpRewritePattern<ExtractStridedSliceOp> {
3721 public:
3723 
3724  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3725  PatternRewriter &rewriter) const override {
3726  // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3727  // ConstantOp.
3728  Value sourceVector = extractStridedSliceOp.getVector();
3729  Attribute vectorCst;
3730  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3731  return failure();
3732 
3733  // The splat case is handled by `StridedSliceSplatConstantFolder`.
3734  auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3735  if (!dense || dense.isSplat())
3736  return failure();
3737 
3738  // TODO: Handle non-unit strides when they become available.
3739  if (extractStridedSliceOp.hasNonUnitStrides())
3740  return failure();
3741 
3742  auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
3743  ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3744  SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3745 
3746  VectorType sliceVecTy = extractStridedSliceOp.getType();
3747  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3748  int64_t sliceRank = sliceVecTy.getRank();
3749 
3750  // Expand offsets and sizes to match the vector rank.
3751  SmallVector<int64_t, 4> offsets(sliceRank, 0);
3752  copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
3753 
3754  SmallVector<int64_t, 4> sizes(sourceShape);
3755  copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
3756 
3757  // Calculate the slice elements by enumerating all slice positions and
3758  // linearizing them. The enumeration order is lexicographic which yields a
3759  // sequence of monotonically increasing linearized position indices.
3760  auto denseValuesBegin = dense.value_begin<Attribute>();
3761  SmallVector<Attribute> sliceValues;
3762  sliceValues.reserve(sliceVecTy.getNumElements());
3763  SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3764  do {
3765  int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
3766  assert(linearizedPosition < sourceVecTy.getNumElements() &&
3767  "Invalid index");
3768  sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3769  } while (
3770  succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
3771 
3772  assert(static_cast<int64_t>(sliceValues.size()) ==
3773  sliceVecTy.getNumElements() &&
3774  "Invalid number of slice elements");
3775  auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
3776  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3777  newAttr);
3778  return success();
3779  }
3780 };
3781 
3782 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
3783 // BroadcastOp(ExtractStrideSliceOp).
3784 class StridedSliceBroadcast final
3785  : public OpRewritePattern<ExtractStridedSliceOp> {
3786 public:
3788 
3789  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3790  PatternRewriter &rewriter) const override {
3791  auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3792  if (!broadcast)
3793  return failure();
3794  auto srcVecType =
3795  llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
3796  unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3797  auto dstVecType = llvm::cast<VectorType>(op.getType());
3798  unsigned dstRank = dstVecType.getRank();
3799  unsigned rankDiff = dstRank - srcRank;
3800  // Check if the most inner dimensions of the source of the broadcast are the
3801  // same as the destination of the extract. If this is the case we can just
3802  // use a broadcast as the original dimensions are untouched.
3803  bool lowerDimMatch = true;
3804  for (unsigned i = 0; i < srcRank; i++) {
3805  if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3806  lowerDimMatch = false;
3807  break;
3808  }
3809  }
3810  Value source = broadcast.getSource();
3811  // If the inner dimensions don't match, it means we need to extract from the
3812  // source of the orignal broadcast and then broadcast the extracted value.
3813  // We also need to handle degenerated cases where the source is effectively
3814  // just a single scalar.
3815  bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3816  if (!lowerDimMatch && !isScalarSrc) {
3817  source = rewriter.create<ExtractStridedSliceOp>(
3818  op->getLoc(), source,
3819  getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
3820  getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
3821  getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
3822  }
3823  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
3824  return success();
3825  }
3826 };
3827 
3828 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
3829 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
3830 public:
3832 
3833  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3834  PatternRewriter &rewriter) const override {
3835  auto splat = op.getVector().getDefiningOp<SplatOp>();
3836  if (!splat)
3837  return failure();
3838  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
3839  return success();
3840  }
3841 };
3842 
3843 /// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
3844 /// slice is contiguous, into extract and shape_cast.
3845 ///
3846 /// Example:
3847 /// Before:
3848 /// %1 = vector.extract_strided_slice %arg0 {
3849 /// offsets = [0, 0, 0, 0, 0],
3850 /// sizes = [1, 1, 1, 1, 8],
3851 /// strides = [1, 1, 1, 1, 1]
3852 /// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
3853 /// After:
3854 /// %0 = vector.extract %arg0[0, 0, 0, 0]
3855 /// : vector<8xi8> from vector<8x1x1x2x8xi8>
3856 /// %1 = vector.shape_cast %0
3857 /// : vector<8xi8> to vector<1x1x1x1x8xi8>
3858 ///
3859 class ContiguousExtractStridedSliceToExtract final
3860  : public OpRewritePattern<ExtractStridedSliceOp> {
3861 public:
3863 
3864  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3865  PatternRewriter &rewriter) const override {
3866  if (op.hasNonUnitStrides())
3867  return failure();
3868  Value source = op.getOperand();
3869  auto sourceType = cast<VectorType>(source.getType());
3870  if (sourceType.isScalable() || sourceType.getRank() == 0)
3871  return failure();
3872 
3873  // Compute the number of offsets to pass to ExtractOp::build. That is the
3874  // difference between the source rank and the desired slice rank. We walk
3875  // the dimensions from innermost out, and stop when the next slice dimension
3876  // is not full-size.
3877  SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
3878  int numOffsets;
3879  for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
3880  if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
3881  break;
3882  }
3883 
3884  // If the created extract op would have no offsets, then this whole
3885  // extract_strided_slice is the identity and should have been handled by
3886  // other canonicalizations.
3887  if (numOffsets == 0)
3888  return failure();
3889 
3890  // If not even the inner-most dimension is full-size, this op can't be
3891  // rewritten as an ExtractOp.
3892  if (numOffsets == sourceType.getRank() &&
3893  static_cast<int>(sizes.size()) == sourceType.getRank())
3894  return failure();
3895 
3896  // The outer dimensions must have unit size.
3897  for (int i = 0; i < numOffsets; ++i) {
3898  if (sizes[i] != 1)
3899  return failure();
3900  }
3901 
3902  // Avoid generating slices that have leading unit dimensions. The shape_cast
3903  // op that we create below would take bad generic fallback patterns
3904  // (ShapeCastOpRewritePattern).
3905  while (sizes[numOffsets] == 1 &&
3906  numOffsets < static_cast<int>(sizes.size()) - 1) {
3907  ++numOffsets;
3908  }
3909 
3910  SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
3911  auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
3912  Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
3913  extractOffsets);
3914  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
3915  return success();
3916  }
3917 };
3918 
3919 } // namespace
3920 
3921 void ExtractStridedSliceOp::getCanonicalizationPatterns(
3922  RewritePatternSet &results, MLIRContext *context) {
3923  // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
3924  // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
3925  results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
3926  StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3927  StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
3928  context);
3929 }
3930 
3931 //===----------------------------------------------------------------------===//
3932 // TransferReadOp
3933 //===----------------------------------------------------------------------===//
3934 
3935 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
3936 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3937  VectorType vectorType, Value source,
3938  ValueRange indices, AffineMapAttr permutationMapAttr,
3939  /*optional*/ ArrayAttr inBoundsAttr) {
3940  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3941  Value padding = builder.create<arith::ConstantOp>(
3942  result.location, elemType, builder.getZeroAttr(elemType));
3943  build(builder, result, vectorType, source, indices, permutationMapAttr,
3944  padding, /*mask=*/Value(), inBoundsAttr);
3945 }
3946 
3947 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
3948 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3949  VectorType vectorType, Value source,
3950  ValueRange indices, AffineMap permutationMap,
3951  std::optional<ArrayRef<bool>> inBounds) {
3952  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3953  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3954  ? builder.getBoolArrayAttr(inBounds.value())
3955  : builder.getBoolArrayAttr(
3956  SmallVector<bool>(vectorType.getRank(), false));
3957  build(builder, result, vectorType, source, indices, permutationMapAttr,
3958  inBoundsAttr);
3959 }
3960 
3961 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
3962 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3963  VectorType vectorType, Value source,
3964  ValueRange indices, Value padding,
3965  std::optional<ArrayRef<bool>> inBounds) {
3966  AffineMap permutationMap = getTransferMinorIdentityMap(
3967  llvm::cast<ShapedType>(source.getType()), vectorType);
3968  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3969  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3970  ? builder.getBoolArrayAttr(inBounds.value())
3971  : builder.getBoolArrayAttr(
3972  SmallVector<bool>(vectorType.getRank(), false));
3973  build(builder, result, vectorType, source, indices, permutationMapAttr,
3974  padding,
3975  /*mask=*/Value(), inBoundsAttr);
3976 }
3977 
3978 /// 4. Builder that sets padding to zero and permutation map to
3979 /// 'getMinorIdentityMap'.
3980 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3981  VectorType vectorType, Value source,
3982  ValueRange indices,
3983  std::optional<ArrayRef<bool>> inBounds) {
3984  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3985  Value padding = builder.create<arith::ConstantOp>(
3986  result.location, elemType, builder.getZeroAttr(elemType));
3987  build(builder, result, vectorType, source, indices, padding, inBounds);
3988 }
3989 
3990 template <typename EmitFun>
3991 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
3992  EmitFun emitOpError) {
3993  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
3994  for (auto expr : permutationMap.getResults()) {
3995  auto dim = dyn_cast<AffineDimExpr>(expr);
3996  auto zero = dyn_cast<AffineConstantExpr>(expr);
3997  if (zero) {
3998  if (zero.getValue() != 0) {
3999  return emitOpError(
4000  "requires a projected permutation_map (at most one dim or the zero "
4001  "constant can appear in each result)");
4002  }
4003  continue;
4004  }
4005  if (!dim) {
4006  return emitOpError("requires a projected permutation_map (at most one "
4007  "dim or the zero constant can appear in each result)");
4008  }
4009  if (seen[dim.getPosition()]) {
4010  return emitOpError(
4011  "requires a permutation_map that is a permutation (found one dim "
4012  "used more than once)");
4013  }
4014  seen[dim.getPosition()] = true;
4015  }
4016  return success();
4017 }
4018 
4019 static LogicalResult
4020 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4021  VectorType vectorType, VectorType maskType,
4022  VectorType inferredMaskType, AffineMap permutationMap,
4023  ArrayAttr inBounds) {
4024  if (op->hasAttr("masked")) {
4025  return op->emitOpError("masked attribute has been removed. "
4026  "Use in_bounds instead.");
4027  }
4028 
4029  if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4030  return op->emitOpError(
4031  "requires source to be a memref or ranked tensor type");
4032 
4033  auto elementType = shapedType.getElementType();
4034  DataLayout dataLayout = DataLayout::closest(op);
4035  if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4036  // Memref or tensor has vector element type.
4037  unsigned sourceVecSize =
4038  dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
4039  vectorElementType.getShape().back();
4040  unsigned resultVecSize =
4041  dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
4042  vectorType.getShape().back();
4043  if (resultVecSize % sourceVecSize != 0)
4044  return op->emitOpError(
4045  "requires the bitwidth of the minor 1-D vector to be an integral "
4046  "multiple of the bitwidth of the minor 1-D vector of the source");
4047 
4048  unsigned sourceVecEltRank = vectorElementType.getRank();
4049  unsigned resultVecRank = vectorType.getRank();
4050  if (sourceVecEltRank > resultVecRank)
4051  return op->emitOpError(
4052  "requires source vector element and vector result ranks to match.");
4053  unsigned rankOffset = resultVecRank - sourceVecEltRank;
4054  // Check that permutation map results match 'rankOffset' of vector type.
4055  if (permutationMap.getNumResults() != rankOffset)
4056  return op->emitOpError("requires a permutation_map with result dims of "
4057  "the same rank as the vector type");
4058 
4059  if (maskType)
4060  return op->emitOpError("does not support masks with vector element type");
4061  } else {
4062  // Memref or tensor has scalar element type.
4063  unsigned minorSize =
4064  vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4065  unsigned resultVecSize =
4066  dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
4067  if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
4068  return op->emitOpError(
4069  "requires the bitwidth of the minor 1-D vector to be an integral "
4070  "multiple of the bitwidth of the source element type");
4071 
4072  // Check that permutation map results match rank of vector type.
4073  if (permutationMap.getNumResults() != vectorType.getRank())
4074  return op->emitOpError("requires a permutation_map with result dims of "
4075  "the same rank as the vector type");
4076  }
4077 
4078  if (permutationMap.getNumSymbols() != 0)
4079  return op->emitOpError("requires permutation_map without symbols");
4080 
4081  if (permutationMap.getNumInputs() != shapedType.getRank())
4082  return op->emitOpError("requires a permutation_map with input dims of the "
4083  "same rank as the source type");
4084 
4085  if (maskType && maskType != inferredMaskType)
4086  return op->emitOpError("inferred mask type (")
4087  << inferredMaskType << ") and mask operand type (" << maskType
4088  << ") don't match";
4089 
4090  if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
4091  return op->emitOpError("expects the in_bounds attr of same rank "
4092  "as permutation_map results: ")
4093  << AffineMapAttr::get(permutationMap)
4094  << " vs inBounds of size: " << inBounds.size();
4095 
4096  return success();
4097 }
4098 
4099 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
4100  SmallVector<StringRef, 3> elidedAttrs;
4101  elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4102  if (op.getPermutationMap().isMinorIdentity())
4103  elidedAttrs.push_back(op.getPermutationMapAttrName());
4104  // Elide in_bounds attribute if all dims are out-of-bounds.
4105  if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
4106  elidedAttrs.push_back(op.getInBoundsAttrName());
4107  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4108 }
4109 
4111  p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
4112  if (getMask())
4113  p << ", " << getMask();
4114  printTransferAttrs(p, *this);
4115  p << " : " << getShapedType() << ", " << getVectorType();
4116 }
4117 
4118 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
4119  AffineMap permMap) {
4120  auto i1Type = IntegerType::get(permMap.getContext(), 1);
4121  AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
4122  assert(invPermMap && "Inversed permutation map couldn't be computed");
4123  SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
4124 
4125  SmallVector<bool> scalableDims =
4126  applyPermutationMap(invPermMap, vecType.getScalableDims());
4127 
4128  return VectorType::get(maskShape, i1Type, scalableDims);
4129 }
4130 
4131 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4132  auto &builder = parser.getBuilder();
4133  SMLoc typesLoc;
4134  OpAsmParser::UnresolvedOperand sourceInfo;
4136  OpAsmParser::UnresolvedOperand paddingInfo;
4137  SmallVector<Type, 2> types;
4139  // Parsing with support for paddingValue.
4140  if (parser.parseOperand(sourceInfo) ||
4141  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
4142  parser.parseComma() || parser.parseOperand(paddingInfo))
4143  return failure();
4144  ParseResult hasMask = parser.parseOptionalComma();
4145  if (hasMask.succeeded()) {
4146  if (parser.parseOperand(maskInfo))
4147  return failure();
4148  }
4149  if (parser.parseOptionalAttrDict(result.attributes) ||
4150  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4151  return failure();
4152  if (types.size() != 2)
4153  return parser.emitError(typesLoc, "requires two types");
4154  auto indexType = builder.getIndexType();
4155  auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4156  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4157  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4158  VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4159  if (!vectorType)
4160  return parser.emitError(typesLoc, "requires vector type");
4161  auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
4162  Attribute permMapAttr = result.attributes.get(permMapAttrName);
4163  AffineMap permMap;
4164  if (!permMapAttr) {
4165  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4166  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4167  } else {
4168  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4169  }
4170  auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
4171  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4172  if (!inBoundsAttr) {
4173  result.addAttribute(inBoundsAttrName,
4174  builder.getBoolArrayAttr(
4175  SmallVector<bool>(permMap.getNumResults(), false)));
4176  }
4177  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4178  parser.resolveOperands(indexInfo, indexType, result.operands) ||
4179  parser.resolveOperand(paddingInfo, shapedType.getElementType(),
4180  result.operands))
4181  return failure();
4182  if (hasMask.succeeded()) {
4183  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4184  return parser.emitError(
4185  maskInfo.location, "does not support masks with vector element type");
4186  if (vectorType.getRank() != permMap.getNumResults()) {
4187  return parser.emitError(typesLoc,
4188  "expected the same rank for the vector and the "
4189  "results of the permutation map");
4190  }
4191  // Instead of adding the mask type as an op type, compute it based on the
4192  // vector type and the permutation map (to keep the type signature small).
4193  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4194  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4195  return failure();
4196  }
4197  result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4198  builder.getDenseI32ArrayAttr(
4199  {1, static_cast<int32_t>(indexInfo.size()), 1,
4200  static_cast<int32_t>(hasMask.succeeded())}));
4201  return parser.addTypeToList(vectorType, result.types);
4202 }
4203 
4204 LogicalResult TransferReadOp::verify() {
4205  // Consistency of elemental types in source and vector.
4206  ShapedType shapedType = getShapedType();
4207  VectorType vectorType = getVectorType();
4208  VectorType maskType = getMaskType();
4209  auto paddingType = getPadding().getType();
4210  auto permutationMap = getPermutationMap();
4211  VectorType inferredMaskType =
4212  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4213  : VectorType();
4214  auto sourceElementType = shapedType.getElementType();
4215 
4216  if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
4217  return emitOpError("requires ") << shapedType.getRank() << " indices";
4218 
4219  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4220  shapedType, vectorType, maskType,
4221  inferredMaskType, permutationMap, getInBounds())))
4222  return failure();
4223 
4224  if (auto sourceVectorElementType =
4225  llvm::dyn_cast<VectorType>(sourceElementType)) {
4226  // Source has vector element type.
4227  // Check that 'sourceVectorElementType' and 'paddingType' types match.
4228  if (sourceVectorElementType != paddingType)
4229  return emitOpError(
4230  "requires source element type and padding type to match.");
4231 
4232  } else {
4233  // Check that 'paddingType' is valid to store in a vector type.
4234  if (!VectorType::isValidElementType(paddingType))
4235  return emitOpError("requires valid padding vector elemental type");
4236 
4237  // Check that padding type and vector element types match.
4238  if (paddingType != sourceElementType)
4239  return emitOpError(
4240  "requires formal padding and source of the same elemental type");
4241  }
4242 
4243  return verifyPermutationMap(permutationMap,
4244  [&](Twine t) { return emitOpError(t); });
4245 }
4246 
4247 // MaskableOpInterface methods.
4248 
4249 /// Returns the mask type expected by this operation. Mostly used for
4250 /// verification purposes. It requires the operation to be vectorized."
4251 Type TransferReadOp::getExpectedMaskType() {
4252  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4253 }
4254 
4255 template <typename TransferOp>
4256 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4257  // TODO: support more aggressive createOrFold on:
4258  // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4259  if (op.getShapedType().isDynamicDim(indicesIdx))
4260  return false;
4261  Value index = op.getIndices()[indicesIdx];
4262  std::optional<int64_t> cstOp = getConstantIntValue(index);
4263  if (!cstOp.has_value())
4264  return false;
4265 
4266  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4267  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4268 
4269  return cstOp.value() + vectorSize <= sourceSize;
4270 }
4271 
4272 template <typename TransferOp>
4273 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4274  // TODO: support 0-d corner case.
4275  // TODO: Be less conservative.
4276  if (op.getTransferRank() == 0)
4277  return failure();
4278  AffineMap permutationMap = op.getPermutationMap();
4279  bool changed = false;
4280  SmallVector<bool, 4> newInBounds;
4281  newInBounds.reserve(op.getTransferRank());
4282  // Idxs of non-bcast dims - used when analysing bcast dims.
4283  SmallVector<unsigned> nonBcastDims;
4284 
4285  // 1. Process non-broadcast dims
4286  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
4287  // 1.1. Already marked as in-bounds, nothing to see here.
4288  if (op.isDimInBounds(i)) {
4289  newInBounds.push_back(true);
4290  continue;
4291  }
4292  // 1.2. Currently out-of-bounds, check whether we can statically determine
4293  // it is inBounds.
4294  bool inBounds = false;
4295  auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
4296  if (dimExpr) {
4297  inBounds = isInBounds(op, /*resultIdx=*/i,
4298  /*indicesIdx=*/dimExpr.getPosition());
4299  nonBcastDims.push_back(i);
4300  }
4301 
4302  newInBounds.push_back(inBounds);
4303  // We commit the pattern if it is "more inbounds".
4304  changed |= inBounds;
4305  }
4306 
4307  // 2. Handle broadcast dims
4308  // If all non-broadcast dims are "in bounds", then all bcast dims should be
4309  // "in bounds" as well.
4310  bool allNonBcastDimsInBounds = llvm::all_of(
4311  nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
4312  if (allNonBcastDimsInBounds) {
4313  for (size_t idx : permutationMap.getBroadcastDims()) {
4314  changed |= !newInBounds[idx];
4315  newInBounds[idx] = true;
4316  }
4317  }
4318 
4319  if (!changed)
4320  return failure();
4321  // OpBuilder is only used as a helper to build an I64ArrayAttr.
4322  OpBuilder b(op.getContext());
4323  op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
4324  return success();
4325 }
4326 
4327 template <typename TransferOp>
4328 static LogicalResult foldTransferFullMask(TransferOp op) {
4329  auto mask = op.getMask();
4330  if (!mask)
4331  return failure();
4332 
4333  if (getMaskFormat(mask) != MaskFormat::AllTrue)
4334  return failure();
4335 
4336  op.getMaskMutable().clear();
4337  return success();
4338 }
4339 
4340 /// ```
4341 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4342 /// : vector<1x4xf32>, tensor<4x4xf32>
4343 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
4344 /// : tensor<4x4xf32>, vector<1x4xf32>
4345 /// ```
4346 /// -> Folds into
4347 /// ```
4348 /// %v0
4349 /// ```
4350 static Value foldRAW(TransferReadOp readOp) {
4351  if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4352  return {};
4353  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4354  while (defWrite) {
4355  if (checkSameValueRAW(defWrite, readOp))
4356  return defWrite.getVector();
4358  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4359  cast<VectorTransferOpInterface>(readOp.getOperation())))
4360  break;
4361  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4362  }
4363  return {};
4364 }
4365 
4366 OpFoldResult TransferReadOp::fold(FoldAdaptor) {
4367  if (Value vec = foldRAW(*this))
4368  return vec;
4369  /// transfer_read(memrefcast) -> transfer_read
4370  if (succeeded(foldTransferInBoundsAttribute(*this)))
4371  return getResult();
4372  if (succeeded(foldTransferFullMask(*this)))
4373  return getResult();
4374  if (succeeded(memref::foldMemRefCast(*this)))
4375  return getResult();
4376  if (succeeded(tensor::foldTensorCast(*this)))
4377  return getResult();
4378  return OpFoldResult();
4379 }
4380 
4381 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4382  return llvm::to_vector<4>(getVectorType().getShape());
4383 }
4384 
4385 void TransferReadOp::getEffects(
4387  &effects) {
4388  if (llvm::isa<MemRefType>(getShapedType()))
4389  effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(),
4391 }
4392 
4393 Speculation::Speculatability TransferReadOp::getSpeculatability() {
4394  if (hasPureTensorSemantics())
4397 }
4398 
4399 namespace {
4400 /// Store to load forwarding for transfer operations with permuation maps.
4401 /// Even if the permutation maps are different we can still propagate the store
4402 /// into the load if the size of the dimensions read and written match. Then we
4403 /// can replace the transfer_read + transfer_write by vector.broadcast and
4404 /// vector.transpose.
4405 /// Example:
4406 /// ```
4407 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
4408 /// {in_bounds = [true, true],
4409 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
4410 /// vector<4x1xf32>, tensor<4x4x4xf32>
4411 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
4412 /// {in_bounds = [true, true, true, true],
4413 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
4414 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
4415 /// ```
4416 /// To:
4417 /// ```
4418 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
4419 /// %r = vector.transpose %0, [3, 0, 2, 1] :
4420 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
4421 /// ```
4422 struct TransferReadAfterWriteToBroadcast
4423  : public OpRewritePattern<TransferReadOp> {
4425 
4426  LogicalResult matchAndRewrite(TransferReadOp readOp,
4427  PatternRewriter &rewriter) const override {
4428  if (readOp.hasOutOfBoundsDim() ||
4429  !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4430  return failure();
4431  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4432  if (!defWrite)
4433  return failure();
4434  // TODO: If the written transfer chunk is a superset of the read transfer
4435  // chunk we could do an extract_strided_slice.
4436  if (readOp.getTransferChunkAccessed() !=
4437  defWrite.getTransferChunkAccessed())
4438  return failure();
4439  // TODO: Support cases where a dim is explicitly written but implicitly
4440  // read (i.e., a unit dim that is rank reduced).
4441  if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
4442  getUnusedDimsBitVector({defWrite.getPermutationMap()}))
4443  return failure();
4444  if (readOp.getIndices() != defWrite.getIndices() ||
4445  readOp.getMask() != defWrite.getMask())
4446  return failure();
4447  Value vec = defWrite.getVector();
4448  // TODO: loop through the chain of transfer_write if we can prove that they
4449  // don't overlap with the transfer_read. This requires improving
4450  // `isDisjointTransferIndices` helper.
4451  AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4452  AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
4453  AffineMap map = readMap.compose(writeMap);
4454  if (map.getNumResults() == 0)
4455  return failure();
4456  // Calculate the permutation to apply to go from the vector stored to the
4457  // vector read.
4458  SmallVector<unsigned> permutation;
4459  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
4460  return failure();
4461 
4462  Location loc = readOp.getLoc();
4463  // Calculate the broadcast shape by applying the reverse permutation to the
4464  // final shape we want.
4465  ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
4466  SmallVector<int64_t> broadcastShape(destShape.size());
4467  SmallVector<bool> broadcastScalableFlags(destShape.size());
4468  for (const auto &pos : llvm::enumerate(permutation)) {
4469  broadcastShape[pos.value()] = destShape[pos.index()];
4470  broadcastScalableFlags[pos.value()] =
4471  readOp.getVectorType().getScalableDims()[pos.index()];
4472  }
4473  VectorType broadcastedType = VectorType::get(
4474  broadcastShape, defWrite.getVectorType().getElementType(),
4475  broadcastScalableFlags);
4476  vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
4477  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
4478  rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
4479  transposePerm);
4480  return success();
4481  }
4482 };
4483 } // namespace
4484 
4485 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4486  MLIRContext *context) {
4487  results.add<TransferReadAfterWriteToBroadcast>(context);
4488 }
4489 
4490 //===----------------------------------------------------------------------===//
4491 // TransferWriteOp
4492 //===----------------------------------------------------------------------===//
4493 
4494 /// 1. Builder with type inference.
4495 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4496  Value vector, Value dest, ValueRange indices,
4497  AffineMapAttr permutationMapAttr,
4498  /*optional*/ Value mask,
4499  /*optional*/ ArrayAttr inBoundsAttr) {
4500  Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
4501  build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4502  mask, inBoundsAttr);
4503 }
4504 
4505 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
4506 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4507  Value vector, Value dest, ValueRange indices,
4508  AffineMapAttr permutationMapAttr,
4509  /*optional*/ ArrayAttr inBoundsAttr) {
4510  build(builder, result, vector, dest, indices, permutationMapAttr,
4511  /*mask=*/Value(), inBoundsAttr);
4512 }
4513 
4514 /// 3. Builder with type inference that sets an empty mask (variant without
4515 /// attrs)
4516 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4517  Value vector, Value dest, ValueRange indices,
4518  AffineMap permutationMap,
4519  std::optional<ArrayRef<bool>> inBounds) {
4520  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4521  auto inBoundsAttr =
4522  (inBounds && !inBounds.value().empty())
4523  ? builder.getBoolArrayAttr(inBounds.value())
4525  llvm::cast<VectorType>(vector.getType()).getRank(), false));
4526  build(builder, result, vector, dest, indices, permutationMapAttr,
4527  /*mask=*/Value(), inBoundsAttr);
4528 }
4529 
4530 /// 4. Builder with type inference that sets an empty mask and sets permutation
4531 /// map to 'getMinorIdentityMap'.
4532 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4533  Value vector, Value dest, ValueRange indices,
4534  std::optional<ArrayRef<bool>> inBounds) {
4535  auto vectorType = llvm::cast<VectorType>(vector.getType());
4536  AffineMap permutationMap = getTransferMinorIdentityMap(
4537  llvm::cast<ShapedType>(dest.getType()), vectorType);
4538  build(builder, result, vector, dest, indices, permutationMap, inBounds);
4539 }
4540 
4541 ParseResult TransferWriteOp::parse(OpAsmParser &parser,
4542  OperationState &result) {
4543  auto &builder = parser.getBuilder();
4544  SMLoc typesLoc;
4545  OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
4547  SmallVector<Type, 2> types;
4549  if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
4550  parser.parseOperand(sourceInfo) ||
4551  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
4552  return failure();
4553  ParseResult hasMask = parser.parseOptionalComma();
4554  if (hasMask.succeeded() && parser.parseOperand(maskInfo))
4555  return failure();
4556  if (parser.parseOptionalAttrDict(result.attributes) ||
4557  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4558  return failure();
4559  if (types.size() != 2)
4560  return parser.emitError(typesLoc, "requires two types");
4561  auto indexType = builder.getIndexType();
4562  VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4563  if (!vectorType)
4564  return parser.emitError(typesLoc, "requires vector type");
4565  ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4566  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4567  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4568  auto permMapAttrName =
4569  TransferWriteOp::getPermutationMapAttrName(result.name);
4570  auto permMapAttr = result.attributes.get(permMapAttrName);
4571  AffineMap permMap;
4572  if (!permMapAttr) {
4573  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4574  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4575  } else {
4576  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4577  }
4578  auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
4579  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4580  if (!inBoundsAttr) {
4581  result.addAttribute(inBoundsAttrName,
4582  builder.getBoolArrayAttr(
4583  SmallVector<bool>(permMap.getNumResults(), false)));
4584  }
4585  if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
4586  parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4587  parser.resolveOperands(indexInfo, indexType, result.operands))
4588  return failure();
4589  if (hasMask.succeeded()) {
4590  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4591  return parser.emitError(
4592  maskInfo.location, "does not support masks with vector element type");
4593  if (vectorType.getRank() != permMap.getNumResults()) {
4594  return parser.emitError(typesLoc,
4595  "expected the same rank for the vector and the "
4596  "results of the permutation map");
4597  }
4598  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4599  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4600  return failure();
4601  }
4602  result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4603  builder.getDenseI32ArrayAttr(
4604  {1, 1, static_cast<int32_t>(indexInfo.size()),
4605  static_cast<int32_t>(hasMask.succeeded())}));
4606  return failure(llvm::isa<RankedTensorType>(shapedType) &&
4607  parser.addTypeToList(shapedType, result.types));
4608 }
4609 
4611  p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
4612  if (getMask())
4613  p << ", " << getMask();
4614  printTransferAttrs(p, *this);
4615  p << " : " << getVectorType() << ", " << getShapedType();
4616 }
4617 
4618 LogicalResult TransferWriteOp::verify() {
4619  // Consistency of elemental types in shape and vector.
4620  ShapedType shapedType = getShapedType();
4621  VectorType vectorType = getVectorType();
4622  VectorType maskType = getMaskType();
4623  auto permutationMap = getPermutationMap();
4624  VectorType inferredMaskType =
4625  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4626  : VectorType();
4627 
4628  if (llvm::size(getIndices()) != shapedType.getRank())
4629  return emitOpError("requires ") << shapedType.getRank() << " indices";
4630 
4631  // We do not allow broadcast dimensions on TransferWriteOps for the moment,
4632  // as the semantics is unclear. This can be revisited later if necessary.
4633  if (hasBroadcastDim())
4634  return emitOpError("should not have broadcast dimensions");
4635 
4636  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4637  shapedType, vectorType, maskType,
4638  inferredMaskType, permutationMap, getInBounds())))
4639  return failure();
4640 
4641  return verifyPermutationMap(permutationMap,
4642  [&](Twine t) { return emitOpError(t); });
4643 }
4644 
4645 // MaskableOpInterface methods.
4646 
4647 /// Returns the mask type expected by this operation. Mostly used for
4648 /// verification purposes.
4649 Type TransferWriteOp::getExpectedMaskType() {
4650  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4651 }
4652 
4653 /// Fold:
4654 /// ```
4655 /// %t1 = ...
4656 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
4657 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4658 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
4659 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4660 /// ```
4661 ///
4662 /// into:
4663 ///
4664 /// ```
4665 /// %t0
4666 /// ```
4667 ///
4668 /// The producer of t1 may or may not be DCE'd depending on whether it is a
4669 /// block argument or has side effects.
4670 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4672  SmallVectorImpl<OpFoldResult> &results) {
4673  // TODO: support 0-d corner case.
4674  if (write.getTransferRank() == 0)
4675  return failure();
4676  auto rankedTensorType =
4677  llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4678  // If not operating on tensors, bail.
4679  if (!rankedTensorType)
4680  return failure();
4681  // If no read, bail.
4682  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4683  if (!read)
4684  return failure();
4685  // TODO: support 0-d corner case.
4686  if (read.getTransferRank() == 0)
4687  return failure();
4688  // For now, only accept minor identity. Future: composition is minor identity.
4689  if (!read.getPermutationMap().isMinorIdentity() ||
4690  !write.getPermutationMap().isMinorIdentity())
4691  return failure();
4692  // Bail on mismatching ranks.
4693  if (read.getTransferRank() != write.getTransferRank())
4694  return failure();
4695  // Bail on potential out-of-bounds accesses.
4696  if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4697  return failure();
4698  // Tensor types must be the same.
4699  if (read.getSource().getType() != rankedTensorType)
4700  return failure();
4701  // Vector types must be the same.
4702  if (read.getVectorType() != write.getVectorType())
4703  return failure();
4704  // Vector and Tensor shapes must match.
4705  if (read.getVectorType().getShape() != rankedTensorType.getShape())
4706  return failure();
4707  // If any index is nonzero.
4708  auto isNotConstantZero = [](Value v) {
4709  auto cstOp = getConstantIntValue(v);
4710  return !cstOp.has_value() || cstOp.value() != 0;
4711  };
4712  if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4713  llvm::any_of(write.getIndices(), isNotConstantZero))
4714  return failure();
4715  // Success.
4716  results.push_back(read.getSource());
4717  return success();
4718 }
4719 
4720 static bool checkSameValueWAR(vector::TransferReadOp read,
4721  vector::TransferWriteOp write) {
4722  return read.getSource() == write.getSource() &&
4723  read.getIndices() == write.getIndices() &&
4724  read.getPermutationMap() == write.getPermutationMap() &&
4725  read.getVectorType() == write.getVectorType() && !read.getMask() &&
4726  !write.getMask();
4727 }
4728 /// Fold transfer_write write after read:
4729 /// ```
4730 /// %t0 = ...
4731 /// %v = vector.transfer_read %t0[%c0...] :
4732 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4733 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
4734 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4735 /// ```
4736 ///
4737 /// into:
4738 ///
4739 /// ```
4740 /// %t0
4741 /// ```
4742 static LogicalResult foldWAR(TransferWriteOp write,
4743  SmallVectorImpl<OpFoldResult> &results) {
4744  if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4745  return failure();
4746  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4747  if (!read)
4748  return failure();
4749 
4750  if (!checkSameValueWAR(read, write))
4751  return failure();
4752  results.push_back(read.getSource());
4753  return success();
4754 }
4755 
4756 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4757  SmallVectorImpl<OpFoldResult> &results) {
4758  if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
4759  return success();
4760  if (succeeded(foldWAR(*this, results)))
4761  return success();
4762  if (succeeded(foldTransferInBoundsAttribute(*this)))
4763  return success();
4764  if (succeeded(foldTransferFullMask(*this)))
4765  return success();
4766  return memref::foldMemRefCast(*this);
4767 }
4768 
4769 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4770  return llvm::to_vector<4>(getVectorType().getShape());
4771 }
4772 
4773 void TransferWriteOp::getEffects(
4775  &effects) {
4776  if (llvm::isa<MemRefType>(getShapedType()))
4777  effects.emplace_back(MemoryEffects::Write::get(), &getSourceMutable(),
4779 }
4780 
4781 Speculation::Speculatability TransferWriteOp::getSpeculatability() {
4782  if (hasPureTensorSemantics())
4785 }
4786 
4787 namespace {
4788 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
4789 /// DCE
4790 /// ```
4791 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4792 /// : vector<1x4xf32>, tensor<4x4xf32>
4793 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
4794 /// : vector<1x4xf32>, tensor<4x4xf32>
4795 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4796 /// : vector<1x4xf32>, tensor<4x4xf32>
4797 /// ```
4798 ///
4799 /// into:
4800 ///
4801 /// ```
4802 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4803 /// : vector<1x4xf32>, tensor<4x4xf32>
4804 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
4805 /// : vector<1x4xf32>, tensor<4x4xf32>
4806 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4807 /// : vector<1x4xf32>, tensor<4x4xf32>
4808 /// ```
4809 ///
4810 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
4811 /// any other uses.
4812 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
4813 public:
4815  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
4816  PatternRewriter &rewriter) const override {
4817  if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4818  return failure();
4819  vector::TransferWriteOp writeToModify = writeOp;
4820 
4821  auto defWrite =
4822  writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4823  while (defWrite) {
4824  if (checkSameValueWAW(writeOp, defWrite)) {
4825  rewriter.modifyOpInPlace(writeToModify, [&]() {
4826  writeToModify.getSourceMutable().assign(defWrite.getSource());
4827  });
4828  return success();
4829  }
4831  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4832  cast<VectorTransferOpInterface>(writeOp.getOperation())))
4833  break;
4834  // If the previous write op doesn't have any other use we an safely look
4835  // at the previous store to see if it can be removed.
4836  if (!defWrite->hasOneUse())
4837  break;
4838  writeToModify = defWrite;
4839  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4840  }
4841  return failure();
4842  }
4843 };
4844 
4845 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
4846 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
4847 /// overwritten and inserted into another tensor. After this rewrite, the
4848 /// operations bufferize in-place since all of them work on the same slice.
4849 ///
4850 /// For example:
4851 /// ```mlir
4852 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
4853 /// : vector<8x16xf32>, tensor<8x16xf32>
4854 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
4855 /// : tensor<8x16xf32> to tensor<?x?xf32>
4856 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4857 /// : tensor<?x?xf32> into tensor<27x37xf32>
4858 /// ```
4859 /// folds to
4860 /// ```mlir
4861 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4862 /// : tensor<27x37xf32> to tensor<?x?xf32>
4863 /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
4864 /// : vector<8x16xf32>, tensor<?x?xf32>
4865 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4866 /// : tensor<?x?xf32> into tensor<27x37xf32>
4867 /// ```
4868 struct SwapExtractSliceOfTransferWrite
4869  : public OpRewritePattern<tensor::InsertSliceOp> {
4870 public:
4872 
4873  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
4874  PatternRewriter &rewriter) const override {
4875  if (!insertOp.hasUnitStride())
4876  return failure();
4877  auto extractOp =
4878  insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4879  if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4880  return failure();
4881  auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4882  if (!transferOp || !transferOp->hasOneUse())
4883  return failure();
4884 
4885  // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
4886  // rank-reducing.
4887  if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4888  return rewriter.notifyMatchFailure(insertOp,
4889  "use-def chain is rank-reducing");
4890  }
4891 
4892  // Fail if tensor::ExtractSliceOp has non-zero offset.
4893  if (!extractOp.hasZeroOffset()) {
4894  return rewriter.notifyMatchFailure(insertOp,
4895  "ExtractSliceOp has non-zero offset");
4896  }
4897 
4898  // Fail if tensor::TransferWriteOp has non-zero offset.
4899  if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
4900  return getConstantIntValue(value) == static_cast<int64_t>(0);
4901  })) {
4902  return rewriter.notifyMatchFailure(insertOp,
4903  "TranferWriteOp has non-zero offset");
4904  }
4905 
4906  // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
4907  if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
4908  return rewriter.notifyMatchFailure(
4909  insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
4910  }
4911 
4912  for (auto [insertSize, extractSize] :
4913  llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
4914  if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
4915  return rewriter.notifyMatchFailure(
4916  insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
4917  }
4918  }
4919 
4920  // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
4921  assert(transferOp.getVectorType().hasStaticShape() &&
4922  "expected vector to have a static shape");
4923  ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
4925  transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
4926  if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
4927  return rewriter.notifyMatchFailure(
4928  insertOp, "TransferWriteOp may not write the full tensor.");
4929  }
4930 
4931  // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
4932  // Set all in_bounds to false and let the folder infer them.
4933  SmallVector<bool> newInBounds(vectorShape.size(), false);
4934  auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
4935  extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
4936  insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
4937  insertOp.getMixedStrides());
4938  auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
4939  transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
4940  transferOp.getIndices(), transferOp.getPermutationMapAttr(),
4941  rewriter.getBoolArrayAttr(newInBounds));
4942  rewriter.modifyOpInPlace(insertOp, [&]() {
4943  insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
4944  });
4945  return success();
4946  }
4947 };
4948 
4949 } // namespace
4950 
4951 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
4952  MLIRContext *context) {
4953  results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
4954 }
4955 
4956 //===----------------------------------------------------------------------===//
4957 // LoadOp
4958 //===----------------------------------------------------------------------===//
4959 
4960 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
4961  VectorType vecTy,
4962  MemRefType memRefTy) {
4963  // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
4964  // need any strides limitations.
4965  if (!vecTy.isScalable() &&
4966  (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
4967  return success();
4968 
4969  if (!isLastMemrefDimUnitStride(memRefTy))
4970  return op->emitOpError("most minor memref dim must have unit stride");
4971  return success();
4972 }
4973 
4974 LogicalResult vector::LoadOp::verify() {
4975  VectorType resVecTy = getVectorType();
4976  MemRefType memRefTy = getMemRefType();
4977 
4978  if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
4979  return failure();
4980 
4981  // Checks for vector memrefs.
4982  Type memElemTy = memRefTy.getElementType();
4983  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4984  if (memVecTy != resVecTy)
4985  return emitOpError("base memref and result vector types should match");
4986  memElemTy = memVecTy.getElementType();
4987  }
4988 
4989  if (resVecTy.getElementType() != memElemTy)
4990  return emitOpError("base and result element types should match");
4991  if (llvm::size(getIndices()) != memRefTy.getRank())
4992  return emitOpError("requires ") << memRefTy.getRank() << " indices";
4993  return success();
4994 }
4995 
4996 OpFoldResult LoadOp::fold(FoldAdaptor) {
4997  if (succeeded(memref::foldMemRefCast(*this)))
4998  return getResult();
4999  return OpFoldResult();
5000 }
5001 
5002 //===----------------------------------------------------------------------===//
5003 // StoreOp
5004 //===----------------------------------------------------------------------===//
5005 
5006 LogicalResult vector::StoreOp::verify() {
5007  VectorType valueVecTy = getVectorType();
5008  MemRefType memRefTy = getMemRefType();
5009 
5010  if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
5011  return failure();
5012 
5013  // Checks for vector memrefs.
5014  Type memElemTy = memRefTy.getElementType();
5015  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5016  if (memVecTy != valueVecTy)
5017  return emitOpError(
5018  "base memref and valueToStore vector types should match");
5019  memElemTy = memVecTy.getElementType();
5020  }
5021 
5022  if (valueVecTy.getElementType() != memElemTy)
5023  return emitOpError("base and valueToStore element type should match");
5024  if (llvm::size(getIndices()) != memRefTy.getRank())
5025  return emitOpError("requires ") << memRefTy.getRank() << " indices";
5026  return success();
5027 }
5028 
5029 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5030  SmallVectorImpl<OpFoldResult> &results) {
5031  return memref::foldMemRefCast(*this);
5032 }
5033 
5034 //===----------------------------------------------------------------------===//
5035 // MaskedLoadOp
5036 //===----------------------------------------------------------------------===//
5037 
5038 LogicalResult MaskedLoadOp::verify() {
5039  VectorType maskVType = getMaskVectorType();
5040  VectorType passVType = getPassThruVectorType();
5041  VectorType resVType = getVectorType();
5042  MemRefType memType = getMemRefType();
5043 
5044  if (resVType.getElementType() != memType.getElementType())
5045  return emitOpError("base and result element type should match");
5046  if (llvm::size(getIndices()) != memType.getRank())
5047  return emitOpError("requires ") << memType.getRank() << " indices";
5048  if (resVType.getShape() != maskVType.getShape())
5049  return emitOpError("expected result shape to match mask shape");
5050  if (resVType != passVType)
5051  return emitOpError("expected pass_thru of same type as result type");
5052  return success();
5053 }
5054 
5055 namespace {
5056 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
5057 public:
5059  LogicalResult matchAndRewrite(MaskedLoadOp load,
5060  PatternRewriter &rewriter) const override {
5061  switch (getMaskFormat(load.getMask())) {
5062  case MaskFormat::AllTrue:
5063  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5064  load, load.getType(), load.getBase(), load.getIndices());
5065  return success();
5066  case MaskFormat::AllFalse:
5067  rewriter.replaceOp(load, load.getPassThru());
5068  return success();
5069  case MaskFormat::Unknown:
5070  return failure();
5071  }
5072  llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
5073  }
5074 };
5075 } // namespace
5076 
5077 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5078  MLIRContext *context) {
5079  results.add<MaskedLoadFolder>(context);
5080 }
5081 
5082 OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5083  if (succeeded(memref::foldMemRefCast(*this)))
5084  return getResult();
5085  return OpFoldResult();
5086 }
5087 
5088 //===----------------------------------------------------------------------===//
5089 // MaskedStoreOp
5090 //===----------------------------------------------------------------------===//
5091 
5092 LogicalResult MaskedStoreOp::verify() {
5093  VectorType maskVType = getMaskVectorType();
5094  VectorType valueVType = getVectorType();
5095  MemRefType memType = getMemRefType();
5096 
5097  if (valueVType.getElementType() != memType.getElementType())
5098  return emitOpError("base and valueToStore element type should match");
5099  if (llvm::size(getIndices()) != memType.getRank())
5100  return emitOpError("requires ") << memType.getRank() << " indices";
5101  if (valueVType.getShape() != maskVType.getShape())
5102  return emitOpError("expected valueToStore shape to match mask shape");
5103  return success();
5104 }
5105 
5106 namespace {
5107 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
5108 public:
5110  LogicalResult matchAndRewrite(MaskedStoreOp store,
5111  PatternRewriter &rewriter) const override {
5112  switch (getMaskFormat(store.getMask())) {
5113  case MaskFormat::AllTrue:
5114  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5115  store, store.getValueToStore(), store.getBase(), store.getIndices());
5116  return success();
5117  case MaskFormat::AllFalse:
5118  rewriter.eraseOp(store);
5119  return success();
5120  case MaskFormat::Unknown:
5121  return failure();
5122  }
5123  llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
5124  }
5125 };
5126 } // namespace
5127 
5128 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5129  MLIRContext *context) {
5130  results.add<MaskedStoreFolder>(context);
5131 }
5132 
5133 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5134  SmallVectorImpl<OpFoldResult> &results) {
5135  return memref::foldMemRefCast(*this);
5136 }
5137 
5138 //===----------------------------------------------------------------------===//
5139 // GatherOp
5140 //===----------------------------------------------------------------------===//
5141 
5142 LogicalResult GatherOp::verify() {
5143  VectorType indVType = getIndexVectorType();
5144  VectorType maskVType = getMaskVectorType();
5145  VectorType resVType = getVectorType();
5146  ShapedType baseType = getBaseType();
5147 
5148  if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5149  return emitOpError("requires base to be a memref or ranked tensor type");
5150 
5151  if (resVType.getElementType() != baseType.getElementType())
5152  return emitOpError("base and result element type should match");
5153  if (llvm::size(getIndices()) != baseType.getRank())
5154  return emitOpError("requires ") << baseType.getRank() << " indices";
5155  if (resVType.getShape() != indVType.getShape())
5156  return emitOpError("expected result dim to match indices dim");
5157  if (resVType.getShape() != maskVType.getShape())
5158  return emitOpError("expected result dim to match mask dim");
5159  if (resVType != getPassThruVectorType())
5160  return emitOpError("expected pass_thru of same type as result type");
5161  return success();
5162 }
5163 
5164 // MaskableOpInterface methods.
5165 
5166 /// Returns the mask type expected by this operation. Mostly used for
5167 /// verification purposes. It requires the operation to be vectorized."
5168 Type GatherOp::getExpectedMaskType() {
5169  auto vecType = this->getIndexVectorType();
5170  return VectorType::get(vecType.getShape(),
5171  IntegerType::get(vecType.getContext(), /*width=*/1),
5172  vecType.getScalableDims());
5173 }
5174 
5175 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5176  return llvm::to_vector<4>(getVectorType().getShape());
5177 }
5178 
5179 namespace {
5180 class GatherFolder final : public OpRewritePattern<GatherOp> {
5181 public:
5183  LogicalResult matchAndRewrite(GatherOp gather,
5184  PatternRewriter &rewriter) const override {
5185  switch (getMaskFormat(gather.getMask())) {
5186  case MaskFormat::AllTrue:
5187  return failure(); // no unmasked equivalent
5188  case MaskFormat::AllFalse:
5189  rewriter.replaceOp(gather, gather.getPassThru());
5190  return success();
5191  case MaskFormat::Unknown:
5192  return failure();
5193  }
5194  llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
5195  }
5196 };
5197 } // namespace
5198 
5199 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
5200  MLIRContext *context) {
5201  results.add<GatherFolder>(context);
5202 }
5203 
5204 //===----------------------------------------------------------------------===//
5205 // ScatterOp
5206 //===----------------------------------------------------------------------===//
5207 
5208 LogicalResult ScatterOp::verify() {
5209  VectorType indVType = getIndexVectorType();
5210  VectorType maskVType = getMaskVectorType();
5211  VectorType valueVType = getVectorType();
5212  MemRefType memType = getMemRefType();
5213 
5214  if (valueVType.getElementType() != memType.getElementType())
5215  return emitOpError("base and valueToStore element type should match");
5216  if (llvm::size(getIndices()) != memType.getRank())
5217  return emitOpError("requires ") << memType.getRank() << " indices";
5218  if (valueVType.getDimSize(0) != indVType.getDimSize(0))
5219  return emitOpError("expected valueToStore dim to match indices dim");
5220  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5221  return emitOpError("expected valueToStore dim to match mask dim");
5222  return success();
5223 }
5224 
5225 namespace {
5226 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
5227 public:
5229  LogicalResult matchAndRewrite(ScatterOp scatter,
5230  PatternRewriter &rewriter) const override {
5231  switch (getMaskFormat(scatter.getMask())) {
5232  case MaskFormat::AllTrue:
5233  return failure(); // no unmasked equivalent
5234  case MaskFormat::AllFalse:
5235  rewriter.eraseOp(scatter);
5236  return success();
5237  case MaskFormat::Unknown:
5238  return failure();
5239  }
5240  llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
5241  }
5242 };
5243 } // namespace
5244 
5245 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
5246  MLIRContext *context) {
5247  results.add<ScatterFolder>(context);
5248 }
5249 
5250 //===----------------------------------------------------------------------===//
5251 // ExpandLoadOp
5252 //===----------------------------------------------------------------------===//
5253 
5254 LogicalResult ExpandLoadOp::verify() {
5255  VectorType maskVType = getMaskVectorType();
5256  VectorType passVType = getPassThruVectorType();
5257  VectorType resVType = getVectorType();
5258  MemRefType memType = getMemRefType();
5259 
5260  if (resVType.getElementType() != memType.getElementType())
5261  return emitOpError("base and result element type should match");
5262  if (llvm::size(getIndices()) != memType.getRank())
5263  return emitOpError("requires ") << memType.getRank() << " indices";
5264  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5265  return emitOpError("expected result dim to match mask dim");
5266  if (resVType != passVType)
5267  return emitOpError("expected pass_thru of same type as result type");
5268  return success();
5269 }
5270 
5271 namespace {
5272 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
5273 public:
5275  LogicalResult matchAndRewrite(ExpandLoadOp expand,
5276  PatternRewriter &rewriter) const override {
5277  switch (getMaskFormat(expand.getMask())) {
5278  case MaskFormat::AllTrue:
5279  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5280  expand, expand.getType(), expand.getBase(), expand.getIndices());
5281  return success();
5282  case MaskFormat::AllFalse:
5283  rewriter.replaceOp(expand, expand.getPassThru());
5284  return success();
5285  case MaskFormat::Unknown:
5286  return failure();
5287  }
5288  llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
5289  }
5290 };
5291 } // namespace
5292 
5293 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5294  MLIRContext *context) {
5295  results.add<ExpandLoadFolder>(context);
5296 }
5297 
5298 //===----------------------------------------------------------------------===//
5299 // CompressStoreOp
5300 //===----------------------------------------------------------------------===//
5301 
5302 LogicalResult CompressStoreOp::verify() {
5303  VectorType maskVType = getMaskVectorType();
5304  VectorType valueVType = getVectorType();
5305  MemRefType memType = getMemRefType();
5306 
5307  if (valueVType.getElementType() != memType.getElementType())
5308  return emitOpError("base and valueToStore element type should match");
5309  if (llvm::size(getIndices()) != memType.getRank())
5310  return emitOpError("requires ") << memType.getRank() << " indices";
5311  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5312  return emitOpError("expected valueToStore dim to match mask dim");
5313  return success();
5314 }
5315 
5316 namespace {
5317 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
5318 public:
5320  LogicalResult matchAndRewrite(CompressStoreOp compress,
5321  PatternRewriter &rewriter) const override {
5322  switch (getMaskFormat(compress.getMask())) {
5323  case MaskFormat::AllTrue:
5324  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5325  compress, compress.getValueToStore(), compress.getBase(),
5326  compress.getIndices());
5327  return success();
5328  case MaskFormat::AllFalse:
5329  rewriter.eraseOp(compress);
5330  return success();
5331  case MaskFormat::Unknown:
5332  return failure();
5333  }
5334  llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
5335  }
5336 };
5337 } // namespace
5338 
5339 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5340  MLIRContext *context) {
5341  results.add<CompressStoreFolder>(context);
5342 }
5343 
5344 //===----------------------------------------------------------------------===//
5345 // ShapeCastOp
5346 //===----------------------------------------------------------------------===//
5347 
5348 void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
5349  SetIntRangeFn setResultRanges) {
5350  setResultRanges(getResult(), argRanges.front());
5351 }
5352 
5353 /// Returns true if each element of 'a' is equal to the product of a contiguous
5354 /// sequence of the elements of 'b'. Returns false otherwise.
5355 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
5356  unsigned rankA = a.size();
5357  unsigned rankB = b.size();
5358  assert(rankA < rankB);
5359 
5360  auto isOne = [](int64_t v) { return v == 1; };
5361 
5362  // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5363  // casted to a 0-d vector.
5364  if (rankA == 0 && llvm::all_of(b, isOne))
5365  return true;
5366 
5367  unsigned i = 0;
5368  unsigned j = 0;
5369  while (i < rankA && j < rankB) {
5370  int64_t dimA = a[i];
5371  int64_t dimB = 1;
5372  while (dimB < dimA && j < rankB)
5373  dimB *= b[j++];
5374  if (dimA != dimB)
5375  break;
5376  ++i;
5377 
5378  // Handle the case when trailing dimensions are of size 1.
5379  // Include them into the contiguous sequence.
5380  if (i < rankA && llvm::all_of(a.slice(i), isOne))
5381  i = rankA;
5382  if (j < rankB && llvm::all_of(b.slice(j), isOne))
5383  j = rankB;
5384  }
5385 
5386  return i == rankA && j == rankB;
5387 }
5388 
5389 static LogicalResult verifyVectorShapeCast(Operation *op,
5390  VectorType sourceVectorType,
5391  VectorType resultVectorType) {
5392  // Check that element type is the same.
5393  if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5394  return op->emitOpError("source/result vectors must have same element type");
5395  auto sourceShape = sourceVectorType.getShape();
5396  auto resultShape = resultVectorType.getShape();
5397 
5398  // Check that product of source dim sizes matches product of result dim sizes.
5399  int64_t sourceDimProduct = std::accumulate(
5400  sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5401  int64_t resultDimProduct = std::accumulate(
5402  resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5403  if (sourceDimProduct != resultDimProduct)
5404  return op->emitOpError("source/result number of elements must match");
5405 
5406  // Check that expanding/contracting rank cases.
5407  unsigned sourceRank = sourceVectorType.getRank();
5408  unsigned resultRank = resultVectorType.getRank();
5409  if (sourceRank < resultRank) {
5410  if (!isValidShapeCast(sourceShape, resultShape))
5411  return op->emitOpError("invalid shape cast");
5412  } else if (sourceRank > resultRank) {
5413  if (!isValidShapeCast(resultShape, sourceShape))
5414  return op->emitOpError("invalid shape cast");
5415  }
5416 
5417  // Check that (non-)scalability is preserved
5418  int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5419  int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5420  if (sourceNScalableDims != resultNScalableDims)
5421  return op->emitOpError("different number of scalable dims at source (")
5422  << sourceNScalableDims << ") and result (" << resultNScalableDims
5423  << ")";
5424  sourceVectorType.getNumDynamicDims();
5425 
5426  return success();
5427 }
5428 
5429 LogicalResult ShapeCastOp::verify() {
5430  auto sourceVectorType =
5431  llvm::dyn_cast_or_null<VectorType>(getSource().getType());
5432  auto resultVectorType =
5433  llvm::dyn_cast_or_null<VectorType>(getResult().getType());
5434 
5435  // Check if source/result are of vector type.
5436  if (sourceVectorType && resultVectorType)
5437  return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
5438 
5439  return success();
5440 }
5441 
5442 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5443  // No-op shape cast.
5444  if (getSource().getType() == getResult().getType())
5445  return getSource();
5446 
5447  // Canceling shape casts.
5448  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5449  if (getResult().getType() == otherOp.getSource().getType())
5450  return otherOp.getSource();
5451 
5452  // Only allows valid transitive folding.
5453  VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5454  VectorType resultType = llvm::cast<VectorType>(getResult().getType());
5455  if (srcType.getRank() < resultType.getRank()) {
5456  if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5457  return {};
5458  } else if (srcType.getRank() > resultType.getRank()) {
5459  if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5460  return {};
5461  } else {
5462  return {};
5463  }
5464 
5465  setOperand(otherOp.getSource());
5466  return getResult();
5467  }
5468 
5469  // Cancelling broadcast and shape cast ops.
5470  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5471  if (bcastOp.getSourceType() == getType())
5472  return bcastOp.getSource();
5473  }
5474 
5475  return {};
5476 }
5477 
5478 namespace {
5479 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
5480 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
5481 public:
5483 
5484  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5485  PatternRewriter &rewriter) const override {
5486  auto constantOp =
5487  shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5488  if (!constantOp)
5489  return failure();
5490  // Only handle splat for now.
5491  auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5492  if (!dense)
5493  return failure();
5494  auto newAttr =
5495  DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
5496  dense.getSplatValue<Attribute>());
5497  rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
5498  return success();
5499  }
5500 };
5501 
5502 /// Helper function that computes a new vector type based on the input vector
5503 /// type by removing the trailing one dims:
5504 ///
5505 /// vector<4x1x1xi1> --> vector<4x1>
5506 ///
5507 static VectorType trimTrailingOneDims(VectorType oldType) {
5508  ArrayRef<int64_t> oldShape = oldType.getShape();
5509  ArrayRef<int64_t> newShape = oldShape;
5510 
5511  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
5512  ArrayRef<bool> newScalableDims = oldScalableDims;
5513 
5514  while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5515  newShape = newShape.drop_back(1);
5516  newScalableDims = newScalableDims.drop_back(1);
5517  }
5518 
5519  // Make sure we have at least 1 dimension.
5520  // TODO: Add support for 0-D vectors.
5521  if (newShape.empty()) {
5522  newShape = oldShape.take_back();
5523  newScalableDims = oldScalableDims.take_back();
5524  }
5525 
5526  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5527 }
5528 
5529 /// Folds qualifying shape_cast(create_mask) into a new create_mask
5530 ///
5531 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
5532 /// dimension. If the input vector comes from `vector.create_mask` for which
5533 /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
5534 /// to fold shape_cast into create_mask.
5535 ///
5536 /// BEFORE:
5537 /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
5538 /// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
5539 /// AFTER:
5540 /// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
5541 class ShapeCastCreateMaskFolderTrailingOneDim final
5542  : public OpRewritePattern<ShapeCastOp> {
5543 public:
5545 
5546  LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
5547  PatternRewriter &rewriter) const override {
5548  Value shapeOpSrc = shapeOp->getOperand(0);
5549  auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
5550  auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
5551  if (!createMaskOp && !constantMaskOp)
5552  return failure();
5553 
5554  VectorType shapeOpResTy = shapeOp.getResultVectorType();
5555  VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5556 
5557  VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5558  if (newVecType != shapeOpResTy)
5559  return failure();
5560 
5561  auto numDimsToDrop =
5562  shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5563 
5564  // No unit dims to drop
5565  if (!numDimsToDrop)
5566  return failure();
5567 
5568  if (createMaskOp) {
5569  auto maskOperands = createMaskOp.getOperands();
5570  auto numMaskOperands = maskOperands.size();
5571 
5572  // Check every mask dim size to see whether it can be dropped
5573  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5574  --i) {
5575  auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5576  if (!constant || (constant.value() != 1))
5577  return failure();
5578  }
5579  SmallVector<Value> newMaskOperands =
5580  maskOperands.drop_back(numDimsToDrop);
5581 
5582  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
5583  newMaskOperands);
5584  return success();
5585  }
5586 
5587  if (constantMaskOp) {
5588  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5589  auto numMaskOperands = maskDimSizes.size();
5590 
5591  // Check every mask dim size to see whether it can be dropped
5592  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5593  --i) {
5594  if (maskDimSizes[i] != 1)
5595  return failure();
5596  }
5597 
5598  auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5599  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
5600  newMaskOperands);
5601  return success();
5602  }
5603 
5604  return failure();
5605  }
5606 };
5607 
5608 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
5609 /// This only applies when the shape of the broadcast source
5610 /// 1. is a suffix of the shape of the result (i.e. when broadcast without
5611 /// reshape is expressive enough to capture the result in a single op), or
5612 /// 2. has the same element count as the shape cast result.
5613 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5614 public:
5616 
5617  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5618  PatternRewriter &rewriter) const override {
5619  auto broadcastOp =
5620  shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5621  if (!broadcastOp)
5622  return failure();
5623 
5624  ArrayRef<int64_t> broadcastSourceShape;
5625  if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5626  broadcastSourceShape = srcType.getShape();
5627  ArrayRef<int64_t> shapeCastTargetShape =
5628  shapeCastOp.getResultVectorType().getShape();
5629 
5630  // If `broadcastSourceShape` is a suffix of the result, we can just replace
5631  // with a broadcast to the final shape.
5632  if (broadcastSourceShape ==
5633  shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5634  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5635  shapeCastOp, shapeCastOp.getResultVectorType(),
5636  broadcastOp.getSource());
5637  return success();
5638  }
5639 
5640  // Otherwise, if the final result has the same element count, we can replace
5641  // with a shape cast.
5642  if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5643  if (srcType.getNumElements() ==
5644  shapeCastOp.getResultVectorType().getNumElements()) {
5645  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
5646  shapeCastOp, shapeCastOp.getResultVectorType(),
5647  broadcastOp.getSource());
5648  return success();
5649  }
5650  }
5651 
5652  return failure();
5653  }
5654 };
5655 
5656 } // namespace
5657 
5658 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
5659  MLIRContext *context) {
5660  results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5661  ShapeCastBroadcastFolder>(context);
5662 }
5663 
5664 //===----------------------------------------------------------------------===//
5665 // VectorBitCastOp
5666 //===----------------------------------------------------------------------===//
5667 
5668 LogicalResult BitCastOp::verify() {
5669  auto sourceVectorType = getSourceVectorType();
5670  auto resultVectorType = getResultVectorType();
5671 
5672  for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5673  if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5674  return emitOpError("dimension size mismatch at: ") << i;
5675  }
5676 
5677  DataLayout dataLayout = DataLayout::closest(*this);
5678  auto sourceElementBits =
5679  dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
5680  auto resultElementBits =
5681  dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
5682 
5683  if (sourceVectorType.getRank() == 0) {
5684  if (sourceElementBits != resultElementBits)
5685  return emitOpError("source/result bitwidth of the 0-D vector element "
5686  "types must be equal");
5687  } else if (sourceElementBits * sourceVectorType.getShape().back() !=
5688  resultElementBits * resultVectorType.getShape().back()) {
5689  return emitOpError(
5690  "source/result bitwidth of the minor 1-D vectors must be equal");
5691  }
5692 
5693  return success();
5694 }
5695 
5696 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
5697  // Nop cast.
5698  if (getSource().getType() == getResult().getType())
5699  return getSource();
5700 
5701  // Canceling bitcasts.
5702  if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5703  if (getResult().getType() == otherOp.getSource().getType())
5704  return otherOp.getSource();
5705 
5706  setOperand(otherOp.getSource());
5707  return getResult();
5708  }
5709 
5710  Attribute sourceConstant = adaptor.getSource();
5711  if (!sourceConstant)
5712  return {};
5713 
5714  Type srcElemType = getSourceVectorType().getElementType();
5715  Type dstElemType = getResultVectorType().getElementType();
5716 
5717  if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5718  if (floatPack.isSplat()) {
5719  auto splat = floatPack.getSplatValue<FloatAttr>();
5720 
5721  // Casting fp16 into fp32.
5722  if (srcElemType.isF16() && dstElemType.isF32()) {
5723  uint32_t bits = static_cast<uint32_t>(
5724  splat.getValue().bitcastToAPInt().getZExtValue());
5725  // Duplicate the 16-bit pattern.
5726  bits = (bits << 16) | (bits & 0xffff);
5727  APInt intBits(32, bits);
5728  APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5729  return DenseElementsAttr::get(getResultVectorType(), floatBits);
5730  }
5731  }
5732  }
5733 
5734  if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5735  if (intPack.isSplat()) {
5736  auto splat = intPack.getSplatValue<IntegerAttr>();
5737 
5738  if (llvm::isa<IntegerType>(dstElemType)) {
5739  uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
5740  uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
5741 
5742  // Casting to a larger integer bit width.
5743  if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5744  APInt intBits = splat.getValue().zext(dstBitWidth);
5745 
5746  // Duplicate the lower width element.
5747  for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5748  intBits = (intBits << srcBitWidth) | intBits;
5749  return DenseElementsAttr::get(getResultVectorType(), intBits);
5750  }
5751  }
5752  }
5753  }
5754 
5755  return {};
5756 }
5757 
5758 //===----------------------------------------------------------------------===//
5759 // TypeCastOp
5760 //===----------------------------------------------------------------------===//
5761 
5762 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
5763  auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5764  SmallVector<int64_t, 8> res(memRefType.getShape());
5765  if (vectorType)
5766  res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5767  return res;
5768 }
5769 
5770 /// Build the canonical memRefType with a single vector.
5771 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
5772 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
5773  Value source) {
5774  result.addOperands(source);
5775  MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
5776  VectorType vectorType =
5777  VectorType::get(extractShape(memRefType),
5779  result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
5780  memRefType.getMemorySpace()));
5781 }
5782 
5783 LogicalResult TypeCastOp::verify() {
5784  MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType());
5785  if (!canonicalType.getLayout().isIdentity())
5786  return emitOpError("expects operand to be a memref with identity layout");
5787  if (!getResultMemRefType().getLayout().isIdentity())
5788  return emitOpError("expects result to be a memref with identity layout");
5789  if (getResultMemRefType().getMemorySpace() !=
5790  getMemRefType().getMemorySpace())
5791  return emitOpError("expects result in same memory space");
5792 
5793  auto sourceType = getMemRefType();
5794  auto resultType = getResultMemRefType();
5795  if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
5797  return emitOpError(
5798  "expects result and operand with same underlying scalar type: ")
5799  << resultType;
5800  if (extractShape(sourceType) != extractShape(resultType))
5801  return emitOpError(
5802  "expects concatenated result and operand shapes to be equal: ")
5803  << resultType;
5804  return success();
5805 }
5806 
5807 //===----------------------------------------------------------------------===//
5808 // TransposeOp
5809 //===----------------------------------------------------------------------===//
5810 
5811 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
5812  Value vector, ArrayRef<int64_t> permutation) {
5813  VectorType vt = llvm::cast<VectorType>(vector.getType());
5814  SmallVector<int64_t, 4> transposedShape(vt.getRank());
5815  SmallVector<bool, 4> transposedScalableDims(vt.getRank());
5816  for (unsigned i = 0; i < permutation.size(); ++i) {
5817  transposedShape[i] = vt.getShape()[permutation[i]];
5818  transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5819  }
5820 
5821  result.addOperands(vector);
5822  result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
5823  transposedScalableDims));
5824  result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
5825  builder.getDenseI64ArrayAttr(permutation));
5826 }
5827 
5828 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5829  // Eliminate splat constant transpose ops.
5830  if (auto attr =
5831  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
5832  if (attr.isSplat())
5833  return attr.reshape(getResultVectorType());
5834 
5835  // Eliminate identity transpose ops. This happens when the dimensions of the
5836  // input vector remain in their original order after the transpose operation.
5837  ArrayRef<int64_t> perm = getPermutation();
5838 
5839  // Check if the permutation of the dimensions contains sequential values:
5840  // {0, 1, 2, ...}.
5841  for (int64_t i = 0, e = perm.size(); i < e; i++) {
5842  if (perm[i] != i)
5843  return {};
5844  }
5845 
5846  return getVector();
5847 }
5848 
5849 LogicalResult vector::TransposeOp::verify() {
5850  VectorType vectorType = getSourceVectorType();
5851  VectorType resultType = getResultVectorType();
5852  int64_t rank = resultType.getRank();
5853  if (vectorType.getRank() != rank)
5854  return emitOpError("vector result rank mismatch: ") << rank;
5855  // Verify transposition array.
5856  ArrayRef<int64_t> perm = getPermutation();
5857  int64_t size = perm.size();
5858  if (rank != size)
5859  return emitOpError("transposition length mismatch: ") << size;
5860  SmallVector<bool, 8> seen(rank, false);
5861  for (const auto &ta : llvm::enumerate(perm)) {
5862  if (ta.value() < 0 || ta.value() >= rank)
5863  return emitOpError("transposition index out of range: ") << ta.value();
5864  if (seen[ta.value()])
5865  return emitOpError("duplicate position index: ") << ta.value();
5866  seen[ta.value()] = true;
5867  if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
5868  return emitOpError("dimension size mismatch at: ") << ta.value();
5869  }
5870  return success();
5871 }
5872 
5873 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
5874  return llvm::to_vector<4>(getResultVectorType().getShape());
5875 }
5876 
5877 namespace {
5878 
5879 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
5880 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
5881 public:
5883 
5884  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5885  PatternRewriter &rewriter) const override {
5886  // Composes two permutations: result[i] = permutation1[permutation2[i]].
5887  auto composePermutations = [](ArrayRef<int64_t> permutation1,
5888  ArrayRef<int64_t> permutation2) {
5889  SmallVector<int64_t, 4> result;
5890  for (auto index : permutation2)
5891  result.push_back(permutation1[index]);
5892  return result;
5893  };
5894 
5895  // Return if the input of 'transposeOp' is not defined by another transpose.
5896  vector::TransposeOp parentTransposeOp =
5897  transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
5898  if (!parentTransposeOp)
5899  return failure();
5900 
5901  SmallVector<int64_t, 4> permutation = composePermutations(
5902  parentTransposeOp.getPermutation(), transposeOp.getPermutation());
5903  // Replace 'transposeOp' with a new transpose operation.
5904  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
5905  transposeOp, transposeOp.getResult().getType(),
5906  parentTransposeOp.getVector(), permutation);
5907  return success();
5908  }
5909 };
5910 
5911 // Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
5912 struct FoldTransposedScalarBroadcast final
5913  : public OpRewritePattern<vector::TransposeOp> {
5915 
5916  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5917  PatternRewriter &rewriter) const override {
5918  auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
5919  if (!bcastOp)
5920  return failure();
5921 
5922  auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
5923  if (!srcVectorType || srcVectorType.getNumElements() == 1) {
5924  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5925  transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
5926  return success();
5927  }
5928 
5929  return failure();
5930  }
5931 };
5932 
5933 // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
5934 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
5935 public:
5937 
5938  LogicalResult matchAndRewrite(TransposeOp transposeOp,
5939  PatternRewriter &rewriter) const override {
5940  auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
5941  if (!splatOp)
5942  return failure();
5943 
5944  rewriter.replaceOpWithNewOp<vector::SplatOp>(
5945  transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
5946  return success();
5947  }
5948 };
5949 
5950 /// Folds transpose(create_mask) into a new transposed create_mask.
5951 class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
5952 public:
5954 
5955  LogicalResult matchAndRewrite(TransposeOp transpOp,
5956  PatternRewriter &rewriter) const override {
5957  Value transposeSrc = transpOp.getVector();
5958  auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
5959  auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
5960  if (!createMaskOp && !constantMaskOp)
5961  return failure();
5962 
5963  // Get the transpose permutation and apply it to the vector.create_mask or
5964  // vector.constant_mask operands.
5965  ArrayRef<int64_t> permutation = transpOp.getPermutation();
5966 
5967  if (createMaskOp) {
5968  auto maskOperands = createMaskOp.getOperands();
5969  SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
5970  applyPermutationToVector(newOperands, permutation);
5971 
5972  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
5973  transpOp, transpOp.getResultVectorType(), newOperands);
5974  return success();
5975  }
5976 
5977  // ConstantMaskOp case.
5978  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5979  auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
5980 
5981  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
5982  transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
5983  return success();
5984  }
5985 };
5986 
5987 } // namespace
5988 
5989 void vector::TransposeOp::getCanonicalizationPatterns(
5990  RewritePatternSet &results, MLIRContext *context) {
5991  results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
5992  TransposeFolder, FoldTransposeSplat>(context);
5993 }
5994 
5995 //===----------------------------------------------------------------------===//
5996 // ConstantMaskOp
5997 //===----------------------------------------------------------------------===//
5998 
5999 void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
6000  VectorType type, ConstantMaskKind kind) {
6001  assert(kind == ConstantMaskKind::AllTrue ||
6002  kind == ConstantMaskKind::AllFalse);
6003  build(builder, result, type,
6004  kind == ConstantMaskKind::AllTrue
6005  ? type.getShape()
6006  : SmallVector<int64_t>(type.getRank(), 0));
6007 }
6008 
6009 LogicalResult ConstantMaskOp::verify() {
6010  auto resultType = llvm::cast<VectorType>(getResult().getType());
6011  // Check the corner case of 0-D vectors first.
6012  if (resultType.getRank() == 0) {
6013  if (getMaskDimSizes().size() != 1)
6014  return emitError("array attr must have length 1 for 0-D vectors");
6015  auto dim = getMaskDimSizes()[0];
6016  if (dim != 0 && dim != 1)
6017  return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
6018  return success();
6019  }
6020 
6021  // Verify that array attr size matches the rank of the vector result.
6022  if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
6023  return emitOpError(
6024  "must specify array attr of size equal vector result rank");
6025  // Verify that each array attr element is in bounds of corresponding vector
6026  // result dimension size.
6027  auto resultShape = resultType.getShape();
6028  auto resultScalableDims = resultType.getScalableDims();
6029  ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
6030  for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
6031  if (maskDimSize < 0 || maskDimSize > resultShape[index])
6032  return emitOpError(
6033  "array attr of size out of bounds of vector result dimension size");
6034  if (resultScalableDims[index] && maskDimSize != 0 &&
6035  maskDimSize != resultShape[index])
6036  return emitOpError(
6037  "only supports 'none set' or 'all set' scalable dimensions");
6038  }
6039  // Verify that if one mask dim size is zero, they all should be zero (because
6040  // the mask region is a conjunction of each mask dimension interval).
6041  bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6042  bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
6043  if (anyZeros && !allZeros)
6044  return emitOpError("expected all mask dim sizes to be zeros, "
6045  "as a result of conjunction with zero mask dim");
6046  return success();
6047 }
6048 
6049 bool ConstantMaskOp::isAllOnesMask() {
6050  auto resultType = getVectorType();
6051  // Check the corner case of 0-D vectors first.
6052  if (resultType.getRank() == 0) {
6053  assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
6054  return getMaskDimSizes()[0] == 1;
6055  }
6056  for (const auto [resultSize, maskDimSize] :
6057  llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
6058  if (maskDimSize < resultSize)
6059  return false;
6060  }
6061  return true;
6062 }
6063 
6064 //===----------------------------------------------------------------------===//
6065 // CreateMaskOp
6066 //===----------------------------------------------------------------------===//
6067 
6068 void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
6069  VectorType type,
6070  ArrayRef<OpFoldResult> mixedOperands) {
6071  SmallVector<Value> operands =
6072  getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
6073  build(builder, result, type, operands);
6074 }
6075 
6076 LogicalResult CreateMaskOp::verify() {
6077  auto vectorType = llvm::cast<VectorType>(getResult().getType());
6078  // Verify that an operand was specified for each result vector each dimension.
6079  if (vectorType.getRank() == 0) {
6080  if (getNumOperands() != 1)
6081  return emitOpError(
6082  "must specify exactly one operand for 0-D create_mask");
6083  } else if (getNumOperands() !=
6084  llvm::cast<VectorType>(getResult().getType()).getRank()) {
6085  return emitOpError(
6086  "must specify an operand for each result vector dimension");
6087  }
6088  return success();
6089 }
6090 
6091 namespace {
6092 
6093 /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
6094 ///
6095 /// Ex 1:
6096 /// %c2 = arith.constant 2 : index
6097 /// %c3 = arith.constant 3 : index
6098 /// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
6099 /// Becomes:
6100 /// vector.constant_mask [3, 2] : vector<4x3xi1>
6101 ///
6102 /// Ex 2:
6103 /// %c_neg_1 = arith.constant -1 : index
6104 /// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
6105 /// becomes:
6106 /// vector.constant_mask [0] : vector<[8]xi1>
6107 ///
6108 /// Ex 3:
6109 /// %c8 = arith.constant 8 : index
6110 /// %c16 = arith.constant 16 : index
6111 /// %0 = vector.vscale
6112 /// %1 = arith.muli %0, %c16 : index
6113 /// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
6114 /// becomes:
6115 /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
6116 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
6117 public:
6119 
6120  LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
6121  PatternRewriter &rewriter) const override {
6122  VectorType maskType = createMaskOp.getVectorType();
6123  ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
6124  ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6125 
6126  // Special case: Rank zero shape.
6127  constexpr std::array<int64_t, 1> rankZeroShape{1};
6128  constexpr std::array<bool, 1> rankZeroScalableDims{false};
6129  if (maskType.getRank() == 0) {
6130  maskTypeDimSizes = rankZeroShape;
6131  maskTypeDimScalableFlags = rankZeroScalableDims;
6132  }
6133 
6134  // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
6135  // collect the `constantDims` (for the ConstantMaskOp).
6136  SmallVector<int64_t, 4> constantDims;
6137  for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
6138  if (auto intSize = getConstantIntValue(dimSize)) {
6139  // Constant value.
6140  // If the mask dim is non-scalable this can be any value.
6141  // If the mask dim is scalable only zero (all-false) is supported.
6142  if (maskTypeDimScalableFlags[i] && intSize >= 0)
6143  return failure();
6144  constantDims.push_back(*intSize);
6145  } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
6146  // Constant vscale multiple (e.g. 4 x vscale).
6147  // Must be all-true to fold to a ConstantMask.
6148  if (vscaleMultiplier < maskTypeDimSizes[i])
6149  return failure();
6150  constantDims.push_back(*vscaleMultiplier);
6151  } else {
6152  return failure();
6153  }
6154  }
6155 
6156  // Clamp values to constant_mask bounds.
6157  for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
6158  value = std::clamp<int64_t>(value, 0, maskDimSize);
6159 
6160  // If one of dim sizes is zero, set all dims to zero.
6161  if (llvm::is_contained(constantDims, 0))
6162  constantDims.assign(constantDims.size(), 0);
6163 
6164  // Replace 'createMaskOp' with ConstantMaskOp.
6165  rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
6166  constantDims);
6167  return success();
6168  }
6169 };
6170 
6171 } // namespace
6172 
6173 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6174  MLIRContext *context) {
6175  results.add<CreateMaskFolder>(context);
6176 }
6177 
6178 //===----------------------------------------------------------------------===//
6179 // MaskOp
6180 //===----------------------------------------------------------------------===//
6181 
6182 void MaskOp::build(
6183  OpBuilder &builder, OperationState &result, Value mask,
6184  Operation *maskableOp,
6185  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6186  assert(maskRegionBuilder &&
6187  "builder callback for 'maskRegion' must be present");
6188 
6189  result.addOperands(mask);
6190  OpBuilder::InsertionGuard guard(builder);
6191  Region *maskRegion = result.addRegion();
6192  builder.createBlock(maskRegion);
6193  maskRegionBuilder(builder, maskableOp);
6194 }
6195 
6196 void MaskOp::build(
6197  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6198  Value mask, Operation *maskableOp,
6199  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6200  build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
6201  maskRegionBuilder);
6202 }
6203 
6204 void MaskOp::build(
6205  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6206  Value mask, Value passthru, Operation *maskableOp,
6207  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6208  build(builder, result, mask, maskableOp, maskRegionBuilder);
6209  if (passthru)
6210  result.addOperands(passthru);
6211  result.addTypes(resultTypes);
6212 }
6213 
6214 ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
6215  // Create the op region.
6216  result.regions.reserve(1);
6217  Region &maskRegion = *result.addRegion();
6218 
6219  auto &builder = parser.getBuilder();
6220 
6221  // Parse all the operands.
6223  if (parser.parseOperand(mask))
6224  return failure();
6225 
6226  // Optional passthru operand.
6228  ParseResult parsePassthru = parser.parseOptionalComma();
6229  if (parsePassthru.succeeded() && parser.parseOperand(passthru))
6230  return failure();
6231 
6232  // Parse op region.
6233  if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
6234  return failure();
6235 
6236  MaskOp::ensureTerminator(maskRegion, builder, result.location);
6237 
6238  // Parse the optional attribute list.
6239  if (parser.parseOptionalAttrDict(result.attributes))
6240  return failure();
6241 
6242  // Parse all the types.
6243  Type maskType;
6244  if (parser.parseColonType(maskType))
6245  return failure();
6246 
6247  SmallVector<Type> resultTypes;
6248  if (parser.parseOptionalArrowTypeList(resultTypes))
6249  return failure();
6250  result.types.append(resultTypes);
6251 
6252  // Resolve operands.
6253  if (parser.resolveOperand(mask, maskType, result.operands))
6254  return failure();
6255 
6256  if (parsePassthru.succeeded())
6257  if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
6258  return failure();
6259 
6260  return success();
6261 }
6262 
6264  p << " " << getMask();
6265  if (getPassthru())
6266  p << ", " << getPassthru();
6267 
6268  // Print single masked operation and skip terminator.
6269  p << " { ";
6270  Block *singleBlock = &getMaskRegion().getBlocks().front();
6271  if (singleBlock && !singleBlock->getOperations().empty())
6272  p.printCustomOrGenericOp(&singleBlock->front());
6273  p << " }";
6274 
6275  p.printOptionalAttrDict(getOperation()->getAttrs());
6276 
6277  p << " : " << getMask().getType();
6278  if (getNumResults() > 0)
6279  p << " -> " << getResultTypes();
6280 }
6281 
6282 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
6284  MaskOp>::ensureTerminator(region, builder, loc);
6285  // Keep the default yield terminator if the number of masked operations is not
6286  // the expected. This case will trigger a verification failure.
6287  Block &block = region.front();
6288  if (block.getOperations().size() != 2)
6289  return;
6290 
6291  // Replace default yield terminator with a new one that returns the results
6292  // from the masked operation.
6293  OpBuilder opBuilder(builder.getContext());
6294  Operation *maskedOp = &block.front();
6295  Operation *oldYieldOp = &block.back();
6296  assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
6297 
6298  // Empty vector.mask op.
6299  if (maskedOp == oldYieldOp)
6300  return;
6301 
6302  opBuilder.setInsertionPoint(oldYieldOp);
6303  opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
6304  oldYieldOp->dropAllReferences();
6305  oldYieldOp->erase();
6306 }
6307 
6308 LogicalResult MaskOp::verify() {
6309  // Structural checks.
6310  Block &block = getMaskRegion().getBlocks().front();
6311  if (block.getOperations().empty())
6312  return emitOpError("expects a terminator within the mask region");
6313 
6314  unsigned numMaskRegionOps = block.getOperations().size();
6315  if (numMaskRegionOps > 2)
6316  return emitOpError("expects only one operation to mask");
6317 
6318  // Terminator checks.
6319  auto terminator = dyn_cast<vector::YieldOp>(block.back());
6320  if (!terminator)
6321  return emitOpError("expects a terminator within the mask region");
6322 
6323  if (terminator->getNumOperands() != getNumResults())
6324  return emitOpError(
6325  "expects number of results to match mask region yielded values");
6326 
6327  // Empty vector.mask. Nothing else to check.
6328  if (numMaskRegionOps == 1)
6329  return success();
6330 
6331  auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
6332  if (!maskableOp)
6333  return emitOpError("expects a MaskableOpInterface within the mask region");
6334 
6335  // Result checks.
6336  if (maskableOp->getNumResults() != getNumResults())
6337  return emitOpError("expects number of results to match maskable operation "
6338  "number of results");
6339 
6340  if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
6341  return emitOpError(
6342  "expects result type to match maskable operation result type");
6343 
6344  if (llvm::count_if(maskableOp->getResultTypes(),
6345  [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
6346  return emitOpError("multiple vector results not supported");
6347 
6348  // Mask checks.
6349  Type expectedMaskType = maskableOp.getExpectedMaskType();
6350  if (getMask().getType() != expectedMaskType)
6351  return emitOpError("expects a ")
6352  << expectedMaskType << " mask for the maskable operation";
6353 
6354  // Passthru checks.
6355  Value passthru = getPassthru();
6356  if (passthru) {
6357  if (!maskableOp.supportsPassthru())
6358  return emitOpError(
6359  "doesn't expect a passthru argument for this maskable operation");
6360 
6361  if (maskableOp->getNumResults() != 1)
6362  return emitOpError("expects result when passthru argument is provided");
6363 
6364  if (passthru.getType() != maskableOp->getResultTypes()[0])
6365  return emitOpError("expects passthru type to match result type");
6366  }
6367 
6368  return success();
6369 }
6370 
6371 /// Folds vector.mask ops with an all-true mask.
6372 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6373  SmallVectorImpl<OpFoldResult> &results) {
6374  MaskFormat maskFormat = getMaskFormat(getMask());
6375  if (isEmpty())
6376  return failure();
6377 
6378  if (maskFormat != MaskFormat::AllTrue)
6379  return failure();
6380 
6381  // Move maskable operation outside of the `vector.mask` region.
6382  Operation *maskableOp = getMaskableOp();
6383  maskableOp->dropAllUses();
6384  maskableOp->moveBefore(getOperation());
6385 
6386  llvm::append_range(results, maskableOp->getResults());
6387  return success();
6388 }
6389 
6390 // Elides empty vector.mask operations with or without return values. Propagates
6391 // the yielded values by the vector.yield terminator, if any, or erases the op,
6392 // otherwise.
6393 class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
6395 
6396  LogicalResult matchAndRewrite(MaskOp maskOp,
6397  PatternRewriter &rewriter) const override {
6398  auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6399  if (maskingOp.getMaskableOp())
6400  return failure();
6401 
6402  if (!maskOp.isEmpty())
6403  return failure();
6404 
6405  Block *block = maskOp.getMaskBlock();
6406  auto terminator = cast<vector::YieldOp>(block->front());
6407  if (terminator.getNumOperands() == 0)
6408  rewriter.eraseOp(maskOp);
6409  else
6410  rewriter.replaceOp(maskOp, terminator.getOperands());
6411 
6412  return success();
6413  }
6414 };
6415 
6416 void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6417  MLIRContext *context) {
6418  results.add<ElideEmptyMaskOp>(context);
6419 }
6420 
6421 // MaskingOpInterface definitions.
6422 
6423 /// Returns the operation masked by this 'vector.mask'.
6424 Operation *MaskOp::getMaskableOp() {
6425  Block *block = getMaskBlock();
6426  if (block->getOperations().size() < 2)
6427  return nullptr;
6428 
6429  return &block->front();
6430 }
6431 
6432 /// Returns true if 'vector.mask' has a passthru value.
6433 bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
6434 
6435 //===----------------------------------------------------------------------===//
6436 // ScanOp
6437 //===----------------------------------------------------------------------===//
6438 
6439 LogicalResult ScanOp::verify() {
6440  VectorType srcType = getSourceType();
6441  VectorType initialType = getInitialValueType();
6442  // Check reduction dimension < rank.
6443  int64_t srcRank = srcType.getRank();
6444  int64_t reductionDim = getReductionDim();
6445  if (reductionDim >= srcRank)
6446  return emitOpError("reduction dimension ")
6447  << reductionDim << " has to be less than " << srcRank;
6448 
6449  // Check that rank(initial_value) = rank(src) - 1.
6450  int64_t initialValueRank = initialType.getRank();
6451  if (initialValueRank != srcRank - 1)
6452  return emitOpError("initial value rank ")
6453  << initialValueRank << " has to be equal to " << srcRank - 1;
6454 
6455  // Check shapes of initial value and src.
6456  ArrayRef<int64_t> srcShape = srcType.getShape();
6457  ArrayRef<int64_t> initialValueShapes = initialType.getShape();
6458  SmallVector<int64_t> expectedShape;
6459  for (int i = 0; i < srcRank; i++) {
6460  if (i != reductionDim)
6461  expectedShape.push_back(srcShape[i]);
6462  }
6463  if (!llvm::equal(initialValueShapes, expectedShape)) {
6464  return emitOpError("incompatible input/initial value shapes");
6465  }
6466 
6467  // Verify supported reduction kind.
6468  Type eltType = getDestType().getElementType();
6469  if (!isSupportedCombiningKind(getKind(), eltType))
6470  return emitOpError("unsupported reduction type ")
6471  << eltType << " for kind '" << stringifyCombiningKind(getKind())
6472  << "'";
6473 
6474  return success();
6475 }
6476 
6478  RewritePatternSet &patterns, PatternBenefit benefit) {
6479  patterns
6480  .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
6481  ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
6482  StridedSliceConstantMaskFolder, TransposeFolder>(
6483  patterns.getContext(), benefit);
6484 }
6485 
6486 //===----------------------------------------------------------------------===//
6487 // SplatOp
6488 //===----------------------------------------------------------------------===//
6489 
6490 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
6491  auto constOperand = adaptor.getInput();
6492  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
6493  return {};
6494 
6495  // SplatElementsAttr::get treats single value for second arg as being a splat.
6496  return SplatElementsAttr::get(getType(), {constOperand});
6497 }
6498 
6499 void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6500  SetIntRangeFn setResultRanges) {
6501  setResultRanges(getResult(), argRanges.front());
6502 }
6503 
6504 //===----------------------------------------------------------------------===//
6505 // WarpExecuteOnLane0Op
6506 //===----------------------------------------------------------------------===//
6507 
6509  p << "(" << getLaneid() << ")";
6510 
6511  SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
6512  auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
6513  p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
6514 
6515  if (!getArgs().empty())
6516  p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
6517  if (!getResults().empty())
6518  p << " -> (" << getResults().getTypes() << ')';
6519  p << " ";
6520  p.printRegion(getRegion(),
6521  /*printEntryBlockArgs=*/true,
6522  /*printBlockTerminators=*/!getResults().empty());
6523  p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
6524 }
6525 
6526 ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
6527  OperationState &result) {
6528  // Create the region.
6529  result.regions.reserve(1);
6530  Region *warpRegion = result.addRegion();
6531 
6532  auto &builder = parser.getBuilder();
6534 
6535  // Parse predicate operand.
6536  if (parser.parseLParen() ||
6537  parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
6538  parser.parseRParen())
6539  return failure();
6540 
6541  int64_t warpSize;
6542  if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
6543  parser.parseRSquare())
6544  return failure();
6545  result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
6546  builder.getContext())),
6547  builder.getI64IntegerAttr(warpSize));
6548 
6549  if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
6550  return failure();
6551 
6552  llvm::SMLoc inputsOperandsLoc;
6554  SmallVector<Type> inputTypes;
6555  if (succeeded(parser.parseOptionalKeyword("args"))) {
6556  if (parser.parseLParen())
6557  return failure();
6558 
6559  inputsOperandsLoc = parser.getCurrentLocation();
6560  if (parser.parseOperandList(inputsOperands) ||
6561  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
6562  return failure();
6563  }
6564  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
6565  result.operands))
6566  return failure();
6567 
6568  // Parse optional results type list.
6569  if (parser.parseOptionalArrowTypeList(result.types))
6570  return failure();
6571  // Parse the region.
6572  if (parser.parseRegion(*warpRegion, /*arguments=*/{},
6573  /*argTypes=*/{}))
6574  return failure();
6575  WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
6576 
6577  // Parse the optional attribute list.
6578  if (parser.parseOptionalAttrDict(result.attributes))
6579  return failure();
6580  return success();
6581 }
6582 
6583 void WarpExecuteOnLane0Op::getSuccessorRegions(
6585  if (!point.isParent()) {
6586  regions.push_back(RegionSuccessor(getResults()));
6587  return;
6588  }
6589 
6590  // The warp region is always executed
6591  regions.push_back(RegionSuccessor(&getWarpRegion()));
6592 }
6593 
6594 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
6595  TypeRange resultTypes, Value laneId,
6596  int64_t warpSize) {
6597  build(builder, result, resultTypes, laneId, warpSize,
6598  /*operands=*/std::nullopt, /*argTypes=*/std::nullopt);
6599 }
6600 
6601 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
6602  TypeRange resultTypes, Value laneId,
6603  int64_t warpSize, ValueRange args,
6604  TypeRange blockArgTypes) {
6605  result.addOperands(laneId);
6606  result.addAttribute(getAttributeNames()[0],
6607  builder.getI64IntegerAttr(warpSize));
6608  result.addTypes(resultTypes);
6609  result.addOperands(args);
6610  assert(args.size() == blockArgTypes.size());
6611  OpBuilder::InsertionGuard guard(builder);
6612  Region *warpRegion = result.addRegion();
6613  Block *block = builder.createBlock(warpRegion);
6614  for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
6615  block->addArgument(type, arg.getLoc());
6616 }
6617 
6618 /// Helper check if the distributed vector type is consistent with the expanded
6619 /// type and distributed size.
6620 static LogicalResult verifyDistributedType(Type expanded, Type distributed,
6621  int64_t warpSize, Operation *op) {
6622  // If the types matches there is no distribution.
6623  if (expanded == distributed)
6624  return success();
6625  auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
6626  auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
6627  if (!expandedVecType || !distributedVecType)
6628  return op->emitOpError("expected vector type for distributed operands.");
6629  if (expandedVecType.getRank() != distributedVecType.getRank() ||
6630  expandedVecType.getElementType() != distributedVecType.getElementType())
6631  return op->emitOpError(
6632  "expected distributed vectors to have same rank and element type.");
6633 
6634  SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
6635  for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
6636  int64_t eDim = expandedVecType.getDimSize(i);
6637  int64_t dDim = distributedVecType.getDimSize(i);
6638  if (eDim == dDim)
6639  continue;
6640  if (eDim % dDim != 0)
6641  return op->emitOpError()
6642  << "expected expanded vector dimension #" << i << " (" << eDim
6643  << ") to be a multipler of the distributed vector dimension ("
6644  << dDim << ")";
6645  scales[i] = eDim / dDim;
6646  }
6647  if (std::accumulate(scales.begin(), scales.end(), 1,
6648  std::multiplies<int64_t>()) != warpSize)
6649  return op->emitOpError()
6650  << "incompatible distribution dimensions from " << expandedVecType
6651  << " to " << distributedVecType << " with warp size = " << warpSize;
6652 
6653  return success();
6654 }
6655 
6656 LogicalResult WarpExecuteOnLane0Op::verify() {
6657  if (getArgs().size() != getWarpRegion().getNumArguments())
6658  return emitOpError(
6659  "expected same number op arguments and block arguments.");
6660  auto yield =
6661  cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
6662  if (yield.getNumOperands() != getNumResults())
6663  return emitOpError(
6664  "expected same number of yield operands and return values.");
6665  int64_t warpSize = getWarpSize();
6666  for (auto [regionArg, arg] :
6667  llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
6668  if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
6669  warpSize, getOperation())))
6670  return failure();
6671  }
6672  for (auto [yieldOperand, result] :
6673  llvm::zip_equal(yield.getOperands(), getResults())) {
6674  if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
6675  warpSize, getOperation())))
6676  return failure();
6677  }
6678  return success();
6679 }
6680 
6681 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
6682  return succeeded(
6683  verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
6684 }
6685 
6687  CombiningKind kind, Value v1, Value acc,
6688  arith::FastMathFlagsAttr fastmath,
6689  Value mask) {
6690  Type t1 = getElementTypeOrSelf(v1.getType());
6691  Type tAcc = getElementTypeOrSelf(acc.getType());
6692  Value result;
6693 
6694  switch (kind) {
6695  case CombiningKind::ADD:
6696  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6697  result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
6698  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6699  result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
6700  else
6701  llvm_unreachable("invalid value types for ADD reduction");
6702  break;
6703  case CombiningKind::AND:
6704  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6705  result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
6706  break;
6707  case CombiningKind::MAXNUMF:
6708  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6709  "expected float values");
6710  result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
6711  break;
6712  case CombiningKind::MAXIMUMF:
6713  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6714  "expected float values");
6715  result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
6716  break;
6717  case CombiningKind::MINNUMF:
6718  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6719  "expected float values");
6720  result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
6721  break;
6722  case CombiningKind::MINIMUMF:
6723  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6724  "expected float values");
6725  result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
6726  break;
6727  case CombiningKind::MAXSI:
6728  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6729  result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
6730  break;
6731  case CombiningKind::MINSI:
6732  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6733  result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
6734  break;
6735  case CombiningKind::MAXUI:
6736  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6737  result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
6738  break;
6739  case CombiningKind::MINUI:
6740  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6741  result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
6742  break;
6743  case CombiningKind::MUL:
6744  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6745  result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
6746  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6747  result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
6748  else
6749  llvm_unreachable("invalid value types for MUL reduction");
6750  break;
6751  case CombiningKind::OR:
6752  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6753  result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
6754  break;
6755  case CombiningKind::XOR:
6756  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6757  result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
6758  break;
6759  };
6760 
6761  assert(result && "unknown CombiningKind");
6762  return selectPassthru(b, mask, result, acc);
6763 }
6764 
6765 //===----------------------------------------------------------------------===//
6766 // Vector Masking Utilities
6767 //===----------------------------------------------------------------------===//
6768 
6769 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
6770 /// as masked operation.
6772  Operation *maskableOp) {
6773  assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
6774  Block *insBlock = builder.getInsertionBlock();
6775  // Create a block and move the op to that block.
6776  insBlock->getOperations().splice(
6777  insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
6778  builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
6779 }
6780 
6781 /// Creates a vector.mask operation around a maskable operation. Returns the
6782 /// vector.mask operation if the mask provided is valid. Otherwise, returns
6783 /// the maskable operation itself.
6785  Operation *maskableOp, Value mask,
6786  Value passthru) {
6787  if (!mask)
6788  return maskableOp;
6789  if (passthru)
6790  return builder.create<MaskOp>(maskableOp->getLoc(),
6791  maskableOp->getResultTypes(), mask, passthru,
6792  maskableOp, createMaskOpRegion);
6793  return builder.create<MaskOp>(maskableOp->getLoc(),
6794  maskableOp->getResultTypes(), mask, maskableOp,
6796 }
6797 
6798 /// Creates a vector select operation that picks values from `newValue` or
6799 /// `passthru` for each result vector lane based on `mask`. This utility is used
6800 /// to propagate the pass-thru value of vector.mask or for cases where only the
6801 /// pass-thru value propagation is needed. VP intrinsics do not support
6802 /// pass-thru values and every mask-out lane is set to poison. LLVM backends are
6803 /// usually able to match op + select patterns and fold them into a native
6804 /// target instructions.
6806  Value newValue, Value passthru) {
6807  if (!mask)
6808  return newValue;
6809 
6810  return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(),
6811  mask, newValue, passthru);
6812 }
6813 
6814 //===----------------------------------------------------------------------===//
6815 // TableGen'd op method definitions
6816 //===----------------------------------------------------------------------===//
6817 
6818 #define GET_ATTRDEF_CLASSES
6819 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
6820 
6821 #define GET_OP_CLASSES
6822 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MINUI(lhs, rhs)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition: VectorOps.cpp:68
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
Definition: VectorOps.cpp:1087
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
Definition: VectorOps.cpp:1378
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
Definition: VectorOps.cpp:1639
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
Definition: VectorOps.cpp:4099
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
Definition: VectorOps.cpp:1952
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
Definition: VectorOps.cpp:1820
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
Definition: VectorOps.cpp:1650
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Definition: VectorOps.cpp:2267
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
Definition: VectorOps.cpp:130
MaskFormat
Helper enum to classify mask value.
Definition: VectorOps.cpp:58
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
Definition: VectorOps.cpp:3130
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
Definition: VectorOps.cpp:296
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
Definition: VectorOps.cpp:868
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
Definition: VectorOps.cpp:2316
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
Definition: VectorOps.cpp:2626
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
Definition: VectorOps.cpp:3066
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
Definition: VectorOps.cpp:1079
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
Definition: VectorOps.cpp:3086
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
Definition: VectorOps.cpp:175
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
Definition: VectorOps.cpp:3109
static LogicalResult foldTransferFullMask(TransferOp op)
Definition: VectorOps.cpp:4328
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Definition: VectorOps.cpp:1370
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite a vector.from_elements into a vector.splat if all elements are the same SSA value.
Definition: VectorOps.cpp:2290
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
Definition: VectorOps.cpp:3991
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
Definition: VectorOps.cpp:879
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
Definition: VectorOps.cpp:3051
static Value foldExtractFromShapeCast(ExtractOp extractOp)
Definition: VectorOps.cpp:1750
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
Definition: VectorOps.cpp:4020
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
Definition: VectorOps.cpp:4256
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
Definition: VectorOps.cpp:3467
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
Definition: VectorOps.cpp:1718
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
Definition: VectorOps.cpp:4273
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
Definition: VectorOps.cpp:1872
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:135
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:216
unsigned getNumResults() const
Definition: AffineMap.cpp:402
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
Definition: AffineMap.cpp:161
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation & back()
Definition: Block.h:152
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:203
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:97
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:321
IndexType getIndexType()
Definition: Builders.cpp:95
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:310
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:358
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:450
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:830
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
Definition: Operation.cpp:584
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:131
bool isF32() const
Definition: Types.cpp:59
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:123
bool isF16() const
Definition: Types.cpp:57
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:317
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Definition: BuiltinTypes.h:329
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:336
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:357
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Definition: VectorOps.cpp:446
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
Definition: VectorOps.cpp:125
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
Definition: VectorOps.cpp:354
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:153
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition: VectorOps.h:60
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2448
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Definition: VectorOps.cpp:4118
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
Definition: VectorOps.cpp:219
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:283
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:314
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:338
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
Definition: VectorOps.cpp:624
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition: VectorOps.h:68
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Definition: VectorOps.cpp:442
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:773
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:722
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:641
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:928
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
Return a fused vector::ContractionOp which represents a patterns such as:
Definition: VectorOps.cpp:1180
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Definition: VectorOps.cpp:1183
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:329
bool operator==(const KeyTy &key) const
Definition: VectorOps.cpp:381
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Definition: VectorOps.cpp:383
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.