MLIR  19.0.0git
VectorOps.cpp
Go to the documentation of this file.
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
24 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/AffineMap.h"
26 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/BuiltinOps.h"
29 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/IRMapping.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/IR/TypeUtilities.h"
37 #include "mlir/Support/LLVM.h"
39 #include "llvm/ADT/ArrayRef.h"
40 #include "llvm/ADT/STLExtras.h"
41 #include "llvm/ADT/SmallVector.h"
42 #include "llvm/ADT/StringSet.h"
43 #include "llvm/ADT/TypeSwitch.h"
44 #include "llvm/ADT/bit.h"
45 
46 #include <cassert>
47 #include <cstdint>
48 #include <numeric>
49 
50 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
51 // Pull in all enum type and utility function definitions.
52 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
53 
54 using namespace mlir;
55 using namespace mlir::vector;
56 
57 /// Helper enum to classify mask value.
58 enum class MaskFormat {
59  AllTrue = 0,
60  AllFalse = 1,
61  Unknown = 2,
62 };
63 
64 /// Helper method to classify a mask value. Currently, the method
65 /// looks "under the hood" of a constant value with dense attributes
66 /// and a constant mask operation (since the client may be called at
67 /// various stages during progressive lowering).
69  if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
70  // Inspect constant dense values. We count up for bits that
71  // are set, count down for bits that are cleared, and bail
72  // when a mix is detected.
73  if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
74  int64_t val = 0;
75  for (bool b : denseElts.getValues<bool>())
76  if (b && val >= 0)
77  val++;
78  else if (!b && val <= 0)
79  val--;
80  else
81  return MaskFormat::Unknown;
82  if (val > 0)
83  return MaskFormat::AllTrue;
84  if (val < 0)
85  return MaskFormat::AllFalse;
86  }
87  } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
88  // Inspect constant mask index. If the index exceeds the
89  // dimension size, all bits are set. If the index is zero
90  // or less, no bits are set.
91  ArrayAttr masks = m.getMaskDimSizes();
92  auto shape = m.getType().getShape();
93  bool allTrue = true;
94  bool allFalse = true;
95  for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
96  int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt();
97  if (i < dimSize)
98  allTrue = false;
99  if (i > 0)
100  allFalse = false;
101  }
102  if (allTrue)
103  return MaskFormat::AllTrue;
104  if (allFalse)
105  return MaskFormat::AllFalse;
106  } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
107  // Finds all-false create_masks. An all-true create_mask requires all
108  // dims to be constants, so that'll be folded to a constant_mask, then
109  // detected in the constant_mask case.
110  auto maskOperands = m.getOperands();
111  for (Value operand : maskOperands) {
112  if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
113  int64_t dimSize =
114  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
115  if (dimSize <= 0)
116  return MaskFormat::AllFalse;
117  }
118  }
119  return MaskFormat::Unknown;
120  }
121  return MaskFormat::Unknown;
122 }
123 
124 /// Default callback to build a region with a 'vector.yield' terminator with no
125 /// arguments.
127  builder.create<vector::YieldOp>(loc);
128 }
129 
130 // Helper for verifying combining kinds in contractions and reductions.
131 static bool isSupportedCombiningKind(CombiningKind combiningKind,
132  Type elementType) {
133  switch (combiningKind) {
134  case CombiningKind::ADD:
135  case CombiningKind::MUL:
136  return elementType.isIntOrIndexOrFloat();
138  case CombiningKind::MINSI:
139  case CombiningKind::MAXUI:
140  case CombiningKind::MAXSI:
141  case CombiningKind::AND:
142  case CombiningKind::OR:
143  case CombiningKind::XOR:
144  return elementType.isIntOrIndex();
145  case CombiningKind::MINNUMF:
146  case CombiningKind::MAXNUMF:
147  case CombiningKind::MINIMUMF:
148  case CombiningKind::MAXIMUMF:
149  return llvm::isa<FloatType>(elementType);
150  }
151  return false;
152 }
153 
155  VectorType vectorType) {
156  int64_t elementVectorRank = 0;
157  VectorType elementVectorType =
158  llvm::dyn_cast<VectorType>(shapedType.getElementType());
159  if (elementVectorType)
160  elementVectorRank += elementVectorType.getRank();
161  // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
162  // TODO: replace once we have 0-d vectors.
163  if (shapedType.getRank() == 0 &&
164  vectorType.getShape() == ArrayRef<int64_t>{1})
165  return AffineMap::get(
166  /*numDims=*/0, /*numSymbols=*/0,
167  getAffineConstantExpr(0, shapedType.getContext()));
169  shapedType.getRank(), vectorType.getRank() - elementVectorRank,
170  shapedType.getContext());
171 }
172 
173 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
174  vector::TransferReadOp read) {
175  return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
176  !read.getMask() && defWrite.getIndices() == read.getIndices() &&
177  defWrite.getVectorType() == read.getVectorType() &&
178  defWrite.getPermutationMap() == read.getPermutationMap();
179 }
180 
181 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
182  vector::TransferWriteOp priorWrite) {
183  return priorWrite.getIndices() == write.getIndices() &&
184  priorWrite.getMask() == write.getMask() &&
185  priorWrite.getVectorType() == write.getVectorType() &&
186  priorWrite.getPermutationMap() == write.getPermutationMap();
187 }
188 
190  VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
191  bool testDynamicValueUsingBounds) {
192  // For simplicity only look at transfer of same type.
193  if (transferA.getVectorType() != transferB.getVectorType())
194  return false;
195  unsigned rankOffset = transferA.getLeadingShapedRank();
196  for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
197  Value indexA = transferA.getIndices()[i];
198  Value indexB = transferB.getIndices()[i];
199  std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
200  std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
201 
202  if (i < rankOffset) {
203  // For leading dimensions, if we can prove that index are different we
204  // know we are accessing disjoint slices.
205  if (cstIndexA.has_value() && cstIndexB.has_value()) {
206  if (*cstIndexA != *cstIndexB)
207  return true;
208  continue;
209  }
210  if (testDynamicValueUsingBounds) {
211  // First try to see if we can fully compose and simplify the affine
212  // expression as a fast track.
213  FailureOr<uint64_t> delta =
215  if (succeeded(delta) && *delta != 0)
216  return true;
217 
218  FailureOr<bool> testEqual =
219  ValueBoundsConstraintSet::areEqual(indexA, indexB);
220  if (succeeded(testEqual) && !testEqual.value())
221  return true;
222  }
223  } else {
224  // For this dimension, we slice a part of the memref we need to make sure
225  // the intervals accessed don't overlap.
226  int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
227  if (cstIndexA.has_value() && cstIndexB.has_value()) {
228  int64_t distance = std::abs(*cstIndexA - *cstIndexB);
229  if (distance >= vectorDim)
230  return true;
231  continue;
232  }
233  if (testDynamicValueUsingBounds) {
234  // First try to see if we can fully compose and simplify the affine
235  // expression as a fast track.
236  FailureOr<int64_t> delta =
238  if (succeeded(delta) && std::abs(*delta) >= vectorDim)
239  return true;
240 
241  FailureOr<int64_t> computeDelta =
243  if (succeeded(computeDelta)) {
244  if (std::abs(computeDelta.value()) >= vectorDim)
245  return true;
246  }
247  }
248  }
249  }
250  return false;
251 }
252 
253 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
254  VectorTransferOpInterface transferB,
255  bool testDynamicValueUsingBounds) {
256  if (transferA.getSource() != transferB.getSource())
257  return false;
258  return isDisjointTransferIndices(transferA, transferB,
259  testDynamicValueUsingBounds);
260 }
261 
262 // Helper to iterate over n-D vector slice elements. Calculate the next
263 // `position` in the n-D vector of size `shape`, applying an offset `offsets`.
264 // Modifies the `position` in place. Returns a failure when `position` becomes
265 // the end position.
267  ArrayRef<int64_t> shape,
268  ArrayRef<int64_t> offsets) {
269  for (auto [posInDim, dimSize, offsetInDim] :
270  llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
271  ++posInDim;
272  if (posInDim < dimSize + offsetInDim)
273  return success();
274 
275  // Carry the overflow to the next loop iteration.
276  posInDim = offsetInDim;
277  }
278 
279  return failure();
280 }
281 
282 /// Returns the integer numbers in `values`. `values` are expected to be
283 /// constant operations.
286  llvm::transform(values, std::back_inserter(ints), [](Value value) {
287  auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
288  assert(constOp && "Unexpected non-constant index");
289  return constOp.value();
290  });
291  return ints;
292 }
293 
294 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
295 /// be constant operations.
298  llvm::transform(
299  foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
300  assert(foldResult.is<Attribute>() && "Unexpected non-constant index");
301  return cast<IntegerAttr>(foldResult.get<Attribute>()).getInt();
302  });
303  return ints;
304 }
305 
306 /// Convert `foldResults` into Values. Integer attributes are converted to
307 /// constant op.
309  ArrayRef<OpFoldResult> foldResults) {
310  SmallVector<Value> values;
311  llvm::transform(foldResults, std::back_inserter(values),
312  [&](OpFoldResult foldResult) {
313  if (auto attr = foldResult.dyn_cast<Attribute>())
314  return builder
316  loc, cast<IntegerAttr>(attr).getInt())
317  .getResult();
318 
319  return foldResult.get<Value>();
320  });
321  return values;
322 }
323 
324 //===----------------------------------------------------------------------===//
325 // CombiningKindAttr
326 //===----------------------------------------------------------------------===//
327 
328 namespace mlir {
329 namespace vector {
330 namespace detail {
332  using KeyTy = uint64_t;
333 
334  BitmaskEnumStorage(KeyTy val) : value(val) {}
335 
336  bool operator==(const KeyTy &key) const { return value == key; }
337 
339  const KeyTy &key) {
340  return new (allocator.allocate<BitmaskEnumStorage>())
341  BitmaskEnumStorage(key);
342  }
343 
344  KeyTy value = 0;
345 };
346 } // namespace detail
347 } // namespace vector
348 } // namespace mlir
349 
350 //===----------------------------------------------------------------------===//
351 // VectorDialect
352 //===----------------------------------------------------------------------===//
353 
354 namespace {
355 /// This class defines the interface for handling inlining with vector dialect
356 /// operations.
357 struct VectorInlinerInterface : public DialectInlinerInterface {
359 
360  /// All vector dialect ops can be inlined.
361  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
362  return true;
363  }
364 };
365 } // namespace
366 
367 void VectorDialect::initialize() {
368  addAttributes<
369 #define GET_ATTRDEF_LIST
370 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
371  >();
372 
373  addOperations<
374 #define GET_OP_LIST
375 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
376  >();
377 
378  addInterfaces<VectorInlinerInterface>();
379 
380  declarePromisedInterfaces<bufferization::BufferizableOpInterface,
381  TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
382  YieldOp>();
383  declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
384  TransferWriteOp>();
385  declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
386  declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
387 }
388 
389 /// Materialize a single constant operation from a given attribute value with
390 /// the desired resultant type.
392  Attribute value, Type type,
393  Location loc) {
394  return arith::ConstantOp::materialize(builder, value, type, loc);
395 }
396 
398  return builder.getIntegerType(64);
399 }
400 
402  ArrayRef<int64_t> values) {
403  return builder.getI64ArrayAttr(values);
404 }
405 
406 //===----------------------------------------------------------------------===//
407 // MultiDimReductionOp
408 //===----------------------------------------------------------------------===//
409 
410 void vector::MultiDimReductionOp::build(OpBuilder &builder,
411  OperationState &result, Value source,
412  Value acc, ArrayRef<bool> reductionMask,
413  CombiningKind kind) {
414  SmallVector<int64_t> reductionDims;
415  for (const auto &en : llvm::enumerate(reductionMask))
416  if (en.value())
417  reductionDims.push_back(en.index());
418  build(builder, result, kind, source, acc,
419  builder.getI64ArrayAttr(reductionDims));
420 }
421 
422 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
423  // Single parallel dim, this is a noop.
424  if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
425  return getSource();
426  return {};
427 }
428 
429 std::optional<SmallVector<int64_t, 4>>
430 MultiDimReductionOp::getShapeForUnroll() {
431  return llvm::to_vector<4>(getSourceVectorType().getShape());
432 }
433 
435  SmallVector<int64_t> targetShape;
436  SmallVector<bool> scalableDims;
437  Type inferredReturnType;
438  auto sourceScalableDims = getSourceVectorType().getScalableDims();
439  for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
440  if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
441  return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
442  })) {
443  targetShape.push_back(it.value());
444  scalableDims.push_back(sourceScalableDims[it.index()]);
445  }
446  // TODO: update to also allow 0-d vectors when available.
447  if (targetShape.empty())
448  inferredReturnType = getSourceVectorType().getElementType();
449  else
450  inferredReturnType = VectorType::get(
451  targetShape, getSourceVectorType().getElementType(), scalableDims);
452  if (getType() != inferredReturnType)
453  return emitOpError() << "destination type " << getType()
454  << " is incompatible with source type "
455  << getSourceVectorType();
456 
457  return success();
458 }
459 
460 /// Returns the mask type expected by this operation.
461 Type MultiDimReductionOp::getExpectedMaskType() {
462  auto vecType = getSourceVectorType();
463  return VectorType::get(vecType.getShape(),
464  IntegerType::get(vecType.getContext(), /*width=*/1),
465  vecType.getScalableDims());
466 }
467 
468 namespace {
469 // Only unit dimensions that are being reduced are folded. If the dimension is
470 // unit, but not reduced, it is not folded, thereby keeping the output type the
471 // same. If not all dimensions which are reduced are of unit dimension, this
472 // transformation does nothing. This is just a generalization of
473 // ElideSingleElementReduction for ReduceOp.
474 struct ElideUnitDimsInMultiDimReduction
475  : public OpRewritePattern<MultiDimReductionOp> {
477 
478  LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
479  PatternRewriter &rewriter) const override {
480  ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
481  for (const auto &dim : enumerate(shape)) {
482  if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
483  return failure();
484  }
485 
486  // Vector mask setup.
487  OpBuilder::InsertionGuard guard(rewriter);
488  Operation *rootOp;
489  Value mask;
490  if (reductionOp.isMasked()) {
491  rewriter.setInsertionPoint(reductionOp.getMaskingOp());
492  rootOp = reductionOp.getMaskingOp();
493  mask = reductionOp.getMaskingOp().getMask();
494  } else {
495  rootOp = reductionOp;
496  }
497 
498  Location loc = reductionOp.getLoc();
499  Value acc = reductionOp.getAcc();
500  Value cast;
501  if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
502  if (mask) {
503  VectorType newMaskType =
504  VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
505  dstVecType.getScalableDims());
506  mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
507  }
508  cast = rewriter.create<vector::ShapeCastOp>(
509  loc, reductionOp.getDestType(), reductionOp.getSource());
510  } else {
511  // This means we are reducing all the dimensions, and all reduction
512  // dimensions are of size 1. So a simple extraction would do.
513  SmallVector<int64_t> zeroIdx(shape.size(), 0);
514  if (mask)
515  mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx);
516  cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(),
517  zeroIdx);
518  }
519 
520  Value result =
521  vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
522  cast, /*fastmath=*/nullptr, mask);
523  rewriter.replaceOp(rootOp, result);
524  return success();
525  }
526 };
527 } // namespace
528 
529 void MultiDimReductionOp::getCanonicalizationPatterns(
530  RewritePatternSet &results, MLIRContext *context) {
531  results.add<ElideUnitDimsInMultiDimReduction>(context);
532 }
533 
534 //===----------------------------------------------------------------------===//
535 // ReductionOp
536 //===----------------------------------------------------------------------===//
537 
538 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
539  CombiningKind kind, Value vector,
540  arith::FastMathFlags fastMathFlags) {
541  build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
542 }
543 
544 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
545  CombiningKind kind, Value vector, Value acc,
546  arith::FastMathFlags fastMathFlags) {
547  build(builder, result,
548  llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
549  acc, fastMathFlags);
550 }
551 
553  // Verify for 0-D and 1-D vector.
554  int64_t rank = getSourceVectorType().getRank();
555  if (rank > 1)
556  return emitOpError("unsupported reduction rank: ") << rank;
557 
558  // Verify supported reduction kind.
559  Type eltType = getDest().getType();
560  if (!isSupportedCombiningKind(getKind(), eltType))
561  return emitOpError("unsupported reduction type '")
562  << eltType << "' for kind '" << stringifyCombiningKind(getKind())
563  << "'";
564 
565  return success();
566 }
567 
568 // MaskableOpInterface methods.
569 
570 /// Returns the mask type expected by this operation.
571 Type ReductionOp::getExpectedMaskType() {
572  auto vecType = getSourceVectorType();
573  return VectorType::get(vecType.getShape(),
574  IntegerType::get(vecType.getContext(), /*width=*/1),
575  vecType.getScalableDims());
576 }
577 
579  OpBuilder &builder, Location loc,
580  Value vector) {
581  switch (op) {
582  case arith::AtomicRMWKind::addf:
583  case arith::AtomicRMWKind::addi:
584  return builder.create<vector::ReductionOp>(vector.getLoc(),
585  CombiningKind::ADD, vector);
586  case arith::AtomicRMWKind::mulf:
587  case arith::AtomicRMWKind::muli:
588  return builder.create<vector::ReductionOp>(vector.getLoc(),
589  CombiningKind::MUL, vector);
590  case arith::AtomicRMWKind::minimumf:
591  return builder.create<vector::ReductionOp>(vector.getLoc(),
592  CombiningKind::MINIMUMF, vector);
593  case arith::AtomicRMWKind::mins:
594  return builder.create<vector::ReductionOp>(vector.getLoc(),
595  CombiningKind::MINSI, vector);
596  case arith::AtomicRMWKind::minu:
597  return builder.create<vector::ReductionOp>(vector.getLoc(),
598  CombiningKind::MINUI, vector);
599  case arith::AtomicRMWKind::maximumf:
600  return builder.create<vector::ReductionOp>(vector.getLoc(),
601  CombiningKind::MAXIMUMF, vector);
602  case arith::AtomicRMWKind::maxs:
603  return builder.create<vector::ReductionOp>(vector.getLoc(),
604  CombiningKind::MAXSI, vector);
605  case arith::AtomicRMWKind::maxu:
606  return builder.create<vector::ReductionOp>(vector.getLoc(),
607  CombiningKind::MAXUI, vector);
608  case arith::AtomicRMWKind::andi:
609  return builder.create<vector::ReductionOp>(vector.getLoc(),
610  CombiningKind::AND, vector);
611  case arith::AtomicRMWKind::ori:
612  return builder.create<vector::ReductionOp>(vector.getLoc(),
613  CombiningKind::OR, vector);
614  // TODO: Add remaining reduction operations.
615  default:
616  (void)emitOptionalError(loc, "Reduction operation type not supported");
617  break;
618  }
619  return nullptr;
620 }
621 
622 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
623  return llvm::to_vector<4>(getSourceVectorType().getShape());
624 }
625 
626 namespace {
627 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
629 
630  LogicalResult matchAndRewrite(ReductionOp reductionOp,
631  PatternRewriter &rewriter) const override {
632  // Vector mask setup.
633  OpBuilder::InsertionGuard guard(rewriter);
634  auto maskableOp =
635  cast<vector::MaskableOpInterface>(reductionOp.getOperation());
636  Operation *rootOp;
637  Value mask;
638  if (maskableOp.isMasked()) {
639  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
640  rootOp = maskableOp.getMaskingOp();
641  mask = maskableOp.getMaskingOp().getMask();
642  } else {
643  rootOp = reductionOp;
644  }
645 
646  auto vectorType = reductionOp.getSourceVectorType();
647  if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
648  return failure();
649 
650  Location loc = reductionOp.getLoc();
651  Value result;
652  if (vectorType.getRank() == 0) {
653  if (mask)
654  mask = rewriter.create<ExtractElementOp>(loc, mask);
655  result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
656  } else {
657  if (mask)
658  mask = rewriter.create<ExtractOp>(loc, mask, 0);
659  result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0);
660  }
661 
662  if (Value acc = reductionOp.getAcc())
663  result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
664  result, acc,
665  reductionOp.getFastmathAttr(), mask);
666 
667  rewriter.replaceOp(rootOp, result);
668  return success();
669  }
670 };
671 } // namespace
672 
673 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
674  MLIRContext *context) {
675  results.add<ElideSingleElementReduction>(context);
676 }
677 
678 //===----------------------------------------------------------------------===//
679 // ContractionOp
680 //===----------------------------------------------------------------------===//
681 
682 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
683  Value lhs, Value rhs, Value acc,
684  ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
685  ArrayRef<IteratorType> iteratorTypes) {
686  result.addOperands({lhs, rhs, acc});
687  result.addTypes(acc.getType());
688  result.addAttribute(
689  getIndexingMapsAttrName(result.name),
690  builder.getAffineMapArrayAttr(
691  AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
692  result.addAttribute(
693  getIteratorTypesAttrName(result.name),
694  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
695  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
696  return IteratorTypeAttr::get(builder.getContext(), t);
697  }))));
698 }
699 
700 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
701  Value lhs, Value rhs, Value acc,
702  ArrayAttr indexingMaps,
703  ArrayAttr iteratorTypes) {
704  build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
705  ContractionOp::getDefaultKind());
706 }
707 
708 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
709  Value lhs, Value rhs, Value acc,
710  ArrayAttr indexingMaps,
711  ArrayAttr iteratorTypes, CombiningKind kind) {
712  result.addOperands({lhs, rhs, acc});
713  result.addTypes(acc.getType());
714  result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
715  result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
716  result.addAttribute(getKindAttrName(result.name),
717  CombiningKindAttr::get(builder.getContext(), kind));
718 }
719 
725  SmallVector<Type, 2> types;
726  Type resultType;
727  auto loc = parser.getCurrentLocation();
728  DictionaryAttr dictAttr;
729  // TODO: Unify linalg op attribute parsing.
730  if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
731  parser.parseComma() || parser.parseOperand(rhsInfo) ||
732  parser.parseComma() || parser.parseOperand(accInfo) ||
733  parser.parseTrailingOperandList(masksInfo) ||
734  parser.parseOptionalAttrDict(result.attributes) ||
735  parser.parseColonTypeList(types) ||
736  parser.parseKeywordType("into", resultType) ||
737  parser.resolveOperand(lhsInfo, types[0], result.operands) ||
738  parser.resolveOperand(rhsInfo, types[1], result.operands) ||
739  parser.resolveOperand(accInfo, resultType, result.operands) ||
740  parser.addTypeToList(resultType, result.types))
741  return failure();
742  result.attributes.append(dictAttr.getValue().begin(),
743  dictAttr.getValue().end());
744 
745  // Convert array of string into an array of IteratyType enums. This is needed,
746  // because tests still use the old format when 'iterator_types' attribute is
747  // represented as an array of strings.
748  // TODO: Remove this conversion once tests are fixed.
749  ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
750  result.attributes.get(getIteratorTypesAttrName(result.name)));
751 
752  SmallVector<Attribute> iteratorTypeAttrs;
753 
754  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
755  auto maybeIteratorType = symbolizeIteratorType(s);
756  if (!maybeIteratorType.has_value())
757  return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
758 
759  iteratorTypeAttrs.push_back(
760  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
761  }
762  result.attributes.set(getIteratorTypesAttrName(result.name),
763  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
764 
765  if (!result.attributes.get(getKindAttrName(result.name))) {
766  result.addAttribute(
767  getKindAttrName(result.name),
769  ContractionOp::getDefaultKind()));
770  }
771  if (masksInfo.empty())
772  return success();
773  if (masksInfo.size() != 2)
774  return parser.emitError(parser.getNameLoc(),
775  "expected zero or exactly 2 vector mask operands");
776  auto lhsType = llvm::cast<VectorType>(types[0]);
777  auto rhsType = llvm::cast<VectorType>(types[1]);
778  auto maskElementType = parser.getBuilder().getI1Type();
779  std::array<VectorType, 2> maskTypes = {
780  VectorType::Builder(lhsType).setElementType(maskElementType),
781  VectorType::Builder(rhsType).setElementType(maskElementType)};
782  if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
783  return failure();
784  return success();
785 }
786 
788  // TODO: Unify printing code with linalg ops.
789  auto attrNames = getTraitAttrNames();
790  llvm::StringSet<> traitAttrsSet;
791  traitAttrsSet.insert(attrNames.begin(), attrNames.end());
793  for (auto attr : (*this)->getAttrs()) {
794  if (attr.getName() == getIteratorTypesAttrName()) {
795  auto iteratorTypes =
796  llvm::cast<ArrayAttr>(attr.getValue())
797  .getAsValueRange<IteratorTypeAttr, IteratorType>();
798  // Convert IteratorType enums into the string representation. This is
799  // needed, because tests still use the old format when 'iterator_types'
800  // attribute is represented as an array of strings.
801  // TODO: Remove this conversion once tests are fixed.
802  SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
803  llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
804  return StringAttr::get(getContext(), stringifyIteratorType(t));
805  }));
806 
807  attrs.emplace_back(getIteratorTypesAttrName(),
808  ArrayAttr::get(getContext(), iteratorTypeNames));
809  } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
810  attrs.push_back(attr);
811  }
812 
813  auto dictAttr = DictionaryAttr::get(getContext(), attrs);
814  p << " " << dictAttr << " " << getLhs() << ", ";
815  p << getRhs() << ", " << getAcc();
816 
817  p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
818  p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
819  << getResultType();
820 }
821 
822 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
823  const std::vector<std::pair<int64_t, int64_t>> &map) {
824  for (auto &dimPair : map) {
825  if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
826  dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
827  lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
828  return false;
829  }
830  return true;
831 }
832 
834  ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
835  Type resType,
836  const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
837  const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
838  DenseSet<int64_t> lhsContractingDimSet;
839  DenseSet<int64_t> rhsContractingDimSet;
840  for (auto &dimPair : contractingDimMap) {
841  lhsContractingDimSet.insert(dimPair.first);
842  rhsContractingDimSet.insert(dimPair.second);
843  }
844  DenseSet<int64_t> rhsBatchDimSet;
845  for (auto &dimPair : batchDimMap)
846  rhsBatchDimSet.insert(dimPair.second);
847 
848  // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
849  SmallVector<int64_t, 4> expectedResultDims;
850  for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
851  if (lhsContractingDimSet.count(i) > 0)
852  continue;
853  expectedResultDims.push_back(lhsType.getDimSize(i));
854  }
855 
856  // Add free dimensions from 'rhsType' to 'expectedResultDims'.
857  for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
858  if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
859  continue;
860  expectedResultDims.push_back(rhsType.getDimSize(i));
861  }
862 
863  // Verify 'expectedResultDims'.
864  if (expectedResultDims.empty()) {
865  // No batch or free dimension implies a scalar result.
866  if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
867  return op.emitOpError("invalid accumulator/result vector shape");
868  } else {
869  // At least one batch or free dimension implies a vector result.
870  auto resVectorType = llvm::dyn_cast<VectorType>(resType);
871  auto accVectorType = llvm::dyn_cast<VectorType>(accType);
872  if (!resVectorType || !accVectorType)
873  return op.emitOpError("invalid accumulator/result vector shape");
874 
875  // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
876  // types fully define the result vector type. This assumes the affine maps
877  // are well-formed, which must have been verified already.
878  MLIRContext *ctx = op.getContext();
879  AffineMap lhsMap = op.getIndexingMapsArray()[0];
880  AffineMap rhsMap = op.getIndexingMapsArray()[1];
881  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
882  return op.emitOpError(
883  "expected all dimensions to be either a LHS or a RHS dimension");
884  SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
885  for (auto pair :
886  {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
887  VectorType v = pair.first;
888  auto map = pair.second;
889  for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
890  unsigned pos = map.getDimPosition(idx);
891  if (!extents[pos])
892  extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
893  }
894  }
895  if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
896  return op.emitOpError("expected all dimensions to get an extent as "
897  "either a LHS or a RHS dimension");
898 
899  AffineMap resMap = op.getIndexingMapsArray()[2];
900  auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
901  /*symCount=*/0, extents, ctx);
902  // Compose the resMap with the extentsMap, which is a constant map.
903  AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
904  assert(
905  llvm::all_of(expectedMap.getResults(),
906  [](AffineExpr e) { return isa<AffineConstantExpr>(e); }) &&
907  "expected constant extent along all dimensions.");
908  // Extract the expected shape and build the type.
909  auto expectedShape = llvm::to_vector<4>(
910  llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
911  return cast<AffineConstantExpr>(e).getValue();
912  }));
913  auto expected =
914  VectorType::get(expectedShape, resVectorType.getElementType(),
915  resVectorType.getScalableDims());
916  if (resVectorType != expected || accVectorType != expected)
917  return op.emitOpError(
918  "invalid accumulator/result vector shape, expected: ")
919  << expected;
920  }
921  return success();
922 }
923 
925  VectorType lhsType = getLhsType();
926  VectorType rhsType = getRhsType();
927  Type accType = getAccType();
928  Type resType = getResultType();
929 
930  if (llvm::isa<IntegerType>(lhsType.getElementType())) {
931  if (!lhsType.getElementType().isSignlessInteger())
932  return emitOpError("only supports signless integer types");
933  }
934 
935  // Verify that an indexing map was specified for each vector operand.
936  if (getIndexingMapsArray().size() != 3)
937  return emitOpError("expected an indexing map for each vector operand");
938 
939  // Verify that each index map has 'numIterators' inputs, no symbols, and
940  // that the number of map outputs equals the rank of its associated
941  // vector operand.
942  unsigned numIterators = getIteratorTypes().getValue().size();
943  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
944  auto index = it.index();
945  auto map = it.value();
946  if (map.getNumSymbols() != 0)
947  return emitOpError("expected indexing map ")
948  << index << " to have no symbols";
949  auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
950  unsigned rank = vectorType ? vectorType.getShape().size() : 0;
951  // Verify that the map has the right number of inputs, outputs, and indices.
952  // This also correctly accounts for (..) -> () for rank-0 results.
953  if (map.getNumDims() != numIterators)
954  return emitOpError("expected indexing map ")
955  << index << " to have " << numIterators << " number of inputs";
956  if (map.getNumResults() != rank)
957  return emitOpError("expected indexing map ")
958  << index << " to have " << rank << " number of outputs";
959  if (!map.isProjectedPermutation())
960  return emitOpError("expected indexing map ")
961  << index << " to be a projected permutation of its inputs";
962  }
963 
964  auto contractingDimMap = getContractingDimMap();
965  auto batchDimMap = getBatchDimMap();
966 
967  // Verify at least one contracting dimension pair was specified.
968  if (contractingDimMap.empty())
969  return emitOpError("expected at least one contracting dimension pair");
970 
971  // Verify contracting dimension map was properly constructed.
972  if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
973  return emitOpError("invalid contracting dimension map");
974 
975  // Verify batch dimension map was properly constructed.
976  if (!verifyDimMap(lhsType, rhsType, batchDimMap))
977  return emitOpError("invalid batch dimension map");
978 
979  // Verify 'accType' and 'resType' shape.
980  if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
981  contractingDimMap, batchDimMap)))
982  return failure();
983 
984  // Verify supported combining kind.
985  auto vectorType = llvm::dyn_cast<VectorType>(resType);
986  auto elementType = vectorType ? vectorType.getElementType() : resType;
987  if (!isSupportedCombiningKind(getKind(), elementType))
988  return emitOpError("unsupported contraction type");
989 
990  return success();
991 }
992 
993 // MaskableOpInterface methods.
994 
995 /// Returns the mask type expected by this operation. Mostly used for
996 /// verification purposes. It requires the operation to be vectorized."
997 Type ContractionOp::getExpectedMaskType() {
998  auto indexingMaps = this->getIndexingMapsArray();
999  AffineMap lhsIdxMap = indexingMaps[0];
1000  AffineMap rhsIdxMap = indexingMaps[1];
1001  VectorType lhsType = this->getLhsType();
1002  VectorType rhsType = this->getRhsType();
1003 
1004  unsigned numVecDims = lhsIdxMap.getNumDims();
1005  SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1006  SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1007 
1008  // Using the information in the indexing maps, extract the size of each
1009  // dimension in the vector.contract operation from the two input operands.
1010  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1011  maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1012  maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1013  lhsType.getScalableDims()[dimIdx];
1014  }
1015  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1016  maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1017  maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1018  rhsType.getScalableDims()[dimIdx];
1019  }
1020 
1021  assert(!ShapedType::isDynamicShape(maskShape) &&
1022  "Mask shape couldn't be computed");
1023 
1024  return VectorType::get(maskShape,
1025  IntegerType::get(lhsType.getContext(), /*width=*/1),
1026  maskShapeScalableDims);
1027 }
1028 
1029 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1030  return SmallVector<StringRef>{getIndexingMapsAttrName(),
1031  getIteratorTypesAttrName(), getKindAttrName()};
1032 }
1033 
1034 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
1035  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1036  if (targetExpr == map.getResult(i))
1037  return i;
1038  return -1;
1039 }
1040 
1041 static std::vector<std::pair<int64_t, int64_t>>
1042 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1043  IteratorType targetIteratorType, MLIRContext *context) {
1044  std::vector<std::pair<int64_t, int64_t>> dimMap;
1045  for (const auto &it : llvm::enumerate(iteratorTypes)) {
1046  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1047  if (iteratorType != targetIteratorType)
1048  continue;
1049  // Search lhs/rhs map results for 'targetExpr'.
1050  auto targetExpr = getAffineDimExpr(it.index(), context);
1051  int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1052  int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1053  if (lhsDim >= 0 && rhsDim >= 0)
1054  dimMap.emplace_back(lhsDim, rhsDim);
1055  }
1056  return dimMap;
1057 }
1058 
1059 void ContractionOp::getIterationBounds(
1060  SmallVectorImpl<int64_t> &iterationBounds) {
1061  auto lhsShape = getLhsType().getShape();
1062  auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1063  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1064  SmallVector<int64_t, 2> iterationShape;
1065  for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1066  // Search lhs/rhs map results for 'targetExpr'.
1067  auto targetExpr = getAffineDimExpr(it.index(), getContext());
1068  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1069  if (iteratorType == IteratorType::reduction) {
1070  // Get reduction dim size from lhs shape (same size in rhsShape).
1071  int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1072  assert(lhsDimIndex >= 0);
1073  iterationBounds.push_back(lhsShape[lhsDimIndex]);
1074  continue;
1075  }
1076  // Get parallel dimension size from result shape.
1077  int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1078  assert(resDimIndex >= 0);
1079  assert(resVectorType != nullptr);
1080  iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1081  }
1082 }
1083 
1084 void ContractionOp::getIterationIndexMap(
1085  std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1086  unsigned numMaps = getIndexingMapsArray().size();
1087  iterationIndexMap.resize(numMaps);
1088  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1089  auto index = it.index();
1090  auto map = it.value();
1091  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1092  auto dim = cast<AffineDimExpr>(map.getResult(i));
1093  iterationIndexMap[index][dim.getPosition()] = i;
1094  }
1095  }
1096 }
1097 
1098 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1099  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1100  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1101  getContext());
1102 }
1103 
1104 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1105  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1106  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1107  getContext());
1108 }
1109 
1110 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1112  getIterationBounds(shape);
1113  return shape;
1114 }
1115 
1116 /// Return a fused vector::ContractionOp which represents a patterns such as:
1117 ///
1118 /// ```mlir
1119 /// %c0 = vector.constant 0: ...
1120 /// %c = vector.contract %a, %b, %c0: ...
1121 /// %e = add %c, %d: ...
1122 /// ```
1123 ///
1124 /// by:
1125 ///
1126 /// ```mlir
1127 /// %e = vector.contract %a, %b, %d: ...
1128 /// ```
1129 ///
1130 /// Return null if the canonicalization does not apply.
1131 // TODO: This should be a folding of Add into Contract in core but while they
1132 // live in different dialects, it is not possible without unnatural
1133 // dependencies.
1134 template <typename AddOpType>
1135 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1137 
1139  PatternRewriter &rewriter) const override {
1140  auto canonicalize = [&](Value maybeContraction,
1141  Value otherOperand) -> vector::ContractionOp {
1142  vector::ContractionOp contractionOp =
1143  dyn_cast_or_null<vector::ContractionOp>(
1144  maybeContraction.getDefiningOp());
1145  if (!contractionOp)
1146  return vector::ContractionOp();
1147  if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1148  contractionOp.getAcc().getDefiningOp())) {
1149  if (maybeZero.getValue() ==
1150  rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1151  IRMapping bvm;
1152  bvm.map(contractionOp.getAcc(), otherOperand);
1153  auto newContraction =
1154  cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1155  rewriter.replaceOp(addOp, newContraction.getResult());
1156  return newContraction;
1157  }
1158  }
1159  return vector::ContractionOp();
1160  };
1161 
1162  Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1163  vector::ContractionOp contract = canonicalize(a, b);
1164  contract = contract ? contract : canonicalize(b, a);
1165  return contract ? success() : failure();
1166  }
1167 };
1168 
1169 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1170  MLIRContext *context) {
1173 }
1174 
1175 //===----------------------------------------------------------------------===//
1176 // ExtractElementOp
1177 //===----------------------------------------------------------------------===//
1178 
1179 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1180  Value source) {
1181  result.addOperands({source});
1182  result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
1183 }
1184 
1186  VectorType vectorType = getSourceVectorType();
1187  if (vectorType.getRank() == 0) {
1188  if (getPosition())
1189  return emitOpError("expected position to be empty with 0-D vector");
1190  return success();
1191  }
1192  if (vectorType.getRank() != 1)
1193  return emitOpError("unexpected >1 vector rank");
1194  if (!getPosition())
1195  return emitOpError("expected position for 1-D vector");
1196  return success();
1197 }
1198 
1199 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1200  // Skip the 0-D vector here now.
1201  if (!adaptor.getPosition())
1202  return {};
1203 
1204  // Fold extractelement (splat X) -> X.
1205  if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
1206  return splat.getInput();
1207 
1208  // Fold extractelement(broadcast(X)) -> X.
1209  if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1210  if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
1211  return broadcast.getSource();
1212 
1213  auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1214  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1215  if (!pos || !src)
1216  return {};
1217 
1218  auto srcElements = src.getValues<Attribute>();
1219 
1220  uint64_t posIdx = pos.getInt();
1221  if (posIdx >= srcElements.size())
1222  return {};
1223 
1224  return srcElements[posIdx];
1225 }
1226 
1227 //===----------------------------------------------------------------------===//
1228 // ExtractOp
1229 //===----------------------------------------------------------------------===//
1230 
1231 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1232  Value source, int64_t position) {
1233  build(builder, result, source, ArrayRef<int64_t>{position});
1234 }
1235 
1236 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1237  Value source, OpFoldResult position) {
1238  build(builder, result, source, ArrayRef<OpFoldResult>{position});
1239 }
1240 
1241 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1242  Value source, ArrayRef<int64_t> position) {
1243  build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1244  builder.getDenseI64ArrayAttr(position));
1245 }
1246 
1247 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1248  Value source, ArrayRef<OpFoldResult> position) {
1249  SmallVector<int64_t> staticPos;
1250  SmallVector<Value> dynamicPos;
1251  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1252  build(builder, result, source, dynamicPos,
1253  builder.getDenseI64ArrayAttr(staticPos));
1254 }
1255 
1257 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1258  ExtractOp::Adaptor adaptor,
1259  SmallVectorImpl<Type> &inferredReturnTypes) {
1260  auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1261  if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1262  vectorType.getRank()) {
1263  inferredReturnTypes.push_back(vectorType.getElementType());
1264  } else {
1265  auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1266  vectorType.getRank());
1267  inferredReturnTypes.push_back(VectorType::get(
1268  vectorType.getShape().drop_front(n), vectorType.getElementType(),
1269  vectorType.getScalableDims().drop_front(n)));
1270  }
1271  return success();
1272 }
1273 
1274 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1275  // Allow extracting 1-element vectors instead of scalars.
1276  auto isCompatible = [](TypeRange l, TypeRange r) {
1277  auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1278  return vectorType && vectorType.getShape().equals({1}) &&
1279  vectorType.getElementType() == r.front();
1280  };
1281  if (l.size() == 1 && r.size() == 1 &&
1282  (isCompatible(l, r) || isCompatible(r, l)))
1283  return true;
1284  return l == r;
1285 }
1286 
1288  // Note: This check must come before getMixedPosition() to prevent a crash.
1289  auto dynamicMarkersCount =
1290  llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1291  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1292  return emitOpError(
1293  "mismatch between dynamic and static positions (kDynamic marker but no "
1294  "corresponding dynamic position) -- this can only happen due to an "
1295  "incorrect fold/rewrite");
1296  auto position = getMixedPosition();
1297  if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1298  return emitOpError(
1299  "expected position attribute of rank no greater than vector rank");
1300  for (auto [idx, pos] : llvm::enumerate(position)) {
1301  if (pos.is<Attribute>()) {
1302  int64_t constIdx = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
1303  if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
1304  return emitOpError("expected position attribute #")
1305  << (idx + 1)
1306  << " to be a non-negative integer smaller than the "
1307  "corresponding vector dimension";
1308  }
1309  }
1310  }
1311  return success();
1312 }
1313 
1314 template <typename IntType>
1315 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1316  return llvm::to_vector<4>(llvm::map_range(
1317  arrayAttr.getAsRange<IntegerAttr>(),
1318  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1319 }
1320 
1321 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1322 /// positions.
1323 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1324  if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1325  return failure();
1326 
1327  // TODO: Canonicalization for dynamic position not implemented yet.
1328  if (extractOp.hasDynamicPosition())
1329  return failure();
1330 
1331  SmallVector<int64_t> globalPosition;
1332  ExtractOp currentOp = extractOp;
1333  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1334  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1335  while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1336  currentOp = nextOp;
1337  // TODO: Canonicalization for dynamic position not implemented yet.
1338  if (currentOp.hasDynamicPosition())
1339  return failure();
1340  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1341  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1342  }
1343  extractOp.setOperand(0, currentOp.getVector());
1344  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1345  OpBuilder b(extractOp.getContext());
1346  std::reverse(globalPosition.begin(), globalPosition.end());
1347  extractOp.setStaticPosition(globalPosition);
1348  return success();
1349 }
1350 
1351 namespace {
1352 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1353 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1354 /// Compose TransposeOp permutations as we walk back.
1355 /// This helper class keeps an updated extraction position `extractPosition`
1356 /// with extra trailing sentinels.
1357 /// The sentinels encode the internal transposition status of the result vector.
1358 /// As we iterate, extractPosition is permuted and updated.
1359 class ExtractFromInsertTransposeChainState {
1360 public:
1361  ExtractFromInsertTransposeChainState(ExtractOp e);
1362 
1363  /// Iterate over producing insert and transpose ops until we find a fold.
1364  Value fold();
1365 
1366 private:
1367  /// Return true if the vector at position `a` is contained within the vector
1368  /// at position `b`. Under insert/extract semantics, this is the same as `a`
1369  /// is a prefix of `b`.
1370  template <typename ContainerA, typename ContainerB>
1371  bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1372  return a.size() <= b.size() &&
1373  std::equal(a.begin(), a.begin() + a.size(), b.begin());
1374  }
1375 
1376  /// Return true if the vector at position `a` intersects the vector at
1377  /// position `b`. Under insert/extract semantics, this is the same as equality
1378  /// of all entries of `a` that are >=0 with the corresponding entries of b.
1379  /// Comparison is on the common prefix (i.e. zip).
1380  template <typename ContainerA, typename ContainerB>
1381  bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1382  for (auto [elemA, elemB] : llvm::zip(a, b)) {
1383  if (elemA < 0 || elemB < 0)
1384  continue;
1385  if (elemA != elemB)
1386  return false;
1387  }
1388  return true;
1389  }
1390 
1391  /// Folding is only possible in the absence of an internal permutation in the
1392  /// result vector.
1393  bool canFold() {
1394  return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1395  }
1396 
1397  // Helper to get the next defining op of interest.
1398  void updateStateForNextIteration(Value v) {
1399  nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1400  nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1401  };
1402 
1403  // Case 1. If we hit a transpose, just compose the map and iterate.
1404  // Invariant: insert + transpose do not change rank, we can always compose.
1405  LogicalResult handleTransposeOp();
1406 
1407  // Case 2: the insert position matches extractPosition exactly, early return.
1408  LogicalResult handleInsertOpWithMatchingPos(Value &res);
1409 
1410  /// Case 3: if the insert position is a prefix of extractPosition, extract a
1411  /// portion of the source of the insert.
1412  /// Example:
1413  /// ```
1414  /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1415  /// // extractPosition == [1, 2, 3]
1416  /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1417  /// // can fold to vector.extract %source[0, 3]
1418  /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1419  /// ```
1420  /// To traverse through %source, we need to set the leading dims to 0 and
1421  /// drop the extra leading dims.
1422  /// This method updates the internal state.
1423  LogicalResult handleInsertOpWithPrefixPos(Value &res);
1424 
1425  /// Try to fold in place to extract(source, extractPosition) and return the
1426  /// folded result. Return null if folding is not possible (e.g. due to an
1427  /// internal tranposition in the result).
1428  Value tryToFoldExtractOpInPlace(Value source);
1429 
1430  ExtractOp extractOp;
1431  int64_t vectorRank;
1432  int64_t extractedRank;
1433 
1434  InsertOp nextInsertOp;
1435  TransposeOp nextTransposeOp;
1436 
1437  /// Sentinel values that encode the internal permutation status of the result.
1438  /// They are set to (-1, ... , -k) at the beginning and appended to
1439  /// `extractPosition`.
1440  /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1441  /// ensure that there is no internal transposition.
1442  /// Internal transposition cannot be accounted for with a folding pattern.
1443  // TODO: We could relax the internal transposition with an extra transposition
1444  // operation in a future canonicalizer.
1445  SmallVector<int64_t> sentinels;
1447 };
1448 } // namespace
1449 
1450 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1451  ExtractOp e)
1452  : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1453  extractedRank(extractOp.getNumIndices()) {
1454  assert(vectorRank >= extractedRank && "Extracted position overflow");
1455  sentinels.reserve(vectorRank - extractedRank);
1456  for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1457  sentinels.push_back(-(i + 1));
1458  extractPosition.assign(extractOp.getStaticPosition().begin(),
1459  extractOp.getStaticPosition().end());
1460  llvm::append_range(extractPosition, sentinels);
1461 }
1462 
1463 // Case 1. If we hit a transpose, just compose the map and iterate.
1464 // Invariant: insert + transpose do not change rank, we can always compose.
1465 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1466  // TODO: Canonicalization for dynamic position not implemented yet.
1467  if (extractOp.hasDynamicPosition())
1468  return failure();
1469 
1470  if (!nextTransposeOp)
1471  return failure();
1472  AffineMap m = inversePermutation(AffineMap::getPermutationMap(
1473  nextTransposeOp.getPermutation(), extractOp.getContext()));
1475  return success();
1476 }
1477 
1478 // Case 2: the insert position matches extractPosition exactly, early return.
1480 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1481  Value &res) {
1482  // TODO: Canonicalization for dynamic position not implemented yet.
1483  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1484  return failure();
1485 
1486  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1487  if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1488  return failure();
1489  // Case 2.a. early-exit fold.
1490  res = nextInsertOp.getSource();
1491  // Case 2.b. if internal transposition is present, canFold will be false.
1492  return success(canFold());
1493 }
1494 
1495 /// Case 3: if inserted position is a prefix of extractPosition,
1496 /// extract a portion of the source of the insertion.
1497 /// This method updates the internal state.
1499 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1500  // TODO: Canonicalization for dynamic position not implemented yet.
1501  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1502  return failure();
1503 
1504  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1505  if (!isContainedWithin(insertedPos, extractPosition))
1506  return failure();
1507  // Set leading dims to zero.
1508  std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1509  // Drop extra leading dims.
1510  extractPosition.erase(extractPosition.begin(),
1511  extractPosition.begin() + insertedPos.size());
1512  extractedRank = extractPosition.size() - sentinels.size();
1513  // Case 3.a. early-exit fold (break and delegate to post-while path).
1514  res = nextInsertOp.getSource();
1515  // Case 3.b. if internal transposition is present, canFold will be false.
1516  return success();
1517 }
1518 
1519 /// Try to fold in place to extract(source, extractPosition) and return the
1520 /// folded result. Return null if folding is not possible (e.g. due to an
1521 /// internal tranposition in the result).
1522 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1523  Value source) {
1524  // TODO: Canonicalization for dynamic position not implemented yet.
1525  if (extractOp.hasDynamicPosition())
1526  return Value();
1527 
1528  // If we can't fold (either internal transposition, or nothing to fold), bail.
1529  bool nothingToFold = (source == extractOp.getVector());
1530  if (nothingToFold || !canFold())
1531  return Value();
1532 
1533  // Otherwise, fold by updating the op inplace and return its result.
1534  OpBuilder b(extractOp.getContext());
1535  extractOp.setStaticPosition(
1536  ArrayRef(extractPosition).take_front(extractedRank));
1537  extractOp.getVectorMutable().assign(source);
1538  return extractOp.getResult();
1539 }
1540 
1541 /// Iterate over producing insert and transpose ops until we find a fold.
1542 Value ExtractFromInsertTransposeChainState::fold() {
1543  // TODO: Canonicalization for dynamic position not implemented yet.
1544  if (extractOp.hasDynamicPosition())
1545  return Value();
1546 
1547  Value valueToExtractFrom = extractOp.getVector();
1548  updateStateForNextIteration(valueToExtractFrom);
1549  while (nextInsertOp || nextTransposeOp) {
1550  // Case 1. If we hit a transpose, just compose the map and iterate.
1551  // Invariant: insert + transpose do not change rank, we can always compose.
1552  if (succeeded(handleTransposeOp())) {
1553  valueToExtractFrom = nextTransposeOp.getVector();
1554  updateStateForNextIteration(valueToExtractFrom);
1555  continue;
1556  }
1557 
1558  Value result;
1559  // Case 2: the position match exactly.
1560  if (succeeded(handleInsertOpWithMatchingPos(result)))
1561  return result;
1562 
1563  // Case 3: if the inserted position is a prefix of extractPosition, we can
1564  // just extract a portion of the source of the insert.
1565  if (succeeded(handleInsertOpWithPrefixPos(result)))
1566  return tryToFoldExtractOpInPlace(result);
1567 
1568  // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1569  // values. This is a more difficult case and we bail.
1570  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1571  if (isContainedWithin(extractPosition, insertedPos) ||
1572  intersectsWhereNonNegative(extractPosition, insertedPos))
1573  return Value();
1574 
1575  // Case 5: No intersection, we forward the extract to insertOp.dest().
1576  valueToExtractFrom = nextInsertOp.getDest();
1577  updateStateForNextIteration(valueToExtractFrom);
1578  }
1579  // If after all this we can fold, go for it.
1580  return tryToFoldExtractOpInPlace(valueToExtractFrom);
1581 }
1582 
1583 /// Returns true if the operation has a 0-D vector type operand or result.
1584 static bool hasZeroDimVectors(Operation *op) {
1585  auto hasZeroDimVectorType = [](Type type) -> bool {
1586  auto vecType = dyn_cast<VectorType>(type);
1587  return vecType && vecType.getRank() == 0;
1588  };
1589 
1590  return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1591  llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1592 }
1593 
1594 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1595 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1596  // TODO: Canonicalization for dynamic position not implemented yet.
1597  if (extractOp.hasDynamicPosition())
1598  return Value();
1599 
1600  Operation *defOp = extractOp.getVector().getDefiningOp();
1601  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1602  return Value();
1603 
1604  // 0-D vectors not supported.
1605  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1606  if (hasZeroDimVectors(defOp))
1607  return Value();
1608 
1609  Value source = defOp->getOperand(0);
1610  if (extractOp.getType() == source.getType())
1611  return source;
1612  auto getRank = [](Type type) {
1613  return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1614  : 0;
1615  };
1616 
1617  // If splat or broadcast from a scalar, just return the source scalar.
1618  unsigned broadcastSrcRank = getRank(source.getType());
1619  if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
1620  return source;
1621 
1622  unsigned extractResultRank = getRank(extractOp.getType());
1623  if (extractResultRank >= broadcastSrcRank)
1624  return Value();
1625  // Check that the dimension of the result haven't been broadcasted.
1626  auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1627  auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
1628  if (extractVecType && broadcastVecType &&
1629  extractVecType.getShape() !=
1630  broadcastVecType.getShape().take_back(extractResultRank))
1631  return Value();
1632 
1633  auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1634  int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1635 
1636  // Detect all the positions that come from "dim-1" broadcasting.
1637  // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1638  // extract position to `0` when extracting from the source operand.
1639  llvm::SetVector<int64_t> broadcastedUnitDims =
1640  broadcastOp.computeBroadcastedUnitDims();
1641  SmallVector<int64_t> extractPos(extractOp.getStaticPosition());
1642  int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1643  for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1644  if (broadcastedUnitDims.contains(i))
1645  extractPos[i] = 0;
1646  // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1647  // matching extract position when extracting from the source operand.
1648  int64_t rankDiff = broadcastSrcRank - extractResultRank;
1649  extractPos.erase(extractPos.begin(),
1650  std::next(extractPos.begin(), extractPos.size() - rankDiff));
1651  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1652  OpBuilder b(extractOp.getContext());
1653  extractOp.setOperand(0, source);
1654  extractOp.setStaticPosition(extractPos);
1655  return extractOp.getResult();
1656 }
1657 
1658 // Fold extractOp with source coming from ShapeCast op.
1659 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1660  // TODO: Canonicalization for dynamic position not implemented yet.
1661  if (extractOp.hasDynamicPosition())
1662  return Value();
1663 
1664  auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1665  if (!shapeCastOp)
1666  return Value();
1667 
1668  // 0-D vectors not supported.
1669  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1670  if (hasZeroDimVectors(shapeCastOp))
1671  return Value();
1672 
1673  // Get the nth dimension size starting from lowest dimension.
1674  auto getDimReverse = [](VectorType type, int64_t n) {
1675  return type.getShape().take_back(n + 1).front();
1676  };
1677  int64_t destinationRank =
1678  llvm::isa<VectorType>(extractOp.getType())
1679  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1680  : 0;
1681  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1682  return Value();
1683  if (destinationRank > 0) {
1684  auto destinationType =
1685  llvm::cast<VectorType>(extractOp.getResult().getType());
1686  for (int64_t i = 0; i < destinationRank; i++) {
1687  // The lowest dimension of the destination must match the lowest
1688  // dimension of the shapecast op source.
1689  // TODO: This case could be support in a canonicalization pattern.
1690  if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1691  getDimReverse(destinationType, i))
1692  return Value();
1693  }
1694  }
1695  // Extract the strides associated with the extract op vector source. Then use
1696  // this to calculate a linearized position for the extract.
1697  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1698  std::reverse(extractedPos.begin(), extractedPos.end());
1699  SmallVector<int64_t, 4> strides;
1700  int64_t stride = 1;
1701  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1702  strides.push_back(stride);
1703  stride *=
1704  getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1705  }
1706 
1707  int64_t position = linearize(extractedPos, strides);
1708  // Then extract the strides associated to the shapeCast op vector source and
1709  // delinearize the position using those strides.
1710  SmallVector<int64_t, 4> newStrides;
1711  int64_t numDimension =
1712  shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1713  stride = 1;
1714  for (int64_t i = 0; i < numDimension; i++) {
1715  newStrides.push_back(stride);
1716  stride *=
1717  getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1718  }
1719  std::reverse(newStrides.begin(), newStrides.end());
1720  SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1721  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1722  OpBuilder b(extractOp.getContext());
1723  extractOp.setStaticPosition(newPosition);
1724  extractOp.setOperand(0, shapeCastOp.getSource());
1725  return extractOp.getResult();
1726 }
1727 
1728 /// Fold an ExtractOp from ExtractStridedSliceOp.
1729 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1730  // TODO: Canonicalization for dynamic position not implemented yet.
1731  if (extractOp.hasDynamicPosition())
1732  return Value();
1733 
1734  auto extractStridedSliceOp =
1735  extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1736  if (!extractStridedSliceOp)
1737  return Value();
1738 
1739  // 0-D vectors not supported.
1740  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1741  if (hasZeroDimVectors(extractStridedSliceOp))
1742  return Value();
1743 
1744  // Return if 'extractStridedSliceOp' has non-unit strides.
1745  if (extractStridedSliceOp.hasNonUnitStrides())
1746  return Value();
1747 
1748  // Trim offsets for dimensions fully extracted.
1749  auto sliceOffsets =
1750  extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1751  while (!sliceOffsets.empty()) {
1752  size_t lastOffset = sliceOffsets.size() - 1;
1753  if (sliceOffsets.back() != 0 ||
1754  extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1755  extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1756  break;
1757  sliceOffsets.pop_back();
1758  }
1759  unsigned destinationRank = 0;
1760  if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1761  destinationRank = vecType.getRank();
1762  // The dimensions of the result need to be untouched by the
1763  // extractStridedSlice op.
1764  if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1765  sliceOffsets.size())
1766  return Value();
1767 
1768  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1769  assert(extractedPos.size() >= sliceOffsets.size());
1770  for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1771  extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1772  extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1773 
1774  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1775  OpBuilder b(extractOp.getContext());
1776  extractOp.setStaticPosition(extractedPos);
1777  return extractOp.getResult();
1778 }
1779 
1780 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1781 static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1782  // TODO: Canonicalization for dynamic position not implemented yet.
1783  if (extractOp.hasDynamicPosition())
1784  return Value();
1785 
1786  int64_t destinationRank =
1787  llvm::isa<VectorType>(extractOp.getType())
1788  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1789  : 0;
1790  auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1791  if (!insertOp)
1792  return Value();
1793 
1794  // 0-D vectors not supported.
1795  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1796  if (hasZeroDimVectors(insertOp))
1797  return Value();
1798 
1799  while (insertOp) {
1800  int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1801  insertOp.getSourceVectorType().getRank();
1802  if (destinationRank > insertOp.getSourceVectorType().getRank())
1803  return Value();
1804  auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1805  ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1806 
1807  if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1808  return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1809  }))
1810  return Value();
1811  bool disjoint = false;
1812  SmallVector<int64_t, 4> offsetDiffs;
1813  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1814  int64_t start = insertOffsets[dim];
1815  int64_t size =
1816  (dim < insertRankDiff)
1817  ? 1
1818  : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1819  int64_t end = start + size;
1820  int64_t offset = extractOffsets[dim];
1821  // Check if the start of the extract offset is in the interval inserted.
1822  if (start <= offset && offset < end) {
1823  if (dim >= insertRankDiff)
1824  offsetDiffs.push_back(offset - start);
1825  continue;
1826  }
1827  disjoint = true;
1828  break;
1829  }
1830  // The extract element chunk overlap with the vector inserted.
1831  if (!disjoint) {
1832  // If any of the inner dimensions are only partially inserted we have a
1833  // partial overlap.
1834  int64_t srcRankDiff =
1835  insertOp.getSourceVectorType().getRank() - destinationRank;
1836  for (int64_t i = 0; i < destinationRank; i++) {
1837  if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1838  insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1839  insertRankDiff))
1840  return Value();
1841  }
1842  extractOp.getVectorMutable().assign(insertOp.getSource());
1843  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1844  OpBuilder b(extractOp.getContext());
1845  extractOp.setStaticPosition(offsetDiffs);
1846  return extractOp.getResult();
1847  }
1848  // If the chunk extracted is disjoint from the chunk inserted, keep
1849  // looking in the insert chain.
1850  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1851  }
1852  return Value();
1853 }
1854 
1855 OpFoldResult ExtractOp::fold(FoldAdaptor) {
1856  if (getNumIndices() == 0)
1857  return getVector();
1859  return getResult();
1860  if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
1861  return res;
1862  if (auto res = foldExtractFromBroadcast(*this))
1863  return res;
1864  if (auto res = foldExtractFromShapeCast(*this))
1865  return res;
1866  if (auto val = foldExtractFromExtractStrided(*this))
1867  return val;
1868  if (auto val = foldExtractStridedOpFromInsertChain(*this))
1869  return val;
1870  return OpFoldResult();
1871 }
1872 
1873 namespace {
1874 
1875 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
1876 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
1877 public:
1879 
1880  LogicalResult matchAndRewrite(ExtractOp extractOp,
1881  PatternRewriter &rewriter) const override {
1882  Operation *defOp = extractOp.getVector().getDefiningOp();
1883  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1884  return failure();
1885 
1886  Value source = defOp->getOperand(0);
1887  if (extractOp.getType() == source.getType())
1888  return failure();
1889  auto getRank = [](Type type) {
1890  return llvm::isa<VectorType>(type)
1891  ? llvm::cast<VectorType>(type).getRank()
1892  : 0;
1893  };
1894  unsigned broadcastSrcRank = getRank(source.getType());
1895  unsigned extractResultRank = getRank(extractOp.getType());
1896  // We only consider the case where the rank of the source is less than or
1897  // equal to the rank of the extract dst. The other cases are handled in the
1898  // folding patterns.
1899  if (extractResultRank < broadcastSrcRank)
1900  return failure();
1901 
1902  // Special case if broadcast src is a 0D vector.
1903  if (extractResultRank == 0) {
1904  assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
1905  rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
1906  return success();
1907  }
1908  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1909  extractOp, extractOp.getType(), source);
1910  return success();
1911  }
1912 };
1913 
1914 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
1915 class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
1916 public:
1918 
1919  LogicalResult matchAndRewrite(ExtractOp extractOp,
1920  PatternRewriter &rewriter) const override {
1921  // Return if 'ExtractOp' operand is not defined by a splat vector
1922  // ConstantOp.
1923  Value sourceVector = extractOp.getVector();
1924  Attribute vectorCst;
1925  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
1926  return failure();
1927  auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
1928  if (!splat)
1929  return failure();
1930  TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
1931  if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1932  newAttr = DenseElementsAttr::get(vecDstType, newAttr);
1933  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
1934  return success();
1935  }
1936 };
1937 
1938 // Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
1939 class ExtractOpNonSplatConstantFolder final
1940  : public OpRewritePattern<ExtractOp> {
1941 public:
1943 
1944  LogicalResult matchAndRewrite(ExtractOp extractOp,
1945  PatternRewriter &rewriter) const override {
1946  // TODO: Canonicalization for dynamic position not implemented yet.
1947  if (extractOp.hasDynamicPosition())
1948  return failure();
1949 
1950  // Return if 'ExtractOp' operand is not defined by a compatible vector
1951  // ConstantOp.
1952  Value sourceVector = extractOp.getVector();
1953  Attribute vectorCst;
1954  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
1955  return failure();
1956 
1957  auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
1958  if (vecTy.isScalable())
1959  return failure();
1960 
1961  // The splat case is handled by `ExtractOpSplatConstantFolder`.
1962  auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
1963  if (!dense || dense.isSplat())
1964  return failure();
1965 
1966  // Calculate the linearized position of the continuous chunk of elements to
1967  // extract.
1968  llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
1969  copy(extractOp.getStaticPosition(), completePositions.begin());
1970  int64_t elemBeginPosition =
1971  linearize(completePositions, computeStrides(vecTy.getShape()));
1972  auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
1973 
1974  TypedAttr newAttr;
1975  if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
1976  SmallVector<Attribute> elementValues(
1977  denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
1978  newAttr = DenseElementsAttr::get(resVecTy, elementValues);
1979  } else {
1980  newAttr = *denseValuesBegin;
1981  }
1982 
1983  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
1984  return success();
1985  }
1986 };
1987 
1988 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
1989 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
1990 public:
1992 
1993  LogicalResult matchAndRewrite(ExtractOp extractOp,
1994  PatternRewriter &rewriter) const override {
1995  auto createMaskOp =
1996  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
1997  if (!createMaskOp)
1998  return failure();
1999 
2000  VectorType extractedMaskType =
2001  llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2002 
2003  if (!extractedMaskType)
2004  return failure();
2005 
2006  auto maskOperands = createMaskOp.getOperands();
2007  ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2008  VectorType maskType = createMaskOp.getVectorType();
2009 
2010  bool containsUnknownDims = false;
2011  bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2012 
2013  for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2014  dimIdx++) {
2015  int64_t pos = extractOpPos[dimIdx];
2016  Value operand = maskOperands[dimIdx];
2017  auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2018  if (!constantOp) {
2019  // Bounds of this dim unknown.
2020  containsUnknownDims = true;
2021  continue;
2022  }
2023 
2024  int64_t createMaskBound =
2025  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2026 
2027  if (pos != ShapedType::kDynamic) {
2028  // If any position is outside the range from the `create_mask`, then the
2029  // extracted mask will be all-false.
2030  allFalse |= pos >= createMaskBound;
2031  } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2032  // This dim is not all-true and since this is a dynamic index we don't
2033  // know if the extraction is within the true or false region.
2034  // Note: Zero dims have already handled via getMaskFormat().
2035  containsUnknownDims = true;
2036  }
2037  }
2038 
2039  if (allFalse) {
2040  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2041  extractOp, DenseElementsAttr::get(extractedMaskType, false));
2042  } else if (!containsUnknownDims) {
2043  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2044  extractOp, extractedMaskType,
2045  maskOperands.drop_front(extractOpPos.size()));
2046  } else {
2047  return failure();
2048  }
2049  return success();
2050  }
2051 };
2052 
2053 // Folds extract(shape_cast(..)) into shape_cast when the total element count
2054 // does not change.
2055 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2056  PatternRewriter &rewriter) {
2057  auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2058  if (!castOp)
2059  return failure();
2060 
2061  VectorType sourceType = castOp.getSourceVectorType();
2062  auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2063  if (!targetType)
2064  return failure();
2065 
2066  if (sourceType.getNumElements() != targetType.getNumElements())
2067  return failure();
2068 
2069  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2070  castOp.getSource());
2071  return success();
2072 }
2073 
2074 } // namespace
2075 
2076 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2077  MLIRContext *context) {
2078  results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2079  ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2080  results.add(foldExtractFromShapeCastToShapeCast);
2081 }
2082 
2083 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
2084  SmallVectorImpl<int64_t> &results) {
2085  for (auto attr : arrayAttr)
2086  results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2087 }
2088 
2089 //===----------------------------------------------------------------------===//
2090 // FmaOp
2091 //===----------------------------------------------------------------------===//
2092 
2093 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2094  return llvm::to_vector<4>(getVectorType().getShape());
2095 }
2096 
2097 //===----------------------------------------------------------------------===//
2098 // BroadcastOp
2099 //===----------------------------------------------------------------------===//
2100 
2101 /// Return the dimensions of the result vector that were formerly ones in the
2102 /// source tensor and thus correspond to "dim-1" broadcasting.
2105  ArrayRef<int64_t> dstShape) {
2106  int64_t rankDiff = dstShape.size() - srcShape.size();
2107  int64_t dstDim = rankDiff;
2109  for (auto [s1, s2] :
2110  llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2111  if (s1 != s2) {
2112  assert(s1 == 1 && "expected dim-1 broadcasting");
2113  res.insert(dstDim);
2114  }
2115  ++dstDim;
2116  }
2117  return res;
2118 }
2119 
2121  // Scalar broadcast is without any unit dim broadcast.
2122  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2123  if (!srcVectorType)
2124  return {};
2125  return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2126  getResultVectorType().getShape());
2127 }
2128 
2129 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2130 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
2131 /// This requires (and asserts) that the broadcast is free of dim-1
2132 /// broadcasting.
2133 /// Since vector.broadcast only allows expanding leading dimensions, an extra
2134 /// vector.transpose may be inserted to make the broadcast possible.
2135 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2136 /// the helper will assert. This means:
2137 /// 1. `dstShape` must not be empty.
2138 /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2139 /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2140 // must match the `value` shape.
2141 Value BroadcastOp::createOrFoldBroadcastOp(
2142  OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2143  const llvm::SetVector<int64_t> &broadcastedDims) {
2144  assert(!dstShape.empty() && "unexpected empty dst shape");
2145 
2146  // Well-formedness check.
2147  SmallVector<int64_t> checkShape;
2148  for (int i = 0, e = dstShape.size(); i < e; ++i) {
2149  if (broadcastedDims.contains(i))
2150  continue;
2151  checkShape.push_back(dstShape[i]);
2152  }
2153  assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2154  "ill-formed broadcastedDims contains values not confined to "
2155  "destVectorShape");
2156 
2157  Location loc = value.getLoc();
2158  Type elementType = getElementTypeOrSelf(value.getType());
2159  VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2160  VectorType dstVectorType = VectorType::get(dstShape, elementType);
2161 
2162  // Step 2. If scalar -> dstShape broadcast, just do it.
2163  if (!srcVectorType) {
2164  assert(checkShape.empty() &&
2165  "ill-formed createOrFoldBroadcastOp arguments");
2166  return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2167  }
2168 
2169  assert(srcVectorType.getShape().equals(checkShape) &&
2170  "ill-formed createOrFoldBroadcastOp arguments");
2171 
2172  // Step 3. Since vector.broadcast only allows creating leading dims,
2173  // vector -> dstShape broadcast may require a transpose.
2174  // Traverse the dims in order and construct:
2175  // 1. The leading entries of the broadcastShape that is guaranteed to be
2176  // achievable by a simple broadcast.
2177  // 2. The induced permutation for the subsequent vector.transpose that will
2178  // bring us from `broadcastShape` back to he desired `dstShape`.
2179  // If the induced permutation is not the identity, create a vector.transpose.
2180  SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2181  broadcastShape.reserve(dstShape.size());
2182  // Consider the example:
2183  // srcShape = 2x4
2184  // dstShape = 1x2x3x4x5
2185  // broadcastedDims = [0, 2, 4]
2186  //
2187  // We want to build:
2188  // broadcastShape = 1x3x5x2x4
2189  // permutation = [0, 2, 4, 1, 3]
2190  // ---V--- -----V-----
2191  // leading broadcast part src shape part
2192  //
2193  // Note that the trailing dims of broadcastShape are exactly the srcShape
2194  // by construction.
2195  // nextSrcShapeDim is used to keep track of where in the permutation the
2196  // "src shape part" occurs.
2197  int64_t nextSrcShapeDim = broadcastedDims.size();
2198  for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2199  if (broadcastedDims.contains(i)) {
2200  // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2201  // bring it to the head of the broadcastShape.
2202  // It will need to be permuted back from `broadcastShape.size() - 1` into
2203  // position `i`.
2204  broadcastShape.push_back(dstShape[i]);
2205  permutation[i] = broadcastShape.size() - 1;
2206  } else {
2207  // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2208  // shape and needs to be permuted into position `i`.
2209  // Don't touch `broadcastShape` here, the whole srcShape will be
2210  // appended after.
2211  permutation[i] = nextSrcShapeDim++;
2212  }
2213  }
2214  // 3.c. Append the srcShape.
2215  llvm::append_range(broadcastShape, srcVectorType.getShape());
2216 
2217  // Ensure there are no dim-1 broadcasts.
2218  assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2219  .empty() &&
2220  "unexpected dim-1 broadcast");
2221 
2222  VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2223  assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2224  vector::BroadcastableToResult::Success &&
2225  "must be broadcastable");
2226  Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2227  // Step 4. If we find any dimension that indeed needs to be permuted,
2228  // immediately return a new vector.transpose.
2229  for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2230  if (permutation[i] != i)
2231  return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2232  // Otherwise return res.
2233  return res;
2234 }
2235 
2237 mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
2238  std::pair<int, int> *mismatchingDims) {
2239  // Broadcast scalar to vector of the same element type.
2240  if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
2241  getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
2242  return BroadcastableToResult::Success;
2243  // From now on, only vectors broadcast.
2244  VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2245  if (!srcVectorType)
2246  return BroadcastableToResult::SourceTypeNotAVector;
2247 
2248  int64_t srcRank = srcVectorType.getRank();
2249  int64_t dstRank = dstVectorType.getRank();
2250  if (srcRank > dstRank)
2251  return BroadcastableToResult::SourceRankHigher;
2252  // Source has an exact match or singleton value for all trailing dimensions
2253  // (all leading dimensions are simply duplicated).
2254  int64_t lead = dstRank - srcRank;
2255  for (int64_t r = 0; r < srcRank; ++r) {
2256  int64_t srcDim = srcVectorType.getDimSize(r);
2257  int64_t dstDim = dstVectorType.getDimSize(lead + r);
2258  if (srcDim != 1 && srcDim != dstDim) {
2259  if (mismatchingDims) {
2260  mismatchingDims->first = srcDim;
2261  mismatchingDims->second = dstDim;
2262  }
2263  return BroadcastableToResult::DimensionMismatch;
2264  }
2265  }
2266 
2267  return BroadcastableToResult::Success;
2268 }
2269 
2271  std::pair<int, int> mismatchingDims;
2273  getSourceType(), getResultVectorType(), &mismatchingDims);
2274  if (res == BroadcastableToResult::Success)
2275  return success();
2276  if (res == BroadcastableToResult::SourceRankHigher)
2277  return emitOpError("source rank higher than destination rank");
2278  if (res == BroadcastableToResult::DimensionMismatch)
2279  return emitOpError("dimension mismatch (")
2280  << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
2281  if (res == BroadcastableToResult::SourceTypeNotAVector)
2282  return emitOpError("source type is not a vector");
2283  llvm_unreachable("unexpected vector.broadcast op error");
2284 }
2285 
2286 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
2287  if (getSourceType() == getResultVectorType())
2288  return getSource();
2289  if (!adaptor.getSource())
2290  return {};
2291  auto vectorType = getResultVectorType();
2292  if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
2293  return DenseElementsAttr::get(vectorType, adaptor.getSource());
2294  if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2295  return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
2296  return {};
2297 }
2298 
2299 namespace {
2300 
2301 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
2302 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2304 
2305  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2306  PatternRewriter &rewriter) const override {
2307  auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2308  if (!srcBroadcast)
2309  return failure();
2310  rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
2311  broadcastOp.getResultVectorType(),
2312  srcBroadcast.getSource());
2313  return success();
2314  }
2315 };
2316 } // namespace
2317 
2318 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2319  MLIRContext *context) {
2320  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2321  // calling `populateCastAwayVectorLeadingOneDimPatterns`
2322  results.add<BroadcastFolder>(context);
2323 }
2324 
2325 //===----------------------------------------------------------------------===//
2326 // ShuffleOp
2327 //===----------------------------------------------------------------------===//
2328 
2329 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
2330  Value v2, ArrayRef<int64_t> mask) {
2331  build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask));
2332 }
2333 
2335  VectorType resultType = getResultVectorType();
2336  VectorType v1Type = getV1VectorType();
2337  VectorType v2Type = getV2VectorType();
2338  // Verify ranks.
2339  int64_t resRank = resultType.getRank();
2340  int64_t v1Rank = v1Type.getRank();
2341  int64_t v2Rank = v2Type.getRank();
2342  bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2343  bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2344  if (!wellFormed0DCase && !wellFormedNDCase)
2345  return emitOpError("rank mismatch");
2346 
2347  // Verify all but leading dimension sizes.
2348  for (int64_t r = 1; r < v1Rank; ++r) {
2349  int64_t resDim = resultType.getDimSize(r);
2350  int64_t v1Dim = v1Type.getDimSize(r);
2351  int64_t v2Dim = v2Type.getDimSize(r);
2352  if (resDim != v1Dim || v1Dim != v2Dim)
2353  return emitOpError("dimension mismatch");
2354  }
2355  // Verify mask length.
2356  auto maskAttr = getMask().getValue();
2357  int64_t maskLength = maskAttr.size();
2358  if (maskLength <= 0)
2359  return emitOpError("invalid mask length");
2360  if (maskLength != resultType.getDimSize(0))
2361  return emitOpError("mask length mismatch");
2362  // Verify all indices.
2363  int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2364  (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2365  for (const auto &en : llvm::enumerate(maskAttr)) {
2366  auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
2367  if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
2368  return emitOpError("mask index #") << (en.index() + 1) << " out of range";
2369  }
2370  return success();
2371 }
2372 
2374 ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
2375  ShuffleOp::Adaptor adaptor,
2376  SmallVectorImpl<Type> &inferredReturnTypes) {
2377  auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2378  auto v1Rank = v1Type.getRank();
2379  // Construct resulting type: leading dimension matches mask
2380  // length, all trailing dimensions match the operands.
2382  shape.reserve(v1Rank);
2383  shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2384  // In the 0-D case there is no trailing shape to append.
2385  if (v1Rank > 0)
2386  llvm::append_range(shape, v1Type.getShape().drop_front());
2387  inferredReturnTypes.push_back(
2388  VectorType::get(shape, v1Type.getElementType()));
2389  return success();
2390 }
2391 
2392 static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
2393  uint64_t expected = begin;
2394  return idxArr.size() == width &&
2395  llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
2396  [&expected](auto attr) {
2397  return attr.getZExtValue() == expected++;
2398  });
2399 }
2400 
2401 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2402  VectorType v1Type = getV1VectorType();
2403  // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
2404  // but must be a canonicalization into a vector.broadcast.
2405  if (v1Type.getRank() == 0)
2406  return {};
2407 
2408  // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
2409  if (!v1Type.isScalable() &&
2410  isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
2411  return getV1();
2412  // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
2413  if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
2414  isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
2415  getV2VectorType().getDimSize(0)))
2416  return getV2();
2417 
2418  Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
2419  if (!lhs || !rhs)
2420  return {};
2421 
2422  auto lhsType =
2423  llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType());
2424  // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2425  // manipulation.
2426  if (lhsType.getRank() != 1)
2427  return {};
2428  int64_t lhsSize = lhsType.getDimSize(0);
2429 
2430  SmallVector<Attribute> results;
2431  auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
2432  auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
2433  for (const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
2434  int64_t i = index.getZExtValue();
2435  if (i >= lhsSize) {
2436  results.push_back(rhsElements[i - lhsSize]);
2437  } else {
2438  results.push_back(lhsElements[i]);
2439  }
2440  }
2441 
2442  return DenseElementsAttr::get(getResultVectorType(), results);
2443 }
2444 
2445 namespace {
2446 
2447 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
2448 // to a broadcast.
2449 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
2451 
2452  LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
2453  PatternRewriter &rewriter) const override {
2454  VectorType v1VectorType = shuffleOp.getV1VectorType();
2455  ArrayAttr mask = shuffleOp.getMask();
2456  if (v1VectorType.getRank() > 0)
2457  return failure();
2458  if (mask.size() != 1)
2459  return failure();
2460  VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
2461  if (llvm::cast<IntegerAttr>(mask[0]).getInt() == 0)
2462  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2463  shuffleOp.getV1());
2464  else
2465  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2466  shuffleOp.getV2());
2467  return success();
2468  }
2469 };
2470 
2471 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
2472 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
2473 public:
2475 
2476  LogicalResult matchAndRewrite(ShuffleOp op,
2477  PatternRewriter &rewriter) const override {
2478  auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2479  auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2480 
2481  if (!v1Splat || !v2Splat)
2482  return failure();
2483 
2484  if (v1Splat.getInput() != v2Splat.getInput())
2485  return failure();
2486 
2487  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
2488  return success();
2489  }
2490 };
2491 
2492 /// Pattern to rewrite a fixed-size interleave via vector.shuffle to
2493 /// vector.interleave.
2494 class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
2495 public:
2497 
2498  LogicalResult matchAndRewrite(ShuffleOp op,
2499  PatternRewriter &rewriter) const override {
2500  VectorType resultType = op.getResultVectorType();
2501  if (resultType.isScalable())
2502  return rewriter.notifyMatchFailure(
2503  op, "ShuffleOp can't represent a scalable interleave");
2504 
2505  if (resultType.getRank() != 1)
2506  return rewriter.notifyMatchFailure(
2507  op, "ShuffleOp can't represent an n-D interleave");
2508 
2509  VectorType sourceType = op.getV1VectorType();
2510  if (sourceType != op.getV2VectorType() ||
2511  sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2512  return rewriter.notifyMatchFailure(
2513  op, "ShuffleOp types don't match an interleave");
2514  }
2515 
2516  ArrayAttr shuffleMask = op.getMask();
2517  int64_t resultVectorSize = resultType.getNumElements();
2518  for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2519  int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
2520  int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
2521  if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2522  return rewriter.notifyMatchFailure(op,
2523  "ShuffleOp mask not interleaving");
2524  }
2525 
2526  rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
2527  return success();
2528  }
2529 };
2530 
2531 } // namespace
2532 
2533 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
2534  MLIRContext *context) {
2535  results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2536  context);
2537 }
2538 
2539 //===----------------------------------------------------------------------===//
2540 // InsertElementOp
2541 //===----------------------------------------------------------------------===//
2542 
2543 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
2544  Value source, Value dest) {
2545  build(builder, result, source, dest, {});
2546 }
2547 
2549  auto dstVectorType = getDestVectorType();
2550  if (dstVectorType.getRank() == 0) {
2551  if (getPosition())
2552  return emitOpError("expected position to be empty with 0-D vector");
2553  return success();
2554  }
2555  if (dstVectorType.getRank() != 1)
2556  return emitOpError("unexpected >1 vector rank");
2557  if (!getPosition())
2558  return emitOpError("expected position for 1-D vector");
2559  return success();
2560 }
2561 
2562 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2563  // Skip the 0-D vector here.
2564  if (!adaptor.getPosition())
2565  return {};
2566 
2567  auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2568  auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2569  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2570  if (!src || !dst || !pos)
2571  return {};
2572 
2573  if (src.getType() != getDestVectorType().getElementType())
2574  return {};
2575 
2576  auto dstElements = dst.getValues<Attribute>();
2577 
2578  SmallVector<Attribute> results(dstElements);
2579 
2580  uint64_t posIdx = pos.getInt();
2581  if (posIdx >= results.size())
2582  return {};
2583  results[posIdx] = src;
2584 
2585  return DenseElementsAttr::get(getDestVectorType(), results);
2586 }
2587 
2588 //===----------------------------------------------------------------------===//
2589 // InsertOp
2590 //===----------------------------------------------------------------------===//
2591 
2592 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2593  Value source, Value dest, int64_t position) {
2594  build(builder, result, source, dest, ArrayRef<int64_t>{position});
2595 }
2596 
2597 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2598  Value source, Value dest, OpFoldResult position) {
2599  build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
2600 }
2601 
2602 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2603  Value source, Value dest,
2604  ArrayRef<int64_t> position) {
2605  SmallVector<OpFoldResult> posVals;
2606  posVals.reserve(position.size());
2607  llvm::transform(position, std::back_inserter(posVals),
2608  [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
2609  build(builder, result, source, dest, posVals);
2610 }
2611 
2612 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2613  Value source, Value dest,
2614  ArrayRef<OpFoldResult> position) {
2615  SmallVector<int64_t> staticPos;
2616  SmallVector<Value> dynamicPos;
2617  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
2618  build(builder, result, source, dest, dynamicPos,
2619  builder.getDenseI64ArrayAttr(staticPos));
2620 }
2621 
2623  SmallVector<OpFoldResult> position = getMixedPosition();
2624  auto destVectorType = getDestVectorType();
2625  if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
2626  return emitOpError(
2627  "expected position attribute of rank no greater than dest vector rank");
2628  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2629  if (srcVectorType &&
2630  (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2631  static_cast<unsigned>(destVectorType.getRank())))
2632  return emitOpError("expected position attribute rank + source rank to "
2633  "match dest vector rank");
2634  if (!srcVectorType &&
2635  (position.size() != static_cast<unsigned>(destVectorType.getRank())))
2636  return emitOpError(
2637  "expected position attribute rank to match the dest vector rank");
2638  for (auto [idx, pos] : llvm::enumerate(position)) {
2639  if (auto attr = pos.dyn_cast<Attribute>()) {
2640  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2641  if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
2642  return emitOpError("expected position attribute #")
2643  << (idx + 1)
2644  << " to be a non-negative integer smaller than the "
2645  "corresponding "
2646  "dest vector dimension";
2647  }
2648  }
2649  }
2650  return success();
2651 }
2652 
2653 namespace {
2654 
2655 // If insertOp is only inserting unit dimensions it can be transformed to a
2656 // broadcast.
2657 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
2658 public:
2660 
2661  LogicalResult matchAndRewrite(InsertOp insertOp,
2662  PatternRewriter &rewriter) const override {
2663  auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
2664  if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2665  srcVecType.getNumElements())
2666  return failure();
2667  rewriter.replaceOpWithNewOp<BroadcastOp>(
2668  insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2669  return success();
2670  }
2671 };
2672 
2673 /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
2674 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
2675 public:
2677 
2678  LogicalResult matchAndRewrite(InsertOp op,
2679  PatternRewriter &rewriter) const override {
2680  auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2681  auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2682 
2683  if (!srcSplat || !dstSplat)
2684  return failure();
2685 
2686  if (srcSplat.getInput() != dstSplat.getInput())
2687  return failure();
2688 
2689  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
2690  return success();
2691  }
2692 };
2693 
2694 // Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
2695 class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
2696 public:
2698 
2699  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
2700  // unless the source vector constant has a single use.
2701  static constexpr int64_t vectorSizeFoldThreshold = 256;
2702 
2703  LogicalResult matchAndRewrite(InsertOp op,
2704  PatternRewriter &rewriter) const override {
2705  // TODO: Canonicalization for dynamic position not implemented yet.
2706  if (op.hasDynamicPosition())
2707  return failure();
2708 
2709  // Return if 'InsertOp' operand is not defined by a compatible vector
2710  // ConstantOp.
2711  TypedValue<VectorType> destVector = op.getDest();
2712  Attribute vectorDestCst;
2713  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
2714  return failure();
2715 
2716  VectorType destTy = destVector.getType();
2717  if (destTy.isScalable())
2718  return failure();
2719 
2720  // Make sure we do not create too many large constants.
2721  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
2722  !destVector.hasOneUse())
2723  return failure();
2724 
2725  auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
2726 
2727  Value sourceValue = op.getSource();
2728  Attribute sourceCst;
2729  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
2730  return failure();
2731 
2732  // Calculate the linearized position of the continuous chunk of elements to
2733  // insert.
2734  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
2735  copy(op.getStaticPosition(), completePositions.begin());
2736  int64_t insertBeginPosition =
2737  linearize(completePositions, computeStrides(destTy.getShape()));
2738 
2739  SmallVector<Attribute> insertedValues;
2740  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst))
2741  llvm::append_range(insertedValues, denseSource.getValues<Attribute>());
2742  else
2743  insertedValues.push_back(sourceCst);
2744 
2745  auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
2746  copy(insertedValues, allValues.begin() + insertBeginPosition);
2747  auto newAttr = DenseElementsAttr::get(destTy, allValues);
2748 
2749  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
2750  return success();
2751  }
2752 };
2753 
2754 } // namespace
2755 
2756 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
2757  MLIRContext *context) {
2758  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
2759  InsertOpConstantFolder>(context);
2760 }
2761 
2762 // Eliminates insert operations that produce values identical to their source
2763 // value. This happens when the source and destination vectors have identical
2764 // sizes.
2765 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
2766  if (getNumIndices() == 0)
2767  return getSource();
2768  return {};
2769 }
2770 
2771 //===----------------------------------------------------------------------===//
2772 // InsertStridedSliceOp
2773 //===----------------------------------------------------------------------===//
2774 
2775 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
2776  Value source, Value dest,
2777  ArrayRef<int64_t> offsets,
2778  ArrayRef<int64_t> strides) {
2779  result.addOperands({source, dest});
2780  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
2781  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
2782  result.addTypes(dest.getType());
2783  result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
2784  offsetsAttr);
2785  result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
2786  stridesAttr);
2787 }
2788 
2789 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
2790 template <typename OpType>
2792  ArrayAttr arrayAttr,
2793  ArrayRef<int64_t> shape,
2794  StringRef attrName) {
2795  if (arrayAttr.size() > shape.size())
2796  return op.emitOpError("expected ")
2797  << attrName << " attribute of rank no greater than vector rank";
2798  return success();
2799 }
2800 
2801 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2802 // interval. If `halfOpen` is true then the admissible interval is [min, max).
2803 // Otherwise, the admissible interval is [min, max].
2804 template <typename OpType>
2805 static LogicalResult
2806 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
2807  int64_t max, StringRef attrName,
2808  bool halfOpen = true) {
2809  for (auto attr : arrayAttr) {
2810  auto val = llvm::cast<IntegerAttr>(attr).getInt();
2811  auto upper = max;
2812  if (!halfOpen)
2813  upper += 1;
2814  if (val < min || val >= upper)
2815  return op.emitOpError("expected ") << attrName << " to be confined to ["
2816  << min << ", " << upper << ")";
2817  }
2818  return success();
2819 }
2820 
2821 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2822 // interval. If `halfOpen` is true then the admissible interval is [min, max).
2823 // Otherwise, the admissible interval is [min, max].
2824 template <typename OpType>
2825 static LogicalResult
2826 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
2827  ArrayRef<int64_t> shape, StringRef attrName,
2828  bool halfOpen = true, int64_t min = 0) {
2829  for (auto [index, attrDimPair] :
2830  llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
2831  int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
2832  int64_t max = std::get<1>(attrDimPair);
2833  if (!halfOpen)
2834  max += 1;
2835  if (val < min || val >= max)
2836  return op.emitOpError("expected ")
2837  << attrName << " dimension " << index << " to be confined to ["
2838  << min << ", " << max << ")";
2839  }
2840  return success();
2841 }
2842 
2843 // Returns true if, for all indices i = 0..shape.size()-1, val is in the
2844 // [min, max} interval:
2845 // val = `arrayAttr1[i]` + `arrayAttr2[i]`,
2846 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
2847 // the admissible interval is [min, max].
2848 template <typename OpType>
2850  OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
2851  ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
2852  bool halfOpen = true, int64_t min = 1) {
2853  assert(arrayAttr1.size() <= shape.size());
2854  assert(arrayAttr2.size() <= shape.size());
2855  for (auto [index, it] :
2856  llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
2857  auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
2858  auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
2859  int64_t max = std::get<2>(it);
2860  if (!halfOpen)
2861  max += 1;
2862  if (val1 + val2 < 0 || val1 + val2 >= max)
2863  return op.emitOpError("expected sum(")
2864  << attrName1 << ", " << attrName2 << ") dimension " << index
2865  << " to be confined to [" << min << ", " << max << ")";
2866  }
2867  return success();
2868 }
2869 
2870 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
2871  MLIRContext *context) {
2872  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
2873  return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
2874  });
2875  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
2876 }
2877 
2879  auto sourceVectorType = getSourceVectorType();
2880  auto destVectorType = getDestVectorType();
2881  auto offsets = getOffsetsAttr();
2882  auto strides = getStridesAttr();
2883  if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
2884  return emitOpError(
2885  "expected offsets of same size as destination vector rank");
2886  if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
2887  return emitOpError("expected strides of same size as source vector rank");
2888  if (sourceVectorType.getRank() > destVectorType.getRank())
2889  return emitOpError(
2890  "expected source rank to be no greater than destination rank");
2891 
2892  auto sourceShape = sourceVectorType.getShape();
2893  auto destShape = destVectorType.getShape();
2894  SmallVector<int64_t, 4> sourceShapeAsDestShape(
2895  destShape.size() - sourceShape.size(), 0);
2896  sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
2897  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
2898  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
2899  if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
2900  offName)) ||
2901  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
2902  /*max=*/1, stridesName,
2903  /*halfOpen=*/false)) ||
2905  *this, offsets,
2906  makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
2907  offName, "source vector shape",
2908  /*halfOpen=*/false, /*min=*/1)))
2909  return failure();
2910 
2911  unsigned rankDiff = destShape.size() - sourceShape.size();
2912  for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
2913  if (sourceVectorType.getScalableDims()[idx] !=
2914  destVectorType.getScalableDims()[idx + rankDiff]) {
2915  return emitOpError("mismatching scalable flags (at source vector idx=")
2916  << idx << ")";
2917  }
2918  if (sourceVectorType.getScalableDims()[idx]) {
2919  auto sourceSize = sourceShape[idx];
2920  auto destSize = destShape[idx + rankDiff];
2921  if (sourceSize != destSize) {
2922  return emitOpError("expected size at idx=")
2923  << idx
2924  << (" to match the corresponding base size from the input "
2925  "vector (")
2926  << sourceSize << (" vs ") << destSize << (")");
2927  }
2928  }
2929  }
2930 
2931  return success();
2932 }
2933 
2934 namespace {
2935 /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
2936 /// SplatOp(X):dst_type) to SplatOp(X):dst_type.
2937 class FoldInsertStridedSliceSplat final
2938  : public OpRewritePattern<InsertStridedSliceOp> {
2939 public:
2941 
2942  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
2943  PatternRewriter &rewriter) const override {
2944  auto srcSplatOp =
2945  insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
2946  auto destSplatOp =
2947  insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
2948 
2949  if (!srcSplatOp || !destSplatOp)
2950  return failure();
2951 
2952  if (srcSplatOp.getInput() != destSplatOp.getInput())
2953  return failure();
2954 
2955  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2956  return success();
2957  }
2958 };
2959 
2960 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
2961 /// to dst.
2962 class FoldInsertStridedSliceOfExtract final
2963  : public OpRewritePattern<InsertStridedSliceOp> {
2964 public:
2966 
2967  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
2968  PatternRewriter &rewriter) const override {
2969  auto extractStridedSliceOp =
2970  insertStridedSliceOp.getSource()
2971  .getDefiningOp<vector::ExtractStridedSliceOp>();
2972 
2973  if (!extractStridedSliceOp)
2974  return failure();
2975 
2976  if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
2977  return failure();
2978 
2979  // Check if have the same strides and offsets.
2980  if (extractStridedSliceOp.getStrides() !=
2981  insertStridedSliceOp.getStrides() ||
2982  extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
2983  return failure();
2984 
2985  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2986  return success();
2987  }
2988 };
2989 
2990 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
2991 // ConstantOp.
2992 class InsertStridedSliceConstantFolder final
2993  : public OpRewritePattern<InsertStridedSliceOp> {
2994 public:
2996 
2997  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
2998  // unless the source vector constant has a single use.
2999  static constexpr int64_t vectorSizeFoldThreshold = 256;
3000 
3001  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3002  PatternRewriter &rewriter) const override {
3003  // Return if 'InsertOp' operand is not defined by a compatible vector
3004  // ConstantOp.
3005  TypedValue<VectorType> destVector = op.getDest();
3006  Attribute vectorDestCst;
3007  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3008  return failure();
3009 
3010  VectorType destTy = destVector.getType();
3011  if (destTy.isScalable())
3012  return failure();
3013 
3014  // Make sure we do not create too many large constants.
3015  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3016  !destVector.hasOneUse())
3017  return failure();
3018 
3019  auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3020 
3021  TypedValue<VectorType> sourceValue = op.getSource();
3022  Attribute sourceCst;
3023  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3024  return failure();
3025 
3026  // TODO: Handle non-unit strides when they become available.
3027  if (op.hasNonUnitStrides())
3028  return failure();
3029 
3030  VectorType sliceVecTy = sourceValue.getType();
3031  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3032  int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3033  SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3034  SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3035 
3036  // Calcualte the destination element indices by enumerating all slice
3037  // positions within the destination and linearizing them. The enumeration
3038  // order is lexicographic which yields a sequence of monotonically
3039  // increasing linearized position indices.
3040  // Because the destination may have higher dimensionality then the slice,
3041  // we keep track of two overlapping sets of positions and offsets.
3042  auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3043  auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3044  auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3045  SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3046  MutableArrayRef<int64_t> currSlicePosition(
3047  currDestPosition.begin() + rankDifference, currDestPosition.end());
3048  ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3049  offsets.end());
3050  do {
3051  int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3052  assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3053  assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3054  "Invalid slice element");
3055  newValues[linearizedPosition] = *sliceValuesIt;
3056  ++sliceValuesIt;
3057  } while (succeeded(
3058  incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3059 
3060  auto newAttr = DenseElementsAttr::get(destTy, newValues);
3061  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3062  return success();
3063  }
3064 };
3065 
3066 } // namespace
3067 
3068 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3069  RewritePatternSet &results, MLIRContext *context) {
3070  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3071  InsertStridedSliceConstantFolder>(context);
3072 }
3073 
3074 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3075  if (getSourceVectorType() == getDestVectorType())
3076  return getSource();
3077  return {};
3078 }
3079 
3080 //===----------------------------------------------------------------------===//
3081 // OuterProductOp
3082 //===----------------------------------------------------------------------===//
3083 
3084 /// Build an op without mask, use the type of `acc` as the return type.
3085 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
3086  Value lhs, Value rhs, Value acc) {
3087  result.addOperands({lhs, rhs, acc});
3088  result.addTypes(acc.getType());
3089 }
3090 
3092  p << " " << getLhs() << ", " << getRhs();
3093  if (getAcc()) {
3094  p << ", " << getAcc();
3095  p.printOptionalAttrDict((*this)->getAttrs());
3096  }
3097  p << " : " << getLhs().getType() << ", " << getRhs().getType();
3098 }
3099 
3102  Type tLHS, tRHS;
3103  if (parser.parseOperandList(operandsInfo) ||
3104  parser.parseOptionalAttrDict(result.attributes) ||
3105  parser.parseColonType(tLHS) || parser.parseComma() ||
3106  parser.parseType(tRHS))
3107  return failure();
3108  if (operandsInfo.size() < 2)
3109  return parser.emitError(parser.getNameLoc(),
3110  "expected at least 2 operands");
3111  VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3112  VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3113  if (!vLHS)
3114  return parser.emitError(parser.getNameLoc(),
3115  "expected vector type for operand #1");
3116 
3117  VectorType resType;
3118  if (vRHS) {
3119  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3120  vRHS.getScalableDims()[0]};
3121  resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
3122  vLHS.getElementType(), scalableDimsRes);
3123  } else {
3124  // Scalar RHS operand
3125  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3126  resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3127  scalableDimsRes);
3128  }
3129 
3130  if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
3131  result.attributes.append(
3132  OuterProductOp::getKindAttrName(result.name),
3134  OuterProductOp::getDefaultKind()));
3135  }
3136 
3137  return failure(
3138  parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
3139  parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
3140  (operandsInfo.size() > 2 &&
3141  parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
3142  parser.addTypeToList(resType, result.types));
3143 }
3144 
3146  Type tRHS = getOperandTypeRHS();
3147  VectorType vLHS = getOperandVectorTypeLHS(),
3148  vRHS = llvm::dyn_cast<VectorType>(tRHS),
3149  vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3150 
3151  if (vLHS.getRank() != 1)
3152  return emitOpError("expected 1-d vector for operand #1");
3153 
3154  if (vRHS) {
3155  // Proper OUTER operation.
3156  if (vRHS.getRank() != 1)
3157  return emitOpError("expected 1-d vector for operand #2");
3158  if (vRES.getRank() != 2)
3159  return emitOpError("expected 2-d vector result");
3160  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3161  return emitOpError("expected #1 operand dim to match result dim #1");
3162  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3163  return emitOpError("expected #2 operand dim to match result dim #2");
3164  if (vLHS.isScalable() && !vRHS.isScalable()) {
3165  // This restriction reflects what's currently supported in terms of
3166  // scalable vectors. However, we could relax this if there's a use case.
3167  return emitOpError(
3168  "expected either both or only #2 operand dim to be scalable");
3169  }
3170  } else {
3171  // An AXPY operation.
3172  if (vRES.getRank() != 1)
3173  return emitOpError("expected 1-d vector result");
3174  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3175  return emitOpError("expected #1 operand dim to match result dim #1");
3176  }
3177 
3178  if (vACC && vACC != vRES)
3179  return emitOpError("expected operand #3 of same type as result type");
3180 
3181  // Verify supported combining kind.
3182  if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
3183  return emitOpError("unsupported outerproduct type");
3184 
3185  return success();
3186 }
3187 
3188 // MaskableOpInterface methods.
3189 
3190 /// Returns the mask type expected by this operation. Mostly used for
3191 /// verification purposes. It requires the operation to be vectorized."
3192 Type OuterProductOp::getExpectedMaskType() {
3193  auto vecType = this->getResultVectorType();
3194  return VectorType::get(vecType.getShape(),
3195  IntegerType::get(vecType.getContext(), /*width=*/1),
3196  vecType.getScalableDims());
3197 }
3198 
3199 //===----------------------------------------------------------------------===//
3200 // ReshapeOp
3201 //===----------------------------------------------------------------------===//
3202 
3204  // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
3205  auto inputVectorType = getInputVectorType();
3206  auto outputVectorType = getOutputVectorType();
3207  int64_t inputShapeRank = getNumInputShapeSizes();
3208  int64_t outputShapeRank = getNumOutputShapeSizes();
3209  SmallVector<int64_t, 4> fixedVectorSizes;
3210  getFixedVectorSizes(fixedVectorSizes);
3211  int64_t numFixedVectorSizes = fixedVectorSizes.size();
3212 
3213  if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
3214  return emitError("invalid input shape for vector type ") << inputVectorType;
3215 
3216  if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
3217  return emitError("invalid output shape for vector type ")
3218  << outputVectorType;
3219 
3220  // Verify that the 'fixedVectorSizes' match an input/output vector shape
3221  // suffix.
3222  unsigned inputVectorRank = inputVectorType.getRank();
3223  for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
3224  unsigned index = inputVectorRank - numFixedVectorSizes - i;
3225  if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
3226  return emitError("fixed vector size must match input vector for dim ")
3227  << i;
3228  }
3229 
3230  unsigned outputVectorRank = outputVectorType.getRank();
3231  for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
3232  unsigned index = outputVectorRank - numFixedVectorSizes - i;
3233  if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
3234  return emitError("fixed vector size must match output vector for dim ")
3235  << i;
3236  }
3237 
3238  // If all shape operands are produced by constant ops, verify that product
3239  // of dimensions for input/output shape match.
3240  auto isDefByConstant = [](Value operand) {
3241  return getConstantIntValue(operand).has_value();
3242  };
3243  if (llvm::all_of(getInputShape(), isDefByConstant) &&
3244  llvm::all_of(getOutputShape(), isDefByConstant)) {
3245  int64_t numInputElements = 1;
3246  for (auto operand : getInputShape())
3247  numInputElements *= getConstantIntValue(operand).value();
3248  int64_t numOutputElements = 1;
3249  for (auto operand : getOutputShape())
3250  numOutputElements *= getConstantIntValue(operand).value();
3251  if (numInputElements != numOutputElements)
3252  return emitError("product of input and output shape sizes must match");
3253  }
3254  return success();
3255 }
3256 
3257 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
3258  populateFromInt64AttrArray(getFixedVectorSizes(), results);
3259 }
3260 
3261 //===----------------------------------------------------------------------===//
3262 // ExtractStridedSliceOp
3263 //===----------------------------------------------------------------------===//
3264 
3265 // Inference works as follows:
3266 // 1. Add 'sizes' from prefix of dims in 'offsets'.
3267 // 2. Add sizes from 'vectorType' for remaining dims.
3268 // Scalable flags are inherited from 'vectorType'.
3269 static Type inferStridedSliceOpResultType(VectorType vectorType,
3270  ArrayAttr offsets, ArrayAttr sizes,
3271  ArrayAttr strides) {
3272  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3274  shape.reserve(vectorType.getRank());
3275  unsigned idx = 0;
3276  for (unsigned e = offsets.size(); idx < e; ++idx)
3277  shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3278  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3279  shape.push_back(vectorType.getShape()[idx]);
3280 
3281  return VectorType::get(shape, vectorType.getElementType(),
3282  vectorType.getScalableDims());
3283 }
3284 
3285 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3286  Value source, ArrayRef<int64_t> offsets,
3287  ArrayRef<int64_t> sizes,
3288  ArrayRef<int64_t> strides) {
3289  result.addOperands(source);
3290  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3291  auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
3292  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3293  result.addTypes(
3294  inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
3295  offsetsAttr, sizesAttr, stridesAttr));
3296  result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
3297  offsetsAttr);
3298  result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
3299  sizesAttr);
3300  result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
3301  stridesAttr);
3302 }
3303 
3305  auto type = getSourceVectorType();
3306  auto offsets = getOffsetsAttr();
3307  auto sizes = getSizesAttr();
3308  auto strides = getStridesAttr();
3309  if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3310  return emitOpError(
3311  "expected offsets, sizes and strides attributes of same size");
3312 
3313  auto shape = type.getShape();
3314  auto offName = getOffsetsAttrName();
3315  auto sizesName = getSizesAttrName();
3316  auto stridesName = getStridesAttrName();
3317  if (failed(
3318  isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
3319  failed(
3320  isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
3321  failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
3322  stridesName)) ||
3323  failed(
3324  isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
3325  failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
3326  /*halfOpen=*/false,
3327  /*min=*/1)) ||
3328  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3329  /*max=*/1, stridesName,
3330  /*halfOpen=*/false)) ||
3331  failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
3332  shape, offName, sizesName,
3333  /*halfOpen=*/false)))
3334  return failure();
3335 
3336  auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
3337  offsets, sizes, strides);
3338  if (getResult().getType() != resultType)
3339  return emitOpError("expected result type to be ") << resultType;
3340 
3341  for (unsigned idx = 0; idx < sizes.size(); ++idx) {
3342  if (type.getScalableDims()[idx]) {
3343  auto inputDim = type.getShape()[idx];
3344  auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3345  if (inputDim != inputSize)
3346  return emitOpError("expected size at idx=")
3347  << idx
3348  << (" to match the corresponding base size from the input "
3349  "vector (")
3350  << inputSize << (" vs ") << inputDim << (")");
3351  }
3352  }
3353 
3354  return success();
3355 }
3356 
3357 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
3358 // to use the source of the InsertStrided ops if we can detect that the
3359 // extracted vector is a subset of one of the vector inserted.
3360 static LogicalResult
3361 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
3362  // Helper to extract integer out of ArrayAttr.
3363  auto getElement = [](ArrayAttr array, int idx) {
3364  return llvm::cast<IntegerAttr>(array[idx]).getInt();
3365  };
3366  ArrayAttr extractOffsets = op.getOffsets();
3367  ArrayAttr extractStrides = op.getStrides();
3368  ArrayAttr extractSizes = op.getSizes();
3369  auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3370  while (insertOp) {
3371  if (op.getSourceVectorType().getRank() !=
3372  insertOp.getSourceVectorType().getRank())
3373  return failure();
3374  ArrayAttr insertOffsets = insertOp.getOffsets();
3375  ArrayAttr insertStrides = insertOp.getStrides();
3376  // If the rank of extract is greater than the rank of insert, we are likely
3377  // extracting a partial chunk of the vector inserted.
3378  if (extractOffsets.size() > insertOffsets.size())
3379  return failure();
3380  bool patialoverlap = false;
3381  bool disjoint = false;
3382  SmallVector<int64_t, 4> offsetDiffs;
3383  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3384  if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
3385  return failure();
3386  int64_t start = getElement(insertOffsets, dim);
3387  int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3388  int64_t offset = getElement(extractOffsets, dim);
3389  int64_t size = getElement(extractSizes, dim);
3390  // Check if the start of the extract offset is in the interval inserted.
3391  if (start <= offset && offset < end) {
3392  // If the extract interval overlaps but is not fully included we may
3393  // have a partial overlap that will prevent any folding.
3394  if (offset + size > end)
3395  patialoverlap = true;
3396  offsetDiffs.push_back(offset - start);
3397  continue;
3398  }
3399  disjoint = true;
3400  break;
3401  }
3402  // The extract element chunk is a subset of the insert element.
3403  if (!disjoint && !patialoverlap) {
3404  op.setOperand(insertOp.getSource());
3405  // OpBuilder is only used as a helper to build an I64ArrayAttr.
3406  OpBuilder b(op.getContext());
3407  op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
3408  return success();
3409  }
3410  // If the chunk extracted is disjoint from the chunk inserted, keep looking
3411  // in the insert chain.
3412  if (disjoint)
3413  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3414  else {
3415  // The extracted vector partially overlap the inserted vector, we cannot
3416  // fold.
3417  return failure();
3418  }
3419  }
3420  return failure();
3421 }
3422 
3423 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3424  if (getSourceVectorType() == getResult().getType())
3425  return getVector();
3427  return getResult();
3428  return {};
3429 }
3430 
3431 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
3432  populateFromInt64AttrArray(getOffsets(), results);
3433 }
3434 
3435 namespace {
3436 
3437 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
3438 // ConstantMaskOp.
3439 class StridedSliceConstantMaskFolder final
3440  : public OpRewritePattern<ExtractStridedSliceOp> {
3441 public:
3443 
3444  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3445  PatternRewriter &rewriter) const override {
3446  // Return if 'extractStridedSliceOp' operand is not defined by a
3447  // ConstantMaskOp.
3448  auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3449  auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3450  if (!constantMaskOp)
3451  return failure();
3452  // Return if 'extractStridedSliceOp' has non-unit strides.
3453  if (extractStridedSliceOp.hasNonUnitStrides())
3454  return failure();
3455  // Gather constant mask dimension sizes.
3456  SmallVector<int64_t, 4> maskDimSizes;
3457  populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
3458  // Gather strided slice offsets and sizes.
3459  SmallVector<int64_t, 4> sliceOffsets;
3460  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
3461  sliceOffsets);
3462  SmallVector<int64_t, 4> sliceSizes;
3463  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
3464 
3465  // Compute slice of vector mask region.
3466  SmallVector<int64_t, 4> sliceMaskDimSizes;
3467  sliceMaskDimSizes.reserve(maskDimSizes.size());
3468  for (auto [maskDimSize, sliceOffset, sliceSize] :
3469  llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3470  int64_t sliceMaskDimSize = std::max(
3471  static_cast<int64_t>(0),
3472  std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3473  sliceMaskDimSizes.push_back(sliceMaskDimSize);
3474  }
3475  // Add unchanged dimensions.
3476  if (sliceMaskDimSizes.size() < maskDimSizes.size())
3477  for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3478  sliceMaskDimSizes.push_back(maskDimSizes[i]);
3479  // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
3480  // region is a conjunction of mask dim intervals).
3481  if (llvm::is_contained(sliceMaskDimSizes, 0))
3482  sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3483 
3484  // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
3485  // region.
3486  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3487  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3488  vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
3489  return success();
3490  }
3491 };
3492 
3493 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3494 class StridedSliceSplatConstantFolder final
3495  : public OpRewritePattern<ExtractStridedSliceOp> {
3496 public:
3498 
3499  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3500  PatternRewriter &rewriter) const override {
3501  // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3502  // ConstantOp.
3503  Value sourceVector = extractStridedSliceOp.getVector();
3504  Attribute vectorCst;
3505  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3506  return failure();
3507 
3508  auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3509  if (!splat)
3510  return failure();
3511 
3512  auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
3513  splat.getSplatValue<Attribute>());
3514  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3515  newAttr);
3516  return success();
3517  }
3518 };
3519 
3520 // Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3521 // ConstantOp.
3522 class StridedSliceNonSplatConstantFolder final
3523  : public OpRewritePattern<ExtractStridedSliceOp> {
3524 public:
3526 
3527  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3528  PatternRewriter &rewriter) const override {
3529  // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3530  // ConstantOp.
3531  Value sourceVector = extractStridedSliceOp.getVector();
3532  Attribute vectorCst;
3533  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3534  return failure();
3535 
3536  // The splat case is handled by `StridedSliceSplatConstantFolder`.
3537  auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3538  if (!dense || dense.isSplat())
3539  return failure();
3540 
3541  // TODO: Handle non-unit strides when they become available.
3542  if (extractStridedSliceOp.hasNonUnitStrides())
3543  return failure();
3544 
3545  auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
3546  ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3547  SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3548 
3549  VectorType sliceVecTy = extractStridedSliceOp.getType();
3550  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3551  int64_t sliceRank = sliceVecTy.getRank();
3552 
3553  // Expand offsets and sizes to match the vector rank.
3554  SmallVector<int64_t, 4> offsets(sliceRank, 0);
3555  copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
3556 
3557  SmallVector<int64_t, 4> sizes(sourceShape.begin(), sourceShape.end());
3558  copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
3559 
3560  // Calculate the slice elements by enumerating all slice positions and
3561  // linearizing them. The enumeration order is lexicographic which yields a
3562  // sequence of monotonically increasing linearized position indices.
3563  auto denseValuesBegin = dense.value_begin<Attribute>();
3564  SmallVector<Attribute> sliceValues;
3565  sliceValues.reserve(sliceVecTy.getNumElements());
3566  SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3567  do {
3568  int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
3569  assert(linearizedPosition < sourceVecTy.getNumElements() &&
3570  "Invalid index");
3571  sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3572  } while (
3573  succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
3574 
3575  assert(static_cast<int64_t>(sliceValues.size()) ==
3576  sliceVecTy.getNumElements() &&
3577  "Invalid number of slice elements");
3578  auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
3579  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3580  newAttr);
3581  return success();
3582  }
3583 };
3584 
3585 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
3586 // BroadcastOp(ExtractStrideSliceOp).
3587 class StridedSliceBroadcast final
3588  : public OpRewritePattern<ExtractStridedSliceOp> {
3589 public:
3591 
3592  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3593  PatternRewriter &rewriter) const override {
3594  auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3595  if (!broadcast)
3596  return failure();
3597  auto srcVecType =
3598  llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
3599  unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3600  auto dstVecType = llvm::cast<VectorType>(op.getType());
3601  unsigned dstRank = dstVecType.getRank();
3602  unsigned rankDiff = dstRank - srcRank;
3603  // Check if the most inner dimensions of the source of the broadcast are the
3604  // same as the destination of the extract. If this is the case we can just
3605  // use a broadcast as the original dimensions are untouched.
3606  bool lowerDimMatch = true;
3607  for (unsigned i = 0; i < srcRank; i++) {
3608  if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3609  lowerDimMatch = false;
3610  break;
3611  }
3612  }
3613  Value source = broadcast.getSource();
3614  // If the inner dimensions don't match, it means we need to extract from the
3615  // source of the orignal broadcast and then broadcast the extracted value.
3616  // We also need to handle degenerated cases where the source is effectively
3617  // just a single scalar.
3618  bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3619  if (!lowerDimMatch && !isScalarSrc) {
3620  source = rewriter.create<ExtractStridedSliceOp>(
3621  op->getLoc(), source,
3622  getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
3623  getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
3624  getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
3625  }
3626  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
3627  return success();
3628  }
3629 };
3630 
3631 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
3632 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
3633 public:
3635 
3636  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3637  PatternRewriter &rewriter) const override {
3638  auto splat = op.getVector().getDefiningOp<SplatOp>();
3639  if (!splat)
3640  return failure();
3641  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
3642  return success();
3643  }
3644 };
3645 
3646 } // namespace
3647 
3648 void ExtractStridedSliceOp::getCanonicalizationPatterns(
3649  RewritePatternSet &results, MLIRContext *context) {
3650  // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
3651  // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
3652  results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
3653  StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3654  StridedSliceSplat>(context);
3655 }
3656 
3657 //===----------------------------------------------------------------------===//
3658 // TransferReadOp
3659 //===----------------------------------------------------------------------===//
3660 
3661 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
3662 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3663  VectorType vectorType, Value source,
3664  ValueRange indices, AffineMapAttr permutationMapAttr,
3665  /*optional*/ ArrayAttr inBoundsAttr) {
3666  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3667  Value padding = builder.create<arith::ConstantOp>(
3668  result.location, elemType, builder.getZeroAttr(elemType));
3669  build(builder, result, vectorType, source, indices, permutationMapAttr,
3670  padding, /*mask=*/Value(), inBoundsAttr);
3671 }
3672 
3673 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
3674 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3675  VectorType vectorType, Value source,
3676  ValueRange indices, AffineMap permutationMap,
3677  std::optional<ArrayRef<bool>> inBounds) {
3678  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3679  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3680  ? builder.getBoolArrayAttr(inBounds.value())
3681  : ArrayAttr();
3682  build(builder, result, vectorType, source, indices, permutationMapAttr,
3683  inBoundsAttr);
3684 }
3685 
3686 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
3687 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3688  VectorType vectorType, Value source,
3689  ValueRange indices, Value padding,
3690  std::optional<ArrayRef<bool>> inBounds) {
3691  AffineMap permutationMap = getTransferMinorIdentityMap(
3692  llvm::cast<ShapedType>(source.getType()), vectorType);
3693  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3694  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3695  ? builder.getBoolArrayAttr(inBounds.value())
3696  : ArrayAttr();
3697  build(builder, result, vectorType, source, indices, permutationMapAttr,
3698  padding,
3699  /*mask=*/Value(), inBoundsAttr);
3700 }
3701 
3702 /// 4. Builder that sets padding to zero and permutation map to
3703 /// 'getMinorIdentityMap'.
3704 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3705  VectorType vectorType, Value source,
3706  ValueRange indices,
3707  std::optional<ArrayRef<bool>> inBounds) {
3708  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3709  Value padding = builder.create<arith::ConstantOp>(
3710  result.location, elemType, builder.getZeroAttr(elemType));
3711  build(builder, result, vectorType, source, indices, padding, inBounds);
3712 }
3713 
3714 template <typename EmitFun>
3716  EmitFun emitOpError) {
3717  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
3718  for (auto expr : permutationMap.getResults()) {
3719  auto dim = dyn_cast<AffineDimExpr>(expr);
3720  auto zero = dyn_cast<AffineConstantExpr>(expr);
3721  if (zero) {
3722  if (zero.getValue() != 0) {
3723  return emitOpError(
3724  "requires a projected permutation_map (at most one dim or the zero "
3725  "constant can appear in each result)");
3726  }
3727  continue;
3728  }
3729  if (!dim) {
3730  return emitOpError("requires a projected permutation_map (at most one "
3731  "dim or the zero constant can appear in each result)");
3732  }
3733  if (seen[dim.getPosition()]) {
3734  return emitOpError(
3735  "requires a permutation_map that is a permutation (found one dim "
3736  "used more than once)");
3737  }
3738  seen[dim.getPosition()] = true;
3739  }
3740  return success();
3741 }
3742 
3743 static LogicalResult
3744 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
3745  VectorType vectorType, VectorType maskType,
3746  VectorType inferredMaskType, AffineMap permutationMap,
3747  ArrayAttr inBounds) {
3748  if (op->hasAttr("masked")) {
3749  return op->emitOpError("masked attribute has been removed. "
3750  "Use in_bounds instead.");
3751  }
3752 
3753  if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
3754  return op->emitOpError(
3755  "requires source to be a memref or ranked tensor type");
3756 
3757  auto elementType = shapedType.getElementType();
3758  DataLayout dataLayout = DataLayout::closest(op);
3759  if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
3760  // Memref or tensor has vector element type.
3761  unsigned sourceVecSize =
3762  dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
3763  vectorElementType.getShape().back();
3764  unsigned resultVecSize =
3765  dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
3766  vectorType.getShape().back();
3767  if (resultVecSize % sourceVecSize != 0)
3768  return op->emitOpError(
3769  "requires the bitwidth of the minor 1-D vector to be an integral "
3770  "multiple of the bitwidth of the minor 1-D vector of the source");
3771 
3772  unsigned sourceVecEltRank = vectorElementType.getRank();
3773  unsigned resultVecRank = vectorType.getRank();
3774  if (sourceVecEltRank > resultVecRank)
3775  return op->emitOpError(
3776  "requires source vector element and vector result ranks to match.");
3777  unsigned rankOffset = resultVecRank - sourceVecEltRank;
3778  // Check that permutation map results match 'rankOffset' of vector type.
3779  if (permutationMap.getNumResults() != rankOffset)
3780  return op->emitOpError("requires a permutation_map with result dims of "
3781  "the same rank as the vector type");
3782 
3783  if (maskType)
3784  return op->emitOpError("does not support masks with vector element type");
3785  } else {
3786  // Memref or tensor has scalar element type.
3787  unsigned minorSize =
3788  vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
3789  unsigned resultVecSize =
3790  dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
3791  if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
3792  return op->emitOpError(
3793  "requires the bitwidth of the minor 1-D vector to be an integral "
3794  "multiple of the bitwidth of the source element type");
3795 
3796  // Check that permutation map results match rank of vector type.
3797  if (permutationMap.getNumResults() != vectorType.getRank())
3798  return op->emitOpError("requires a permutation_map with result dims of "
3799  "the same rank as the vector type");
3800  }
3801 
3802  if (permutationMap.getNumSymbols() != 0)
3803  return op->emitOpError("requires permutation_map without symbols");
3804 
3805  if (permutationMap.getNumInputs() != shapedType.getRank())
3806  return op->emitOpError("requires a permutation_map with input dims of the "
3807  "same rank as the source type");
3808 
3809  if (maskType && maskType != inferredMaskType)
3810  return op->emitOpError("inferred mask type (")
3811  << inferredMaskType << ") and mask operand type (" << maskType
3812  << ") don't match";
3813 
3814  if (inBounds) {
3815  if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
3816  return op->emitOpError("expects the optional in_bounds attr of same rank "
3817  "as permutation_map results: ")
3818  << AffineMapAttr::get(permutationMap)
3819  << " vs inBounds of size: " << inBounds.size();
3820  for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
3821  if (isa<AffineConstantExpr>(permutationMap.getResult(i)) &&
3822  !llvm::cast<BoolAttr>(inBounds.getValue()[i]).getValue())
3823  return op->emitOpError("requires broadcast dimensions to be in-bounds");
3824  }
3825 
3826  return success();
3827 }
3828 
3829 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
3830  SmallVector<StringRef, 3> elidedAttrs;
3831  elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
3832  if (op.getPermutationMap().isMinorIdentity())
3833  elidedAttrs.push_back(op.getPermutationMapAttrName());
3834  // Elide in_bounds attribute if all dims are out-of-bounds.
3835  if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
3836  elidedAttrs.push_back(op.getInBoundsAttrName());
3837  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
3838 }
3839 
3841  p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
3842  if (getMask())
3843  p << ", " << getMask();
3844  printTransferAttrs(p, *this);
3845  p << " : " << getShapedType() << ", " << getVectorType();
3846 }
3847 
3848 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
3849  AffineMap permMap) {
3850  auto i1Type = IntegerType::get(permMap.getContext(), 1);
3851  AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
3852  assert(invPermMap && "Inversed permutation map couldn't be computed");
3853  SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
3854 
3855  SmallVector<bool> scalableDims =
3856  applyPermutationMap(invPermMap, vecType.getScalableDims());
3857 
3858  return VectorType::get(maskShape, i1Type, scalableDims);
3859 }
3860 
3862  auto &builder = parser.getBuilder();
3863  SMLoc typesLoc;
3864  OpAsmParser::UnresolvedOperand sourceInfo;
3866  OpAsmParser::UnresolvedOperand paddingInfo;
3867  SmallVector<Type, 2> types;
3869  // Parsing with support for paddingValue.
3870  if (parser.parseOperand(sourceInfo) ||
3871  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
3872  parser.parseComma() || parser.parseOperand(paddingInfo))
3873  return failure();
3874  ParseResult hasMask = parser.parseOptionalComma();
3875  if (hasMask.succeeded()) {
3876  if (parser.parseOperand(maskInfo))
3877  return failure();
3878  }
3879  if (parser.parseOptionalAttrDict(result.attributes) ||
3880  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
3881  return failure();
3882  if (types.size() != 2)
3883  return parser.emitError(typesLoc, "requires two types");
3884  auto indexType = builder.getIndexType();
3885  auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
3886  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
3887  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
3888  VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
3889  if (!vectorType)
3890  return parser.emitError(typesLoc, "requires vector type");
3891  auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
3892  Attribute permMapAttr = result.attributes.get(permMapAttrName);
3893  AffineMap permMap;
3894  if (!permMapAttr) {
3895  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3896  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
3897  } else {
3898  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
3899  }
3900  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
3901  parser.resolveOperands(indexInfo, indexType, result.operands) ||
3902  parser.resolveOperand(paddingInfo, shapedType.getElementType(),
3903  result.operands))
3904  return failure();
3905  if (hasMask.succeeded()) {
3906  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
3907  return parser.emitError(
3908  maskInfo.location, "does not support masks with vector element type");
3909  if (vectorType.getRank() != permMap.getNumResults()) {
3910  return parser.emitError(typesLoc,
3911  "expected the same rank for the vector and the "
3912  "results of the permutation map");
3913  }
3914  // Instead of adding the mask type as an op type, compute it based on the
3915  // vector type and the permutation map (to keep the type signature small).
3916  auto maskType = inferTransferOpMaskType(vectorType, permMap);
3917  if (parser.resolveOperand(maskInfo, maskType, result.operands))
3918  return failure();
3919  }
3920  result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
3921  builder.getDenseI32ArrayAttr(
3922  {1, static_cast<int32_t>(indexInfo.size()), 1,
3923  static_cast<int32_t>(hasMask.succeeded())}));
3924  return parser.addTypeToList(vectorType, result.types);
3925 }
3926 
3928  // Consistency of elemental types in source and vector.
3929  ShapedType shapedType = getShapedType();
3930  VectorType vectorType = getVectorType();
3931  VectorType maskType = getMaskType();
3932  auto paddingType = getPadding().getType();
3933  auto permutationMap = getPermutationMap();
3934  VectorType inferredMaskType =
3935  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
3936  : VectorType();
3937  auto sourceElementType = shapedType.getElementType();
3938 
3939  if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
3940  return emitOpError("requires ") << shapedType.getRank() << " indices";
3941 
3942  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
3943  shapedType, vectorType, maskType,
3944  inferredMaskType, permutationMap,
3945  getInBounds() ? *getInBounds() : ArrayAttr())))
3946  return failure();
3947 
3948  if (auto sourceVectorElementType =
3949  llvm::dyn_cast<VectorType>(sourceElementType)) {
3950  // Source has vector element type.
3951  // Check that 'sourceVectorElementType' and 'paddingType' types match.
3952  if (sourceVectorElementType != paddingType)
3953  return emitOpError(
3954  "requires source element type and padding type to match.");
3955 
3956  } else {
3957  // Check that 'paddingType' is valid to store in a vector type.
3958  if (!VectorType::isValidElementType(paddingType))
3959  return emitOpError("requires valid padding vector elemental type");
3960 
3961  // Check that padding type and vector element types match.
3962  if (paddingType != sourceElementType)
3963  return emitOpError(
3964  "requires formal padding and source of the same elemental type");
3965  }
3966 
3967  return verifyPermutationMap(permutationMap,
3968  [&](Twine t) { return emitOpError(t); });
3969 }
3970 
3971 // MaskableOpInterface methods.
3972 
3973 /// Returns the mask type expected by this operation. Mostly used for
3974 /// verification purposes. It requires the operation to be vectorized."
3975 Type TransferReadOp::getExpectedMaskType() {
3976  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
3977 }
3978 
3979 template <typename TransferOp>
3980 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
3981  // TODO: support more aggressive createOrFold on:
3982  // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
3983  if (op.getShapedType().isDynamicDim(indicesIdx))
3984  return false;
3985  Value index = op.getIndices()[indicesIdx];
3986  std::optional<int64_t> cstOp = getConstantIntValue(index);
3987  if (!cstOp.has_value())
3988  return false;
3989 
3990  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
3991  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
3992 
3993  return cstOp.value() + vectorSize <= sourceSize;
3994 }
3995 
3996 template <typename TransferOp>
3998  // TODO: support 0-d corner case.
3999  // TODO: Be less conservative.
4000  if (op.getTransferRank() == 0)
4001  return failure();
4002  AffineMap permutationMap = op.getPermutationMap();
4003  bool changed = false;
4004  SmallVector<bool, 4> newInBounds;
4005  newInBounds.reserve(op.getTransferRank());
4006  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
4007  // Already marked as in-bounds, nothing to see here.
4008  if (op.isDimInBounds(i)) {
4009  newInBounds.push_back(true);
4010  continue;
4011  }
4012  // Currently out-of-bounds, check whether we can statically determine it is
4013  // inBounds.
4014  auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
4015  assert(dimExpr && "Broadcast dims must be in-bounds");
4016  auto inBounds =
4017  isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
4018  newInBounds.push_back(inBounds);
4019  // We commit the pattern if it is "more inbounds".
4020  changed |= inBounds;
4021  }
4022  if (!changed)
4023  return failure();
4024  // OpBuilder is only used as a helper to build an I64ArrayAttr.
4025  OpBuilder b(op.getContext());
4026  op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
4027  return success();
4028 }
4029 
4030 template <typename TransferOp>
4031 static LogicalResult foldTransferFullMask(TransferOp op) {
4032  auto mask = op.getMask();
4033  if (!mask)
4034  return failure();
4035 
4036  auto constantMask = mask.template getDefiningOp<vector::ConstantMaskOp>();
4037  if (!constantMask)
4038  return failure();
4039 
4040  if (!constantMask.isAllOnesMask())
4041  return failure();
4042 
4043  op.getMaskMutable().clear();
4044  return success();
4045 }
4046 
4047 /// ```
4048 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4049 /// : vector<1x4xf32>, tensor<4x4xf32>
4050 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
4051 /// : tensor<4x4xf32>, vector<1x4xf32>
4052 /// ```
4053 /// -> Folds into
4054 /// ```
4055 /// %v0
4056 /// ```
4057 static Value foldRAW(TransferReadOp readOp) {
4058  if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4059  return {};
4060  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4061  while (defWrite) {
4062  if (checkSameValueRAW(defWrite, readOp))
4063  return defWrite.getVector();
4065  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4066  cast<VectorTransferOpInterface>(readOp.getOperation())))
4067  break;
4068  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4069  }
4070  return {};
4071 }
4072 
4073 OpFoldResult TransferReadOp::fold(FoldAdaptor) {
4074  if (Value vec = foldRAW(*this))
4075  return vec;
4076  /// transfer_read(memrefcast) -> transfer_read
4078  return getResult();
4079  if (succeeded(foldTransferFullMask(*this)))
4080  return getResult();
4081  if (succeeded(memref::foldMemRefCast(*this)))
4082  return getResult();
4083  if (succeeded(tensor::foldTensorCast(*this)))
4084  return getResult();
4085  return OpFoldResult();
4086 }
4087 
4088 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4089  return llvm::to_vector<4>(getVectorType().getShape());
4090 }
4091 
4092 void TransferReadOp::getEffects(
4094  &effects) {
4095  if (llvm::isa<MemRefType>(getShapedType()))
4096  effects.emplace_back(MemoryEffects::Read::get(), getSource(),
4098 }
4099 
4100 namespace {
4101 /// Store to load forwarding for transfer operations with permuation maps.
4102 /// Even if the permutation maps are different we can still propagate the store
4103 /// into the load if the size of the dimensions read and written match. Then we
4104 /// can replace the transfer_read + transfer_write by vector.broadcast and
4105 /// vector.transpose.
4106 /// Example:
4107 /// ```
4108 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
4109 /// {in_bounds = [true, true],
4110 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
4111 /// vector<4x1xf32>, tensor<4x4x4xf32>
4112 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
4113 /// {in_bounds = [true, true, true, true],
4114 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
4115 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
4116 /// ```
4117 /// To:
4118 /// ```
4119 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
4120 /// %r = vector.transpose %0, [3, 0, 2, 1] :
4121 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
4122 /// ```
4123 struct TransferReadAfterWriteToBroadcast
4124  : public OpRewritePattern<TransferReadOp> {
4126 
4127  LogicalResult matchAndRewrite(TransferReadOp readOp,
4128  PatternRewriter &rewriter) const override {
4129  if (readOp.hasOutOfBoundsDim() ||
4130  !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4131  return failure();
4132  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4133  if (!defWrite)
4134  return failure();
4135  // TODO: If the written transfer chunk is a superset of the read transfer
4136  // chunk we could do an extract_strided_slice.
4137  if (readOp.getTransferChunkAccessed() !=
4138  defWrite.getTransferChunkAccessed())
4139  return failure();
4140  // TODO: Support cases where a dim is explicitly written but implicitly
4141  // read (i.e., a unit dim that is rank reduced).
4142  if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
4143  getUnusedDimsBitVector({defWrite.getPermutationMap()}))
4144  return failure();
4145  if (readOp.getIndices() != defWrite.getIndices() ||
4146  readOp.getMask() != defWrite.getMask())
4147  return failure();
4148  Value vec = defWrite.getVector();
4149  // TODO: loop through the chain of transfer_write if we can prove that they
4150  // don't overlap with the transfer_read. This requires improving
4151  // `isDisjointTransferIndices` helper.
4152  AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4153  AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
4154  AffineMap map = readMap.compose(writeMap);
4155  if (map.getNumResults() == 0)
4156  return failure();
4157  // Calculate the permutation to apply to go from the vector stored to the
4158  // vector read.
4159  SmallVector<unsigned> permutation;
4160  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
4161  return failure();
4162 
4163  Location loc = readOp.getLoc();
4164  // Calculate the broadcast shape by applying the reverse permutation to the
4165  // final shape we want.
4166  ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
4167  SmallVector<int64_t> broadcastShape(destShape.size());
4168  SmallVector<bool> broadcastScalableFlags(destShape.size());
4169  for (const auto &pos : llvm::enumerate(permutation)) {
4170  broadcastShape[pos.value()] = destShape[pos.index()];
4171  broadcastScalableFlags[pos.value()] =
4172  readOp.getVectorType().getScalableDims()[pos.index()];
4173  }
4174  VectorType broadcastedType = VectorType::get(
4175  broadcastShape, defWrite.getVectorType().getElementType(),
4176  broadcastScalableFlags);
4177  vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
4178  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
4179  rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
4180  transposePerm);
4181  return success();
4182  }
4183 };
4184 } // namespace
4185 
4186 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4187  MLIRContext *context) {
4188  results.add<TransferReadAfterWriteToBroadcast>(context);
4189 }
4190 
4191 //===----------------------------------------------------------------------===//
4192 // TransferWriteOp
4193 //===----------------------------------------------------------------------===//
4194 
4195 /// 1. Builder with type inference.
4196 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4197  Value vector, Value dest, ValueRange indices,
4198  AffineMapAttr permutationMapAttr,
4199  /*optional*/ Value mask,
4200  /*optional*/ ArrayAttr inBoundsAttr) {
4201  Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
4202  build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4203  mask, inBoundsAttr);
4204 }
4205 
4206 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
4207 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4208  Value vector, Value dest, ValueRange indices,
4209  AffineMapAttr permutationMapAttr,
4210  /*optional*/ ArrayAttr inBoundsAttr) {
4211  build(builder, result, vector, dest, indices, permutationMapAttr,
4212  /*mask=*/Value(), inBoundsAttr);
4213 }
4214 
4215 /// 3. Builder with type inference that sets an empty mask (variant without
4216 /// attrs)
4217 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4218  Value vector, Value dest, ValueRange indices,
4219  AffineMap permutationMap,
4220  std::optional<ArrayRef<bool>> inBounds) {
4221  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4222  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4223  ? builder.getBoolArrayAttr(inBounds.value())
4224  : ArrayAttr();
4225  build(builder, result, vector, dest, indices, permutationMapAttr,
4226  /*mask=*/Value(), inBoundsAttr);
4227 }
4228 
4229 /// 4. Builder with type inference that sets an empty mask and sets permutation
4230 /// map to 'getMinorIdentityMap'.
4231 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4232  Value vector, Value dest, ValueRange indices,
4233  std::optional<ArrayRef<bool>> inBounds) {
4234  auto vectorType = llvm::cast<VectorType>(vector.getType());
4235  AffineMap permutationMap = getTransferMinorIdentityMap(
4236  llvm::cast<ShapedType>(dest.getType()), vectorType);
4237  build(builder, result, vector, dest, indices, permutationMap, inBounds);
4238 }
4239 
4241  OperationState &result) {
4242  auto &builder = parser.getBuilder();
4243  SMLoc typesLoc;
4244  OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
4246  SmallVector<Type, 2> types;
4248  if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
4249  parser.parseOperand(sourceInfo) ||
4250  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
4251  return failure();
4252  ParseResult hasMask = parser.parseOptionalComma();
4253  if (hasMask.succeeded() && parser.parseOperand(maskInfo))
4254  return failure();
4255  if (parser.parseOptionalAttrDict(result.attributes) ||
4256  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4257  return failure();
4258  if (types.size() != 2)
4259  return parser.emitError(typesLoc, "requires two types");
4260  auto indexType = builder.getIndexType();
4261  VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4262  if (!vectorType)
4263  return parser.emitError(typesLoc, "requires vector type");
4264  ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4265  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4266  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4267  auto permMapAttrName =
4268  TransferWriteOp::getPermutationMapAttrName(result.name);
4269  auto permMapAttr = result.attributes.get(permMapAttrName);
4270  AffineMap permMap;
4271  if (!permMapAttr) {
4272  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4273  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4274  } else {
4275  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4276  }
4277  if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
4278  parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4279  parser.resolveOperands(indexInfo, indexType, result.operands))
4280  return failure();
4281  if (hasMask.succeeded()) {
4282  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4283  return parser.emitError(
4284  maskInfo.location, "does not support masks with vector element type");
4285  if (vectorType.getRank() != permMap.getNumResults()) {
4286  return parser.emitError(typesLoc,
4287  "expected the same rank for the vector and the "
4288  "results of the permutation map");
4289  }
4290  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4291  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4292  return failure();
4293  }
4294  result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4295  builder.getDenseI32ArrayAttr(
4296  {1, 1, static_cast<int32_t>(indexInfo.size()),
4297  static_cast<int32_t>(hasMask.succeeded())}));
4298  return failure(llvm::isa<RankedTensorType>(shapedType) &&
4299  parser.addTypeToList(shapedType, result.types));
4300 }
4301 
4303  p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
4304  if (getMask())
4305  p << ", " << getMask();
4306  printTransferAttrs(p, *this);
4307  p << " : " << getVectorType() << ", " << getShapedType();
4308 }
4309 
4311  // Consistency of elemental types in shape and vector.
4312  ShapedType shapedType = getShapedType();
4313  VectorType vectorType = getVectorType();
4314  VectorType maskType = getMaskType();
4315  auto permutationMap = getPermutationMap();
4316  VectorType inferredMaskType =
4317  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4318  : VectorType();
4319 
4320  if (llvm::size(getIndices()) != shapedType.getRank())
4321  return emitOpError("requires ") << shapedType.getRank() << " indices";
4322 
4323  // We do not allow broadcast dimensions on TransferWriteOps for the moment,
4324  // as the semantics is unclear. This can be revisited later if necessary.
4325  if (hasBroadcastDim())
4326  return emitOpError("should not have broadcast dimensions");
4327 
4328  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4329  shapedType, vectorType, maskType,
4330  inferredMaskType, permutationMap,
4331  getInBounds() ? *getInBounds() : ArrayAttr())))
4332  return failure();
4333 
4334  return verifyPermutationMap(permutationMap,
4335  [&](Twine t) { return emitOpError(t); });
4336 }
4337 
4338 // MaskableOpInterface methods.
4339 
4340 /// Returns the mask type expected by this operation. Mostly used for
4341 /// verification purposes.
4342 Type TransferWriteOp::getExpectedMaskType() {
4343  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4344 }
4345 
4346 /// Fold:
4347 /// ```
4348 /// %t1 = ...
4349 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
4350 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4351 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
4352 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4353 /// ```
4354 ///
4355 /// into:
4356 ///
4357 /// ```
4358 /// %t0
4359 /// ```
4360 ///
4361 /// The producer of t1 may or may not be DCE'd depending on whether it is a
4362 /// block argument or has side effects.
4363 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4365  SmallVectorImpl<OpFoldResult> &results) {
4366  // TODO: support 0-d corner case.
4367  if (write.getTransferRank() == 0)
4368  return failure();
4369  auto rankedTensorType =
4370  llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4371  // If not operating on tensors, bail.
4372  if (!rankedTensorType)
4373  return failure();
4374  // If no read, bail.
4375  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4376  if (!read)
4377  return failure();
4378  // TODO: support 0-d corner case.
4379  if (read.getTransferRank() == 0)
4380  return failure();
4381  // For now, only accept minor identity. Future: composition is minor identity.
4382  if (!read.getPermutationMap().isMinorIdentity() ||
4383  !write.getPermutationMap().isMinorIdentity())
4384  return failure();
4385  // Bail on mismatching ranks.
4386  if (read.getTransferRank() != write.getTransferRank())
4387  return failure();
4388  // Bail on potential out-of-bounds accesses.
4389  if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4390  return failure();
4391  // Tensor types must be the same.
4392  if (read.getSource().getType() != rankedTensorType)
4393  return failure();
4394  // Vector types must be the same.
4395  if (read.getVectorType() != write.getVectorType())
4396  return failure();
4397  // Vector and Tensor shapes must match.
4398  if (read.getVectorType().getShape() != rankedTensorType.getShape())
4399  return failure();
4400  // If any index is nonzero.
4401  auto isNotConstantZero = [](Value v) {
4402  auto cstOp = getConstantIntValue(v);
4403  return !cstOp.has_value() || cstOp.value() != 0;
4404  };
4405  if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4406  llvm::any_of(write.getIndices(), isNotConstantZero))
4407  return failure();
4408  // Success.
4409  results.push_back(read.getSource());
4410  return success();
4411 }
4412 
4413 static bool checkSameValueWAR(vector::TransferReadOp read,
4414  vector::TransferWriteOp write) {
4415  return read.getSource() == write.getSource() &&
4416  read.getIndices() == write.getIndices() &&
4417  read.getPermutationMap() == write.getPermutationMap() &&
4418  read.getVectorType() == write.getVectorType() && !read.getMask() &&
4419  !write.getMask();
4420 }
4421 /// Fold transfer_write write after read:
4422 /// ```
4423 /// %t0 = ...
4424 /// %v = vector.transfer_read %t0[%c0...] :
4425 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4426 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
4427 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4428 /// ```
4429 ///
4430 /// into:
4431 ///
4432 /// ```
4433 /// %t0
4434 /// ```
4435 static LogicalResult foldWAR(TransferWriteOp write,
4436  SmallVectorImpl<OpFoldResult> &results) {
4437  if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4438  return failure();
4439  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4440  if (!read)
4441  return failure();
4442 
4443  if (!checkSameValueWAR(read, write))
4444  return failure();
4445  results.push_back(read.getSource());
4446  return success();
4447 }
4448 
4449 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4450  SmallVectorImpl<OpFoldResult> &results) {
4451  if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
4452  return success();
4453  if (succeeded(foldWAR(*this, results)))
4454  return success();
4456  return success();
4457  if (succeeded(foldTransferFullMask(*this)))
4458  return success();
4459  return memref::foldMemRefCast(*this);
4460 }
4461 
4462 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4463  return llvm::to_vector<4>(getVectorType().getShape());
4464 }
4465 
4466 void TransferWriteOp::getEffects(
4468  &effects) {
4469  if (llvm::isa<MemRefType>(getShapedType()))
4470  effects.emplace_back(MemoryEffects::Write::get(), getSource(),
4472 }
4473 
4474 namespace {
4475 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
4476 /// DCE
4477 /// ```
4478 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4479 /// : vector<1x4xf32>, tensor<4x4xf32>
4480 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
4481 /// : vector<1x4xf32>, tensor<4x4xf32>
4482 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4483 /// : vector<1x4xf32>, tensor<4x4xf32>
4484 /// ```
4485 ///
4486 /// into:
4487 ///
4488 /// ```
4489 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4490 /// : vector<1x4xf32>, tensor<4x4xf32>
4491 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
4492 /// : vector<1x4xf32>, tensor<4x4xf32>
4493 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4494 /// : vector<1x4xf32>, tensor<4x4xf32>
4495 /// ```
4496 ///
4497 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
4498 /// any other uses.
4499 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
4500 public:
4502  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
4503  PatternRewriter &rewriter) const override {
4504  if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4505  return failure();
4506  vector::TransferWriteOp writeToModify = writeOp;
4507 
4508  auto defWrite =
4509  writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4510  while (defWrite) {
4511  if (checkSameValueWAW(writeOp, defWrite)) {
4512  rewriter.modifyOpInPlace(writeToModify, [&]() {
4513  writeToModify.getSourceMutable().assign(defWrite.getSource());
4514  });
4515  return success();
4516  }
4518  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4519  cast<VectorTransferOpInterface>(writeOp.getOperation())))
4520  break;
4521  // If the previous write op doesn't have any other use we an safely look
4522  // at the previous store to see if it can be removed.
4523  if (!defWrite->hasOneUse())
4524  break;
4525  writeToModify = defWrite;
4526  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4527  }
4528  return failure();
4529  }
4530 };
4531 
4532 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
4533 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
4534 /// overwritten and inserted into another tensor. After this rewrite, the
4535 /// operations bufferize in-place since all of them work on the same slice.
4536 ///
4537 /// For example:
4538 /// ```mlir
4539 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
4540 /// : vector<8x16xf32>, tensor<8x16xf32>
4541 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
4542 /// : tensor<8x16xf32> to tensor<?x?xf32>
4543 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4544 /// : tensor<?x?xf32> into tensor<27x37xf32>
4545 /// ```
4546 /// folds to
4547 /// ```mlir
4548 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4549 /// : tensor<27x37xf32> to tensor<?x?xf32>
4550 /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
4551 /// : vector<8x16xf32>, tensor<?x?xf32>
4552 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4553 /// : tensor<?x?xf32> into tensor<27x37xf32>
4554 /// ```
4555 struct SwapExtractSliceOfTransferWrite
4556  : public OpRewritePattern<tensor::InsertSliceOp> {
4557 public:
4559 
4560  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
4561  PatternRewriter &rewriter) const override {
4562  if (!insertOp.hasUnitStride())
4563  return failure();
4564  auto extractOp =
4565  insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4566  if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4567  return failure();
4568  auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4569  if (!transferOp || !transferOp->hasOneUse())
4570  return failure();
4571 
4572  // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
4573  // rank-reducing.
4574  if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4575  return rewriter.notifyMatchFailure(insertOp,
4576  "use-def chain is rank-reducing");
4577  }
4578 
4579  // Fail if tensor::ExtractSliceOp has non-zero offset.
4580  if (!extractOp.hasZeroOffset()) {
4581  return rewriter.notifyMatchFailure(insertOp,
4582  "ExtractSliceOp has non-zero offset");
4583  }
4584 
4585  // Fail if tensor::TransferWriteOp has non-zero offset.
4586  if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
4587  return getConstantIntValue(value) == static_cast<int64_t>(0);
4588  })) {
4589  return rewriter.notifyMatchFailure(insertOp,
4590  "TranferWriteOp has non-zero offset");
4591  }
4592 
4593  // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
4594  if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
4595  return rewriter.notifyMatchFailure(
4596  insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
4597  }
4598 
4599  for (auto [insertSize, extractSize] :
4600  llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
4601  if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
4602  return rewriter.notifyMatchFailure(
4603  insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
4604  }
4605  }
4606 
4607  // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
4608  assert(transferOp.getVectorType().hasStaticShape() &&
4609  "expected vector to have a static shape");
4610  ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
4612  transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
4613  if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
4614  return rewriter.notifyMatchFailure(
4615  insertOp, "TransferWriteOp may not write the full tensor.");
4616  }
4617 
4618  // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
4619  // Set all in_bounds to false and let the folder infer them.
4620  SmallVector<bool> newInBounds(vectorShape.size(), false);
4621  auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
4622  extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
4623  insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
4624  insertOp.getMixedStrides());
4625  auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
4626  transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
4627  transferOp.getIndices(), transferOp.getPermutationMapAttr(),
4628  rewriter.getBoolArrayAttr(newInBounds));
4629  rewriter.modifyOpInPlace(insertOp, [&]() {
4630  insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
4631  });
4632  return success();
4633  }
4634 };
4635 
4636 } // namespace
4637 
4638 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
4639  MLIRContext *context) {
4640  results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
4641 }
4642 
4643 //===----------------------------------------------------------------------===//
4644 // LoadOp
4645 //===----------------------------------------------------------------------===//
4646 
4647 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
4648  MemRefType memRefTy) {
4649  if (!isLastMemrefDimUnitStride(memRefTy))
4650  return op->emitOpError("most minor memref dim must have unit stride");
4651  return success();
4652 }
4653 
4655  VectorType resVecTy = getVectorType();
4656  MemRefType memRefTy = getMemRefType();
4657 
4658  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
4659  return failure();
4660 
4661  // Checks for vector memrefs.
4662  Type memElemTy = memRefTy.getElementType();
4663  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4664  if (memVecTy != resVecTy)
4665  return emitOpError("base memref and result vector types should match");
4666  memElemTy = memVecTy.getElementType();
4667  }
4668 
4669  if (resVecTy.getElementType() != memElemTy)
4670  return emitOpError("base and result element types should match");
4671  if (llvm::size(getIndices()) != memRefTy.getRank())
4672  return emitOpError("requires ") << memRefTy.getRank() << " indices";
4673  return success();
4674 }
4675 
4676 OpFoldResult LoadOp::fold(FoldAdaptor) {
4677  if (succeeded(memref::foldMemRefCast(*this)))
4678  return getResult();
4679  return OpFoldResult();
4680 }
4681 
4682 //===----------------------------------------------------------------------===//
4683 // StoreOp
4684 //===----------------------------------------------------------------------===//
4685 
4687  VectorType valueVecTy = getVectorType();
4688  MemRefType memRefTy = getMemRefType();
4689 
4690  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
4691  return failure();
4692 
4693  // Checks for vector memrefs.
4694  Type memElemTy = memRefTy.getElementType();
4695  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4696  if (memVecTy != valueVecTy)
4697  return emitOpError(
4698  "base memref and valueToStore vector types should match");
4699  memElemTy = memVecTy.getElementType();
4700  }
4701 
4702  if (valueVecTy.getElementType() != memElemTy)
4703  return emitOpError("base and valueToStore element type should match");
4704  if (llvm::size(getIndices()) != memRefTy.getRank())
4705  return emitOpError("requires ") << memRefTy.getRank() << " indices";
4706  return success();
4707 }
4708 
4709 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
4710  SmallVectorImpl<OpFoldResult> &results) {
4711  return memref::foldMemRefCast(*this);
4712 }
4713 
4714 //===----------------------------------------------------------------------===//
4715 // MaskedLoadOp
4716 //===----------------------------------------------------------------------===//
4717 
4719  VectorType maskVType = getMaskVectorType();
4720  VectorType passVType = getPassThruVectorType();
4721  VectorType resVType = getVectorType();
4722  MemRefType memType = getMemRefType();
4723 
4724  if (resVType.getElementType() != memType.getElementType())
4725  return emitOpError("base and result element type should match");
4726  if (llvm::size(getIndices()) != memType.getRank())
4727  return emitOpError("requires ") << memType.getRank() << " indices";
4728  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4729  return emitOpError("expected result dim to match mask dim");
4730  if (resVType != passVType)
4731  return emitOpError("expected pass_thru of same type as result type");
4732  return success();
4733 }
4734 
4735 namespace {
4736 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
4737 public:
4739  LogicalResult matchAndRewrite(MaskedLoadOp load,
4740  PatternRewriter &rewriter) const override {
4741  switch (getMaskFormat(load.getMask())) {
4742  case MaskFormat::AllTrue:
4743  rewriter.replaceOpWithNewOp<vector::LoadOp>(
4744  load, load.getType(), load.getBase(), load.getIndices());
4745  return success();
4746  case MaskFormat::AllFalse:
4747  rewriter.replaceOp(load, load.getPassThru());
4748  return success();
4749  case MaskFormat::Unknown:
4750  return failure();
4751  }
4752  llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
4753  }
4754 };
4755 } // namespace
4756 
4757 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4758  MLIRContext *context) {
4759  results.add<MaskedLoadFolder>(context);
4760 }
4761 
4762 OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
4763  if (succeeded(memref::foldMemRefCast(*this)))
4764  return getResult();
4765  return OpFoldResult();
4766 }
4767 
4768 //===----------------------------------------------------------------------===//
4769 // MaskedStoreOp
4770 //===----------------------------------------------------------------------===//
4771 
4773  VectorType maskVType = getMaskVectorType();
4774  VectorType valueVType = getVectorType();
4775  MemRefType memType = getMemRefType();
4776 
4777  if (valueVType.getElementType() != memType.getElementType())
4778  return emitOpError("base and valueToStore element type should match");
4779  if (llvm::size(getIndices()) != memType.getRank())
4780  return emitOpError("requires ") << memType.getRank() << " indices";
4781  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4782  return emitOpError("expected valueToStore dim to match mask dim");
4783  return success();
4784 }
4785 
4786 namespace {
4787 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
4788 public:
4790  LogicalResult matchAndRewrite(MaskedStoreOp store,
4791  PatternRewriter &rewriter) const override {
4792  switch (getMaskFormat(store.getMask())) {
4793  case MaskFormat::AllTrue:
4794  rewriter.replaceOpWithNewOp<vector::StoreOp>(
4795  store, store.getValueToStore(), store.getBase(), store.getIndices());
4796  return success();
4797  case MaskFormat::AllFalse:
4798  rewriter.eraseOp(store);
4799  return success();
4800  case MaskFormat::Unknown:
4801  return failure();
4802  }
4803  llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
4804  }
4805 };
4806 } // namespace
4807 
4808 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
4809  MLIRContext *context) {
4810  results.add<MaskedStoreFolder>(context);
4811 }
4812 
4813 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
4814  SmallVectorImpl<OpFoldResult> &results) {
4815  return memref::foldMemRefCast(*this);
4816 }
4817 
4818 //===----------------------------------------------------------------------===//
4819 // GatherOp
4820 //===----------------------------------------------------------------------===//
4821 
4823  VectorType indVType = getIndexVectorType();
4824  VectorType maskVType = getMaskVectorType();
4825  VectorType resVType = getVectorType();
4826  ShapedType baseType = getBaseType();
4827 
4828  if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
4829  return emitOpError("requires base to be a memref or ranked tensor type");
4830 
4831  if (resVType.getElementType() != baseType.getElementType())
4832  return emitOpError("base and result element type should match");
4833  if (llvm::size(getIndices()) != baseType.getRank())
4834  return emitOpError("requires ") << baseType.getRank() << " indices";
4835  if (resVType.getShape() != indVType.getShape())
4836  return emitOpError("expected result dim to match indices dim");
4837  if (resVType.getShape() != maskVType.getShape())
4838  return emitOpError("expected result dim to match mask dim");
4839  if (resVType != getPassThruVectorType())
4840  return emitOpError("expected pass_thru of same type as result type");
4841  return success();
4842 }
4843 
4844 // MaskableOpInterface methods.
4845 
4846 /// Returns the mask type expected by this operation. Mostly used for
4847 /// verification purposes. It requires the operation to be vectorized."
4848 Type GatherOp::getExpectedMaskType() {
4849  auto vecType = this->getIndexVectorType();
4850  return VectorType::get(vecType.getShape(),
4851  IntegerType::get(vecType.getContext(), /*width=*/1),
4852  vecType.getScalableDims());
4853 }
4854 
4855 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
4856  return llvm::to_vector<4>(getVectorType().getShape());
4857 }
4858 
4859 namespace {
4860 class GatherFolder final : public OpRewritePattern<GatherOp> {
4861 public:
4863  LogicalResult matchAndRewrite(GatherOp gather,
4864  PatternRewriter &rewriter) const override {
4865  switch (getMaskFormat(gather.getMask())) {
4866  case MaskFormat::AllTrue:
4867  return failure(); // no unmasked equivalent
4868  case MaskFormat::AllFalse:
4869  rewriter.replaceOp(gather, gather.getPassThru());
4870  return success();
4871  case MaskFormat::Unknown:
4872  return failure();
4873  }
4874  llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
4875  }
4876 };
4877 } // namespace
4878 
4879 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
4880  MLIRContext *context) {
4881  results.add<GatherFolder>(context);
4882 }
4883 
4884 //===----------------------------------------------------------------------===//
4885 // ScatterOp
4886 //===----------------------------------------------------------------------===//
4887 
4889  VectorType indVType = getIndexVectorType();
4890  VectorType maskVType = getMaskVectorType();
4891  VectorType valueVType = getVectorType();
4892  MemRefType memType = getMemRefType();
4893 
4894  if (valueVType.getElementType() != memType.getElementType())
4895  return emitOpError("base and valueToStore element type should match");
4896  if (llvm::size(getIndices()) != memType.getRank())
4897  return emitOpError("requires ") << memType.getRank() << " indices";
4898  if (valueVType.getDimSize(0) != indVType.getDimSize(0))
4899  return emitOpError("expected valueToStore dim to match indices dim");
4900  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4901  return emitOpError("expected valueToStore dim to match mask dim");
4902  return success();
4903 }
4904 
4905 namespace {
4906 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
4907 public:
4909  LogicalResult matchAndRewrite(ScatterOp scatter,
4910  PatternRewriter &rewriter) const override {
4911  switch (getMaskFormat(scatter.getMask())) {
4912  case MaskFormat::AllTrue:
4913  return failure(); // no unmasked equivalent
4914  case MaskFormat::AllFalse:
4915  rewriter.eraseOp(scatter);
4916  return success();
4917  case MaskFormat::Unknown:
4918  return failure();
4919  }
4920  llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
4921  }
4922 };
4923 } // namespace
4924 
4925 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
4926  MLIRContext *context) {
4927  results.add<ScatterFolder>(context);
4928 }
4929 
4930 //===----------------------------------------------------------------------===//
4931 // ExpandLoadOp
4932 //===----------------------------------------------------------------------===//
4933 
4935  VectorType maskVType = getMaskVectorType();
4936  VectorType passVType = getPassThruVectorType();
4937  VectorType resVType = getVectorType();
4938  MemRefType memType = getMemRefType();
4939 
4940  if (resVType.getElementType() != memType.getElementType())
4941  return emitOpError("base and result element type should match");
4942  if (llvm::size(getIndices()) != memType.getRank())
4943  return emitOpError("requires ") << memType.getRank() << " indices";
4944  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4945  return emitOpError("expected result dim to match mask dim");
4946  if (resVType != passVType)
4947  return emitOpError("expected pass_thru of same type as result type");
4948  return success();
4949 }
4950 
4951 namespace {
4952 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
4953 public:
4955  LogicalResult matchAndRewrite(ExpandLoadOp expand,
4956  PatternRewriter &rewriter) const override {
4957  switch (getMaskFormat(expand.getMask())) {
4958  case MaskFormat::AllTrue:
4959  rewriter.replaceOpWithNewOp<vector::LoadOp>(
4960  expand, expand.getType(), expand.getBase(), expand.getIndices());
4961  return success();
4962  case MaskFormat::AllFalse:
4963  rewriter.replaceOp(expand, expand.getPassThru());
4964  return success();
4965  case MaskFormat::Unknown:
4966  return failure();
4967  }
4968  llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
4969  }
4970 };
4971 } // namespace
4972 
4973 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4974  MLIRContext *context) {
4975  results.add<ExpandLoadFolder>(context);
4976 }
4977 
4978 //===----------------------------------------------------------------------===//
4979 // CompressStoreOp
4980 //===----------------------------------------------------------------------===//
4981 
4983  VectorType maskVType = getMaskVectorType();
4984  VectorType valueVType = getVectorType();
4985  MemRefType memType = getMemRefType();
4986 
4987  if (valueVType.getElementType() != memType.getElementType())
4988  return emitOpError("base and valueToStore element type should match");
4989  if (llvm::size(getIndices()) != memType.getRank())
4990  return emitOpError("requires ") << memType.getRank() << " indices";
4991  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4992  return emitOpError("expected valueToStore dim to match mask dim");
4993  return success();
4994 }
4995 
4996 namespace {
4997 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
4998 public:
5000  LogicalResult matchAndRewrite(CompressStoreOp compress,
5001  PatternRewriter &rewriter) const override {
5002  switch (getMaskFormat(compress.getMask())) {
5003  case MaskFormat::AllTrue:
5004  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5005  compress, compress.getValueToStore(), compress.getBase(),
5006  compress.getIndices());
5007  return success();
5008  case MaskFormat::AllFalse:
5009  rewriter.eraseOp(compress);
5010  return success();
5011  case MaskFormat::Unknown:
5012  return failure();
5013  }
5014  llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
5015  }
5016 };
5017 } // namespace
5018 
5019 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5020  MLIRContext *context) {
5021  results.add<CompressStoreFolder>(context);
5022 }
5023 
5024 //===----------------------------------------------------------------------===//
5025 // ShapeCastOp
5026 //===----------------------------------------------------------------------===//
5027 
5028 /// Returns true if each element of 'a' is equal to the product of a contiguous
5029 /// sequence of the elements of 'b'. Returns false otherwise.
5030 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
5031  unsigned rankA = a.size();
5032  unsigned rankB = b.size();
5033  assert(rankA < rankB);
5034 
5035  auto isOne = [](int64_t v) { return v == 1; };
5036 
5037  // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5038  // casted to a 0-d vector.
5039  if (rankA == 0 && llvm::all_of(b, isOne))
5040  return true;
5041 
5042  unsigned i = 0;
5043  unsigned j = 0;
5044  while (i < rankA && j < rankB) {
5045  int64_t dimA = a[i];
5046  int64_t dimB = 1;
5047  while (dimB < dimA && j < rankB)
5048  dimB *= b[j++];
5049  if (dimA != dimB)
5050  break;
5051  ++i;
5052 
5053  // Handle the case when trailing dimensions are of size 1.
5054  // Include them into the contiguous sequence.
5055  if (i < rankA && llvm::all_of(a.slice(i), isOne))
5056  i = rankA;
5057  if (j < rankB && llvm::all_of(b.slice(j), isOne))
5058  j = rankB;
5059  }
5060 
5061  return i == rankA && j == rankB;
5062 }
5063 
5064 static LogicalResult verifyVectorShapeCast(Operation *op,
5065  VectorType sourceVectorType,
5066  VectorType resultVectorType) {
5067  // Check that element type is the same.
5068  if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5069  return op->emitOpError("source/result vectors must have same element type");
5070  auto sourceShape = sourceVectorType.getShape();
5071  auto resultShape = resultVectorType.getShape();
5072 
5073  // Check that product of source dim sizes matches product of result dim sizes.
5074  int64_t sourceDimProduct = std::accumulate(
5075  sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5076  int64_t resultDimProduct = std::accumulate(
5077  resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5078  if (sourceDimProduct != resultDimProduct)
5079  return op->emitOpError("source/result number of elements must match");
5080 
5081  // Check that expanding/contracting rank cases.
5082  unsigned sourceRank = sourceVectorType.getRank();
5083  unsigned resultRank = resultVectorType.getRank();
5084  if (sourceRank < resultRank) {
5085  if (!isValidShapeCast(sourceShape, resultShape))
5086  return op->emitOpError("invalid shape cast");
5087  } else if (sourceRank > resultRank) {
5088  if (!isValidShapeCast(resultShape, sourceShape))
5089  return op->emitOpError("invalid shape cast");
5090  }
5091  return success();
5092 }
5093 
5095  auto sourceVectorType =
5096  llvm::dyn_cast_or_null<VectorType>(getSource().getType());
5097  auto resultVectorType =
5098  llvm::dyn_cast_or_null<VectorType>(getResult().getType());
5099 
5100  // Check if source/result are of vector type.
5101  if (sourceVectorType && resultVectorType)
5102  return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
5103 
5104  return success();
5105 }
5106 
5107 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5108  // No-op shape cast.
5109  if (getSource().getType() == getResult().getType())
5110  return getSource();
5111 
5112  // Canceling shape casts.
5113  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5114  if (getResult().getType() == otherOp.getSource().getType())
5115  return otherOp.getSource();
5116 
5117  // Only allows valid transitive folding.
5118  VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5119  VectorType resultType = llvm::cast<VectorType>(getResult().getType());
5120  if (srcType.getRank() < resultType.getRank()) {
5121  if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5122  return {};
5123  } else if (srcType.getRank() > resultType.getRank()) {
5124  if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5125  return {};
5126  } else {
5127  return {};
5128  }
5129 
5130  setOperand(otherOp.getSource());
5131  return getResult();
5132  }
5133 
5134  // Cancelling broadcast and shape cast ops.
5135  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5136  if (bcastOp.getSourceType() == getType())
5137  return bcastOp.getSource();
5138  }
5139 
5140  return {};
5141 }
5142 
5143 namespace {
5144 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
5145 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
5146 public:
5148 
5149  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5150  PatternRewriter &rewriter) const override {
5151  auto constantOp =
5152  shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5153  if (!constantOp)
5154  return failure();
5155  // Only handle splat for now.
5156  auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5157  if (!dense)
5158  return failure();
5159  auto newAttr =
5160  DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
5161  dense.getSplatValue<Attribute>());
5162  rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
5163  return success();
5164  }
5165 };
5166 
5167 /// Helper function that computes a new vector type based on the input vector
5168 /// type by removing the trailing one dims:
5169 ///
5170 /// vector<4x1x1xi1> --> vector<4x1>
5171 ///
5172 static VectorType trimTrailingOneDims(VectorType oldType) {
5173  ArrayRef<int64_t> oldShape = oldType.getShape();
5174  ArrayRef<int64_t> newShape = oldShape;
5175 
5176  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
5177  ArrayRef<bool> newScalableDims = oldScalableDims;
5178 
5179  while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5180  newShape = newShape.drop_back(1);
5181  newScalableDims = newScalableDims.drop_back(1);
5182  }
5183 
5184  // Make sure we have at least 1 dimension.
5185  // TODO: Add support for 0-D vectors.
5186  if (newShape.empty()) {
5187  newShape = oldShape.take_back();
5188  newScalableDims = oldScalableDims.take_back();
5189  }
5190 
5191  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5192 }
5193 
5194 /// Folds qualifying shape_cast(create_mask) into a new create_mask
5195 ///
5196 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
5197 /// dimension. If the input vector comes from `vector.create_mask` for which
5198 /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
5199 /// to fold shape_cast into create_mask.
5200 ///
5201 /// BEFORE:
5202 /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
5203 /// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
5204 /// AFTER:
5205 /// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
5206 class ShapeCastCreateMaskFolderTrailingOneDim final
5207  : public OpRewritePattern<ShapeCastOp> {
5208 public:
5210 
5211  LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
5212  PatternRewriter &rewriter) const override {
5213  Value shapeOpSrc = shapeOp->getOperand(0);
5214  auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
5215  auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
5216  if (!createMaskOp && !constantMaskOp)
5217  return failure();
5218 
5219  VectorType shapeOpResTy = shapeOp.getResultVectorType();
5220  VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5221 
5222  VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5223  if (newVecType != shapeOpResTy)
5224  return failure();
5225 
5226  auto numDimsToDrop =
5227  shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5228 
5229  // No unit dims to drop
5230  if (!numDimsToDrop)
5231  return failure();
5232 
5233  if (createMaskOp) {
5234  auto maskOperands = createMaskOp.getOperands();
5235  auto numMaskOperands = maskOperands.size();
5236 
5237  // Check every mask dim size to see whether it can be dropped
5238  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5239  --i) {
5240  auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5241  if (!constant || (constant.value() != 1))
5242  return failure();
5243  }
5244  SmallVector<Value> newMaskOperands =
5245  maskOperands.drop_back(numDimsToDrop);
5246 
5247  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
5248  newMaskOperands);
5249  return success();
5250  }
5251 
5252  if (constantMaskOp) {
5253  auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
5254  auto numMaskOperands = maskDimSizes.size();
5255 
5256  // Check every mask dim size to see whether it can be dropped
5257  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5258  --i) {
5259  if (cast<IntegerAttr>(maskDimSizes[i]).getValue() != 1)
5260  return failure();
5261  }
5262 
5263  auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5264  ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr(newMaskOperands);
5265 
5266  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
5267  newMaskOperandsAttr);
5268  return success();
5269  }
5270 
5271  return failure();
5272  }
5273 };
5274 
5275 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
5276 /// This only applies when the shape of the broadcast source
5277 /// 1. is a suffix of the shape of the result (i.e. when broadcast without
5278 /// reshape is expressive enough to capture the result in a single op), or
5279 /// 2. has the same element count as the shape cast result.
5280 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5281 public:
5283 
5284  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5285  PatternRewriter &rewriter) const override {
5286  auto broadcastOp =
5287  shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5288  if (!broadcastOp)
5289  return failure();
5290 
5291  ArrayRef<int64_t> broadcastSourceShape;
5292  if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5293  broadcastSourceShape = srcType.getShape();
5294  ArrayRef<int64_t> shapeCastTargetShape =
5295  shapeCastOp.getResultVectorType().getShape();
5296 
5297  // If `broadcastSourceShape` is a suffix of the result, we can just replace
5298  // with a broadcast to the final shape.
5299  if (broadcastSourceShape ==
5300  shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5301  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5302  shapeCastOp, shapeCastOp.getResultVectorType(),
5303  broadcastOp.getSource());
5304  return success();
5305  }
5306 
5307  // Otherwise, if the final result has the same element count, we can replace
5308  // with a shape cast.
5309  if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5310  if (srcType.getNumElements() ==
5311  shapeCastOp.getResultVectorType().getNumElements()) {
5312  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
5313  shapeCastOp, shapeCastOp.getResultVectorType(),
5314  broadcastOp.getSource());
5315  return success();
5316  }
5317  }
5318 
5319  return failure();
5320  }
5321 };
5322 
5323 } // namespace
5324 
5325 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
5326  MLIRContext *context) {
5327  results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5328  ShapeCastBroadcastFolder>(context);
5329 }
5330 
5331 //===----------------------------------------------------------------------===//
5332 // VectorBitCastOp
5333 //===----------------------------------------------------------------------===//
5334 
5336  auto sourceVectorType = getSourceVectorType();
5337  auto resultVectorType = getResultVectorType();
5338 
5339  for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5340  if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5341  return emitOpError("dimension size mismatch at: ") << i;
5342  }
5343 
5344  DataLayout dataLayout = DataLayout::closest(*this);
5345  auto sourceElementBits =
5346  dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
5347  auto resultElementBits =
5348  dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
5349 
5350  if (sourceVectorType.getRank() == 0) {
5351  if (sourceElementBits != resultElementBits)
5352  return emitOpError("source/result bitwidth of the 0-D vector element "
5353  "types must be equal");
5354  } else if (sourceElementBits * sourceVectorType.getShape().back() !=
5355  resultElementBits * resultVectorType.getShape().back()) {
5356  return emitOpError(
5357  "source/result bitwidth of the minor 1-D vectors must be equal");
5358  }
5359 
5360  return success();
5361 }
5362 
5363 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
5364  // Nop cast.
5365  if (getSource().getType() == getResult().getType())
5366  return getSource();
5367 
5368  // Canceling bitcasts.
5369  if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5370  if (getResult().getType() == otherOp.getSource().getType())
5371  return otherOp.getSource();
5372 
5373  setOperand(otherOp.getSource());
5374  return getResult();
5375  }
5376 
5377  Attribute sourceConstant = adaptor.getSource();
5378  if (!sourceConstant)
5379  return {};
5380 
5381  Type srcElemType = getSourceVectorType().getElementType();
5382  Type dstElemType = getResultVectorType().getElementType();
5383 
5384  if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5385  if (floatPack.isSplat()) {
5386  auto splat = floatPack.getSplatValue<FloatAttr>();
5387 
5388  // Casting fp16 into fp32.
5389  if (srcElemType.isF16() && dstElemType.isF32()) {
5390  uint32_t bits = static_cast<uint32_t>(
5391  splat.getValue().bitcastToAPInt().getZExtValue());
5392  // Duplicate the 16-bit pattern.
5393  bits = (bits << 16) | (bits & 0xffff);
5394  APInt intBits(32, bits);
5395  APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5396  return DenseElementsAttr::get(getResultVectorType(), floatBits);
5397  }
5398  }
5399  }
5400 
5401  if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5402  if (intPack.isSplat()) {
5403  auto splat = intPack.getSplatValue<IntegerAttr>();
5404 
5405  if (llvm::isa<IntegerType>(dstElemType)) {
5406  uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
5407  uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
5408 
5409  // Casting to a larger integer bit width.
5410  if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5411  APInt intBits = splat.getValue().zext(dstBitWidth);
5412 
5413  // Duplicate the lower width element.
5414  for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5415  intBits = (intBits << srcBitWidth) | intBits;
5416  return DenseElementsAttr::get(getResultVectorType(), intBits);
5417  }
5418  }
5419  }
5420  }
5421 
5422  return {};
5423 }
5424 
5425 //===----------------------------------------------------------------------===//
5426 // TypeCastOp
5427 //===----------------------------------------------------------------------===//
5428 
5429 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
5430  auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5431  SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
5432  memRefType.getShape().end());
5433  if (vectorType)
5434  res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5435  return res;
5436 }
5437 
5438 /// Build the canonical memRefType with a single vector.
5439 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
5440 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
5441  Value source) {
5442  result.addOperands(source);
5443  MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
5444  VectorType vectorType =
5445  VectorType::get(extractShape(memRefType),
5447  result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
5448  memRefType.getMemorySpace()));
5449 }
5450 
5452  MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType());
5453  if (!canonicalType.getLayout().isIdentity())
5454  return emitOpError("expects operand to be a memref with identity layout");
5455  if (!getResultMemRefType().getLayout().isIdentity())
5456  return emitOpError("expects result to be a memref with identity layout");
5457  if (getResultMemRefType().getMemorySpace() !=
5458  getMemRefType().getMemorySpace())
5459  return emitOpError("expects result in same memory space");
5460 
5461  auto sourceType = getMemRefType();
5462  auto resultType = getResultMemRefType();
5463  if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
5465  return emitOpError(
5466  "expects result and operand with same underlying scalar type: ")
5467  << resultType;
5468  if (extractShape(sourceType) != extractShape(resultType))
5469  return emitOpError(
5470  "expects concatenated result and operand shapes to be equal: ")
5471  << resultType;
5472  return success();
5473 }
5474 
5475 //===----------------------------------------------------------------------===//
5476 // TransposeOp
5477 //===----------------------------------------------------------------------===//
5478 
5479 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
5480  Value vector, ArrayRef<int64_t> permutation) {
5481  VectorType vt = llvm::cast<VectorType>(vector.getType());
5482  SmallVector<int64_t, 4> transposedShape(vt.getRank());
5483  SmallVector<bool, 4> transposedScalableDims(vt.getRank());
5484  for (unsigned i = 0; i < permutation.size(); ++i) {
5485  transposedShape[i] = vt.getShape()[permutation[i]];
5486  transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5487  }
5488 
5489  result.addOperands(vector);
5490  result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
5491  transposedScalableDims));
5492  result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
5493  builder.getDenseI64ArrayAttr(permutation));
5494 }
5495 
5496 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5497  // Eliminate splat constant transpose ops.
5498  if (auto attr =
5499  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
5500  if (attr.isSplat())
5501  return attr.reshape(getResultVectorType());
5502 
5503  // Eliminate identity transpose ops. This happens when the dimensions of the
5504  // input vector remain in their original order after the transpose operation.
5505  ArrayRef<int64_t> perm = getPermutation();
5506 
5507  // Check if the permutation of the dimensions contains sequential values:
5508  // {0, 1, 2, ...}.
5509  for (int64_t i = 0, e = perm.size(); i < e; i++) {
5510  if (perm[i] != i)
5511  return {};
5512  }
5513 
5514  return getVector();
5515 }
5516 
5518  VectorType vectorType = getSourceVectorType();
5519  VectorType resultType = getResultVectorType();
5520  int64_t rank = resultType.getRank();
5521  if (vectorType.getRank() != rank)
5522  return emitOpError("vector result rank mismatch: ") << rank;
5523  // Verify transposition array.
5524  ArrayRef<int64_t> perm = getPermutation();
5525  int64_t size = perm.size();
5526  if (rank != size)
5527  return emitOpError("transposition length mismatch: ") << size;
5528  SmallVector<bool, 8> seen(rank, false);
5529  for (const auto &ta : llvm::enumerate(perm)) {
5530  if (ta.value() < 0 || ta.value() >= rank)
5531  return emitOpError("transposition index out of range: ") << ta.value();
5532  if (seen[ta.value()])
5533  return emitOpError("duplicate position index: ") << ta.value();
5534  seen[ta.value()] = true;
5535  if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
5536  return emitOpError("dimension size mismatch at: ") << ta.value();
5537  }
5538  return success();
5539 }
5540 
5541 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
5542  return llvm::to_vector<4>(getResultVectorType().getShape());
5543 }
5544 
5545 namespace {
5546 
5547 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
5548 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
5549 public:
5551 
5552  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5553  PatternRewriter &rewriter) const override {
5554  // Composes two permutations: result[i] = permutation1[permutation2[i]].
5555  auto composePermutations = [](ArrayRef<int64_t> permutation1,
5556  ArrayRef<int64_t> permutation2) {
5557  SmallVector<int64_t, 4> result;
5558  for (auto index : permutation2)
5559  result.push_back(permutation1[index]);
5560  return result;
5561  };
5562 
5563  // Return if the input of 'transposeOp' is not defined by another transpose.
5564  vector::TransposeOp parentTransposeOp =
5565  transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
5566  if (!parentTransposeOp)
5567  return failure();
5568 
5569  SmallVector<int64_t, 4> permutation = composePermutations(
5570  parentTransposeOp.getPermutation(), transposeOp.getPermutation());
5571  // Replace 'transposeOp' with a new transpose operation.
5572  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
5573  transposeOp, transposeOp.getResult().getType(),
5574  parentTransposeOp.getVector(), permutation);
5575  return success();
5576  }
5577 };
5578 
5579 // Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
5580 struct FoldTransposedScalarBroadcast final
5581  : public OpRewritePattern<vector::TransposeOp> {
5583 
5584  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5585  PatternRewriter &rewriter) const override {
5586  auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
5587  if (!bcastOp)
5588  return failure();
5589 
5590  auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
5591  if (!srcVectorType || srcVectorType.getNumElements() == 1) {
5592  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5593  transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
5594  return success();
5595  }
5596 
5597  return failure();
5598  }
5599 };
5600 
5601 // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
5602 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
5603 public:
5605 
5606  LogicalResult matchAndRewrite(TransposeOp transposeOp,
5607  PatternRewriter &rewriter) const override {
5608  auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
5609  if (!splatOp)
5610  return failure();
5611 
5612  rewriter.replaceOpWithNewOp<vector::SplatOp>(
5613  transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
5614  return success();
5615  }
5616 };
5617 
5618 /// Folds transpose(create_mask) into a new transposed create_mask.
5619 class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
5620 public:
5622 
5623  LogicalResult matchAndRewrite(TransposeOp transpOp,
5624  PatternRewriter &rewriter) const override {
5625  Value transposeSrc = transpOp.getVector();
5626  auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
5627  auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
5628  if (!createMaskOp && !constantMaskOp)
5629  return failure();
5630 
5631  // Get the transpose permutation and apply it to the vector.create_mask or
5632  // vector.constant_mask operands.
5633  ArrayRef<int64_t> permutation = transpOp.getPermutation();
5634 
5635  if (createMaskOp) {
5636  auto maskOperands = createMaskOp.getOperands();
5637  SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
5638  applyPermutationToVector(newOperands, permutation);
5639 
5640  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
5641  transpOp, transpOp.getResultVectorType(), newOperands);
5642  return success();
5643  }
5644 
5645  // ConstantMaskOp case.
5646  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5647  SmallVector<Attribute> newMaskDimSizes(maskDimSizes.getValue());
5648  applyPermutationToVector(newMaskDimSizes, permutation);
5649 
5650  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
5651  transpOp, transpOp.getResultVectorType(),
5652  ArrayAttr::get(transpOp.getContext(), newMaskDimSizes));
5653  return success();
5654  }
5655 };
5656 
5657 } // namespace
5658 
5659 void vector::TransposeOp::getCanonicalizationPatterns(
5660  RewritePatternSet &results, MLIRContext *context) {
5661  results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
5662  TransposeFolder, FoldTransposeSplat>(context);
5663 }
5664 
5665 //===----------------------------------------------------------------------===//
5666 // ConstantMaskOp
5667 //===----------------------------------------------------------------------===//
5668 
5670  auto resultType = llvm::cast<VectorType>(getResult().getType());
5671  // Check the corner case of 0-D vectors first.
5672  if (resultType.getRank() == 0) {
5673  if (getMaskDimSizes().size() != 1)
5674  return emitError("array attr must have length 1 for 0-D vectors");
5675  auto dim = llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt();
5676  if (dim != 0 && dim != 1)
5677  return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
5678  return success();
5679  }
5680 
5681  // Verify that array attr size matches the rank of the vector result.
5682  if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
5683  return emitOpError(
5684  "must specify array attr of size equal vector result rank");
5685  // Verify that each array attr element is in bounds of corresponding vector
5686  // result dimension size.
5687  auto resultShape = resultType.getShape();
5688  auto resultScalableDims = resultType.getScalableDims();
5689  SmallVector<int64_t, 4> maskDimSizes;
5690  for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
5691  int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
5692  if (maskDimSize < 0 || maskDimSize > resultShape[index])
5693  return emitOpError(
5694  "array attr of size out of bounds of vector result dimension size");
5695  if (resultScalableDims[index] && maskDimSize != 0 &&
5696  maskDimSize != resultShape[index])
5697  return emitOpError(
5698  "only supports 'none set' or 'all set' scalable dimensions");
5699  maskDimSizes.push_back(maskDimSize);
5700  }
5701  // Verify that if one mask dim size is zero, they all should be zero (because
5702  // the mask region is a conjunction of each mask dimension interval).
5703  bool anyZeros = llvm::is_contained(maskDimSizes, 0);
5704  bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
5705  if (anyZeros && !allZeros)
5706  return emitOpError("expected all mask dim sizes to be zeros, "
5707  "as a result of conjunction with zero mask dim");
5708  return success();
5709 }
5710 
5711 bool ConstantMaskOp::isAllOnesMask() {
5712  auto resultType = getVectorType();
5713  // Check the corner case of 0-D vectors first.
5714  if (resultType.getRank() == 0) {
5715  assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
5716  return llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() == 1;
5717  }
5718  for (const auto [resultSize, intAttr] :
5719  llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
5720  int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
5721  if (maskDimSize < resultSize)
5722  return false;
5723  }
5724  return true;
5725 }
5726 
5727 //===----------------------------------------------------------------------===//
5728 // CreateMaskOp
5729 //===----------------------------------------------------------------------===//
5730 
5731 void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
5732  VectorType type,
5733  ArrayRef<OpFoldResult> mixedOperands) {
5734  SmallVector<Value> operands =
5735  getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
5736  build(builder, result, type, operands);
5737 }
5738 
5740  auto vectorType = llvm::cast<VectorType>(getResult().getType());
5741  // Verify that an operand was specified for each result vector each dimension.
5742  if (vectorType.getRank() == 0) {
5743  if (getNumOperands() != 1)
5744  return emitOpError(
5745  "must specify exactly one operand for 0-D create_mask");
5746  } else if (getNumOperands() !=
5747  llvm::cast<VectorType>(getResult().getType()).getRank()) {
5748  return emitOpError(
5749  "must specify an operand for each result vector dimension");
5750  }
5751  return success();
5752 }
5753 
5754 namespace {
5755 
5756 /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
5757 ///
5758 /// Ex 1:
5759 /// %c2 = arith.constant 2 : index
5760 /// %c3 = arith.constant 3 : index
5761 /// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
5762 /// Becomes:
5763 /// vector.constant_mask [3, 2] : vector<4x3xi1>
5764 ///
5765 /// Ex 2:
5766 /// %c_neg_1 = arith.constant -1 : index
5767 /// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
5768 /// becomes:
5769 /// vector.constant_mask [0] : vector<[8]xi1>
5770 ///
5771 /// Ex 3:
5772 /// %c8 = arith.constant 8 : index
5773 /// %c16 = arith.constant 16 : index
5774 /// %0 = vector.vscale
5775 /// %1 = arith.muli %0, %c16 : index
5776 /// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
5777 /// becomes:
5778 /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
5779 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
5780 public:
5782 
5783  LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
5784  PatternRewriter &rewriter) const override {
5785  VectorType retTy = createMaskOp.getResult().getType();
5786  bool isScalable = retTy.isScalable();
5787 
5788  // Check every mask operand
5789  for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
5790  if (auto cst = getConstantIntValue(operand)) {
5791  // Most basic case - this operand is a constant value. Note that for
5792  // scalable dimensions, CreateMaskOp can be folded only if the
5793  // corresponding operand is negative or zero.
5794  if (retTy.getScalableDims()[opIdx] && *cst > 0)
5795  return failure();
5796 
5797  continue;
5798  }
5799 
5800  // Non-constant operands are not allowed for non-scalable vectors.
5801  if (!isScalable)
5802  return failure();
5803 
5804  // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
5805  // true" mask, so can also be treated as constant.
5806  auto mul = operand.getDefiningOp<arith::MulIOp>();
5807  if (!mul)
5808  return failure();
5809  auto mulLHS = mul.getRhs();
5810  auto mulRHS = mul.getLhs();
5811  bool isOneOpVscale =
5812  (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
5813  isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));
5814 
5815  auto isConstantValMatchingDim =
5816  [=, dim = retTy.getShape()[opIdx]](Value operand) {
5817  auto constantVal = getConstantIntValue(operand);
5818  return (constantVal.has_value() && constantVal.value() == dim);
5819  };
5820 
5821  bool isOneOpConstantMatchingDim =
5822  isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);
5823 
5824  if (!isOneOpVscale || !isOneOpConstantMatchingDim)
5825  return failure();
5826  }
5827 
5828  // Gather constant mask dimension sizes.
5829  SmallVector<int64_t, 4> maskDimSizes;
5830  maskDimSizes.reserve(createMaskOp->getNumOperands());
5831  for (auto [operand, maxDimSize] : llvm::zip_equal(
5832  createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
5833  std::optional dimSize = getConstantIntValue(operand);
5834  if (!dimSize) {
5835  // Although not a constant, it is safe to assume that `operand` is
5836  // "vscale * maxDimSize".
5837  maskDimSizes.push_back(maxDimSize);
5838  continue;
5839  }
5840  int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize);
5841  // If one of dim sizes is zero, set all dims to zero.
5842  if (dimSize <= 0) {
5843  maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
5844  break;
5845  }
5846  maskDimSizes.push_back(dimSizeVal);
5847  }
5848 
5849  // Replace 'createMaskOp' with ConstantMaskOp.
5850  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
5851  createMaskOp, retTy,
5852  vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
5853  return success();
5854  }
5855 };
5856 
5857 } // namespace
5858 
5859 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
5860  MLIRContext *context) {
5861  results.add<CreateMaskFolder>(context);
5862 }
5863 
5864 //===----------------------------------------------------------------------===//
5865 // MaskOp
5866 //===----------------------------------------------------------------------===//
5867 
5868 void MaskOp::build(
5869  OpBuilder &builder, OperationState &result, Value mask,
5870  Operation *maskableOp,
5871  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
5872  assert(maskRegionBuilder &&
5873  "builder callback for 'maskRegion' must be present");
5874 
5875  result.addOperands(mask);
5876  OpBuilder::InsertionGuard guard(builder);
5877  Region *maskRegion = result.addRegion();
5878  builder.createBlock(maskRegion);
5879  maskRegionBuilder(builder, maskableOp);
5880 }
5881 
5882 void MaskOp::build(
5883  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
5884  Value mask, Operation *maskableOp,
5885  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
5886  build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
5887  maskRegionBuilder);
5888 }
5889 
5890 void MaskOp::build(
5891  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
5892  Value mask, Value passthru, Operation *maskableOp,
5893  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
5894  build(builder, result, mask, maskableOp, maskRegionBuilder);
5895  if (passthru)
5896  result.addOperands(passthru);
5897  result.addTypes(resultTypes);
5898 }
5899 
5901  // Create the op region.
5902  result.regions.reserve(1);
5903  Region &maskRegion = *result.addRegion();
5904 
5905  auto &builder = parser.getBuilder();
5906 
5907  // Parse all the operands.
5909  if (parser.parseOperand(mask))
5910  return failure();
5911 
5912  // Optional passthru operand.
5914  ParseResult parsePassthru = parser.parseOptionalComma();
5915  if (parsePassthru.succeeded() && parser.parseOperand(passthru))
5916  return failure();
5917 
5918  // Parse op region.
5919  if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
5920  return failure();
5921 
5922  MaskOp::ensureTerminator(maskRegion, builder, result.location);
5923 
5924  // Parse the optional attribute list.
5925  if (parser.parseOptionalAttrDict(result.attributes))
5926  return failure();
5927 
5928  // Parse all the types.
5929  Type maskType;
5930  if (parser.parseColonType(maskType))
5931  return failure();
5932 
5933  SmallVector<Type> resultTypes;
5934  if (parser.parseOptionalArrowTypeList(resultTypes))
5935  return failure();
5936  result.types.append(resultTypes);
5937 
5938  // Resolve operands.
5939  if (parser.resolveOperand(mask, maskType, result.operands))
5940  return failure();
5941 
5942  if (parsePassthru.succeeded())
5943  if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
5944  return failure();
5945 
5946  return success();
5947 }
5948 
5950  p << " " << getMask();
5951  if (getPassthru())
5952  p << ", " << getPassthru();
5953 
5954  // Print single masked operation and skip terminator.
5955  p << " { ";
5956  Block *singleBlock = &getMaskRegion().getBlocks().front();
5957  if (singleBlock && !singleBlock->getOperations().empty())
5958  p.printCustomOrGenericOp(&singleBlock->front());
5959  p << " }";
5960 
5961  p.printOptionalAttrDict(getOperation()->getAttrs());
5962 
5963  p << " : " << getMask().getType();
5964  if (getNumResults() > 0)
5965  p << " -> " << getResultTypes();
5966 }
5967 
5968 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
5970  MaskOp>::ensureTerminator(region, builder, loc);
5971  // Keep the default yield terminator if the number of masked operations is not
5972  // the expected. This case will trigger a verification failure.
5973  Block &block = region.front();
5974  if (block.getOperations().size() != 2)
5975  return;
5976 
5977  // Replace default yield terminator with a new one that returns the results
5978  // from the masked operation.
5979  OpBuilder opBuilder(builder.getContext());
5980  Operation *maskedOp = &block.front();
5981  Operation *oldYieldOp = &block.back();
5982  assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
5983 
5984  // Empty vector.mask op.
5985  if (maskedOp == oldYieldOp)
5986  return;
5987 
5988  opBuilder.setInsertionPoint(oldYieldOp);
5989  opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
5990  oldYieldOp->dropAllReferences();
5991  oldYieldOp->erase();
5992 }
5993 
5995  // Structural checks.
5996  Block &block = getMaskRegion().getBlocks().front();
5997  if (block.getOperations().empty())
5998  return emitOpError("expects a terminator within the mask region");
5999  if (block.getOperations().size() > 2)
6000  return emitOpError("expects only one operation to mask");
6001 
6002  // Terminator checks.
6003  auto terminator = dyn_cast<vector::YieldOp>(block.back());
6004  if (!terminator)
6005  return emitOpError("expects a terminator within the mask region");
6006 
6007  if (terminator->getNumOperands() != getNumResults())
6008  return emitOpError(
6009  "expects number of results to match mask region yielded values");
6010 
6011  auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
6012  // Empty vector.mask. Nothing else to check.
6013  if (!maskableOp)
6014  return success();
6015 
6016  // Result checks.
6017  if (maskableOp->getNumResults() != getNumResults())
6018  return emitOpError("expects number of results to match maskable operation "
6019  "number of results");
6020 
6021  if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
6022  return emitOpError(
6023  "expects result type to match maskable operation result type");
6024 
6025  if (llvm::count_if(maskableOp->getResultTypes(),
6026  [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
6027  return emitOpError("multiple vector results not supported");
6028 
6029  // Mask checks.
6030  Type expectedMaskType = maskableOp.getExpectedMaskType();
6031  if (getMask().getType() != expectedMaskType)
6032  return emitOpError("expects a ")
6033  << expectedMaskType << " mask for the maskable operation";
6034 
6035  // Passthru checks.
6036  Value passthru = getPassthru();
6037  if (passthru) {
6038  if (!maskableOp.supportsPassthru())
6039  return emitOpError(
6040  "doesn't expect a passthru argument for this maskable operation");
6041 
6042  if (maskableOp->getNumResults() != 1)
6043  return emitOpError("expects result when passthru argument is provided");
6044 
6045  if (passthru.getType() != maskableOp->getResultTypes()[0])
6046  return emitOpError("expects passthru type to match result type");
6047  }
6048 
6049  return success();
6050 }
6051 
6052 /// Folds vector.mask ops with an all-true mask.
6053 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6054  SmallVectorImpl<OpFoldResult> &results) {
6055  MaskFormat maskFormat = getMaskFormat(getMask());
6056  if (isEmpty())
6057  return failure();
6058 
6059  if (maskFormat != MaskFormat::AllTrue)
6060  return failure();
6061 
6062  // Move maskable operation outside of the `vector.mask` region.
6063  Operation *maskableOp = getMaskableOp();
6064  maskableOp->dropAllUses();
6065  maskableOp->moveBefore(getOperation());
6066 
6067  llvm::append_range(results, maskableOp->getResults());
6068  return success();
6069 }
6070 
6071 // Elides empty vector.mask operations with or without return values. Propagates
6072 // the yielded values by the vector.yield terminator, if any, or erases the op,
6073 // otherwise.
6074 class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
6076 
6077  LogicalResult matchAndRewrite(MaskOp maskOp,
6078  PatternRewriter &rewriter) const override {
6079  auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6080  if (maskingOp.getMaskableOp())
6081  return failure();
6082 
6083  if (!maskOp.isEmpty())
6084  return failure();
6085 
6086  Block *block = maskOp.getMaskBlock();
6087  auto terminator = cast<vector::YieldOp>(block->front());
6088  if (terminator.getNumOperands() == 0)
6089  rewriter.eraseOp(maskOp);
6090  else
6091  rewriter.replaceOp(maskOp, terminator.getOperands());
6092 
6093  return success();
6094  }
6095 };
6096 
6097 void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6098  MLIRContext *context) {
6099  results.add<ElideEmptyMaskOp>(context);
6100 }
6101 
6102 // MaskingOpInterface definitions.
6103 
6104 /// Returns the operation masked by this 'vector.mask'.
6105 Operation *MaskOp::getMaskableOp() {
6106  Block *block = getMaskBlock();
6107  if (block->getOperations().size() < 2)
6108  return nullptr;
6109 
6110  return &block->front();
6111 }
6112 
6113 /// Returns true if 'vector.mask' has a passthru value.
6114 bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
6115 
6116 //===----------------------------------------------------------------------===//
6117 // ScanOp
6118 //===----------------------------------------------------------------------===//
6119 
6121  VectorType srcType = getSourceType();
6122  VectorType initialType = getInitialValueType();
6123  // Check reduction dimension < rank.
6124  int64_t srcRank = srcType.getRank();
6125  int64_t reductionDim = getReductionDim();
6126  if (reductionDim >= srcRank)
6127  return emitOpError("reduction dimension ")
6128  << reductionDim << " has to be less than " << srcRank;
6129 
6130  // Check that rank(initial_value) = rank(src) - 1.
6131  int64_t initialValueRank = initialType.getRank();
6132  if (initialValueRank != srcRank - 1)
6133  return emitOpError("initial value rank ")
6134  << initialValueRank << " has to be equal to " << srcRank - 1;
6135 
6136  // Check shapes of initial value and src.
6137  ArrayRef<int64_t> srcShape = srcType.getShape();
6138  ArrayRef<int64_t> initialValueShapes = initialType.getShape();
6139  SmallVector<int64_t> expectedShape;
6140  for (int i = 0; i < srcRank; i++) {
6141  if (i != reductionDim)
6142  expectedShape.push_back(srcShape[i]);
6143  }
6144  if (!llvm::equal(initialValueShapes, expectedShape)) {
6145  return emitOpError("incompatible input/initial value shapes");
6146  }
6147 
6148  // Verify supported reduction kind.
6149  Type eltType = getDestType().getElementType();
6150  if (!isSupportedCombiningKind(getKind(), eltType))
6151  return emitOpError("unsupported reduction type ")
6152  << eltType << " for kind '" << stringifyCombiningKind(getKind())
6153  << "'";
6154 
6155  return success();
6156 }
6157 
6159  RewritePatternSet &patterns, PatternBenefit benefit) {
6160  patterns
6161  .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
6162  ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
6163  StridedSliceConstantMaskFolder, TransposeFolder>(
6164  patterns.getContext(), benefit);
6165 }
6166 
6167 //===----------------------------------------------------------------------===//
6168 // SplatOp
6169 //===----------------------------------------------------------------------===//
6170 
6171 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
6172  auto constOperand = adaptor.getInput();
6173  if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
6174  return {};
6175 
6176  // SplatElementsAttr::get treats single value for second arg as being a splat.
6177  return SplatElementsAttr::get(getType(), {constOperand});
6178 }
6179 
6180 //===----------------------------------------------------------------------===//
6181 // WarpExecuteOnLane0Op
6182 //===----------------------------------------------------------------------===//
6183 
6185  p << "(" << getLaneid() << ")";
6186 
6187  SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
6188  auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
6189  p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
6190 
6191  if (!getArgs().empty())
6192  p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
6193  if (!getResults().empty())
6194  p << " -> (" << getResults().getTypes() << ')';
6195  p << " ";
6196  p.printRegion(getRegion(),
6197  /*printEntryBlockArgs=*/true,
6198  /*printBlockTerminators=*/!getResults().empty());
6199  p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
6200 }
6201 
6203  OperationState &result) {
6204  // Create the region.
6205  result.regions.reserve(1);
6206  Region *warpRegion = result.addRegion();
6207 
6208  auto &builder = parser.getBuilder();
6210 
6211  // Parse predicate operand.
6212  if (parser.parseLParen() ||
6213  parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
6214  parser.parseRParen())
6215  return failure();
6216 
6217  int64_t warpSize;
6218  if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
6219  parser.parseRSquare())
6220  return failure();
6221  result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
6222  builder.getContext())),
6223  builder.getI64IntegerAttr(warpSize));
6224 
6225  if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
6226  return failure();
6227 
6228  llvm::SMLoc inputsOperandsLoc;
6230  SmallVector<Type> inputTypes;
6231  if (succeeded(parser.parseOptionalKeyword("args"))) {
6232  if (parser.parseLParen())
6233  return failure();
6234 
6235  inputsOperandsLoc = parser.getCurrentLocation();
6236  if (parser.parseOperandList(inputsOperands) ||
6237  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
6238  return failure();
6239  }
6240  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
6241  result.operands))
6242  return failure();
6243 
6244  // Parse optional results type list.
6245  if (parser.parseOptionalArrowTypeList(result.types))
6246  return failure();
6247  // Parse the region.
6248  if (parser.parseRegion(*warpRegion, /*arguments=*/{},
6249  /*argTypes=*/{}))
6250  return failure();
6251  WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
6252 
6253  // Parse the optional attribute list.
6254  if (parser.parseOptionalAttrDict(result.attributes))
6255  return failure();
6256  return success();
6257 }
6258 
6259 void WarpExecuteOnLane0Op::getSuccessorRegions(
6261  if (!point.isParent()) {
6262  regions.push_back(RegionSuccessor(getResults()));
6263  return;
6264  }
6265 
6266  // The warp region is always executed
6267  regions.push_back(RegionSuccessor(&getWarpRegion()));
6268 }
6269 
6270 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
6271  TypeRange resultTypes, Value laneId,
6272  int64_t warpSize) {
6273  build(builder, result, resultTypes, laneId, warpSize,
6274  /*operands=*/std::nullopt, /*argTypes=*/std::nullopt);
6275 }
6276 
6277 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
6278  TypeRange resultTypes, Value laneId,
6279  int64_t warpSize, ValueRange args,
6280  TypeRange blockArgTypes) {
6281  result.addOperands(laneId);
6282  result.addAttribute(getAttributeNames()[0],
6283  builder.getI64IntegerAttr(warpSize));
6284  result.addTypes(resultTypes);
6285  result.addOperands(args);
6286  assert(args.size() == blockArgTypes.size());
6287  OpBuilder::InsertionGuard guard(builder);
6288  Region *warpRegion = result.addRegion();
6289  Block *block = builder.createBlock(warpRegion);
6290  for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
6291  block->addArgument(type, arg.getLoc());
6292 }
6293 
6294 /// Helper check if the distributed vector type is consistent with the expanded
6295 /// type and distributed size.
6296 static LogicalResult verifyDistributedType(Type expanded, Type distributed,
6297  int64_t warpSize, Operation *op) {
6298  // If the types matches there is no distribution.
6299  if (expanded == distributed)
6300  return success();
6301  auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
6302  auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
6303  if (!expandedVecType || !distributedVecType)
6304  return op->emitOpError("expected vector type for distributed operands.");
6305  if (expandedVecType.getRank() != distributedVecType.getRank() ||
6306  expandedVecType.getElementType() != distributedVecType.getElementType())
6307  return op->emitOpError(
6308  "expected distributed vectors to have same rank and element type.");
6309 
6310  SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
6311  for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
6312  int64_t eDim = expandedVecType.getDimSize(i);
6313  int64_t dDim = distributedVecType.getDimSize(i);
6314  if (eDim == dDim)
6315  continue;
6316  if (eDim % dDim != 0)
6317  return op->emitOpError()
6318  << "expected expanded vector dimension #" << i << " (" << eDim
6319  << ") to be a multipler of the distributed vector dimension ("
6320  << dDim << ")";
6321  scales[i] = eDim / dDim;
6322  }
6323  if (std::accumulate(scales.begin(), scales.end(), 1,
6324  std::multiplies<int64_t>()) != warpSize)
6325  return op->emitOpError()
6326  << "incompatible distribution dimensions from " << expandedVecType
6327  << " to " << distributedVecType << " with warp size = " << warpSize;
6328 
6329  return success();
6330 }
6331 
6333  if (getArgs().size() != getWarpRegion().getNumArguments())
6334  return emitOpError(
6335  "expected same number op arguments and block arguments.");
6336  auto yield =
6337  cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
6338  if (yield.getNumOperands() != getNumResults())
6339  return emitOpError(
6340  "expected same number of yield operands and return values.");
6341  int64_t warpSize = getWarpSize();
6342  for (auto [regionArg, arg] :
6343  llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
6344  if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
6345  warpSize, getOperation())))
6346  return failure();
6347  }
6348  for (auto [yieldOperand, result] :
6349  llvm::zip_equal(yield.getOperands(), getResults())) {
6350  if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
6351  warpSize, getOperation())))
6352  return failure();
6353  }
6354  return success();
6355 }
6356 
6357 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
6358  return succeeded(
6359  verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
6360 }
6361 
6363  CombiningKind kind, Value v1, Value acc,
6364  arith::FastMathFlagsAttr fastmath,
6365  Value mask) {
6366  Type t1 = getElementTypeOrSelf(v1.getType());
6367  Type tAcc = getElementTypeOrSelf(acc.getType());
6368  Value result;
6369 
6370  switch (kind) {
6371  case CombiningKind::ADD:
6372  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6373  result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
6374  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6375  result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
6376  else
6377  llvm_unreachable("invalid value types for ADD reduction");
6378  break;
6379  case CombiningKind::AND:
6380  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6381  result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
6382  break;
6383  case CombiningKind::MAXNUMF:
6384  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6385  "expected float values");
6386  result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
6387  break;
6388  case CombiningKind::MAXIMUMF:
6389  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6390  "expected float values");
6391  result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
6392  break;
6393  case CombiningKind::MINNUMF:
6394  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6395  "expected float values");
6396  result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
6397  break;
6398  case CombiningKind::MINIMUMF:
6399  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6400  "expected float values");
6401  result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
6402  break;
6403  case CombiningKind::MAXSI:
6404  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6405  result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
6406  break;
6407  case CombiningKind::MINSI:
6408  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6409  result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
6410  break;
6411  case CombiningKind::MAXUI:
6412  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6413  result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
6414  break;
6415  case CombiningKind::MINUI:
6416  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6417  result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
6418  break;
6419  case CombiningKind::MUL:
6420  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6421  result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
6422  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6423  result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
6424  else
6425  llvm_unreachable("invalid value types for MUL reduction");
6426  break;
6427  case CombiningKind::OR:
6428  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6429  result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
6430  break;
6431  case CombiningKind::XOR:
6432  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6433  result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
6434  break;
6435  };
6436 
6437  assert(result && "unknown CombiningKind");
6438  return selectPassthru(b, mask, result, acc);
6439 }
6440 
6441 //===----------------------------------------------------------------------===//
6442 // Vector Masking Utilities
6443 //===----------------------------------------------------------------------===//
6444 
6445 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
6446 /// as masked operation.
6448  Operation *maskableOp) {
6449  assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
6450  Block *insBlock = builder.getInsertionBlock();
6451  // Create a block and move the op to that block.
6452  insBlock->getOperations().splice(
6453  insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
6454  builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
6455 }
6456 
6457 /// Creates a vector.mask operation around a maskable operation. Returns the
6458 /// vector.mask operation if the mask provided is valid. Otherwise, returns
6459 /// the maskable operation itself.
6461  Operation *maskableOp, Value mask,
6462  Value passthru) {
6463  if (!mask)
6464  return maskableOp;
6465  if (passthru)
6466  return builder.create<MaskOp>(maskableOp->getLoc(),
6467  maskableOp->getResultTypes(), mask, passthru,
6468  maskableOp, createMaskOpRegion);
6469  return builder.create<MaskOp>(maskableOp->getLoc(),
6470  maskableOp->getResultTypes(), mask, maskableOp,
6472 }
6473 
6474 /// Creates a vector select operation that picks values from `newValue` or
6475 /// `passthru` for each result vector lane based on `mask`. This utility is used
6476 /// to propagate the pass-thru value of vector.mask or for cases where only the
6477 /// pass-thru value propagation is needed. VP intrinsics do not support
6478 /// pass-thru values and every mask-out lane is set to poison. LLVM backends are
6479 /// usually able to match op + select patterns and fold them into a native
6480 /// target instructions.
6482  Value newValue, Value passthru) {
6483  if (!mask)
6484  return newValue;
6485 
6486  return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(),
6487  mask, newValue, passthru);
6488 }
6489 
6490 //===----------------------------------------------------------------------===//
6491 // TableGen'd op method definitions
6492 //===----------------------------------------------------------------------===//
6493 
6494 #define GET_ATTRDEF_CLASSES
6495 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
6496 
6497 #define GET_OP_CLASSES
6498 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static VectorShape vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MINUI(lhs, rhs)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition: VectorOps.cpp:68
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
Definition: VectorOps.cpp:1042
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
Definition: VectorOps.cpp:1323
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
Definition: VectorOps.cpp:1584
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
Definition: VectorOps.cpp:3829
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
Definition: VectorOps.cpp:1729
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
Definition: VectorOps.cpp:1595
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Definition: VectorOps.cpp:2083
static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width)
Definition: VectorOps.cpp:2392
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
Definition: VectorOps.cpp:131
MaskFormat
Helper enum to classify mask value.
Definition: VectorOps.cpp:58
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
Definition: VectorOps.cpp:2870
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
Definition: VectorOps.cpp:266
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
Definition: VectorOps.cpp:822
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
Definition: VectorOps.cpp:2104
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
Definition: VectorOps.cpp:2806
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
Definition: VectorOps.cpp:1034
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
Definition: VectorOps.cpp:2826
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
Definition: VectorOps.cpp:2849
static LogicalResult foldTransferFullMask(TransferOp op)
Definition: VectorOps.cpp:4031
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Definition: VectorOps.cpp:1315
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
Definition: VectorOps.cpp:3715
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
Definition: VectorOps.cpp:833
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
Definition: VectorOps.cpp:2791
static Value foldExtractFromShapeCast(ExtractOp extractOp)
Definition: VectorOps.cpp:1659
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
Definition: VectorOps.cpp:3744
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
Definition: VectorOps.cpp:3980
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
Definition: VectorOps.cpp:3269
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
Definition: VectorOps.cpp:3997
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
Definition: VectorOps.cpp:1781
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:132
MLIRContext * getContext() const
Definition: AffineMap.cpp:327
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:401
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:384
unsigned getNumDims() const
Definition: AffineMap.cpp:380
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:393
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:200
unsigned getNumResults() const
Definition: AffineMap.cpp:388
unsigned getNumInputs() const
Definition: AffineMap.cpp:389
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:397
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:542
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:296
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation & back()
Definition: Block.h:149
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
OpListType & getOperations()
Definition: Block.h:134
Operation & front()
Definition: Block.h:150
iterator begin()
Definition: Block.h:140
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:73
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
IndexType getIndexType()
Definition: Builders.cpp:71
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:277
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:553
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:444
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:830
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
Definition: Operation.cpp:584
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:775
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:812
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:836
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:708
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:534
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:123
bool isF32() const
Definition: Types.cpp:51
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:115
bool isF16() const
Definition: Types.cpp:49
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
static FailureOr< bool > areEqual(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute whether the given values/dimensions are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:305
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Definition: BuiltinTypes.h:317
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:324
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:90
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
std::optional< Operation::operand_range > getIndices(Operation *op)
Get and set the indices that the given load/store operation is operating on.
Definition: Utils.cpp:10
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
Fraction abs(const Fraction &f)
Definition: Fraction.h:104
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:354
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Definition: VectorOps.cpp:401
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
Definition: VectorOps.cpp:126
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:154
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Definition: VectorOps.cpp:3848
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
Definition: VectorOps.cpp:189
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:253
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2237
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:284
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:308
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
Definition: VectorOps.cpp:578
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition: VectorOps.h:65
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Definition: VectorOps.cpp:397
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:736
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:650
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:494
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:491
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:755
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:685
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:623
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:892
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:599
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Return a fused vector::ContractionOp which represents a patterns such as:
Definition: VectorOps.cpp:1135
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Definition: VectorOps.cpp:1138
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:41
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:328
bool operator==(const KeyTy &key) const
Definition: VectorOps.cpp:336
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Definition: VectorOps.cpp:338
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.