MLIR  20.0.0git
VectorOps.cpp
Go to the documentation of this file.
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
24 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/AffineMap.h"
26 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/BuiltinOps.h"
29 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/IRMapping.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/IR/TypeUtilities.h"
37 #include "mlir/Support/LLVM.h"
39 #include "llvm/ADT/ArrayRef.h"
40 #include "llvm/ADT/STLExtras.h"
41 #include "llvm/ADT/SmallVector.h"
42 #include "llvm/ADT/StringSet.h"
43 #include "llvm/ADT/TypeSwitch.h"
44 #include "llvm/ADT/bit.h"
45 
46 #include <cassert>
47 #include <cstdint>
48 #include <numeric>
49 
50 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
51 // Pull in all enum type and utility function definitions.
52 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
53 
54 using namespace mlir;
55 using namespace mlir::vector;
56 
57 /// Helper enum to classify mask value.
58 enum class MaskFormat {
59  AllTrue = 0,
60  AllFalse = 1,
61  Unknown = 2,
62 };
63 
64 /// Helper method to classify a mask value. Currently, the method
65 /// looks "under the hood" of a constant value with dense attributes
66 /// and a constant mask operation (since the client may be called at
67 /// various stages during progressive lowering).
69  if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
70  // Inspect constant dense values. We count up for bits that
71  // are set, count down for bits that are cleared, and bail
72  // when a mix is detected.
73  if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
74  int64_t val = 0;
75  for (bool b : denseElts.getValues<bool>())
76  if (b && val >= 0)
77  val++;
78  else if (!b && val <= 0)
79  val--;
80  else
81  return MaskFormat::Unknown;
82  if (val > 0)
83  return MaskFormat::AllTrue;
84  if (val < 0)
85  return MaskFormat::AllFalse;
86  }
87  } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
88  // Inspect constant mask index. If the index exceeds the
89  // dimension size, all bits are set. If the index is zero
90  // or less, no bits are set.
91  ArrayAttr masks = m.getMaskDimSizes();
92  auto shape = m.getType().getShape();
93  bool allTrue = true;
94  bool allFalse = true;
95  for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
96  int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt();
97  if (i < dimSize)
98  allTrue = false;
99  if (i > 0)
100  allFalse = false;
101  }
102  if (allTrue)
103  return MaskFormat::AllTrue;
104  if (allFalse)
105  return MaskFormat::AllFalse;
106  } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
107  // Finds all-false create_masks. An all-true create_mask requires all
108  // dims to be constants, so that'll be folded to a constant_mask, then
109  // detected in the constant_mask case.
110  auto maskOperands = m.getOperands();
111  for (Value operand : maskOperands) {
112  if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
113  int64_t dimSize =
114  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
115  if (dimSize <= 0)
116  return MaskFormat::AllFalse;
117  }
118  }
119  return MaskFormat::Unknown;
120  }
121  return MaskFormat::Unknown;
122 }
123 
124 /// Default callback to build a region with a 'vector.yield' terminator with no
125 /// arguments.
127  builder.create<vector::YieldOp>(loc);
128 }
129 
130 // Helper for verifying combining kinds in contractions and reductions.
131 static bool isSupportedCombiningKind(CombiningKind combiningKind,
132  Type elementType) {
133  switch (combiningKind) {
134  case CombiningKind::ADD:
135  case CombiningKind::MUL:
136  return elementType.isIntOrIndexOrFloat();
138  case CombiningKind::MINSI:
139  case CombiningKind::MAXUI:
140  case CombiningKind::MAXSI:
141  case CombiningKind::AND:
142  case CombiningKind::OR:
143  case CombiningKind::XOR:
144  return elementType.isIntOrIndex();
145  case CombiningKind::MINNUMF:
146  case CombiningKind::MAXNUMF:
147  case CombiningKind::MINIMUMF:
148  case CombiningKind::MAXIMUMF:
149  return llvm::isa<FloatType>(elementType);
150  }
151  return false;
152 }
153 
155  VectorType vectorType) {
156  int64_t elementVectorRank = 0;
157  VectorType elementVectorType =
158  llvm::dyn_cast<VectorType>(shapedType.getElementType());
159  if (elementVectorType)
160  elementVectorRank += elementVectorType.getRank();
161  // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
162  // TODO: replace once we have 0-d vectors.
163  if (shapedType.getRank() == 0 &&
164  vectorType.getShape() == ArrayRef<int64_t>{1})
165  return AffineMap::get(
166  /*numDims=*/0, /*numSymbols=*/0,
167  getAffineConstantExpr(0, shapedType.getContext()));
169  shapedType.getRank(), vectorType.getRank() - elementVectorRank,
170  shapedType.getContext());
171 }
172 
173 /// Check if `write` is of a constant splat and the masked `read` is padded with
174 /// the same splat value -- meaning it could be the same value as the initial
175 /// constant splat.
176 static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
177  vector::TransferReadOp read) {
178  auto readMask = read.getMask();
179  auto writeMask = write.getMask();
180  // Check if the masks are consistent. The splat value could be the same if the
181  // read is masked (and padded with the splat value), and the write is unmasked
182  // or has the same mask. Note this does not allow the case where the write is
183  // masked and the read is unmasked, as then the read could be of more elements
184  // than the write (which may not be the same value).
185  bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
186  if (!couldBeSameSplat)
187  return false;
188  // Check for constant splat (as the source of the write).
189  DenseElementsAttr splatAttr;
190  if (!matchPattern(write.getVector(),
191  m_Constant<DenseElementsAttr>(&splatAttr)) ||
192  !splatAttr.isSplat()) {
193  return false;
194  }
195  // The padding of the read and the constant splat value must be the same.
196  Attribute padAttr;
197  if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
198  return false;
199  return padAttr == splatAttr.getSplatValue<Attribute>();
200 }
201 
202 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
203  vector::TransferReadOp read) {
204  return !defWrite.hasOutOfBoundsDim() &&
205  defWrite.getIndices() == read.getIndices() &&
206  defWrite.getVectorType() == read.getVectorType() &&
207  defWrite.getPermutationMap() == read.getPermutationMap() &&
208  ((!defWrite.getMask() && !read.getMask()) ||
209  isSplatWriteConsistentWithMaskedRead(defWrite, read));
210 }
211 
212 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
213  vector::TransferWriteOp priorWrite) {
214  return priorWrite.getIndices() == write.getIndices() &&
215  priorWrite.getMask() == write.getMask() &&
216  priorWrite.getVectorType() == write.getVectorType() &&
217  priorWrite.getPermutationMap() == write.getPermutationMap();
218 }
219 
221  VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
222  bool testDynamicValueUsingBounds) {
223  // For simplicity only look at transfer of same type.
224  if (transferA.getVectorType() != transferB.getVectorType())
225  return false;
226  unsigned rankOffset = transferA.getLeadingShapedRank();
227  for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
228  Value indexA = transferA.getIndices()[i];
229  Value indexB = transferB.getIndices()[i];
230  std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
231  std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
232 
233  if (i < rankOffset) {
234  // For leading dimensions, if we can prove that index are different we
235  // know we are accessing disjoint slices.
236  if (cstIndexA.has_value() && cstIndexB.has_value()) {
237  if (*cstIndexA != *cstIndexB)
238  return true;
239  continue;
240  }
241  if (testDynamicValueUsingBounds) {
242  // First try to see if we can fully compose and simplify the affine
243  // expression as a fast track.
244  FailureOr<uint64_t> delta =
246  if (succeeded(delta) && *delta != 0)
247  return true;
248 
249  FailureOr<bool> testEqual =
250  ValueBoundsConstraintSet::areEqual(indexA, indexB);
251  if (succeeded(testEqual) && !testEqual.value())
252  return true;
253  }
254  } else {
255  // For this dimension, we slice a part of the memref we need to make sure
256  // the intervals accessed don't overlap.
257  int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
258  if (cstIndexA.has_value() && cstIndexB.has_value()) {
259  int64_t distance = std::abs(*cstIndexA - *cstIndexB);
260  if (distance >= vectorDim)
261  return true;
262  continue;
263  }
264  if (testDynamicValueUsingBounds) {
265  // First try to see if we can fully compose and simplify the affine
266  // expression as a fast track.
267  FailureOr<int64_t> delta =
269  if (succeeded(delta) && std::abs(*delta) >= vectorDim)
270  return true;
271 
272  FailureOr<int64_t> computeDelta =
274  if (succeeded(computeDelta)) {
275  if (std::abs(computeDelta.value()) >= vectorDim)
276  return true;
277  }
278  }
279  }
280  }
281  return false;
282 }
283 
284 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
285  VectorTransferOpInterface transferB,
286  bool testDynamicValueUsingBounds) {
287  if (transferA.getSource() != transferB.getSource())
288  return false;
289  return isDisjointTransferIndices(transferA, transferB,
290  testDynamicValueUsingBounds);
291 }
292 
293 // Helper to iterate over n-D vector slice elements. Calculate the next
294 // `position` in the n-D vector of size `shape`, applying an offset `offsets`.
295 // Modifies the `position` in place. Returns a failure when `position` becomes
296 // the end position.
297 static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
298  ArrayRef<int64_t> shape,
299  ArrayRef<int64_t> offsets) {
300  for (auto [posInDim, dimSize, offsetInDim] :
301  llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
302  ++posInDim;
303  if (posInDim < dimSize + offsetInDim)
304  return success();
305 
306  // Carry the overflow to the next loop iteration.
307  posInDim = offsetInDim;
308  }
309 
310  return failure();
311 }
312 
313 /// Returns the integer numbers in `values`. `values` are expected to be
314 /// constant operations.
317  llvm::transform(values, std::back_inserter(ints), [](Value value) {
318  auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
319  assert(constOp && "Unexpected non-constant index");
320  return constOp.value();
321  });
322  return ints;
323 }
324 
325 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
326 /// be constant operations.
329  llvm::transform(
330  foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
331  assert(foldResult.is<Attribute>() && "Unexpected non-constant index");
332  return cast<IntegerAttr>(foldResult.get<Attribute>()).getInt();
333  });
334  return ints;
335 }
336 
337 /// Convert `foldResults` into Values. Integer attributes are converted to
338 /// constant op.
340  ArrayRef<OpFoldResult> foldResults) {
341  SmallVector<Value> values;
342  llvm::transform(foldResults, std::back_inserter(values),
343  [&](OpFoldResult foldResult) {
344  if (auto attr = foldResult.dyn_cast<Attribute>())
345  return builder
347  loc, cast<IntegerAttr>(attr).getInt())
348  .getResult();
349 
350  return foldResult.get<Value>();
351  });
352  return values;
353 }
354 
355 //===----------------------------------------------------------------------===//
356 // CombiningKindAttr
357 //===----------------------------------------------------------------------===//
358 
359 namespace mlir {
360 namespace vector {
361 namespace detail {
363  using KeyTy = uint64_t;
364 
365  BitmaskEnumStorage(KeyTy val) : value(val) {}
366 
367  bool operator==(const KeyTy &key) const { return value == key; }
368 
370  const KeyTy &key) {
371  return new (allocator.allocate<BitmaskEnumStorage>())
372  BitmaskEnumStorage(key);
373  }
374 
375  KeyTy value = 0;
376 };
377 } // namespace detail
378 } // namespace vector
379 } // namespace mlir
380 
381 //===----------------------------------------------------------------------===//
382 // VectorDialect
383 //===----------------------------------------------------------------------===//
384 
385 namespace {
386 /// This class defines the interface for handling inlining with vector dialect
387 /// operations.
388 struct VectorInlinerInterface : public DialectInlinerInterface {
390 
391  /// All vector dialect ops can be inlined.
392  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
393  return true;
394  }
395 };
396 } // namespace
397 
398 void VectorDialect::initialize() {
399  addAttributes<
400 #define GET_ATTRDEF_LIST
401 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
402  >();
403 
404  addOperations<
405 #define GET_OP_LIST
406 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
407  >();
408 
409  addInterfaces<VectorInlinerInterface>();
410 
411  declarePromisedInterfaces<bufferization::BufferizableOpInterface,
412  TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
413  YieldOp>();
414  declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
415  TransferWriteOp>();
416  declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
417  declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
418 }
419 
420 /// Materialize a single constant operation from a given attribute value with
421 /// the desired resultant type.
423  Attribute value, Type type,
424  Location loc) {
425  return arith::ConstantOp::materialize(builder, value, type, loc);
426 }
427 
429  return builder.getIntegerType(64);
430 }
431 
433  ArrayRef<int64_t> values) {
434  return builder.getI64ArrayAttr(values);
435 }
436 
437 //===----------------------------------------------------------------------===//
438 // MultiDimReductionOp
439 //===----------------------------------------------------------------------===//
440 
441 void vector::MultiDimReductionOp::build(OpBuilder &builder,
442  OperationState &result, Value source,
443  Value acc, ArrayRef<bool> reductionMask,
444  CombiningKind kind) {
445  SmallVector<int64_t> reductionDims;
446  for (const auto &en : llvm::enumerate(reductionMask))
447  if (en.value())
448  reductionDims.push_back(en.index());
449  build(builder, result, kind, source, acc,
450  builder.getI64ArrayAttr(reductionDims));
451 }
452 
453 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
454  // Single parallel dim, this is a noop.
455  if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
456  return getSource();
457  return {};
458 }
459 
460 std::optional<SmallVector<int64_t, 4>>
461 MultiDimReductionOp::getShapeForUnroll() {
462  return llvm::to_vector<4>(getSourceVectorType().getShape());
463 }
464 
465 LogicalResult MultiDimReductionOp::verify() {
466  SmallVector<int64_t> targetShape;
467  SmallVector<bool> scalableDims;
468  Type inferredReturnType;
469  auto sourceScalableDims = getSourceVectorType().getScalableDims();
470  for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
471  if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
472  return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
473  })) {
474  targetShape.push_back(it.value());
475  scalableDims.push_back(sourceScalableDims[it.index()]);
476  }
477  // TODO: update to also allow 0-d vectors when available.
478  if (targetShape.empty())
479  inferredReturnType = getSourceVectorType().getElementType();
480  else
481  inferredReturnType = VectorType::get(
482  targetShape, getSourceVectorType().getElementType(), scalableDims);
483  if (getType() != inferredReturnType)
484  return emitOpError() << "destination type " << getType()
485  << " is incompatible with source type "
486  << getSourceVectorType();
487 
488  return success();
489 }
490 
491 /// Returns the mask type expected by this operation.
492 Type MultiDimReductionOp::getExpectedMaskType() {
493  auto vecType = getSourceVectorType();
494  return VectorType::get(vecType.getShape(),
495  IntegerType::get(vecType.getContext(), /*width=*/1),
496  vecType.getScalableDims());
497 }
498 
499 namespace {
500 // Only unit dimensions that are being reduced are folded. If the dimension is
501 // unit, but not reduced, it is not folded, thereby keeping the output type the
502 // same. If not all dimensions which are reduced are of unit dimension, this
503 // transformation does nothing. This is just a generalization of
504 // ElideSingleElementReduction for ReduceOp.
505 struct ElideUnitDimsInMultiDimReduction
506  : public OpRewritePattern<MultiDimReductionOp> {
508 
509  LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
510  PatternRewriter &rewriter) const override {
511  ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
512  for (const auto &dim : enumerate(shape)) {
513  if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
514  return failure();
515  }
516 
517  // Vector mask setup.
518  OpBuilder::InsertionGuard guard(rewriter);
519  Operation *rootOp;
520  Value mask;
521  if (reductionOp.isMasked()) {
522  rewriter.setInsertionPoint(reductionOp.getMaskingOp());
523  rootOp = reductionOp.getMaskingOp();
524  mask = reductionOp.getMaskingOp().getMask();
525  } else {
526  rootOp = reductionOp;
527  }
528 
529  Location loc = reductionOp.getLoc();
530  Value acc = reductionOp.getAcc();
531  Value cast;
532  if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
533  if (mask) {
534  VectorType newMaskType =
535  VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
536  dstVecType.getScalableDims());
537  mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
538  }
539  cast = rewriter.create<vector::ShapeCastOp>(
540  loc, reductionOp.getDestType(), reductionOp.getSource());
541  } else {
542  // This means we are reducing all the dimensions, and all reduction
543  // dimensions are of size 1. So a simple extraction would do.
544  SmallVector<int64_t> zeroIdx(shape.size(), 0);
545  if (mask)
546  mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx);
547  cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(),
548  zeroIdx);
549  }
550 
551  Value result =
552  vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
553  cast, /*fastmath=*/nullptr, mask);
554  rewriter.replaceOp(rootOp, result);
555  return success();
556  }
557 };
558 } // namespace
559 
560 void MultiDimReductionOp::getCanonicalizationPatterns(
561  RewritePatternSet &results, MLIRContext *context) {
562  results.add<ElideUnitDimsInMultiDimReduction>(context);
563 }
564 
565 //===----------------------------------------------------------------------===//
566 // ReductionOp
567 //===----------------------------------------------------------------------===//
568 
569 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
570  CombiningKind kind, Value vector,
571  arith::FastMathFlags fastMathFlags) {
572  build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
573 }
574 
575 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
576  CombiningKind kind, Value vector, Value acc,
577  arith::FastMathFlags fastMathFlags) {
578  build(builder, result,
579  llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
580  acc, fastMathFlags);
581 }
582 
583 LogicalResult ReductionOp::verify() {
584  // Verify for 0-D and 1-D vector.
585  int64_t rank = getSourceVectorType().getRank();
586  if (rank > 1)
587  return emitOpError("unsupported reduction rank: ") << rank;
588 
589  // Verify supported reduction kind.
590  Type eltType = getDest().getType();
591  if (!isSupportedCombiningKind(getKind(), eltType))
592  return emitOpError("unsupported reduction type '")
593  << eltType << "' for kind '" << stringifyCombiningKind(getKind())
594  << "'";
595 
596  return success();
597 }
598 
599 // MaskableOpInterface methods.
600 
601 /// Returns the mask type expected by this operation.
602 Type ReductionOp::getExpectedMaskType() {
603  auto vecType = getSourceVectorType();
604  return VectorType::get(vecType.getShape(),
605  IntegerType::get(vecType.getContext(), /*width=*/1),
606  vecType.getScalableDims());
607 }
608 
610  OpBuilder &builder, Location loc,
611  Value vector) {
612  switch (op) {
613  case arith::AtomicRMWKind::addf:
614  case arith::AtomicRMWKind::addi:
615  return builder.create<vector::ReductionOp>(vector.getLoc(),
616  CombiningKind::ADD, vector);
617  case arith::AtomicRMWKind::mulf:
618  case arith::AtomicRMWKind::muli:
619  return builder.create<vector::ReductionOp>(vector.getLoc(),
620  CombiningKind::MUL, vector);
621  case arith::AtomicRMWKind::minimumf:
622  return builder.create<vector::ReductionOp>(vector.getLoc(),
623  CombiningKind::MINIMUMF, vector);
624  case arith::AtomicRMWKind::mins:
625  return builder.create<vector::ReductionOp>(vector.getLoc(),
626  CombiningKind::MINSI, vector);
627  case arith::AtomicRMWKind::minu:
628  return builder.create<vector::ReductionOp>(vector.getLoc(),
629  CombiningKind::MINUI, vector);
630  case arith::AtomicRMWKind::maximumf:
631  return builder.create<vector::ReductionOp>(vector.getLoc(),
632  CombiningKind::MAXIMUMF, vector);
633  case arith::AtomicRMWKind::maxs:
634  return builder.create<vector::ReductionOp>(vector.getLoc(),
635  CombiningKind::MAXSI, vector);
636  case arith::AtomicRMWKind::maxu:
637  return builder.create<vector::ReductionOp>(vector.getLoc(),
638  CombiningKind::MAXUI, vector);
639  case arith::AtomicRMWKind::andi:
640  return builder.create<vector::ReductionOp>(vector.getLoc(),
641  CombiningKind::AND, vector);
642  case arith::AtomicRMWKind::ori:
643  return builder.create<vector::ReductionOp>(vector.getLoc(),
644  CombiningKind::OR, vector);
645  // TODO: Add remaining reduction operations.
646  default:
647  (void)emitOptionalError(loc, "Reduction operation type not supported");
648  break;
649  }
650  return nullptr;
651 }
652 
653 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
654  return llvm::to_vector<4>(getSourceVectorType().getShape());
655 }
656 
657 namespace {
658 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
660 
661  LogicalResult matchAndRewrite(ReductionOp reductionOp,
662  PatternRewriter &rewriter) const override {
663  // Vector mask setup.
664  OpBuilder::InsertionGuard guard(rewriter);
665  auto maskableOp =
666  cast<vector::MaskableOpInterface>(reductionOp.getOperation());
667  Operation *rootOp;
668  Value mask;
669  if (maskableOp.isMasked()) {
670  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
671  rootOp = maskableOp.getMaskingOp();
672  mask = maskableOp.getMaskingOp().getMask();
673  } else {
674  rootOp = reductionOp;
675  }
676 
677  auto vectorType = reductionOp.getSourceVectorType();
678  if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
679  return failure();
680 
681  Location loc = reductionOp.getLoc();
682  Value result;
683  if (vectorType.getRank() == 0) {
684  if (mask)
685  mask = rewriter.create<ExtractElementOp>(loc, mask);
686  result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
687  } else {
688  if (mask)
689  mask = rewriter.create<ExtractOp>(loc, mask, 0);
690  result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0);
691  }
692 
693  if (Value acc = reductionOp.getAcc())
694  result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
695  result, acc,
696  reductionOp.getFastmathAttr(), mask);
697 
698  rewriter.replaceOp(rootOp, result);
699  return success();
700  }
701 };
702 } // namespace
703 
704 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
705  MLIRContext *context) {
706  results.add<ElideSingleElementReduction>(context);
707 }
708 
709 //===----------------------------------------------------------------------===//
710 // ContractionOp
711 //===----------------------------------------------------------------------===//
712 
713 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
714  Value lhs, Value rhs, Value acc,
715  ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
716  ArrayRef<IteratorType> iteratorTypes) {
717  result.addOperands({lhs, rhs, acc});
718  result.addTypes(acc.getType());
719  result.addAttribute(
720  getIndexingMapsAttrName(result.name),
721  builder.getAffineMapArrayAttr(
722  AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
723  result.addAttribute(
724  getIteratorTypesAttrName(result.name),
725  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
726  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
727  return IteratorTypeAttr::get(builder.getContext(), t);
728  }))));
729 }
730 
731 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
732  Value lhs, Value rhs, Value acc,
733  ArrayAttr indexingMaps,
734  ArrayAttr iteratorTypes) {
735  build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
736  ContractionOp::getDefaultKind());
737 }
738 
739 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
740  Value lhs, Value rhs, Value acc,
741  ArrayAttr indexingMaps,
742  ArrayAttr iteratorTypes, CombiningKind kind) {
743  result.addOperands({lhs, rhs, acc});
744  result.addTypes(acc.getType());
745  result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
746  result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
747  result.addAttribute(getKindAttrName(result.name),
748  CombiningKindAttr::get(builder.getContext(), kind));
749 }
750 
751 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
756  SmallVector<Type, 2> types;
757  Type resultType;
758  auto loc = parser.getCurrentLocation();
759  DictionaryAttr dictAttr;
760  // TODO: Unify linalg op attribute parsing.
761  if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
762  parser.parseComma() || parser.parseOperand(rhsInfo) ||
763  parser.parseComma() || parser.parseOperand(accInfo) ||
764  parser.parseTrailingOperandList(masksInfo) ||
765  parser.parseOptionalAttrDict(result.attributes) ||
766  parser.parseColonTypeList(types) ||
767  parser.parseKeywordType("into", resultType) ||
768  parser.resolveOperand(lhsInfo, types[0], result.operands) ||
769  parser.resolveOperand(rhsInfo, types[1], result.operands) ||
770  parser.resolveOperand(accInfo, resultType, result.operands) ||
771  parser.addTypeToList(resultType, result.types))
772  return failure();
773  result.attributes.append(dictAttr.getValue().begin(),
774  dictAttr.getValue().end());
775 
776  // Convert array of string into an array of IteratyType enums. This is needed,
777  // because tests still use the old format when 'iterator_types' attribute is
778  // represented as an array of strings.
779  // TODO: Remove this conversion once tests are fixed.
780  ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
781  result.attributes.get(getIteratorTypesAttrName(result.name)));
782 
783  SmallVector<Attribute> iteratorTypeAttrs;
784 
785  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
786  auto maybeIteratorType = symbolizeIteratorType(s);
787  if (!maybeIteratorType.has_value())
788  return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
789 
790  iteratorTypeAttrs.push_back(
791  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
792  }
793  result.attributes.set(getIteratorTypesAttrName(result.name),
794  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
795 
796  if (!result.attributes.get(getKindAttrName(result.name))) {
797  result.addAttribute(
798  getKindAttrName(result.name),
800  ContractionOp::getDefaultKind()));
801  }
802  if (masksInfo.empty())
803  return success();
804  if (masksInfo.size() != 2)
805  return parser.emitError(parser.getNameLoc(),
806  "expected zero or exactly 2 vector mask operands");
807  auto lhsType = llvm::cast<VectorType>(types[0]);
808  auto rhsType = llvm::cast<VectorType>(types[1]);
809  auto maskElementType = parser.getBuilder().getI1Type();
810  std::array<VectorType, 2> maskTypes = {
811  VectorType::Builder(lhsType).setElementType(maskElementType),
812  VectorType::Builder(rhsType).setElementType(maskElementType)};
813  if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
814  return failure();
815  return success();
816 }
817 
819  // TODO: Unify printing code with linalg ops.
820  auto attrNames = getTraitAttrNames();
821  llvm::StringSet<> traitAttrsSet;
822  traitAttrsSet.insert(attrNames.begin(), attrNames.end());
824  for (auto attr : (*this)->getAttrs()) {
825  if (attr.getName() == getIteratorTypesAttrName()) {
826  auto iteratorTypes =
827  llvm::cast<ArrayAttr>(attr.getValue())
828  .getAsValueRange<IteratorTypeAttr, IteratorType>();
829  // Convert IteratorType enums into the string representation. This is
830  // needed, because tests still use the old format when 'iterator_types'
831  // attribute is represented as an array of strings.
832  // TODO: Remove this conversion once tests are fixed.
833  SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
834  llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
835  return StringAttr::get(getContext(), stringifyIteratorType(t));
836  }));
837 
838  attrs.emplace_back(getIteratorTypesAttrName(),
839  ArrayAttr::get(getContext(), iteratorTypeNames));
840  } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
841  attrs.push_back(attr);
842  }
843 
844  auto dictAttr = DictionaryAttr::get(getContext(), attrs);
845  p << " " << dictAttr << " " << getLhs() << ", ";
846  p << getRhs() << ", " << getAcc();
847 
848  p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
849  p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
850  << getResultType();
851 }
852 
853 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
854  const std::vector<std::pair<int64_t, int64_t>> &map) {
855  for (auto &dimPair : map) {
856  if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
857  dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
858  lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
859  return false;
860  }
861  return true;
862 }
863 
864 static LogicalResult verifyOutputShape(
865  ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
866  Type resType,
867  const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
868  const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
869  DenseSet<int64_t> lhsContractingDimSet;
870  DenseSet<int64_t> rhsContractingDimSet;
871  for (auto &dimPair : contractingDimMap) {
872  lhsContractingDimSet.insert(dimPair.first);
873  rhsContractingDimSet.insert(dimPair.second);
874  }
875  DenseSet<int64_t> rhsBatchDimSet;
876  for (auto &dimPair : batchDimMap)
877  rhsBatchDimSet.insert(dimPair.second);
878 
879  // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
880  SmallVector<int64_t, 4> expectedResultDims;
881  for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
882  if (lhsContractingDimSet.count(i) > 0)
883  continue;
884  expectedResultDims.push_back(lhsType.getDimSize(i));
885  }
886 
887  // Add free dimensions from 'rhsType' to 'expectedResultDims'.
888  for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
889  if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
890  continue;
891  expectedResultDims.push_back(rhsType.getDimSize(i));
892  }
893 
894  // Verify 'expectedResultDims'.
895  if (expectedResultDims.empty()) {
896  // No batch or free dimension implies a scalar result.
897  if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
898  return op.emitOpError("invalid accumulator/result vector shape");
899  } else {
900  // At least one batch or free dimension implies a vector result.
901  auto resVectorType = llvm::dyn_cast<VectorType>(resType);
902  auto accVectorType = llvm::dyn_cast<VectorType>(accType);
903  if (!resVectorType || !accVectorType)
904  return op.emitOpError("invalid accumulator/result vector shape");
905 
906  // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
907  // types fully define the result vector type. This assumes the affine maps
908  // are well-formed, which must have been verified already.
909  MLIRContext *ctx = op.getContext();
910  AffineMap lhsMap = op.getIndexingMapsArray()[0];
911  AffineMap rhsMap = op.getIndexingMapsArray()[1];
912  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
913  return op.emitOpError(
914  "expected all dimensions to be either a LHS or a RHS dimension");
915  SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
916  for (auto pair :
917  {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
918  VectorType v = pair.first;
919  auto map = pair.second;
920  for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
921  unsigned pos = map.getDimPosition(idx);
922  if (!extents[pos])
923  extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
924  }
925  }
926  if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
927  return op.emitOpError("expected all dimensions to get an extent as "
928  "either a LHS or a RHS dimension");
929 
930  AffineMap resMap = op.getIndexingMapsArray()[2];
931  auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
932  /*symbolCount=*/0, extents, ctx);
933  // Compose the resMap with the extentsMap, which is a constant map.
934  AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
935  assert(llvm::all_of(expectedMap.getResults(),
936  llvm::IsaPred<AffineConstantExpr>) &&
937  "expected constant extent along all dimensions.");
938  // Extract the expected shape and build the type.
939  auto expectedShape = llvm::to_vector<4>(
940  llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
941  return cast<AffineConstantExpr>(e).getValue();
942  }));
943  auto expected =
944  VectorType::get(expectedShape, resVectorType.getElementType(),
945  resVectorType.getScalableDims());
946  if (resVectorType != expected || accVectorType != expected)
947  return op.emitOpError(
948  "invalid accumulator/result vector shape, expected: ")
949  << expected;
950  }
951  return success();
952 }
953 
954 LogicalResult ContractionOp::verify() {
955  VectorType lhsType = getLhsType();
956  VectorType rhsType = getRhsType();
957  Type accType = getAccType();
958  Type resType = getResultType();
959 
960  if (llvm::isa<IntegerType>(lhsType.getElementType())) {
961  if (!lhsType.getElementType().isSignlessInteger())
962  return emitOpError("only supports signless integer types");
963  }
964 
965  // Verify that an indexing map was specified for each vector operand.
966  if (getIndexingMapsArray().size() != 3)
967  return emitOpError("expected an indexing map for each vector operand");
968 
969  // Verify that each index map has 'numIterators' inputs, no symbols, and
970  // that the number of map outputs equals the rank of its associated
971  // vector operand.
972  unsigned numIterators = getIteratorTypes().getValue().size();
973  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
974  auto index = it.index();
975  auto map = it.value();
976  if (map.getNumSymbols() != 0)
977  return emitOpError("expected indexing map ")
978  << index << " to have no symbols";
979  auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
980  unsigned rank = vectorType ? vectorType.getShape().size() : 0;
981  // Verify that the map has the right number of inputs, outputs, and indices.
982  // This also correctly accounts for (..) -> () for rank-0 results.
983  if (map.getNumDims() != numIterators)
984  return emitOpError("expected indexing map ")
985  << index << " to have " << numIterators << " number of inputs";
986  if (map.getNumResults() != rank)
987  return emitOpError("expected indexing map ")
988  << index << " to have " << rank << " number of outputs";
989  if (!map.isProjectedPermutation())
990  return emitOpError("expected indexing map ")
991  << index << " to be a projected permutation of its inputs";
992  }
993 
994  auto contractingDimMap = getContractingDimMap();
995  auto batchDimMap = getBatchDimMap();
996 
997  // Verify at least one contracting dimension pair was specified.
998  if (contractingDimMap.empty())
999  return emitOpError("expected at least one contracting dimension pair");
1000 
1001  // Verify contracting dimension map was properly constructed.
1002  if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1003  return emitOpError("invalid contracting dimension map");
1004 
1005  // Verify batch dimension map was properly constructed.
1006  if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1007  return emitOpError("invalid batch dimension map");
1008 
1009  // Verify 'accType' and 'resType' shape.
1010  if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1011  contractingDimMap, batchDimMap)))
1012  return failure();
1013 
1014  // Verify supported combining kind.
1015  auto vectorType = llvm::dyn_cast<VectorType>(resType);
1016  auto elementType = vectorType ? vectorType.getElementType() : resType;
1017  if (!isSupportedCombiningKind(getKind(), elementType))
1018  return emitOpError("unsupported contraction type");
1019 
1020  return success();
1021 }
1022 
1023 // MaskableOpInterface methods.
1024 
1025 /// Returns the mask type expected by this operation. Mostly used for
1026 /// verification purposes. It requires the operation to be vectorized."
1027 Type ContractionOp::getExpectedMaskType() {
1028  auto indexingMaps = this->getIndexingMapsArray();
1029  AffineMap lhsIdxMap = indexingMaps[0];
1030  AffineMap rhsIdxMap = indexingMaps[1];
1031  VectorType lhsType = this->getLhsType();
1032  VectorType rhsType = this->getRhsType();
1033 
1034  unsigned numVecDims = lhsIdxMap.getNumDims();
1035  SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1036  SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1037 
1038  // Using the information in the indexing maps, extract the size of each
1039  // dimension in the vector.contract operation from the two input operands.
1040  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1041  maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1042  maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1043  lhsType.getScalableDims()[dimIdx];
1044  }
1045  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1046  maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1047  maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1048  rhsType.getScalableDims()[dimIdx];
1049  }
1050 
1051  assert(!ShapedType::isDynamicShape(maskShape) &&
1052  "Mask shape couldn't be computed");
1053 
1054  return VectorType::get(maskShape,
1055  IntegerType::get(lhsType.getContext(), /*width=*/1),
1056  maskShapeScalableDims);
1057 }
1058 
1059 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1060  return SmallVector<StringRef>{getIndexingMapsAttrName(),
1061  getIteratorTypesAttrName(), getKindAttrName()};
1062 }
1063 
1064 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
1065  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1066  if (targetExpr == map.getResult(i))
1067  return i;
1068  return -1;
1069 }
1070 
1071 static std::vector<std::pair<int64_t, int64_t>>
1072 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1073  IteratorType targetIteratorType, MLIRContext *context) {
1074  std::vector<std::pair<int64_t, int64_t>> dimMap;
1075  for (const auto &it : llvm::enumerate(iteratorTypes)) {
1076  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1077  if (iteratorType != targetIteratorType)
1078  continue;
1079  // Search lhs/rhs map results for 'targetExpr'.
1080  auto targetExpr = getAffineDimExpr(it.index(), context);
1081  int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1082  int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1083  if (lhsDim >= 0 && rhsDim >= 0)
1084  dimMap.emplace_back(lhsDim, rhsDim);
1085  }
1086  return dimMap;
1087 }
1088 
1089 void ContractionOp::getIterationBounds(
1090  SmallVectorImpl<int64_t> &iterationBounds) {
1091  auto lhsShape = getLhsType().getShape();
1092  auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1093  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1094  SmallVector<int64_t, 2> iterationShape;
1095  for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1096  // Search lhs/rhs map results for 'targetExpr'.
1097  auto targetExpr = getAffineDimExpr(it.index(), getContext());
1098  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1099  if (iteratorType == IteratorType::reduction) {
1100  // Get reduction dim size from lhs shape (same size in rhsShape).
1101  int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1102  assert(lhsDimIndex >= 0);
1103  iterationBounds.push_back(lhsShape[lhsDimIndex]);
1104  continue;
1105  }
1106  // Get parallel dimension size from result shape.
1107  int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1108  assert(resDimIndex >= 0);
1109  assert(resVectorType != nullptr);
1110  iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1111  }
1112 }
1113 
1114 void ContractionOp::getIterationIndexMap(
1115  std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1116  unsigned numMaps = getIndexingMapsArray().size();
1117  iterationIndexMap.resize(numMaps);
1118  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1119  auto index = it.index();
1120  auto map = it.value();
1121  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1122  auto dim = cast<AffineDimExpr>(map.getResult(i));
1123  iterationIndexMap[index][dim.getPosition()] = i;
1124  }
1125  }
1126 }
1127 
1128 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1129  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1130  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1131  getContext());
1132 }
1133 
1134 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1135  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1136  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1137  getContext());
1138 }
1139 
1140 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1142  getIterationBounds(shape);
1143  return shape;
1144 }
1145 
1146 /// Return a fused vector::ContractionOp which represents a patterns such as:
1147 ///
1148 /// ```mlir
1149 /// %c0 = vector.constant 0: ...
1150 /// %c = vector.contract %a, %b, %c0: ...
1151 /// %e = add %c, %d: ...
1152 /// ```
1153 ///
1154 /// by:
1155 ///
1156 /// ```mlir
1157 /// %e = vector.contract %a, %b, %d: ...
1158 /// ```
1159 ///
1160 /// Return null if the canonicalization does not apply.
1161 // TODO: This should be a folding of Add into Contract in core but while they
1162 // live in different dialects, it is not possible without unnatural
1163 // dependencies.
1164 template <typename AddOpType>
1165 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1167 
1168  LogicalResult matchAndRewrite(AddOpType addOp,
1169  PatternRewriter &rewriter) const override {
1170  auto canonicalize = [&](Value maybeContraction,
1171  Value otherOperand) -> vector::ContractionOp {
1172  vector::ContractionOp contractionOp =
1173  dyn_cast_or_null<vector::ContractionOp>(
1174  maybeContraction.getDefiningOp());
1175  if (!contractionOp)
1176  return vector::ContractionOp();
1177  if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1178  contractionOp.getAcc().getDefiningOp())) {
1179  if (maybeZero.getValue() ==
1180  rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1181  IRMapping bvm;
1182  bvm.map(contractionOp.getAcc(), otherOperand);
1183  auto newContraction =
1184  cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1185  rewriter.replaceOp(addOp, newContraction.getResult());
1186  return newContraction;
1187  }
1188  }
1189  return vector::ContractionOp();
1190  };
1191 
1192  Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1193  vector::ContractionOp contract = canonicalize(a, b);
1194  contract = contract ? contract : canonicalize(b, a);
1195  return contract ? success() : failure();
1196  }
1197 };
1198 
1199 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1200  MLIRContext *context) {
1203 }
1204 
1205 //===----------------------------------------------------------------------===//
1206 // ExtractElementOp
1207 //===----------------------------------------------------------------------===//
1208 
1209 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1210  Value source) {
1211  result.addOperands({source});
1212  result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
1213 }
1214 
1215 LogicalResult vector::ExtractElementOp::verify() {
1216  VectorType vectorType = getSourceVectorType();
1217  if (vectorType.getRank() == 0) {
1218  if (getPosition())
1219  return emitOpError("expected position to be empty with 0-D vector");
1220  return success();
1221  }
1222  if (vectorType.getRank() != 1)
1223  return emitOpError("unexpected >1 vector rank");
1224  if (!getPosition())
1225  return emitOpError("expected position for 1-D vector");
1226  return success();
1227 }
1228 
1229 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1230  // Skip the 0-D vector here now.
1231  if (!adaptor.getPosition())
1232  return {};
1233 
1234  // Fold extractelement (splat X) -> X.
1235  if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
1236  return splat.getInput();
1237 
1238  // Fold extractelement(broadcast(X)) -> X.
1239  if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1240  if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
1241  return broadcast.getSource();
1242 
1243  auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1244  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1245  if (!pos || !src)
1246  return {};
1247 
1248  auto srcElements = src.getValues<Attribute>();
1249 
1250  uint64_t posIdx = pos.getInt();
1251  if (posIdx >= srcElements.size())
1252  return {};
1253 
1254  return srcElements[posIdx];
1255 }
1256 
1257 //===----------------------------------------------------------------------===//
1258 // ExtractOp
1259 //===----------------------------------------------------------------------===//
1260 
1261 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1262  Value source, int64_t position) {
1263  build(builder, result, source, ArrayRef<int64_t>{position});
1264 }
1265 
1266 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1267  Value source, OpFoldResult position) {
1268  build(builder, result, source, ArrayRef<OpFoldResult>{position});
1269 }
1270 
1271 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1272  Value source, ArrayRef<int64_t> position) {
1273  build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1274  builder.getDenseI64ArrayAttr(position));
1275 }
1276 
1277 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1278  Value source, ArrayRef<OpFoldResult> position) {
1279  SmallVector<int64_t> staticPos;
1280  SmallVector<Value> dynamicPos;
1281  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1282  build(builder, result, source, dynamicPos,
1283  builder.getDenseI64ArrayAttr(staticPos));
1284 }
1285 
1286 LogicalResult
1287 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1288  ExtractOp::Adaptor adaptor,
1289  SmallVectorImpl<Type> &inferredReturnTypes) {
1290  auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1291  if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1292  vectorType.getRank()) {
1293  inferredReturnTypes.push_back(vectorType.getElementType());
1294  } else {
1295  auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1296  vectorType.getRank());
1297  inferredReturnTypes.push_back(VectorType::get(
1298  vectorType.getShape().drop_front(n), vectorType.getElementType(),
1299  vectorType.getScalableDims().drop_front(n)));
1300  }
1301  return success();
1302 }
1303 
1304 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1305  // Allow extracting 1-element vectors instead of scalars.
1306  auto isCompatible = [](TypeRange l, TypeRange r) {
1307  auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1308  return vectorType && vectorType.getShape().equals({1}) &&
1309  vectorType.getElementType() == r.front();
1310  };
1311  if (l.size() == 1 && r.size() == 1 &&
1312  (isCompatible(l, r) || isCompatible(r, l)))
1313  return true;
1314  return l == r;
1315 }
1316 
1317 LogicalResult vector::ExtractOp::verify() {
1318  // Note: This check must come before getMixedPosition() to prevent a crash.
1319  auto dynamicMarkersCount =
1320  llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1321  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1322  return emitOpError(
1323  "mismatch between dynamic and static positions (kDynamic marker but no "
1324  "corresponding dynamic position) -- this can only happen due to an "
1325  "incorrect fold/rewrite");
1326  auto position = getMixedPosition();
1327  if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1328  return emitOpError(
1329  "expected position attribute of rank no greater than vector rank");
1330  for (auto [idx, pos] : llvm::enumerate(position)) {
1331  if (pos.is<Attribute>()) {
1332  int64_t constIdx = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
1333  if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
1334  return emitOpError("expected position attribute #")
1335  << (idx + 1)
1336  << " to be a non-negative integer smaller than the "
1337  "corresponding vector dimension";
1338  }
1339  }
1340  }
1341  return success();
1342 }
1343 
1344 template <typename IntType>
1345 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1346  return llvm::to_vector<4>(llvm::map_range(
1347  arrayAttr.getAsRange<IntegerAttr>(),
1348  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1349 }
1350 
1351 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1352 /// positions.
1353 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1354  if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1355  return failure();
1356 
1357  // TODO: Canonicalization for dynamic position not implemented yet.
1358  if (extractOp.hasDynamicPosition())
1359  return failure();
1360 
1361  SmallVector<int64_t> globalPosition;
1362  ExtractOp currentOp = extractOp;
1363  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1364  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1365  while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1366  currentOp = nextOp;
1367  // TODO: Canonicalization for dynamic position not implemented yet.
1368  if (currentOp.hasDynamicPosition())
1369  return failure();
1370  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1371  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1372  }
1373  extractOp.setOperand(0, currentOp.getVector());
1374  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1375  OpBuilder b(extractOp.getContext());
1376  std::reverse(globalPosition.begin(), globalPosition.end());
1377  extractOp.setStaticPosition(globalPosition);
1378  return success();
1379 }
1380 
1381 namespace {
1382 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1383 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1384 /// Compose TransposeOp permutations as we walk back.
1385 /// This helper class keeps an updated extraction position `extractPosition`
1386 /// with extra trailing sentinels.
1387 /// The sentinels encode the internal transposition status of the result vector.
1388 /// As we iterate, extractPosition is permuted and updated.
1389 class ExtractFromInsertTransposeChainState {
1390 public:
1391  ExtractFromInsertTransposeChainState(ExtractOp e);
1392 
1393  /// Iterate over producing insert and transpose ops until we find a fold.
1394  Value fold();
1395 
1396 private:
1397  /// Return true if the vector at position `a` is contained within the vector
1398  /// at position `b`. Under insert/extract semantics, this is the same as `a`
1399  /// is a prefix of `b`.
1400  template <typename ContainerA, typename ContainerB>
1401  bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1402  return a.size() <= b.size() &&
1403  std::equal(a.begin(), a.begin() + a.size(), b.begin());
1404  }
1405 
1406  /// Return true if the vector at position `a` intersects the vector at
1407  /// position `b`. Under insert/extract semantics, this is the same as equality
1408  /// of all entries of `a` that are >=0 with the corresponding entries of b.
1409  /// Comparison is on the common prefix (i.e. zip).
1410  template <typename ContainerA, typename ContainerB>
1411  bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1412  for (auto [elemA, elemB] : llvm::zip(a, b)) {
1413  if (elemA < 0 || elemB < 0)
1414  continue;
1415  if (elemA != elemB)
1416  return false;
1417  }
1418  return true;
1419  }
1420 
1421  /// Folding is only possible in the absence of an internal permutation in the
1422  /// result vector.
1423  bool canFold() {
1424  return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1425  }
1426 
1427  // Helper to get the next defining op of interest.
1428  void updateStateForNextIteration(Value v) {
1429  nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1430  nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1431  };
1432 
1433  // Case 1. If we hit a transpose, just compose the map and iterate.
1434  // Invariant: insert + transpose do not change rank, we can always compose.
1435  LogicalResult handleTransposeOp();
1436 
1437  // Case 2: the insert position matches extractPosition exactly, early return.
1438  LogicalResult handleInsertOpWithMatchingPos(Value &res);
1439 
1440  /// Case 3: if the insert position is a prefix of extractPosition, extract a
1441  /// portion of the source of the insert.
1442  /// Example:
1443  /// ```
1444  /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1445  /// // extractPosition == [1, 2, 3]
1446  /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1447  /// // can fold to vector.extract %source[0, 3]
1448  /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1449  /// ```
1450  /// To traverse through %source, we need to set the leading dims to 0 and
1451  /// drop the extra leading dims.
1452  /// This method updates the internal state.
1453  LogicalResult handleInsertOpWithPrefixPos(Value &res);
1454 
1455  /// Try to fold in place to extract(source, extractPosition) and return the
1456  /// folded result. Return null if folding is not possible (e.g. due to an
1457  /// internal tranposition in the result).
1458  Value tryToFoldExtractOpInPlace(Value source);
1459 
1460  ExtractOp extractOp;
1461  int64_t vectorRank;
1462  int64_t extractedRank;
1463 
1464  InsertOp nextInsertOp;
1465  TransposeOp nextTransposeOp;
1466 
1467  /// Sentinel values that encode the internal permutation status of the result.
1468  /// They are set to (-1, ... , -k) at the beginning and appended to
1469  /// `extractPosition`.
1470  /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1471  /// ensure that there is no internal transposition.
1472  /// Internal transposition cannot be accounted for with a folding pattern.
1473  // TODO: We could relax the internal transposition with an extra transposition
1474  // operation in a future canonicalizer.
1475  SmallVector<int64_t> sentinels;
1477 };
1478 } // namespace
1479 
1480 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1481  ExtractOp e)
1482  : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1483  extractedRank(extractOp.getNumIndices()) {
1484  assert(vectorRank >= extractedRank && "Extracted position overflow");
1485  sentinels.reserve(vectorRank - extractedRank);
1486  for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1487  sentinels.push_back(-(i + 1));
1488  extractPosition.assign(extractOp.getStaticPosition().begin(),
1489  extractOp.getStaticPosition().end());
1490  llvm::append_range(extractPosition, sentinels);
1491 }
1492 
1493 // Case 1. If we hit a transpose, just compose the map and iterate.
1494 // Invariant: insert + transpose do not change rank, we can always compose.
1495 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1496  // TODO: Canonicalization for dynamic position not implemented yet.
1497  if (extractOp.hasDynamicPosition())
1498  return failure();
1499 
1500  if (!nextTransposeOp)
1501  return failure();
1502  AffineMap m = inversePermutation(AffineMap::getPermutationMap(
1503  nextTransposeOp.getPermutation(), extractOp.getContext()));
1505  return success();
1506 }
1507 
1508 // Case 2: the insert position matches extractPosition exactly, early return.
1509 LogicalResult
1510 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1511  Value &res) {
1512  // TODO: Canonicalization for dynamic position not implemented yet.
1513  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1514  return failure();
1515 
1516  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1517  if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1518  return failure();
1519  // Case 2.a. early-exit fold.
1520  res = nextInsertOp.getSource();
1521  // Case 2.b. if internal transposition is present, canFold will be false.
1522  return success(canFold());
1523 }
1524 
1525 /// Case 3: if inserted position is a prefix of extractPosition,
1526 /// extract a portion of the source of the insertion.
1527 /// This method updates the internal state.
1528 LogicalResult
1529 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1530  // TODO: Canonicalization for dynamic position not implemented yet.
1531  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1532  return failure();
1533 
1534  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1535  if (!isContainedWithin(insertedPos, extractPosition))
1536  return failure();
1537  // Set leading dims to zero.
1538  std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1539  // Drop extra leading dims.
1540  extractPosition.erase(extractPosition.begin(),
1541  extractPosition.begin() + insertedPos.size());
1542  extractedRank = extractPosition.size() - sentinels.size();
1543  // Case 3.a. early-exit fold (break and delegate to post-while path).
1544  res = nextInsertOp.getSource();
1545  // Case 3.b. if internal transposition is present, canFold will be false.
1546  return success();
1547 }
1548 
1549 /// Try to fold in place to extract(source, extractPosition) and return the
1550 /// folded result. Return null if folding is not possible (e.g. due to an
1551 /// internal tranposition in the result).
1552 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1553  Value source) {
1554  // TODO: Canonicalization for dynamic position not implemented yet.
1555  if (extractOp.hasDynamicPosition())
1556  return Value();
1557 
1558  // If we can't fold (either internal transposition, or nothing to fold), bail.
1559  bool nothingToFold = (source == extractOp.getVector());
1560  if (nothingToFold || !canFold())
1561  return Value();
1562 
1563  // Otherwise, fold by updating the op inplace and return its result.
1564  OpBuilder b(extractOp.getContext());
1565  extractOp.setStaticPosition(
1566  ArrayRef(extractPosition).take_front(extractedRank));
1567  extractOp.getVectorMutable().assign(source);
1568  return extractOp.getResult();
1569 }
1570 
1571 /// Iterate over producing insert and transpose ops until we find a fold.
1572 Value ExtractFromInsertTransposeChainState::fold() {
1573  // TODO: Canonicalization for dynamic position not implemented yet.
1574  if (extractOp.hasDynamicPosition())
1575  return Value();
1576 
1577  Value valueToExtractFrom = extractOp.getVector();
1578  updateStateForNextIteration(valueToExtractFrom);
1579  while (nextInsertOp || nextTransposeOp) {
1580  // Case 1. If we hit a transpose, just compose the map and iterate.
1581  // Invariant: insert + transpose do not change rank, we can always compose.
1582  if (succeeded(handleTransposeOp())) {
1583  valueToExtractFrom = nextTransposeOp.getVector();
1584  updateStateForNextIteration(valueToExtractFrom);
1585  continue;
1586  }
1587 
1588  Value result;
1589  // Case 2: the position match exactly.
1590  if (succeeded(handleInsertOpWithMatchingPos(result)))
1591  return result;
1592 
1593  // Case 3: if the inserted position is a prefix of extractPosition, we can
1594  // just extract a portion of the source of the insert.
1595  if (succeeded(handleInsertOpWithPrefixPos(result)))
1596  return tryToFoldExtractOpInPlace(result);
1597 
1598  // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1599  // values. This is a more difficult case and we bail.
1600  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1601  if (isContainedWithin(extractPosition, insertedPos) ||
1602  intersectsWhereNonNegative(extractPosition, insertedPos))
1603  return Value();
1604 
1605  // Case 5: No intersection, we forward the extract to insertOp.dest().
1606  valueToExtractFrom = nextInsertOp.getDest();
1607  updateStateForNextIteration(valueToExtractFrom);
1608  }
1609  // If after all this we can fold, go for it.
1610  return tryToFoldExtractOpInPlace(valueToExtractFrom);
1611 }
1612 
1613 /// Returns true if the operation has a 0-D vector type operand or result.
1614 static bool hasZeroDimVectors(Operation *op) {
1615  auto hasZeroDimVectorType = [](Type type) -> bool {
1616  auto vecType = dyn_cast<VectorType>(type);
1617  return vecType && vecType.getRank() == 0;
1618  };
1619 
1620  return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1621  llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1622 }
1623 
1624 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1625 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1626  // TODO: Canonicalization for dynamic position not implemented yet.
1627  if (extractOp.hasDynamicPosition())
1628  return Value();
1629 
1630  Operation *defOp = extractOp.getVector().getDefiningOp();
1631  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1632  return Value();
1633 
1634  Value source = defOp->getOperand(0);
1635  if (extractOp.getType() == source.getType())
1636  return source;
1637  auto getRank = [](Type type) {
1638  return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1639  : 0;
1640  };
1641 
1642  // If splat or broadcast from a scalar, just return the source scalar.
1643  unsigned broadcastSrcRank = getRank(source.getType());
1644  if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
1645  return source;
1646 
1647  unsigned extractResultRank = getRank(extractOp.getType());
1648  if (extractResultRank >= broadcastSrcRank)
1649  return Value();
1650  // Check that the dimension of the result haven't been broadcasted.
1651  auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1652  auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
1653  if (extractVecType && broadcastVecType &&
1654  extractVecType.getShape() !=
1655  broadcastVecType.getShape().take_back(extractResultRank))
1656  return Value();
1657 
1658  auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1659  int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1660 
1661  // Detect all the positions that come from "dim-1" broadcasting.
1662  // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1663  // extract position to `0` when extracting from the source operand.
1664  llvm::SetVector<int64_t> broadcastedUnitDims =
1665  broadcastOp.computeBroadcastedUnitDims();
1666  SmallVector<int64_t> extractPos(extractOp.getStaticPosition());
1667  int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1668  for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1669  if (broadcastedUnitDims.contains(i))
1670  extractPos[i] = 0;
1671  // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1672  // matching extract position when extracting from the source operand.
1673  int64_t rankDiff = broadcastSrcRank - extractResultRank;
1674  extractPos.erase(extractPos.begin(),
1675  std::next(extractPos.begin(), extractPos.size() - rankDiff));
1676  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1677  OpBuilder b(extractOp.getContext());
1678  extractOp.setOperand(0, source);
1679  extractOp.setStaticPosition(extractPos);
1680  return extractOp.getResult();
1681 }
1682 
1683 // Fold extractOp with source coming from ShapeCast op.
1684 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1685  // TODO: Canonicalization for dynamic position not implemented yet.
1686  if (extractOp.hasDynamicPosition())
1687  return Value();
1688 
1689  auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1690  if (!shapeCastOp)
1691  return Value();
1692 
1693  // 0-D vectors not supported.
1694  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1695  if (hasZeroDimVectors(shapeCastOp))
1696  return Value();
1697 
1698  // Get the nth dimension size starting from lowest dimension.
1699  auto getDimReverse = [](VectorType type, int64_t n) {
1700  return type.getShape().take_back(n + 1).front();
1701  };
1702  int64_t destinationRank =
1703  llvm::isa<VectorType>(extractOp.getType())
1704  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1705  : 0;
1706  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1707  return Value();
1708  if (destinationRank > 0) {
1709  auto destinationType =
1710  llvm::cast<VectorType>(extractOp.getResult().getType());
1711  for (int64_t i = 0; i < destinationRank; i++) {
1712  // The lowest dimension of the destination must match the lowest
1713  // dimension of the shapecast op source.
1714  // TODO: This case could be support in a canonicalization pattern.
1715  if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1716  getDimReverse(destinationType, i))
1717  return Value();
1718  }
1719  }
1720  // Extract the strides associated with the extract op vector source. Then use
1721  // this to calculate a linearized position for the extract.
1722  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1723  std::reverse(extractedPos.begin(), extractedPos.end());
1724  SmallVector<int64_t, 4> strides;
1725  int64_t stride = 1;
1726  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1727  strides.push_back(stride);
1728  stride *=
1729  getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1730  }
1731 
1732  int64_t position = linearize(extractedPos, strides);
1733  // Then extract the strides associated to the shapeCast op vector source and
1734  // delinearize the position using those strides.
1735  SmallVector<int64_t, 4> newStrides;
1736  int64_t numDimension =
1737  shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1738  stride = 1;
1739  for (int64_t i = 0; i < numDimension; i++) {
1740  newStrides.push_back(stride);
1741  stride *=
1742  getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1743  }
1744  std::reverse(newStrides.begin(), newStrides.end());
1745  SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1746  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1747  OpBuilder b(extractOp.getContext());
1748  extractOp.setStaticPosition(newPosition);
1749  extractOp.setOperand(0, shapeCastOp.getSource());
1750  return extractOp.getResult();
1751 }
1752 
1753 /// Fold an ExtractOp from ExtractStridedSliceOp.
1754 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1755  // TODO: Canonicalization for dynamic position not implemented yet.
1756  if (extractOp.hasDynamicPosition())
1757  return Value();
1758 
1759  auto extractStridedSliceOp =
1760  extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1761  if (!extractStridedSliceOp)
1762  return Value();
1763 
1764  // 0-D vectors not supported.
1765  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1766  if (hasZeroDimVectors(extractStridedSliceOp))
1767  return Value();
1768 
1769  // Return if 'extractStridedSliceOp' has non-unit strides.
1770  if (extractStridedSliceOp.hasNonUnitStrides())
1771  return Value();
1772 
1773  // Trim offsets for dimensions fully extracted.
1774  auto sliceOffsets =
1775  extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1776  while (!sliceOffsets.empty()) {
1777  size_t lastOffset = sliceOffsets.size() - 1;
1778  if (sliceOffsets.back() != 0 ||
1779  extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1780  extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1781  break;
1782  sliceOffsets.pop_back();
1783  }
1784  unsigned destinationRank = 0;
1785  if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1786  destinationRank = vecType.getRank();
1787  // The dimensions of the result need to be untouched by the
1788  // extractStridedSlice op.
1789  if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1790  sliceOffsets.size())
1791  return Value();
1792 
1793  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1794  assert(extractedPos.size() >= sliceOffsets.size());
1795  for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1796  extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1797  extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1798 
1799  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1800  OpBuilder b(extractOp.getContext());
1801  extractOp.setStaticPosition(extractedPos);
1802  return extractOp.getResult();
1803 }
1804 
1805 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1806 static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1807  // TODO: Canonicalization for dynamic position not implemented yet.
1808  if (extractOp.hasDynamicPosition())
1809  return Value();
1810 
1811  int64_t destinationRank =
1812  llvm::isa<VectorType>(extractOp.getType())
1813  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1814  : 0;
1815  auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1816  if (!insertOp)
1817  return Value();
1818 
1819  // 0-D vectors not supported.
1820  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1821  if (hasZeroDimVectors(insertOp))
1822  return Value();
1823 
1824  while (insertOp) {
1825  int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1826  insertOp.getSourceVectorType().getRank();
1827  if (destinationRank > insertOp.getSourceVectorType().getRank())
1828  return Value();
1829  auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1830  ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1831 
1832  if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1833  return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1834  }))
1835  return Value();
1836  bool disjoint = false;
1837  SmallVector<int64_t, 4> offsetDiffs;
1838  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1839  int64_t start = insertOffsets[dim];
1840  int64_t size =
1841  (dim < insertRankDiff)
1842  ? 1
1843  : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1844  int64_t end = start + size;
1845  int64_t offset = extractOffsets[dim];
1846  // Check if the start of the extract offset is in the interval inserted.
1847  if (start <= offset && offset < end) {
1848  if (dim >= insertRankDiff)
1849  offsetDiffs.push_back(offset - start);
1850  continue;
1851  }
1852  disjoint = true;
1853  break;
1854  }
1855  // The extract element chunk overlap with the vector inserted.
1856  if (!disjoint) {
1857  // If any of the inner dimensions are only partially inserted we have a
1858  // partial overlap.
1859  int64_t srcRankDiff =
1860  insertOp.getSourceVectorType().getRank() - destinationRank;
1861  for (int64_t i = 0; i < destinationRank; i++) {
1862  if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1863  insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1864  insertRankDiff))
1865  return Value();
1866  }
1867  extractOp.getVectorMutable().assign(insertOp.getSource());
1868  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1869  OpBuilder b(extractOp.getContext());
1870  extractOp.setStaticPosition(offsetDiffs);
1871  return extractOp.getResult();
1872  }
1873  // If the chunk extracted is disjoint from the chunk inserted, keep
1874  // looking in the insert chain.
1875  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1876  }
1877  return Value();
1878 }
1879 
1880 /// Try to fold the extraction of a scalar from a vector defined by
1881 /// vector.from_elements. E.g.:
1882 ///
1883 /// %0 = vector.from_elements %a, %b : vector<2xf32>
1884 /// %1 = vector.extract %0[0] : f32 from vector<2xf32>
1885 /// ==> fold to %a
1886 static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
1887  // Dynamic extractions cannot be folded.
1888  if (extractOp.hasDynamicPosition())
1889  return {};
1890 
1891  // Look for extract(from_elements).
1892  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
1893  if (!fromElementsOp)
1894  return {};
1895 
1896  // Scalable vectors are not supported.
1897  auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1898  if (vecType.isScalable())
1899  return {};
1900 
1901  // Only extractions of scalars are supported.
1902  int64_t rank = vecType.getRank();
1903  ArrayRef<int64_t> indices = extractOp.getStaticPosition();
1904  if (extractOp.getType() != vecType.getElementType())
1905  return {};
1906  assert(static_cast<int64_t>(indices.size()) == rank &&
1907  "unexpected number of indices");
1908 
1909  // Compute flattened/linearized index and fold to operand.
1910  int flatIndex = 0;
1911  int stride = 1;
1912  for (int i = rank - 1; i >= 0; --i) {
1913  flatIndex += indices[i] * stride;
1914  stride *= vecType.getDimSize(i);
1915  }
1916  return fromElementsOp.getElements()[flatIndex];
1917 }
1918 
1919 OpFoldResult ExtractOp::fold(FoldAdaptor) {
1920  // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
1921  // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
1922  // mismatch).
1923  if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
1924  return getVector();
1925  if (succeeded(foldExtractOpFromExtractChain(*this)))
1926  return getResult();
1927  if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
1928  return res;
1929  if (auto res = foldExtractFromBroadcast(*this))
1930  return res;
1931  if (auto res = foldExtractFromShapeCast(*this))
1932  return res;
1933  if (auto val = foldExtractFromExtractStrided(*this))
1934  return val;
1935  if (auto val = foldExtractStridedOpFromInsertChain(*this))
1936  return val;
1937  if (auto val = foldScalarExtractFromFromElements(*this))
1938  return val;
1939  return OpFoldResult();
1940 }
1941 
1942 namespace {
1943 
1944 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
1945 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
1946 public:
1948 
1949  LogicalResult matchAndRewrite(ExtractOp extractOp,
1950  PatternRewriter &rewriter) const override {
1951  Operation *defOp = extractOp.getVector().getDefiningOp();
1952  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1953  return failure();
1954 
1955  Value source = defOp->getOperand(0);
1956  if (extractOp.getType() == source.getType())
1957  return failure();
1958  auto getRank = [](Type type) {
1959  return llvm::isa<VectorType>(type)
1960  ? llvm::cast<VectorType>(type).getRank()
1961  : 0;
1962  };
1963  unsigned broadcastSrcRank = getRank(source.getType());
1964  unsigned extractResultRank = getRank(extractOp.getType());
1965  // We only consider the case where the rank of the source is less than or
1966  // equal to the rank of the extract dst. The other cases are handled in the
1967  // folding patterns.
1968  if (extractResultRank < broadcastSrcRank)
1969  return failure();
1970 
1971  // Special case if broadcast src is a 0D vector.
1972  if (extractResultRank == 0) {
1973  assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
1974  rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
1975  return success();
1976  }
1977  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1978  extractOp, extractOp.getType(), source);
1979  return success();
1980  }
1981 };
1982 
1983 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
1984 class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
1985 public:
1987 
1988  LogicalResult matchAndRewrite(ExtractOp extractOp,
1989  PatternRewriter &rewriter) const override {
1990  // Return if 'ExtractOp' operand is not defined by a splat vector
1991  // ConstantOp.
1992  Value sourceVector = extractOp.getVector();
1993  Attribute vectorCst;
1994  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
1995  return failure();
1996  auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
1997  if (!splat)
1998  return failure();
1999  TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
2000  if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
2001  newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2002  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2003  return success();
2004  }
2005 };
2006 
2007 // Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
2008 class ExtractOpNonSplatConstantFolder final
2009  : public OpRewritePattern<ExtractOp> {
2010 public:
2012 
2013  LogicalResult matchAndRewrite(ExtractOp extractOp,
2014  PatternRewriter &rewriter) const override {
2015  // TODO: Canonicalization for dynamic position not implemented yet.
2016  if (extractOp.hasDynamicPosition())
2017  return failure();
2018 
2019  // Return if 'ExtractOp' operand is not defined by a compatible vector
2020  // ConstantOp.
2021  Value sourceVector = extractOp.getVector();
2022  Attribute vectorCst;
2023  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2024  return failure();
2025 
2026  auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
2027  if (vecTy.isScalable())
2028  return failure();
2029 
2030  // The splat case is handled by `ExtractOpSplatConstantFolder`.
2031  auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2032  if (!dense || dense.isSplat())
2033  return failure();
2034 
2035  // Calculate the linearized position of the continuous chunk of elements to
2036  // extract.
2037  llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2038  copy(extractOp.getStaticPosition(), completePositions.begin());
2039  int64_t elemBeginPosition =
2040  linearize(completePositions, computeStrides(vecTy.getShape()));
2041  auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
2042 
2043  TypedAttr newAttr;
2044  if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
2045  SmallVector<Attribute> elementValues(
2046  denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2047  newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2048  } else {
2049  newAttr = *denseValuesBegin;
2050  }
2051 
2052  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2053  return success();
2054  }
2055 };
2056 
2057 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2058 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2059 public:
2061 
2062  LogicalResult matchAndRewrite(ExtractOp extractOp,
2063  PatternRewriter &rewriter) const override {
2064  auto createMaskOp =
2065  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2066  if (!createMaskOp)
2067  return failure();
2068 
2069  VectorType extractedMaskType =
2070  llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2071 
2072  if (!extractedMaskType)
2073  return failure();
2074 
2075  auto maskOperands = createMaskOp.getOperands();
2076  ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2077  VectorType maskType = createMaskOp.getVectorType();
2078 
2079  bool containsUnknownDims = false;
2080  bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2081 
2082  for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2083  dimIdx++) {
2084  int64_t pos = extractOpPos[dimIdx];
2085  Value operand = maskOperands[dimIdx];
2086  auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2087  if (!constantOp) {
2088  // Bounds of this dim unknown.
2089  containsUnknownDims = true;
2090  continue;
2091  }
2092 
2093  int64_t createMaskBound =
2094  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2095 
2096  if (pos != ShapedType::kDynamic) {
2097  // If any position is outside the range from the `create_mask`, then the
2098  // extracted mask will be all-false.
2099  allFalse |= pos >= createMaskBound;
2100  } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2101  // This dim is not all-true and since this is a dynamic index we don't
2102  // know if the extraction is within the true or false region.
2103  // Note: Zero dims have already handled via getMaskFormat().
2104  containsUnknownDims = true;
2105  }
2106  }
2107 
2108  if (allFalse) {
2109  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2110  extractOp, DenseElementsAttr::get(extractedMaskType, false));
2111  } else if (!containsUnknownDims) {
2112  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2113  extractOp, extractedMaskType,
2114  maskOperands.drop_front(extractOpPos.size()));
2115  } else {
2116  return failure();
2117  }
2118  return success();
2119  }
2120 };
2121 
2122 // Folds extract(shape_cast(..)) into shape_cast when the total element count
2123 // does not change.
2124 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2125  PatternRewriter &rewriter) {
2126  auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2127  if (!castOp)
2128  return failure();
2129 
2130  VectorType sourceType = castOp.getSourceVectorType();
2131  auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2132  if (!targetType)
2133  return failure();
2134 
2135  if (sourceType.getNumElements() != targetType.getNumElements())
2136  return failure();
2137 
2138  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2139  castOp.getSource());
2140  return success();
2141 }
2142 
2143 /// Try to canonicalize the extraction of a subvector from a vector defined by
2144 /// vector.from_elements. E.g.:
2145 ///
2146 /// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2147 /// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2148 /// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2149 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2150  PatternRewriter &rewriter) {
2151  // Dynamic positions are not supported.
2152  if (extractOp.hasDynamicPosition())
2153  return failure();
2154 
2155  // Scalar extracts are handled by the folder.
2156  auto resultType = dyn_cast<VectorType>(extractOp.getType());
2157  if (!resultType)
2158  return failure();
2159 
2160  // Look for extracts from a from_elements op.
2161  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2162  if (!fromElementsOp)
2163  return failure();
2164  VectorType inputType = fromElementsOp.getType();
2165 
2166  // Scalable vectors are not supported.
2167  if (resultType.isScalable() || inputType.isScalable())
2168  return failure();
2169 
2170  // Compute the position of first extracted element and flatten/linearize the
2171  // position.
2172  SmallVector<int64_t> firstElementPos =
2173  llvm::to_vector(extractOp.getStaticPosition());
2174  firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2175  int flatIndex = 0;
2176  int stride = 1;
2177  for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2178  flatIndex += firstElementPos[i] * stride;
2179  stride *= inputType.getDimSize(i);
2180  }
2181 
2182  // Replace the op with a smaller from_elements op.
2183  rewriter.replaceOpWithNewOp<FromElementsOp>(
2184  extractOp, resultType,
2185  fromElementsOp.getElements().slice(flatIndex,
2186  resultType.getNumElements()));
2187  return success();
2188 }
2189 } // namespace
2190 
2191 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2192  MLIRContext *context) {
2193  results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2194  ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2195  results.add(foldExtractFromShapeCastToShapeCast);
2196  results.add(foldExtractFromFromElements);
2197 }
2198 
2199 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
2200  SmallVectorImpl<int64_t> &results) {
2201  for (auto attr : arrayAttr)
2202  results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2203 }
2204 
2205 //===----------------------------------------------------------------------===//
2206 // FmaOp
2207 //===----------------------------------------------------------------------===//
2208 
2209 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2210  return llvm::to_vector<4>(getVectorType().getShape());
2211 }
2212 
2213 //===----------------------------------------------------------------------===//
2214 // FromElementsOp
2215 //===----------------------------------------------------------------------===//
2216 
2217 /// Rewrite a vector.from_elements into a vector.splat if all elements are the
2218 /// same SSA value. E.g.:
2219 ///
2220 /// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2221 /// ==> rewrite to vector.splat %a : vector<3xf32>
2222 static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
2223  PatternRewriter &rewriter) {
2224  if (!llvm::all_equal(fromElementsOp.getElements()))
2225  return failure();
2226  rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
2227  fromElementsOp.getElements().front());
2228  return success();
2229 }
2230 
2231 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2232  MLIRContext *context) {
2234 }
2235 
2236 //===----------------------------------------------------------------------===//
2237 // BroadcastOp
2238 //===----------------------------------------------------------------------===//
2239 
2240 /// Return the dimensions of the result vector that were formerly ones in the
2241 /// source tensor and thus correspond to "dim-1" broadcasting.
2244  ArrayRef<int64_t> dstShape) {
2245  int64_t rankDiff = dstShape.size() - srcShape.size();
2246  int64_t dstDim = rankDiff;
2248  for (auto [s1, s2] :
2249  llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2250  if (s1 != s2) {
2251  assert(s1 == 1 && "expected dim-1 broadcasting");
2252  res.insert(dstDim);
2253  }
2254  ++dstDim;
2255  }
2256  return res;
2257 }
2258 
2260  // Scalar broadcast is without any unit dim broadcast.
2261  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2262  if (!srcVectorType)
2263  return {};
2264  return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2265  getResultVectorType().getShape());
2266 }
2267 
2268 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2269 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
2270 /// This requires (and asserts) that the broadcast is free of dim-1
2271 /// broadcasting.
2272 /// Since vector.broadcast only allows expanding leading dimensions, an extra
2273 /// vector.transpose may be inserted to make the broadcast possible.
2274 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2275 /// the helper will assert. This means:
2276 /// 1. `dstShape` must not be empty.
2277 /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2278 /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2279 // must match the `value` shape.
2280 Value BroadcastOp::createOrFoldBroadcastOp(
2281  OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2282  const llvm::SetVector<int64_t> &broadcastedDims) {
2283  assert(!dstShape.empty() && "unexpected empty dst shape");
2284 
2285  // Well-formedness check.
2286  SmallVector<int64_t> checkShape;
2287  for (int i = 0, e = dstShape.size(); i < e; ++i) {
2288  if (broadcastedDims.contains(i))
2289  continue;
2290  checkShape.push_back(dstShape[i]);
2291  }
2292  assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2293  "ill-formed broadcastedDims contains values not confined to "
2294  "destVectorShape");
2295 
2296  Location loc = value.getLoc();
2297  Type elementType = getElementTypeOrSelf(value.getType());
2298  VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2299  VectorType dstVectorType = VectorType::get(dstShape, elementType);
2300 
2301  // Step 2. If scalar -> dstShape broadcast, just do it.
2302  if (!srcVectorType) {
2303  assert(checkShape.empty() &&
2304  "ill-formed createOrFoldBroadcastOp arguments");
2305  return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2306  }
2307 
2308  assert(srcVectorType.getShape().equals(checkShape) &&
2309  "ill-formed createOrFoldBroadcastOp arguments");
2310 
2311  // Step 3. Since vector.broadcast only allows creating leading dims,
2312  // vector -> dstShape broadcast may require a transpose.
2313  // Traverse the dims in order and construct:
2314  // 1. The leading entries of the broadcastShape that is guaranteed to be
2315  // achievable by a simple broadcast.
2316  // 2. The induced permutation for the subsequent vector.transpose that will
2317  // bring us from `broadcastShape` back to he desired `dstShape`.
2318  // If the induced permutation is not the identity, create a vector.transpose.
2319  SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2320  broadcastShape.reserve(dstShape.size());
2321  // Consider the example:
2322  // srcShape = 2x4
2323  // dstShape = 1x2x3x4x5
2324  // broadcastedDims = [0, 2, 4]
2325  //
2326  // We want to build:
2327  // broadcastShape = 1x3x5x2x4
2328  // permutation = [0, 2, 4, 1, 3]
2329  // ---V--- -----V-----
2330  // leading broadcast part src shape part
2331  //
2332  // Note that the trailing dims of broadcastShape are exactly the srcShape
2333  // by construction.
2334  // nextSrcShapeDim is used to keep track of where in the permutation the
2335  // "src shape part" occurs.
2336  int64_t nextSrcShapeDim = broadcastedDims.size();
2337  for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2338  if (broadcastedDims.contains(i)) {
2339  // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2340  // bring it to the head of the broadcastShape.
2341  // It will need to be permuted back from `broadcastShape.size() - 1` into
2342  // position `i`.
2343  broadcastShape.push_back(dstShape[i]);
2344  permutation[i] = broadcastShape.size() - 1;
2345  } else {
2346  // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2347  // shape and needs to be permuted into position `i`.
2348  // Don't touch `broadcastShape` here, the whole srcShape will be
2349  // appended after.
2350  permutation[i] = nextSrcShapeDim++;
2351  }
2352  }
2353  // 3.c. Append the srcShape.
2354  llvm::append_range(broadcastShape, srcVectorType.getShape());
2355 
2356  // Ensure there are no dim-1 broadcasts.
2357  assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2358  .empty() &&
2359  "unexpected dim-1 broadcast");
2360 
2361  VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2362  assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2363  vector::BroadcastableToResult::Success &&
2364  "must be broadcastable");
2365  Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2366  // Step 4. If we find any dimension that indeed needs to be permuted,
2367  // immediately return a new vector.transpose.
2368  for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2369  if (permutation[i] != i)
2370  return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2371  // Otherwise return res.
2372  return res;
2373 }
2374 
2376 mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
2377  std::pair<int, int> *mismatchingDims) {
2378  // Broadcast scalar to vector of the same element type.
2379  if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
2380  getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
2381  return BroadcastableToResult::Success;
2382  // From now on, only vectors broadcast.
2383  VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2384  if (!srcVectorType)
2385  return BroadcastableToResult::SourceTypeNotAVector;
2386 
2387  int64_t srcRank = srcVectorType.getRank();
2388  int64_t dstRank = dstVectorType.getRank();
2389  if (srcRank > dstRank)
2390  return BroadcastableToResult::SourceRankHigher;
2391  // Source has an exact match or singleton value for all trailing dimensions
2392  // (all leading dimensions are simply duplicated).
2393  int64_t lead = dstRank - srcRank;
2394  for (int64_t r = 0; r < srcRank; ++r) {
2395  int64_t srcDim = srcVectorType.getDimSize(r);
2396  int64_t dstDim = dstVectorType.getDimSize(lead + r);
2397  if (srcDim != 1 && srcDim != dstDim) {
2398  if (mismatchingDims) {
2399  mismatchingDims->first = srcDim;
2400  mismatchingDims->second = dstDim;
2401  }
2402  return BroadcastableToResult::DimensionMismatch;
2403  }
2404  }
2405 
2406  return BroadcastableToResult::Success;
2407 }
2408 
2409 LogicalResult BroadcastOp::verify() {
2410  std::pair<int, int> mismatchingDims;
2412  getSourceType(), getResultVectorType(), &mismatchingDims);
2413  if (res == BroadcastableToResult::Success)
2414  return success();
2415  if (res == BroadcastableToResult::SourceRankHigher)
2416  return emitOpError("source rank higher than destination rank");
2417  if (res == BroadcastableToResult::DimensionMismatch)
2418  return emitOpError("dimension mismatch (")
2419  << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
2420  if (res == BroadcastableToResult::SourceTypeNotAVector)
2421  return emitOpError("source type is not a vector");
2422  llvm_unreachable("unexpected vector.broadcast op error");
2423 }
2424 
2425 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
2426  if (getSourceType() == getResultVectorType())
2427  return getSource();
2428  if (!adaptor.getSource())
2429  return {};
2430  auto vectorType = getResultVectorType();
2431  if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
2432  return DenseElementsAttr::get(vectorType, adaptor.getSource());
2433  if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2434  return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
2435  return {};
2436 }
2437 
2438 namespace {
2439 
2440 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
2441 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2443 
2444  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2445  PatternRewriter &rewriter) const override {
2446  auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2447  if (!srcBroadcast)
2448  return failure();
2449  rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
2450  broadcastOp.getResultVectorType(),
2451  srcBroadcast.getSource());
2452  return success();
2453  }
2454 };
2455 } // namespace
2456 
2457 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2458  MLIRContext *context) {
2459  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2460  // calling `populateCastAwayVectorLeadingOneDimPatterns`
2461  results.add<BroadcastFolder>(context);
2462 }
2463 
2464 //===----------------------------------------------------------------------===//
2465 // ShuffleOp
2466 //===----------------------------------------------------------------------===//
2467 
2468 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
2469  Value v2, ArrayRef<int64_t> mask) {
2470  build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask));
2471 }
2472 
2473 LogicalResult ShuffleOp::verify() {
2474  VectorType resultType = getResultVectorType();
2475  VectorType v1Type = getV1VectorType();
2476  VectorType v2Type = getV2VectorType();
2477  // Verify ranks.
2478  int64_t resRank = resultType.getRank();
2479  int64_t v1Rank = v1Type.getRank();
2480  int64_t v2Rank = v2Type.getRank();
2481  bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2482  bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2483  if (!wellFormed0DCase && !wellFormedNDCase)
2484  return emitOpError("rank mismatch");
2485 
2486  // Verify all but leading dimension sizes.
2487  for (int64_t r = 1; r < v1Rank; ++r) {
2488  int64_t resDim = resultType.getDimSize(r);
2489  int64_t v1Dim = v1Type.getDimSize(r);
2490  int64_t v2Dim = v2Type.getDimSize(r);
2491  if (resDim != v1Dim || v1Dim != v2Dim)
2492  return emitOpError("dimension mismatch");
2493  }
2494  // Verify mask length.
2495  auto maskAttr = getMask().getValue();
2496  int64_t maskLength = maskAttr.size();
2497  if (maskLength <= 0)
2498  return emitOpError("invalid mask length");
2499  if (maskLength != resultType.getDimSize(0))
2500  return emitOpError("mask length mismatch");
2501  // Verify all indices.
2502  int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2503  (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2504  for (const auto &en : llvm::enumerate(maskAttr)) {
2505  auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
2506  if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
2507  return emitOpError("mask index #") << (en.index() + 1) << " out of range";
2508  }
2509  return success();
2510 }
2511 
2512 LogicalResult
2513 ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
2514  ShuffleOp::Adaptor adaptor,
2515  SmallVectorImpl<Type> &inferredReturnTypes) {
2516  auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2517  auto v1Rank = v1Type.getRank();
2518  // Construct resulting type: leading dimension matches mask
2519  // length, all trailing dimensions match the operands.
2521  shape.reserve(v1Rank);
2522  shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2523  // In the 0-D case there is no trailing shape to append.
2524  if (v1Rank > 0)
2525  llvm::append_range(shape, v1Type.getShape().drop_front());
2526  inferredReturnTypes.push_back(
2527  VectorType::get(shape, v1Type.getElementType()));
2528  return success();
2529 }
2530 
2531 static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
2532  uint64_t expected = begin;
2533  return idxArr.size() == width &&
2534  llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
2535  [&expected](auto attr) {
2536  return attr.getZExtValue() == expected++;
2537  });
2538 }
2539 
2540 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2541  VectorType v1Type = getV1VectorType();
2542  // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
2543  // but must be a canonicalization into a vector.broadcast.
2544  if (v1Type.getRank() == 0)
2545  return {};
2546 
2547  // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
2548  if (!v1Type.isScalable() &&
2549  isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
2550  return getV1();
2551  // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
2552  if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
2553  isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
2554  getV2VectorType().getDimSize(0)))
2555  return getV2();
2556 
2557  Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
2558  if (!lhs || !rhs)
2559  return {};
2560 
2561  auto lhsType =
2562  llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType());
2563  // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2564  // manipulation.
2565  if (lhsType.getRank() != 1)
2566  return {};
2567  int64_t lhsSize = lhsType.getDimSize(0);
2568 
2569  SmallVector<Attribute> results;
2570  auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
2571  auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
2572  for (const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
2573  int64_t i = index.getZExtValue();
2574  if (i >= lhsSize) {
2575  results.push_back(rhsElements[i - lhsSize]);
2576  } else {
2577  results.push_back(lhsElements[i]);
2578  }
2579  }
2580 
2581  return DenseElementsAttr::get(getResultVectorType(), results);
2582 }
2583 
2584 namespace {
2585 
2586 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
2587 // to a broadcast.
2588 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
2590 
2591  LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
2592  PatternRewriter &rewriter) const override {
2593  VectorType v1VectorType = shuffleOp.getV1VectorType();
2594  ArrayAttr mask = shuffleOp.getMask();
2595  if (v1VectorType.getRank() > 0)
2596  return failure();
2597  if (mask.size() != 1)
2598  return failure();
2599  VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
2600  if (llvm::cast<IntegerAttr>(mask[0]).getInt() == 0)
2601  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2602  shuffleOp.getV1());
2603  else
2604  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2605  shuffleOp.getV2());
2606  return success();
2607  }
2608 };
2609 
2610 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
2611 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
2612 public:
2614 
2615  LogicalResult matchAndRewrite(ShuffleOp op,
2616  PatternRewriter &rewriter) const override {
2617  auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2618  auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2619 
2620  if (!v1Splat || !v2Splat)
2621  return failure();
2622 
2623  if (v1Splat.getInput() != v2Splat.getInput())
2624  return failure();
2625 
2626  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
2627  return success();
2628  }
2629 };
2630 
2631 /// Pattern to rewrite a fixed-size interleave via vector.shuffle to
2632 /// vector.interleave.
2633 class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
2634 public:
2636 
2637  LogicalResult matchAndRewrite(ShuffleOp op,
2638  PatternRewriter &rewriter) const override {
2639  VectorType resultType = op.getResultVectorType();
2640  if (resultType.isScalable())
2641  return rewriter.notifyMatchFailure(
2642  op, "ShuffleOp can't represent a scalable interleave");
2643 
2644  if (resultType.getRank() != 1)
2645  return rewriter.notifyMatchFailure(
2646  op, "ShuffleOp can't represent an n-D interleave");
2647 
2648  VectorType sourceType = op.getV1VectorType();
2649  if (sourceType != op.getV2VectorType() ||
2650  sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2651  return rewriter.notifyMatchFailure(
2652  op, "ShuffleOp types don't match an interleave");
2653  }
2654 
2655  ArrayAttr shuffleMask = op.getMask();
2656  int64_t resultVectorSize = resultType.getNumElements();
2657  for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2658  int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
2659  int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
2660  if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2661  return rewriter.notifyMatchFailure(op,
2662  "ShuffleOp mask not interleaving");
2663  }
2664 
2665  rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
2666  return success();
2667  }
2668 };
2669 
2670 } // namespace
2671 
2672 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
2673  MLIRContext *context) {
2674  results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2675  context);
2676 }
2677 
2678 //===----------------------------------------------------------------------===//
2679 // InsertElementOp
2680 //===----------------------------------------------------------------------===//
2681 
2682 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
2683  Value source, Value dest) {
2684  build(builder, result, source, dest, {});
2685 }
2686 
2687 LogicalResult InsertElementOp::verify() {
2688  auto dstVectorType = getDestVectorType();
2689  if (dstVectorType.getRank() == 0) {
2690  if (getPosition())
2691  return emitOpError("expected position to be empty with 0-D vector");
2692  return success();
2693  }
2694  if (dstVectorType.getRank() != 1)
2695  return emitOpError("unexpected >1 vector rank");
2696  if (!getPosition())
2697  return emitOpError("expected position for 1-D vector");
2698  return success();
2699 }
2700 
2701 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2702  // Skip the 0-D vector here.
2703  if (!adaptor.getPosition())
2704  return {};
2705 
2706  auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2707  auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2708  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2709  if (!src || !dst || !pos)
2710  return {};
2711 
2712  if (src.getType() != getDestVectorType().getElementType())
2713  return {};
2714 
2715  auto dstElements = dst.getValues<Attribute>();
2716 
2717  SmallVector<Attribute> results(dstElements);
2718 
2719  uint64_t posIdx = pos.getInt();
2720  if (posIdx >= results.size())
2721  return {};
2722  results[posIdx] = src;
2723 
2724  return DenseElementsAttr::get(getDestVectorType(), results);
2725 }
2726 
2727 //===----------------------------------------------------------------------===//
2728 // InsertOp
2729 //===----------------------------------------------------------------------===//
2730 
2731 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2732  Value source, Value dest, int64_t position) {
2733  build(builder, result, source, dest, ArrayRef<int64_t>{position});
2734 }
2735 
2736 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2737  Value source, Value dest, OpFoldResult position) {
2738  build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
2739 }
2740 
2741 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2742  Value source, Value dest,
2743  ArrayRef<int64_t> position) {
2744  SmallVector<OpFoldResult> posVals;
2745  posVals.reserve(position.size());
2746  llvm::transform(position, std::back_inserter(posVals),
2747  [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
2748  build(builder, result, source, dest, posVals);
2749 }
2750 
2751 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2752  Value source, Value dest,
2753  ArrayRef<OpFoldResult> position) {
2754  SmallVector<int64_t> staticPos;
2755  SmallVector<Value> dynamicPos;
2756  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
2757  build(builder, result, source, dest, dynamicPos,
2758  builder.getDenseI64ArrayAttr(staticPos));
2759 }
2760 
2761 LogicalResult InsertOp::verify() {
2762  SmallVector<OpFoldResult> position = getMixedPosition();
2763  auto destVectorType = getDestVectorType();
2764  if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
2765  return emitOpError(
2766  "expected position attribute of rank no greater than dest vector rank");
2767  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2768  if (srcVectorType &&
2769  (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2770  static_cast<unsigned>(destVectorType.getRank())))
2771  return emitOpError("expected position attribute rank + source rank to "
2772  "match dest vector rank");
2773  if (!srcVectorType &&
2774  (position.size() != static_cast<unsigned>(destVectorType.getRank())))
2775  return emitOpError(
2776  "expected position attribute rank to match the dest vector rank");
2777  for (auto [idx, pos] : llvm::enumerate(position)) {
2778  if (auto attr = pos.dyn_cast<Attribute>()) {
2779  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2780  if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
2781  return emitOpError("expected position attribute #")
2782  << (idx + 1)
2783  << " to be a non-negative integer smaller than the "
2784  "corresponding "
2785  "dest vector dimension";
2786  }
2787  }
2788  }
2789  return success();
2790 }
2791 
2792 namespace {
2793 
2794 // If insertOp is only inserting unit dimensions it can be transformed to a
2795 // broadcast.
2796 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
2797 public:
2799 
2800  LogicalResult matchAndRewrite(InsertOp insertOp,
2801  PatternRewriter &rewriter) const override {
2802  auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
2803  if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2804  srcVecType.getNumElements())
2805  return failure();
2806  rewriter.replaceOpWithNewOp<BroadcastOp>(
2807  insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2808  return success();
2809  }
2810 };
2811 
2812 /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
2813 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
2814 public:
2816 
2817  LogicalResult matchAndRewrite(InsertOp op,
2818  PatternRewriter &rewriter) const override {
2819  auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2820  auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2821 
2822  if (!srcSplat || !dstSplat)
2823  return failure();
2824 
2825  if (srcSplat.getInput() != dstSplat.getInput())
2826  return failure();
2827 
2828  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
2829  return success();
2830  }
2831 };
2832 
2833 // Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
2834 class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
2835 public:
2837 
2838  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
2839  // unless the source vector constant has a single use.
2840  static constexpr int64_t vectorSizeFoldThreshold = 256;
2841 
2842  LogicalResult matchAndRewrite(InsertOp op,
2843  PatternRewriter &rewriter) const override {
2844  // TODO: Canonicalization for dynamic position not implemented yet.
2845  if (op.hasDynamicPosition())
2846  return failure();
2847 
2848  // Return if 'InsertOp' operand is not defined by a compatible vector
2849  // ConstantOp.
2850  TypedValue<VectorType> destVector = op.getDest();
2851  Attribute vectorDestCst;
2852  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
2853  return failure();
2854  auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
2855  if (!denseDest)
2856  return failure();
2857 
2858  VectorType destTy = destVector.getType();
2859  if (destTy.isScalable())
2860  return failure();
2861 
2862  // Make sure we do not create too many large constants.
2863  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
2864  !destVector.hasOneUse())
2865  return failure();
2866 
2867  Value sourceValue = op.getSource();
2868  Attribute sourceCst;
2869  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
2870  return failure();
2871 
2872  // Calculate the linearized position of the continuous chunk of elements to
2873  // insert.
2874  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
2875  copy(op.getStaticPosition(), completePositions.begin());
2876  int64_t insertBeginPosition =
2877  linearize(completePositions, computeStrides(destTy.getShape()));
2878 
2879  SmallVector<Attribute> insertedValues;
2880  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst))
2881  llvm::append_range(insertedValues, denseSource.getValues<Attribute>());
2882  else
2883  insertedValues.push_back(sourceCst);
2884 
2885  auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
2886  copy(insertedValues, allValues.begin() + insertBeginPosition);
2887  auto newAttr = DenseElementsAttr::get(destTy, allValues);
2888 
2889  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
2890  return success();
2891  }
2892 };
2893 
2894 } // namespace
2895 
2896 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
2897  MLIRContext *context) {
2898  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
2899  InsertOpConstantFolder>(context);
2900 }
2901 
2902 // Eliminates insert operations that produce values identical to their source
2903 // value. This happens when the source and destination vectors have identical
2904 // sizes.
2905 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
2906  if (getNumIndices() == 0)
2907  return getSource();
2908  return {};
2909 }
2910 
2911 //===----------------------------------------------------------------------===//
2912 // InsertStridedSliceOp
2913 //===----------------------------------------------------------------------===//
2914 
2915 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
2916  Value source, Value dest,
2917  ArrayRef<int64_t> offsets,
2918  ArrayRef<int64_t> strides) {
2919  result.addOperands({source, dest});
2920  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
2921  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
2922  result.addTypes(dest.getType());
2923  result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
2924  offsetsAttr);
2925  result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
2926  stridesAttr);
2927 }
2928 
2929 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
2930 template <typename OpType>
2931 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
2932  ArrayAttr arrayAttr,
2933  ArrayRef<int64_t> shape,
2934  StringRef attrName) {
2935  if (arrayAttr.size() > shape.size())
2936  return op.emitOpError("expected ")
2937  << attrName << " attribute of rank no greater than vector rank";
2938  return success();
2939 }
2940 
2941 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2942 // interval. If `halfOpen` is true then the admissible interval is [min, max).
2943 // Otherwise, the admissible interval is [min, max].
2944 template <typename OpType>
2945 static LogicalResult
2946 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
2947  int64_t max, StringRef attrName,
2948  bool halfOpen = true) {
2949  for (auto attr : arrayAttr) {
2950  auto val = llvm::cast<IntegerAttr>(attr).getInt();
2951  auto upper = max;
2952  if (!halfOpen)
2953  upper += 1;
2954  if (val < min || val >= upper)
2955  return op.emitOpError("expected ") << attrName << " to be confined to ["
2956  << min << ", " << upper << ")";
2957  }
2958  return success();
2959 }
2960 
2961 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2962 // interval. If `halfOpen` is true then the admissible interval is [min, max).
2963 // Otherwise, the admissible interval is [min, max].
2964 template <typename OpType>
2965 static LogicalResult
2966 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
2967  ArrayRef<int64_t> shape, StringRef attrName,
2968  bool halfOpen = true, int64_t min = 0) {
2969  for (auto [index, attrDimPair] :
2970  llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
2971  int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
2972  int64_t max = std::get<1>(attrDimPair);
2973  if (!halfOpen)
2974  max += 1;
2975  if (val < min || val >= max)
2976  return op.emitOpError("expected ")
2977  << attrName << " dimension " << index << " to be confined to ["
2978  << min << ", " << max << ")";
2979  }
2980  return success();
2981 }
2982 
2983 // Returns true if, for all indices i = 0..shape.size()-1, val is in the
2984 // [min, max} interval:
2985 // val = `arrayAttr1[i]` + `arrayAttr2[i]`,
2986 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
2987 // the admissible interval is [min, max].
2988 template <typename OpType>
2990  OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
2991  ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
2992  bool halfOpen = true, int64_t min = 1) {
2993  assert(arrayAttr1.size() <= shape.size());
2994  assert(arrayAttr2.size() <= shape.size());
2995  for (auto [index, it] :
2996  llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
2997  auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
2998  auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
2999  int64_t max = std::get<2>(it);
3000  if (!halfOpen)
3001  max += 1;
3002  if (val1 + val2 < 0 || val1 + val2 >= max)
3003  return op.emitOpError("expected sum(")
3004  << attrName1 << ", " << attrName2 << ") dimension " << index
3005  << " to be confined to [" << min << ", " << max << ")";
3006  }
3007  return success();
3008 }
3009 
3010 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
3011  MLIRContext *context) {
3012  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3013  return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3014  });
3015  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3016 }
3017 
3018 LogicalResult InsertStridedSliceOp::verify() {
3019  auto sourceVectorType = getSourceVectorType();
3020  auto destVectorType = getDestVectorType();
3021  auto offsets = getOffsetsAttr();
3022  auto strides = getStridesAttr();
3023  if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3024  return emitOpError(
3025  "expected offsets of same size as destination vector rank");
3026  if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3027  return emitOpError("expected strides of same size as source vector rank");
3028  if (sourceVectorType.getRank() > destVectorType.getRank())
3029  return emitOpError(
3030  "expected source rank to be no greater than destination rank");
3031 
3032  auto sourceShape = sourceVectorType.getShape();
3033  auto destShape = destVectorType.getShape();
3034  SmallVector<int64_t, 4> sourceShapeAsDestShape(
3035  destShape.size() - sourceShape.size(), 0);
3036  sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3037  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3038  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3039  if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3040  offName)) ||
3041  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3042  /*max=*/1, stridesName,
3043  /*halfOpen=*/false)) ||
3045  *this, offsets,
3046  makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3047  offName, "source vector shape",
3048  /*halfOpen=*/false, /*min=*/1)))
3049  return failure();
3050 
3051  unsigned rankDiff = destShape.size() - sourceShape.size();
3052  for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3053  if (sourceVectorType.getScalableDims()[idx] !=
3054  destVectorType.getScalableDims()[idx + rankDiff]) {
3055  return emitOpError("mismatching scalable flags (at source vector idx=")
3056  << idx << ")";
3057  }
3058  if (sourceVectorType.getScalableDims()[idx]) {
3059  auto sourceSize = sourceShape[idx];
3060  auto destSize = destShape[idx + rankDiff];
3061  if (sourceSize != destSize) {
3062  return emitOpError("expected size at idx=")
3063  << idx
3064  << (" to match the corresponding base size from the input "
3065  "vector (")
3066  << sourceSize << (" vs ") << destSize << (")");
3067  }
3068  }
3069  }
3070 
3071  return success();
3072 }
3073 
3074 namespace {
3075 /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3076 /// SplatOp(X):dst_type) to SplatOp(X):dst_type.
3077 class FoldInsertStridedSliceSplat final
3078  : public OpRewritePattern<InsertStridedSliceOp> {
3079 public:
3081 
3082  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3083  PatternRewriter &rewriter) const override {
3084  auto srcSplatOp =
3085  insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
3086  auto destSplatOp =
3087  insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3088 
3089  if (!srcSplatOp || !destSplatOp)
3090  return failure();
3091 
3092  if (srcSplatOp.getInput() != destSplatOp.getInput())
3093  return failure();
3094 
3095  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3096  return success();
3097  }
3098 };
3099 
3100 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3101 /// to dst.
3102 class FoldInsertStridedSliceOfExtract final
3103  : public OpRewritePattern<InsertStridedSliceOp> {
3104 public:
3106 
3107  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3108  PatternRewriter &rewriter) const override {
3109  auto extractStridedSliceOp =
3110  insertStridedSliceOp.getSource()
3111  .getDefiningOp<vector::ExtractStridedSliceOp>();
3112 
3113  if (!extractStridedSliceOp)
3114  return failure();
3115 
3116  if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3117  return failure();
3118 
3119  // Check if have the same strides and offsets.
3120  if (extractStridedSliceOp.getStrides() !=
3121  insertStridedSliceOp.getStrides() ||
3122  extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3123  return failure();
3124 
3125  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3126  return success();
3127  }
3128 };
3129 
3130 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3131 // ConstantOp.
3132 class InsertStridedSliceConstantFolder final
3133  : public OpRewritePattern<InsertStridedSliceOp> {
3134 public:
3136 
3137  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3138  // unless the source vector constant has a single use.
3139  static constexpr int64_t vectorSizeFoldThreshold = 256;
3140 
3141  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3142  PatternRewriter &rewriter) const override {
3143  // Return if 'InsertOp' operand is not defined by a compatible vector
3144  // ConstantOp.
3145  TypedValue<VectorType> destVector = op.getDest();
3146  Attribute vectorDestCst;
3147  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3148  return failure();
3149 
3150  VectorType destTy = destVector.getType();
3151  if (destTy.isScalable())
3152  return failure();
3153 
3154  // Make sure we do not create too many large constants.
3155  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3156  !destVector.hasOneUse())
3157  return failure();
3158 
3159  auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3160 
3161  TypedValue<VectorType> sourceValue = op.getSource();
3162  Attribute sourceCst;
3163  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3164  return failure();
3165 
3166  // TODO: Handle non-unit strides when they become available.
3167  if (op.hasNonUnitStrides())
3168  return failure();
3169 
3170  VectorType sliceVecTy = sourceValue.getType();
3171  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3172  int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3173  SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3174  SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3175 
3176  // Calcualte the destination element indices by enumerating all slice
3177  // positions within the destination and linearizing them. The enumeration
3178  // order is lexicographic which yields a sequence of monotonically
3179  // increasing linearized position indices.
3180  // Because the destination may have higher dimensionality then the slice,
3181  // we keep track of two overlapping sets of positions and offsets.
3182  auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3183  auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3184  auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3185  SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3186  MutableArrayRef<int64_t> currSlicePosition(
3187  currDestPosition.begin() + rankDifference, currDestPosition.end());
3188  ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3189  offsets.end());
3190  do {
3191  int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3192  assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3193  assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3194  "Invalid slice element");
3195  newValues[linearizedPosition] = *sliceValuesIt;
3196  ++sliceValuesIt;
3197  } while (succeeded(
3198  incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3199 
3200  auto newAttr = DenseElementsAttr::get(destTy, newValues);
3201  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3202  return success();
3203  }
3204 };
3205 
3206 } // namespace
3207 
3208 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3209  RewritePatternSet &results, MLIRContext *context) {
3210  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3211  InsertStridedSliceConstantFolder>(context);
3212 }
3213 
3214 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3215  if (getSourceVectorType() == getDestVectorType())
3216  return getSource();
3217  return {};
3218 }
3219 
3220 //===----------------------------------------------------------------------===//
3221 // OuterProductOp
3222 //===----------------------------------------------------------------------===//
3223 
3224 /// Build an op without mask, use the type of `acc` as the return type.
3225 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
3226  Value lhs, Value rhs, Value acc) {
3227  result.addOperands({lhs, rhs, acc});
3228  result.addTypes(acc.getType());
3229 }
3230 
3232  p << " " << getLhs() << ", " << getRhs();
3233  if (getAcc()) {
3234  p << ", " << getAcc();
3235  p.printOptionalAttrDict((*this)->getAttrs());
3236  }
3237  p << " : " << getLhs().getType() << ", " << getRhs().getType();
3238 }
3239 
3240 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
3242  Type tLHS, tRHS;
3243  if (parser.parseOperandList(operandsInfo) ||
3244  parser.parseOptionalAttrDict(result.attributes) ||
3245  parser.parseColonType(tLHS) || parser.parseComma() ||
3246  parser.parseType(tRHS))
3247  return failure();
3248  if (operandsInfo.size() < 2)
3249  return parser.emitError(parser.getNameLoc(),
3250  "expected at least 2 operands");
3251  VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3252  VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3253  if (!vLHS)
3254  return parser.emitError(parser.getNameLoc(),
3255  "expected vector type for operand #1");
3256 
3257  VectorType resType;
3258  if (vRHS) {
3259  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3260  vRHS.getScalableDims()[0]};
3261  resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
3262  vLHS.getElementType(), scalableDimsRes);
3263  } else {
3264  // Scalar RHS operand
3265  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3266  resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3267  scalableDimsRes);
3268  }
3269 
3270  if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
3271  result.attributes.append(
3272  OuterProductOp::getKindAttrName(result.name),
3274  OuterProductOp::getDefaultKind()));
3275  }
3276 
3277  return failure(
3278  parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
3279  parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
3280  (operandsInfo.size() > 2 &&
3281  parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
3282  parser.addTypeToList(resType, result.types));
3283 }
3284 
3285 LogicalResult OuterProductOp::verify() {
3286  Type tRHS = getOperandTypeRHS();
3287  VectorType vLHS = getOperandVectorTypeLHS(),
3288  vRHS = llvm::dyn_cast<VectorType>(tRHS),
3289  vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3290 
3291  if (vLHS.getRank() != 1)
3292  return emitOpError("expected 1-d vector for operand #1");
3293 
3294  if (vRHS) {
3295  // Proper OUTER operation.
3296  if (vRHS.getRank() != 1)
3297  return emitOpError("expected 1-d vector for operand #2");
3298  if (vRES.getRank() != 2)
3299  return emitOpError("expected 2-d vector result");
3300  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3301  return emitOpError("expected #1 operand dim to match result dim #1");
3302  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3303  return emitOpError("expected #2 operand dim to match result dim #2");
3304  if (vLHS.isScalable() && !vRHS.isScalable()) {
3305  // This restriction reflects what's currently supported in terms of
3306  // scalable vectors. However, we could relax this if there's a use case.
3307  return emitOpError(
3308  "expected either both or only #2 operand dim to be scalable");
3309  }
3310  } else {
3311  // An AXPY operation.
3312  if (vRES.getRank() != 1)
3313  return emitOpError("expected 1-d vector result");
3314  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3315  return emitOpError("expected #1 operand dim to match result dim #1");
3316  }
3317 
3318  if (vACC && vACC != vRES)
3319  return emitOpError("expected operand #3 of same type as result type");
3320 
3321  // Verify supported combining kind.
3322  if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
3323  return emitOpError("unsupported outerproduct type");
3324 
3325  return success();
3326 }
3327 
3328 // MaskableOpInterface methods.
3329 
3330 /// Returns the mask type expected by this operation. Mostly used for
3331 /// verification purposes. It requires the operation to be vectorized."
3332 Type OuterProductOp::getExpectedMaskType() {
3333  auto vecType = this->getResultVectorType();
3334  return VectorType::get(vecType.getShape(),
3335  IntegerType::get(vecType.getContext(), /*width=*/1),
3336  vecType.getScalableDims());
3337 }
3338 
3339 //===----------------------------------------------------------------------===//
3340 // ReshapeOp
3341 //===----------------------------------------------------------------------===//
3342 
3343 LogicalResult ReshapeOp::verify() {
3344  // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
3345  auto inputVectorType = getInputVectorType();
3346  auto outputVectorType = getOutputVectorType();
3347  int64_t inputShapeRank = getNumInputShapeSizes();
3348  int64_t outputShapeRank = getNumOutputShapeSizes();
3349  SmallVector<int64_t, 4> fixedVectorSizes;
3350  getFixedVectorSizes(fixedVectorSizes);
3351  int64_t numFixedVectorSizes = fixedVectorSizes.size();
3352 
3353  if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
3354  return emitError("invalid input shape for vector type ") << inputVectorType;
3355 
3356  if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
3357  return emitError("invalid output shape for vector type ")
3358  << outputVectorType;
3359 
3360  // Verify that the 'fixedVectorSizes' match an input/output vector shape
3361  // suffix.
3362  unsigned inputVectorRank = inputVectorType.getRank();
3363  for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
3364  unsigned index = inputVectorRank - numFixedVectorSizes - i;
3365  if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
3366  return emitError("fixed vector size must match input vector for dim ")
3367  << i;
3368  }
3369 
3370  unsigned outputVectorRank = outputVectorType.getRank();
3371  for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
3372  unsigned index = outputVectorRank - numFixedVectorSizes - i;
3373  if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
3374  return emitError("fixed vector size must match output vector for dim ")
3375  << i;
3376  }
3377 
3378  // If all shape operands are produced by constant ops, verify that product
3379  // of dimensions for input/output shape match.
3380  auto isDefByConstant = [](Value operand) {
3381  return getConstantIntValue(operand).has_value();
3382  };
3383  if (llvm::all_of(getInputShape(), isDefByConstant) &&
3384  llvm::all_of(getOutputShape(), isDefByConstant)) {
3385  int64_t numInputElements = 1;
3386  for (auto operand : getInputShape())
3387  numInputElements *= getConstantIntValue(operand).value();
3388  int64_t numOutputElements = 1;
3389  for (auto operand : getOutputShape())
3390  numOutputElements *= getConstantIntValue(operand).value();
3391  if (numInputElements != numOutputElements)
3392  return emitError("product of input and output shape sizes must match");
3393  }
3394  return success();
3395 }
3396 
3397 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
3398  populateFromInt64AttrArray(getFixedVectorSizes(), results);
3399 }
3400 
3401 //===----------------------------------------------------------------------===//
3402 // ExtractStridedSliceOp
3403 //===----------------------------------------------------------------------===//
3404 
3405 // Inference works as follows:
3406 // 1. Add 'sizes' from prefix of dims in 'offsets'.
3407 // 2. Add sizes from 'vectorType' for remaining dims.
3408 // Scalable flags are inherited from 'vectorType'.
3409 static Type inferStridedSliceOpResultType(VectorType vectorType,
3410  ArrayAttr offsets, ArrayAttr sizes,
3411  ArrayAttr strides) {
3412  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3414  shape.reserve(vectorType.getRank());
3415  unsigned idx = 0;
3416  for (unsigned e = offsets.size(); idx < e; ++idx)
3417  shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3418  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3419  shape.push_back(vectorType.getShape()[idx]);
3420 
3421  return VectorType::get(shape, vectorType.getElementType(),
3422  vectorType.getScalableDims());
3423 }
3424 
3425 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3426  Value source, ArrayRef<int64_t> offsets,
3427  ArrayRef<int64_t> sizes,
3428  ArrayRef<int64_t> strides) {
3429  result.addOperands(source);
3430  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3431  auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
3432  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3433  result.addTypes(
3434  inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
3435  offsetsAttr, sizesAttr, stridesAttr));
3436  result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
3437  offsetsAttr);
3438  result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
3439  sizesAttr);
3440  result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
3441  stridesAttr);
3442 }
3443 
3444 LogicalResult ExtractStridedSliceOp::verify() {
3445  auto type = getSourceVectorType();
3446  auto offsets = getOffsetsAttr();
3447  auto sizes = getSizesAttr();
3448  auto strides = getStridesAttr();
3449  if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3450  return emitOpError(
3451  "expected offsets, sizes and strides attributes of same size");
3452 
3453  auto shape = type.getShape();
3454  auto offName = getOffsetsAttrName();
3455  auto sizesName = getSizesAttrName();
3456  auto stridesName = getStridesAttrName();
3457  if (failed(
3458  isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
3459  failed(
3460  isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
3461  failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
3462  stridesName)) ||
3463  failed(
3464  isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
3465  failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
3466  /*halfOpen=*/false,
3467  /*min=*/1)) ||
3468  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3469  /*max=*/1, stridesName,
3470  /*halfOpen=*/false)) ||
3471  failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
3472  shape, offName, sizesName,
3473  /*halfOpen=*/false)))
3474  return failure();
3475 
3476  auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
3477  offsets, sizes, strides);
3478  if (getResult().getType() != resultType)
3479  return emitOpError("expected result type to be ") << resultType;
3480 
3481  for (unsigned idx = 0; idx < sizes.size(); ++idx) {
3482  if (type.getScalableDims()[idx]) {
3483  auto inputDim = type.getShape()[idx];
3484  auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3485  if (inputDim != inputSize)
3486  return emitOpError("expected size at idx=")
3487  << idx
3488  << (" to match the corresponding base size from the input "
3489  "vector (")
3490  << inputSize << (" vs ") << inputDim << (")");
3491  }
3492  }
3493 
3494  return success();
3495 }
3496 
3497 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
3498 // to use the source of the InsertStrided ops if we can detect that the
3499 // extracted vector is a subset of one of the vector inserted.
3500 static LogicalResult
3501 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
3502  // Helper to extract integer out of ArrayAttr.
3503  auto getElement = [](ArrayAttr array, int idx) {
3504  return llvm::cast<IntegerAttr>(array[idx]).getInt();
3505  };
3506  ArrayAttr extractOffsets = op.getOffsets();
3507  ArrayAttr extractStrides = op.getStrides();
3508  ArrayAttr extractSizes = op.getSizes();
3509  auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3510  while (insertOp) {
3511  if (op.getSourceVectorType().getRank() !=
3512  insertOp.getSourceVectorType().getRank())
3513  return failure();
3514  ArrayAttr insertOffsets = insertOp.getOffsets();
3515  ArrayAttr insertStrides = insertOp.getStrides();
3516  // If the rank of extract is greater than the rank of insert, we are likely
3517  // extracting a partial chunk of the vector inserted.
3518  if (extractOffsets.size() > insertOffsets.size())
3519  return failure();
3520  bool patialoverlap = false;
3521  bool disjoint = false;
3522  SmallVector<int64_t, 4> offsetDiffs;
3523  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3524  if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
3525  return failure();
3526  int64_t start = getElement(insertOffsets, dim);
3527  int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3528  int64_t offset = getElement(extractOffsets, dim);
3529  int64_t size = getElement(extractSizes, dim);
3530  // Check if the start of the extract offset is in the interval inserted.
3531  if (start <= offset && offset < end) {
3532  // If the extract interval overlaps but is not fully included we may
3533  // have a partial overlap that will prevent any folding.
3534  if (offset + size > end)
3535  patialoverlap = true;
3536  offsetDiffs.push_back(offset - start);
3537  continue;
3538  }
3539  disjoint = true;
3540  break;
3541  }
3542  // The extract element chunk is a subset of the insert element.
3543  if (!disjoint && !patialoverlap) {
3544  op.setOperand(insertOp.getSource());
3545  // OpBuilder is only used as a helper to build an I64ArrayAttr.
3546  OpBuilder b(op.getContext());
3547  op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
3548  return success();
3549  }
3550  // If the chunk extracted is disjoint from the chunk inserted, keep looking
3551  // in the insert chain.
3552  if (disjoint)
3553  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3554  else {
3555  // The extracted vector partially overlap the inserted vector, we cannot
3556  // fold.
3557  return failure();
3558  }
3559  }
3560  return failure();
3561 }
3562 
3563 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3564  if (getSourceVectorType() == getResult().getType())
3565  return getVector();
3566  if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
3567  return getResult();
3568  return {};
3569 }
3570 
3571 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
3572  populateFromInt64AttrArray(getOffsets(), results);
3573 }
3574 
3575 namespace {
3576 
3577 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
3578 // ConstantMaskOp.
3579 class StridedSliceConstantMaskFolder final
3580  : public OpRewritePattern<ExtractStridedSliceOp> {
3581 public:
3583 
3584  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3585  PatternRewriter &rewriter) const override {
3586  // Return if 'extractStridedSliceOp' operand is not defined by a
3587  // ConstantMaskOp.
3588  auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3589  auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3590  if (!constantMaskOp)
3591  return failure();
3592  // Return if 'extractStridedSliceOp' has non-unit strides.
3593  if (extractStridedSliceOp.hasNonUnitStrides())
3594  return failure();
3595  // Gather constant mask dimension sizes.
3596  SmallVector<int64_t, 4> maskDimSizes;
3597  populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
3598  // Gather strided slice offsets and sizes.
3599  SmallVector<int64_t, 4> sliceOffsets;
3600  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
3601  sliceOffsets);
3602  SmallVector<int64_t, 4> sliceSizes;
3603  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
3604 
3605  // Compute slice of vector mask region.
3606  SmallVector<int64_t, 4> sliceMaskDimSizes;
3607  sliceMaskDimSizes.reserve(maskDimSizes.size());
3608  for (auto [maskDimSize, sliceOffset, sliceSize] :
3609  llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3610  int64_t sliceMaskDimSize = std::max(
3611  static_cast<int64_t>(0),
3612  std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3613  sliceMaskDimSizes.push_back(sliceMaskDimSize);
3614  }
3615  // Add unchanged dimensions.
3616  if (sliceMaskDimSizes.size() < maskDimSizes.size())
3617  for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3618  sliceMaskDimSizes.push_back(maskDimSizes[i]);
3619  // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
3620  // region is a conjunction of mask dim intervals).
3621  if (llvm::is_contained(sliceMaskDimSizes, 0))
3622  sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3623 
3624  // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
3625  // region.
3626  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3627  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3628  vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
3629  return success();
3630  }
3631 };
3632 
3633 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3634 class StridedSliceSplatConstantFolder final
3635  : public OpRewritePattern<ExtractStridedSliceOp> {
3636 public:
3638 
3639  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3640  PatternRewriter &rewriter) const override {
3641  // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3642  // ConstantOp.
3643  Value sourceVector = extractStridedSliceOp.getVector();
3644  Attribute vectorCst;
3645  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3646  return failure();
3647 
3648  auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3649  if (!splat)
3650  return failure();
3651 
3652  auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
3653  splat.getSplatValue<Attribute>());
3654  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3655  newAttr);
3656  return success();
3657  }
3658 };
3659 
3660 // Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3661 // ConstantOp.
3662 class StridedSliceNonSplatConstantFolder final
3663  : public OpRewritePattern<ExtractStridedSliceOp> {
3664 public:
3666 
3667  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3668  PatternRewriter &rewriter) const override {
3669  // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3670  // ConstantOp.
3671  Value sourceVector = extractStridedSliceOp.getVector();
3672  Attribute vectorCst;
3673  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3674  return failure();
3675 
3676  // The splat case is handled by `StridedSliceSplatConstantFolder`.
3677  auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3678  if (!dense || dense.isSplat())
3679  return failure();
3680 
3681  // TODO: Handle non-unit strides when they become available.
3682  if (extractStridedSliceOp.hasNonUnitStrides())
3683  return failure();
3684 
3685  auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
3686  ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3687  SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3688 
3689  VectorType sliceVecTy = extractStridedSliceOp.getType();
3690  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3691  int64_t sliceRank = sliceVecTy.getRank();
3692 
3693  // Expand offsets and sizes to match the vector rank.
3694  SmallVector<int64_t, 4> offsets(sliceRank, 0);
3695  copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
3696 
3697  SmallVector<int64_t, 4> sizes(sourceShape.begin(), sourceShape.end());
3698  copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
3699 
3700  // Calculate the slice elements by enumerating all slice positions and
3701  // linearizing them. The enumeration order is lexicographic which yields a
3702  // sequence of monotonically increasing linearized position indices.
3703  auto denseValuesBegin = dense.value_begin<Attribute>();
3704  SmallVector<Attribute> sliceValues;
3705  sliceValues.reserve(sliceVecTy.getNumElements());
3706  SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3707  do {
3708  int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
3709  assert(linearizedPosition < sourceVecTy.getNumElements() &&
3710  "Invalid index");
3711  sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3712  } while (
3713  succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
3714 
3715  assert(static_cast<int64_t>(sliceValues.size()) ==
3716  sliceVecTy.getNumElements() &&
3717  "Invalid number of slice elements");
3718  auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
3719  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3720  newAttr);
3721  return success();
3722  }
3723 };
3724 
3725 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
3726 // BroadcastOp(ExtractStrideSliceOp).
3727 class StridedSliceBroadcast final
3728  : public OpRewritePattern<ExtractStridedSliceOp> {
3729 public:
3731 
3732  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3733  PatternRewriter &rewriter) const override {
3734  auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3735  if (!broadcast)
3736  return failure();
3737  auto srcVecType =
3738  llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
3739  unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3740  auto dstVecType = llvm::cast<VectorType>(op.getType());
3741  unsigned dstRank = dstVecType.getRank();
3742  unsigned rankDiff = dstRank - srcRank;
3743  // Check if the most inner dimensions of the source of the broadcast are the
3744  // same as the destination of the extract. If this is the case we can just
3745  // use a broadcast as the original dimensions are untouched.
3746  bool lowerDimMatch = true;
3747  for (unsigned i = 0; i < srcRank; i++) {
3748  if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3749  lowerDimMatch = false;
3750  break;
3751  }
3752  }
3753  Value source = broadcast.getSource();
3754  // If the inner dimensions don't match, it means we need to extract from the
3755  // source of the orignal broadcast and then broadcast the extracted value.
3756  // We also need to handle degenerated cases where the source is effectively
3757  // just a single scalar.
3758  bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3759  if (!lowerDimMatch && !isScalarSrc) {
3760  source = rewriter.create<ExtractStridedSliceOp>(
3761  op->getLoc(), source,
3762  getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
3763  getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
3764  getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
3765  }
3766  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
3767  return success();
3768  }
3769 };
3770 
3771 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
3772 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
3773 public:
3775 
3776  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3777  PatternRewriter &rewriter) const override {
3778  auto splat = op.getVector().getDefiningOp<SplatOp>();
3779  if (!splat)
3780  return failure();
3781  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
3782  return success();
3783  }
3784 };
3785 
3786 } // namespace
3787 
3788 void ExtractStridedSliceOp::getCanonicalizationPatterns(
3789  RewritePatternSet &results, MLIRContext *context) {
3790  // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
3791  // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
3792  results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
3793  StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3794  StridedSliceSplat>(context);
3795 }
3796 
3797 //===----------------------------------------------------------------------===//
3798 // TransferReadOp
3799 //===----------------------------------------------------------------------===//
3800 
3801 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
3802 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3803  VectorType vectorType, Value source,
3804  ValueRange indices, AffineMapAttr permutationMapAttr,
3805  /*optional*/ ArrayAttr inBoundsAttr) {
3806  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3807  Value padding = builder.create<arith::ConstantOp>(
3808  result.location, elemType, builder.getZeroAttr(elemType));
3809  build(builder, result, vectorType, source, indices, permutationMapAttr,
3810  padding, /*mask=*/Value(), inBoundsAttr);
3811 }
3812 
3813 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
3814 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3815  VectorType vectorType, Value source,
3816  ValueRange indices, AffineMap permutationMap,
3817  std::optional<ArrayRef<bool>> inBounds) {
3818  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3819  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3820  ? builder.getBoolArrayAttr(inBounds.value())
3821  : builder.getBoolArrayAttr(
3822  SmallVector<bool>(vectorType.getRank(), false));
3823  build(builder, result, vectorType, source, indices, permutationMapAttr,
3824  inBoundsAttr);
3825 }
3826 
3827 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
3828 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3829  VectorType vectorType, Value source,
3830  ValueRange indices, Value padding,
3831  std::optional<ArrayRef<bool>> inBounds) {
3832  AffineMap permutationMap = getTransferMinorIdentityMap(
3833  llvm::cast<ShapedType>(source.getType()), vectorType);
3834  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3835  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3836  ? builder.getBoolArrayAttr(inBounds.value())
3837  : builder.getBoolArrayAttr(
3838  SmallVector<bool>(vectorType.getRank(), false));
3839  build(builder, result, vectorType, source, indices, permutationMapAttr,
3840  padding,
3841  /*mask=*/Value(), inBoundsAttr);
3842 }
3843 
3844 /// 4. Builder that sets padding to zero and permutation map to
3845 /// 'getMinorIdentityMap'.
3846 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3847  VectorType vectorType, Value source,
3848  ValueRange indices,
3849  std::optional<ArrayRef<bool>> inBounds) {
3850  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3851  Value padding = builder.create<arith::ConstantOp>(
3852  result.location, elemType, builder.getZeroAttr(elemType));
3853  build(builder, result, vectorType, source, indices, padding, inBounds);
3854 }
3855 
3856 template <typename EmitFun>
3857 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
3858  EmitFun emitOpError) {
3859  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
3860  for (auto expr : permutationMap.getResults()) {
3861  auto dim = dyn_cast<AffineDimExpr>(expr);
3862  auto zero = dyn_cast<AffineConstantExpr>(expr);
3863  if (zero) {
3864  if (zero.getValue() != 0) {
3865  return emitOpError(
3866  "requires a projected permutation_map (at most one dim or the zero "
3867  "constant can appear in each result)");
3868  }
3869  continue;
3870  }
3871  if (!dim) {
3872  return emitOpError("requires a projected permutation_map (at most one "
3873  "dim or the zero constant can appear in each result)");
3874  }
3875  if (seen[dim.getPosition()]) {
3876  return emitOpError(
3877  "requires a permutation_map that is a permutation (found one dim "
3878  "used more than once)");
3879  }
3880  seen[dim.getPosition()] = true;
3881  }
3882  return success();
3883 }
3884 
3885 static LogicalResult
3886 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
3887  VectorType vectorType, VectorType maskType,
3888  VectorType inferredMaskType, AffineMap permutationMap,
3889  ArrayAttr inBounds) {
3890  if (op->hasAttr("masked")) {
3891  return op->emitOpError("masked attribute has been removed. "
3892  "Use in_bounds instead.");
3893  }
3894 
3895  if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
3896  return op->emitOpError(
3897  "requires source to be a memref or ranked tensor type");
3898 
3899  auto elementType = shapedType.getElementType();
3900  DataLayout dataLayout = DataLayout::closest(op);
3901  if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
3902  // Memref or tensor has vector element type.
3903  unsigned sourceVecSize =
3904  dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
3905  vectorElementType.getShape().back();
3906  unsigned resultVecSize =
3907  dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
3908  vectorType.getShape().back();
3909  if (resultVecSize % sourceVecSize != 0)
3910  return op->emitOpError(
3911  "requires the bitwidth of the minor 1-D vector to be an integral "
3912  "multiple of the bitwidth of the minor 1-D vector of the source");
3913 
3914  unsigned sourceVecEltRank = vectorElementType.getRank();
3915  unsigned resultVecRank = vectorType.getRank();
3916  if (sourceVecEltRank > resultVecRank)
3917  return op->emitOpError(
3918  "requires source vector element and vector result ranks to match.");
3919  unsigned rankOffset = resultVecRank - sourceVecEltRank;
3920  // Check that permutation map results match 'rankOffset' of vector type.
3921  if (permutationMap.getNumResults() != rankOffset)
3922  return op->emitOpError("requires a permutation_map with result dims of "
3923  "the same rank as the vector type");
3924 
3925  if (maskType)
3926  return op->emitOpError("does not support masks with vector element type");
3927  } else {
3928  // Memref or tensor has scalar element type.
3929  unsigned minorSize =
3930  vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
3931  unsigned resultVecSize =
3932  dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
3933  if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
3934  return op->emitOpError(
3935  "requires the bitwidth of the minor 1-D vector to be an integral "
3936  "multiple of the bitwidth of the source element type");
3937 
3938  // Check that permutation map results match rank of vector type.
3939  if (permutationMap.getNumResults() != vectorType.getRank())
3940  return op->emitOpError("requires a permutation_map with result dims of "
3941  "the same rank as the vector type");
3942  }
3943 
3944  if (permutationMap.getNumSymbols() != 0)
3945  return op->emitOpError("requires permutation_map without symbols");
3946 
3947  if (permutationMap.getNumInputs() != shapedType.getRank())
3948  return op->emitOpError("requires a permutation_map with input dims of the "
3949  "same rank as the source type");
3950 
3951  if (maskType && maskType != inferredMaskType)
3952  return op->emitOpError("inferred mask type (")
3953  << inferredMaskType << ") and mask operand type (" << maskType
3954  << ") don't match";
3955 
3956  if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
3957  return op->emitOpError("expects the in_bounds attr of same rank "
3958  "as permutation_map results: ")
3959  << AffineMapAttr::get(permutationMap)
3960  << " vs inBounds of size: " << inBounds.size();
3961  for (unsigned int i = 0, e = permutationMap.getNumResults(); i < e; ++i)
3962  if (isa<AffineConstantExpr>(permutationMap.getResult(i)) &&
3963  !llvm::cast<BoolAttr>(inBounds.getValue()[i]).getValue())
3964  return op->emitOpError("requires broadcast dimensions to be in-bounds");
3965 
3966  return success();
3967 }
3968 
3969 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
3970  SmallVector<StringRef, 3> elidedAttrs;
3971  elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
3972  if (op.getPermutationMap().isMinorIdentity())
3973  elidedAttrs.push_back(op.getPermutationMapAttrName());
3974  // Elide in_bounds attribute if all dims are out-of-bounds.
3975  if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
3976  elidedAttrs.push_back(op.getInBoundsAttrName());
3977  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
3978 }
3979 
3981  p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
3982  if (getMask())
3983  p << ", " << getMask();
3984  printTransferAttrs(p, *this);
3985  p << " : " << getShapedType() << ", " << getVectorType();
3986 }
3987 
3988 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
3989  AffineMap permMap) {
3990  auto i1Type = IntegerType::get(permMap.getContext(), 1);
3991  AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
3992  assert(invPermMap && "Inversed permutation map couldn't be computed");
3993  SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
3994 
3995  SmallVector<bool> scalableDims =
3996  applyPermutationMap(invPermMap, vecType.getScalableDims());
3997 
3998  return VectorType::get(maskShape, i1Type, scalableDims);
3999 }
4000 
4001 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4002  auto &builder = parser.getBuilder();
4003  SMLoc typesLoc;
4004  OpAsmParser::UnresolvedOperand sourceInfo;
4006  OpAsmParser::UnresolvedOperand paddingInfo;
4007  SmallVector<Type, 2> types;
4009  // Parsing with support for paddingValue.
4010  if (parser.parseOperand(sourceInfo) ||
4011  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
4012  parser.parseComma() || parser.parseOperand(paddingInfo))
4013  return failure();
4014  ParseResult hasMask = parser.parseOptionalComma();
4015  if (hasMask.succeeded()) {
4016  if (parser.parseOperand(maskInfo))
4017  return failure();
4018  }
4019  if (parser.parseOptionalAttrDict(result.attributes) ||
4020  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4021  return failure();
4022  if (types.size() != 2)
4023  return parser.emitError(typesLoc, "requires two types");
4024  auto indexType = builder.getIndexType();
4025  auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4026  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4027  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4028  VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4029  if (!vectorType)
4030  return parser.emitError(typesLoc, "requires vector type");
4031  auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
4032  Attribute permMapAttr = result.attributes.get(permMapAttrName);
4033  AffineMap permMap;
4034  if (!permMapAttr) {
4035  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4036  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4037  } else {
4038  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4039  }
4040  auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
4041  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4042  if (!inBoundsAttr) {
4043  result.addAttribute(inBoundsAttrName,
4044  builder.getBoolArrayAttr(
4045  SmallVector<bool>(permMap.getNumResults(), false)));
4046  }
4047  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4048  parser.resolveOperands(indexInfo, indexType, result.operands) ||
4049  parser.resolveOperand(paddingInfo, shapedType.getElementType(),
4050  result.operands))
4051  return failure();
4052  if (hasMask.succeeded()) {
4053  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4054  return parser.emitError(
4055  maskInfo.location, "does not support masks with vector element type");
4056  if (vectorType.getRank() != permMap.getNumResults()) {
4057  return parser.emitError(typesLoc,
4058  "expected the same rank for the vector and the "
4059  "results of the permutation map");
4060  }
4061  // Instead of adding the mask type as an op type, compute it based on the
4062  // vector type and the permutation map (to keep the type signature small).
4063  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4064  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4065  return failure();
4066  }
4067  result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4068  builder.getDenseI32ArrayAttr(
4069  {1, static_cast<int32_t>(indexInfo.size()), 1,
4070  static_cast<int32_t>(hasMask.succeeded())}));
4071  return parser.addTypeToList(vectorType, result.types);
4072 }
4073 
4074 LogicalResult TransferReadOp::verify() {
4075  // Consistency of elemental types in source and vector.
4076  ShapedType shapedType = getShapedType();
4077  VectorType vectorType = getVectorType();
4078  VectorType maskType = getMaskType();
4079  auto paddingType = getPadding().getType();
4080  auto permutationMap = getPermutationMap();
4081  VectorType inferredMaskType =
4082  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4083  : VectorType();
4084  auto sourceElementType = shapedType.getElementType();
4085 
4086  if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
4087  return emitOpError("requires ") << shapedType.getRank() << " indices";
4088 
4089  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4090  shapedType, vectorType, maskType,
4091  inferredMaskType, permutationMap, getInBounds())))
4092  return failure();
4093 
4094  if (auto sourceVectorElementType =
4095  llvm::dyn_cast<VectorType>(sourceElementType)) {
4096  // Source has vector element type.
4097  // Check that 'sourceVectorElementType' and 'paddingType' types match.
4098  if (sourceVectorElementType != paddingType)
4099  return emitOpError(
4100  "requires source element type and padding type to match.");
4101 
4102  } else {
4103  // Check that 'paddingType' is valid to store in a vector type.
4104  if (!VectorType::isValidElementType(paddingType))
4105  return emitOpError("requires valid padding vector elemental type");
4106 
4107  // Check that padding type and vector element types match.
4108  if (paddingType != sourceElementType)
4109  return emitOpError(
4110  "requires formal padding and source of the same elemental type");
4111  }
4112 
4113  return verifyPermutationMap(permutationMap,
4114  [&](Twine t) { return emitOpError(t); });
4115 }
4116 
4117 // MaskableOpInterface methods.
4118 
4119 /// Returns the mask type expected by this operation. Mostly used for
4120 /// verification purposes. It requires the operation to be vectorized."
4121 Type TransferReadOp::getExpectedMaskType() {
4122  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4123 }
4124 
4125 template <typename TransferOp>
4126 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4127  // TODO: support more aggressive createOrFold on:
4128  // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4129  if (op.getShapedType().isDynamicDim(indicesIdx))
4130  return false;
4131  Value index = op.getIndices()[indicesIdx];
4132  std::optional<int64_t> cstOp = getConstantIntValue(index);
4133  if (!cstOp.has_value())
4134  return false;
4135 
4136  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4137  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4138 
4139  return cstOp.value() + vectorSize <= sourceSize;
4140 }
4141 
4142 template <typename TransferOp>
4143 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4144  // TODO: support 0-d corner case.
4145  // TODO: Be less conservative.
4146  if (op.getTransferRank() == 0)
4147  return failure();
4148  AffineMap permutationMap = op.getPermutationMap();
4149  bool changed = false;
4150  SmallVector<bool, 4> newInBounds;
4151  newInBounds.reserve(op.getTransferRank());
4152  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
4153  // Already marked as in-bounds, nothing to see here.
4154  if (op.isDimInBounds(i)) {
4155  newInBounds.push_back(true);
4156  continue;
4157  }
4158  // Currently out-of-bounds, check whether we can statically determine it is
4159  // inBounds.
4160  auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
4161  assert(dimExpr && "Broadcast dims must be in-bounds");
4162  auto inBounds =
4163  isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
4164  newInBounds.push_back(inBounds);
4165  // We commit the pattern if it is "more inbounds".
4166  changed |= inBounds;
4167  }
4168  if (!changed)
4169  return failure();
4170  // OpBuilder is only used as a helper to build an I64ArrayAttr.
4171  OpBuilder b(op.getContext());
4172  op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
4173  return success();
4174 }
4175 
4176 template <typename TransferOp>
4177 static LogicalResult foldTransferFullMask(TransferOp op) {
4178  auto mask = op.getMask();
4179  if (!mask)
4180  return failure();
4181 
4182  if (getMaskFormat(mask) != MaskFormat::AllTrue)
4183  return failure();
4184 
4185  op.getMaskMutable().clear();
4186  return success();
4187 }
4188 
4189 /// ```
4190 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4191 /// : vector<1x4xf32>, tensor<4x4xf32>
4192 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
4193 /// : tensor<4x4xf32>, vector<1x4xf32>
4194 /// ```
4195 /// -> Folds into
4196 /// ```
4197 /// %v0
4198 /// ```
4199 static Value foldRAW(TransferReadOp readOp) {
4200  if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4201  return {};
4202  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4203  while (defWrite) {
4204  if (checkSameValueRAW(defWrite, readOp))
4205  return defWrite.getVector();
4207  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4208  cast<VectorTransferOpInterface>(readOp.getOperation())))
4209  break;
4210  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4211  }
4212  return {};
4213 }
4214 
4215 OpFoldResult TransferReadOp::fold(FoldAdaptor) {
4216  if (Value vec = foldRAW(*this))
4217  return vec;
4218  /// transfer_read(memrefcast) -> transfer_read
4219  if (succeeded(foldTransferInBoundsAttribute(*this)))
4220  return getResult();
4221  if (succeeded(foldTransferFullMask(*this)))
4222  return getResult();
4223  if (succeeded(memref::foldMemRefCast(*this)))
4224  return getResult();
4225  if (succeeded(tensor::foldTensorCast(*this)))
4226  return getResult();
4227  return OpFoldResult();
4228 }
4229 
4230 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4231  return llvm::to_vector<4>(getVectorType().getShape());
4232 }
4233 
4234 void TransferReadOp::getEffects(
4236  &effects) {
4237  if (llvm::isa<MemRefType>(getShapedType()))
4238  effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(),
4240 }
4241 
4242 namespace {
4243 /// Store to load forwarding for transfer operations with permuation maps.
4244 /// Even if the permutation maps are different we can still propagate the store
4245 /// into the load if the size of the dimensions read and written match. Then we
4246 /// can replace the transfer_read + transfer_write by vector.broadcast and
4247 /// vector.transpose.
4248 /// Example:
4249 /// ```
4250 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
4251 /// {in_bounds = [true, true],
4252 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
4253 /// vector<4x1xf32>, tensor<4x4x4xf32>
4254 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
4255 /// {in_bounds = [true, true, true, true],
4256 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
4257 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
4258 /// ```
4259 /// To:
4260 /// ```
4261 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
4262 /// %r = vector.transpose %0, [3, 0, 2, 1] :
4263 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
4264 /// ```
4265 struct TransferReadAfterWriteToBroadcast
4266  : public OpRewritePattern<TransferReadOp> {
4268 
4269  LogicalResult matchAndRewrite(TransferReadOp readOp,
4270  PatternRewriter &rewriter) const override {
4271  if (readOp.hasOutOfBoundsDim() ||
4272  !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4273  return failure();
4274  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4275  if (!defWrite)
4276  return failure();
4277  // TODO: If the written transfer chunk is a superset of the read transfer
4278  // chunk we could do an extract_strided_slice.
4279  if (readOp.getTransferChunkAccessed() !=
4280  defWrite.getTransferChunkAccessed())
4281  return failure();
4282  // TODO: Support cases where a dim is explicitly written but implicitly
4283  // read (i.e., a unit dim that is rank reduced).
4284  if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
4285  getUnusedDimsBitVector({defWrite.getPermutationMap()}))
4286  return failure();
4287  if (readOp.getIndices() != defWrite.getIndices() ||
4288  readOp.getMask() != defWrite.getMask())
4289  return failure();
4290  Value vec = defWrite.getVector();
4291  // TODO: loop through the chain of transfer_write if we can prove that they
4292  // don't overlap with the transfer_read. This requires improving
4293  // `isDisjointTransferIndices` helper.
4294  AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4295  AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
4296  AffineMap map = readMap.compose(writeMap);
4297  if (map.getNumResults() == 0)
4298  return failure();
4299  // Calculate the permutation to apply to go from the vector stored to the
4300  // vector read.
4301  SmallVector<unsigned> permutation;
4302  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
4303  return failure();
4304 
4305  Location loc = readOp.getLoc();
4306  // Calculate the broadcast shape by applying the reverse permutation to the
4307  // final shape we want.
4308  ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
4309  SmallVector<int64_t> broadcastShape(destShape.size());
4310  SmallVector<bool> broadcastScalableFlags(destShape.size());
4311  for (const auto &pos : llvm::enumerate(permutation)) {
4312  broadcastShape[pos.value()] = destShape[pos.index()];
4313  broadcastScalableFlags[pos.value()] =
4314  readOp.getVectorType().getScalableDims()[pos.index()];
4315  }
4316  VectorType broadcastedType = VectorType::get(
4317  broadcastShape, defWrite.getVectorType().getElementType(),
4318  broadcastScalableFlags);
4319  vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
4320  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
4321  rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
4322  transposePerm);
4323  return success();
4324  }
4325 };
4326 } // namespace
4327 
4328 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4329  MLIRContext *context) {
4330  results.add<TransferReadAfterWriteToBroadcast>(context);
4331 }
4332 
4333 //===----------------------------------------------------------------------===//
4334 // TransferWriteOp
4335 //===----------------------------------------------------------------------===//
4336 
4337 /// 1. Builder with type inference.
4338 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4339  Value vector, Value dest, ValueRange indices,
4340  AffineMapAttr permutationMapAttr,
4341  /*optional*/ Value mask,
4342  /*optional*/ ArrayAttr inBoundsAttr) {
4343  Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
4344  build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4345  mask, inBoundsAttr);
4346 }
4347 
4348 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
4349 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4350  Value vector, Value dest, ValueRange indices,
4351  AffineMapAttr permutationMapAttr,
4352  /*optional*/ ArrayAttr inBoundsAttr) {
4353  build(builder, result, vector, dest, indices, permutationMapAttr,
4354  /*mask=*/Value(), inBoundsAttr);
4355 }
4356 
4357 /// 3. Builder with type inference that sets an empty mask (variant without
4358 /// attrs)
4359 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4360  Value vector, Value dest, ValueRange indices,
4361  AffineMap permutationMap,
4362  std::optional<ArrayRef<bool>> inBounds) {
4363  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4364  auto inBoundsAttr =
4365  (inBounds && !inBounds.value().empty())
4366  ? builder.getBoolArrayAttr(inBounds.value())
4368  llvm::cast<VectorType>(vector.getType()).getRank(), false));
4369  build(builder, result, vector, dest, indices, permutationMapAttr,
4370  /*mask=*/Value(), inBoundsAttr);
4371 }
4372 
4373 /// 4. Builder with type inference that sets an empty mask and sets permutation
4374 /// map to 'getMinorIdentityMap'.
4375 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4376  Value vector, Value dest, ValueRange indices,
4377  std::optional<ArrayRef<bool>> inBounds) {
4378  auto vectorType = llvm::cast<VectorType>(vector.getType());
4379  AffineMap permutationMap = getTransferMinorIdentityMap(
4380  llvm::cast<ShapedType>(dest.getType()), vectorType);
4381  build(builder, result, vector, dest, indices, permutationMap, inBounds);
4382 }
4383 
4384 ParseResult TransferWriteOp::parse(OpAsmParser &parser,
4385  OperationState &result) {
4386  auto &builder = parser.getBuilder();
4387  SMLoc typesLoc;
4388  OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
4390  SmallVector<Type, 2> types;
4392  if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
4393  parser.parseOperand(sourceInfo) ||
4394  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
4395  return failure();
4396  ParseResult hasMask = parser.parseOptionalComma();
4397  if (hasMask.succeeded() && parser.parseOperand(maskInfo))
4398  return failure();
4399  if (parser.parseOptionalAttrDict(result.attributes) ||
4400  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4401  return failure();
4402  if (types.size() != 2)
4403  return parser.emitError(typesLoc, "requires two types");
4404  auto indexType = builder.getIndexType();
4405  VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4406  if (!vectorType)
4407  return parser.emitError(typesLoc, "requires vector type");
4408  ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4409  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4410  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4411  auto permMapAttrName =
4412  TransferWriteOp::getPermutationMapAttrName(result.name);
4413  auto permMapAttr = result.attributes.get(permMapAttrName);
4414  AffineMap permMap;
4415  if (!permMapAttr) {
4416  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4417  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4418  } else {
4419  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4420  }
4421  auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
4422  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4423  if (!inBoundsAttr) {
4424  result.addAttribute(inBoundsAttrName,
4425  builder.getBoolArrayAttr(
4426  SmallVector<bool>(permMap.getNumResults(), false)));
4427  }
4428  if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
4429  parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4430  parser.resolveOperands(indexInfo, indexType, result.operands))
4431  return failure();
4432  if (hasMask.succeeded()) {
4433  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4434  return parser.emitError(
4435  maskInfo.location, "does not support masks with vector element type");
4436  if (vectorType.getRank() != permMap.getNumResults()) {
4437  return parser.emitError(typesLoc,
4438  "expected the same rank for the vector and the "
4439  "results of the permutation map");
4440  }
4441  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4442  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4443  return failure();
4444  }
4445  result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4446  builder.getDenseI32ArrayAttr(
4447  {1, 1, static_cast<int32_t>(indexInfo.size()),
4448  static_cast<int32_t>(hasMask.succeeded())}));
4449  return failure(llvm::isa<RankedTensorType>(shapedType) &&
4450  parser.addTypeToList(shapedType, result.types));
4451 }
4452 
4454  p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
4455  if (getMask())
4456  p << ", " << getMask();
4457  printTransferAttrs(p, *this);
4458  p << " : " << getVectorType() << ", " << getShapedType();
4459 }
4460 
4461 LogicalResult TransferWriteOp::verify() {
4462  // Consistency of elemental types in shape and vector.
4463  ShapedType shapedType = getShapedType();
4464  VectorType vectorType = getVectorType();
4465  VectorType maskType = getMaskType();
4466  auto permutationMap = getPermutationMap();
4467  VectorType inferredMaskType =
4468  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4469  : VectorType();
4470 
4471  if (llvm::size(getIndices()) != shapedType.getRank())
4472  return emitOpError("requires ") << shapedType.getRank() << " indices";
4473 
4474  // We do not allow broadcast dimensions on TransferWriteOps for the moment,
4475  // as the semantics is unclear. This can be revisited later if necessary.
4476  if (hasBroadcastDim())
4477  return emitOpError("should not have broadcast dimensions");
4478 
4479  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4480  shapedType, vectorType, maskType,
4481  inferredMaskType, permutationMap, getInBounds())))
4482  return failure();
4483 
4484  return verifyPermutationMap(permutationMap,
4485  [&](Twine t) { return emitOpError(t); });
4486 }
4487 
4488 // MaskableOpInterface methods.
4489 
4490 /// Returns the mask type expected by this operation. Mostly used for
4491 /// verification purposes.
4492 Type TransferWriteOp::getExpectedMaskType() {
4493  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4494 }
4495 
4496 /// Fold:
4497 /// ```
4498 /// %t1 = ...
4499 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
4500 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4501 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
4502 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4503 /// ```
4504 ///
4505 /// into:
4506 ///
4507 /// ```
4508 /// %t0
4509 /// ```
4510 ///
4511 /// The producer of t1 may or may not be DCE'd depending on whether it is a
4512 /// block argument or has side effects.
4513 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4515  SmallVectorImpl<OpFoldResult> &results) {
4516  // TODO: support 0-d corner case.
4517  if (write.getTransferRank() == 0)
4518  return failure();
4519  auto rankedTensorType =
4520  llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4521  // If not operating on tensors, bail.
4522  if (!rankedTensorType)
4523  return failure();
4524  // If no read, bail.
4525  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4526  if (!read)
4527  return failure();
4528  // TODO: support 0-d corner case.
4529  if (read.getTransferRank() == 0)
4530  return failure();
4531  // For now, only accept minor identity. Future: composition is minor identity.
4532  if (!read.getPermutationMap().isMinorIdentity() ||
4533  !write.getPermutationMap().isMinorIdentity())
4534  return failure();
4535  // Bail on mismatching ranks.
4536  if (read.getTransferRank() != write.getTransferRank())
4537  return failure();
4538  // Bail on potential out-of-bounds accesses.
4539  if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4540  return failure();
4541  // Tensor types must be the same.
4542  if (read.getSource().getType() != rankedTensorType)
4543  return failure();
4544  // Vector types must be the same.
4545  if (read.getVectorType() != write.getVectorType())
4546  return failure();
4547  // Vector and Tensor shapes must match.
4548  if (read.getVectorType().getShape() != rankedTensorType.getShape())
4549  return failure();
4550  // If any index is nonzero.
4551  auto isNotConstantZero = [](Value v) {
4552  auto cstOp = getConstantIntValue(v);
4553  return !cstOp.has_value() || cstOp.value() != 0;
4554  };
4555  if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4556  llvm::any_of(write.getIndices(), isNotConstantZero))
4557  return failure();
4558  // Success.
4559  results.push_back(read.getSource());
4560  return success();
4561 }
4562 
4563 static bool checkSameValueWAR(vector::TransferReadOp read,
4564  vector::TransferWriteOp write) {
4565  return read.getSource() == write.getSource() &&
4566  read.getIndices() == write.getIndices() &&
4567  read.getPermutationMap() == write.getPermutationMap() &&
4568  read.getVectorType() == write.getVectorType() && !read.getMask() &&
4569  !write.getMask();
4570 }
4571 /// Fold transfer_write write after read:
4572 /// ```
4573 /// %t0 = ...
4574 /// %v = vector.transfer_read %t0[%c0...] :
4575 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4576 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
4577 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4578 /// ```
4579 ///
4580 /// into:
4581 ///
4582 /// ```
4583 /// %t0
4584 /// ```
4585 static LogicalResult foldWAR(TransferWriteOp write,
4586  SmallVectorImpl<OpFoldResult> &results) {
4587  if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4588  return failure();
4589  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4590  if (!read)
4591  return failure();
4592 
4593  if (!checkSameValueWAR(read, write))
4594  return failure();
4595  results.push_back(read.getSource());
4596  return success();
4597 }
4598 
4599 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4600  SmallVectorImpl<OpFoldResult> &results) {
4601  if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
4602  return success();
4603  if (succeeded(foldWAR(*this, results)))
4604  return success();
4605  if (succeeded(foldTransferInBoundsAttribute(*this)))
4606  return success();
4607  if (succeeded(foldTransferFullMask(*this)))
4608  return success();
4609  return memref::foldMemRefCast(*this);
4610 }
4611 
4612 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4613  return llvm::to_vector<4>(getVectorType().getShape());
4614 }
4615 
4616 void TransferWriteOp::getEffects(
4618  &effects) {
4619  if (llvm::isa<MemRefType>(getShapedType()))
4620  effects.emplace_back(MemoryEffects::Write::get(), &getSourceMutable(),
4622 }
4623 
4624 namespace {
4625 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
4626 /// DCE
4627 /// ```
4628 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4629 /// : vector<1x4xf32>, tensor<4x4xf32>
4630 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
4631 /// : vector<1x4xf32>, tensor<4x4xf32>
4632 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4633 /// : vector<1x4xf32>, tensor<4x4xf32>
4634 /// ```
4635 ///
4636 /// into:
4637 ///
4638 /// ```
4639 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4640 /// : vector<1x4xf32>, tensor<4x4xf32>
4641 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
4642 /// : vector<1x4xf32>, tensor<4x4xf32>
4643 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4644 /// : vector<1x4xf32>, tensor<4x4xf32>
4645 /// ```
4646 ///
4647 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
4648 /// any other uses.
4649 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
4650 public:
4652  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
4653  PatternRewriter &rewriter) const override {
4654  if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4655  return failure();
4656  vector::TransferWriteOp writeToModify = writeOp;
4657 
4658  auto defWrite =
4659  writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4660  while (defWrite) {
4661  if (checkSameValueWAW(writeOp, defWrite)) {
4662  rewriter.modifyOpInPlace(writeToModify, [&]() {
4663  writeToModify.getSourceMutable().assign(defWrite.getSource());
4664  });
4665  return success();
4666  }
4668  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4669  cast<VectorTransferOpInterface>(writeOp.getOperation())))
4670  break;
4671  // If the previous write op doesn't have any other use we an safely look
4672  // at the previous store to see if it can be removed.
4673  if (!defWrite->hasOneUse())
4674  break;
4675  writeToModify = defWrite;
4676  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4677  }
4678  return failure();
4679  }
4680 };
4681 
4682 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
4683 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
4684 /// overwritten and inserted into another tensor. After this rewrite, the
4685 /// operations bufferize in-place since all of them work on the same slice.
4686 ///
4687 /// For example:
4688 /// ```mlir
4689 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
4690 /// : vector<8x16xf32>, tensor<8x16xf32>
4691 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
4692 /// : tensor<8x16xf32> to tensor<?x?xf32>
4693 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4694 /// : tensor<?x?xf32> into tensor<27x37xf32>
4695 /// ```
4696 /// folds to
4697 /// ```mlir
4698 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4699 /// : tensor<27x37xf32> to tensor<?x?xf32>
4700 /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
4701 /// : vector<8x16xf32>, tensor<?x?xf32>
4702 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4703 /// : tensor<?x?xf32> into tensor<27x37xf32>
4704 /// ```
4705 struct SwapExtractSliceOfTransferWrite
4706  : public OpRewritePattern<tensor::InsertSliceOp> {
4707 public:
4709 
4710  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
4711  PatternRewriter &rewriter) const override {
4712  if (!insertOp.hasUnitStride())
4713  return failure();
4714  auto extractOp =
4715  insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4716  if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4717  return failure();
4718  auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4719  if (!transferOp || !transferOp->hasOneUse())
4720  return failure();
4721 
4722  // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
4723  // rank-reducing.
4724  if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4725  return rewriter.notifyMatchFailure(insertOp,
4726  "use-def chain is rank-reducing");
4727  }
4728 
4729  // Fail if tensor::ExtractSliceOp has non-zero offset.
4730  if (!extractOp.hasZeroOffset()) {
4731  return rewriter.notifyMatchFailure(insertOp,
4732  "ExtractSliceOp has non-zero offset");
4733  }
4734 
4735  // Fail if tensor::TransferWriteOp has non-zero offset.
4736  if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
4737  return getConstantIntValue(value) == static_cast<int64_t>(0);
4738  })) {
4739  return rewriter.notifyMatchFailure(insertOp,
4740  "TranferWriteOp has non-zero offset");
4741  }
4742 
4743  // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
4744  if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
4745  return rewriter.notifyMatchFailure(
4746  insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
4747  }
4748 
4749  for (auto [insertSize, extractSize] :
4750  llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
4751  if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
4752  return rewriter.notifyMatchFailure(
4753  insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
4754  }
4755  }
4756 
4757  // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
4758  assert(transferOp.getVectorType().hasStaticShape() &&
4759  "expected vector to have a static shape");
4760  ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
4762  transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
4763  if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
4764  return rewriter.notifyMatchFailure(
4765  insertOp, "TransferWriteOp may not write the full tensor.");
4766  }
4767 
4768  // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
4769  // Set all in_bounds to false and let the folder infer them.
4770  SmallVector<bool> newInBounds(vectorShape.size(), false);
4771  auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
4772  extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
4773  insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
4774  insertOp.getMixedStrides());
4775  auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
4776  transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
4777  transferOp.getIndices(), transferOp.getPermutationMapAttr(),
4778  rewriter.getBoolArrayAttr(newInBounds));
4779  rewriter.modifyOpInPlace(insertOp, [&]() {
4780  insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
4781  });
4782  return success();
4783  }
4784 };
4785 
4786 } // namespace
4787 
4788 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
4789  MLIRContext *context) {
4790  results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
4791 }
4792 
4793 //===----------------------------------------------------------------------===//
4794 // LoadOp
4795 //===----------------------------------------------------------------------===//
4796 
4797 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
4798  MemRefType memRefTy) {
4799  if (!isLastMemrefDimUnitStride(memRefTy))
4800  return op->emitOpError("most minor memref dim must have unit stride");
4801  return success();
4802 }
4803 
4804 LogicalResult vector::LoadOp::verify() {
4805  VectorType resVecTy = getVectorType();
4806  MemRefType memRefTy = getMemRefType();
4807 
4808  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
4809  return failure();
4810 
4811  // Checks for vector memrefs.
4812  Type memElemTy = memRefTy.getElementType();
4813  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4814  if (memVecTy != resVecTy)
4815  return emitOpError("base memref and result vector types should match");
4816  memElemTy = memVecTy.getElementType();
4817  }
4818 
4819  if (resVecTy.getElementType() != memElemTy)
4820  return emitOpError("base and result element types should match");
4821  if (llvm::size(getIndices()) != memRefTy.getRank())
4822  return emitOpError("requires ") << memRefTy.getRank() << " indices";
4823  return success();
4824 }
4825 
4826 OpFoldResult LoadOp::fold(FoldAdaptor) {
4827  if (succeeded(memref::foldMemRefCast(*this)))
4828  return getResult();
4829  return OpFoldResult();
4830 }
4831 
4832 //===----------------------------------------------------------------------===//
4833 // StoreOp
4834 //===----------------------------------------------------------------------===//
4835 
4836 LogicalResult vector::StoreOp::verify() {
4837  VectorType valueVecTy = getVectorType();
4838  MemRefType memRefTy = getMemRefType();
4839 
4840  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
4841  return failure();
4842 
4843  // Checks for vector memrefs.
4844  Type memElemTy = memRefTy.getElementType();
4845  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4846  if (memVecTy != valueVecTy)
4847  return emitOpError(
4848  "base memref and valueToStore vector types should match");
4849  memElemTy = memVecTy.getElementType();
4850  }
4851 
4852  if (valueVecTy.getElementType() != memElemTy)
4853  return emitOpError("base and valueToStore element type should match");
4854  if (llvm::size(getIndices()) != memRefTy.getRank())
4855  return emitOpError("requires ") << memRefTy.getRank() << " indices";
4856  return success();
4857 }
4858 
4859 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
4860  SmallVectorImpl<OpFoldResult> &results) {
4861  return memref::foldMemRefCast(*this);
4862 }
4863 
4864 //===----------------------------------------------------------------------===//
4865 // MaskedLoadOp
4866 //===----------------------------------------------------------------------===//
4867 
4868 LogicalResult MaskedLoadOp::verify() {
4869  VectorType maskVType = getMaskVectorType();
4870  VectorType passVType = getPassThruVectorType();
4871  VectorType resVType = getVectorType();
4872  MemRefType memType = getMemRefType();
4873 
4874  if (resVType.getElementType() != memType.getElementType())
4875  return emitOpError("base and result element type should match");
4876  if (llvm::size(getIndices()) != memType.getRank())
4877  return emitOpError("requires ") << memType.getRank() << " indices";
4878  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4879  return emitOpError("expected result dim to match mask dim");
4880  if (resVType != passVType)
4881  return emitOpError("expected pass_thru of same type as result type");
4882  return success();
4883 }
4884 
4885 namespace {
4886 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
4887 public:
4889  LogicalResult matchAndRewrite(MaskedLoadOp load,
4890  PatternRewriter &rewriter) const override {
4891  switch (getMaskFormat(load.getMask())) {
4892  case MaskFormat::AllTrue:
4893  rewriter.replaceOpWithNewOp<vector::LoadOp>(
4894  load, load.getType(), load.getBase(), load.getIndices());
4895  return success();
4896  case MaskFormat::AllFalse:
4897  rewriter.replaceOp(load, load.getPassThru());
4898  return success();
4899  case MaskFormat::Unknown:
4900  return failure();
4901  }
4902  llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
4903  }
4904 };
4905 } // namespace
4906 
4907 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4908  MLIRContext *context) {
4909  results.add<MaskedLoadFolder>(context);
4910 }
4911 
4912 OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
4913  if (succeeded(memref::foldMemRefCast(*this)))
4914  return getResult();
4915  return OpFoldResult();
4916 }
4917 
4918 //===----------------------------------------------------------------------===//
4919 // MaskedStoreOp
4920 //===----------------------------------------------------------------------===//
4921 
4922 LogicalResult MaskedStoreOp::verify() {
4923  VectorType maskVType = getMaskVectorType();
4924  VectorType valueVType = getVectorType();
4925  MemRefType memType = getMemRefType();
4926 
4927  if (valueVType.getElementType() != memType.getElementType())
4928  return emitOpError("base and valueToStore element type should match");
4929  if (llvm::size(getIndices()) != memType.getRank())
4930  return emitOpError("requires ") << memType.getRank() << " indices";
4931  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4932  return emitOpError("expected valueToStore dim to match mask dim");
4933  return success();
4934 }
4935 
4936 namespace {
4937 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
4938 public:
4940  LogicalResult matchAndRewrite(MaskedStoreOp store,
4941  PatternRewriter &rewriter) const override {
4942  switch (getMaskFormat(store.getMask())) {
4943  case MaskFormat::AllTrue:
4944  rewriter.replaceOpWithNewOp<vector::StoreOp>(
4945  store, store.getValueToStore(), store.getBase(), store.getIndices());
4946  return success();
4947  case MaskFormat::AllFalse:
4948  rewriter.eraseOp(store);
4949  return success();
4950  case MaskFormat::Unknown:
4951  return failure();
4952  }
4953  llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
4954  }
4955 };
4956 } // namespace
4957 
4958 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
4959  MLIRContext *context) {
4960  results.add<MaskedStoreFolder>(context);
4961 }
4962 
4963 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
4964  SmallVectorImpl<OpFoldResult> &results) {
4965  return memref::foldMemRefCast(*this);
4966 }
4967 
4968 //===----------------------------------------------------------------------===//
4969 // GatherOp
4970 //===----------------------------------------------------------------------===//
4971 
4972 LogicalResult GatherOp::verify() {
4973  VectorType indVType = getIndexVectorType();
4974  VectorType maskVType = getMaskVectorType();
4975  VectorType resVType = getVectorType();
4976  ShapedType baseType = getBaseType();
4977 
4978  if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
4979  return emitOpError("requires base to be a memref or ranked tensor type");
4980 
4981  if (resVType.getElementType() != baseType.getElementType())
4982  return emitOpError("base and result element type should match");
4983  if (llvm::size(getIndices()) != baseType.getRank())
4984  return emitOpError("requires ") << baseType.getRank() << " indices";
4985  if (resVType.getShape() != indVType.getShape())
4986  return emitOpError("expected result dim to match indices dim");
4987  if (resVType.getShape() != maskVType.getShape())
4988  return emitOpError("expected result dim to match mask dim");
4989  if (resVType != getPassThruVectorType())
4990  return emitOpError("expected pass_thru of same type as result type");
4991  return success();
4992 }
4993 
4994 // MaskableOpInterface methods.
4995 
4996 /// Returns the mask type expected by this operation. Mostly used for
4997 /// verification purposes. It requires the operation to be vectorized."
4998 Type GatherOp::getExpectedMaskType() {
4999  auto vecType = this->getIndexVectorType();
5000  return VectorType::get(vecType.getShape(),
5001  IntegerType::get(vecType.getContext(), /*width=*/1),
5002  vecType.getScalableDims());
5003 }
5004 
5005 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5006  return llvm::to_vector<4>(getVectorType().getShape());
5007 }
5008 
5009 namespace {
5010 class GatherFolder final : public OpRewritePattern<GatherOp> {
5011 public:
5013  LogicalResult matchAndRewrite(GatherOp gather,
5014  PatternRewriter &rewriter) const override {
5015  switch (getMaskFormat(gather.getMask())) {
5016  case MaskFormat::AllTrue:
5017  return failure(); // no unmasked equivalent
5018  case MaskFormat::AllFalse:
5019  rewriter.replaceOp(gather, gather.getPassThru());
5020  return success();
5021  case MaskFormat::Unknown:
5022  return failure();
5023  }
5024  llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
5025  }
5026 };
5027 } // namespace
5028 
5029 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
5030  MLIRContext *context) {
5031  results.add<GatherFolder>(context);
5032 }
5033 
5034 //===----------------------------------------------------------------------===//
5035 // ScatterOp
5036 //===----------------------------------------------------------------------===//
5037 
5038 LogicalResult ScatterOp::verify() {
5039  VectorType indVType = getIndexVectorType();
5040  VectorType maskVType = getMaskVectorType();
5041  VectorType valueVType = getVectorType();
5042  MemRefType memType = getMemRefType();
5043 
5044  if (valueVType.getElementType() != memType.getElementType())
5045  return emitOpError("base and valueToStore element type should match");
5046  if (llvm::size(getIndices()) != memType.getRank())
5047  return emitOpError("requires ") << memType.getRank() << " indices";
5048  if (valueVType.getDimSize(0) != indVType.getDimSize(0))
5049  return emitOpError("expected valueToStore dim to match indices dim");
5050  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5051  return emitOpError("expected valueToStore dim to match mask dim");
5052  return success();
5053 }
5054 
5055 namespace {
5056 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
5057 public:
5059  LogicalResult matchAndRewrite(ScatterOp scatter,
5060  PatternRewriter &rewriter) const override {
5061  switch (getMaskFormat(scatter.getMask())) {
5062  case MaskFormat::AllTrue:
5063  return failure(); // no unmasked equivalent
5064  case MaskFormat::AllFalse:
5065  rewriter.eraseOp(scatter);
5066  return success();
5067  case MaskFormat::Unknown:
5068  return failure();
5069  }
5070  llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
5071  }
5072 };
5073 } // namespace
5074 
5075 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
5076  MLIRContext *context) {
5077  results.add<ScatterFolder>(context);
5078 }
5079 
5080 //===----------------------------------------------------------------------===//
5081 // ExpandLoadOp
5082 //===----------------------------------------------------------------------===//
5083 
5084 LogicalResult ExpandLoadOp::verify() {
5085  VectorType maskVType = getMaskVectorType();
5086  VectorType passVType = getPassThruVectorType();
5087  VectorType resVType = getVectorType();
5088  MemRefType memType = getMemRefType();
5089 
5090  if (resVType.getElementType() != memType.getElementType())
5091  return emitOpError("base and result element type should match");
5092  if (llvm::size(getIndices()) != memType.getRank())
5093  return emitOpError("requires ") << memType.getRank() << " indices";
5094  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5095  return emitOpError("expected result dim to match mask dim");
5096  if (resVType != passVType)
5097  return emitOpError("expected pass_thru of same type as result type");
5098  return success();
5099 }
5100 
5101 namespace {
5102 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
5103 public:
5105  LogicalResult matchAndRewrite(ExpandLoadOp expand,
5106  PatternRewriter &rewriter) const override {
5107  switch (getMaskFormat(expand.getMask())) {
5108  case MaskFormat::AllTrue:
5109  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5110  expand, expand.getType(), expand.getBase(), expand.getIndices());
5111  return success();
5112  case MaskFormat::AllFalse:
5113  rewriter.replaceOp(expand, expand.getPassThru());
5114  return success();
5115  case MaskFormat::Unknown:
5116  return failure();
5117  }
5118  llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
5119  }
5120 };
5121 } // namespace
5122 
5123 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5124  MLIRContext *context) {
5125  results.add<ExpandLoadFolder>(context);
5126 }
5127 
5128 //===----------------------------------------------------------------------===//
5129 // CompressStoreOp
5130 //===----------------------------------------------------------------------===//
5131 
5132 LogicalResult CompressStoreOp::verify() {
5133  VectorType maskVType = getMaskVectorType();
5134  VectorType valueVType = getVectorType();
5135  MemRefType memType = getMemRefType();
5136 
5137  if (valueVType.getElementType() != memType.getElementType())
5138  return emitOpError("base and valueToStore element type should match");
5139  if (llvm::size(getIndices()) != memType.getRank())
5140  return emitOpError("requires ") << memType.getRank() << " indices";
5141  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5142  return emitOpError("expected valueToStore dim to match mask dim");
5143  return success();
5144 }
5145 
5146 namespace {
5147 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
5148 public:
5150  LogicalResult matchAndRewrite(CompressStoreOp compress,
5151  PatternRewriter &rewriter) const override {
5152  switch (getMaskFormat(compress.getMask())) {
5153  case MaskFormat::AllTrue:
5154  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5155  compress, compress.getValueToStore(), compress.getBase(),
5156  compress.getIndices());
5157  return success();
5158  case MaskFormat::AllFalse:
5159  rewriter.eraseOp(compress);
5160  return success();
5161  case MaskFormat::Unknown:
5162  return failure();
5163  }
5164  llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
5165  }
5166 };
5167 } // namespace
5168 
5169 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5170  MLIRContext *context) {
5171  results.add<CompressStoreFolder>(context);
5172 }
5173 
5174 //===----------------------------------------------------------------------===//
5175 // ShapeCastOp
5176 //===----------------------------------------------------------------------===//
5177 
5178 /// Returns true if each element of 'a' is equal to the product of a contiguous
5179 /// sequence of the elements of 'b'. Returns false otherwise.
5180 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
5181  unsigned rankA = a.size();
5182  unsigned rankB = b.size();
5183  assert(rankA < rankB);
5184 
5185  auto isOne = [](int64_t v) { return v == 1; };
5186 
5187  // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5188  // casted to a 0-d vector.
5189  if (rankA == 0 && llvm::all_of(b, isOne))
5190  return true;
5191 
5192  unsigned i = 0;
5193  unsigned j = 0;
5194  while (i < rankA && j < rankB) {
5195  int64_t dimA = a[i];
5196  int64_t dimB = 1;
5197  while (dimB < dimA && j < rankB)
5198  dimB *= b[j++];
5199  if (dimA != dimB)
5200  break;
5201  ++i;
5202 
5203  // Handle the case when trailing dimensions are of size 1.
5204  // Include them into the contiguous sequence.
5205  if (i < rankA && llvm::all_of(a.slice(i), isOne))
5206  i = rankA;
5207  if (j < rankB && llvm::all_of(b.slice(j), isOne))
5208  j = rankB;
5209  }
5210 
5211  return i == rankA && j == rankB;
5212 }
5213 
5214 static LogicalResult verifyVectorShapeCast(Operation *op,
5215  VectorType sourceVectorType,
5216  VectorType resultVectorType) {
5217  // Check that element type is the same.
5218  if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5219  return op->emitOpError("source/result vectors must have same element type");
5220  auto sourceShape = sourceVectorType.getShape();
5221  auto resultShape = resultVectorType.getShape();
5222 
5223  // Check that product of source dim sizes matches product of result dim sizes.
5224  int64_t sourceDimProduct = std::accumulate(
5225  sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5226  int64_t resultDimProduct = std::accumulate(
5227  resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5228  if (sourceDimProduct != resultDimProduct)
5229  return op->emitOpError("source/result number of elements must match");
5230 
5231  // Check that expanding/contracting rank cases.
5232  unsigned sourceRank = sourceVectorType.getRank();
5233  unsigned resultRank = resultVectorType.getRank();
5234  if (sourceRank < resultRank) {
5235  if (!isValidShapeCast(sourceShape, resultShape))
5236  return op->emitOpError("invalid shape cast");
5237  } else if (sourceRank > resultRank) {
5238  if (!isValidShapeCast(resultShape, sourceShape))
5239  return op->emitOpError("invalid shape cast");
5240  }
5241 
5242  // Check that (non-)scalability is preserved
5243  int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5244  int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5245  if (sourceNScalableDims != resultNScalableDims)
5246  return op->emitOpError("different number of scalable dims at source (")
5247  << sourceNScalableDims << ") and result (" << resultNScalableDims
5248  << ")";
5249  sourceVectorType.getNumDynamicDims();
5250 
5251  return success();
5252 }
5253 
5254 LogicalResult ShapeCastOp::verify() {
5255  auto sourceVectorType =
5256  llvm::dyn_cast_or_null<VectorType>(getSource().getType());
5257  auto resultVectorType =
5258  llvm::dyn_cast_or_null<VectorType>(getResult().getType());
5259 
5260  // Check if source/result are of vector type.
5261  if (sourceVectorType && resultVectorType)
5262  return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
5263 
5264  return success();
5265 }
5266 
5267 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5268  // No-op shape cast.
5269  if (getSource().getType() == getResult().getType())
5270  return getSource();
5271 
5272  // Canceling shape casts.
5273  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5274  if (getResult().getType() == otherOp.getSource().getType())
5275  return otherOp.getSource();
5276 
5277  // Only allows valid transitive folding.
5278  VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5279  VectorType resultType = llvm::cast<VectorType>(getResult().getType());
5280  if (srcType.getRank() < resultType.getRank()) {
5281  if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5282  return {};
5283  } else if (srcType.getRank() > resultType.getRank()) {
5284  if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5285  return {};
5286  } else {
5287  return {};
5288  }
5289 
5290  setOperand(otherOp.getSource());
5291  return getResult();
5292  }
5293 
5294  // Cancelling broadcast and shape cast ops.
5295  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5296  if (bcastOp.getSourceType() == getType())
5297  return bcastOp.getSource();
5298  }
5299 
5300  return {};
5301 }
5302 
5303 namespace {
5304 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
5305 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
5306 public:
5308 
5309  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5310  PatternRewriter &rewriter) const override {
5311  auto constantOp =
5312  shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5313  if (!constantOp)
5314  return failure();
5315  // Only handle splat for now.
5316  auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5317  if (!dense)
5318  return failure();
5319  auto newAttr =
5320  DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
5321  dense.getSplatValue<Attribute>());
5322  rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
5323  return success();
5324  }
5325 };
5326 
5327 /// Helper function that computes a new vector type based on the input vector
5328 /// type by removing the trailing one dims:
5329 ///
5330 /// vector<4x1x1xi1> --> vector<4x1>
5331 ///
5332 static VectorType trimTrailingOneDims(VectorType oldType) {
5333  ArrayRef<int64_t> oldShape = oldType.getShape();
5334  ArrayRef<int64_t> newShape = oldShape;
5335 
5336  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
5337  ArrayRef<bool> newScalableDims = oldScalableDims;
5338 
5339  while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5340  newShape = newShape.drop_back(1);
5341  newScalableDims = newScalableDims.drop_back(1);
5342  }
5343 
5344  // Make sure we have at least 1 dimension.
5345  // TODO: Add support for 0-D vectors.
5346  if (newShape.empty()) {
5347  newShape = oldShape.take_back();
5348  newScalableDims = oldScalableDims.take_back();
5349  }
5350 
5351  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5352 }
5353 
5354 /// Folds qualifying shape_cast(create_mask) into a new create_mask
5355 ///
5356 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
5357 /// dimension. If the input vector comes from `vector.create_mask` for which
5358 /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
5359 /// to fold shape_cast into create_mask.
5360 ///
5361 /// BEFORE:
5362 /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
5363 /// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
5364 /// AFTER:
5365 /// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
5366 class ShapeCastCreateMaskFolderTrailingOneDim final
5367  : public OpRewritePattern<ShapeCastOp> {
5368 public:
5370 
5371  LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
5372  PatternRewriter &rewriter) const override {
5373  Value shapeOpSrc = shapeOp->getOperand(0);
5374  auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
5375  auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
5376  if (!createMaskOp && !constantMaskOp)
5377  return failure();
5378 
5379  VectorType shapeOpResTy = shapeOp.getResultVectorType();
5380  VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5381 
5382  VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5383  if (newVecType != shapeOpResTy)
5384  return failure();
5385 
5386  auto numDimsToDrop =
5387  shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5388 
5389  // No unit dims to drop
5390  if (!numDimsToDrop)
5391  return failure();
5392 
5393  if (createMaskOp) {
5394  auto maskOperands = createMaskOp.getOperands();
5395  auto numMaskOperands = maskOperands.size();
5396 
5397  // Check every mask dim size to see whether it can be dropped
5398  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5399  --i) {
5400  auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5401  if (!constant || (constant.value() != 1))
5402  return failure();
5403  }
5404  SmallVector<Value> newMaskOperands =
5405  maskOperands.drop_back(numDimsToDrop);
5406 
5407  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
5408  newMaskOperands);
5409  return success();
5410  }
5411 
5412  if (constantMaskOp) {
5413  auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
5414  auto numMaskOperands = maskDimSizes.size();
5415 
5416  // Check every mask dim size to see whether it can be dropped
5417  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5418  --i) {
5419  if (cast<IntegerAttr>(maskDimSizes[i]).getValue() != 1)
5420  return failure();
5421  }
5422 
5423  auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5424  ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr(newMaskOperands);
5425 
5426  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
5427  newMaskOperandsAttr);
5428  return success();
5429  }
5430 
5431  return failure();
5432  }
5433 };
5434 
5435 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
5436 /// This only applies when the shape of the broadcast source
5437 /// 1. is a suffix of the shape of the result (i.e. when broadcast without
5438 /// reshape is expressive enough to capture the result in a single op), or
5439 /// 2. has the same element count as the shape cast result.
5440 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5441 public:
5443 
5444  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5445  PatternRewriter &rewriter) const override {
5446  auto broadcastOp =
5447  shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5448  if (!broadcastOp)
5449  return failure();
5450 
5451  ArrayRef<int64_t> broadcastSourceShape;
5452  if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5453  broadcastSourceShape = srcType.getShape();
5454  ArrayRef<int64_t> shapeCastTargetShape =
5455  shapeCastOp.getResultVectorType().getShape();
5456 
5457  // If `broadcastSourceShape` is a suffix of the result, we can just replace
5458  // with a broadcast to the final shape.
5459  if (broadcastSourceShape ==
5460  shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5461  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5462  shapeCastOp, shapeCastOp.getResultVectorType(),
5463  broadcastOp.getSource());
5464  return success();
5465  }
5466 
5467  // Otherwise, if the final result has the same element count, we can replace
5468  // with a shape cast.
5469  if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5470  if (srcType.getNumElements() ==
5471  shapeCastOp.getResultVectorType().getNumElements()) {
5472  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
5473  shapeCastOp, shapeCastOp.getResultVectorType(),
5474  broadcastOp.getSource());
5475  return success();
5476  }
5477  }
5478 
5479  return failure();
5480  }
5481 };
5482 
5483 } // namespace
5484 
5485 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
5486  MLIRContext *context) {
5487  results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5488  ShapeCastBroadcastFolder>(context);
5489 }
5490 
5491 //===----------------------------------------------------------------------===//
5492 // VectorBitCastOp
5493 //===----------------------------------------------------------------------===//
5494 
5495 LogicalResult BitCastOp::verify() {
5496  auto sourceVectorType = getSourceVectorType();
5497  auto resultVectorType = getResultVectorType();
5498 
5499  for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5500  if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5501  return emitOpError("dimension size mismatch at: ") << i;
5502  }
5503 
5504  DataLayout dataLayout = DataLayout::closest(*this);
5505  auto sourceElementBits =
5506  dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
5507  auto resultElementBits =
5508  dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
5509 
5510  if (sourceVectorType.getRank() == 0) {
5511  if (sourceElementBits != resultElementBits)
5512  return emitOpError("source/result bitwidth of the 0-D vector element "
5513  "types must be equal");
5514  } else if (sourceElementBits * sourceVectorType.getShape().back() !=
5515  resultElementBits * resultVectorType.getShape().back()) {
5516  return emitOpError(
5517  "source/result bitwidth of the minor 1-D vectors must be equal");
5518  }
5519 
5520  return success();
5521 }
5522 
5523 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
5524  // Nop cast.
5525  if (getSource().getType() == getResult().getType())
5526  return getSource();
5527 
5528  // Canceling bitcasts.
5529  if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5530  if (getResult().getType() == otherOp.getSource().getType())
5531  return otherOp.getSource();
5532 
5533  setOperand(otherOp.getSource());
5534  return getResult();
5535  }
5536 
5537  Attribute sourceConstant = adaptor.getSource();
5538  if (!sourceConstant)
5539  return {};
5540 
5541  Type srcElemType = getSourceVectorType().getElementType();
5542  Type dstElemType = getResultVectorType().getElementType();
5543 
5544  if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5545  if (floatPack.isSplat()) {
5546  auto splat = floatPack.getSplatValue<FloatAttr>();
5547 
5548  // Casting fp16 into fp32.
5549  if (srcElemType.isF16() && dstElemType.isF32()) {
5550  uint32_t bits = static_cast<uint32_t>(
5551  splat.getValue().bitcastToAPInt().getZExtValue());
5552  // Duplicate the 16-bit pattern.
5553  bits = (bits << 16) | (bits & 0xffff);
5554  APInt intBits(32, bits);
5555  APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5556  return DenseElementsAttr::get(getResultVectorType(), floatBits);
5557  }
5558  }
5559  }
5560 
5561  if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5562  if (intPack.isSplat()) {
5563  auto splat = intPack.getSplatValue<IntegerAttr>();
5564 
5565  if (llvm::isa<IntegerType>(dstElemType)) {
5566  uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
5567  uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
5568 
5569  // Casting to a larger integer bit width.
5570  if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5571  APInt intBits = splat.getValue().zext(dstBitWidth);
5572 
5573  // Duplicate the lower width element.
5574  for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5575  intBits = (intBits << srcBitWidth) | intBits;
5576  return DenseElementsAttr::get(getResultVectorType(), intBits);
5577  }
5578  }
5579  }
5580  }
5581 
5582  return {};
5583 }
5584 
5585 //===----------------------------------------------------------------------===//
5586 // TypeCastOp
5587 //===----------------------------------------------------------------------===//
5588 
5589 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
5590  auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5591  SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
5592  memRefType.getShape().end());
5593  if (vectorType)
5594  res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5595  return res;
5596 }
5597 
5598 /// Build the canonical memRefType with a single vector.
5599 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
5600 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
5601  Value source) {
5602  result.addOperands(source);
5603  MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
5604  VectorType vectorType =
5605  VectorType::get(extractShape(memRefType),
5607  result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
5608  memRefType.getMemorySpace()));
5609 }
5610 
5611 LogicalResult TypeCastOp::verify() {
5612  MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType());
5613  if (!canonicalType.getLayout().isIdentity())
5614  return emitOpError("expects operand to be a memref with identity layout");
5615  if (!getResultMemRefType().getLayout().isIdentity())
5616  return emitOpError("expects result to be a memref with identity layout");
5617  if (getResultMemRefType().getMemorySpace() !=
5618  getMemRefType().getMemorySpace())
5619  return emitOpError("expects result in same memory space");
5620 
5621  auto sourceType = getMemRefType();
5622  auto resultType = getResultMemRefType();
5623  if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
5625  return emitOpError(
5626  "expects result and operand with same underlying scalar type: ")
5627  << resultType;
5628  if (extractShape(sourceType) != extractShape(resultType))
5629  return emitOpError(
5630  "expects concatenated result and operand shapes to be equal: ")
5631  << resultType;
5632  return success();
5633 }
5634 
5635 //===----------------------------------------------------------------------===//
5636 // TransposeOp
5637 //===----------------------------------------------------------------------===//
5638 
5639 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
5640  Value vector, ArrayRef<int64_t> permutation) {
5641  VectorType vt = llvm::cast<VectorType>(vector.getType());
5642  SmallVector<int64_t, 4> transposedShape(vt.getRank());
5643  SmallVector<bool, 4> transposedScalableDims(vt.getRank());
5644  for (unsigned i = 0; i < permutation.size(); ++i) {
5645  transposedShape[i] = vt.getShape()[permutation[i]];
5646  transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5647  }
5648 
5649  result.addOperands(vector);
5650  result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
5651  transposedScalableDims));
5652  result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
5653  builder.getDenseI64ArrayAttr(permutation));
5654 }
5655 
5656 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5657  // Eliminate splat constant transpose ops.
5658  if (auto attr =
5659  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
5660  if (attr.isSplat())
5661  return attr.reshape(getResultVectorType());
5662 
5663  // Eliminate identity transpose ops. This happens when the dimensions of the
5664  // input vector remain in their original order after the transpose operation.
5665  ArrayRef<int64_t> perm = getPermutation();
5666 
5667  // Check if the permutation of the dimensions contains sequential values:
5668  // {0, 1, 2, ...}.
5669  for (int64_t i = 0, e = perm.size(); i < e; i++) {
5670  if (perm[i] != i)
5671  return {};
5672  }
5673 
5674  return getVector();
5675 }
5676 
5677 LogicalResult vector::TransposeOp::verify() {
5678  VectorType vectorType = getSourceVectorType();
5679  VectorType resultType = getResultVectorType();
5680  int64_t rank = resultType.getRank();
5681  if (vectorType.getRank() != rank)
5682  return emitOpError("vector result rank mismatch: ") << rank;
5683  // Verify transposition array.
5684  ArrayRef<int64_t> perm = getPermutation();
5685  int64_t size = perm.size();
5686  if (rank != size)
5687  return emitOpError("transposition length mismatch: ") << size;
5688  SmallVector<bool, 8> seen(rank, false);
5689  for (const auto &ta : llvm::enumerate(perm)) {
5690  if (ta.value() < 0 || ta.value() >= rank)
5691  return emitOpError("transposition index out of range: ") << ta.value();
5692  if (seen[ta.value()])
5693  return emitOpError("duplicate position index: ") << ta.value();
5694  seen[ta.value()] = true;
5695  if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
5696  return emitOpError("dimension size mismatch at: ") << ta.value();
5697  }
5698  return success();
5699 }
5700 
5701 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
5702  return llvm::to_vector<4>(getResultVectorType().getShape());
5703 }
5704 
5705 namespace {
5706 
5707 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
5708 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
5709 public:
5711 
5712  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5713  PatternRewriter &rewriter) const override {
5714  // Composes two permutations: result[i] = permutation1[permutation2[i]].
5715  auto composePermutations = [](ArrayRef<int64_t> permutation1,
5716  ArrayRef<int64_t> permutation2) {
5717  SmallVector<int64_t, 4> result;
5718  for (auto index : permutation2)
5719  result.push_back(permutation1[index]);
5720  return result;
5721  };
5722 
5723  // Return if the input of 'transposeOp' is not defined by another transpose.
5724  vector::TransposeOp parentTransposeOp =
5725  transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
5726  if (!parentTransposeOp)
5727  return failure();
5728 
5729  SmallVector<int64_t, 4> permutation = composePermutations(
5730  parentTransposeOp.getPermutation(), transposeOp.getPermutation());
5731  // Replace 'transposeOp' with a new transpose operation.
5732  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
5733  transposeOp, transposeOp.getResult().getType(),
5734  parentTransposeOp.getVector(), permutation);
5735  return success();
5736  }
5737 };
5738 
5739 // Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
5740 struct FoldTransposedScalarBroadcast final
5741  : public OpRewritePattern<vector::TransposeOp> {
5743 
5744  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5745  PatternRewriter &rewriter) const override {
5746  auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
5747  if (!bcastOp)
5748  return failure();
5749 
5750  auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
5751  if (!srcVectorType || srcVectorType.getNumElements() == 1) {
5752  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5753  transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
5754  return success();
5755  }
5756 
5757  return failure();
5758  }
5759 };
5760 
5761 // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
5762 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
5763 public:
5765 
5766  LogicalResult matchAndRewrite(TransposeOp transposeOp,
5767  PatternRewriter &rewriter) const override {
5768  auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
5769  if (!splatOp)
5770  return failure();
5771 
5772  rewriter.replaceOpWithNewOp<vector::SplatOp>(
5773  transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
5774  return success();
5775  }
5776 };
5777 
5778 /// Folds transpose(create_mask) into a new transposed create_mask.
5779 class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
5780 public:
5782 
5783  LogicalResult matchAndRewrite(TransposeOp transpOp,
5784  PatternRewriter &rewriter) const override {
5785  Value transposeSrc = transpOp.getVector();
5786  auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
5787  auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
5788  if (!createMaskOp && !constantMaskOp)
5789  return failure();
5790 
5791  // Get the transpose permutation and apply it to the vector.create_mask or
5792  // vector.constant_mask operands.
5793  ArrayRef<int64_t> permutation = transpOp.getPermutation();
5794 
5795  if (createMaskOp) {
5796  auto maskOperands = createMaskOp.getOperands();
5797  SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
5798  applyPermutationToVector(newOperands, permutation);
5799 
5800  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
5801  transpOp, transpOp.getResultVectorType(), newOperands);
5802  return success();
5803  }
5804 
5805  // ConstantMaskOp case.
5806  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5807  SmallVector<Attribute> newMaskDimSizes(maskDimSizes.getValue());
5808  applyPermutationToVector(newMaskDimSizes, permutation);
5809 
5810  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
5811  transpOp, transpOp.getResultVectorType(),
5812  ArrayAttr::get(transpOp.getContext(), newMaskDimSizes));
5813  return success();
5814  }
5815 };
5816 
5817 } // namespace
5818 
5819 void vector::TransposeOp::getCanonicalizationPatterns(
5820  RewritePatternSet &results, MLIRContext *context) {
5821  results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
5822  TransposeFolder, FoldTransposeSplat>(context);
5823 }
5824 
5825 //===----------------------------------------------------------------------===//
5826 // ConstantMaskOp
5827 //===----------------------------------------------------------------------===//
5828 
5829 LogicalResult ConstantMaskOp::verify() {
5830  auto resultType = llvm::cast<VectorType>(getResult().getType());
5831  // Check the corner case of 0-D vectors first.
5832  if (resultType.getRank() == 0) {
5833  if (getMaskDimSizes().size() != 1)
5834  return emitError("array attr must have length 1 for 0-D vectors");
5835  auto dim = llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt();
5836  if (dim != 0 && dim != 1)
5837  return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
5838  return success();
5839  }
5840 
5841  // Verify that array attr size matches the rank of the vector result.
5842  if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
5843  return emitOpError(
5844  "must specify array attr of size equal vector result rank");
5845  // Verify that each array attr element is in bounds of corresponding vector
5846  // result dimension size.
5847  auto resultShape = resultType.getShape();
5848  auto resultScalableDims = resultType.getScalableDims();
5849  SmallVector<int64_t, 4> maskDimSizes;
5850  for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
5851  int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
5852  if (maskDimSize < 0 || maskDimSize > resultShape[index])
5853  return emitOpError(
5854  "array attr of size out of bounds of vector result dimension size");
5855  if (resultScalableDims[index] && maskDimSize != 0 &&
5856  maskDimSize != resultShape[index])
5857  return emitOpError(
5858  "only supports 'none set' or 'all set' scalable dimensions");
5859  maskDimSizes.push_back(maskDimSize);
5860  }
5861  // Verify that if one mask dim size is zero, they all should be zero (because
5862  // the mask region is a conjunction of each mask dimension interval).
5863  bool anyZeros = llvm::is_contained(maskDimSizes, 0);
5864  bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
5865  if (anyZeros && !allZeros)
5866  return emitOpError("expected all mask dim sizes to be zeros, "
5867  "as a result of conjunction with zero mask dim");
5868  return success();
5869 }
5870 
5871 bool ConstantMaskOp::isAllOnesMask() {
5872  auto resultType = getVectorType();
5873  // Check the corner case of 0-D vectors first.
5874  if (resultType.getRank() == 0) {
5875  assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
5876  return llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() == 1;
5877  }
5878  for (const auto [resultSize, intAttr] :
5879  llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
5880  int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
5881  if (maskDimSize < resultSize)
5882  return false;
5883  }
5884  return true;
5885 }
5886 
5887 //===----------------------------------------------------------------------===//
5888 // CreateMaskOp
5889 //===----------------------------------------------------------------------===//
5890 
5891 void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
5892  VectorType type,
5893  ArrayRef<OpFoldResult> mixedOperands) {
5894  SmallVector<Value> operands =
5895  getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
5896  build(builder, result, type, operands);
5897 }
5898 
5899 LogicalResult CreateMaskOp::verify() {
5900  auto vectorType = llvm::cast<VectorType>(getResult().getType());
5901  // Verify that an operand was specified for each result vector each dimension.
5902  if (vectorType.getRank() == 0) {
5903  if (getNumOperands() != 1)
5904  return emitOpError(
5905  "must specify exactly one operand for 0-D create_mask");
5906  } else if (getNumOperands() !=
5907  llvm::cast<VectorType>(getResult().getType()).getRank()) {
5908  return emitOpError(
5909  "must specify an operand for each result vector dimension");
5910  }
5911  return success();
5912 }
5913 
5914 namespace {
5915 
5916 /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
5917 ///
5918 /// Ex 1:
5919 /// %c2 = arith.constant 2 : index
5920 /// %c3 = arith.constant 3 : index
5921 /// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
5922 /// Becomes:
5923 /// vector.constant_mask [3, 2] : vector<4x3xi1>
5924 ///
5925 /// Ex 2:
5926 /// %c_neg_1 = arith.constant -1 : index
5927 /// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
5928 /// becomes:
5929 /// vector.constant_mask [0] : vector<[8]xi1>
5930 ///
5931 /// Ex 3:
5932 /// %c8 = arith.constant 8 : index
5933 /// %c16 = arith.constant 16 : index
5934 /// %0 = vector.vscale
5935 /// %1 = arith.muli %0, %c16 : index
5936 /// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
5937 /// becomes:
5938 /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
5939 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
5940 public:
5942 
5943  LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
5944  PatternRewriter &rewriter) const override {
5945  VectorType retTy = createMaskOp.getResult().getType();
5946  bool isScalable = retTy.isScalable();
5947 
5948  // Check every mask operand
5949  for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
5950  if (auto cst = getConstantIntValue(operand)) {
5951  // Most basic case - this operand is a constant value. Note that for
5952  // scalable dimensions, CreateMaskOp can be folded only if the
5953  // corresponding operand is negative or zero.
5954  if (retTy.getScalableDims()[opIdx] && *cst > 0)
5955  return failure();
5956 
5957  continue;
5958  }
5959 
5960  // Non-constant operands are not allowed for non-scalable vectors.
5961  if (!isScalable)
5962  return failure();
5963 
5964  // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
5965  // true" mask, so can also be treated as constant.
5966  auto mul = operand.getDefiningOp<arith::MulIOp>();
5967  if (!mul)
5968  return failure();
5969  auto mulLHS = mul.getRhs();
5970  auto mulRHS = mul.getLhs();
5971  bool isOneOpVscale =
5972  (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
5973  isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));
5974 
5975  auto isConstantValMatchingDim =
5976  [=, dim = retTy.getShape()[opIdx]](Value operand) {
5977  auto constantVal = getConstantIntValue(operand);
5978  return (constantVal.has_value() && constantVal.value() == dim);
5979  };
5980 
5981  bool isOneOpConstantMatchingDim =
5982  isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);
5983 
5984  if (!isOneOpVscale || !isOneOpConstantMatchingDim)
5985  return failure();
5986  }
5987 
5988  // Gather constant mask dimension sizes.
5989  SmallVector<int64_t, 4> maskDimSizes;
5990  maskDimSizes.reserve(createMaskOp->getNumOperands());
5991  for (auto [operand, maxDimSize] : llvm::zip_equal(
5992  createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
5993  std::optional dimSize = getConstantIntValue(operand);
5994  if (!dimSize) {
5995  // Although not a constant, it is safe to assume that `operand` is
5996  // "vscale * maxDimSize".
5997  maskDimSizes.push_back(maxDimSize);
5998  continue;
5999  }
6000  int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize);
6001  // If one of dim sizes is zero, set all dims to zero.
6002  if (dimSize <= 0) {
6003  maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
6004  break;
6005  }
6006  maskDimSizes.push_back(dimSizeVal);
6007  }
6008 
6009  // Replace 'createMaskOp' with ConstantMaskOp.
6010  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
6011  createMaskOp, retTy,
6012  vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
6013  return success();
6014  }
6015 };
6016 
6017 } // namespace
6018 
6019 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6020  MLIRContext *context) {
6021  results.add<CreateMaskFolder>(context);
6022 }
6023 
6024 //===----------------------------------------------------------------------===//
6025 // MaskOp
6026 //===----------------------------------------------------------------------===//
6027 
6028 void MaskOp::build(
6029  OpBuilder &builder, OperationState &result, Value mask,
6030  Operation *maskableOp,
6031  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6032  assert(maskRegionBuilder &&
6033  "builder callback for 'maskRegion' must be present");
6034 
6035  result.addOperands(mask);
6036  OpBuilder::InsertionGuard guard(builder);
6037  Region *maskRegion = result.addRegion();
6038  builder.createBlock(maskRegion);
6039  maskRegionBuilder(builder, maskableOp);
6040 }
6041 
6042 void MaskOp::build(
6043  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6044  Value mask, Operation *maskableOp,
6045  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6046  build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
6047  maskRegionBuilder);
6048 }
6049 
6050 void MaskOp::build(
6051  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6052  Value mask, Value passthru, Operation *maskableOp,
6053  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6054  build(builder, result, mask, maskableOp, maskRegionBuilder);
6055  if (passthru)
6056  result.addOperands(passthru);
6057  result.addTypes(resultTypes);
6058 }
6059 
6060 ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
6061  // Create the op region.
6062  result.regions.reserve(1);
6063  Region &maskRegion = *result.addRegion();
6064 
6065  auto &builder = parser.getBuilder();
6066 
6067  // Parse all the operands.
6069  if (parser.parseOperand(mask))
6070  return failure();
6071 
6072  // Optional passthru operand.
6074  ParseResult parsePassthru = parser.parseOptionalComma();
6075  if (parsePassthru.succeeded() && parser.parseOperand(passthru))
6076  return failure();
6077 
6078  // Parse op region.
6079  if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
6080  return failure();
6081 
6082  MaskOp::ensureTerminator(maskRegion, builder, result.location);
6083 
6084  // Parse the optional attribute list.
6085  if (parser.parseOptionalAttrDict(result.attributes))
6086  return failure();
6087 
6088  // Parse all the types.
6089  Type maskType;
6090  if (parser.parseColonType(maskType))
6091  return failure();
6092 
6093  SmallVector<Type> resultTypes;
6094  if (parser.parseOptionalArrowTypeList(resultTypes))
6095  return failure();
6096  result.types.append(resultTypes);
6097 
6098  // Resolve operands.
6099  if (parser.resolveOperand(mask, maskType, result.operands))
6100  return failure();
6101 
6102  if (parsePassthru.succeeded())
6103  if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
6104  return failure();
6105 
6106  return success();
6107 }
6108 
6110  p << " " << getMask();
6111  if (getPassthru())
6112  p << ", " << getPassthru();
6113 
6114  // Print single masked operation and skip terminator.
6115  p << " { ";
6116  Block *singleBlock = &getMaskRegion().getBlocks().front();
6117  if (singleBlock && !singleBlock->getOperations().empty())
6118  p.printCustomOrGenericOp(&singleBlock->front());
6119  p << " }";
6120 
6121  p.printOptionalAttrDict(getOperation()->getAttrs());
6122 
6123  p << " : " << getMask().getType();
6124  if (getNumResults() > 0)
6125  p << " -> " << getResultTypes();
6126 }
6127 
6128 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
6130  MaskOp>::ensureTerminator(region, builder, loc);
6131  // Keep the default yield terminator if the number of masked operations is not
6132  // the expected. This case will trigger a verification failure.
6133  Block &block = region.front();
6134  if (block.getOperations().size() != 2)
6135  return;
6136 
6137  // Replace default yield terminator with a new one that returns the results
6138  // from the masked operation.
6139  OpBuilder opBuilder(builder.getContext());
6140  Operation *maskedOp = &block.front();
6141  Operation *oldYieldOp = &block.back();
6142  assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
6143 
6144  // Empty vector.mask op.
6145  if (maskedOp == oldYieldOp)
6146  return;
6147 
6148  opBuilder.setInsertionPoint(oldYieldOp);
6149  opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
6150  oldYieldOp->dropAllReferences();
6151  oldYieldOp->erase();
6152 }
6153 
6154 LogicalResult MaskOp::verify() {
6155  // Structural checks.
6156  Block &block = getMaskRegion().getBlocks().front();
6157  if (block.getOperations().empty())
6158  return emitOpError("expects a terminator within the mask region");
6159  if (block.getOperations().size() > 2)
6160  return emitOpError("expects only one operation to mask");
6161 
6162  // Terminator checks.
6163  auto terminator = dyn_cast<vector::YieldOp>(block.back());
6164  if (!terminator)
6165  return emitOpError("expects a terminator within the mask region");
6166 
6167  if (terminator->getNumOperands() != getNumResults())
6168  return emitOpError(
6169  "expects number of results to match mask region yielded values");
6170 
6171  auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
6172  // Empty vector.mask. Nothing else to check.
6173  if (!maskableOp)
6174  return success();
6175 
6176  // Result checks.
6177  if (maskableOp->getNumResults() != getNumResults())
6178  return emitOpError("expects number of results to match maskable operation "
6179  "number of results");
6180 
6181  if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
6182  return emitOpError(
6183  "expects result type to match maskable operation result type");
6184 
6185  if (llvm::count_if(maskableOp->getResultTypes(),
6186  [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
6187  return emitOpError("multiple vector results not supported");
6188 
6189  // Mask checks.
6190  Type expectedMaskType = maskableOp.getExpectedMaskType();
6191  if (getMask().getType() != expectedMaskType)
6192  return emitOpError("expects a ")
6193  << expectedMaskType << " mask for the maskable operation";
6194 
6195  // Passthru checks.
6196  Value passthru = getPassthru();
6197  if (passthru) {
6198  if (!maskableOp.supportsPassthru())
6199  return emitOpError(
6200  "doesn't expect a passthru argument for this maskable operation");
6201 
6202  if (maskableOp->getNumResults() != 1)
6203  return emitOpError("expects result when passthru argument is provided");
6204 
6205  if (passthru.getType() != maskableOp->getResultTypes()[0])
6206  return emitOpError("expects passthru type to match result type");
6207  }
6208 
6209  return success();
6210 }
6211 
6212 /// Folds vector.mask ops with an all-true mask.
6213 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6214  SmallVectorImpl<OpFoldResult> &results) {
6215  MaskFormat maskFormat = getMaskFormat(getMask());
6216  if (isEmpty())
6217  return failure();
6218 
6219  if (maskFormat != MaskFormat::AllTrue)
6220  return failure();
6221 
6222  // Move maskable operation outside of the `vector.mask` region.
6223  Operation *maskableOp = getMaskableOp();
6224  maskableOp->dropAllUses();
6225  maskableOp->moveBefore(getOperation());
6226 
6227  llvm::append_range(results, maskableOp->getResults());
6228  return success();
6229 }
6230 
6231 // Elides empty vector.mask operations with or without return values. Propagates
6232 // the yielded values by the vector.yield terminator, if any, or erases the op,
6233 // otherwise.
6234 class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
6236 
6237  LogicalResult matchAndRewrite(MaskOp maskOp,
6238  PatternRewriter &rewriter) const override {
6239  auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6240  if (maskingOp.getMaskableOp())
6241  return failure();
6242 
6243  if (!maskOp.isEmpty())
6244  return failure();
6245 
6246  Block *block = maskOp.getMaskBlock();
6247  auto terminator = cast<vector::YieldOp>(block->front());
6248  if (terminator.getNumOperands() == 0)
6249  rewriter.eraseOp(maskOp);
6250  else
6251  rewriter.replaceOp(maskOp, terminator.getOperands());
6252 
6253  return success();
6254  }
6255 };
6256 
6257 void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6258  MLIRContext *context) {
6259  results.add<ElideEmptyMaskOp>(context);
6260 }
6261 
6262 // MaskingOpInterface definitions.
6263 
6264 /// Returns the operation masked by this 'vector.mask'.
6265 Operation *MaskOp::getMaskableOp() {
6266  Block *block = getMaskBlock();
6267  if (block->getOperations().size() < 2)
6268  return nullptr;
6269 
6270  return &block->front();
6271 }
6272 
6273 /// Returns true if 'vector.mask' has a passthru value.
6274 bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
6275 
6276 //===----------------------------------------------------------------------===//
6277 // ScanOp
6278 //===----------------------------------------------------------------------===//
6279 
6280 LogicalResult ScanOp::verify() {
6281  VectorType srcType = getSourceType();
6282  VectorType initialType = getInitialValueType();
6283  // Check reduction dimension < rank.
6284  int64_t srcRank = srcType.getRank();
6285  int64_t reductionDim = getReductionDim();
6286  if (reductionDim >= srcRank)
6287  return emitOpError("reduction dimension ")
6288  << reductionDim << " has to be less than " << srcRank;
6289 
6290  // Check that rank(initial_value) = rank(src) - 1.
6291  int64_t initialValueRank = initialType.getRank();
6292  if (initialValueRank != srcRank - 1)
6293  return emitOpError("initial value rank ")
6294  << initialValueRank << " has to be equal to " << srcRank - 1;
6295 
6296  // Check shapes of initial value and src.
6297  ArrayRef<int64_t> srcShape = srcType.getShape();
6298  ArrayRef<int64_t> initialValueShapes = initialType.getShape();
6299  SmallVector<int64_t> expectedShape;
6300  for (int i = 0; i < srcRank; i++) {
6301  if (i != reductionDim)
6302  expectedShape.push_back(srcShape[i]);
6303  }
6304  if (!llvm::equal(initialValueShapes, expectedShape)) {
6305  return emitOpError("incompatible input/initial value shapes");
6306  }
6307 
6308  // Verify supported reduction kind.
6309  Type eltType = getDestType().getElementType();
6310  if (!isSupportedCombiningKind(getKind(), eltType))
6311  return emitOpError("unsupported reduction type ")
6312  << eltType << " for kind '" << stringifyCombiningKind(getKind())
6313  << "'";
6314 
6315  return success();
6316 }
6317 
6319  RewritePatternSet &patterns, PatternBenefit benefit) {
6320  patterns
6321  .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
6322  ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
6323  StridedSliceConstantMaskFolder, TransposeFolder>(
6324  patterns.getContext(), benefit);
6325 }
6326 
6327 //===----------------------------------------------------------------------===//
6328 // SplatOp
6329 //===----------------------------------------------------------------------===//
6330 
6331 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
6332  auto constOperand = adaptor.getInput();
6333  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
6334  return {};
6335 
6336  // SplatElementsAttr::get treats single value for second arg as being a splat.
6337  return SplatElementsAttr::get(getType(), {constOperand});
6338 }
6339 
6340 //===----------------------------------------------------------------------===//
6341 // StepOp
6342 //===----------------------------------------------------------------------===//
6343 
6344 OpFoldResult StepOp::fold(FoldAdaptor adaptor) {
6345  auto resultType = cast<VectorType>(getType());
6346  if (resultType.isScalable())
6347  return nullptr;
6348  SmallVector<APInt> indices;
6349  for (unsigned i = 0; i < resultType.getNumElements(); i++)
6350  indices.push_back(APInt(/*width=*/64, i));
6351  return DenseElementsAttr::get(resultType, indices);
6352 }
6353 
6354 //===----------------------------------------------------------------------===//
6355 // WarpExecuteOnLane0Op
6356 //===----------------------------------------------------------------------===//
6357 
6359  p << "(" << getLaneid() << ")";
6360 
6361  SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
6362  auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
6363  p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
6364 
6365  if (!getArgs().empty())
6366  p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
6367  if (!getResults().empty())
6368  p << " -> (" << getResults().getTypes() << ')';
6369  p << " ";
6370  p.printRegion(getRegion(),
6371  /*printEntryBlockArgs=*/true,
6372  /*printBlockTerminators=*/!getResults().empty());
6373  p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
6374 }
6375 
6376 ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
6377  OperationState &result) {
6378  // Create the region.
6379  result.regions.reserve(1);
6380  Region *warpRegion = result.addRegion();
6381 
6382  auto &builder = parser.getBuilder();
6384 
6385  // Parse predicate operand.
6386  if (parser.parseLParen() ||
6387  parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
6388  parser.parseRParen())
6389  return failure();
6390 
6391  int64_t warpSize;
6392  if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
6393  parser.parseRSquare())
6394  return failure();
6395  result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
6396  builder.getContext())),
6397  builder.getI64IntegerAttr(warpSize));
6398 
6399  if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
6400  return failure();
6401 
6402  llvm::SMLoc inputsOperandsLoc;
6404  SmallVector<Type> inputTypes;
6405  if (succeeded(parser.parseOptionalKeyword("args"))) {
6406  if (parser.parseLParen())
6407  return failure();
6408 
6409  inputsOperandsLoc = parser.getCurrentLocation();
6410  if (parser.parseOperandList(inputsOperands) ||
6411  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
6412  return failure();
6413  }
6414  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
6415  result.operands))
6416  return failure();
6417 
6418  // Parse optional results type list.
6419  if (parser.parseOptionalArrowTypeList(result.types))
6420  return failure();
6421  // Parse the region.
6422  if (parser.parseRegion(*warpRegion, /*arguments=*/{},
6423  /*argTypes=*/{}))
6424  return failure();
6425  WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
6426 
6427  // Parse the optional attribute list.
6428  if (parser.parseOptionalAttrDict(result.attributes))
6429  return failure();
6430  return success();
6431 }
6432 
6433 void WarpExecuteOnLane0Op::getSuccessorRegions(
6435  if (!point.isParent()) {
6436  regions.push_back(RegionSuccessor(getResults()));
6437  return;
6438  }
6439 
6440  // The warp region is always executed
6441  regions.push_back(RegionSuccessor(&getWarpRegion()));
6442 }
6443 
6444 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
6445  TypeRange resultTypes, Value laneId,
6446  int64_t warpSize) {
6447  build(builder, result, resultTypes, laneId, warpSize,
6448  /*operands=*/std::nullopt, /*argTypes=*/std::nullopt);
6449 }
6450 
6451 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
6452  TypeRange resultTypes, Value laneId,
6453  int64_t warpSize, ValueRange args,
6454  TypeRange blockArgTypes) {
6455  result.addOperands(laneId);
6456  result.addAttribute(getAttributeNames()[0],
6457  builder.getI64IntegerAttr(warpSize));
6458  result.addTypes(resultTypes);
6459  result.addOperands(args);
6460  assert(args.size() == blockArgTypes.size());
6461  OpBuilder::InsertionGuard guard(builder);
6462  Region *warpRegion = result.addRegion();
6463  Block *block = builder.createBlock(warpRegion);
6464  for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
6465  block->addArgument(type, arg.getLoc());
6466 }
6467 
6468 /// Helper check if the distributed vector type is consistent with the expanded
6469 /// type and distributed size.
6470 static LogicalResult verifyDistributedType(Type expanded, Type distributed,
6471  int64_t warpSize, Operation *op) {
6472  // If the types matches there is no distribution.
6473  if (expanded == distributed)
6474  return success();
6475  auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
6476  auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
6477  if (!expandedVecType || !distributedVecType)
6478  return op->emitOpError("expected vector type for distributed operands.");
6479  if (expandedVecType.getRank() != distributedVecType.getRank() ||
6480  expandedVecType.getElementType() != distributedVecType.getElementType())
6481  return op->emitOpError(
6482  "expected distributed vectors to have same rank and element type.");
6483 
6484  SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
6485  for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
6486  int64_t eDim = expandedVecType.getDimSize(i);
6487  int64_t dDim = distributedVecType.getDimSize(i);
6488  if (eDim == dDim)
6489  continue;
6490  if (eDim % dDim != 0)
6491  return op->emitOpError()
6492  << "expected expanded vector dimension #" << i << " (" << eDim
6493  << ") to be a multipler of the distributed vector dimension ("
6494  << dDim << ")";
6495  scales[i] = eDim / dDim;
6496  }
6497  if (std::accumulate(scales.begin(), scales.end(), 1,
6498  std::multiplies<int64_t>()) != warpSize)
6499  return op->emitOpError()
6500  << "incompatible distribution dimensions from " << expandedVecType
6501  << " to " << distributedVecType << " with warp size = " << warpSize;
6502 
6503  return success();
6504 }
6505 
6506 LogicalResult WarpExecuteOnLane0Op::verify() {
6507  if (getArgs().size() != getWarpRegion().getNumArguments())
6508  return emitOpError(
6509  "expected same number op arguments and block arguments.");
6510  auto yield =
6511  cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
6512  if (yield.getNumOperands() != getNumResults())
6513  return emitOpError(
6514  "expected same number of yield operands and return values.");
6515  int64_t warpSize = getWarpSize();
6516  for (auto [regionArg, arg] :
6517  llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
6518  if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
6519  warpSize, getOperation())))
6520  return failure();
6521  }
6522  for (auto [yieldOperand, result] :
6523  llvm::zip_equal(yield.getOperands(), getResults())) {
6524  if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
6525  warpSize, getOperation())))
6526  return failure();
6527  }
6528  return success();
6529 }
6530 
6531 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
6532  return succeeded(
6533  verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
6534 }
6535 
6537  CombiningKind kind, Value v1, Value acc,
6538  arith::FastMathFlagsAttr fastmath,
6539  Value mask) {
6540  Type t1 = getElementTypeOrSelf(v1.getType());
6541  Type tAcc = getElementTypeOrSelf(acc.getType());
6542  Value result;
6543 
6544  switch (kind) {
6545  case CombiningKind::ADD:
6546  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6547  result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
6548  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6549  result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
6550  else
6551  llvm_unreachable("invalid value types for ADD reduction");
6552  break;
6553  case CombiningKind::AND:
6554  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6555  result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
6556  break;
6557  case CombiningKind::MAXNUMF:
6558  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6559  "expected float values");
6560  result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
6561  break;
6562  case CombiningKind::MAXIMUMF:
6563  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6564  "expected float values");
6565  result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
6566  break;
6567  case CombiningKind::MINNUMF:
6568  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6569  "expected float values");
6570  result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
6571  break;
6572  case CombiningKind::MINIMUMF:
6573  assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6574  "expected float values");
6575  result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
6576  break;
6577  case CombiningKind::MAXSI:
6578  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6579  result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
6580  break;
6581  case CombiningKind::MINSI:
6582  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6583  result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
6584  break;
6585  case CombiningKind::MAXUI:
6586  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6587  result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
6588  break;
6589  case CombiningKind::MINUI:
6590  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6591  result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
6592  break;
6593  case CombiningKind::MUL:
6594  if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6595  result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
6596  else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6597  result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
6598  else
6599  llvm_unreachable("invalid value types for MUL reduction");
6600  break;
6601  case CombiningKind::OR:
6602  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6603  result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
6604  break;
6605  case CombiningKind::XOR:
6606  assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6607  result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
6608  break;
6609  };
6610 
6611  assert(result && "unknown CombiningKind");
6612  return selectPassthru(b, mask, result, acc);
6613 }
6614 
6615 //===----------------------------------------------------------------------===//
6616 // Vector Masking Utilities
6617 //===----------------------------------------------------------------------===//
6618 
6619 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
6620 /// as masked operation.
6622  Operation *maskableOp) {
6623  assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
6624  Block *insBlock = builder.getInsertionBlock();
6625  // Create a block and move the op to that block.
6626  insBlock->getOperations().splice(
6627  insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
6628  builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
6629 }
6630 
6631 /// Creates a vector.mask operation around a maskable operation. Returns the
6632 /// vector.mask operation if the mask provided is valid. Otherwise, returns
6633 /// the maskable operation itself.
6635  Operation *maskableOp, Value mask,
6636  Value passthru) {
6637  if (!mask)
6638  return maskableOp;
6639  if (passthru)
6640  return builder.create<MaskOp>(maskableOp->getLoc(),
6641  maskableOp->getResultTypes(), mask, passthru,
6642  maskableOp, createMaskOpRegion);
6643  return builder.create<MaskOp>(maskableOp->getLoc(),
6644  maskableOp->getResultTypes(), mask, maskableOp,
6646 }
6647 
6648 /// Creates a vector select operation that picks values from `newValue` or
6649 /// `passthru` for each result vector lane based on `mask`. This utility is used
6650 /// to propagate the pass-thru value of vector.mask or for cases where only the
6651 /// pass-thru value propagation is needed. VP intrinsics do not support
6652 /// pass-thru values and every mask-out lane is set to poison. LLVM backends are
6653 /// usually able to match op + select patterns and fold them into a native
6654 /// target instructions.
6656  Value newValue, Value passthru) {
6657  if (!mask)
6658  return newValue;
6659 
6660  return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(),
6661  mask, newValue, passthru);
6662 }
6663 
6664 //===----------------------------------------------------------------------===//
6665 // TableGen'd op method definitions
6666 //===----------------------------------------------------------------------===//
6667 
6668 #define GET_ATTRDEF_CLASSES
6669 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
6670 
6671 #define GET_OP_CLASSES
6672 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static VectorShape vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MINUI(lhs, rhs)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition: VectorOps.cpp:68
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
Definition: VectorOps.cpp:1072
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
Definition: VectorOps.cpp:1353
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
Definition: VectorOps.cpp:1614
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
Definition: VectorOps.cpp:3969
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
Definition: VectorOps.cpp:1886
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
Definition: VectorOps.cpp:1754
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
Definition: VectorOps.cpp:1625
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Definition: VectorOps.cpp:2199
static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width)
Definition: VectorOps.cpp:2531
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
Definition: VectorOps.cpp:131
MaskFormat
Helper enum to classify mask value.
Definition: VectorOps.cpp:58
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
Definition: VectorOps.cpp:3010
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
Definition: VectorOps.cpp:297
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
Definition: VectorOps.cpp:853
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
Definition: VectorOps.cpp:2243
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
Definition: VectorOps.cpp:2946
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
Definition: VectorOps.cpp:1064
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
Definition: VectorOps.cpp:2966
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
Definition: VectorOps.cpp:176
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
Definition: VectorOps.cpp:2989
static LogicalResult foldTransferFullMask(TransferOp op)
Definition: VectorOps.cpp:4177
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Definition: VectorOps.cpp:1345
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite a vector.from_elements into a vector.splat if all elements are the same SSA value.
Definition: VectorOps.cpp:2222
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
Definition: VectorOps.cpp:3857
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
Definition: VectorOps.cpp:864
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
Definition: VectorOps.cpp:2931
static Value foldExtractFromShapeCast(ExtractOp extractOp)
Definition: VectorOps.cpp:1684
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
Definition: VectorOps.cpp:3886
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
Definition: VectorOps.cpp:4126
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
Definition: VectorOps.cpp:3409
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
Definition: VectorOps.cpp:4143
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
Definition: VectorOps.cpp:1806
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Definition: AffineMap.cpp:135
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:216
unsigned getNumResults() const
Definition: AffineMap.cpp:402
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation & back()
Definition: Block.h:150
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
OpListType & getOperations()
Definition: Block.h:135
Operation & front()
Definition: Block.h:151
iterator begin()
Definition: Block.h:141
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:183
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:187
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:132
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:91
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:335
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:77
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:277
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:292
IndexType getIndexType()
Definition: Builders.cpp:75
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:281
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:329
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:559
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:441
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:523
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:445
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:830
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
Definition: Operation.cpp:584
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:124
bool isF32() const
Definition: Types.cpp:52
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:116
bool isF16() const
Definition: Types.cpp:50
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:311
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Definition: BuiltinTypes.h:323
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:330
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:92
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Fraction abs(const Fraction &f)
Definition: Fraction.h:106
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:358
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Definition: VectorOps.cpp:432
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
Definition: VectorOps.cpp:126
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:154
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Definition: VectorOps.cpp:3988
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
Definition: VectorOps.cpp:220
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:284
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2376
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:315
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:339
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
Definition: VectorOps.cpp:609
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition: VectorOps.h:65
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Definition: VectorOps.cpp:428
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:750
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:657
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:496
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:768
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:699
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:631
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
Definition: AffineMap.cpp:905
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:607
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
Return a fused vector::ContractionOp which represents a patterns such as:
Definition: VectorOps.cpp:1165
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Definition: VectorOps.cpp:1168
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:329
bool operator==(const KeyTy &key) const
Definition: VectorOps.cpp:367
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Definition: VectorOps.cpp:369
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.