MLIR  20.0.0git
VectorOps.cpp
Go to the documentation of this file.
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
24 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/AffineMap.h"
26 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/BuiltinOps.h"
29 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/IRMapping.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/IR/TypeUtilities.h"
37 #include "mlir/Support/LLVM.h"
39 #include "llvm/ADT/ArrayRef.h"
40 #include "llvm/ADT/STLExtras.h"
41 #include "llvm/ADT/SmallVector.h"
42 #include "llvm/ADT/StringSet.h"
43 #include "llvm/ADT/TypeSwitch.h"
44 #include "llvm/ADT/bit.h"
45 
46 #include <cassert>
47 #include <cstdint>
48 #include <numeric>
49 
50 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
51 // Pull in all enum type and utility function definitions.
52 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
53 
54 using namespace mlir;
55 using namespace mlir::vector;
56 
57 /// Helper enum to classify mask value.
58 enum class MaskFormat {
59  AllTrue = 0,
60  AllFalse = 1,
61  Unknown = 2,
62 };
63 
64 /// Helper method to classify a mask value. Currently, the method
65 /// looks "under the hood" of a constant value with dense attributes
66 /// and a constant mask operation (since the client may be called at
67 /// various stages during progressive lowering).
69  if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
70  // Inspect constant dense values. We count up for bits that
71  // are set, count down for bits that are cleared, and bail
72  // when a mix is detected.
73  if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
74  int64_t val = 0;
75  for (bool b : denseElts.getValues<bool>())
76  if (b && val >= 0)
77  val++;
78  else if (!b && val <= 0)
79  val--;
80  else
81  return MaskFormat::Unknown;
82  if (val > 0)
83  return MaskFormat::AllTrue;
84  if (val < 0)
85  return MaskFormat::AllFalse;
86  }
87  } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
88  // Inspect constant mask index. If the index exceeds the
89  // dimension size, all bits are set. If the index is zero
90  // or less, no bits are set.
91  ArrayRef<int64_t> masks = m.getMaskDimSizes();
92  auto shape = m.getType().getShape();
93  bool allTrue = true;
94  bool allFalse = true;
95  for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
96  if (maskIdx < dimSize)
97  allTrue = false;
98  if (maskIdx > 0)
99  allFalse = false;
100  }
101  if (allTrue)
102  return MaskFormat::AllTrue;
103  if (allFalse)
104  return MaskFormat::AllFalse;
105  } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
106  // Finds all-false create_masks. An all-true create_mask requires all
107  // dims to be constants, so that'll be folded to a constant_mask, then
108  // detected in the constant_mask case.
109  auto maskOperands = m.getOperands();
110  for (Value operand : maskOperands) {
111  if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
112  int64_t dimSize =
113  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
114  if (dimSize <= 0)
115  return MaskFormat::AllFalse;
116  }
117  }
118  return MaskFormat::Unknown;
119  }
120  return MaskFormat::Unknown;
121 }
122 
123 /// Default callback to build a region with a 'vector.yield' terminator with no
124 /// arguments.
126  builder.create<vector::YieldOp>(loc);
127 }
128 
129 // Helper for verifying combining kinds in contractions and reductions.
130 static bool isSupportedCombiningKind(CombiningKind combiningKind,
131  Type elementType) {
132  switch (combiningKind) {
133  case CombiningKind::ADD:
134  case CombiningKind::MUL:
135  return elementType.isIntOrIndexOrFloat();
137  case CombiningKind::MINSI:
138  case CombiningKind::MAXUI:
139  case CombiningKind::MAXSI:
140  case CombiningKind::AND:
141  case CombiningKind::OR:
142  case CombiningKind::XOR:
143  return elementType.isIntOrIndex();
144  case CombiningKind::MINNUMF:
145  case CombiningKind::MAXNUMF:
146  case CombiningKind::MINIMUMF:
147  case CombiningKind::MAXIMUMF:
148  return llvm::isa<FloatType>(elementType);
149  }
150  return false;
151 }
152 
154  VectorType vectorType) {
155  int64_t elementVectorRank = 0;
156  VectorType elementVectorType =
157  llvm::dyn_cast<VectorType>(shapedType.getElementType());
158  if (elementVectorType)
159  elementVectorRank += elementVectorType.getRank();
160  // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
161  // TODO: replace once we have 0-d vectors.
162  if (shapedType.getRank() == 0 &&
163  vectorType.getShape() == ArrayRef<int64_t>{1})
164  return AffineMap::get(
165  /*numDims=*/0, /*numSymbols=*/0,
166  getAffineConstantExpr(0, shapedType.getContext()));
168  shapedType.getRank(), vectorType.getRank() - elementVectorRank,
169  shapedType.getContext());
170 }
171 
172 /// Check if `write` is of a constant splat and the masked `read` is padded with
173 /// the same splat value -- meaning it could be the same value as the initial
174 /// constant splat.
175 static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
176  vector::TransferReadOp read) {
177  auto readMask = read.getMask();
178  auto writeMask = write.getMask();
179  // Check if the masks are consistent. The splat value could be the same if the
180  // read is masked (and padded with the splat value), and the write is unmasked
181  // or has the same mask. Note this does not allow the case where the write is
182  // masked and the read is unmasked, as then the read could be of more elements
183  // than the write (which may not be the same value).
184  bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
185  if (!couldBeSameSplat)
186  return false;
187  // Check for constant splat (as the source of the write).
188  DenseElementsAttr splatAttr;
189  if (!matchPattern(write.getVector(),
190  m_Constant<DenseElementsAttr>(&splatAttr)) ||
191  !splatAttr.isSplat()) {
192  return false;
193  }
194  // The padding of the read and the constant splat value must be the same.
195  Attribute padAttr;
196  if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
197  return false;
198  return padAttr == splatAttr.getSplatValue<Attribute>();
199 }
200 
201 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
202  vector::TransferReadOp read) {
203  return !defWrite.hasOutOfBoundsDim() &&
204  defWrite.getIndices() == read.getIndices() &&
205  defWrite.getVectorType() == read.getVectorType() &&
206  defWrite.getPermutationMap() == read.getPermutationMap() &&
207  ((!defWrite.getMask() && !read.getMask()) ||
208  isSplatWriteConsistentWithMaskedRead(defWrite, read));
209 }
210 
211 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
212  vector::TransferWriteOp priorWrite) {
213  return priorWrite.getIndices() == write.getIndices() &&
214  priorWrite.getMask() == write.getMask() &&
215  priorWrite.getVectorType() == write.getVectorType() &&
216  priorWrite.getPermutationMap() == write.getPermutationMap();
217 }
218 
220  VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
221  bool testDynamicValueUsingBounds) {
222  // For simplicity only look at transfer of same type.
223  if (transferA.getVectorType() != transferB.getVectorType())
224  return false;
225  unsigned rankOffset = transferA.getLeadingShapedRank();
226  for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
227  Value indexA = transferA.getIndices()[i];
228  Value indexB = transferB.getIndices()[i];
229  std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
230  std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
231 
232  if (i < rankOffset) {
233  // For leading dimensions, if we can prove that index are different we
234  // know we are accessing disjoint slices.
235  if (cstIndexA.has_value() && cstIndexB.has_value()) {
236  if (*cstIndexA != *cstIndexB)
237  return true;
238  continue;
239  }
240  if (testDynamicValueUsingBounds) {
241  // First try to see if we can fully compose and simplify the affine
242  // expression as a fast track.
243  FailureOr<uint64_t> delta =
245  if (succeeded(delta) && *delta != 0)
246  return true;
247 
248  FailureOr<bool> testEqual =
249  ValueBoundsConstraintSet::areEqual(indexA, indexB);
250  if (succeeded(testEqual) && !testEqual.value())
251  return true;
252  }
253  } else {
254  // For this dimension, we slice a part of the memref we need to make sure
255  // the intervals accessed don't overlap.
256  int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
257  if (cstIndexA.has_value() && cstIndexB.has_value()) {
258  int64_t distance = std::abs(*cstIndexA - *cstIndexB);
259  if (distance >= vectorDim)
260  return true;
261  continue;
262  }
263  if (testDynamicValueUsingBounds) {
264  // First try to see if we can fully compose and simplify the affine
265  // expression as a fast track.
266  FailureOr<int64_t> delta =
268  if (succeeded(delta) && std::abs(*delta) >= vectorDim)
269  return true;
270 
271  FailureOr<int64_t> computeDelta =
273  if (succeeded(computeDelta)) {
274  if (std::abs(computeDelta.value()) >= vectorDim)
275  return true;
276  }
277  }
278  }
279  }
280  return false;
281 }
282 
283 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
284  VectorTransferOpInterface transferB,
285  bool testDynamicValueUsingBounds) {
286  if (transferA.getSource() != transferB.getSource())
287  return false;
288  return isDisjointTransferIndices(transferA, transferB,
289  testDynamicValueUsingBounds);
290 }
291 
292 // Helper to iterate over n-D vector slice elements. Calculate the next
293 // `position` in the n-D vector of size `shape`, applying an offset `offsets`.
294 // Modifies the `position` in place. Returns a failure when `position` becomes
295 // the end position.
296 static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
297  ArrayRef<int64_t> shape,
298  ArrayRef<int64_t> offsets) {
299  for (auto [posInDim, dimSize, offsetInDim] :
300  llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
301  ++posInDim;
302  if (posInDim < dimSize + offsetInDim)
303  return success();
304 
305  // Carry the overflow to the next loop iteration.
306  posInDim = offsetInDim;
307  }
308 
309  return failure();
310 }
311 
312 /// Returns the integer numbers in `values`. `values` are expected to be
313 /// constant operations.
316  llvm::transform(values, std::back_inserter(ints), [](Value value) {
317  auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
318  assert(constOp && "Unexpected non-constant index");
319  return constOp.value();
320  });
321  return ints;
322 }
323 
324 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
325 /// be constant operations.
328  llvm::transform(
329  foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
330  assert(foldResult.is<Attribute>() && "Unexpected non-constant index");
331  return cast<IntegerAttr>(foldResult.get<Attribute>()).getInt();
332  });
333  return ints;
334 }
335 
336 /// Convert `foldResults` into Values. Integer attributes are converted to
337 /// constant op.
339  ArrayRef<OpFoldResult> foldResults) {
340  SmallVector<Value> values;
341  llvm::transform(foldResults, std::back_inserter(values),
342  [&](OpFoldResult foldResult) {
343  if (auto attr = foldResult.dyn_cast<Attribute>())
344  return builder
346  loc, cast<IntegerAttr>(attr).getInt())
347  .getResult();
348 
349  return foldResult.get<Value>();
350  });
351  return values;
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // CombiningKindAttr
356 //===----------------------------------------------------------------------===//
357 
358 namespace mlir {
359 namespace vector {
360 namespace detail {
362  using KeyTy = uint64_t;
363 
364  BitmaskEnumStorage(KeyTy val) : value(val) {}
365 
366  bool operator==(const KeyTy &key) const { return value == key; }
367 
369  const KeyTy &key) {
370  return new (allocator.allocate<BitmaskEnumStorage>())
371  BitmaskEnumStorage(key);
372  }
373 
374  KeyTy value = 0;
375 };
376 } // namespace detail
377 } // namespace vector
378 } // namespace mlir
379 
380 //===----------------------------------------------------------------------===//
381 // VectorDialect
382 //===----------------------------------------------------------------------===//
383 
384 namespace {
385 /// This class defines the interface for handling inlining with vector dialect
386 /// operations.
387 struct VectorInlinerInterface : public DialectInlinerInterface {
389 
390  /// All vector dialect ops can be inlined.
391  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
392  return true;
393  }
394 };
395 } // namespace
396 
397 void VectorDialect::initialize() {
398  addAttributes<
399 #define GET_ATTRDEF_LIST
400 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
401  >();
402 
403  addOperations<
404 #define GET_OP_LIST
405 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
406  >();
407 
408  addInterfaces<VectorInlinerInterface>();
409 
410  declarePromisedInterfaces<bufferization::BufferizableOpInterface,
411  TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
412  YieldOp>();
413  declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
414  TransferWriteOp>();
415  declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
416  declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
417 }
418 
419 /// Materialize a single constant operation from a given attribute value with
420 /// the desired resultant type.
422  Attribute value, Type type,
423  Location loc) {
424  return arith::ConstantOp::materialize(builder, value, type, loc);
425 }
426 
428  return builder.getIntegerType(64);
429 }
430 
432  ArrayRef<int64_t> values) {
433  return builder.getI64ArrayAttr(values);
434 }
435 
436 //===----------------------------------------------------------------------===//
437 // MultiDimReductionOp
438 //===----------------------------------------------------------------------===//
439 
440 void vector::MultiDimReductionOp::build(OpBuilder &builder,
441  OperationState &result, Value source,
442  Value acc, ArrayRef<bool> reductionMask,
443  CombiningKind kind) {
444  SmallVector<int64_t> reductionDims;
445  for (const auto &en : llvm::enumerate(reductionMask))
446  if (en.value())
447  reductionDims.push_back(en.index());
448  build(builder, result, kind, source, acc, reductionDims);
449 }
450 
451 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
452  // Single parallel dim, this is a noop.
453  if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
454  return getSource();
455  return {};
456 }
457 
458 std::optional<SmallVector<int64_t, 4>>
459 MultiDimReductionOp::getShapeForUnroll() {
460  return llvm::to_vector<4>(getSourceVectorType().getShape());
461 }
462 
463 LogicalResult MultiDimReductionOp::verify() {
464  SmallVector<int64_t> targetShape;
465  SmallVector<bool> scalableDims;
466  Type inferredReturnType;
467  auto sourceScalableDims = getSourceVectorType().getScalableDims();
468  for (auto [dimIdx, dimSize] :
469  llvm::enumerate(getSourceVectorType().getShape()))
470  if (!llvm::any_of(getReductionDims(),
471  [dimIdx = dimIdx](int64_t reductionDimIdx) {
472  return reductionDimIdx == static_cast<int64_t>(dimIdx);
473  })) {
474  targetShape.push_back(dimSize);
475  scalableDims.push_back(sourceScalableDims[dimIdx]);
476  }
477  // TODO: update to also allow 0-d vectors when available.
478  if (targetShape.empty())
479  inferredReturnType = getSourceVectorType().getElementType();
480  else
481  inferredReturnType = VectorType::get(
482  targetShape, getSourceVectorType().getElementType(), scalableDims);
483  if (getType() != inferredReturnType)
484  return emitOpError() << "destination type " << getType()
485  << " is incompatible with source type "
486  << getSourceVectorType();
487 
488  return success();
489 }
490 
491 /// Returns the mask type expected by this operation.
492 Type MultiDimReductionOp::getExpectedMaskType() {
493  auto vecType = getSourceVectorType();
494  return VectorType::get(vecType.getShape(),
495  IntegerType::get(vecType.getContext(), /*width=*/1),
496  vecType.getScalableDims());
497 }
498 
499 namespace {
500 // Only unit dimensions that are being reduced are folded. If the dimension is
501 // unit, but not reduced, it is not folded, thereby keeping the output type the
502 // same. If not all dimensions which are reduced are of unit dimension, this
503 // transformation does nothing. This is just a generalization of
504 // ElideSingleElementReduction for ReduceOp.
505 struct ElideUnitDimsInMultiDimReduction
506  : public OpRewritePattern<MultiDimReductionOp> {
508 
509  LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
510  PatternRewriter &rewriter) const override {
511  ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
512  for (const auto &dim : enumerate(shape)) {
513  if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
514  return failure();
515  }
516 
517  // Vector mask setup.
518  OpBuilder::InsertionGuard guard(rewriter);
519  Operation *rootOp;
520  Value mask;
521  if (reductionOp.isMasked()) {
522  rewriter.setInsertionPoint(reductionOp.getMaskingOp());
523  rootOp = reductionOp.getMaskingOp();
524  mask = reductionOp.getMaskingOp().getMask();
525  } else {
526  rootOp = reductionOp;
527  }
528 
529  Location loc = reductionOp.getLoc();
530  Value acc = reductionOp.getAcc();
531  Value cast;
532  if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
533  if (mask) {
534  VectorType newMaskType =
535  VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
536  dstVecType.getScalableDims());
537  mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
538  }
539  cast = rewriter.create<vector::ShapeCastOp>(
540  loc, reductionOp.getDestType(), reductionOp.getSource());
541  } else {
542  // This means we are reducing all the dimensions, and all reduction
543  // dimensions are of size 1. So a simple extraction would do.
544  SmallVector<int64_t> zeroIdx(shape.size(), 0);
545  if (mask)
546  mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx);
547  cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(),
548  zeroIdx);
549  }
550 
551  Value result =
552  vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
553  cast, /*fastmath=*/nullptr, mask);
554  rewriter.replaceOp(rootOp, result);
555  return success();
556  }
557 };
558 } // namespace
559 
560 void MultiDimReductionOp::getCanonicalizationPatterns(
561  RewritePatternSet &results, MLIRContext *context) {
562  results.add<ElideUnitDimsInMultiDimReduction>(context);
563 }
564 
565 //===----------------------------------------------------------------------===//
566 // ReductionOp
567 //===----------------------------------------------------------------------===//
568 
569 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
570  CombiningKind kind, Value vector,
571  arith::FastMathFlags fastMathFlags) {
572  build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
573 }
574 
575 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
576  CombiningKind kind, Value vector, Value acc,
577  arith::FastMathFlags fastMathFlags) {
578  build(builder, result,
579  llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
580  acc, fastMathFlags);
581 }
582 
583 LogicalResult ReductionOp::verify() {
584  // Verify for 0-D and 1-D vector.
585  int64_t rank = getSourceVectorType().getRank();
586  if (rank > 1)
587  return emitOpError("unsupported reduction rank: ") << rank;
588 
589  // Verify supported reduction kind.
590  Type eltType = getDest().getType();
591  if (!isSupportedCombiningKind(getKind(), eltType))
592  return emitOpError("unsupported reduction type '")
593  << eltType << "' for kind '" << stringifyCombiningKind(getKind())
594  << "'";
595 
596  return success();
597 }
598 
599 // MaskableOpInterface methods.
600 
601 /// Returns the mask type expected by this operation.
602 Type ReductionOp::getExpectedMaskType() {
603  auto vecType = getSourceVectorType();
604  return VectorType::get(vecType.getShape(),
605  IntegerType::get(vecType.getContext(), /*width=*/1),
606  vecType.getScalableDims());
607 }
608 
610  OpBuilder &builder, Location loc,
611  Value vector) {
612  switch (op) {
613  case arith::AtomicRMWKind::addf:
614  case arith::AtomicRMWKind::addi:
615  return builder.create<vector::ReductionOp>(vector.getLoc(),
616  CombiningKind::ADD, vector);
617  case arith::AtomicRMWKind::mulf:
618  case arith::AtomicRMWKind::muli:
619  return builder.create<vector::ReductionOp>(vector.getLoc(),
620  CombiningKind::MUL, vector);
621  case arith::AtomicRMWKind::minimumf:
622  return builder.create<vector::ReductionOp>(vector.getLoc(),
623  CombiningKind::MINIMUMF, vector);
624  case arith::AtomicRMWKind::mins:
625  return builder.create<vector::ReductionOp>(vector.getLoc(),
626  CombiningKind::MINSI, vector);
627  case arith::AtomicRMWKind::minu:
628  return builder.create<vector::ReductionOp>(vector.getLoc(),
629  CombiningKind::MINUI, vector);
630  case arith::AtomicRMWKind::maximumf:
631  return builder.create<vector::ReductionOp>(vector.getLoc(),
632  CombiningKind::MAXIMUMF, vector);
633  case arith::AtomicRMWKind::maxs:
634  return builder.create<vector::ReductionOp>(vector.getLoc(),
635  CombiningKind::MAXSI, vector);
636  case arith::AtomicRMWKind::maxu:
637  return builder.create<vector::ReductionOp>(vector.getLoc(),
638  CombiningKind::MAXUI, vector);
639  case arith::AtomicRMWKind::andi:
640  return builder.create<vector::ReductionOp>(vector.getLoc(),
641  CombiningKind::AND, vector);
642  case arith::AtomicRMWKind::ori:
643  return builder.create<vector::ReductionOp>(vector.getLoc(),
644  CombiningKind::OR, vector);
645  // TODO: Add remaining reduction operations.
646  default:
647  (void)emitOptionalError(loc, "Reduction operation type not supported");
648  break;
649  }
650  return nullptr;
651 }
652 
653 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
654  return llvm::to_vector<4>(getSourceVectorType().getShape());
655 }
656 
657 namespace {
658 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
660 
661  LogicalResult matchAndRewrite(ReductionOp reductionOp,
662  PatternRewriter &rewriter) const override {
663  // Vector mask setup.
664  OpBuilder::InsertionGuard guard(rewriter);
665  auto maskableOp =
666  cast<vector::MaskableOpInterface>(reductionOp.getOperation());
667  Operation *rootOp;
668  Value mask;
669  if (maskableOp.isMasked()) {
670  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
671  rootOp = maskableOp.getMaskingOp();
672  mask = maskableOp.getMaskingOp().getMask();
673  } else {
674  rootOp = reductionOp;
675  }
676 
677  auto vectorType = reductionOp.getSourceVectorType();
678  if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
679  return failure();
680 
681  Location loc = reductionOp.getLoc();
682  Value result;
683  if (vectorType.getRank() == 0) {
684  if (mask)
685  mask = rewriter.create<ExtractElementOp>(loc, mask);
686  result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
687  } else {
688  if (mask)
689  mask = rewriter.create<ExtractOp>(loc, mask, 0);
690  result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0);
691  }
692 
693  if (Value acc = reductionOp.getAcc())
694  result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
695  result, acc,
696  reductionOp.getFastmathAttr(), mask);
697 
698  rewriter.replaceOp(rootOp, result);
699  return success();
700  }
701 };
702 } // namespace
703 
704 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
705  MLIRContext *context) {
706  results.add<ElideSingleElementReduction>(context);
707 }
708 
709 //===----------------------------------------------------------------------===//
710 // ContractionOp
711 //===----------------------------------------------------------------------===//
712 
713 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
714  Value lhs, Value rhs, Value acc,
715  ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
716  ArrayRef<IteratorType> iteratorTypes) {
717  result.addOperands({lhs, rhs, acc});
718  result.addTypes(acc.getType());
719  result.addAttribute(
720  getIndexingMapsAttrName(result.name),
721  builder.getAffineMapArrayAttr(
722  AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
723  result.addAttribute(
724  getIteratorTypesAttrName(result.name),
725  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
726  iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
727  return IteratorTypeAttr::get(builder.getContext(), t);
728  }))));
729 }
730 
731 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
732  Value lhs, Value rhs, Value acc,
733  ArrayAttr indexingMaps,
734  ArrayAttr iteratorTypes) {
735  build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
736  ContractionOp::getDefaultKind());
737 }
738 
739 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
740  Value lhs, Value rhs, Value acc,
741  ArrayAttr indexingMaps,
742  ArrayAttr iteratorTypes, CombiningKind kind) {
743  result.addOperands({lhs, rhs, acc});
744  result.addTypes(acc.getType());
745  result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
746  result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
747  result.addAttribute(getKindAttrName(result.name),
748  CombiningKindAttr::get(builder.getContext(), kind));
749 }
750 
751 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
756  SmallVector<Type, 2> types;
757  Type resultType;
758  auto loc = parser.getCurrentLocation();
759  DictionaryAttr dictAttr;
760  // TODO: Unify linalg op attribute parsing.
761  if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
762  parser.parseComma() || parser.parseOperand(rhsInfo) ||
763  parser.parseComma() || parser.parseOperand(accInfo) ||
764  parser.parseTrailingOperandList(masksInfo) ||
765  parser.parseOptionalAttrDict(result.attributes) ||
766  parser.parseColonTypeList(types) ||
767  parser.parseKeywordType("into", resultType) ||
768  parser.resolveOperand(lhsInfo, types[0], result.operands) ||
769  parser.resolveOperand(rhsInfo, types[1], result.operands) ||
770  parser.resolveOperand(accInfo, resultType, result.operands) ||
771  parser.addTypeToList(resultType, result.types))
772  return failure();
773  result.attributes.append(dictAttr.getValue().begin(),
774  dictAttr.getValue().end());
775 
776  // Convert array of string into an array of IteratyType enums. This is needed,
777  // because tests still use the old format when 'iterator_types' attribute is
778  // represented as an array of strings.
779  // TODO: Remove this conversion once tests are fixed.
780  ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
781  result.attributes.get(getIteratorTypesAttrName(result.name)));
782 
783  SmallVector<Attribute> iteratorTypeAttrs;
784 
785  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
786  auto maybeIteratorType = symbolizeIteratorType(s);
787  if (!maybeIteratorType.has_value())
788  return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
789 
790  iteratorTypeAttrs.push_back(
791  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
792  }
793  result.attributes.set(getIteratorTypesAttrName(result.name),
794  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
795 
796  if (!result.attributes.get(getKindAttrName(result.name))) {
797  result.addAttribute(
798  getKindAttrName(result.name),
800  ContractionOp::getDefaultKind()));
801  }
802  if (masksInfo.empty())
803  return success();
804  if (masksInfo.size() != 2)
805  return parser.emitError(parser.getNameLoc(),
806  "expected zero or exactly 2 vector mask operands");
807  auto lhsType = llvm::cast<VectorType>(types[0]);
808  auto rhsType = llvm::cast<VectorType>(types[1]);
809  auto maskElementType = parser.getBuilder().getI1Type();
810  std::array<VectorType, 2> maskTypes = {
811  VectorType::Builder(lhsType).setElementType(maskElementType),
812  VectorType::Builder(rhsType).setElementType(maskElementType)};
813  if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
814  return failure();
815  return success();
816 }
817 
819  // TODO: Unify printing code with linalg ops.
820  auto attrNames = getTraitAttrNames();
821  llvm::StringSet<> traitAttrsSet;
822  traitAttrsSet.insert(attrNames.begin(), attrNames.end());
824  for (auto attr : (*this)->getAttrs()) {
825  if (attr.getName() == getIteratorTypesAttrName()) {
826  auto iteratorTypes =
827  llvm::cast<ArrayAttr>(attr.getValue())
828  .getAsValueRange<IteratorTypeAttr, IteratorType>();
829  // Convert IteratorType enums into the string representation. This is
830  // needed, because tests still use the old format when 'iterator_types'
831  // attribute is represented as an array of strings.
832  // TODO: Remove this conversion once tests are fixed.
833  SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
834  llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
835  return StringAttr::get(getContext(), stringifyIteratorType(t));
836  }));
837 
838  attrs.emplace_back(getIteratorTypesAttrName(),
839  ArrayAttr::get(getContext(), iteratorTypeNames));
840  } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
841  attrs.push_back(attr);
842  }
843 
844  auto dictAttr = DictionaryAttr::get(getContext(), attrs);
845  p << " " << dictAttr << " " << getLhs() << ", ";
846  p << getRhs() << ", " << getAcc();
847 
848  p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
849  p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
850  << getResultType();
851 }
852 
853 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
854  const std::vector<std::pair<int64_t, int64_t>> &map) {
855  for (auto &dimPair : map) {
856  if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
857  dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
858  lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
859  return false;
860  }
861  return true;
862 }
863 
864 static LogicalResult verifyOutputShape(
865  ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
866  Type resType,
867  const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
868  const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
869  DenseSet<int64_t> lhsContractingDimSet;
870  DenseSet<int64_t> rhsContractingDimSet;
871  for (auto &dimPair : contractingDimMap) {
872  lhsContractingDimSet.insert(dimPair.first);
873  rhsContractingDimSet.insert(dimPair.second);
874  }
875  DenseSet<int64_t> rhsBatchDimSet;
876  for (auto &dimPair : batchDimMap)
877  rhsBatchDimSet.insert(dimPair.second);
878 
879  // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
880  SmallVector<int64_t, 4> expectedResultDims;
881  for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
882  if (lhsContractingDimSet.count(i) > 0)
883  continue;
884  expectedResultDims.push_back(lhsType.getDimSize(i));
885  }
886 
887  // Add free dimensions from 'rhsType' to 'expectedResultDims'.
888  for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
889  if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
890  continue;
891  expectedResultDims.push_back(rhsType.getDimSize(i));
892  }
893 
894  // Verify 'expectedResultDims'.
895  if (expectedResultDims.empty()) {
896  // No batch or free dimension implies a scalar result.
897  if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
898  return op.emitOpError("invalid accumulator/result vector shape");
899  } else {
900  // At least one batch or free dimension implies a vector result.
901  auto resVectorType = llvm::dyn_cast<VectorType>(resType);
902  auto accVectorType = llvm::dyn_cast<VectorType>(accType);
903  if (!resVectorType || !accVectorType)
904  return op.emitOpError("invalid accumulator/result vector shape");
905 
906  // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
907  // types fully define the result vector type. This assumes the affine maps
908  // are well-formed, which must have been verified already.
909  MLIRContext *ctx = op.getContext();
910  AffineMap lhsMap = op.getIndexingMapsArray()[0];
911  AffineMap rhsMap = op.getIndexingMapsArray()[1];
912  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
913  return op.emitOpError(
914  "expected all dimensions to be either a LHS or a RHS dimension");
915  SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
916  for (auto pair :
917  {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
918  VectorType v = pair.first;
919  auto map = pair.second;
920  for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
921  unsigned pos = map.getDimPosition(idx);
922  if (!extents[pos])
923  extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
924  }
925  }
926  if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
927  return op.emitOpError("expected all dimensions to get an extent as "
928  "either a LHS or a RHS dimension");
929 
930  AffineMap resMap = op.getIndexingMapsArray()[2];
931  auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
932  /*symbolCount=*/0, extents, ctx);
933  // Compose the resMap with the extentsMap, which is a constant map.
934  AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
935  assert(llvm::all_of(expectedMap.getResults(),
936  llvm::IsaPred<AffineConstantExpr>) &&
937  "expected constant extent along all dimensions.");
938  // Extract the expected shape and build the type.
939  auto expectedShape = llvm::to_vector<4>(
940  llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
941  return cast<AffineConstantExpr>(e).getValue();
942  }));
943  auto expected =
944  VectorType::get(expectedShape, resVectorType.getElementType(),
945  resVectorType.getScalableDims());
946  if (resVectorType != expected || accVectorType != expected)
947  return op.emitOpError(
948  "invalid accumulator/result vector shape, expected: ")
949  << expected;
950  }
951  return success();
952 }
953 
954 LogicalResult ContractionOp::verify() {
955  VectorType lhsType = getLhsType();
956  VectorType rhsType = getRhsType();
957  Type accType = getAccType();
958  Type resType = getResultType();
959 
960  if (llvm::isa<IntegerType>(lhsType.getElementType())) {
961  if (!lhsType.getElementType().isSignlessInteger())
962  return emitOpError("only supports signless integer types");
963  }
964 
965  // Verify that an indexing map was specified for each vector operand.
966  if (getIndexingMapsArray().size() != 3)
967  return emitOpError("expected an indexing map for each vector operand");
968 
969  // Verify that each index map has 'numIterators' inputs, no symbols, and
970  // that the number of map outputs equals the rank of its associated
971  // vector operand.
972  unsigned numIterators = getIteratorTypes().getValue().size();
973  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
974  auto index = it.index();
975  auto map = it.value();
976  if (map.getNumSymbols() != 0)
977  return emitOpError("expected indexing map ")
978  << index << " to have no symbols";
979  auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
980  unsigned rank = vectorType ? vectorType.getShape().size() : 0;
981  // Verify that the map has the right number of inputs, outputs, and indices.
982  // This also correctly accounts for (..) -> () for rank-0 results.
983  if (map.getNumDims() != numIterators)
984  return emitOpError("expected indexing map ")
985  << index << " to have " << numIterators << " number of inputs";
986  if (map.getNumResults() != rank)
987  return emitOpError("expected indexing map ")
988  << index << " to have " << rank << " number of outputs";
989  if (!map.isProjectedPermutation())
990  return emitOpError("expected indexing map ")
991  << index << " to be a projected permutation of its inputs";
992  }
993 
994  auto contractingDimMap = getContractingDimMap();
995  auto batchDimMap = getBatchDimMap();
996 
997  // Verify at least one contracting dimension pair was specified.
998  if (contractingDimMap.empty())
999  return emitOpError("expected at least one contracting dimension pair");
1000 
1001  // Verify contracting dimension map was properly constructed.
1002  if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1003  return emitOpError("invalid contracting dimension map");
1004 
1005  // Verify batch dimension map was properly constructed.
1006  if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1007  return emitOpError("invalid batch dimension map");
1008 
1009  // Verify 'accType' and 'resType' shape.
1010  if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1011  contractingDimMap, batchDimMap)))
1012  return failure();
1013 
1014  // Verify supported combining kind.
1015  auto vectorType = llvm::dyn_cast<VectorType>(resType);
1016  auto elementType = vectorType ? vectorType.getElementType() : resType;
1017  if (!isSupportedCombiningKind(getKind(), elementType))
1018  return emitOpError("unsupported contraction type");
1019 
1020  return success();
1021 }
1022 
1023 // MaskableOpInterface methods.
1024 
1025 /// Returns the mask type expected by this operation. Mostly used for
1026 /// verification purposes. It requires the operation to be vectorized."
1027 Type ContractionOp::getExpectedMaskType() {
1028  auto indexingMaps = this->getIndexingMapsArray();
1029  AffineMap lhsIdxMap = indexingMaps[0];
1030  AffineMap rhsIdxMap = indexingMaps[1];
1031  VectorType lhsType = this->getLhsType();
1032  VectorType rhsType = this->getRhsType();
1033 
1034  unsigned numVecDims = lhsIdxMap.getNumDims();
1035  SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1036  SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1037 
1038  // Using the information in the indexing maps, extract the size of each
1039  // dimension in the vector.contract operation from the two input operands.
1040  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1041  maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1042  maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1043  lhsType.getScalableDims()[dimIdx];
1044  }
1045  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1046  maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1047  maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1048  rhsType.getScalableDims()[dimIdx];
1049  }
1050 
1051  assert(!ShapedType::isDynamicShape(maskShape) &&
1052  "Mask shape couldn't be computed");
1053 
1054  return VectorType::get(maskShape,
1055  IntegerType::get(lhsType.getContext(), /*width=*/1),
1056  maskShapeScalableDims);
1057 }
1058 
1059 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1060  return SmallVector<StringRef>{getIndexingMapsAttrName(),
1061  getIteratorTypesAttrName(), getKindAttrName()};
1062 }
1063 
1064 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
1065  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1066  if (targetExpr == map.getResult(i))
1067  return i;
1068  return -1;
1069 }
1070 
1071 static std::vector<std::pair<int64_t, int64_t>>
1072 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1073  IteratorType targetIteratorType, MLIRContext *context) {
1074  std::vector<std::pair<int64_t, int64_t>> dimMap;
1075  for (const auto &it : llvm::enumerate(iteratorTypes)) {
1076  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1077  if (iteratorType != targetIteratorType)
1078  continue;
1079  // Search lhs/rhs map results for 'targetExpr'.
1080  auto targetExpr = getAffineDimExpr(it.index(), context);
1081  int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1082  int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1083  if (lhsDim >= 0 && rhsDim >= 0)
1084  dimMap.emplace_back(lhsDim, rhsDim);
1085  }
1086  return dimMap;
1087 }
1088 
1089 void ContractionOp::getIterationBounds(
1090  SmallVectorImpl<int64_t> &iterationBounds) {
1091  auto lhsShape = getLhsType().getShape();
1092  auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1093  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1094  SmallVector<int64_t, 2> iterationShape;
1095  for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1096  // Search lhs/rhs map results for 'targetExpr'.
1097  auto targetExpr = getAffineDimExpr(it.index(), getContext());
1098  auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1099  if (iteratorType == IteratorType::reduction) {
1100  // Get reduction dim size from lhs shape (same size in rhsShape).
1101  int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1102  assert(lhsDimIndex >= 0);
1103  iterationBounds.push_back(lhsShape[lhsDimIndex]);
1104  continue;
1105  }
1106  // Get parallel dimension size from result shape.
1107  int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1108  assert(resDimIndex >= 0);
1109  assert(resVectorType != nullptr);
1110  iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1111  }
1112 }
1113 
1114 void ContractionOp::getIterationIndexMap(
1115  std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1116  unsigned numMaps = getIndexingMapsArray().size();
1117  iterationIndexMap.resize(numMaps);
1118  for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1119  auto index = it.index();
1120  auto map = it.value();
1121  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1122  auto dim = cast<AffineDimExpr>(map.getResult(i));
1123  iterationIndexMap[index][dim.getPosition()] = i;
1124  }
1125  }
1126 }
1127 
1128 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1129  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1130  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1131  getContext());
1132 }
1133 
1134 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1135  SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1136  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1137  getContext());
1138 }
1139 
1140 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1142  getIterationBounds(shape);
1143  return shape;
1144 }
1145 
1146 /// Return a fused vector::ContractionOp which represents a patterns such as:
1147 ///
1148 /// ```mlir
1149 /// %c0 = vector.constant 0: ...
1150 /// %c = vector.contract %a, %b, %c0: ...
1151 /// %e = add %c, %d: ...
1152 /// ```
1153 ///
1154 /// by:
1155 ///
1156 /// ```mlir
1157 /// %e = vector.contract %a, %b, %d: ...
1158 /// ```
1159 ///
1160 /// Return null if the canonicalization does not apply.
1161 // TODO: This should be a folding of Add into Contract in core but while they
1162 // live in different dialects, it is not possible without unnatural
1163 // dependencies.
1164 template <typename AddOpType>
1165 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1167 
1168  LogicalResult matchAndRewrite(AddOpType addOp,
1169  PatternRewriter &rewriter) const override {
1170  auto canonicalize = [&](Value maybeContraction,
1171  Value otherOperand) -> vector::ContractionOp {
1172  vector::ContractionOp contractionOp =
1173  dyn_cast_or_null<vector::ContractionOp>(
1174  maybeContraction.getDefiningOp());
1175  if (!contractionOp)
1176  return vector::ContractionOp();
1177  if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1178  contractionOp.getAcc().getDefiningOp())) {
1179  if (maybeZero.getValue() ==
1180  rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1181  IRMapping bvm;
1182  bvm.map(contractionOp.getAcc(), otherOperand);
1183  auto newContraction =
1184  cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1185  rewriter.replaceOp(addOp, newContraction.getResult());
1186  return newContraction;
1187  }
1188  }
1189  return vector::ContractionOp();
1190  };
1191 
1192  Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1193  vector::ContractionOp contract = canonicalize(a, b);
1194  contract = contract ? contract : canonicalize(b, a);
1195  return contract ? success() : failure();
1196  }
1197 };
1198 
1199 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1200  MLIRContext *context) {
1203 }
1204 
1205 //===----------------------------------------------------------------------===//
1206 // ExtractElementOp
1207 //===----------------------------------------------------------------------===//
1208 
1209 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1210  Value source) {
1211  result.addOperands({source});
1212  result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
1213 }
1214 
1215 LogicalResult vector::ExtractElementOp::verify() {
1216  VectorType vectorType = getSourceVectorType();
1217  if (vectorType.getRank() == 0) {
1218  if (getPosition())
1219  return emitOpError("expected position to be empty with 0-D vector");
1220  return success();
1221  }
1222  if (vectorType.getRank() != 1)
1223  return emitOpError("unexpected >1 vector rank");
1224  if (!getPosition())
1225  return emitOpError("expected position for 1-D vector");
1226  return success();
1227 }
1228 
1229 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1230  // Skip the 0-D vector here now.
1231  if (!adaptor.getPosition())
1232  return {};
1233 
1234  // Fold extractelement (splat X) -> X.
1235  if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
1236  return splat.getInput();
1237 
1238  // Fold extractelement(broadcast(X)) -> X.
1239  if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1240  if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
1241  return broadcast.getSource();
1242 
1243  auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1244  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1245  if (!pos || !src)
1246  return {};
1247 
1248  auto srcElements = src.getValues<Attribute>();
1249 
1250  uint64_t posIdx = pos.getInt();
1251  if (posIdx >= srcElements.size())
1252  return {};
1253 
1254  return srcElements[posIdx];
1255 }
1256 
1257 //===----------------------------------------------------------------------===//
1258 // ExtractOp
1259 //===----------------------------------------------------------------------===//
1260 
1261 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1262  Value source, int64_t position) {
1263  build(builder, result, source, ArrayRef<int64_t>{position});
1264 }
1265 
1266 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1267  Value source, OpFoldResult position) {
1268  build(builder, result, source, ArrayRef<OpFoldResult>{position});
1269 }
1270 
1271 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1272  Value source, ArrayRef<int64_t> position) {
1273  build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1274  builder.getDenseI64ArrayAttr(position));
1275 }
1276 
1277 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1278  Value source, ArrayRef<OpFoldResult> position) {
1279  SmallVector<int64_t> staticPos;
1280  SmallVector<Value> dynamicPos;
1281  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1282  build(builder, result, source, dynamicPos,
1283  builder.getDenseI64ArrayAttr(staticPos));
1284 }
1285 
1286 LogicalResult
1287 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1288  ExtractOp::Adaptor adaptor,
1289  SmallVectorImpl<Type> &inferredReturnTypes) {
1290  auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1291  if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1292  vectorType.getRank()) {
1293  inferredReturnTypes.push_back(vectorType.getElementType());
1294  } else {
1295  auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1296  vectorType.getRank());
1297  inferredReturnTypes.push_back(VectorType::get(
1298  vectorType.getShape().drop_front(n), vectorType.getElementType(),
1299  vectorType.getScalableDims().drop_front(n)));
1300  }
1301  return success();
1302 }
1303 
1304 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1305  // Allow extracting 1-element vectors instead of scalars.
1306  auto isCompatible = [](TypeRange l, TypeRange r) {
1307  auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1308  return vectorType && vectorType.getShape().equals({1}) &&
1309  vectorType.getElementType() == r.front();
1310  };
1311  if (l.size() == 1 && r.size() == 1 &&
1312  (isCompatible(l, r) || isCompatible(r, l)))
1313  return true;
1314  return l == r;
1315 }
1316 
1317 LogicalResult vector::ExtractOp::verify() {
1318  // Note: This check must come before getMixedPosition() to prevent a crash.
1319  auto dynamicMarkersCount =
1320  llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1321  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1322  return emitOpError(
1323  "mismatch between dynamic and static positions (kDynamic marker but no "
1324  "corresponding dynamic position) -- this can only happen due to an "
1325  "incorrect fold/rewrite");
1326  auto position = getMixedPosition();
1327  if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1328  return emitOpError(
1329  "expected position attribute of rank no greater than vector rank");
1330  for (auto [idx, pos] : llvm::enumerate(position)) {
1331  if (pos.is<Attribute>()) {
1332  int64_t constIdx = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
1333  if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
1334  return emitOpError("expected position attribute #")
1335  << (idx + 1)
1336  << " to be a non-negative integer smaller than the "
1337  "corresponding vector dimension";
1338  }
1339  }
1340  }
1341  return success();
1342 }
1343 
1344 template <typename IntType>
1345 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1346  return llvm::to_vector<4>(llvm::map_range(
1347  arrayAttr.getAsRange<IntegerAttr>(),
1348  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1349 }
1350 
1351 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1352 /// positions.
1353 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1354  if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1355  return failure();
1356 
1357  // TODO: Canonicalization for dynamic position not implemented yet.
1358  if (extractOp.hasDynamicPosition())
1359  return failure();
1360 
1361  SmallVector<int64_t> globalPosition;
1362  ExtractOp currentOp = extractOp;
1363  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1364  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1365  while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1366  currentOp = nextOp;
1367  // TODO: Canonicalization for dynamic position not implemented yet.
1368  if (currentOp.hasDynamicPosition())
1369  return failure();
1370  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1371  globalPosition.append(extrPos.rbegin(), extrPos.rend());
1372  }
1373  extractOp.setOperand(0, currentOp.getVector());
1374  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1375  OpBuilder b(extractOp.getContext());
1376  std::reverse(globalPosition.begin(), globalPosition.end());
1377  extractOp.setStaticPosition(globalPosition);
1378  return success();
1379 }
1380 
1381 namespace {
1382 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1383 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1384 /// Compose TransposeOp permutations as we walk back.
1385 /// This helper class keeps an updated extraction position `extractPosition`
1386 /// with extra trailing sentinels.
1387 /// The sentinels encode the internal transposition status of the result vector.
1388 /// As we iterate, extractPosition is permuted and updated.
1389 class ExtractFromInsertTransposeChainState {
1390 public:
1391  ExtractFromInsertTransposeChainState(ExtractOp e);
1392 
1393  /// Iterate over producing insert and transpose ops until we find a fold.
1394  Value fold();
1395 
1396 private:
1397  /// Return true if the vector at position `a` is contained within the vector
1398  /// at position `b`. Under insert/extract semantics, this is the same as `a`
1399  /// is a prefix of `b`.
1400  template <typename ContainerA, typename ContainerB>
1401  bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1402  return a.size() <= b.size() &&
1403  std::equal(a.begin(), a.begin() + a.size(), b.begin());
1404  }
1405 
1406  /// Return true if the vector at position `a` intersects the vector at
1407  /// position `b`. Under insert/extract semantics, this is the same as equality
1408  /// of all entries of `a` that are >=0 with the corresponding entries of b.
1409  /// Comparison is on the common prefix (i.e. zip).
1410  template <typename ContainerA, typename ContainerB>
1411  bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1412  for (auto [elemA, elemB] : llvm::zip(a, b)) {
1413  if (elemA < 0 || elemB < 0)
1414  continue;
1415  if (elemA != elemB)
1416  return false;
1417  }
1418  return true;
1419  }
1420 
1421  /// Folding is only possible in the absence of an internal permutation in the
1422  /// result vector.
1423  bool canFold() {
1424  return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1425  }
1426 
1427  // Helper to get the next defining op of interest.
1428  void updateStateForNextIteration(Value v) {
1429  nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1430  nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1431  };
1432 
1433  // Case 1. If we hit a transpose, just compose the map and iterate.
1434  // Invariant: insert + transpose do not change rank, we can always compose.
1435  LogicalResult handleTransposeOp();
1436 
1437  // Case 2: the insert position matches extractPosition exactly, early return.
1438  LogicalResult handleInsertOpWithMatchingPos(Value &res);
1439 
1440  /// Case 3: if the insert position is a prefix of extractPosition, extract a
1441  /// portion of the source of the insert.
1442  /// Example:
1443  /// ```
1444  /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1445  /// // extractPosition == [1, 2, 3]
1446  /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1447  /// // can fold to vector.extract %source[0, 3]
1448  /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1449  /// ```
1450  /// To traverse through %source, we need to set the leading dims to 0 and
1451  /// drop the extra leading dims.
1452  /// This method updates the internal state.
1453  LogicalResult handleInsertOpWithPrefixPos(Value &res);
1454 
1455  /// Try to fold in place to extract(source, extractPosition) and return the
1456  /// folded result. Return null if folding is not possible (e.g. due to an
1457  /// internal tranposition in the result).
1458  Value tryToFoldExtractOpInPlace(Value source);
1459 
1460  ExtractOp extractOp;
1461  int64_t vectorRank;
1462  int64_t extractedRank;
1463 
1464  InsertOp nextInsertOp;
1465  TransposeOp nextTransposeOp;
1466 
1467  /// Sentinel values that encode the internal permutation status of the result.
1468  /// They are set to (-1, ... , -k) at the beginning and appended to
1469  /// `extractPosition`.
1470  /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1471  /// ensure that there is no internal transposition.
1472  /// Internal transposition cannot be accounted for with a folding pattern.
1473  // TODO: We could relax the internal transposition with an extra transposition
1474  // operation in a future canonicalizer.
1475  SmallVector<int64_t> sentinels;
1477 };
1478 } // namespace
1479 
1480 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1481  ExtractOp e)
1482  : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1483  extractedRank(extractOp.getNumIndices()) {
1484  assert(vectorRank >= extractedRank && "Extracted position overflow");
1485  sentinels.reserve(vectorRank - extractedRank);
1486  for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1487  sentinels.push_back(-(i + 1));
1488  extractPosition.assign(extractOp.getStaticPosition().begin(),
1489  extractOp.getStaticPosition().end());
1490  llvm::append_range(extractPosition, sentinels);
1491 }
1492 
1493 // Case 1. If we hit a transpose, just compose the map and iterate.
1494 // Invariant: insert + transpose do not change rank, we can always compose.
1495 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1496  // TODO: Canonicalization for dynamic position not implemented yet.
1497  if (extractOp.hasDynamicPosition())
1498  return failure();
1499 
1500  if (!nextTransposeOp)
1501  return failure();
1502  AffineMap m = inversePermutation(AffineMap::getPermutationMap(
1503  nextTransposeOp.getPermutation(), extractOp.getContext()));
1505  return success();
1506 }
1507 
1508 // Case 2: the insert position matches extractPosition exactly, early return.
1509 LogicalResult
1510 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1511  Value &res) {
1512  // TODO: Canonicalization for dynamic position not implemented yet.
1513  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1514  return failure();
1515 
1516  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1517  if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1518  return failure();
1519  // Case 2.a. early-exit fold.
1520  res = nextInsertOp.getSource();
1521  // Case 2.b. if internal transposition is present, canFold will be false.
1522  return success(canFold());
1523 }
1524 
1525 /// Case 3: if inserted position is a prefix of extractPosition,
1526 /// extract a portion of the source of the insertion.
1527 /// This method updates the internal state.
1528 LogicalResult
1529 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1530  // TODO: Canonicalization for dynamic position not implemented yet.
1531  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1532  return failure();
1533 
1534  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1535  if (!isContainedWithin(insertedPos, extractPosition))
1536  return failure();
1537  // Set leading dims to zero.
1538  std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1539  // Drop extra leading dims.
1540  extractPosition.erase(extractPosition.begin(),
1541  extractPosition.begin() + insertedPos.size());
1542  extractedRank = extractPosition.size() - sentinels.size();
1543  // Case 3.a. early-exit fold (break and delegate to post-while path).
1544  res = nextInsertOp.getSource();
1545  // Case 3.b. if internal transposition is present, canFold will be false.
1546  return success();
1547 }
1548 
1549 /// Try to fold in place to extract(source, extractPosition) and return the
1550 /// folded result. Return null if folding is not possible (e.g. due to an
1551 /// internal tranposition in the result).
1552 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1553  Value source) {
1554  // TODO: Canonicalization for dynamic position not implemented yet.
1555  if (extractOp.hasDynamicPosition())
1556  return Value();
1557 
1558  // If we can't fold (either internal transposition, or nothing to fold), bail.
1559  bool nothingToFold = (source == extractOp.getVector());
1560  if (nothingToFold || !canFold())
1561  return Value();
1562 
1563  // Otherwise, fold by updating the op inplace and return its result.
1564  OpBuilder b(extractOp.getContext());
1565  extractOp.setStaticPosition(
1566  ArrayRef(extractPosition).take_front(extractedRank));
1567  extractOp.getVectorMutable().assign(source);
1568  return extractOp.getResult();
1569 }
1570 
1571 /// Iterate over producing insert and transpose ops until we find a fold.
1572 Value ExtractFromInsertTransposeChainState::fold() {
1573  // TODO: Canonicalization for dynamic position not implemented yet.
1574  if (extractOp.hasDynamicPosition())
1575  return Value();
1576 
1577  Value valueToExtractFrom = extractOp.getVector();
1578  updateStateForNextIteration(valueToExtractFrom);
1579  while (nextInsertOp || nextTransposeOp) {
1580  // Case 1. If we hit a transpose, just compose the map and iterate.
1581  // Invariant: insert + transpose do not change rank, we can always compose.
1582  if (succeeded(handleTransposeOp())) {
1583  valueToExtractFrom = nextTransposeOp.getVector();
1584  updateStateForNextIteration(valueToExtractFrom);
1585  continue;
1586  }
1587 
1588  Value result;
1589  // Case 2: the position match exactly.
1590  if (succeeded(handleInsertOpWithMatchingPos(result)))
1591  return result;
1592 
1593  // Case 3: if the inserted position is a prefix of extractPosition, we can
1594  // just extract a portion of the source of the insert.
1595  if (succeeded(handleInsertOpWithPrefixPos(result)))
1596  return tryToFoldExtractOpInPlace(result);
1597 
1598  // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1599  // values. This is a more difficult case and we bail.
1600  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1601  if (isContainedWithin(extractPosition, insertedPos) ||
1602  intersectsWhereNonNegative(extractPosition, insertedPos))
1603  return Value();
1604 
1605  // Case 5: No intersection, we forward the extract to insertOp.dest().
1606  valueToExtractFrom = nextInsertOp.getDest();
1607  updateStateForNextIteration(valueToExtractFrom);
1608  }
1609  // If after all this we can fold, go for it.
1610  return tryToFoldExtractOpInPlace(valueToExtractFrom);
1611 }
1612 
1613 /// Returns true if the operation has a 0-D vector type operand or result.
1614 static bool hasZeroDimVectors(Operation *op) {
1615  auto hasZeroDimVectorType = [](Type type) -> bool {
1616  auto vecType = dyn_cast<VectorType>(type);
1617  return vecType && vecType.getRank() == 0;
1618  };
1619 
1620  return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1621  llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1622 }
1623 
1624 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1625 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1626  // TODO: Canonicalization for dynamic position not implemented yet.
1627  if (extractOp.hasDynamicPosition())
1628  return Value();
1629 
1630  Operation *defOp = extractOp.getVector().getDefiningOp();
1631  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1632  return Value();
1633 
1634  Value source = defOp->getOperand(0);
1635  if (extractOp.getType() == source.getType())
1636  return source;
1637  auto getRank = [](Type type) {
1638  return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1639  : 0;
1640  };
1641 
1642  // If splat or broadcast from a scalar, just return the source scalar.
1643  unsigned broadcastSrcRank = getRank(source.getType());
1644  if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
1645  return source;
1646 
1647  unsigned extractResultRank = getRank(extractOp.getType());
1648  if (extractResultRank >= broadcastSrcRank)
1649  return Value();
1650  // Check that the dimension of the result haven't been broadcasted.
1651  auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1652  auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
1653  if (extractVecType && broadcastVecType &&
1654  extractVecType.getShape() !=
1655  broadcastVecType.getShape().take_back(extractResultRank))
1656  return Value();
1657 
1658  auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1659  int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1660 
1661  // Detect all the positions that come from "dim-1" broadcasting.
1662  // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1663  // extract position to `0` when extracting from the source operand.
1664  llvm::SetVector<int64_t> broadcastedUnitDims =
1665  broadcastOp.computeBroadcastedUnitDims();
1666  SmallVector<int64_t> extractPos(extractOp.getStaticPosition());
1667  int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1668  for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1669  if (broadcastedUnitDims.contains(i))
1670  extractPos[i] = 0;
1671  // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1672  // matching extract position when extracting from the source operand.
1673  int64_t rankDiff = broadcastSrcRank - extractResultRank;
1674  extractPos.erase(extractPos.begin(),
1675  std::next(extractPos.begin(), extractPos.size() - rankDiff));
1676  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1677  OpBuilder b(extractOp.getContext());
1678  extractOp.setOperand(0, source);
1679  extractOp.setStaticPosition(extractPos);
1680  return extractOp.getResult();
1681 }
1682 
1683 // Fold extractOp with source coming from ShapeCast op.
1684 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1685  // TODO: Canonicalization for dynamic position not implemented yet.
1686  if (extractOp.hasDynamicPosition())
1687  return Value();
1688 
1689  auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1690  if (!shapeCastOp)
1691  return Value();
1692 
1693  // 0-D vectors not supported.
1694  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1695  if (hasZeroDimVectors(shapeCastOp))
1696  return Value();
1697 
1698  // Get the nth dimension size starting from lowest dimension.
1699  auto getDimReverse = [](VectorType type, int64_t n) {
1700  return type.getShape().take_back(n + 1).front();
1701  };
1702  int64_t destinationRank =
1703  llvm::isa<VectorType>(extractOp.getType())
1704  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1705  : 0;
1706  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1707  return Value();
1708  if (destinationRank > 0) {
1709  auto destinationType =
1710  llvm::cast<VectorType>(extractOp.getResult().getType());
1711  for (int64_t i = 0; i < destinationRank; i++) {
1712  // The lowest dimension of the destination must match the lowest
1713  // dimension of the shapecast op source.
1714  // TODO: This case could be support in a canonicalization pattern.
1715  if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1716  getDimReverse(destinationType, i))
1717  return Value();
1718  }
1719  }
1720  // Extract the strides associated with the extract op vector source. Then use
1721  // this to calculate a linearized position for the extract.
1722  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1723  std::reverse(extractedPos.begin(), extractedPos.end());
1724  SmallVector<int64_t, 4> strides;
1725  int64_t stride = 1;
1726  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1727  strides.push_back(stride);
1728  stride *=
1729  getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1730  }
1731 
1732  int64_t position = linearize(extractedPos, strides);
1733  // Then extract the strides associated to the shapeCast op vector source and
1734  // delinearize the position using those strides.
1735  SmallVector<int64_t, 4> newStrides;
1736  int64_t numDimension =
1737  shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1738  stride = 1;
1739  for (int64_t i = 0; i < numDimension; i++) {
1740  newStrides.push_back(stride);
1741  stride *=
1742  getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1743  }
1744  std::reverse(newStrides.begin(), newStrides.end());
1745  SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1746  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1747  OpBuilder b(extractOp.getContext());
1748  extractOp.setStaticPosition(newPosition);
1749  extractOp.setOperand(0, shapeCastOp.getSource());
1750  return extractOp.getResult();
1751 }
1752 
1753 /// Fold an ExtractOp from ExtractStridedSliceOp.
1754 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1755  // TODO: Canonicalization for dynamic position not implemented yet.
1756  if (extractOp.hasDynamicPosition())
1757  return Value();
1758 
1759  auto extractStridedSliceOp =
1760  extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1761  if (!extractStridedSliceOp)
1762  return Value();
1763 
1764  // 0-D vectors not supported.
1765  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1766  if (hasZeroDimVectors(extractStridedSliceOp))
1767  return Value();
1768 
1769  // Return if 'extractStridedSliceOp' has non-unit strides.
1770  if (extractStridedSliceOp.hasNonUnitStrides())
1771  return Value();
1772 
1773  // Trim offsets for dimensions fully extracted.
1774  auto sliceOffsets =
1775  extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1776  while (!sliceOffsets.empty()) {
1777  size_t lastOffset = sliceOffsets.size() - 1;
1778  if (sliceOffsets.back() != 0 ||
1779  extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1780  extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1781  break;
1782  sliceOffsets.pop_back();
1783  }
1784  unsigned destinationRank = 0;
1785  if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1786  destinationRank = vecType.getRank();
1787  // The dimensions of the result need to be untouched by the
1788  // extractStridedSlice op.
1789  if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1790  sliceOffsets.size())
1791  return Value();
1792 
1793  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1794  assert(extractedPos.size() >= sliceOffsets.size());
1795  for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1796  extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1797  extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1798 
1799  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1800  OpBuilder b(extractOp.getContext());
1801  extractOp.setStaticPosition(extractedPos);
1802  return extractOp.getResult();
1803 }
1804 
1805 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1806 static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1807  // TODO: Canonicalization for dynamic position not implemented yet.
1808  if (extractOp.hasDynamicPosition())
1809  return Value();
1810 
1811  int64_t destinationRank =
1812  llvm::isa<VectorType>(extractOp.getType())
1813  ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1814  : 0;
1815  auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1816  if (!insertOp)
1817  return Value();
1818 
1819  // 0-D vectors not supported.
1820  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1821  if (hasZeroDimVectors(insertOp))
1822  return Value();
1823 
1824  while (insertOp) {
1825  int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1826  insertOp.getSourceVectorType().getRank();
1827  if (destinationRank > insertOp.getSourceVectorType().getRank())
1828  return Value();
1829  auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1830  ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1831 
1832  if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1833  return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1834  }))
1835  return Value();
1836  bool disjoint = false;
1837  SmallVector<int64_t, 4> offsetDiffs;
1838  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1839  int64_t start = insertOffsets[dim];
1840  int64_t size =
1841  (dim < insertRankDiff)
1842  ? 1
1843  : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1844  int64_t end = start + size;
1845  int64_t offset = extractOffsets[dim];
1846  // Check if the start of the extract offset is in the interval inserted.
1847  if (start <= offset && offset < end) {
1848  if (dim >= insertRankDiff)
1849  offsetDiffs.push_back(offset - start);
1850  continue;
1851  }
1852  disjoint = true;
1853  break;
1854  }
1855  // The extract element chunk overlap with the vector inserted.
1856  if (!disjoint) {
1857  // If any of the inner dimensions are only partially inserted we have a
1858  // partial overlap.
1859  int64_t srcRankDiff =
1860  insertOp.getSourceVectorType().getRank() - destinationRank;
1861  for (int64_t i = 0; i < destinationRank; i++) {
1862  if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1863  insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1864  insertRankDiff))
1865  return Value();
1866  }
1867  extractOp.getVectorMutable().assign(insertOp.getSource());
1868  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1869  OpBuilder b(extractOp.getContext());
1870  extractOp.setStaticPosition(offsetDiffs);
1871  return extractOp.getResult();
1872  }
1873  // If the chunk extracted is disjoint from the chunk inserted, keep
1874  // looking in the insert chain.
1875  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1876  }
1877  return Value();
1878 }
1879 
1880 /// Try to fold the extraction of a scalar from a vector defined by
1881 /// vector.from_elements. E.g.:
1882 ///
1883 /// %0 = vector.from_elements %a, %b : vector<2xf32>
1884 /// %1 = vector.extract %0[0] : f32 from vector<2xf32>
1885 /// ==> fold to %a
1886 static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
1887  // Dynamic extractions cannot be folded.
1888  if (extractOp.hasDynamicPosition())
1889  return {};
1890 
1891  // Look for extract(from_elements).
1892  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
1893  if (!fromElementsOp)
1894  return {};
1895 
1896  // Scalable vectors are not supported.
1897  auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1898  if (vecType.isScalable())
1899  return {};
1900 
1901  // Only extractions of scalars are supported.
1902  int64_t rank = vecType.getRank();
1903  ArrayRef<int64_t> indices = extractOp.getStaticPosition();
1904  if (extractOp.getType() != vecType.getElementType())
1905  return {};
1906  assert(static_cast<int64_t>(indices.size()) == rank &&
1907  "unexpected number of indices");
1908 
1909  // Compute flattened/linearized index and fold to operand.
1910  int flatIndex = 0;
1911  int stride = 1;
1912  for (int i = rank - 1; i >= 0; --i) {
1913  flatIndex += indices[i] * stride;
1914  stride *= vecType.getDimSize(i);
1915  }
1916  return fromElementsOp.getElements()[flatIndex];
1917 }
1918 
1919 OpFoldResult ExtractOp::fold(FoldAdaptor) {
1920  // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
1921  // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
1922  // mismatch).
1923  if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
1924  return getVector();
1925  if (succeeded(foldExtractOpFromExtractChain(*this)))
1926  return getResult();
1927  if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
1928  return res;
1929  if (auto res = foldExtractFromBroadcast(*this))
1930  return res;
1931  if (auto res = foldExtractFromShapeCast(*this))
1932  return res;
1933  if (auto val = foldExtractFromExtractStrided(*this))
1934  return val;
1935  if (auto val = foldExtractStridedOpFromInsertChain(*this))
1936  return val;
1937  if (auto val = foldScalarExtractFromFromElements(*this))
1938  return val;
1939  return OpFoldResult();
1940 }
1941 
1942 namespace {
1943 
1944 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
1945 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
1946 public:
1948 
1949  LogicalResult matchAndRewrite(ExtractOp extractOp,
1950  PatternRewriter &rewriter) const override {
1951  Operation *defOp = extractOp.getVector().getDefiningOp();
1952  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1953  return failure();
1954 
1955  Value source = defOp->getOperand(0);
1956  if (extractOp.getType() == source.getType())
1957  return failure();
1958  auto getRank = [](Type type) {
1959  return llvm::isa<VectorType>(type)
1960  ? llvm::cast<VectorType>(type).getRank()
1961  : 0;
1962  };
1963  unsigned broadcastSrcRank = getRank(source.getType());
1964  unsigned extractResultRank = getRank(extractOp.getType());
1965  // We only consider the case where the rank of the source is less than or
1966  // equal to the rank of the extract dst. The other cases are handled in the
1967  // folding patterns.
1968  if (extractResultRank < broadcastSrcRank)
1969  return failure();
1970 
1971  // Special case if broadcast src is a 0D vector.
1972  if (extractResultRank == 0) {
1973  assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
1974  rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
1975  return success();
1976  }
1977  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1978  extractOp, extractOp.getType(), source);
1979  return success();
1980  }
1981 };
1982 
1983 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
1984 class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
1985 public:
1987 
1988  LogicalResult matchAndRewrite(ExtractOp extractOp,
1989  PatternRewriter &rewriter) const override {
1990  // Return if 'ExtractOp' operand is not defined by a splat vector
1991  // ConstantOp.
1992  Value sourceVector = extractOp.getVector();
1993  Attribute vectorCst;
1994  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
1995  return failure();
1996  auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
1997  if (!splat)
1998  return failure();
1999  TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
2000  if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
2001  newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2002  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2003  return success();
2004  }
2005 };
2006 
2007 // Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
2008 class ExtractOpNonSplatConstantFolder final
2009  : public OpRewritePattern<ExtractOp> {
2010 public:
2012 
2013  LogicalResult matchAndRewrite(ExtractOp extractOp,
2014  PatternRewriter &rewriter) const override {
2015  // TODO: Canonicalization for dynamic position not implemented yet.
2016  if (extractOp.hasDynamicPosition())
2017  return failure();
2018 
2019  // Return if 'ExtractOp' operand is not defined by a compatible vector
2020  // ConstantOp.
2021  Value sourceVector = extractOp.getVector();
2022  Attribute vectorCst;
2023  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2024  return failure();
2025 
2026  auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
2027  if (vecTy.isScalable())
2028  return failure();
2029 
2030  // The splat case is handled by `ExtractOpSplatConstantFolder`.
2031  auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2032  if (!dense || dense.isSplat())
2033  return failure();
2034 
2035  // Calculate the linearized position of the continuous chunk of elements to
2036  // extract.
2037  llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2038  copy(extractOp.getStaticPosition(), completePositions.begin());
2039  int64_t elemBeginPosition =
2040  linearize(completePositions, computeStrides(vecTy.getShape()));
2041  auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
2042 
2043  TypedAttr newAttr;
2044  if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
2045  SmallVector<Attribute> elementValues(
2046  denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2047  newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2048  } else {
2049  newAttr = *denseValuesBegin;
2050  }
2051 
2052  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2053  return success();
2054  }
2055 };
2056 
2057 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2058 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2059 public:
2061 
2062  LogicalResult matchAndRewrite(ExtractOp extractOp,
2063  PatternRewriter &rewriter) const override {
2064  auto createMaskOp =
2065  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2066  if (!createMaskOp)
2067  return failure();
2068 
2069  VectorType extractedMaskType =
2070  llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2071 
2072  if (!extractedMaskType)
2073  return failure();
2074 
2075  auto maskOperands = createMaskOp.getOperands();
2076  ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2077  VectorType maskType = createMaskOp.getVectorType();
2078 
2079  bool containsUnknownDims = false;
2080  bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2081 
2082  for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2083  dimIdx++) {
2084  int64_t pos = extractOpPos[dimIdx];
2085  Value operand = maskOperands[dimIdx];
2086  auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2087  if (!constantOp) {
2088  // Bounds of this dim unknown.
2089  containsUnknownDims = true;
2090  continue;
2091  }
2092 
2093  int64_t createMaskBound =
2094  llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2095 
2096  if (pos != ShapedType::kDynamic) {
2097  // If any position is outside the range from the `create_mask`, then the
2098  // extracted mask will be all-false.
2099  allFalse |= pos >= createMaskBound;
2100  } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2101  // This dim is not all-true and since this is a dynamic index we don't
2102  // know if the extraction is within the true or false region.
2103  // Note: Zero dims have already handled via getMaskFormat().
2104  containsUnknownDims = true;
2105  }
2106  }
2107 
2108  if (allFalse) {
2109  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2110  extractOp, DenseElementsAttr::get(extractedMaskType, false));
2111  } else if (!containsUnknownDims) {
2112  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2113  extractOp, extractedMaskType,
2114  maskOperands.drop_front(extractOpPos.size()));
2115  } else {
2116  return failure();
2117  }
2118  return success();
2119  }
2120 };
2121 
2122 // Folds extract(shape_cast(..)) into shape_cast when the total element count
2123 // does not change.
2124 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2125  PatternRewriter &rewriter) {
2126  auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2127  if (!castOp)
2128  return failure();
2129 
2130  VectorType sourceType = castOp.getSourceVectorType();
2131  auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2132  if (!targetType)
2133  return failure();
2134 
2135  if (sourceType.getNumElements() != targetType.getNumElements())
2136  return failure();
2137 
2138  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2139  castOp.getSource());
2140  return success();
2141 }
2142 
2143 /// Try to canonicalize the extraction of a subvector from a vector defined by
2144 /// vector.from_elements. E.g.:
2145 ///
2146 /// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2147 /// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2148 /// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2149 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2150  PatternRewriter &rewriter) {
2151  // Dynamic positions are not supported.
2152  if (extractOp.hasDynamicPosition())
2153  return failure();
2154 
2155  // Scalar extracts are handled by the folder.
2156  auto resultType = dyn_cast<VectorType>(extractOp.getType());
2157  if (!resultType)
2158  return failure();
2159 
2160  // Look for extracts from a from_elements op.
2161  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2162  if (!fromElementsOp)
2163  return failure();
2164  VectorType inputType = fromElementsOp.getType();
2165 
2166  // Scalable vectors are not supported.
2167  if (resultType.isScalable() || inputType.isScalable())
2168  return failure();
2169 
2170  // Compute the position of first extracted element and flatten/linearize the
2171  // position.
2172  SmallVector<int64_t> firstElementPos =
2173  llvm::to_vector(extractOp.getStaticPosition());
2174  firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2175  int flatIndex = 0;
2176  int stride = 1;
2177  for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2178  flatIndex += firstElementPos[i] * stride;
2179  stride *= inputType.getDimSize(i);
2180  }
2181 
2182  // Replace the op with a smaller from_elements op.
2183  rewriter.replaceOpWithNewOp<FromElementsOp>(
2184  extractOp, resultType,
2185  fromElementsOp.getElements().slice(flatIndex,
2186  resultType.getNumElements()));
2187  return success();
2188 }
2189 } // namespace
2190 
2191 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2192  MLIRContext *context) {
2193  results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2194  ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2195  results.add(foldExtractFromShapeCastToShapeCast);
2196  results.add(foldExtractFromFromElements);
2197 }
2198 
2199 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
2200  SmallVectorImpl<int64_t> &results) {
2201  for (auto attr : arrayAttr)
2202  results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2203 }
2204 
2205 //===----------------------------------------------------------------------===//
2206 // FmaOp
2207 //===----------------------------------------------------------------------===//
2208 
2209 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2210  return llvm::to_vector<4>(getVectorType().getShape());
2211 }
2212 
2213 //===----------------------------------------------------------------------===//
2214 // FromElementsOp
2215 //===----------------------------------------------------------------------===//
2216 
2217 /// Rewrite a vector.from_elements into a vector.splat if all elements are the
2218 /// same SSA value. E.g.:
2219 ///
2220 /// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2221 /// ==> rewrite to vector.splat %a : vector<3xf32>
2222 static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
2223  PatternRewriter &rewriter) {
2224  if (!llvm::all_equal(fromElementsOp.getElements()))
2225  return failure();
2226  rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
2227  fromElementsOp.getElements().front());
2228  return success();
2229 }
2230 
2231 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2232  MLIRContext *context) {
2234 }
2235 
2236 //===----------------------------------------------------------------------===//
2237 // BroadcastOp
2238 //===----------------------------------------------------------------------===//
2239 
2240 /// Return the dimensions of the result vector that were formerly ones in the
2241 /// source tensor and thus correspond to "dim-1" broadcasting.
2244  ArrayRef<int64_t> dstShape) {
2245  int64_t rankDiff = dstShape.size() - srcShape.size();
2246  int64_t dstDim = rankDiff;
2248  for (auto [s1, s2] :
2249  llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2250  if (s1 != s2) {
2251  assert(s1 == 1 && "expected dim-1 broadcasting");
2252  res.insert(dstDim);
2253  }
2254  ++dstDim;
2255  }
2256  return res;
2257 }
2258 
2260  // Scalar broadcast is without any unit dim broadcast.
2261  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2262  if (!srcVectorType)
2263  return {};
2264  return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2265  getResultVectorType().getShape());
2266 }
2267 
2268 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2269 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
2270 /// This requires (and asserts) that the broadcast is free of dim-1
2271 /// broadcasting.
2272 /// Since vector.broadcast only allows expanding leading dimensions, an extra
2273 /// vector.transpose may be inserted to make the broadcast possible.
2274 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2275 /// the helper will assert. This means:
2276 /// 1. `dstShape` must not be empty.
2277 /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2278 /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2279 // must match the `value` shape.
2280 Value BroadcastOp::createOrFoldBroadcastOp(
2281  OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2282  const llvm::SetVector<int64_t> &broadcastedDims) {
2283  assert(!dstShape.empty() && "unexpected empty dst shape");
2284 
2285  // Well-formedness check.
2286  SmallVector<int64_t> checkShape;
2287  for (int i = 0, e = dstShape.size(); i < e; ++i) {
2288  if (broadcastedDims.contains(i))
2289  continue;
2290  checkShape.push_back(dstShape[i]);
2291  }
2292  assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2293  "ill-formed broadcastedDims contains values not confined to "
2294  "destVectorShape");
2295 
2296  Location loc = value.getLoc();
2297  Type elementType = getElementTypeOrSelf(value.getType());
2298  VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2299  VectorType dstVectorType = VectorType::get(dstShape, elementType);
2300 
2301  // Step 2. If scalar -> dstShape broadcast, just do it.
2302  if (!srcVectorType) {
2303  assert(checkShape.empty() &&
2304  "ill-formed createOrFoldBroadcastOp arguments");
2305  return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2306  }
2307 
2308  assert(srcVectorType.getShape().equals(checkShape) &&
2309  "ill-formed createOrFoldBroadcastOp arguments");
2310 
2311  // Step 3. Since vector.broadcast only allows creating leading dims,
2312  // vector -> dstShape broadcast may require a transpose.
2313  // Traverse the dims in order and construct:
2314  // 1. The leading entries of the broadcastShape that is guaranteed to be
2315  // achievable by a simple broadcast.
2316  // 2. The induced permutation for the subsequent vector.transpose that will
2317  // bring us from `broadcastShape` back to he desired `dstShape`.
2318  // If the induced permutation is not the identity, create a vector.transpose.
2319  SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2320  broadcastShape.reserve(dstShape.size());
2321  // Consider the example:
2322  // srcShape = 2x4
2323  // dstShape = 1x2x3x4x5
2324  // broadcastedDims = [0, 2, 4]
2325  //
2326  // We want to build:
2327  // broadcastShape = 1x3x5x2x4
2328  // permutation = [0, 2, 4, 1, 3]
2329  // ---V--- -----V-----
2330  // leading broadcast part src shape part
2331  //
2332  // Note that the trailing dims of broadcastShape are exactly the srcShape
2333  // by construction.
2334  // nextSrcShapeDim is used to keep track of where in the permutation the
2335  // "src shape part" occurs.
2336  int64_t nextSrcShapeDim = broadcastedDims.size();
2337  for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2338  if (broadcastedDims.contains(i)) {
2339  // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2340  // bring it to the head of the broadcastShape.
2341  // It will need to be permuted back from `broadcastShape.size() - 1` into
2342  // position `i`.
2343  broadcastShape.push_back(dstShape[i]);
2344  permutation[i] = broadcastShape.size() - 1;
2345  } else {
2346  // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2347  // shape and needs to be permuted into position `i`.
2348  // Don't touch `broadcastShape` here, the whole srcShape will be
2349  // appended after.
2350  permutation[i] = nextSrcShapeDim++;
2351  }
2352  }
2353  // 3.c. Append the srcShape.
2354  llvm::append_range(broadcastShape, srcVectorType.getShape());
2355 
2356  // Ensure there are no dim-1 broadcasts.
2357  assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2358  .empty() &&
2359  "unexpected dim-1 broadcast");
2360 
2361  VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2362  assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2363  vector::BroadcastableToResult::Success &&
2364  "must be broadcastable");
2365  Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2366  // Step 4. If we find any dimension that indeed needs to be permuted,
2367  // immediately return a new vector.transpose.
2368  for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2369  if (permutation[i] != i)
2370  return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2371  // Otherwise return res.
2372  return res;
2373 }
2374 
2376  Type srcType, VectorType dstVectorType,
2377  std::pair<VectorDim, VectorDim> *mismatchingDims) {
2378  // Broadcast scalar to vector of the same element type.
2379  if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
2380  getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
2381  return BroadcastableToResult::Success;
2382  // From now on, only vectors broadcast.
2383  VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2384  if (!srcVectorType)
2385  return BroadcastableToResult::SourceTypeNotAVector;
2386 
2387  int64_t srcRank = srcVectorType.getRank();
2388  int64_t dstRank = dstVectorType.getRank();
2389  if (srcRank > dstRank)
2390  return BroadcastableToResult::SourceRankHigher;
2391  // Source has an exact match or singleton value for all trailing dimensions
2392  // (all leading dimensions are simply duplicated).
2393  int64_t lead = dstRank - srcRank;
2394  for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2395  // Have mismatching dims (in the sense of vector.broadcast semantics) been
2396  // encountered?
2397  bool foundMismatchingDims = false;
2398 
2399  // Check fixed-width dims.
2400  int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2401  int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2402  if (srcDim != 1 && srcDim != dstDim)
2403  foundMismatchingDims = true;
2404 
2405  // Check scalable flags.
2406  bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2407  bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2408  if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2409  // 1 -> [N] is fine, everything else should be rejected when mixing
2410  // fixed-width and scalable dims
2411  (srcDimScalableFlag != dstDimScalableFlag &&
2412  (srcDim != 1 || srcDimScalableFlag)))
2413  foundMismatchingDims = true;
2414 
2415  if (foundMismatchingDims) {
2416  if (mismatchingDims != nullptr) {
2417  mismatchingDims->first.dim = srcDim;
2418  mismatchingDims->first.isScalable = srcDimScalableFlag;
2419 
2420  mismatchingDims->second.dim = dstDim;
2421  mismatchingDims->second.isScalable = dstDimScalableFlag;
2422  }
2423  return BroadcastableToResult::DimensionMismatch;
2424  }
2425  }
2426 
2427  return BroadcastableToResult::Success;
2428 }
2429 
2430 LogicalResult BroadcastOp::verify() {
2431  std::pair<VectorDim, VectorDim> mismatchingDims;
2433  getSourceType(), getResultVectorType(), &mismatchingDims);
2434  if (res == BroadcastableToResult::Success)
2435  return success();
2436  if (res == BroadcastableToResult::SourceRankHigher)
2437  return emitOpError("source rank higher than destination rank");
2438  if (res == BroadcastableToResult::DimensionMismatch) {
2439  return emitOpError("dimension mismatch (")
2440  << (mismatchingDims.first.isScalable ? "[" : "")
2441  << mismatchingDims.first.dim
2442  << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
2443  << (mismatchingDims.second.isScalable ? "[" : "")
2444  << mismatchingDims.second.dim
2445  << (mismatchingDims.second.isScalable ? "]" : "") << ")";
2446  }
2447  if (res == BroadcastableToResult::SourceTypeNotAVector)
2448  return emitOpError("source type is not a vector");
2449  llvm_unreachable("unexpected vector.broadcast op error");
2450 }
2451 
2452 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
2453  if (getSourceType() == getResultVectorType())
2454  return getSource();
2455  if (!adaptor.getSource())
2456  return {};
2457  auto vectorType = getResultVectorType();
2458  if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
2459  return DenseElementsAttr::get(vectorType, adaptor.getSource());
2460  if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2461  return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
2462  return {};
2463 }
2464 
2465 namespace {
2466 
2467 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
2468 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2470 
2471  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2472  PatternRewriter &rewriter) const override {
2473  auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2474  if (!srcBroadcast)
2475  return failure();
2476  rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
2477  broadcastOp.getResultVectorType(),
2478  srcBroadcast.getSource());
2479  return success();
2480  }
2481 };
2482 } // namespace
2483 
2484 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2485  MLIRContext *context) {
2486  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2487  // calling `populateCastAwayVectorLeadingOneDimPatterns`
2488  results.add<BroadcastFolder>(context);
2489 }
2490 
2491 //===----------------------------------------------------------------------===//
2492 // ShuffleOp
2493 //===----------------------------------------------------------------------===//
2494 
2495 LogicalResult ShuffleOp::verify() {
2496  VectorType resultType = getResultVectorType();
2497  VectorType v1Type = getV1VectorType();
2498  VectorType v2Type = getV2VectorType();
2499  // Verify ranks.
2500  int64_t resRank = resultType.getRank();
2501  int64_t v1Rank = v1Type.getRank();
2502  int64_t v2Rank = v2Type.getRank();
2503  bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2504  bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2505  if (!wellFormed0DCase && !wellFormedNDCase)
2506  return emitOpError("rank mismatch");
2507 
2508  // Verify all but leading dimension sizes.
2509  for (int64_t r = 1; r < v1Rank; ++r) {
2510  int64_t resDim = resultType.getDimSize(r);
2511  int64_t v1Dim = v1Type.getDimSize(r);
2512  int64_t v2Dim = v2Type.getDimSize(r);
2513  if (resDim != v1Dim || v1Dim != v2Dim)
2514  return emitOpError("dimension mismatch");
2515  }
2516  // Verify mask length.
2517  ArrayRef<int64_t> mask = getMask();
2518  int64_t maskLength = mask.size();
2519  if (maskLength <= 0)
2520  return emitOpError("invalid mask length");
2521  if (maskLength != resultType.getDimSize(0))
2522  return emitOpError("mask length mismatch");
2523  // Verify all indices.
2524  int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2525  (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2526  for (auto [idx, maskPos] : llvm::enumerate(mask)) {
2527  if (maskPos < 0 || maskPos >= indexSize)
2528  return emitOpError("mask index #") << (idx + 1) << " out of range";
2529  }
2530  return success();
2531 }
2532 
2533 LogicalResult
2534 ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
2535  ShuffleOp::Adaptor adaptor,
2536  SmallVectorImpl<Type> &inferredReturnTypes) {
2537  auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2538  auto v1Rank = v1Type.getRank();
2539  // Construct resulting type: leading dimension matches mask
2540  // length, all trailing dimensions match the operands.
2542  shape.reserve(v1Rank);
2543  shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2544  // In the 0-D case there is no trailing shape to append.
2545  if (v1Rank > 0)
2546  llvm::append_range(shape, v1Type.getShape().drop_front());
2547  inferredReturnTypes.push_back(
2548  VectorType::get(shape, v1Type.getElementType()));
2549  return success();
2550 }
2551 
2552 template <typename T>
2553 static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
2554  T expected = begin;
2555  return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2556  return value == expected++;
2557  });
2558 }
2559 
2560 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2561  VectorType v1Type = getV1VectorType();
2562  // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
2563  // but must be a canonicalization into a vector.broadcast.
2564  if (v1Type.getRank() == 0)
2565  return {};
2566 
2567  // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
2568  if (!v1Type.isScalable() &&
2569  isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
2570  return getV1();
2571  // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
2572  if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
2573  isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
2574  getV2VectorType().getDimSize(0)))
2575  return getV2();
2576 
2577  Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
2578  if (!lhs || !rhs)
2579  return {};
2580 
2581  auto lhsType =
2582  llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType());
2583  // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2584  // manipulation.
2585  if (lhsType.getRank() != 1)
2586  return {};
2587  int64_t lhsSize = lhsType.getDimSize(0);
2588 
2589  SmallVector<Attribute> results;
2590  auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
2591  auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
2592  for (int64_t i : this->getMask()) {
2593  if (i >= lhsSize) {
2594  results.push_back(rhsElements[i - lhsSize]);
2595  } else {
2596  results.push_back(lhsElements[i]);
2597  }
2598  }
2599 
2600  return DenseElementsAttr::get(getResultVectorType(), results);
2601 }
2602 
2603 namespace {
2604 
2605 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
2606 // to a broadcast.
2607 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
2609 
2610  LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
2611  PatternRewriter &rewriter) const override {
2612  VectorType v1VectorType = shuffleOp.getV1VectorType();
2613  ArrayRef<int64_t> mask = shuffleOp.getMask();
2614  if (v1VectorType.getRank() > 0)
2615  return failure();
2616  if (mask.size() != 1)
2617  return failure();
2618  VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
2619  if (mask[0] == 0)
2620  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2621  shuffleOp.getV1());
2622  else
2623  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2624  shuffleOp.getV2());
2625  return success();
2626  }
2627 };
2628 
2629 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
2630 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
2631 public:
2633 
2634  LogicalResult matchAndRewrite(ShuffleOp op,
2635  PatternRewriter &rewriter) const override {
2636  auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2637  auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2638 
2639  if (!v1Splat || !v2Splat)
2640  return failure();
2641 
2642  if (v1Splat.getInput() != v2Splat.getInput())
2643  return failure();
2644 
2645  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
2646  return success();
2647  }
2648 };
2649 
2650 /// Pattern to rewrite a fixed-size interleave via vector.shuffle to
2651 /// vector.interleave.
2652 class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
2653 public:
2655 
2656  LogicalResult matchAndRewrite(ShuffleOp op,
2657  PatternRewriter &rewriter) const override {
2658  VectorType resultType = op.getResultVectorType();
2659  if (resultType.isScalable())
2660  return rewriter.notifyMatchFailure(
2661  op, "ShuffleOp can't represent a scalable interleave");
2662 
2663  if (resultType.getRank() != 1)
2664  return rewriter.notifyMatchFailure(
2665  op, "ShuffleOp can't represent an n-D interleave");
2666 
2667  VectorType sourceType = op.getV1VectorType();
2668  if (sourceType != op.getV2VectorType() ||
2669  sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2670  return rewriter.notifyMatchFailure(
2671  op, "ShuffleOp types don't match an interleave");
2672  }
2673 
2674  ArrayRef<int64_t> shuffleMask = op.getMask();
2675  int64_t resultVectorSize = resultType.getNumElements();
2676  for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2677  int64_t maskValueA = shuffleMask[i * 2];
2678  int64_t maskValueB = shuffleMask[(i * 2) + 1];
2679  if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2680  return rewriter.notifyMatchFailure(op,
2681  "ShuffleOp mask not interleaving");
2682  }
2683 
2684  rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
2685  return success();
2686  }
2687 };
2688 
2689 } // namespace
2690 
2691 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
2692  MLIRContext *context) {
2693  results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2694  context);
2695 }
2696 
2697 //===----------------------------------------------------------------------===//
2698 // InsertElementOp
2699 //===----------------------------------------------------------------------===//
2700 
2701 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
2702  Value source, Value dest) {
2703  build(builder, result, source, dest, {});
2704 }
2705 
2706 LogicalResult InsertElementOp::verify() {
2707  auto dstVectorType = getDestVectorType();
2708  if (dstVectorType.getRank() == 0) {
2709  if (getPosition())
2710  return emitOpError("expected position to be empty with 0-D vector");
2711  return success();
2712  }
2713  if (dstVectorType.getRank() != 1)
2714  return emitOpError("unexpected >1 vector rank");
2715  if (!getPosition())
2716  return emitOpError("expected position for 1-D vector");
2717  return success();
2718 }
2719 
2720 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2721  // Skip the 0-D vector here.
2722  if (!adaptor.getPosition())
2723  return {};
2724 
2725  auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2726  auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2727  auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2728  if (!src || !dst || !pos)
2729  return {};
2730 
2731  if (src.getType() != getDestVectorType().getElementType())
2732  return {};
2733 
2734  auto dstElements = dst.getValues<Attribute>();
2735 
2736  SmallVector<Attribute> results(dstElements);
2737 
2738  uint64_t posIdx = pos.getInt();
2739  if (posIdx >= results.size())
2740  return {};
2741  results[posIdx] = src;
2742 
2743  return DenseElementsAttr::get(getDestVectorType(), results);
2744 }
2745 
2746 //===----------------------------------------------------------------------===//
2747 // InsertOp
2748 //===----------------------------------------------------------------------===//
2749 
2750 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2751  Value source, Value dest, int64_t position) {
2752  build(builder, result, source, dest, ArrayRef<int64_t>{position});
2753 }
2754 
2755 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2756  Value source, Value dest, OpFoldResult position) {
2757  build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
2758 }
2759 
2760 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2761  Value source, Value dest,
2762  ArrayRef<int64_t> position) {
2763  SmallVector<OpFoldResult> posVals;
2764  posVals.reserve(position.size());
2765  llvm::transform(position, std::back_inserter(posVals),
2766  [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
2767  build(builder, result, source, dest, posVals);
2768 }
2769 
2770 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2771  Value source, Value dest,
2772  ArrayRef<OpFoldResult> position) {
2773  SmallVector<int64_t> staticPos;
2774  SmallVector<Value> dynamicPos;
2775  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
2776  build(builder, result, source, dest, dynamicPos,
2777  builder.getDenseI64ArrayAttr(staticPos));
2778 }
2779 
2780 LogicalResult InsertOp::verify() {
2781  SmallVector<OpFoldResult> position = getMixedPosition();
2782  auto destVectorType = getDestVectorType();
2783  if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
2784  return emitOpError(
2785  "expected position attribute of rank no greater than dest vector rank");
2786  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2787  if (srcVectorType &&
2788  (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2789  static_cast<unsigned>(destVectorType.getRank())))
2790  return emitOpError("expected position attribute rank + source rank to "
2791  "match dest vector rank");
2792  if (!srcVectorType &&
2793  (position.size() != static_cast<unsigned>(destVectorType.getRank())))
2794  return emitOpError(
2795  "expected position attribute rank to match the dest vector rank");
2796  for (auto [idx, pos] : llvm::enumerate(position)) {
2797  if (auto attr = pos.dyn_cast<Attribute>()) {
2798  int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2799  if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
2800  return emitOpError("expected position attribute #")
2801  << (idx + 1)
2802  << " to be a non-negative integer smaller than the "
2803  "corresponding "
2804  "dest vector dimension";
2805  }
2806  }
2807  }
2808  return success();
2809 }
2810 
2811 namespace {
2812 
2813 // If insertOp is only inserting unit dimensions it can be transformed to a
2814 // broadcast.
2815 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
2816 public:
2818 
2819  LogicalResult matchAndRewrite(InsertOp insertOp,
2820  PatternRewriter &rewriter) const override {
2821  auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
2822  if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2823  srcVecType.getNumElements())
2824  return failure();
2825  rewriter.replaceOpWithNewOp<BroadcastOp>(
2826  insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2827  return success();
2828  }
2829 };
2830 
2831 /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
2832 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
2833 public:
2835 
2836  LogicalResult matchAndRewrite(InsertOp op,
2837  PatternRewriter &rewriter) const override {
2838  auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2839  auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2840 
2841  if (!srcSplat || !dstSplat)
2842  return failure();
2843 
2844  if (srcSplat.getInput() != dstSplat.getInput())
2845  return failure();
2846 
2847  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
2848  return success();
2849  }
2850 };
2851 
2852 // Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
2853 class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
2854 public:
2856 
2857  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
2858  // unless the source vector constant has a single use.
2859  static constexpr int64_t vectorSizeFoldThreshold = 256;
2860 
2861  LogicalResult matchAndRewrite(InsertOp op,
2862  PatternRewriter &rewriter) const override {
2863  // TODO: Canonicalization for dynamic position not implemented yet.
2864  if (op.hasDynamicPosition())
2865  return failure();
2866 
2867  // Return if 'InsertOp' operand is not defined by a compatible vector
2868  // ConstantOp.
2869  TypedValue<VectorType> destVector = op.getDest();
2870  Attribute vectorDestCst;
2871  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
2872  return failure();
2873  auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
2874  if (!denseDest)
2875  return failure();
2876 
2877  VectorType destTy = destVector.getType();
2878  if (destTy.isScalable())
2879  return failure();
2880 
2881  // Make sure we do not create too many large constants.
2882  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
2883  !destVector.hasOneUse())
2884  return failure();
2885 
2886  Value sourceValue = op.getSource();
2887  Attribute sourceCst;
2888  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
2889  return failure();
2890 
2891  // Calculate the linearized position of the continuous chunk of elements to
2892  // insert.
2893  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
2894  copy(op.getStaticPosition(), completePositions.begin());
2895  int64_t insertBeginPosition =
2896  linearize(completePositions, computeStrides(destTy.getShape()));
2897 
2898  SmallVector<Attribute> insertedValues;
2899  Type destEltType = destTy.getElementType();
2900 
2901  // The `convertIntegerAttr` method specifically handles the case
2902  // for `llvm.mlir.constant` which can hold an attribute with a
2903  // different type than the return type.
2904  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
2905  for (auto value : denseSource.getValues<Attribute>())
2906  insertedValues.push_back(convertIntegerAttr(value, destEltType));
2907  } else {
2908  insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType));
2909  }
2910 
2911  auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
2912  copy(insertedValues, allValues.begin() + insertBeginPosition);
2913  auto newAttr = DenseElementsAttr::get(destTy, allValues);
2914 
2915  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
2916  return success();
2917  }
2918 
2919 private:
2920  /// Converts the expected type to an IntegerAttr if there's
2921  /// a mismatch.
2922  Attribute convertIntegerAttr(Attribute attr, Type expectedType) const {
2923  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
2924  if (intAttr.getType() != expectedType)
2925  return IntegerAttr::get(expectedType, intAttr.getInt());
2926  }
2927  return attr;
2928  }
2929 };
2930 
2931 } // namespace
2932 
2933 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
2934  MLIRContext *context) {
2935  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
2936  InsertOpConstantFolder>(context);
2937 }
2938 
2939 // Eliminates insert operations that produce values identical to their source
2940 // value. This happens when the source and destination vectors have identical
2941 // sizes.
2942 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
2943  if (getNumIndices() == 0)
2944  return getSource();
2945  return {};
2946 }
2947 
2948 //===----------------------------------------------------------------------===//
2949 // InsertStridedSliceOp
2950 //===----------------------------------------------------------------------===//
2951 
2952 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
2953  Value source, Value dest,
2954  ArrayRef<int64_t> offsets,
2955  ArrayRef<int64_t> strides) {
2956  result.addOperands({source, dest});
2957  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
2958  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
2959  result.addTypes(dest.getType());
2960  result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
2961  offsetsAttr);
2962  result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
2963  stridesAttr);
2964 }
2965 
2966 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
2967 template <typename OpType>
2968 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
2969  ArrayAttr arrayAttr,
2970  ArrayRef<int64_t> shape,
2971  StringRef attrName) {
2972  if (arrayAttr.size() > shape.size())
2973  return op.emitOpError("expected ")
2974  << attrName << " attribute of rank no greater than vector rank";
2975  return success();
2976 }
2977 
2978 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2979 // interval. If `halfOpen` is true then the admissible interval is [min, max).
2980 // Otherwise, the admissible interval is [min, max].
2981 template <typename OpType>
2982 static LogicalResult
2983 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
2984  int64_t max, StringRef attrName,
2985  bool halfOpen = true) {
2986  for (auto attr : arrayAttr) {
2987  auto val = llvm::cast<IntegerAttr>(attr).getInt();
2988  auto upper = max;
2989  if (!halfOpen)
2990  upper += 1;
2991  if (val < min || val >= upper)
2992  return op.emitOpError("expected ") << attrName << " to be confined to ["
2993  << min << ", " << upper << ")";
2994  }
2995  return success();
2996 }
2997 
2998 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2999 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3000 // Otherwise, the admissible interval is [min, max].
3001 template <typename OpType>
3002 static LogicalResult
3003 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
3004  ArrayRef<int64_t> shape, StringRef attrName,
3005  bool halfOpen = true, int64_t min = 0) {
3006  for (auto [index, attrDimPair] :
3007  llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3008  int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3009  int64_t max = std::get<1>(attrDimPair);
3010  if (!halfOpen)
3011  max += 1;
3012  if (val < min || val >= max)
3013  return op.emitOpError("expected ")
3014  << attrName << " dimension " << index << " to be confined to ["
3015  << min << ", " << max << ")";
3016  }
3017  return success();
3018 }
3019 
3020 // Returns true if, for all indices i = 0..shape.size()-1, val is in the
3021 // [min, max} interval:
3022 // val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3023 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3024 // the admissible interval is [min, max].
3025 template <typename OpType>
3027  OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3028  ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3029  bool halfOpen = true, int64_t min = 1) {
3030  assert(arrayAttr1.size() <= shape.size());
3031  assert(arrayAttr2.size() <= shape.size());
3032  for (auto [index, it] :
3033  llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3034  auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3035  auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3036  int64_t max = std::get<2>(it);
3037  if (!halfOpen)
3038  max += 1;
3039  if (val1 + val2 < 0 || val1 + val2 >= max)
3040  return op.emitOpError("expected sum(")
3041  << attrName1 << ", " << attrName2 << ") dimension " << index
3042  << " to be confined to [" << min << ", " << max << ")";
3043  }
3044  return success();
3045 }
3046 
3047 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
3048  MLIRContext *context) {
3049  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3050  return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3051  });
3052  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3053 }
3054 
3055 LogicalResult InsertStridedSliceOp::verify() {
3056  auto sourceVectorType = getSourceVectorType();
3057  auto destVectorType = getDestVectorType();
3058  auto offsets = getOffsetsAttr();
3059  auto strides = getStridesAttr();
3060  if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3061  return emitOpError(
3062  "expected offsets of same size as destination vector rank");
3063  if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3064  return emitOpError("expected strides of same size as source vector rank");
3065  if (sourceVectorType.getRank() > destVectorType.getRank())
3066  return emitOpError(
3067  "expected source rank to be no greater than destination rank");
3068 
3069  auto sourceShape = sourceVectorType.getShape();
3070  auto destShape = destVectorType.getShape();
3071  SmallVector<int64_t, 4> sourceShapeAsDestShape(
3072  destShape.size() - sourceShape.size(), 0);
3073  sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3074  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3075  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3076  if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3077  offName)) ||
3078  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3079  /*max=*/1, stridesName,
3080  /*halfOpen=*/false)) ||
3082  *this, offsets,
3083  makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3084  offName, "source vector shape",
3085  /*halfOpen=*/false, /*min=*/1)))
3086  return failure();
3087 
3088  unsigned rankDiff = destShape.size() - sourceShape.size();
3089  for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3090  if (sourceVectorType.getScalableDims()[idx] !=
3091  destVectorType.getScalableDims()[idx + rankDiff]) {
3092  return emitOpError("mismatching scalable flags (at source vector idx=")
3093  << idx << ")";
3094  }
3095  if (sourceVectorType.getScalableDims()[idx]) {
3096  auto sourceSize = sourceShape[idx];
3097  auto destSize = destShape[idx + rankDiff];
3098  if (sourceSize != destSize) {
3099  return emitOpError("expected size at idx=")
3100  << idx
3101  << (" to match the corresponding base size from the input "
3102  "vector (")
3103  << sourceSize << (" vs ") << destSize << (")");
3104  }
3105  }
3106  }
3107 
3108  return success();
3109 }
3110 
3111 namespace {
3112 /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3113 /// SplatOp(X):dst_type) to SplatOp(X):dst_type.
3114 class FoldInsertStridedSliceSplat final
3115  : public OpRewritePattern<InsertStridedSliceOp> {
3116 public:
3118 
3119  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3120  PatternRewriter &rewriter) const override {
3121  auto srcSplatOp =
3122  insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
3123  auto destSplatOp =
3124  insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3125 
3126  if (!srcSplatOp || !destSplatOp)
3127  return failure();
3128 
3129  if (srcSplatOp.getInput() != destSplatOp.getInput())
3130  return failure();
3131 
3132  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3133  return success();
3134  }
3135 };
3136 
3137 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3138 /// to dst.
3139 class FoldInsertStridedSliceOfExtract final
3140  : public OpRewritePattern<InsertStridedSliceOp> {
3141 public:
3143 
3144  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3145  PatternRewriter &rewriter) const override {
3146  auto extractStridedSliceOp =
3147  insertStridedSliceOp.getSource()
3148  .getDefiningOp<vector::ExtractStridedSliceOp>();
3149 
3150  if (!extractStridedSliceOp)
3151  return failure();
3152 
3153  if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3154  return failure();
3155 
3156  // Check if have the same strides and offsets.
3157  if (extractStridedSliceOp.getStrides() !=
3158  insertStridedSliceOp.getStrides() ||
3159  extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3160  return failure();
3161 
3162  rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3163  return success();
3164  }
3165 };
3166 
3167 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3168 // ConstantOp.
3169 class InsertStridedSliceConstantFolder final
3170  : public OpRewritePattern<InsertStridedSliceOp> {
3171 public:
3173 
3174  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3175  // unless the source vector constant has a single use.
3176  static constexpr int64_t vectorSizeFoldThreshold = 256;
3177 
3178  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3179  PatternRewriter &rewriter) const override {
3180  // Return if 'InsertOp' operand is not defined by a compatible vector
3181  // ConstantOp.
3182  TypedValue<VectorType> destVector = op.getDest();
3183  Attribute vectorDestCst;
3184  if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3185  return failure();
3186 
3187  VectorType destTy = destVector.getType();
3188  if (destTy.isScalable())
3189  return failure();
3190 
3191  // Make sure we do not create too many large constants.
3192  if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3193  !destVector.hasOneUse())
3194  return failure();
3195 
3196  auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3197 
3198  TypedValue<VectorType> sourceValue = op.getSource();
3199  Attribute sourceCst;
3200  if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3201  return failure();
3202 
3203  // TODO: Handle non-unit strides when they become available.
3204  if (op.hasNonUnitStrides())
3205  return failure();
3206 
3207  VectorType sliceVecTy = sourceValue.getType();
3208  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3209  int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3210  SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3211  SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3212 
3213  // Calcualte the destination element indices by enumerating all slice
3214  // positions within the destination and linearizing them. The enumeration
3215  // order is lexicographic which yields a sequence of monotonically
3216  // increasing linearized position indices.
3217  // Because the destination may have higher dimensionality then the slice,
3218  // we keep track of two overlapping sets of positions and offsets.
3219  auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3220  auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3221  auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3222  SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3223  MutableArrayRef<int64_t> currSlicePosition(
3224  currDestPosition.begin() + rankDifference, currDestPosition.end());
3225  ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3226  offsets.end());
3227  do {
3228  int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3229  assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3230  assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3231  "Invalid slice element");
3232  newValues[linearizedPosition] = *sliceValuesIt;
3233  ++sliceValuesIt;
3234  } while (succeeded(
3235  incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3236 
3237  auto newAttr = DenseElementsAttr::get(destTy, newValues);
3238  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3239  return success();
3240  }
3241 };
3242 
3243 } // namespace
3244 
3245 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3246  RewritePatternSet &results, MLIRContext *context) {
3247  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3248  InsertStridedSliceConstantFolder>(context);
3249 }
3250 
3251 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3252  if (getSourceVectorType() == getDestVectorType())
3253  return getSource();
3254  return {};
3255 }
3256 
3257 //===----------------------------------------------------------------------===//
3258 // OuterProductOp
3259 //===----------------------------------------------------------------------===//
3260 
3261 /// Build an op without mask, use the type of `acc` as the return type.
3262 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
3263  Value lhs, Value rhs, Value acc) {
3264  result.addOperands({lhs, rhs, acc});
3265  result.addTypes(acc.getType());
3266 }
3267 
3269  p << " " << getLhs() << ", " << getRhs();
3270  if (getAcc()) {
3271  p << ", " << getAcc();
3272  p.printOptionalAttrDict((*this)->getAttrs());
3273  }
3274  p << " : " << getLhs().getType() << ", " << getRhs().getType();
3275 }
3276 
3277 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
3279  Type tLHS, tRHS;
3280  if (parser.parseOperandList(operandsInfo) ||
3281  parser.parseOptionalAttrDict(result.attributes) ||
3282  parser.parseColonType(tLHS) || parser.parseComma() ||
3283  parser.parseType(tRHS))
3284  return failure();
3285  if (operandsInfo.size() < 2)
3286  return parser.emitError(parser.getNameLoc(),
3287  "expected at least 2 operands");
3288  VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3289  VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3290  if (!vLHS)
3291  return parser.emitError(parser.getNameLoc(),
3292  "expected vector type for operand #1");
3293 
3294  VectorType resType;
3295  if (vRHS) {
3296  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3297  vRHS.getScalableDims()[0]};
3298  resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
3299  vLHS.getElementType(), scalableDimsRes);
3300  } else {
3301  // Scalar RHS operand
3302  SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3303  resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3304  scalableDimsRes);
3305  }
3306 
3307  if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
3308  result.attributes.append(
3309  OuterProductOp::getKindAttrName(result.name),
3311  OuterProductOp::getDefaultKind()));
3312  }
3313 
3314  return failure(
3315  parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
3316  parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
3317  (operandsInfo.size() > 2 &&
3318  parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
3319  parser.addTypeToList(resType, result.types));
3320 }
3321 
3322 LogicalResult OuterProductOp::verify() {
3323  Type tRHS = getOperandTypeRHS();
3324  VectorType vLHS = getOperandVectorTypeLHS(),
3325  vRHS = llvm::dyn_cast<VectorType>(tRHS),
3326  vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3327 
3328  if (vLHS.getRank() != 1)
3329  return emitOpError("expected 1-d vector for operand #1");
3330 
3331  if (vRHS) {
3332  // Proper OUTER operation.
3333  if (vRHS.getRank() != 1)
3334  return emitOpError("expected 1-d vector for operand #2");
3335  if (vRES.getRank() != 2)
3336  return emitOpError("expected 2-d vector result");
3337  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3338  return emitOpError("expected #1 operand dim to match result dim #1");
3339  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3340  return emitOpError("expected #2 operand dim to match result dim #2");
3341  if (vLHS.isScalable() && !vRHS.isScalable()) {
3342  // This restriction reflects what's currently supported in terms of
3343  // scalable vectors. However, we could relax this if there's a use case.
3344  return emitOpError(
3345  "expected either both or only #2 operand dim to be scalable");
3346  }
3347  } else {
3348  // An AXPY operation.
3349  if (vRES.getRank() != 1)
3350  return emitOpError("expected 1-d vector result");
3351  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3352  return emitOpError("expected #1 operand dim to match result dim #1");
3353  }
3354 
3355  if (vACC && vACC != vRES)
3356  return emitOpError("expected operand #3 of same type as result type");
3357 
3358  // Verify supported combining kind.
3359  if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
3360  return emitOpError("unsupported outerproduct type");
3361 
3362  return success();
3363 }
3364 
3365 // MaskableOpInterface methods.
3366 
3367 /// Returns the mask type expected by this operation. Mostly used for
3368 /// verification purposes. It requires the operation to be vectorized."
3369 Type OuterProductOp::getExpectedMaskType() {
3370  auto vecType = this->getResultVectorType();
3371  return VectorType::get(vecType.getShape(),
3372  IntegerType::get(vecType.getContext(), /*width=*/1),
3373  vecType.getScalableDims());
3374 }
3375 
3376 //===----------------------------------------------------------------------===//
3377 // ExtractStridedSliceOp
3378 //===----------------------------------------------------------------------===//
3379 
3380 // Inference works as follows:
3381 // 1. Add 'sizes' from prefix of dims in 'offsets'.
3382 // 2. Add sizes from 'vectorType' for remaining dims.
3383 // Scalable flags are inherited from 'vectorType'.
3384 static Type inferStridedSliceOpResultType(VectorType vectorType,
3385  ArrayAttr offsets, ArrayAttr sizes,
3386  ArrayAttr strides) {
3387  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3389  shape.reserve(vectorType.getRank());
3390  unsigned idx = 0;
3391  for (unsigned e = offsets.size(); idx < e; ++idx)
3392  shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3393  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3394  shape.push_back(vectorType.getShape()[idx]);
3395 
3396  return VectorType::get(shape, vectorType.getElementType(),
3397  vectorType.getScalableDims());
3398 }
3399 
3400 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3401  Value source, ArrayRef<int64_t> offsets,
3402  ArrayRef<int64_t> sizes,
3403  ArrayRef<int64_t> strides) {
3404  result.addOperands(source);
3405  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3406  auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
3407  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3408  result.addTypes(
3409  inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
3410  offsetsAttr, sizesAttr, stridesAttr));
3411  result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
3412  offsetsAttr);
3413  result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
3414  sizesAttr);
3415  result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
3416  stridesAttr);
3417 }
3418 
3419 LogicalResult ExtractStridedSliceOp::verify() {
3420  auto type = getSourceVectorType();
3421  auto offsets = getOffsetsAttr();
3422  auto sizes = getSizesAttr();
3423  auto strides = getStridesAttr();
3424  if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3425  return emitOpError(
3426  "expected offsets, sizes and strides attributes of same size");
3427 
3428  auto shape = type.getShape();
3429  auto offName = getOffsetsAttrName();
3430  auto sizesName = getSizesAttrName();
3431  auto stridesName = getStridesAttrName();
3432  if (failed(
3433  isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
3434  failed(
3435  isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
3436  failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
3437  stridesName)) ||
3438  failed(
3439  isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
3440  failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
3441  /*halfOpen=*/false,
3442  /*min=*/1)) ||
3443  failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3444  /*max=*/1, stridesName,
3445  /*halfOpen=*/false)) ||
3446  failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
3447  shape, offName, sizesName,
3448  /*halfOpen=*/false)))
3449  return failure();
3450 
3451  auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
3452  offsets, sizes, strides);
3453  if (getResult().getType() != resultType)
3454  return emitOpError("expected result type to be ") << resultType;
3455 
3456  for (unsigned idx = 0; idx < sizes.size(); ++idx) {
3457  if (type.getScalableDims()[idx]) {
3458  auto inputDim = type.getShape()[idx];
3459  auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3460  if (inputDim != inputSize)
3461  return emitOpError("expected size at idx=")
3462  << idx
3463  << (" to match the corresponding base size from the input "
3464  "vector (")
3465  << inputSize << (" vs ") << inputDim << (")");
3466  }
3467  }
3468 
3469  return success();
3470 }
3471 
3472 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
3473 // to use the source of the InsertStrided ops if we can detect that the
3474 // extracted vector is a subset of one of the vector inserted.
3475 static LogicalResult
3476 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
3477  // Helper to extract integer out of ArrayAttr.
3478  auto getElement = [](ArrayAttr array, int idx) {
3479  return llvm::cast<IntegerAttr>(array[idx]).getInt();
3480  };
3481  ArrayAttr extractOffsets = op.getOffsets();
3482  ArrayAttr extractStrides = op.getStrides();
3483  ArrayAttr extractSizes = op.getSizes();
3484  auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3485  while (insertOp) {
3486  if (op.getSourceVectorType().getRank() !=
3487  insertOp.getSourceVectorType().getRank())
3488  return failure();
3489  ArrayAttr insertOffsets = insertOp.getOffsets();
3490  ArrayAttr insertStrides = insertOp.getStrides();
3491  // If the rank of extract is greater than the rank of insert, we are likely
3492  // extracting a partial chunk of the vector inserted.
3493  if (extractOffsets.size() > insertOffsets.size())
3494  return failure();
3495  bool patialoverlap = false;
3496  bool disjoint = false;
3497  SmallVector<int64_t, 4> offsetDiffs;
3498  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3499  if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
3500  return failure();
3501  int64_t start = getElement(insertOffsets, dim);
3502  int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3503  int64_t offset = getElement(extractOffsets, dim);
3504  int64_t size = getElement(extractSizes, dim);
3505  // Check if the start of the extract offset is in the interval inserted.
3506  if (start <= offset && offset < end) {
3507  // If the extract interval overlaps but is not fully included we may
3508  // have a partial overlap that will prevent any folding.
3509  if (offset + size > end)
3510  patialoverlap = true;
3511  offsetDiffs.push_back(offset - start);
3512  continue;
3513  }
3514  disjoint = true;
3515  break;
3516  }
3517  // The extract element chunk is a subset of the insert element.
3518  if (!disjoint && !patialoverlap) {
3519  op.setOperand(insertOp.getSource());
3520  // OpBuilder is only used as a helper to build an I64ArrayAttr.
3521  OpBuilder b(op.getContext());
3522  op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
3523  return success();
3524  }
3525  // If the chunk extracted is disjoint from the chunk inserted, keep looking
3526  // in the insert chain.
3527  if (disjoint)
3528  insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3529  else {
3530  // The extracted vector partially overlap the inserted vector, we cannot
3531  // fold.
3532  return failure();
3533  }
3534  }
3535  return failure();
3536 }
3537 
3538 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3539  if (getSourceVectorType() == getResult().getType())
3540  return getVector();
3541  if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
3542  return getResult();
3543  return {};
3544 }
3545 
3546 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
3547  populateFromInt64AttrArray(getOffsets(), results);
3548 }
3549 
3550 namespace {
3551 
3552 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
3553 // ConstantMaskOp.
3554 class StridedSliceConstantMaskFolder final
3555  : public OpRewritePattern<ExtractStridedSliceOp> {
3556 public:
3558 
3559  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3560  PatternRewriter &rewriter) const override {
3561  // Return if 'extractStridedSliceOp' operand is not defined by a
3562  // ConstantMaskOp.
3563  auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3564  auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3565  if (!constantMaskOp)
3566  return failure();
3567  // Return if 'extractStridedSliceOp' has non-unit strides.
3568  if (extractStridedSliceOp.hasNonUnitStrides())
3569  return failure();
3570  // Gather constant mask dimension sizes.
3571  ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
3572  // Gather strided slice offsets and sizes.
3573  SmallVector<int64_t, 4> sliceOffsets;
3574  populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
3575  sliceOffsets);
3576  SmallVector<int64_t, 4> sliceSizes;
3577  populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
3578 
3579  // Compute slice of vector mask region.
3580  SmallVector<int64_t, 4> sliceMaskDimSizes;
3581  sliceMaskDimSizes.reserve(maskDimSizes.size());
3582  for (auto [maskDimSize, sliceOffset, sliceSize] :
3583  llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3584  int64_t sliceMaskDimSize = std::max(
3585  static_cast<int64_t>(0),
3586  std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3587  sliceMaskDimSizes.push_back(sliceMaskDimSize);
3588  }
3589  // Add unchanged dimensions.
3590  if (sliceMaskDimSizes.size() < maskDimSizes.size())
3591  for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3592  sliceMaskDimSizes.push_back(maskDimSizes[i]);
3593  // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
3594  // region is a conjunction of mask dim intervals).
3595  if (llvm::is_contained(sliceMaskDimSizes, 0))
3596  sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3597 
3598  // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
3599  // region.
3600  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3601  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3602  sliceMaskDimSizes);
3603  return success();
3604  }
3605 };
3606 
3607 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3608 class StridedSliceSplatConstantFolder final
3609  : public OpRewritePattern<ExtractStridedSliceOp> {
3610 public:
3612 
3613  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3614  PatternRewriter &rewriter) const override {
3615  // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3616  // ConstantOp.
3617  Value sourceVector = extractStridedSliceOp.getVector();
3618  Attribute vectorCst;
3619  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3620  return failure();
3621 
3622  auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3623  if (!splat)
3624  return failure();
3625 
3626  auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
3627  splat.getSplatValue<Attribute>());
3628  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3629  newAttr);
3630  return success();
3631  }
3632 };
3633 
3634 // Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3635 // ConstantOp.
3636 class StridedSliceNonSplatConstantFolder final
3637  : public OpRewritePattern<ExtractStridedSliceOp> {
3638 public:
3640 
3641  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3642  PatternRewriter &rewriter) const override {
3643  // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3644  // ConstantOp.
3645  Value sourceVector = extractStridedSliceOp.getVector();
3646  Attribute vectorCst;
3647  if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3648  return failure();
3649 
3650  // The splat case is handled by `StridedSliceSplatConstantFolder`.
3651  auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3652  if (!dense || dense.isSplat())
3653  return failure();
3654 
3655  // TODO: Handle non-unit strides when they become available.
3656  if (extractStridedSliceOp.hasNonUnitStrides())
3657  return failure();
3658 
3659  auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
3660  ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3661  SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3662 
3663  VectorType sliceVecTy = extractStridedSliceOp.getType();
3664  ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3665  int64_t sliceRank = sliceVecTy.getRank();
3666 
3667  // Expand offsets and sizes to match the vector rank.
3668  SmallVector<int64_t, 4> offsets(sliceRank, 0);
3669  copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
3670 
3671  SmallVector<int64_t, 4> sizes(sourceShape);
3672  copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
3673 
3674  // Calculate the slice elements by enumerating all slice positions and
3675  // linearizing them. The enumeration order is lexicographic which yields a
3676  // sequence of monotonically increasing linearized position indices.
3677  auto denseValuesBegin = dense.value_begin<Attribute>();
3678  SmallVector<Attribute> sliceValues;
3679  sliceValues.reserve(sliceVecTy.getNumElements());
3680  SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3681  do {
3682  int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
3683  assert(linearizedPosition < sourceVecTy.getNumElements() &&
3684  "Invalid index");
3685  sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3686  } while (
3687  succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
3688 
3689  assert(static_cast<int64_t>(sliceValues.size()) ==
3690  sliceVecTy.getNumElements() &&
3691  "Invalid number of slice elements");
3692  auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
3693  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3694  newAttr);
3695  return success();
3696  }
3697 };
3698 
3699 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
3700 // BroadcastOp(ExtractStrideSliceOp).
3701 class StridedSliceBroadcast final
3702  : public OpRewritePattern<ExtractStridedSliceOp> {
3703 public:
3705 
3706  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3707  PatternRewriter &rewriter) const override {
3708  auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3709  if (!broadcast)
3710  return failure();
3711  auto srcVecType =
3712  llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
3713  unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3714  auto dstVecType = llvm::cast<VectorType>(op.getType());
3715  unsigned dstRank = dstVecType.getRank();
3716  unsigned rankDiff = dstRank - srcRank;
3717  // Check if the most inner dimensions of the source of the broadcast are the
3718  // same as the destination of the extract. If this is the case we can just
3719  // use a broadcast as the original dimensions are untouched.
3720  bool lowerDimMatch = true;
3721  for (unsigned i = 0; i < srcRank; i++) {
3722  if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3723  lowerDimMatch = false;
3724  break;
3725  }
3726  }
3727  Value source = broadcast.getSource();
3728  // If the inner dimensions don't match, it means we need to extract from the
3729  // source of the orignal broadcast and then broadcast the extracted value.
3730  // We also need to handle degenerated cases where the source is effectively
3731  // just a single scalar.
3732  bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3733  if (!lowerDimMatch && !isScalarSrc) {
3734  source = rewriter.create<ExtractStridedSliceOp>(
3735  op->getLoc(), source,
3736  getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
3737  getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
3738  getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
3739  }
3740  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
3741  return success();
3742  }
3743 };
3744 
3745 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
3746 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
3747 public:
3749 
3750  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3751  PatternRewriter &rewriter) const override {
3752  auto splat = op.getVector().getDefiningOp<SplatOp>();
3753  if (!splat)
3754  return failure();
3755  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
3756  return success();
3757  }
3758 };
3759 
3760 } // namespace
3761 
3762 void ExtractStridedSliceOp::getCanonicalizationPatterns(
3763  RewritePatternSet &results, MLIRContext *context) {
3764  // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
3765  // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
3766  results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
3767  StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3768  StridedSliceSplat>(context);
3769 }
3770 
3771 //===----------------------------------------------------------------------===//
3772 // TransferReadOp
3773 //===----------------------------------------------------------------------===//
3774 
3775 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
3776 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3777  VectorType vectorType, Value source,
3778  ValueRange indices, AffineMapAttr permutationMapAttr,
3779  /*optional*/ ArrayAttr inBoundsAttr) {
3780  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3781  Value padding = builder.create<arith::ConstantOp>(
3782  result.location, elemType, builder.getZeroAttr(elemType));
3783  build(builder, result, vectorType, source, indices, permutationMapAttr,
3784  padding, /*mask=*/Value(), inBoundsAttr);
3785 }
3786 
3787 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
3788 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3789  VectorType vectorType, Value source,
3790  ValueRange indices, AffineMap permutationMap,
3791  std::optional<ArrayRef<bool>> inBounds) {
3792  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3793  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3794  ? builder.getBoolArrayAttr(inBounds.value())
3795  : builder.getBoolArrayAttr(
3796  SmallVector<bool>(vectorType.getRank(), false));
3797  build(builder, result, vectorType, source, indices, permutationMapAttr,
3798  inBoundsAttr);
3799 }
3800 
3801 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
3802 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3803  VectorType vectorType, Value source,
3804  ValueRange indices, Value padding,
3805  std::optional<ArrayRef<bool>> inBounds) {
3806  AffineMap permutationMap = getTransferMinorIdentityMap(
3807  llvm::cast<ShapedType>(source.getType()), vectorType);
3808  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3809  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3810  ? builder.getBoolArrayAttr(inBounds.value())
3811  : builder.getBoolArrayAttr(
3812  SmallVector<bool>(vectorType.getRank(), false));
3813  build(builder, result, vectorType, source, indices, permutationMapAttr,
3814  padding,
3815  /*mask=*/Value(), inBoundsAttr);
3816 }
3817 
3818 /// 4. Builder that sets padding to zero and permutation map to
3819 /// 'getMinorIdentityMap'.
3820 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3821  VectorType vectorType, Value source,
3822  ValueRange indices,
3823  std::optional<ArrayRef<bool>> inBounds) {
3824  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3825  Value padding = builder.create<arith::ConstantOp>(
3826  result.location, elemType, builder.getZeroAttr(elemType));
3827  build(builder, result, vectorType, source, indices, padding, inBounds);
3828 }
3829 
3830 template <typename EmitFun>
3831 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
3832  EmitFun emitOpError) {
3833  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
3834  for (auto expr : permutationMap.getResults()) {
3835  auto dim = dyn_cast<AffineDimExpr>(expr);
3836  auto zero = dyn_cast<AffineConstantExpr>(expr);
3837  if (zero) {
3838  if (zero.getValue() != 0) {
3839  return emitOpError(
3840  "requires a projected permutation_map (at most one dim or the zero "
3841  "constant can appear in each result)");
3842  }
3843  continue;
3844  }
3845  if (!dim) {
3846  return emitOpError("requires a projected permutation_map (at most one "
3847  "dim or the zero constant can appear in each result)");
3848  }
3849  if (seen[dim.getPosition()]) {
3850  return emitOpError(
3851  "requires a permutation_map that is a permutation (found one dim "
3852  "used more than once)");
3853  }
3854  seen[dim.getPosition()] = true;
3855  }
3856  return success();
3857 }
3858 
3859 static LogicalResult
3860 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
3861  VectorType vectorType, VectorType maskType,
3862  VectorType inferredMaskType, AffineMap permutationMap,
3863  ArrayAttr inBounds) {
3864  if (op->hasAttr("masked")) {
3865  return op->emitOpError("masked attribute has been removed. "
3866  "Use in_bounds instead.");
3867  }
3868 
3869  if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
3870  return op->emitOpError(
3871  "requires source to be a memref or ranked tensor type");
3872 
3873  auto elementType = shapedType.getElementType();
3874  DataLayout dataLayout = DataLayout::closest(op);
3875  if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
3876  // Memref or tensor has vector element type.
3877  unsigned sourceVecSize =
3878  dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
3879  vectorElementType.getShape().back();
3880  unsigned resultVecSize =
3881  dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
3882  vectorType.getShape().back();
3883  if (resultVecSize % sourceVecSize != 0)
3884  return op->emitOpError(
3885  "requires the bitwidth of the minor 1-D vector to be an integral "
3886  "multiple of the bitwidth of the minor 1-D vector of the source");
3887 
3888  unsigned sourceVecEltRank = vectorElementType.getRank();
3889  unsigned resultVecRank = vectorType.getRank();
3890  if (sourceVecEltRank > resultVecRank)
3891  return op->emitOpError(
3892  "requires source vector element and vector result ranks to match.");
3893  unsigned rankOffset = resultVecRank - sourceVecEltRank;
3894  // Check that permutation map results match 'rankOffset' of vector type.
3895  if (permutationMap.getNumResults() != rankOffset)
3896  return op->emitOpError("requires a permutation_map with result dims of "
3897  "the same rank as the vector type");
3898 
3899  if (maskType)
3900  return op->emitOpError("does not support masks with vector element type");
3901  } else {
3902  // Memref or tensor has scalar element type.
3903  unsigned minorSize =
3904  vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
3905  unsigned resultVecSize =
3906  dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
3907  if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
3908  return op->emitOpError(
3909  "requires the bitwidth of the minor 1-D vector to be an integral "
3910  "multiple of the bitwidth of the source element type");
3911 
3912  // Check that permutation map results match rank of vector type.
3913  if (permutationMap.getNumResults() != vectorType.getRank())
3914  return op->emitOpError("requires a permutation_map with result dims of "
3915  "the same rank as the vector type");
3916  }
3917 
3918  if (permutationMap.getNumSymbols() != 0)
3919  return op->emitOpError("requires permutation_map without symbols");
3920 
3921  if (permutationMap.getNumInputs() != shapedType.getRank())
3922  return op->emitOpError("requires a permutation_map with input dims of the "
3923  "same rank as the source type");
3924 
3925  if (maskType && maskType != inferredMaskType)
3926  return op->emitOpError("inferred mask type (")
3927  << inferredMaskType << ") and mask operand type (" << maskType
3928  << ") don't match";
3929 
3930  if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
3931  return op->emitOpError("expects the in_bounds attr of same rank "
3932  "as permutation_map results: ")
3933  << AffineMapAttr::get(permutationMap)
3934  << " vs inBounds of size: " << inBounds.size();
3935  for (unsigned int i = 0, e = permutationMap.getNumResults(); i < e; ++i)
3936  if (isa<AffineConstantExpr>(permutationMap.getResult(i)) &&
3937  !llvm::cast<BoolAttr>(inBounds.getValue()[i]).getValue())
3938  return op->emitOpError("requires broadcast dimensions to be in-bounds");
3939 
3940  return success();
3941 }
3942 
3943 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
3944  SmallVector<StringRef, 3> elidedAttrs;
3945  elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
3946  if (op.getPermutationMap().isMinorIdentity())
3947  elidedAttrs.push_back(op.getPermutationMapAttrName());
3948  // Elide in_bounds attribute if all dims are out-of-bounds.
3949  if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
3950  elidedAttrs.push_back(op.getInBoundsAttrName());
3951  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
3952 }
3953 
3955  p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
3956  if (getMask())
3957  p << ", " << getMask();
3958  printTransferAttrs(p, *this);
3959  p << " : " << getShapedType() << ", " << getVectorType();
3960 }
3961 
3962 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
3963  AffineMap permMap) {
3964  auto i1Type = IntegerType::get(permMap.getContext(), 1);
3965  AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
3966  assert(invPermMap && "Inversed permutation map couldn't be computed");
3967  SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
3968 
3969  SmallVector<bool> scalableDims =
3970  applyPermutationMap(invPermMap, vecType.getScalableDims());
3971 
3972  return VectorType::get(maskShape, i1Type, scalableDims);
3973 }
3974 
3975 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
3976  auto &builder = parser.getBuilder();
3977  SMLoc typesLoc;
3978  OpAsmParser::UnresolvedOperand sourceInfo;
3980  OpAsmParser::UnresolvedOperand paddingInfo;
3981  SmallVector<Type, 2> types;
3983  // Parsing with support for paddingValue.
3984  if (parser.parseOperand(sourceInfo) ||
3985  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
3986  parser.parseComma() || parser.parseOperand(paddingInfo))
3987  return failure();
3988  ParseResult hasMask = parser.parseOptionalComma();
3989  if (hasMask.succeeded()) {
3990  if (parser.parseOperand(maskInfo))
3991  return failure();
3992  }
3993  if (parser.parseOptionalAttrDict(result.attributes) ||
3994  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
3995  return failure();
3996  if (types.size() != 2)
3997  return parser.emitError(typesLoc, "requires two types");
3998  auto indexType = builder.getIndexType();
3999  auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4000  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4001  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4002  VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4003  if (!vectorType)
4004  return parser.emitError(typesLoc, "requires vector type");
4005  auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
4006  Attribute permMapAttr = result.attributes.get(permMapAttrName);
4007  AffineMap permMap;
4008  if (!permMapAttr) {
4009  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4010  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4011  } else {
4012  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4013  }
4014  auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
4015  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4016  if (!inBoundsAttr) {
4017  result.addAttribute(inBoundsAttrName,
4018  builder.getBoolArrayAttr(
4019  SmallVector<bool>(permMap.getNumResults(), false)));
4020  }
4021  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4022  parser.resolveOperands(indexInfo, indexType, result.operands) ||
4023  parser.resolveOperand(paddingInfo, shapedType.getElementType(),
4024  result.operands))
4025  return failure();
4026  if (hasMask.succeeded()) {
4027  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4028  return parser.emitError(
4029  maskInfo.location, "does not support masks with vector element type");
4030  if (vectorType.getRank() != permMap.getNumResults()) {
4031  return parser.emitError(typesLoc,
4032  "expected the same rank for the vector and the "
4033  "results of the permutation map");
4034  }
4035  // Instead of adding the mask type as an op type, compute it based on the
4036  // vector type and the permutation map (to keep the type signature small).
4037  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4038  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4039  return failure();
4040  }
4041  result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4042  builder.getDenseI32ArrayAttr(
4043  {1, static_cast<int32_t>(indexInfo.size()), 1,
4044  static_cast<int32_t>(hasMask.succeeded())}));
4045  return parser.addTypeToList(vectorType, result.types);
4046 }
4047 
4048 LogicalResult TransferReadOp::verify() {
4049  // Consistency of elemental types in source and vector.
4050  ShapedType shapedType = getShapedType();
4051  VectorType vectorType = getVectorType();
4052  VectorType maskType = getMaskType();
4053  auto paddingType = getPadding().getType();
4054  auto permutationMap = getPermutationMap();
4055  VectorType inferredMaskType =
4056  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4057  : VectorType();
4058  auto sourceElementType = shapedType.getElementType();
4059 
4060  if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
4061  return emitOpError("requires ") << shapedType.getRank() << " indices";
4062 
4063  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4064  shapedType, vectorType, maskType,
4065  inferredMaskType, permutationMap, getInBounds())))
4066  return failure();
4067 
4068  if (auto sourceVectorElementType =
4069  llvm::dyn_cast<VectorType>(sourceElementType)) {
4070  // Source has vector element type.
4071  // Check that 'sourceVectorElementType' and 'paddingType' types match.
4072  if (sourceVectorElementType != paddingType)
4073  return emitOpError(
4074  "requires source element type and padding type to match.");
4075 
4076  } else {
4077  // Check that 'paddingType' is valid to store in a vector type.
4078  if (!VectorType::isValidElementType(paddingType))
4079  return emitOpError("requires valid padding vector elemental type");
4080 
4081  // Check that padding type and vector element types match.
4082  if (paddingType != sourceElementType)
4083  return emitOpError(
4084  "requires formal padding and source of the same elemental type");
4085  }
4086 
4087  return verifyPermutationMap(permutationMap,
4088  [&](Twine t) { return emitOpError(t); });
4089 }
4090 
4091 // MaskableOpInterface methods.
4092 
4093 /// Returns the mask type expected by this operation. Mostly used for
4094 /// verification purposes. It requires the operation to be vectorized."
4095 Type TransferReadOp::getExpectedMaskType() {
4096  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4097 }
4098 
4099 template <typename TransferOp>
4100 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4101  // TODO: support more aggressive createOrFold on:
4102  // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4103  if (op.getShapedType().isDynamicDim(indicesIdx))
4104  return false;
4105  Value index = op.getIndices()[indicesIdx];
4106  std::optional<int64_t> cstOp = getConstantIntValue(index);
4107  if (!cstOp.has_value())
4108  return false;
4109 
4110  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4111  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4112 
4113  return cstOp.value() + vectorSize <= sourceSize;
4114 }
4115 
4116 template <typename TransferOp>
4117 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4118  // TODO: support 0-d corner case.
4119  // TODO: Be less conservative.
4120  if (op.getTransferRank() == 0)
4121  return failure();
4122  AffineMap permutationMap = op.getPermutationMap();
4123  bool changed = false;
4124  SmallVector<bool, 4> newInBounds;
4125  newInBounds.reserve(op.getTransferRank());
4126  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
4127  // Already marked as in-bounds, nothing to see here.
4128  if (op.isDimInBounds(i)) {
4129  newInBounds.push_back(true);
4130  continue;
4131  }
4132  // Currently out-of-bounds, check whether we can statically determine it is
4133  // inBounds.
4134  auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
4135  assert(dimExpr && "Broadcast dims must be in-bounds");
4136  auto inBounds =
4137  isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
4138  newInBounds.push_back(inBounds);
4139  // We commit the pattern if it is "more inbounds".
4140  changed |= inBounds;
4141  }
4142  if (!changed)
4143  return failure();
4144  // OpBuilder is only used as a helper to build an I64ArrayAttr.
4145  OpBuilder b(op.getContext());
4146  op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
4147  return success();
4148 }
4149 
4150 template <typename TransferOp>
4151 static LogicalResult foldTransferFullMask(TransferOp op) {
4152  auto mask = op.getMask();
4153  if (!mask)
4154  return failure();
4155 
4156  if (getMaskFormat(mask) != MaskFormat::AllTrue)
4157  return failure();
4158 
4159  op.getMaskMutable().clear();
4160  return success();
4161 }
4162 
4163 /// ```
4164 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4165 /// : vector<1x4xf32>, tensor<4x4xf32>
4166 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
4167 /// : tensor<4x4xf32>, vector<1x4xf32>
4168 /// ```
4169 /// -> Folds into
4170 /// ```
4171 /// %v0
4172 /// ```
4173 static Value foldRAW(TransferReadOp readOp) {
4174  if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4175  return {};
4176  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4177  while (defWrite) {
4178  if (checkSameValueRAW(defWrite, readOp))
4179  return defWrite.getVector();
4181  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4182  cast<VectorTransferOpInterface>(readOp.getOperation())))
4183  break;
4184  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4185  }
4186  return {};
4187 }
4188 
4189 OpFoldResult TransferReadOp::fold(FoldAdaptor) {
4190  if (Value vec = foldRAW(*this))
4191  return vec;
4192  /// transfer_read(memrefcast) -> transfer_read
4193  if (succeeded(foldTransferInBoundsAttribute(*this)))
4194  return getResult();
4195  if (succeeded(foldTransferFullMask(*this)))
4196  return getResult();
4197  if (succeeded(memref::foldMemRefCast(*this)))
4198  return getResult();
4199  if (succeeded(tensor::foldTensorCast(*this)))
4200  return getResult();
4201  return OpFoldResult();
4202 }
4203 
4204 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4205  return llvm::to_vector<4>(getVectorType().getShape());
4206 }
4207 
4208 void TransferReadOp::getEffects(
4210  &effects) {
4211  if (llvm::isa<MemRefType>(getShapedType()))
4212  effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(),
4214 }
4215 
4216 namespace {
4217 /// Store to load forwarding for transfer operations with permuation maps.
4218 /// Even if the permutation maps are different we can still propagate the store
4219 /// into the load if the size of the dimensions read and written match. Then we
4220 /// can replace the transfer_read + transfer_write by vector.broadcast and
4221 /// vector.transpose.
4222 /// Example:
4223 /// ```
4224 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
4225 /// {in_bounds = [true, true],
4226 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
4227 /// vector<4x1xf32>, tensor<4x4x4xf32>
4228 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
4229 /// {in_bounds = [true, true, true, true],
4230 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
4231 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
4232 /// ```
4233 /// To:
4234 /// ```
4235 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
4236 /// %r = vector.transpose %0, [3, 0, 2, 1] :
4237 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
4238 /// ```
4239 struct TransferReadAfterWriteToBroadcast
4240  : public OpRewritePattern<TransferReadOp> {
4242 
4243  LogicalResult matchAndRewrite(TransferReadOp readOp,
4244  PatternRewriter &rewriter) const override {
4245  if (readOp.hasOutOfBoundsDim() ||
4246  !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4247  return failure();
4248  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4249  if (!defWrite)
4250  return failure();
4251  // TODO: If the written transfer chunk is a superset of the read transfer
4252  // chunk we could do an extract_strided_slice.
4253  if (readOp.getTransferChunkAccessed() !=
4254  defWrite.getTransferChunkAccessed())
4255  return failure();
4256  // TODO: Support cases where a dim is explicitly written but implicitly
4257  // read (i.e., a unit dim that is rank reduced).
4258  if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
4259  getUnusedDimsBitVector({defWrite.getPermutationMap()}))
4260  return failure();
4261  if (readOp.getIndices() != defWrite.getIndices() ||
4262  readOp.getMask() != defWrite.getMask())
4263  return failure();
4264  Value vec = defWrite.getVector();
4265  // TODO: loop through the chain of transfer_write if we can prove that they
4266  // don't overlap with the transfer_read. This requires improving
4267  // `isDisjointTransferIndices` helper.
4268  AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4269  AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
4270  AffineMap map = readMap.compose(writeMap);
4271  if (map.getNumResults() == 0)
4272  return failure();
4273  // Calculate the permutation to apply to go from the vector stored to the
4274  // vector read.
4275  SmallVector<unsigned> permutation;
4276  if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
4277  return failure();
4278 
4279  Location loc = readOp.getLoc();
4280  // Calculate the broadcast shape by applying the reverse permutation to the
4281  // final shape we want.
4282  ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
4283  SmallVector<int64_t> broadcastShape(destShape.size());
4284  SmallVector<bool> broadcastScalableFlags(destShape.size());
4285  for (const auto &pos : llvm::enumerate(permutation)) {
4286  broadcastShape[pos.value()] = destShape[pos.index()];
4287  broadcastScalableFlags[pos.value()] =
4288  readOp.getVectorType().getScalableDims()[pos.index()];
4289  }
4290  VectorType broadcastedType = VectorType::get(
4291  broadcastShape, defWrite.getVectorType().getElementType(),
4292  broadcastScalableFlags);
4293  vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
4294  SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
4295  rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
4296  transposePerm);
4297  return success();
4298  }
4299 };
4300 } // namespace
4301 
4302 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4303  MLIRContext *context) {
4304  results.add<TransferReadAfterWriteToBroadcast>(context);
4305 }
4306 
4307 //===----------------------------------------------------------------------===//
4308 // TransferWriteOp
4309 //===----------------------------------------------------------------------===//
4310 
4311 /// 1. Builder with type inference.
4312 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4313  Value vector, Value dest, ValueRange indices,
4314  AffineMapAttr permutationMapAttr,
4315  /*optional*/ Value mask,
4316  /*optional*/ ArrayAttr inBoundsAttr) {
4317  Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
4318  build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4319  mask, inBoundsAttr);
4320 }
4321 
4322 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
4323 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4324  Value vector, Value dest, ValueRange indices,
4325  AffineMapAttr permutationMapAttr,
4326  /*optional*/ ArrayAttr inBoundsAttr) {
4327  build(builder, result, vector, dest, indices, permutationMapAttr,
4328  /*mask=*/Value(), inBoundsAttr);
4329 }
4330 
4331 /// 3. Builder with type inference that sets an empty mask (variant without
4332 /// attrs)
4333 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4334  Value vector, Value dest, ValueRange indices,
4335  AffineMap permutationMap,
4336  std::optional<ArrayRef<bool>> inBounds) {
4337  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4338  auto inBoundsAttr =
4339  (inBounds && !inBounds.value().empty())
4340  ? builder.getBoolArrayAttr(inBounds.value())
4342  llvm::cast<VectorType>(vector.getType()).getRank(), false));
4343  build(builder, result, vector, dest, indices, permutationMapAttr,
4344  /*mask=*/Value(), inBoundsAttr);
4345 }
4346 
4347 /// 4. Builder with type inference that sets an empty mask and sets permutation
4348 /// map to 'getMinorIdentityMap'.
4349 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4350  Value vector, Value dest, ValueRange indices,
4351  std::optional<ArrayRef<bool>> inBounds) {
4352  auto vectorType = llvm::cast<VectorType>(vector.getType());
4353  AffineMap permutationMap = getTransferMinorIdentityMap(
4354  llvm::cast<ShapedType>(dest.getType()), vectorType);
4355  build(builder, result, vector, dest, indices, permutationMap, inBounds);
4356 }
4357 
4358 ParseResult TransferWriteOp::parse(OpAsmParser &parser,
4359  OperationState &result) {
4360  auto &builder = parser.getBuilder();
4361  SMLoc typesLoc;
4362  OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
4364  SmallVector<Type, 2> types;
4366  if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
4367  parser.parseOperand(sourceInfo) ||
4368  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
4369  return failure();
4370  ParseResult hasMask = parser.parseOptionalComma();
4371  if (hasMask.succeeded() && parser.parseOperand(maskInfo))
4372  return failure();
4373  if (parser.parseOptionalAttrDict(result.attributes) ||
4374  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4375  return failure();
4376  if (types.size() != 2)
4377  return parser.emitError(typesLoc, "requires two types");
4378  auto indexType = builder.getIndexType();
4379  VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4380  if (!vectorType)
4381  return parser.emitError(typesLoc, "requires vector type");
4382  ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4383  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4384  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4385  auto permMapAttrName =
4386  TransferWriteOp::getPermutationMapAttrName(result.name);
4387  auto permMapAttr = result.attributes.get(permMapAttrName);
4388  AffineMap permMap;
4389  if (!permMapAttr) {
4390  permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4391  result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4392  } else {
4393  permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4394  }
4395  auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
4396  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4397  if (!inBoundsAttr) {
4398  result.addAttribute(inBoundsAttrName,
4399  builder.getBoolArrayAttr(
4400  SmallVector<bool>(permMap.getNumResults(), false)));
4401  }
4402  if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
4403  parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4404  parser.resolveOperands(indexInfo, indexType, result.operands))
4405  return failure();
4406  if (hasMask.succeeded()) {
4407  if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4408  return parser.emitError(
4409  maskInfo.location, "does not support masks with vector element type");
4410  if (vectorType.getRank() != permMap.getNumResults()) {
4411  return parser.emitError(typesLoc,
4412  "expected the same rank for the vector and the "
4413  "results of the permutation map");
4414  }
4415  auto maskType = inferTransferOpMaskType(vectorType, permMap);
4416  if (parser.resolveOperand(maskInfo, maskType, result.operands))
4417  return failure();
4418  }
4419  result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4420  builder.getDenseI32ArrayAttr(
4421  {1, 1, static_cast<int32_t>(indexInfo.size()),
4422  static_cast<int32_t>(hasMask.succeeded())}));
4423  return failure(llvm::isa<RankedTensorType>(shapedType) &&
4424  parser.addTypeToList(shapedType, result.types));
4425 }
4426 
4428  p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
4429  if (getMask())
4430  p << ", " << getMask();
4431  printTransferAttrs(p, *this);
4432  p << " : " << getVectorType() << ", " << getShapedType();
4433 }
4434 
4435 LogicalResult TransferWriteOp::verify() {
4436  // Consistency of elemental types in shape and vector.
4437  ShapedType shapedType = getShapedType();
4438  VectorType vectorType = getVectorType();
4439  VectorType maskType = getMaskType();
4440  auto permutationMap = getPermutationMap();
4441  VectorType inferredMaskType =
4442  maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4443  : VectorType();
4444 
4445  if (llvm::size(getIndices()) != shapedType.getRank())
4446  return emitOpError("requires ") << shapedType.getRank() << " indices";
4447 
4448  // We do not allow broadcast dimensions on TransferWriteOps for the moment,
4449  // as the semantics is unclear. This can be revisited later if necessary.
4450  if (hasBroadcastDim())
4451  return emitOpError("should not have broadcast dimensions");
4452 
4453  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4454  shapedType, vectorType, maskType,
4455  inferredMaskType, permutationMap, getInBounds())))
4456  return failure();
4457 
4458  return verifyPermutationMap(permutationMap,
4459  [&](Twine t) { return emitOpError(t); });
4460 }
4461 
4462 // MaskableOpInterface methods.
4463 
4464 /// Returns the mask type expected by this operation. Mostly used for
4465 /// verification purposes.
4466 Type TransferWriteOp::getExpectedMaskType() {
4467  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4468 }
4469 
4470 /// Fold:
4471 /// ```
4472 /// %t1 = ...
4473 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
4474 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4475 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
4476 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4477 /// ```
4478 ///
4479 /// into:
4480 ///
4481 /// ```
4482 /// %t0
4483 /// ```
4484 ///
4485 /// The producer of t1 may or may not be DCE'd depending on whether it is a
4486 /// block argument or has side effects.
4487 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4489  SmallVectorImpl<OpFoldResult> &results) {
4490  // TODO: support 0-d corner case.
4491  if (write.getTransferRank() == 0)
4492  return failure();
4493  auto rankedTensorType =
4494  llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4495  // If not operating on tensors, bail.
4496  if (!rankedTensorType)
4497  return failure();
4498  // If no read, bail.
4499  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4500  if (!read)
4501  return failure();
4502  // TODO: support 0-d corner case.
4503  if (read.getTransferRank() == 0)
4504  return failure();
4505  // For now, only accept minor identity. Future: composition is minor identity.
4506  if (!read.getPermutationMap().isMinorIdentity() ||
4507  !write.getPermutationMap().isMinorIdentity())
4508  return failure();
4509  // Bail on mismatching ranks.
4510  if (read.getTransferRank() != write.getTransferRank())
4511  return failure();
4512  // Bail on potential out-of-bounds accesses.
4513  if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4514  return failure();
4515  // Tensor types must be the same.
4516  if (read.getSource().getType() != rankedTensorType)
4517  return failure();
4518  // Vector types must be the same.
4519  if (read.getVectorType() != write.getVectorType())
4520  return failure();
4521  // Vector and Tensor shapes must match.
4522  if (read.getVectorType().getShape() != rankedTensorType.getShape())
4523  return failure();
4524  // If any index is nonzero.
4525  auto isNotConstantZero = [](Value v) {
4526  auto cstOp = getConstantIntValue(v);
4527  return !cstOp.has_value() || cstOp.value() != 0;
4528  };
4529  if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4530  llvm::any_of(write.getIndices(), isNotConstantZero))
4531  return failure();
4532  // Success.
4533  results.push_back(read.getSource());
4534  return success();
4535 }
4536 
4537 static bool checkSameValueWAR(vector::TransferReadOp read,
4538  vector::TransferWriteOp write) {
4539  return read.getSource() == write.getSource() &&
4540  read.getIndices() == write.getIndices() &&
4541  read.getPermutationMap() == write.getPermutationMap() &&
4542  read.getVectorType() == write.getVectorType() && !read.getMask() &&
4543  !write.getMask();
4544 }
4545 /// Fold transfer_write write after read:
4546 /// ```
4547 /// %t0 = ...
4548 /// %v = vector.transfer_read %t0[%c0...] :
4549 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4550 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
4551 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4552 /// ```
4553 ///
4554 /// into:
4555 ///
4556 /// ```
4557 /// %t0
4558 /// ```
4559 static LogicalResult foldWAR(TransferWriteOp write,
4560  SmallVectorImpl<OpFoldResult> &results) {
4561  if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4562  return failure();
4563  auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4564  if (!read)
4565  return failure();
4566 
4567  if (!checkSameValueWAR(read, write))
4568  return failure();
4569  results.push_back(read.getSource());
4570  return success();
4571 }
4572 
4573 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4574  SmallVectorImpl<OpFoldResult> &results) {
4575  if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
4576  return success();
4577  if (succeeded(foldWAR(*this, results)))
4578  return success();
4579  if (succeeded(foldTransferInBoundsAttribute(*this)))
4580  return success();
4581  if (succeeded(foldTransferFullMask(*this)))
4582  return success();
4583  return memref::foldMemRefCast(*this);
4584 }
4585 
4586 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4587  return llvm::to_vector<4>(getVectorType().getShape());
4588 }
4589 
4590 void TransferWriteOp::getEffects(
4592  &effects) {
4593  if (llvm::isa<MemRefType>(getShapedType()))
4594  effects.emplace_back(MemoryEffects::Write::get(), &getSourceMutable(),
4596 }
4597 
4598 namespace {
4599 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
4600 /// DCE
4601 /// ```
4602 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4603 /// : vector<1x4xf32>, tensor<4x4xf32>
4604 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
4605 /// : vector<1x4xf32>, tensor<4x4xf32>
4606 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4607 /// : vector<1x4xf32>, tensor<4x4xf32>
4608 /// ```
4609 ///
4610 /// into:
4611 ///
4612 /// ```
4613 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4614 /// : vector<1x4xf32>, tensor<4x4xf32>
4615 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
4616 /// : vector<1x4xf32>, tensor<4x4xf32>
4617 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4618 /// : vector<1x4xf32>, tensor<4x4xf32>
4619 /// ```
4620 ///
4621 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
4622 /// any other uses.
4623 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
4624 public:
4626  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
4627  PatternRewriter &rewriter) const override {
4628  if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4629  return failure();
4630  vector::TransferWriteOp writeToModify = writeOp;
4631 
4632  auto defWrite =
4633  writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4634  while (defWrite) {
4635  if (checkSameValueWAW(writeOp, defWrite)) {
4636  rewriter.modifyOpInPlace(writeToModify, [&]() {
4637  writeToModify.getSourceMutable().assign(defWrite.getSource());
4638  });
4639  return success();
4640  }
4642  cast<VectorTransferOpInterface>(defWrite.getOperation()),
4643  cast<VectorTransferOpInterface>(writeOp.getOperation())))
4644  break;
4645  // If the previous write op doesn't have any other use we an safely look
4646  // at the previous store to see if it can be removed.
4647  if (!defWrite->hasOneUse())
4648  break;
4649  writeToModify = defWrite;
4650  defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4651  }
4652  return failure();
4653  }
4654 };
4655 
4656 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
4657 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
4658 /// overwritten and inserted into another tensor. After this rewrite, the
4659 /// operations bufferize in-place since all of them work on the same slice.
4660 ///
4661 /// For example:
4662 /// ```mlir
4663 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
4664 /// : vector<8x16xf32>, tensor<8x16xf32>
4665 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
4666 /// : tensor<8x16xf32> to tensor<?x?xf32>
4667 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4668 /// : tensor<?x?xf32> into tensor<27x37xf32>
4669 /// ```
4670 /// folds to
4671 /// ```mlir
4672 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4673 /// : tensor<27x37xf32> to tensor<?x?xf32>
4674 /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
4675 /// : vector<8x16xf32>, tensor<?x?xf32>
4676 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4677 /// : tensor<?x?xf32> into tensor<27x37xf32>
4678 /// ```
4679 struct SwapExtractSliceOfTransferWrite
4680  : public OpRewritePattern<tensor::InsertSliceOp> {
4681 public:
4683 
4684  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
4685  PatternRewriter &rewriter) const override {
4686  if (!insertOp.hasUnitStride())
4687  return failure();
4688  auto extractOp =
4689  insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4690  if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4691  return failure();
4692  auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4693  if (!transferOp || !transferOp->hasOneUse())
4694  return failure();
4695 
4696  // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
4697  // rank-reducing.
4698  if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4699  return rewriter.notifyMatchFailure(insertOp,
4700  "use-def chain is rank-reducing");
4701  }
4702 
4703  // Fail if tensor::ExtractSliceOp has non-zero offset.
4704  if (!extractOp.hasZeroOffset()) {
4705  return rewriter.notifyMatchFailure(insertOp,
4706  "ExtractSliceOp has non-zero offset");
4707  }
4708 
4709  // Fail if tensor::TransferWriteOp has non-zero offset.
4710  if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
4711  return getConstantIntValue(value) == static_cast<int64_t>(0);
4712  })) {
4713  return rewriter.notifyMatchFailure(insertOp,
4714  "TranferWriteOp has non-zero offset");
4715  }
4716 
4717  // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
4718  if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
4719  return rewriter.notifyMatchFailure(
4720  insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
4721  }
4722 
4723  for (auto [insertSize, extractSize] :
4724  llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
4725  if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
4726  return rewriter.notifyMatchFailure(
4727  insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
4728  }
4729  }
4730 
4731  // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
4732  assert(transferOp.getVectorType().hasStaticShape() &&
4733  "expected vector to have a static shape");
4734  ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
4736  transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
4737  if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
4738  return rewriter.notifyMatchFailure(
4739  insertOp, "TransferWriteOp may not write the full tensor.");
4740  }
4741 
4742  // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
4743  // Set all in_bounds to false and let the folder infer them.
4744  SmallVector<bool> newInBounds(vectorShape.size(), false);
4745  auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
4746  extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
4747  insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
4748  insertOp.getMixedStrides());
4749  auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
4750  transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
4751  transferOp.getIndices(), transferOp.getPermutationMapAttr(),
4752  rewriter.getBoolArrayAttr(newInBounds));
4753  rewriter.modifyOpInPlace(insertOp, [&]() {
4754  insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
4755  });
4756  return success();
4757  }
4758 };
4759 
4760 } // namespace
4761 
4762 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
4763  MLIRContext *context) {
4764  results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
4765 }
4766 
4767 //===----------------------------------------------------------------------===//
4768 // LoadOp
4769 //===----------------------------------------------------------------------===//
4770 
4771 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
4772  MemRefType memRefTy) {
4773  if (!isLastMemrefDimUnitStride(memRefTy))
4774  return op->emitOpError("most minor memref dim must have unit stride");
4775  return success();
4776 }
4777 
4778 LogicalResult vector::LoadOp::verify() {
4779  VectorType resVecTy = getVectorType();
4780  MemRefType memRefTy = getMemRefType();
4781 
4782  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
4783  return failure();
4784 
4785  // Checks for vector memrefs.
4786  Type memElemTy = memRefTy.getElementType();
4787  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4788  if (memVecTy != resVecTy)
4789  return emitOpError("base memref and result vector types should match");
4790  memElemTy = memVecTy.getElementType();
4791  }
4792 
4793  if (resVecTy.getElementType() != memElemTy)
4794  return emitOpError("base and result element types should match");
4795  if (llvm::size(getIndices()) != memRefTy.getRank())
4796  return emitOpError("requires ") << memRefTy.getRank() << " indices";
4797  return success();
4798 }
4799 
4800 OpFoldResult LoadOp::fold(FoldAdaptor) {
4801  if (succeeded(memref::foldMemRefCast(*this)))
4802  return getResult();
4803  return OpFoldResult();
4804 }
4805 
4806 //===----------------------------------------------------------------------===//
4807 // StoreOp
4808 //===----------------------------------------------------------------------===//
4809 
4810 LogicalResult vector::StoreOp::verify() {
4811  VectorType valueVecTy = getVectorType();
4812  MemRefType memRefTy = getMemRefType();
4813 
4814  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
4815  return failure();
4816 
4817  // Checks for vector memrefs.
4818  Type memElemTy = memRefTy.getElementType();
4819  if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4820  if (memVecTy != valueVecTy)
4821  return emitOpError(
4822  "base memref and valueToStore vector types should match");
4823  memElemTy = memVecTy.getElementType();
4824  }
4825 
4826  if (valueVecTy.getElementType() != memElemTy)
4827  return emitOpError("base and valueToStore element type should match");
4828  if (llvm::size(getIndices()) != memRefTy.getRank())
4829  return emitOpError("requires ") << memRefTy.getRank() << " indices";
4830  return success();
4831 }
4832 
4833 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
4834  SmallVectorImpl<OpFoldResult> &results) {
4835  return memref::foldMemRefCast(*this);
4836 }
4837 
4838 //===----------------------------------------------------------------------===//
4839 // MaskedLoadOp
4840 //===----------------------------------------------------------------------===//
4841 
4842 LogicalResult MaskedLoadOp::verify() {
4843  VectorType maskVType = getMaskVectorType();
4844  VectorType passVType = getPassThruVectorType();
4845  VectorType resVType = getVectorType();
4846  MemRefType memType = getMemRefType();
4847 
4848  if (resVType.getElementType() != memType.getElementType())
4849  return emitOpError("base and result element type should match");
4850  if (llvm::size(getIndices()) != memType.getRank())
4851  return emitOpError("requires ") << memType.getRank() << " indices";
4852  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4853  return emitOpError("expected result dim to match mask dim");
4854  if (resVType != passVType)
4855  return emitOpError("expected pass_thru of same type as result type");
4856  return success();
4857 }
4858 
4859 namespace {
4860 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
4861 public:
4863  LogicalResult matchAndRewrite(MaskedLoadOp load,
4864  PatternRewriter &rewriter) const override {
4865  switch (getMaskFormat(load.getMask())) {
4866  case MaskFormat::AllTrue:
4867  rewriter.replaceOpWithNewOp<vector::LoadOp>(
4868  load, load.getType(), load.getBase(), load.getIndices());
4869  return success();
4870  case MaskFormat::AllFalse:
4871  rewriter.replaceOp(load, load.getPassThru());
4872  return success();
4873  case MaskFormat::Unknown:
4874  return failure();
4875  }
4876  llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
4877  }
4878 };
4879 } // namespace
4880 
4881 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4882  MLIRContext *context) {
4883  results.add<MaskedLoadFolder>(context);
4884 }
4885 
4886 OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
4887  if (succeeded(memref::foldMemRefCast(*this)))
4888  return getResult();
4889  return OpFoldResult();
4890 }
4891 
4892 //===----------------------------------------------------------------------===//
4893 // MaskedStoreOp
4894 //===----------------------------------------------------------------------===//
4895 
4896 LogicalResult MaskedStoreOp::verify() {
4897  VectorType maskVType = getMaskVectorType();
4898  VectorType valueVType = getVectorType();
4899  MemRefType memType = getMemRefType();
4900 
4901  if (valueVType.getElementType() != memType.getElementType())
4902  return emitOpError("base and valueToStore element type should match");
4903  if (llvm::size(getIndices()) != memType.getRank())
4904  return emitOpError("requires ") << memType.getRank() << " indices";
4905  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4906  return emitOpError("expected valueToStore dim to match mask dim");
4907  return success();
4908 }
4909 
4910 namespace {
4911 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
4912 public:
4914  LogicalResult matchAndRewrite(MaskedStoreOp store,
4915  PatternRewriter &rewriter) const override {
4916  switch (getMaskFormat(store.getMask())) {
4917  case MaskFormat::AllTrue:
4918  rewriter.replaceOpWithNewOp<vector::StoreOp>(
4919  store, store.getValueToStore(), store.getBase(), store.getIndices());
4920  return success();
4921  case MaskFormat::AllFalse:
4922  rewriter.eraseOp(store);
4923  return success();
4924  case MaskFormat::Unknown:
4925  return failure();
4926  }
4927  llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
4928  }
4929 };
4930 } // namespace
4931 
4932 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
4933  MLIRContext *context) {
4934  results.add<MaskedStoreFolder>(context);
4935 }
4936 
4937 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
4938  SmallVectorImpl<OpFoldResult> &results) {
4939  return memref::foldMemRefCast(*this);
4940 }
4941 
4942 //===----------------------------------------------------------------------===//
4943 // GatherOp
4944 //===----------------------------------------------------------------------===//
4945 
4946 LogicalResult GatherOp::verify() {
4947  VectorType indVType = getIndexVectorType();
4948  VectorType maskVType = getMaskVectorType();
4949  VectorType resVType = getVectorType();
4950  ShapedType baseType = getBaseType();
4951 
4952  if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
4953  return emitOpError("requires base to be a memref or ranked tensor type");
4954 
4955  if (resVType.getElementType() != baseType.getElementType())
4956  return emitOpError("base and result element type should match");
4957  if (llvm::size(getIndices()) != baseType.getRank())
4958  return emitOpError("requires ") << baseType.getRank() << " indices";
4959  if (resVType.getShape() != indVType.getShape())
4960  return emitOpError("expected result dim to match indices dim");
4961  if (resVType.getShape() != maskVType.getShape())
4962  return emitOpError("expected result dim to match mask dim");
4963  if (resVType != getPassThruVectorType())
4964  return emitOpError("expected pass_thru of same type as result type");
4965  return success();
4966 }
4967 
4968 // MaskableOpInterface methods.
4969 
4970 /// Returns the mask type expected by this operation. Mostly used for
4971 /// verification purposes. It requires the operation to be vectorized."
4972 Type GatherOp::getExpectedMaskType() {
4973  auto vecType = this->getIndexVectorType();
4974  return VectorType::get(vecType.getShape(),
4975  IntegerType::get(vecType.getContext(), /*width=*/1),
4976  vecType.getScalableDims());
4977 }
4978 
4979 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
4980  return llvm::to_vector<4>(getVectorType().getShape());
4981 }
4982 
4983 namespace {
4984 class GatherFolder final : public OpRewritePattern<GatherOp> {
4985 public:
4987  LogicalResult matchAndRewrite(GatherOp gather,
4988  PatternRewriter &rewriter) const override {
4989  switch (getMaskFormat(gather.getMask())) {
4990  case MaskFormat::AllTrue:
4991  return failure(); // no unmasked equivalent
4992  case MaskFormat::AllFalse:
4993  rewriter.replaceOp(gather, gather.getPassThru());
4994  return success();
4995  case MaskFormat::Unknown:
4996  return failure();
4997  }
4998  llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
4999  }
5000 };
5001 } // namespace
5002 
5003 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
5004  MLIRContext *context) {
5005  results.add<GatherFolder>(context);
5006 }
5007 
5008 //===----------------------------------------------------------------------===//
5009 // ScatterOp
5010 //===----------------------------------------------------------------------===//
5011 
5012 LogicalResult ScatterOp::verify() {
5013  VectorType indVType = getIndexVectorType();
5014  VectorType maskVType = getMaskVectorType();
5015  VectorType valueVType = getVectorType();
5016  MemRefType memType = getMemRefType();
5017 
5018  if (valueVType.getElementType() != memType.getElementType())
5019  return emitOpError("base and valueToStore element type should match");
5020  if (llvm::size(getIndices()) != memType.getRank())
5021  return emitOpError("requires ") << memType.getRank() << " indices";
5022  if (valueVType.getDimSize(0) != indVType.getDimSize(0))
5023  return emitOpError("expected valueToStore dim to match indices dim");
5024  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5025  return emitOpError("expected valueToStore dim to match mask dim");
5026  return success();
5027 }
5028 
5029 namespace {
5030 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
5031 public:
5033  LogicalResult matchAndRewrite(ScatterOp scatter,
5034  PatternRewriter &rewriter) const override {
5035  switch (getMaskFormat(scatter.getMask())) {
5036  case MaskFormat::AllTrue:
5037  return failure(); // no unmasked equivalent
5038  case MaskFormat::AllFalse:
5039  rewriter.eraseOp(scatter);
5040  return success();
5041  case MaskFormat::Unknown:
5042  return failure();
5043  }
5044  llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
5045  }
5046 };
5047 } // namespace
5048 
5049 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
5050  MLIRContext *context) {
5051  results.add<ScatterFolder>(context);
5052 }
5053 
5054 //===----------------------------------------------------------------------===//
5055 // ExpandLoadOp
5056 //===----------------------------------------------------------------------===//
5057 
5058 LogicalResult ExpandLoadOp::verify() {
5059  VectorType maskVType = getMaskVectorType();
5060  VectorType passVType = getPassThruVectorType();
5061  VectorType resVType = getVectorType();
5062  MemRefType memType = getMemRefType();
5063 
5064  if (resVType.getElementType() != memType.getElementType())
5065  return emitOpError("base and result element type should match");
5066  if (llvm::size(getIndices()) != memType.getRank())
5067  return emitOpError("requires ") << memType.getRank() << " indices";
5068  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5069  return emitOpError("expected result dim to match mask dim");
5070  if (resVType != passVType)
5071  return emitOpError("expected pass_thru of same type as result type");
5072  return success();
5073 }
5074 
5075 namespace {
5076 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
5077 public:
5079  LogicalResult matchAndRewrite(ExpandLoadOp expand,
5080  PatternRewriter &rewriter) const override {
5081  switch (getMaskFormat(expand.getMask())) {
5082  case MaskFormat::AllTrue:
5083  rewriter.replaceOpWithNewOp<vector::LoadOp>(
5084  expand, expand.getType(), expand.getBase(), expand.getIndices());
5085  return success();
5086  case MaskFormat::AllFalse:
5087  rewriter.replaceOp(expand, expand.getPassThru());
5088  return success();
5089  case MaskFormat::Unknown:
5090  return failure();
5091  }
5092  llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
5093  }
5094 };
5095 } // namespace
5096 
5097 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5098  MLIRContext *context) {
5099  results.add<ExpandLoadFolder>(context);
5100 }
5101 
5102 //===----------------------------------------------------------------------===//
5103 // CompressStoreOp
5104 //===----------------------------------------------------------------------===//
5105 
5106 LogicalResult CompressStoreOp::verify() {
5107  VectorType maskVType = getMaskVectorType();
5108  VectorType valueVType = getVectorType();
5109  MemRefType memType = getMemRefType();
5110 
5111  if (valueVType.getElementType() != memType.getElementType())
5112  return emitOpError("base and valueToStore element type should match");
5113  if (llvm::size(getIndices()) != memType.getRank())
5114  return emitOpError("requires ") << memType.getRank() << " indices";
5115  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5116  return emitOpError("expected valueToStore dim to match mask dim");
5117  return success();
5118 }
5119 
5120 namespace {
5121 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
5122 public:
5124  LogicalResult matchAndRewrite(CompressStoreOp compress,
5125  PatternRewriter &rewriter) const override {
5126  switch (getMaskFormat(compress.getMask())) {
5127  case MaskFormat::AllTrue:
5128  rewriter.replaceOpWithNewOp<vector::StoreOp>(
5129  compress, compress.getValueToStore(), compress.getBase(),
5130  compress.getIndices());
5131  return success();
5132  case MaskFormat::AllFalse:
5133  rewriter.eraseOp(compress);
5134  return success();
5135  case MaskFormat::Unknown:
5136  return failure();
5137  }
5138  llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
5139  }
5140 };
5141 } // namespace
5142 
5143 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5144  MLIRContext *context) {
5145  results.add<CompressStoreFolder>(context);
5146 }
5147 
5148 //===----------------------------------------------------------------------===//
5149 // ShapeCastOp
5150 //===----------------------------------------------------------------------===//
5151 
5152 /// Returns true if each element of 'a' is equal to the product of a contiguous
5153 /// sequence of the elements of 'b'. Returns false otherwise.
5154 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
5155  unsigned rankA = a.size();
5156  unsigned rankB = b.size();
5157  assert(rankA < rankB);
5158 
5159  auto isOne = [](int64_t v) { return v == 1; };
5160 
5161  // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5162  // casted to a 0-d vector.
5163  if (rankA == 0 && llvm::all_of(b, isOne))
5164  return true;
5165 
5166  unsigned i = 0;
5167  unsigned j = 0;
5168  while (i < rankA && j < rankB) {
5169  int64_t dimA = a[i];
5170  int64_t dimB = 1;
5171  while (dimB < dimA && j < rankB)
5172  dimB *= b[j++];
5173  if (dimA != dimB)
5174  break;
5175  ++i;
5176 
5177  // Handle the case when trailing dimensions are of size 1.
5178  // Include them into the contiguous sequence.
5179  if (i < rankA && llvm::all_of(a.slice(i), isOne))
5180  i = rankA;
5181  if (j < rankB && llvm::all_of(b.slice(j), isOne))
5182  j = rankB;
5183  }
5184 
5185  return i == rankA && j == rankB;
5186 }
5187 
5188 static LogicalResult verifyVectorShapeCast(Operation *op,
5189  VectorType sourceVectorType,
5190  VectorType resultVectorType) {
5191  // Check that element type is the same.
5192  if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5193  return op->emitOpError("source/result vectors must have same element type");
5194  auto sourceShape = sourceVectorType.getShape();
5195  auto resultShape = resultVectorType.getShape();
5196 
5197  // Check that product of source dim sizes matches product of result dim sizes.
5198  int64_t sourceDimProduct = std::accumulate(
5199  sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5200  int64_t resultDimProduct = std::accumulate(
5201  resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5202  if (sourceDimProduct != resultDimProduct)
5203  return op->emitOpError("source/result number of elements must match");
5204 
5205  // Check that expanding/contracting rank cases.
5206  unsigned sourceRank = sourceVectorType.getRank();
5207  unsigned resultRank = resultVectorType.getRank();
5208  if (sourceRank < resultRank) {
5209  if (!isValidShapeCast(sourceShape, resultShape))
5210  return op->emitOpError("invalid shape cast");
5211  } else if (sourceRank > resultRank) {
5212  if (!isValidShapeCast(resultShape, sourceShape))
5213  return op->emitOpError("invalid shape cast");
5214  }
5215 
5216  // Check that (non-)scalability is preserved
5217  int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5218  int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5219  if (sourceNScalableDims != resultNScalableDims)
5220  return op->emitOpError("different number of scalable dims at source (")
5221  << sourceNScalableDims << ") and result (" << resultNScalableDims
5222  << ")";
5223  sourceVectorType.getNumDynamicDims();
5224 
5225  return success();
5226 }
5227 
5228 LogicalResult ShapeCastOp::verify() {
5229  auto sourceVectorType =
5230  llvm::dyn_cast_or_null<VectorType>(getSource().getType());
5231  auto resultVectorType =
5232  llvm::dyn_cast_or_null<VectorType>(getResult().getType());
5233 
5234  // Check if source/result are of vector type.
5235  if (sourceVectorType && resultVectorType)
5236  return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
5237 
5238  return success();
5239 }
5240 
5241 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5242  // No-op shape cast.
5243  if (getSource().getType() == getResult().getType())
5244  return getSource();
5245 
5246  // Canceling shape casts.
5247  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5248  if (getResult().getType() == otherOp.getSource().getType())
5249  return otherOp.getSource();
5250 
5251  // Only allows valid transitive folding.
5252  VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5253  VectorType resultType = llvm::cast<VectorType>(getResult().getType());
5254  if (srcType.getRank() < resultType.getRank()) {
5255  if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5256  return {};
5257  } else if (srcType.getRank() > resultType.getRank()) {
5258  if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5259  return {};
5260  } else {
5261  return {};
5262  }
5263 
5264  setOperand(otherOp.getSource());
5265  return getResult();
5266  }
5267 
5268  // Cancelling broadcast and shape cast ops.
5269  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5270  if (bcastOp.getSourceType() == getType())
5271  return bcastOp.getSource();
5272  }
5273 
5274  return {};
5275 }
5276 
5277 namespace {
5278 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
5279 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
5280 public:
5282 
5283  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5284  PatternRewriter &rewriter) const override {
5285  auto constantOp =
5286  shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5287  if (!constantOp)
5288  return failure();
5289  // Only handle splat for now.
5290  auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5291  if (!dense)
5292  return failure();
5293  auto newAttr =
5294  DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
5295  dense.getSplatValue<Attribute>());
5296  rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
5297  return success();
5298  }
5299 };
5300 
5301 /// Helper function that computes a new vector type based on the input vector
5302 /// type by removing the trailing one dims:
5303 ///
5304 /// vector<4x1x1xi1> --> vector<4x1>
5305 ///
5306 static VectorType trimTrailingOneDims(VectorType oldType) {
5307  ArrayRef<int64_t> oldShape = oldType.getShape();
5308  ArrayRef<int64_t> newShape = oldShape;
5309 
5310  ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
5311  ArrayRef<bool> newScalableDims = oldScalableDims;
5312 
5313  while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5314  newShape = newShape.drop_back(1);
5315  newScalableDims = newScalableDims.drop_back(1);
5316  }
5317 
5318  // Make sure we have at least 1 dimension.
5319  // TODO: Add support for 0-D vectors.
5320  if (newShape.empty()) {
5321  newShape = oldShape.take_back();
5322  newScalableDims = oldScalableDims.take_back();
5323  }
5324 
5325  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5326 }
5327 
5328 /// Folds qualifying shape_cast(create_mask) into a new create_mask
5329 ///
5330 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
5331 /// dimension. If the input vector comes from `vector.create_mask` for which
5332 /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
5333 /// to fold shape_cast into create_mask.
5334 ///
5335 /// BEFORE:
5336 /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
5337 /// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
5338 /// AFTER:
5339 /// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
5340 class ShapeCastCreateMaskFolderTrailingOneDim final
5341  : public OpRewritePattern<ShapeCastOp> {
5342 public:
5344 
5345  LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
5346  PatternRewriter &rewriter) const override {
5347  Value shapeOpSrc = shapeOp->getOperand(0);
5348  auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
5349  auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
5350  if (!createMaskOp && !constantMaskOp)
5351  return failure();
5352 
5353  VectorType shapeOpResTy = shapeOp.getResultVectorType();
5354  VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5355 
5356  VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5357  if (newVecType != shapeOpResTy)
5358  return failure();
5359 
5360  auto numDimsToDrop =
5361  shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5362 
5363  // No unit dims to drop
5364  if (!numDimsToDrop)
5365  return failure();
5366 
5367  if (createMaskOp) {
5368  auto maskOperands = createMaskOp.getOperands();
5369  auto numMaskOperands = maskOperands.size();
5370 
5371  // Check every mask dim size to see whether it can be dropped
5372  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5373  --i) {
5374  auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5375  if (!constant || (constant.value() != 1))
5376  return failure();
5377  }
5378  SmallVector<Value> newMaskOperands =
5379  maskOperands.drop_back(numDimsToDrop);
5380 
5381  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
5382  newMaskOperands);
5383  return success();
5384  }
5385 
5386  if (constantMaskOp) {
5387  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5388  auto numMaskOperands = maskDimSizes.size();
5389 
5390  // Check every mask dim size to see whether it can be dropped
5391  for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5392  --i) {
5393  if (maskDimSizes[i] != 1)
5394  return failure();
5395  }
5396 
5397  auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5398  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
5399  newMaskOperands);
5400  return success();
5401  }
5402 
5403  return failure();
5404  }
5405 };
5406 
5407 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
5408 /// This only applies when the shape of the broadcast source
5409 /// 1. is a suffix of the shape of the result (i.e. when broadcast without
5410 /// reshape is expressive enough to capture the result in a single op), or
5411 /// 2. has the same element count as the shape cast result.
5412 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5413 public:
5415 
5416  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5417  PatternRewriter &rewriter) const override {
5418  auto broadcastOp =
5419  shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5420  if (!broadcastOp)
5421  return failure();
5422 
5423  ArrayRef<int64_t> broadcastSourceShape;
5424  if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5425  broadcastSourceShape = srcType.getShape();
5426  ArrayRef<int64_t> shapeCastTargetShape =
5427  shapeCastOp.getResultVectorType().getShape();
5428 
5429  // If `broadcastSourceShape` is a suffix of the result, we can just replace
5430  // with a broadcast to the final shape.
5431  if (broadcastSourceShape ==
5432  shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5433  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5434  shapeCastOp, shapeCastOp.getResultVectorType(),
5435  broadcastOp.getSource());
5436  return success();
5437  }
5438 
5439  // Otherwise, if the final result has the same element count, we can replace
5440  // with a shape cast.
5441  if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5442  if (srcType.getNumElements() ==
5443  shapeCastOp.getResultVectorType().getNumElements()) {
5444  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
5445  shapeCastOp, shapeCastOp.getResultVectorType(),
5446  broadcastOp.getSource());
5447  return success();
5448  }
5449  }
5450 
5451  return failure();
5452  }
5453 };
5454 
5455 } // namespace
5456 
5457 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
5458  MLIRContext *context) {
5459  results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5460  ShapeCastBroadcastFolder>(context);
5461 }
5462 
5463 //===----------------------------------------------------------------------===//
5464 // VectorBitCastOp
5465 //===----------------------------------------------------------------------===//
5466 
5467 LogicalResult BitCastOp::verify() {
5468  auto sourceVectorType = getSourceVectorType();
5469  auto resultVectorType = getResultVectorType();
5470 
5471  for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5472  if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5473  return emitOpError("dimension size mismatch at: ") << i;
5474  }
5475 
5476  DataLayout dataLayout = DataLayout::closest(*this);
5477  auto sourceElementBits =
5478  dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
5479  auto resultElementBits =
5480  dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
5481 
5482  if (sourceVectorType.getRank() == 0) {
5483  if (sourceElementBits != resultElementBits)
5484  return emitOpError("source/result bitwidth of the 0-D vector element "
5485  "types must be equal");
5486  } else if (sourceElementBits * sourceVectorType.getShape().back() !=
5487  resultElementBits * resultVectorType.getShape().back()) {
5488  return emitOpError(
5489  "source/result bitwidth of the minor 1-D vectors must be equal");
5490  }
5491 
5492  return success();
5493 }
5494 
5495 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
5496  // Nop cast.
5497  if (getSource().getType() == getResult().getType())
5498  return getSource();
5499 
5500  // Canceling bitcasts.
5501  if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5502  if (getResult().getType() == otherOp.getSource().getType())
5503  return otherOp.getSource();
5504 
5505  setOperand(otherOp.getSource());
5506  return getResult();
5507  }
5508 
5509  Attribute sourceConstant = adaptor.getSource();
5510  if (!sourceConstant)
5511  return {};
5512 
5513  Type srcElemType = getSourceVectorType().getElementType();
5514  Type dstElemType = getResultVectorType().getElementType();
5515 
5516  if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5517  if (floatPack.isSplat()) {
5518  auto splat = floatPack.getSplatValue<FloatAttr>();
5519 
5520  // Casting fp16 into fp32.
5521  if (srcElemType.isF16() && dstElemType.isF32()) {
5522  uint32_t bits = static_cast<uint32_t>(
5523  splat.getValue().bitcastToAPInt().getZExtValue());
5524  // Duplicate the 16-bit pattern.
5525  bits = (bits << 16) | (bits & 0xffff);
5526  APInt intBits(32, bits);
5527  APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5528  return DenseElementsAttr::get(getResultVectorType(), floatBits);
5529  }
5530  }
5531  }
5532 
5533  if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5534  if (intPack.isSplat()) {
5535  auto splat = intPack.getSplatValue<IntegerAttr>();
5536 
5537  if (llvm::isa<IntegerType>(dstElemType)) {
5538  uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
5539  uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
5540 
5541  // Casting to a larger integer bit width.
5542  if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5543  APInt intBits = splat.getValue().zext(dstBitWidth);
5544 
5545  // Duplicate the lower width element.
5546  for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5547  intBits = (intBits << srcBitWidth) | intBits;
5548  return DenseElementsAttr::get(getResultVectorType(), intBits);
5549  }
5550  }
5551  }
5552  }
5553 
5554  return {};
5555 }
5556 
5557 //===----------------------------------------------------------------------===//
5558 // TypeCastOp
5559 //===----------------------------------------------------------------------===//
5560 
5561 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
5562  auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5563  SmallVector<int64_t, 8> res(memRefType.getShape());
5564  if (vectorType)
5565  res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5566  return res;
5567 }
5568 
5569 /// Build the canonical memRefType with a single vector.
5570 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
5571 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
5572  Value source) {
5573  result.addOperands(source);
5574  MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
5575  VectorType vectorType =
5576  VectorType::get(extractShape(memRefType),
5578  result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
5579  memRefType.getMemorySpace()));
5580 }
5581 
5582 LogicalResult TypeCastOp::verify() {
5583  MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType());
5584  if (!canonicalType.getLayout().isIdentity())
5585  return emitOpError("expects operand to be a memref with identity layout");
5586  if (!getResultMemRefType().getLayout().isIdentity())
5587  return emitOpError("expects result to be a memref with identity layout");
5588  if (getResultMemRefType().getMemorySpace() !=
5589  getMemRefType().getMemorySpace())
5590  return emitOpError("expects result in same memory space");
5591 
5592  auto sourceType = getMemRefType();
5593  auto resultType = getResultMemRefType();
5594  if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
5596  return emitOpError(
5597  "expects result and operand with same underlying scalar type: ")
5598  << resultType;
5599  if (extractShape(sourceType) != extractShape(resultType))
5600  return emitOpError(
5601  "expects concatenated result and operand shapes to be equal: ")
5602  << resultType;
5603  return success();
5604 }
5605 
5606 //===----------------------------------------------------------------------===//
5607 // TransposeOp
5608 //===----------------------------------------------------------------------===//
5609 
5610 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
5611  Value vector, ArrayRef<int64_t> permutation) {
5612  VectorType vt = llvm::cast<VectorType>(vector.getType());
5613  SmallVector<int64_t, 4> transposedShape(vt.getRank());
5614  SmallVector<bool, 4> transposedScalableDims(vt.getRank());
5615  for (unsigned i = 0; i < permutation.size(); ++i) {
5616  transposedShape[i] = vt.getShape()[permutation[i]];
5617  transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5618  }
5619 
5620  result.addOperands(vector);
5621  result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
5622  transposedScalableDims));
5623  result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
5624  builder.getDenseI64ArrayAttr(permutation));
5625 }
5626 
5627 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5628  // Eliminate splat constant transpose ops.
5629  if (auto attr =
5630  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
5631  if (attr.isSplat())
5632  return attr.reshape(getResultVectorType());
5633 
5634  // Eliminate identity transpose ops. This happens when the dimensions of the
5635  // input vector remain in their original order after the transpose operation.
5636  ArrayRef<int64_t> perm = getPermutation();
5637 
5638  // Check if the permutation of the dimensions contains sequential values:
5639  // {0, 1, 2, ...}.
5640  for (int64_t i = 0, e = perm.size(); i < e; i++) {
5641  if (perm[i] != i)
5642  return {};
5643  }
5644 
5645  return getVector();
5646 }
5647 
5648 LogicalResult vector::TransposeOp::verify() {
5649  VectorType vectorType = getSourceVectorType();
5650  VectorType resultType = getResultVectorType();
5651  int64_t rank = resultType.getRank();
5652  if (vectorType.getRank() != rank)
5653  return emitOpError("vector result rank mismatch: ") << rank;
5654  // Verify transposition array.
5655  ArrayRef<int64_t> perm = getPermutation();
5656  int64_t size = perm.size();
5657  if (rank != size)
5658  return emitOpError("transposition length mismatch: ") << size;
5659  SmallVector<bool, 8> seen(rank, false);
5660  for (const auto &ta : llvm::enumerate(perm)) {
5661  if (ta.value() < 0 || ta.value() >= rank)
5662  return emitOpError("transposition index out of range: ") << ta.value();
5663  if (seen[ta.value()])
5664  return emitOpError("duplicate position index: ") << ta.value();
5665  seen[ta.value()] = true;
5666  if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
5667  return emitOpError("dimension size mismatch at: ") << ta.value();
5668  }
5669  return success();
5670 }
5671 
5672 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
5673  return llvm::to_vector<4>(getResultVectorType().getShape());
5674 }
5675 
5676 namespace {
5677 
5678 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
5679 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
5680 public:
5682 
5683  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5684  PatternRewriter &rewriter) const override {
5685  // Composes two permutations: result[i] = permutation1[permutation2[i]].
5686  auto composePermutations = [](ArrayRef<int64_t> permutation1,
5687  ArrayRef<int64_t> permutation2) {
5688  SmallVector<int64_t, 4> result;
5689  for (auto index : permutation2)
5690  result.push_back(permutation1[index]);
5691  return result;
5692  };
5693 
5694  // Return if the input of 'transposeOp' is not defined by another transpose.
5695  vector::TransposeOp parentTransposeOp =
5696  transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
5697  if (!parentTransposeOp)
5698  return failure();
5699 
5700  SmallVector<int64_t, 4> permutation = composePermutations(
5701  parentTransposeOp.getPermutation(), transposeOp.getPermutation());
5702  // Replace 'transposeOp' with a new transpose operation.
5703  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
5704  transposeOp, transposeOp.getResult().getType(),
5705  parentTransposeOp.getVector(), permutation);
5706  return success();
5707  }
5708 };
5709 
5710 // Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
5711 struct FoldTransposedScalarBroadcast final
5712  : public OpRewritePattern<vector::TransposeOp> {
5714 
5715  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5716  PatternRewriter &rewriter) const override {
5717  auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
5718  if (!bcastOp)
5719  return failure();
5720 
5721  auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
5722  if (!srcVectorType || srcVectorType.getNumElements() == 1) {
5723  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5724  transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
5725  return success();
5726  }
5727 
5728  return failure();
5729  }
5730 };
5731 
5732 // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
5733 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
5734 public:
5736 
5737  LogicalResult matchAndRewrite(TransposeOp transposeOp,
5738  PatternRewriter &rewriter) const override {
5739  auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
5740  if (!splatOp)
5741  return failure();
5742 
5743  rewriter.replaceOpWithNewOp<vector::SplatOp>(
5744  transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
5745  return success();
5746  }
5747 };
5748 
5749 /// Folds transpose(create_mask) into a new transposed create_mask.
5750 class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
5751 public:
5753 
5754  LogicalResult matchAndRewrite(TransposeOp transpOp,
5755  PatternRewriter &rewriter) const override {
5756  Value transposeSrc = transpOp.getVector();
5757  auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
5758  auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
5759  if (!createMaskOp && !constantMaskOp)
5760  return failure();
5761 
5762  // Get the transpose permutation and apply it to the vector.create_mask or
5763  // vector.constant_mask operands.
5764  ArrayRef<int64_t> permutation = transpOp.getPermutation();
5765 
5766  if (createMaskOp) {
5767  auto maskOperands = createMaskOp.getOperands();
5768  SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
5769  applyPermutationToVector(newOperands, permutation);
5770 
5771  rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
5772  transpOp, transpOp.getResultVectorType(), newOperands);
5773  return success();
5774  }
5775 
5776  // ConstantMaskOp case.
5777  auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5778  auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
5779 
5780  rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
5781  transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
5782  return success();
5783  }
5784 };
5785 
5786 } // namespace
5787 
5788 void vector::TransposeOp::getCanonicalizationPatterns(
5789  RewritePatternSet &results, MLIRContext *context) {
5790  results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
5791  TransposeFolder, FoldTransposeSplat>(context);
5792 }
5793 
5794 //===----------------------------------------------------------------------===//
5795 // ConstantMaskOp
5796 //===----------------------------------------------------------------------===//
5797 
5798 void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
5799  VectorType type, ConstantMaskKind kind) {
5800  assert(kind == ConstantMaskKind::AllTrue ||
5801  kind == ConstantMaskKind::AllFalse);
5802  build(builder, result, type,
5803  kind == ConstantMaskKind::AllTrue
5804  ? type.getShape()
5805  : SmallVector<int64_t>(type.getRank(), 0));
5806 }
5807 
5808 LogicalResult ConstantMaskOp::verify() {
5809  auto resultType = llvm::cast<VectorType>(getResult().getType());
5810  // Check the corner case of 0-D vectors first.
5811  if (resultType.getRank() == 0) {
5812  if (getMaskDimSizes().size() != 1)
5813  return emitError("array attr must have length 1 for 0-D vectors");
5814  auto dim = getMaskDimSizes()[0];
5815  if (dim != 0 && dim != 1)
5816  return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
5817  return success();
5818  }
5819 
5820  // Verify that array attr size matches the rank of the vector result.
5821  if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
5822  return emitOpError(
5823  "must specify array attr of size equal vector result rank");
5824  // Verify that each array attr element is in bounds of corresponding vector
5825  // result dimension size.
5826  auto resultShape = resultType.getShape();
5827  auto resultScalableDims = resultType.getScalableDims();
5828  ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
5829  for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
5830  if (maskDimSize < 0 || maskDimSize > resultShape[index])
5831  return emitOpError(
5832  "array attr of size out of bounds of vector result dimension size");
5833  if (resultScalableDims[index] && maskDimSize != 0 &&
5834  maskDimSize != resultShape[index])
5835  return emitOpError(
5836  "only supports 'none set' or 'all set' scalable dimensions");
5837  }
5838  // Verify that if one mask dim size is zero, they all should be zero (because
5839  // the mask region is a conjunction of each mask dimension interval).
5840  bool anyZeros = llvm::is_contained(maskDimSizes, 0);
5841  bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
5842  if (anyZeros && !allZeros)
5843  return emitOpError("expected all mask dim sizes to be zeros, "
5844  "as a result of conjunction with zero mask dim");
5845  return success();
5846 }
5847 
5848 bool ConstantMaskOp::isAllOnesMask() {
5849  auto resultType = getVectorType();
5850  // Check the corner case of 0-D vectors first.
5851  if (resultType.getRank() == 0) {
5852  assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
5853  return getMaskDimSizes()[0] == 1;
5854  }
5855  for (const auto [resultSize, maskDimSize] :
5856  llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
5857  if (maskDimSize < resultSize)
5858  return false;
5859  }
5860  return true;
5861 }
5862 
5863 //===----------------------------------------------------------------------===//
5864 // CreateMaskOp
5865 //===----------------------------------------------------------------------===//
5866 
5867 void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
5868  VectorType type,
5869  ArrayRef<OpFoldResult> mixedOperands) {
5870  SmallVector<Value> operands =
5871  getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
5872  build(builder, result, type, operands);
5873 }
5874 
5875 LogicalResult CreateMaskOp::verify() {
5876  auto vectorType = llvm::cast<VectorType>(getResult().getType());
5877  // Verify that an operand was specified for each result vector each dimension.
5878  if (vectorType.getRank() == 0) {
5879  if (getNumOperands() != 1)
5880  return emitOpError(
5881  "must specify exactly one operand for 0-D create_mask");
5882  } else if (getNumOperands() !=
5883  llvm::cast<VectorType>(getResult().getType()).getRank()) {
5884  return emitOpError(
5885  "must specify an operand for each result vector dimension");
5886  }
5887  return success();
5888 }
5889 
5890 std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
5891  if (value.getDefiningOp<vector::VectorScaleOp>())
5892  return 1;
5893  auto mul = value.getDefiningOp<arith::MulIOp>();
5894  if (!mul)
5895  return {};
5896  auto lhs = mul.getLhs();
5897  auto rhs = mul.getRhs();
5898  if (lhs.getDefiningOp<vector::VectorScaleOp>())
5899  return getConstantIntValue(rhs);
5900  if (rhs.getDefiningOp<vector::VectorScaleOp>())
5901  return getConstantIntValue(lhs);
5902  return {};
5903 }
5904 
5905 namespace {
5906 
5907 /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
5908 ///
5909 /// Ex 1:
5910 /// %c2 = arith.constant 2 : index
5911 /// %c3 = arith.constant 3 : index
5912 /// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
5913 /// Becomes:
5914 /// vector.constant_mask [3, 2] : vector<4x3xi1>
5915 ///
5916 /// Ex 2:
5917 /// %c_neg_1 = arith.constant -1 : index
5918 /// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
5919 /// becomes:
5920 /// vector.constant_mask [0] : vector<[8]xi1>
5921 ///
5922 /// Ex 3:
5923 /// %c8 = arith.constant 8 : index
5924 /// %c16 = arith.constant 16 : index
5925 /// %0 = vector.vscale
5926 /// %1 = arith.muli %0, %c16 : index
5927 /// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
5928 /// becomes:
5929 /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
5930 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
5931 public:
5933 
5934  LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
5935  PatternRewriter &rewriter) const override {
5936  VectorType maskType = createMaskOp.getVectorType();
5937  ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
5938  ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
5939 
5940  // Special case: Rank zero shape.
5941  constexpr std::array<int64_t, 1> rankZeroShape{1};
5942  constexpr std::array<bool, 1> rankZeroScalableDims{false};
5943  if (maskType.getRank() == 0) {
5944  maskTypeDimSizes = rankZeroShape;
5945  maskTypeDimScalableFlags = rankZeroScalableDims;
5946  }
5947 
5948  // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
5949  // collect the `constantDims` (for the ConstantMaskOp).
5950  SmallVector<int64_t, 4> constantDims;
5951  for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
5952  if (auto intSize = getConstantIntValue(dimSize)) {
5953  // Constant value.
5954  // If the mask dim is non-scalable this can be any value.
5955  // If the mask dim is scalable only zero (all-false) is supported.
5956  if (maskTypeDimScalableFlags[i] && intSize >= 0)
5957  return failure();
5958  constantDims.push_back(*intSize);
5959  } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
5960  // Constant vscale multiple (e.g. 4 x vscale).
5961  // Must be all-true to fold to a ConstantMask.
5962  if (vscaleMultiplier < maskTypeDimSizes[i])
5963  return failure();
5964  constantDims.push_back(*vscaleMultiplier);
5965  } else {
5966  return failure();
5967  }
5968  }
5969 
5970  // Clamp values to constant_mask bounds.
5971  for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
5972  value = std::clamp<int64_t>(value, 0, maskDimSize);
5973 
5974  // If one of dim sizes is zero, set all dims to zero.
5975  if (llvm::is_contained(constantDims, 0))
5976  constantDims.assign(constantDims.size(), 0);
5977 
5978  // Replace 'createMaskOp' with ConstantMaskOp.
5979  rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
5980  constantDims);
5981  return success();
5982  }
5983 };
5984 
5985 } // namespace
5986 
5987 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
5988  MLIRContext *context) {
5989  results.add<CreateMaskFolder>(context);
5990 }
5991 
5992 //===----------------------------------------------------------------------===//
5993 // MaskOp
5994 //===----------------------------------------------------------------------===//
5995 
5996 void MaskOp::build(
5997  OpBuilder &builder, OperationState &result, Value mask,
5998  Operation *maskableOp,
5999  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6000  assert(maskRegionBuilder &&
6001  "builder callback for 'maskRegion' must be present");
6002 
6003  result.addOperands(mask);
6004  OpBuilder::InsertionGuard guard(builder);
6005  Region *maskRegion = result.addRegion();
6006  builder.createBlock(maskRegion);
6007  maskRegionBuilder(builder, maskableOp);
6008 }
6009 
6010 void MaskOp::build(
6011  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6012  Value mask, Operation *maskableOp,
6013  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6014  build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
6015  maskRegionBuilder);
6016 }
6017 
6018 void MaskOp::build(
6019  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6020  Value mask, Value passthru, Operation *maskableOp,
6021  function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6022  build(builder, result, mask, maskableOp, maskRegionBuilder);
6023  if (passthru)
6024  result.addOperands(passthru);
6025  result.addTypes(resultTypes);
6026 }
6027 
6028 ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
6029  // Create the op region.
6030  result.regions.reserve(1);
6031  Region &maskRegion = *result.addRegion();
6032 
6033  auto &builder = parser.getBuilder();
6034 
6035  // Parse all the operands.
6037  if (parser.parseOperand(mask))
6038  return failure();
6039 
6040  // Optional passthru operand.
6042  ParseResult parsePassthru = parser.parseOptionalComma();
6043  if (parsePassthru.succeeded() && parser.parseOperand(passthru))
6044  return failure();
6045 
6046  // Parse op region.
6047  if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
6048  return failure();
6049 
6050  MaskOp::ensureTerminator(maskRegion, builder, result.location);
6051 
6052  // Parse the optional attribute list.
6053  if (parser.parseOptionalAttrDict(result.attributes))
6054  return failure();
6055 
6056  // Parse all the types.
6057  Type maskType;
6058  if (parser.parseColonType(maskType))
6059  return failure();
6060 
6061  SmallVector<Type> resultTypes;
6062  if (parser.parseOptionalArrowTypeList(resultTypes))
6063  return failure();
6064  result.types.append(resultTypes);
6065 
6066  // Resolve operands.
6067  if (parser.resolveOperand(mask, maskType, result.operands))
6068  return failure();
6069 
6070  if (parsePassthru.succeeded())
6071  if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
6072  return failure();
6073 
6074  return success();
6075 }
6076 
6078  p << " " << getMask();
6079  if (getPassthru())
6080  p << ", " << getPassthru();
6081 
6082  // Print single masked operation and skip terminator.
6083  p << " { ";
6084  Block *singleBlock = &getMaskRegion().getBlocks().front();
6085  if (singleBlock && !singleBlock->getOperations().empty())
6086  p.printCustomOrGenericOp(&singleBlock->front());
6087  p << " }";
6088 
6089  p.printOptionalAttrDict(getOperation()->getAttrs());
6090 
6091  p << " : " << getMask().getType();
6092  if (getNumResults() > 0)
6093  p << " -> " << getResultTypes();
6094 }
6095 
6096