MLIR 23.0.0git
VectorOps.cpp
Go to the documentation of this file.
1//===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements convenience types for working with super-vectorization
10// operations, in particular super-vector loads and stores.
11//
12//===----------------------------------------------------------------------===//
13
15
27#include "mlir/IR/AffineExpr.h"
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Builders.h"
33#include "mlir/IR/IRMapping.h"
37#include "mlir/IR/ValueRange.h"
40#include "mlir/Support/LLVM.h"
42#include "llvm/ADT/ArrayRef.h"
43#include "llvm/ADT/Repeated.h"
44#include "llvm/ADT/STLExtras.h"
45#include "llvm/ADT/SmallVector.h"
46#include "llvm/ADT/SmallVectorExtras.h"
47#include "llvm/ADT/StringSet.h"
48#include "llvm/ADT/TypeSwitch.h"
49#include "llvm/Support/Casting.h"
50
51#include <cassert>
52#include <cstdint>
53#include <numeric>
54
55#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
56// Pull in all enum type and utility function definitions.
57#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
58
59using namespace mlir;
60using namespace mlir::vector;
61
62/// Helper enum to classify mask value.
63enum class MaskFormat {
67};
68
69/// Helper method to classify a mask value. Currently, the method
70/// looks "under the hood" of a constant value with dense attributes
71/// and a constant mask operation (since the client may be called at
72/// various stages during progressive lowering).
74 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
75 // Inspect constant dense values. We count up for bits that
76 // are set, count down for bits that are cleared, and bail
77 // when a mix is detected.
78 if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
79 int64_t val = 0;
80 for (bool b : denseElts.getValues<bool>())
81 if (b && val >= 0)
82 val++;
83 else if (!b && val <= 0)
84 val--;
85 else
87 if (val > 0)
89 if (val < 0)
91 }
92 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
93 // Inspect constant mask index. If the index exceeds the
94 // dimension size, all bits are set. If the index is zero
95 // or less, no bits are set.
96 ArrayRef<int64_t> masks = m.getMaskDimSizes();
97 auto shape = m.getType().getShape();
98 bool allTrue = true;
99 bool allFalse = true;
100 for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
101 if (maskIdx < dimSize)
102 allTrue = false;
103 if (maskIdx > 0)
104 allFalse = false;
105 }
106 if (allTrue)
107 return MaskFormat::AllTrue;
108 if (allFalse)
110 } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
111 // Finds all-false create_masks. An all-true create_mask requires all
112 // dims to be constants, so that'll be folded to a constant_mask, then
113 // detected in the constant_mask case.
114 auto maskOperands = m.getOperands();
115 for (Value operand : maskOperands) {
116 if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
117 int64_t dimSize =
118 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
119 if (dimSize <= 0)
121 }
122 }
123 return MaskFormat::Unknown;
124 }
125 return MaskFormat::Unknown;
126}
127
128/// Default callback to build a region with a 'vector.yield' terminator with no
129/// arguments.
131 vector::YieldOp::create(builder, loc);
132}
133
134// Helper for verifying combining kinds in contractions and reductions.
135static bool isSupportedCombiningKind(CombiningKind combiningKind,
136 Type elementType) {
137 switch (combiningKind) {
138 case CombiningKind::ADD:
139 case CombiningKind::MUL:
140 return elementType.isIntOrIndexOrFloat();
141 case CombiningKind::MINUI:
142 case CombiningKind::MINSI:
143 case CombiningKind::MAXUI:
144 case CombiningKind::MAXSI:
145 case CombiningKind::AND:
146 case CombiningKind::OR:
147 case CombiningKind::XOR:
148 return elementType.isIntOrIndex();
149 case CombiningKind::MINNUMF:
150 case CombiningKind::MAXNUMF:
151 case CombiningKind::MINIMUMF:
152 case CombiningKind::MAXIMUMF:
153 return llvm::isa<FloatType>(elementType);
154 }
155 return false;
156}
157
158/// Returns the effective rank of the vector to read/write for Xfer Ops
159///
160/// When the element type of the shaped type is _a scalar_, this will simply
161/// return the rank of the vector ( the result for xfer_read or the value to
162/// store for xfer_write).
163///
164/// When the element type of the base shaped type is _a vector_, returns the
165/// difference between the original vector type and the element type of the
166/// shaped type.
167///
168/// EXAMPLE 1 (element type is _a scalar_):
169/// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
170/// - shapedType.getElementType() = f32 (rank 0)
171/// - vectorType.getRank() = 2
172/// - Result = 2 - 0 = 2
173///
174/// EXAMPLE 2 (element type is _a vector_):
175/// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
176/// - shapedType.getElementType() = vector<20xf32> (rank 1)
177/// - vectorType.getRank() = 1
178/// - Result = 1 - 1 = 0
179///
180/// This is used to determine the number of minor dimensions for identity maps
181/// in vector transfer Ops.
182static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
183 VectorType vectorType) {
184 unsigned elementVectorRank = 0;
185 VectorType elementVectorType =
186 llvm::dyn_cast<VectorType>(shapedType.getElementType());
187 if (elementVectorType)
188 elementVectorRank += elementVectorType.getRank();
189 return vectorType.getRank() - elementVectorRank;
190}
191
193 VectorType vectorType) {
194 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
195 // TODO: replace once we have 0-d vectors.
196 if (shapedType.getRank() == 0 &&
197 vectorType.getShape() == ArrayRef<int64_t>{1})
198 return AffineMap::get(
199 /*numDims=*/0, /*numSymbols=*/0,
200 getAffineConstantExpr(0, shapedType.getContext()));
202 shapedType.getRank(),
203 getEffectiveVectorRankForXferOp(shapedType, vectorType),
204 shapedType.getContext());
205}
206
207/// Check if `write` is of a constant splat and the masked `read` is padded with
208/// the same splat value -- meaning it could be the same value as the initial
209/// constant splat.
210static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
211 vector::TransferReadOp read) {
212 auto readMask = read.getMask();
213 auto writeMask = write.getMask();
214 // Check if the masks are consistent. The splat value could be the same if the
215 // read is masked (and padded with the splat value), and the write is unmasked
216 // or has the same mask. Note this does not allow the case where the write is
217 // masked and the read is unmasked, as then the read could be of more elements
218 // than the write (which may not be the same value).
219 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
220 if (!couldBeSameSplat)
221 return false;
222 // Check for constant splat (as the source of the write).
223 DenseElementsAttr splatAttr;
224 if (!matchPattern(write.getVector(),
225 m_Constant<DenseElementsAttr>(&splatAttr)) ||
226 !splatAttr.isSplat()) {
227 return false;
228 }
229 // The padding of the read and the constant splat value must be the same.
230 Attribute padAttr;
231 if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
232 return false;
233 return padAttr == splatAttr.getSplatValue<Attribute>();
234}
235
236bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
237 vector::TransferReadOp read) {
238 return !defWrite.hasOutOfBoundsDim() &&
239 defWrite.getIndices() == read.getIndices() &&
240 defWrite.getVectorType() == read.getVectorType() &&
241 defWrite.getPermutationMap() == read.getPermutationMap() &&
242 ((!defWrite.getMask() && !read.getMask()) ||
244}
245
246bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
247 vector::TransferWriteOp priorWrite) {
248 return priorWrite.getIndices() == write.getIndices() &&
249 priorWrite.getMask() == write.getMask() &&
250 priorWrite.getVectorType() == write.getVectorType() &&
251 priorWrite.getPermutationMap() == write.getPermutationMap();
252}
253
255 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
256 bool testDynamicValueUsingBounds) {
257 // For simplicity only look at transfer of same type.
258 if (transferA.getVectorType() != transferB.getVectorType())
259 return false;
260 unsigned rankOffset = transferA.getLeadingShapedRank();
261 for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
262 Value indexA = transferA.getIndices()[i];
263 Value indexB = transferB.getIndices()[i];
264 std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
265 std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
266
267 if (i < rankOffset) {
268 // For leading dimensions, if we can prove that index are different we
269 // know we are accessing disjoint slices.
270 if (cstIndexA.has_value() && cstIndexB.has_value()) {
271 if (*cstIndexA != *cstIndexB)
272 return true;
273 continue;
274 }
275 if (testDynamicValueUsingBounds) {
276 // First try to see if we can fully compose and simplify the affine
277 // expression as a fast track.
278 FailureOr<uint64_t> delta =
280 if (succeeded(delta) && *delta != 0)
281 return true;
282
283 FailureOr<bool> testEqual =
285 if (succeeded(testEqual) && !testEqual.value())
286 return true;
287 }
288 } else {
289 // For this dimension, we slice a part of the memref we need to make sure
290 // the intervals accessed don't overlap.
291 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
292 if (cstIndexA.has_value() && cstIndexB.has_value()) {
293 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
294 if (distance >= vectorDim)
295 return true;
296 continue;
297 }
298 if (testDynamicValueUsingBounds) {
299 // First try to see if we can fully compose and simplify the affine
300 // expression as a fast track.
301 FailureOr<int64_t> delta =
303 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
304 return true;
305
306 FailureOr<int64_t> computeDelta =
308 if (succeeded(computeDelta)) {
309 if (std::abs(computeDelta.value()) >= vectorDim)
310 return true;
311 }
312 }
313 }
314 }
315 return false;
316}
317
318bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
319 VectorTransferOpInterface transferB,
320 bool testDynamicValueUsingBounds) {
321 if (transferA.getBase() != transferB.getBase())
322 return false;
323 return isDisjointTransferIndices(transferA, transferB,
324 testDynamicValueUsingBounds);
325}
326
327// Helper to iterate over n-D vector slice elements. Calculate the next
328// `position` in the n-D vector of size `shape`, applying an offset `offsets`.
329// Modifies the `position` in place. Returns a failure when `position` becomes
330// the end position.
331static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
333 ArrayRef<int64_t> offsets) {
334 for (auto [posInDim, dimSize, offsetInDim] :
335 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
336 ++posInDim;
337 if (posInDim < dimSize + offsetInDim)
338 return success();
339
340 // Carry the overflow to the next loop iteration.
341 posInDim = offsetInDim;
342 }
343
344 return failure();
345}
346
347/// Returns the integer numbers in `values`. `values` are expected to be
348/// constant operations.
351 llvm::transform(values, std::back_inserter(ints), [](Value value) {
352 auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
353 assert(constOp && "Unexpected non-constant index");
354 return constOp.value();
355 });
356 return ints;
357}
358
359/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
360/// be constant operations.
363 llvm::transform(
364 foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
365 assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
366 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
367 });
368 return ints;
369}
370
371/// Convert `foldResults` into Values. Integer attributes are converted to
372/// constant op.
374 ArrayRef<OpFoldResult> foldResults) {
375 SmallVector<Value> values;
376 llvm::transform(foldResults, std::back_inserter(values),
377 [&](OpFoldResult foldResult) {
378 if (auto attr = dyn_cast<Attribute>(foldResult))
380 builder, loc, cast<IntegerAttr>(attr).getInt())
381 .getResult();
382
383 return cast<Value>(foldResult);
384 });
385 return values;
386}
387
388std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
389 if (value.getDefiningOp<vector::VectorScaleOp>())
390 return 1;
391 auto mul = value.getDefiningOp<arith::MulIOp>();
392 if (!mul)
393 return {};
394 auto lhs = mul.getLhs();
395 auto rhs = mul.getRhs();
396 if (lhs.getDefiningOp<vector::VectorScaleOp>())
397 return getConstantIntValue(rhs);
398 if (rhs.getDefiningOp<vector::VectorScaleOp>())
399 return getConstantIntValue(lhs);
400 return {};
401}
402
403/// Converts numeric attributes to the expected type. Supports
404/// integer-to-integer and float-to-integer conversions. Returns the original
405/// attribute if no conversion is needed or supported.
406static Attribute convertNumericAttr(Attribute attr, Type expectedType) {
407 // Integer-to-integer conversion
408 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
409 if (auto intType = dyn_cast<IntegerType>(expectedType)) {
410 if (intAttr.getType() != expectedType)
411 return IntegerAttr::get(expectedType, intAttr.getInt());
412 }
413 return attr;
414 }
415
416 // Float-to-integer bitcast (preserves bit representation)
417 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
418 auto intType = dyn_cast<IntegerType>(expectedType);
419 if (!intType)
420 return attr;
421
422 APFloat floatVal = floatAttr.getValue();
423 APInt intVal = floatVal.bitcastToAPInt();
424 return IntegerAttr::get(expectedType, intVal);
425 }
426
427 return attr;
428}
429
430//===----------------------------------------------------------------------===//
431// CombiningKindAttr
432//===----------------------------------------------------------------------===//
433
434namespace mlir {
435namespace vector {
436namespace detail {
438 using KeyTy = uint64_t;
439
441
442 bool operator==(const KeyTy &key) const { return value == key; }
443
445 const KeyTy &key) {
446 return new (allocator.allocate<BitmaskEnumStorage>())
448 }
449
451};
452} // namespace detail
453} // namespace vector
454} // namespace mlir
455
456//===----------------------------------------------------------------------===//
457// VectorDialect
458//===----------------------------------------------------------------------===//
459
460namespace {
461/// This class defines the interface for handling inlining with vector dialect
462/// operations.
463struct VectorInlinerInterface : public DialectInlinerInterface {
464 using DialectInlinerInterface::DialectInlinerInterface;
465
466 /// All vector dialect ops can be inlined.
467 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
468 return true;
469 }
470};
471} // namespace
472
473void VectorDialect::initialize() {
474 addAttributes<
475#define GET_ATTRDEF_LIST
476#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
477 >();
478
479 addOperations<
480#define GET_OP_LIST
481#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
482 >();
483
484 addInterfaces<VectorInlinerInterface>();
485
486 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
487 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
488 YieldOp>();
489 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
490 TransferWriteOp>();
491 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
492 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
493 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
494}
495
496/// Materialize a single constant operation from a given attribute value with
497/// the desired resultant type.
498Operation *VectorDialect::materializeConstant(OpBuilder &builder,
499 Attribute value, Type type,
500 Location loc) {
501 if (matchPattern(value, ub::m_Poison()))
502 return value.getDialect().materializeConstant(builder, value, type, loc);
503
504 return arith::ConstantOp::materialize(builder, value, type, loc);
505}
506
508 return builder.getIntegerType(64);
509}
510
512 ArrayRef<int64_t> values) {
513 return builder.getI64ArrayAttr(values);
514}
515
516//===----------------------------------------------------------------------===//
517// MultiDimReductionOp
518//===----------------------------------------------------------------------===//
519
520void vector::MultiDimReductionOp::build(OpBuilder &builder,
521 OperationState &result, Value source,
522 Value acc, ArrayRef<bool> reductionMask,
523 CombiningKind kind) {
524 SmallVector<int64_t> reductionDims;
525 for (const auto &en : llvm::enumerate(reductionMask))
526 if (en.value())
527 reductionDims.push_back(en.index());
528 build(builder, result, kind, source, acc, reductionDims);
529}
530
531OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
532 // No reduction dims: this is a noop regardless of rank.
533 if (getReductionDims().empty())
534 return getSource();
535 return {};
536}
537
538std::optional<SmallVector<int64_t, 4>>
539MultiDimReductionOp::getShapeForUnroll() {
540 return llvm::to_vector<4>(getSourceVectorType().getShape());
541}
542
543LogicalResult MultiDimReductionOp::verify() {
544 SmallVector<int64_t> targetShape;
545 SmallVector<bool> scalableDims;
546 Type inferredReturnType;
547 auto sourceScalableDims = getSourceVectorType().getScalableDims();
548 for (auto [dimIdx, dimSize] :
549 llvm::enumerate(getSourceVectorType().getShape()))
550 if (!llvm::any_of(getReductionDims(),
551 [dimIdx = dimIdx](int64_t reductionDimIdx) {
552 return reductionDimIdx == static_cast<int64_t>(dimIdx);
553 })) {
554 targetShape.push_back(dimSize);
555 scalableDims.push_back(sourceScalableDims[dimIdx]);
556 }
557 // TODO: update to also allow 0-d vectors when available.
558 if (targetShape.empty())
559 inferredReturnType = getSourceVectorType().getElementType();
560 else
561 inferredReturnType = VectorType::get(
562 targetShape, getSourceVectorType().getElementType(), scalableDims);
563 if (getType() != inferredReturnType)
564 return emitOpError() << "destination type " << getType()
565 << " is incompatible with source type "
566 << getSourceVectorType();
567
568 return success();
569}
570
571/// Returns the mask type expected by this operation.
572Type MultiDimReductionOp::getExpectedMaskType() {
573 auto vecType = getSourceVectorType();
574 return VectorType::get(vecType.getShape(),
575 IntegerType::get(vecType.getContext(), /*width=*/1),
576 vecType.getScalableDims());
577}
578
579namespace {
580// Only unit dimensions that are being reduced are folded. If the dimension is
581// unit, but not reduced, it is not folded, thereby keeping the output type the
582// same. If not all dimensions which are reduced are of unit dimension, this
583// transformation does nothing. This is just a generalization of
584// ElideSingleElementReduction for ReduceOp.
585struct ElideUnitDimsInMultiDimReduction
586 : public OpRewritePattern<MultiDimReductionOp> {
587 using Base::Base;
588
589 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
590 PatternRewriter &rewriter) const override {
591 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
592 for (const auto &dim : enumerate(shape)) {
593 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
594 return failure();
595 }
596
597 // Vector mask setup.
598 OpBuilder::InsertionGuard guard(rewriter);
599 Operation *rootOp;
600 Value mask;
601 if (reductionOp.isMasked()) {
602 rewriter.setInsertionPoint(reductionOp.getMaskingOp());
603 rootOp = reductionOp.getMaskingOp();
604 mask = reductionOp.getMaskingOp().getMask();
605 } else {
606 rootOp = reductionOp;
607 }
608
609 Location loc = reductionOp.getLoc();
610 Value acc = reductionOp.getAcc();
611 Value cast;
612 if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
613 if (mask) {
614 VectorType newMaskType =
615 VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
616 dstVecType.getScalableDims());
617 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
618 }
619 cast = vector::ShapeCastOp::create(
620 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
621 } else {
622 // This means we are reducing all the dimensions, and all reduction
623 // dimensions are of size 1. So a simple extraction would do.
624 if (mask)
625 mask = vector::ExtractOp::create(rewriter, loc, mask);
626 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
627 }
628
629 Value result =
630 vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
631 cast, /*fastmath=*/nullptr, mask);
632 rewriter.replaceOp(rootOp, result);
633 return success();
634 }
635};
636} // namespace
637
638void MultiDimReductionOp::getCanonicalizationPatterns(
639 RewritePatternSet &results, MLIRContext *context) {
640 results.add<ElideUnitDimsInMultiDimReduction>(context);
641}
642
643//===----------------------------------------------------------------------===//
644// ReductionOp
645//===----------------------------------------------------------------------===//
646
647void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
648 CombiningKind kind, Value vector,
649 arith::FastMathFlags fastMathFlags) {
650 build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
651}
652
653void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
654 CombiningKind kind, Value vector, Value acc,
655 arith::FastMathFlags fastMathFlags) {
656 build(builder, result,
657 llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
658 acc, fastMathFlags);
659}
660
661LogicalResult ReductionOp::verify() {
662 // Verify for 0-D and 1-D vector.
663 int64_t rank = getSourceVectorType().getRank();
664 if (rank > 1)
665 return emitOpError("unsupported reduction rank: ") << rank;
666
667 // Verify supported reduction kind.
668 Type eltType = getDest().getType();
669 if (!isSupportedCombiningKind(getKind(), eltType))
670 return emitOpError("unsupported reduction type '")
671 << eltType << "' for kind '" << stringifyCombiningKind(getKind())
672 << "'";
673
674 return success();
675}
676
677// MaskableOpInterface methods.
678
679/// Returns the mask type expected by this operation.
680Type ReductionOp::getExpectedMaskType() {
681 auto vecType = getSourceVectorType();
682 return VectorType::get(vecType.getShape(),
683 IntegerType::get(vecType.getContext(), /*width=*/1),
684 vecType.getScalableDims());
685}
686
688 OpBuilder &builder, Location loc,
689 Value vector) {
690 switch (op) {
691 case arith::AtomicRMWKind::addf:
692 case arith::AtomicRMWKind::addi:
693 return vector::ReductionOp::create(builder, vector.getLoc(),
694 CombiningKind::ADD, vector);
695 case arith::AtomicRMWKind::mulf:
696 case arith::AtomicRMWKind::muli:
697 return vector::ReductionOp::create(builder, vector.getLoc(),
698 CombiningKind::MUL, vector);
699 case arith::AtomicRMWKind::minimumf:
700 return vector::ReductionOp::create(builder, vector.getLoc(),
701 CombiningKind::MINIMUMF, vector);
702 case arith::AtomicRMWKind::mins:
703 return vector::ReductionOp::create(builder, vector.getLoc(),
704 CombiningKind::MINSI, vector);
705 case arith::AtomicRMWKind::minu:
706 return vector::ReductionOp::create(builder, vector.getLoc(),
707 CombiningKind::MINUI, vector);
708 case arith::AtomicRMWKind::maximumf:
709 return vector::ReductionOp::create(builder, vector.getLoc(),
710 CombiningKind::MAXIMUMF, vector);
711 case arith::AtomicRMWKind::maxs:
712 return vector::ReductionOp::create(builder, vector.getLoc(),
713 CombiningKind::MAXSI, vector);
714 case arith::AtomicRMWKind::maxu:
715 return vector::ReductionOp::create(builder, vector.getLoc(),
716 CombiningKind::MAXUI, vector);
717 case arith::AtomicRMWKind::andi:
718 return vector::ReductionOp::create(builder, vector.getLoc(),
719 CombiningKind::AND, vector);
720 case arith::AtomicRMWKind::ori:
721 return vector::ReductionOp::create(builder, vector.getLoc(),
722 CombiningKind::OR, vector);
723 case arith::AtomicRMWKind::minnumf:
724 return vector::ReductionOp::create(builder, vector.getLoc(),
725 CombiningKind::MINNUMF, vector);
726 case arith::AtomicRMWKind::maxnumf:
727 return vector::ReductionOp::create(builder, vector.getLoc(),
728 CombiningKind::MAXNUMF, vector);
729 case arith::AtomicRMWKind::xori:
730 return vector::ReductionOp::create(builder, vector.getLoc(),
731 CombiningKind::XOR, vector);
732 default:
733 (void)emitOptionalError(loc, "Reduction operation type not supported");
734 break;
735 }
736 return nullptr;
737}
738
739std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
740 return llvm::to_vector<4>(getSourceVectorType().getShape());
741}
742
743namespace {
744struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
745 using Base::Base;
746
747 LogicalResult matchAndRewrite(ReductionOp reductionOp,
748 PatternRewriter &rewriter) const override {
749 // Vector mask setup.
750 OpBuilder::InsertionGuard guard(rewriter);
751 auto maskableOp =
752 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
753 Operation *rootOp;
754 Value mask;
755 if (maskableOp.isMasked()) {
756 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
757 rootOp = maskableOp.getMaskingOp();
758 mask = maskableOp.getMaskingOp().getMask();
759 } else {
760 rootOp = reductionOp;
761 }
762
763 auto vectorType = reductionOp.getSourceVectorType();
764 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
765 return failure();
766
767 Location loc = reductionOp.getLoc();
768 if (mask)
769 mask = ExtractOp::create(rewriter, loc, mask);
770 Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
771
772 if (Value acc = reductionOp.getAcc())
773 result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
774 result, acc,
775 reductionOp.getFastmathAttr(), mask);
776
777 rewriter.replaceOp(rootOp, result);
778 return success();
779 }
780};
781} // namespace
782
783void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
784 MLIRContext *context) {
785 results.add<ElideSingleElementReduction>(context);
786}
787
788//===----------------------------------------------------------------------===//
789// ContractionOp
790//===----------------------------------------------------------------------===//
791
792void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
794 ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
795 ArrayRef<IteratorType> iteratorTypes) {
796 result.addOperands({lhs, rhs, acc});
797 result.addTypes(acc.getType());
798 result.addAttribute(
799 getIndexingMapsAttrName(result.name),
800 builder.getAffineMapArrayAttr(
801 AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
802 result.addAttribute(
803 getIteratorTypesAttrName(result.name),
804 builder.getArrayAttr(llvm::map_to_vector(
805 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
806 return IteratorTypeAttr::get(builder.getContext(), t);
807 })));
808}
809
810void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
812 ArrayAttr indexingMaps,
813 ArrayAttr iteratorTypes) {
814 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
815 ContractionOp::getDefaultKind());
816}
817
818void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
820 ArrayAttr indexingMaps,
821 ArrayAttr iteratorTypes, CombiningKind kind,
822 arith::FastMathFlags fastMathFlags) {
823 result.addOperands({lhs, rhs, acc});
824 result.addTypes(acc.getType());
825 result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
826 result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
827 result.addAttribute(getKindAttrName(result.name),
828 CombiningKindAttr::get(builder.getContext(), kind));
829 if (fastMathFlags != arith::FastMathFlags::none)
830 result.addAttribute(
831 getFastmathAttrName(result.name),
832 arith::FastMathFlagsAttr::get(builder.getContext(), fastMathFlags));
833}
834
835ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
841 Type resultType;
842 auto loc = parser.getCurrentLocation();
843 DictionaryAttr dictAttr;
844 // TODO: Unify linalg op attribute parsing.
845 if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
846 parser.parseComma() || parser.parseOperand(rhsInfo) ||
847 parser.parseComma() || parser.parseOperand(accInfo) ||
848 parser.parseTrailingOperandList(masksInfo) ||
849 parser.parseOptionalAttrDict(result.attributes) ||
850 parser.parseColonTypeList(types) ||
851 parser.parseKeywordType("into", resultType) ||
852 parser.resolveOperand(lhsInfo, types[0], result.operands) ||
853 parser.resolveOperand(rhsInfo, types[1], result.operands) ||
854 parser.resolveOperand(accInfo, resultType, result.operands) ||
855 parser.addTypeToList(resultType, result.types))
856 return failure();
857 result.attributes.append(dictAttr.getValue().begin(),
858 dictAttr.getValue().end());
859
860 // Convert array of string into an array of IteratyType enums. This is needed,
861 // because tests still use the old format when 'iterator_types' attribute is
862 // represented as an array of strings.
863 // TODO: Remove this conversion once tests are fixed.
864 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
865 result.attributes.get(getIteratorTypesAttrName(result.name)));
866 if (!iteratorTypes) {
867 return parser.emitError(loc)
868 << "expected " << getIteratorTypesAttrName(result.name)
869 << " array attribute";
870 }
871
872 SmallVector<Attribute> iteratorTypeAttrs;
873
874 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
875 auto maybeIteratorType = symbolizeIteratorType(s);
876 if (!maybeIteratorType.has_value())
877 return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
878
879 iteratorTypeAttrs.push_back(
880 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
881 }
882 result.attributes.set(getIteratorTypesAttrName(result.name),
883 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
884
885 if (!result.attributes.get(getKindAttrName(result.name))) {
886 result.addAttribute(
887 getKindAttrName(result.name),
888 CombiningKindAttr::get(result.getContext(),
889 ContractionOp::getDefaultKind()));
890 }
891 if (masksInfo.empty())
892 return success();
893 if (masksInfo.size() != 2)
894 return parser.emitError(parser.getNameLoc(),
895 "expected zero or exactly 2 vector mask operands");
896 auto lhsType = llvm::cast<VectorType>(types[0]);
897 auto rhsType = llvm::cast<VectorType>(types[1]);
898 auto maskElementType = parser.getBuilder().getI1Type();
899 std::array<VectorType, 2> maskTypes = {
900 VectorType::Builder(lhsType).setElementType(maskElementType),
901 VectorType::Builder(rhsType).setElementType(maskElementType)};
902 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
903 return failure();
904 return success();
905}
906
907void ContractionOp::print(OpAsmPrinter &p) {
908 // TODO: Unify printing code with linalg ops.
909 auto attrNames = getTraitAttrNames();
910 llvm::StringSet<> traitAttrsSet;
911 traitAttrsSet.insert_range(attrNames);
913 for (auto attr : (*this)->getAttrs()) {
914 if (attr.getName() == getIteratorTypesAttrName()) {
915 auto iteratorTypes =
916 llvm::cast<ArrayAttr>(attr.getValue())
917 .getAsValueRange<IteratorTypeAttr, IteratorType>();
918 // Convert IteratorType enums into the string representation. This is
919 // needed, because tests still use the old format when 'iterator_types'
920 // attribute is represented as an array of strings.
921 // TODO: Remove this conversion once tests are fixed.
922 SmallVector<Attribute> iteratorTypeNames =
923 llvm::map_to_vector(iteratorTypes, [&](IteratorType t) -> Attribute {
924 return StringAttr::get(getContext(), stringifyIteratorType(t));
925 });
926
927 attrs.emplace_back(getIteratorTypesAttrName(),
928 ArrayAttr::get(getContext(), iteratorTypeNames));
929 } else if (traitAttrsSet.count(attr.getName().strref()) > 0) {
930 // Omit fastmath when it equals the default (none) to keep output clean.
931 if (attr.getName() == getFastmathAttrName() &&
932 llvm::cast<arith::FastMathFlagsAttr>(attr.getValue()).getValue() ==
933 arith::FastMathFlags::none)
934 continue;
935 attrs.push_back(attr);
936 }
937 }
938
939 auto dictAttr = DictionaryAttr::get(getContext(), attrs);
940 p << " " << dictAttr << " " << getLhs() << ", ";
941 p << getRhs() << ", " << getAcc();
942
943 p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
944 p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
945 << getResultType();
946}
947
948static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
949 const std::vector<std::pair<int64_t, int64_t>> &map) {
950 for (auto &dimPair : map) {
951 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
952 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
953 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
954 return false;
955 }
956 return true;
957}
958
959static LogicalResult verifyOutputShape(
960 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
961 Type resType,
962 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
963 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
964 DenseSet<int64_t> lhsContractingDimSet;
965 DenseSet<int64_t> rhsContractingDimSet;
966 for (auto &dimPair : contractingDimMap) {
967 lhsContractingDimSet.insert(dimPair.first);
968 rhsContractingDimSet.insert(dimPair.second);
969 }
970 DenseSet<int64_t> rhsBatchDimSet(llvm::from_range,
971 llvm::make_second_range(batchDimMap));
972
973 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
974 SmallVector<int64_t, 4> expectedResultDims;
975 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
976 if (lhsContractingDimSet.count(i) > 0)
977 continue;
978 expectedResultDims.push_back(lhsType.getDimSize(i));
979 }
980
981 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
982 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
983 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
984 continue;
985 expectedResultDims.push_back(rhsType.getDimSize(i));
986 }
987
988 // Verify 'expectedResultDims'.
989 if (expectedResultDims.empty()) {
990 // No batch or free dimension implies a scalar result.
991 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
992 return op.emitOpError("invalid accumulator/result vector shape");
993 } else {
994 // At least one batch or free dimension implies a vector result.
995 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
996 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
997 if (!resVectorType || !accVectorType)
998 return op.emitOpError("invalid accumulator/result vector shape");
999
1000 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
1001 // types fully define the result vector type. This assumes the affine maps
1002 // are well-formed, which must have been verified already.
1003 MLIRContext *ctx = op.getContext();
1004 AffineMap lhsMap = op.getIndexingMapsArray()[0];
1005 AffineMap rhsMap = op.getIndexingMapsArray()[1];
1006 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
1007 return op.emitOpError(
1008 "expected all dimensions to be either a LHS or a RHS dimension");
1009 SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
1010 for (auto pair :
1011 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
1012 VectorType v = pair.first;
1013 auto map = pair.second;
1014 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1015 unsigned pos = map.getDimPosition(idx);
1016 if (!extents[pos])
1017 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
1018 }
1019 }
1020 if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
1021 return op.emitOpError("expected all dimensions to get an extent as "
1022 "either a LHS or a RHS dimension");
1023
1024 AffineMap resMap = op.getIndexingMapsArray()[2];
1025 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
1026 /*symbolCount=*/0, extents, ctx);
1027 // Compose the resMap with the extentsMap, which is a constant map.
1028 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
1029 assert(llvm::all_of(expectedMap.getResults(),
1030 llvm::IsaPred<AffineConstantExpr>) &&
1031 "expected constant extent along all dimensions.");
1032 // Extract the expected shape and build the type.
1033 auto expectedShape =
1034 llvm::map_to_vector<4>(expectedMap.getResults(), [](AffineExpr e) {
1035 return cast<AffineConstantExpr>(e).getValue();
1036 });
1037 auto expected =
1038 VectorType::get(expectedShape, resVectorType.getElementType(),
1039 resVectorType.getScalableDims());
1040 if (resVectorType != expected || accVectorType != expected)
1041 return op.emitOpError(
1042 "invalid accumulator/result vector shape, expected: ")
1043 << expected;
1044 }
1045 return success();
1046}
1047
1048LogicalResult ContractionOp::verify() {
1049 VectorType lhsType = getLhsType();
1050 VectorType rhsType = getRhsType();
1051 Type accType = getAccType();
1052 Type resType = getResultType();
1053
1054 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1055 if (!lhsType.getElementType().isSignlessInteger())
1056 return emitOpError("only supports signless integer types");
1057 }
1058
1059 // Verify that an indexing map was specified for each vector operand.
1060 if (getIndexingMapsArray().size() != 3)
1061 return emitOpError("expected an indexing map for each vector operand");
1062
1063 // Verify that each index map has 'numIterators' inputs, no symbols, and
1064 // that the number of map outputs equals the rank of its associated
1065 // vector operand.
1066 unsigned numIterators = getIteratorTypes().getValue().size();
1067 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1068 auto index = it.index();
1069 auto map = it.value();
1070 if (map.getNumSymbols() != 0)
1071 return emitOpError("expected indexing map ")
1072 << index << " to have no symbols";
1073 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
1074 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1075 // Verify that the map has the right number of inputs, outputs, and indices.
1076 // This also correctly accounts for (..) -> () for rank-0 results.
1077 if (map.getNumDims() != numIterators)
1078 return emitOpError("expected indexing map ")
1079 << index << " to have " << numIterators << " number of inputs";
1080 if (map.getNumResults() != rank)
1081 return emitOpError("expected indexing map ")
1082 << index << " to have " << rank << " number of outputs";
1083 if (!map.isProjectedPermutation())
1084 return emitOpError("expected indexing map ")
1085 << index << " to be a projected permutation of its inputs";
1086 }
1087
1088 auto contractingDimMap = getContractingDimMap();
1089 auto batchDimMap = getBatchDimMap();
1090
1091 // Verify at least one contracting dimension pair was specified.
1092 if (contractingDimMap.empty())
1093 return emitOpError("expected at least one contracting dimension pair");
1094
1095 // Verify contracting dimension map was properly constructed.
1096 if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1097 return emitOpError("invalid contracting dimension map");
1098
1099 // Verify batch dimension map was properly constructed.
1100 if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1101 return emitOpError("invalid batch dimension map");
1102
1103 // Verify 'accType' and 'resType' shape.
1104 if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1105 contractingDimMap, batchDimMap)))
1106 return failure();
1107
1108 if (!getKindAttr()) {
1109 return emitOpError("expected 'kind' attribute of type CombiningKind (e.g. "
1110 "'vector.kind<add>')");
1111 }
1112
1113 // Verify supported combining kind.
1114 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1115 auto elementType = vectorType ? vectorType.getElementType() : resType;
1116 if (!isSupportedCombiningKind(getKind(), elementType))
1117 return emitOpError("unsupported contraction type");
1118
1119 // Delayed calling of IndexingMapOpInterface::verifyImpl.
1120 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1121}
1122
1123// MaskableOpInterface methods.
1124
1125/// Returns the mask type expected by this operation. Mostly used for
1126/// verification purposes. It requires the operation to be vectorized."
1127Type ContractionOp::getExpectedMaskType() {
1128 auto indexingMaps = this->getIndexingMapsArray();
1129 AffineMap lhsIdxMap = indexingMaps[0];
1130 AffineMap rhsIdxMap = indexingMaps[1];
1131 VectorType lhsType = this->getLhsType();
1132 VectorType rhsType = this->getRhsType();
1133
1134 unsigned numVecDims = lhsIdxMap.getNumDims();
1135 SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1136 SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1137
1138 // Using the information in the indexing maps, extract the size of each
1139 // dimension in the vector.contract operation from the two input operands.
1140 for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1141 maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1142 maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1143 lhsType.getScalableDims()[dimIdx];
1144 }
1145 for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1146 maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1147 maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1148 rhsType.getScalableDims()[dimIdx];
1149 }
1150
1151 assert(ShapedType::isStaticShape(maskShape) &&
1152 "Mask shape couldn't be computed");
1153
1154 return VectorType::get(maskShape,
1155 IntegerType::get(lhsType.getContext(), /*width=*/1),
1156 maskShapeScalableDims);
1157}
1158
1159SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1160 return SmallVector<StringRef>{getIndexingMapsAttrName(),
1161 getIteratorTypesAttrName(), getKindAttrName(),
1162 getFastmathAttrName()};
1163}
1164
1166 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1167 if (targetExpr == map.getResult(i))
1168 return i;
1169 return -1;
1170}
1171
1172static std::vector<std::pair<int64_t, int64_t>>
1173getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1174 IteratorType targetIteratorType, MLIRContext *context) {
1175 std::vector<std::pair<int64_t, int64_t>> dimMap;
1176 for (const auto &it : llvm::enumerate(iteratorTypes)) {
1177 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1178 if (iteratorType != targetIteratorType)
1179 continue;
1180 // Search lhs/rhs map results for 'targetExpr'.
1181 auto targetExpr = getAffineDimExpr(it.index(), context);
1182 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1183 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1184 if (lhsDim >= 0 && rhsDim >= 0)
1185 dimMap.emplace_back(lhsDim, rhsDim);
1186 }
1187 return dimMap;
1188}
1189
1190void ContractionOp::getIterationBounds(
1191 SmallVectorImpl<int64_t> &iterationBounds) {
1192 auto lhsShape = getLhsType().getShape();
1193 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1194 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1195 for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1196 // Search lhs/rhs map results for 'targetExpr'.
1197 auto targetExpr = getAffineDimExpr(it.index(), getContext());
1198 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1199 if (iteratorType == IteratorType::reduction) {
1200 // Get reduction dim size from lhs shape (same size in rhsShape).
1201 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1202 assert(lhsDimIndex >= 0);
1203 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1204 continue;
1205 }
1206 // Get parallel dimension size from result shape.
1207 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1208 assert(resDimIndex >= 0);
1209 assert(resVectorType != nullptr);
1210 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1211 }
1212}
1213
1214void ContractionOp::getIterationIndexMap(
1215 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1216 unsigned numMaps = getIndexingMapsArray().size();
1217 iterationIndexMap.resize(numMaps);
1218 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1219 auto index = it.index();
1220 auto map = it.value();
1221 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1222 auto dim = cast<AffineDimExpr>(map.getResult(i));
1223 iterationIndexMap[index][dim.getPosition()] = i;
1224 }
1225 }
1226}
1227
1228std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1229 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1230 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1231 getContext());
1232}
1233
1234std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1235 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1236 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1237 getContext());
1238}
1239
1240std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1242 getIterationBounds(shape);
1243 return shape;
1244}
1245
1246/// Return a fused vector::ContractionOp which represents a patterns such as:
1247///
1248/// ```mlir
1249/// %c0 = vector.constant 0: ...
1250/// %c = vector.contract %a, %b, %c0: ...
1251/// %e = add %c, %d: ...
1252/// ```
1253///
1254/// by:
1255///
1256/// ```mlir
1257/// %e = vector.contract %a, %b, %d: ...
1258/// ```
1259///
1260/// Return null if the canonicalization does not apply.
1261// TODO: This should be a folding of Add into Contract in core but while they
1262// live in different dialects, it is not possible without unnatural
1263// dependencies.
1264template <typename AddOpType>
1265struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1266 using OpRewritePattern<AddOpType>::OpRewritePattern;
1267
1268 LogicalResult matchAndRewrite(AddOpType addOp,
1269 PatternRewriter &rewriter) const override {
1270 auto canonicalize = [&](Value maybeContraction,
1271 Value otherOperand) -> vector::ContractionOp {
1272 vector::ContractionOp contractionOp =
1273 dyn_cast_or_null<vector::ContractionOp>(
1274 maybeContraction.getDefiningOp());
1275 if (!contractionOp)
1276 return vector::ContractionOp();
1277 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1278 contractionOp.getAcc().getDefiningOp())) {
1279 if (maybeZero.getValue() ==
1280 rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1281 IRMapping bvm;
1282 bvm.map(contractionOp.getAcc(), otherOperand);
1283 auto newContraction =
1284 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1285 rewriter.replaceOp(addOp, newContraction.getResult());
1286 return newContraction;
1287 }
1288 }
1289 return vector::ContractionOp();
1290 };
1291
1292 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1293 vector::ContractionOp contract = canonicalize(a, b);
1294 contract = contract ? contract : canonicalize(b, a);
1295 return contract ? success() : failure();
1296 }
1297};
1298
1299void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1300 MLIRContext *context) {
1303}
1304
1305// Returns `true` if `index` is either within [0, maxIndex) or equal to
1306// `poisonValue`.
1308 int64_t maxIndex) {
1309 return index == poisonValue || (index >= 0 && index < maxIndex);
1310}
1311
1312//===----------------------------------------------------------------------===//
1313// ExtractOp
1314//===----------------------------------------------------------------------===//
1315
1316void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1317 SetIntRangeFn setResultRanges) {
1318 setResultRanges(getResult(), argRanges.front());
1319}
1320
1321void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1322 Value source) {
1323 auto vectorTy = cast<VectorType>(source.getType());
1324 build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
1325}
1326
1327void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1328 Value source, int64_t position) {
1329 build(builder, result, source, ArrayRef<int64_t>{position});
1330}
1331
1332void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1333 Value source, OpFoldResult position) {
1334 build(builder, result, source, ArrayRef<OpFoldResult>{position});
1335}
1336
1337void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1338 Value source, ArrayRef<int64_t> position) {
1339 build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1340 builder.getDenseI64ArrayAttr(position));
1341}
1342
1343void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1344 Value source, ArrayRef<OpFoldResult> position) {
1345 SmallVector<int64_t> staticPos;
1346 SmallVector<Value> dynamicPos;
1347 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1348 build(builder, result, source, dynamicPos,
1349 builder.getDenseI64ArrayAttr(staticPos));
1350}
1351
1352LogicalResult
1353ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1354 ExtractOp::Adaptor adaptor,
1355 SmallVectorImpl<Type> &inferredReturnTypes) {
1356 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1357 if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1358 vectorType.getRank()) {
1359 inferredReturnTypes.push_back(vectorType.getElementType());
1360 } else {
1361 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1362 vectorType.getRank());
1363 inferredReturnTypes.push_back(VectorType::get(
1364 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1365 vectorType.getScalableDims().drop_front(n)));
1366 }
1367 return success();
1368}
1369
1370LogicalResult vector::ExtractOp::verify() {
1371 if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
1372 if (resTy.getRank() == 0)
1373 return emitError(
1374 "expected a scalar instead of a 0-d vector as the result type");
1375
1376 // Note: This check must come before getMixedPosition() to prevent a crash.
1377 auto dynamicMarkersCount =
1378 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1379 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1380 return emitOpError(
1381 "mismatch between dynamic and static positions (kDynamic marker but no "
1382 "corresponding dynamic position) -- this can only happen due to an "
1383 "incorrect fold/rewrite");
1384 auto position = getMixedPosition();
1385 if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1386 return emitOpError(
1387 "expected position attribute of rank no greater than vector rank");
1388 for (auto [idx, pos] : llvm::enumerate(position)) {
1389 if (auto attr = dyn_cast<Attribute>(pos)) {
1390 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1392 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1393 return emitOpError("expected position attribute #")
1394 << (idx + 1)
1395 << " to be a non-negative integer smaller than the "
1396 "corresponding vector dimension or poison (-1)";
1397 }
1398 }
1399 }
1400 return success();
1401}
1402
1403template <typename IntType>
1405 return llvm::map_to_vector<4>(
1406 arrayAttr.getAsRange<IntegerAttr>(),
1407 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); });
1408}
1409
1410/// Fold the result of chains of ExtractOp in place by simply concatenating the
1411/// positions.
1412static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1413 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1414 return failure();
1415
1416 // TODO: Canonicalization for dynamic position not implemented yet.
1417 if (extractOp.hasDynamicPosition())
1418 return failure();
1419
1420 SmallVector<int64_t> globalPosition;
1421 ExtractOp currentOp = extractOp;
1422 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1423 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1424 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1425 currentOp = nextOp;
1426 // TODO: Canonicalization for dynamic position not implemented yet.
1427 if (currentOp.hasDynamicPosition())
1428 return failure();
1429 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1430 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1431 }
1432 extractOp.setOperand(0, currentOp.getSource());
1433 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1434 OpBuilder b(extractOp.getContext());
1435 std::reverse(globalPosition.begin(), globalPosition.end());
1436 extractOp.setStaticPosition(globalPosition);
1437 return success();
1438}
1439
1440namespace {
1441/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1442/// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1443/// Compose TransposeOp permutations as we walk back.
1444/// This helper class keeps an updated extraction position `extractPosition`
1445/// with extra trailing sentinels.
1446/// The sentinels encode the internal transposition status of the result vector.
1447/// As we iterate, extractPosition is permuted and updated.
1448class ExtractFromInsertTransposeChainState {
1449public:
1450 ExtractFromInsertTransposeChainState(ExtractOp e);
1451
1452 /// Iterate over producing insert and transpose ops until we find a fold.
1453 Value fold();
1454
1455private:
1456 /// Return true if the vector at position `a` is contained within the vector
1457 /// at position `b`. Under insert/extract semantics, this is the same as `a`
1458 /// is a prefix of `b`.
1459 template <typename ContainerA, typename ContainerB>
1460 bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1461 return a.size() <= b.size() &&
1462 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1463 }
1464
1465 /// Return true if the vector at position `a` intersects the vector at
1466 /// position `b`. Under insert/extract semantics, this is the same as equality
1467 /// of all entries of `a` that are >=0 with the corresponding entries of b.
1468 /// Comparison is on the common prefix (i.e. zip).
1469 template <typename ContainerA, typename ContainerB>
1470 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1471 for (auto [elemA, elemB] : llvm::zip(a, b)) {
1472 if (elemA < 0 || elemB < 0)
1473 continue;
1474 if (elemA != elemB)
1475 return false;
1476 }
1477 return true;
1478 }
1479
1480 /// Folding is only possible in the absence of an internal permutation in the
1481 /// result vector.
1482 bool canFold() {
1483 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1484 }
1485
1486 // Helper to get the next defining op of interest.
1487 void updateStateForNextIteration(Value v) {
1488 nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1489 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1490 };
1491
1492 // Case 1. If we hit a transpose, just compose the map and iterate.
1493 // Invariant: insert + transpose do not change rank, we can always compose.
1494 LogicalResult handleTransposeOp();
1495
1496 // Case 2: the insert position matches extractPosition exactly, early return.
1497 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1498
1499 /// Case 3: if the insert position is a prefix of extractPosition, extract a
1500 /// portion of the source of the insert.
1501 /// Example:
1502 /// ```
1503 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1504 /// // extractPosition == [1, 2, 3]
1505 /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1506 /// // can fold to vector.extract %source[0, 3]
1507 /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1508 /// ```
1509 /// To traverse through %source, we need to set the leading dims to 0 and
1510 /// drop the extra leading dims.
1511 /// This method updates the internal state.
1512 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1513
1514 /// Try to fold in place to extract(source, extractPosition) and return the
1515 /// folded result. Return null if folding is not possible (e.g. due to an
1516 /// internal transposition in the result).
1517 Value tryToFoldExtractOpInPlace(Value source);
1518
1519 ExtractOp extractOp;
1520 int64_t vectorRank;
1521 int64_t extractedRank;
1522
1523 InsertOp nextInsertOp;
1524 TransposeOp nextTransposeOp;
1525
1526 /// Sentinel values that encode the internal permutation status of the result.
1527 /// They are set to (-1, ... , -k) at the beginning and appended to
1528 /// `extractPosition`.
1529 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1530 /// ensure that there is no internal transposition.
1531 /// Internal transposition cannot be accounted for with a folding pattern.
1532 // TODO: We could relax the internal transposition with an extra transposition
1533 // operation in a future canonicalizer.
1534 SmallVector<int64_t> sentinels;
1535 SmallVector<int64_t> extractPosition;
1536};
1537} // namespace
1538
1539ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1540 ExtractOp e)
1541 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1542 extractedRank(extractOp.getNumIndices()) {
1543 assert(vectorRank >= extractedRank && "Extracted position overflow");
1544 sentinels.reserve(vectorRank - extractedRank);
1545 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1546 sentinels.push_back(-(i + 1));
1547 extractPosition.assign(extractOp.getStaticPosition().begin(),
1548 extractOp.getStaticPosition().end());
1549 llvm::append_range(extractPosition, sentinels);
1550}
1551
1552// Case 1. If we hit a transpose, just compose the map and iterate.
1553// Invariant: insert + transpose do not change rank, we can always compose.
1554LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1555 // TODO: Canonicalization for dynamic position not implemented yet.
1556 if (extractOp.hasDynamicPosition())
1557 return failure();
1558
1559 if (!nextTransposeOp)
1560 return failure();
1562 nextTransposeOp.getPermutation(), extractOp.getContext()));
1564 return success();
1565}
1566
1567// Case 2: the insert position matches extractPosition exactly, early return.
1568LogicalResult
1569ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1570 Value &res) {
1571 // TODO: Canonicalization for dynamic position not implemented yet.
1572 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1573 return failure();
1574
1575 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1576 if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1577 return failure();
1578 // Case 2.a. early-exit fold.
1579 res = nextInsertOp.getValueToStore();
1580 // Case 2.b. if internal transposition is present, canFold will be false.
1581 return success(canFold());
1582}
1583
1584/// Case 3: if inserted position is a prefix of extractPosition,
1585/// extract a portion of the source of the insertion.
1586/// This method updates the internal state.
1587LogicalResult
1588ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1589 // TODO: Canonicalization for dynamic position not implemented yet.
1590 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1591 return failure();
1592
1593 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1594 if (!isContainedWithin(insertedPos, extractPosition))
1595 return failure();
1596 // Set leading dims to zero.
1597 std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1598 // Drop extra leading dims.
1599 extractPosition.erase(extractPosition.begin(),
1600 extractPosition.begin() + insertedPos.size());
1601 extractedRank = extractPosition.size() - sentinels.size();
1602 // Case 3.a. early-exit fold (break and delegate to post-while path).
1603 res = nextInsertOp.getValueToStore();
1604 // Case 3.b. if internal transposition is present, canFold will be false.
1605 return success();
1606}
1607
1608/// Try to fold in place to extract(source, extractPosition) and return the
1609/// folded result. Return null if folding is not possible (e.g. due to an
1610/// internal transposition in the result).
1611Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1612 Value source) {
1613 // TODO: Canonicalization for dynamic position not implemented yet.
1614 if (extractOp.hasDynamicPosition())
1615 return Value();
1616
1617 // If we can't fold (either internal transposition, or nothing to fold), bail.
1618 bool nothingToFold = (source == extractOp.getSource());
1619 if (nothingToFold || !canFold())
1620 return Value();
1621
1622 // Otherwise, fold by updating the op inplace and return its result.
1623 OpBuilder b(extractOp.getContext());
1624 extractOp.setStaticPosition(
1625 ArrayRef(extractPosition).take_front(extractedRank));
1626 extractOp.getSourceMutable().assign(source);
1627 return extractOp.getResult();
1628}
1629
1630/// Iterate over producing insert and transpose ops until we find a fold.
1631Value ExtractFromInsertTransposeChainState::fold() {
1632 // TODO: Canonicalization for dynamic position not implemented yet.
1633 if (extractOp.hasDynamicPosition())
1634 return Value();
1635
1636 Value valueToExtractFrom = extractOp.getSource();
1637 updateStateForNextIteration(valueToExtractFrom);
1638 while (nextInsertOp || nextTransposeOp) {
1639 // Case 1. If we hit a transpose, just compose the map and iterate.
1640 // Invariant: insert + transpose do not change rank, we can always compose.
1641 if (succeeded(handleTransposeOp())) {
1642 valueToExtractFrom = nextTransposeOp.getVector();
1643 updateStateForNextIteration(valueToExtractFrom);
1644 continue;
1645 }
1646
1647 Value result;
1648 // Case 2: the position match exactly.
1649 if (succeeded(handleInsertOpWithMatchingPos(result)))
1650 return result;
1651
1652 // Case 3: if the inserted position is a prefix of extractPosition, we can
1653 // just extract a portion of the source of the insert.
1654 if (succeeded(handleInsertOpWithPrefixPos(result)))
1655 return tryToFoldExtractOpInPlace(result);
1656
1657 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1658 // values. This is a more difficult case and we bail.
1659 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1660 if (isContainedWithin(extractPosition, insertedPos) ||
1661 intersectsWhereNonNegative(extractPosition, insertedPos))
1662 return Value();
1663
1664 // Case 5: No intersection, we forward the extract to insertOp.dest().
1665 valueToExtractFrom = nextInsertOp.getDest();
1666 updateStateForNextIteration(valueToExtractFrom);
1667 }
1668 // If after all this we can fold, go for it.
1669 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1670}
1671
1672/// Returns true if the operation has a 0-D vector type operand or result.
1674 auto hasZeroDimVectorType = [](Type type) -> bool {
1675 auto vecType = dyn_cast<VectorType>(type);
1676 return vecType && vecType.getRank() == 0;
1677 };
1678
1679 return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1680 llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1681}
1682
1683/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
1684/// considered to be 'broadcastlike'.
1685static bool isBroadcastLike(Operation *op) {
1686 if (isa<BroadcastOp>(op))
1687 return true;
1688
1689 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1690 if (!shapeCast)
1691 return false;
1692
1693 // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1694 // Checking that the destination shape has a prefix of 1s is not sufficient,
1695 // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
1696 // is that the source shape is a suffix of the destination shape.
1697 VectorType srcType = shapeCast.getSourceVectorType();
1698 ArrayRef<int64_t> srcShape = srcType.getShape();
1699 uint64_t srcRank = srcType.getRank();
1700 ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
1701 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1702}
1703
1704/// Fold extract(broadcast(X)) to either extract(X) or just X.
1705///
1706/// Example:
1707///
1708/// broadcast extract [1][2]
1709/// (3, 4) --------> (2, 3, 4) ----------------> (4)
1710///
1711/// becomes
1712/// extract [1]
1713/// (3,4) -------------------------------------> (4)
1714///
1715///
1716/// The variable names used in this implementation correspond to the above
1717/// shapes as,
1718///
1719/// - (3, 4) is `input` shape.
1720/// - (2, 3, 4) is `broadcast` shape.
1721/// - (4) is `extract` shape.
1722///
1723/// This folding is possible when the suffix of `input` shape is the same as
1724/// `extract` shape.
1725static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1726
1727 Operation *defOp = extractOp.getSource().getDefiningOp();
1728 if (!defOp || !isBroadcastLike(defOp))
1729 return Value();
1730
1731 Value input = defOp->getOperand(0);
1732
1733 // Replace extract(broadcast(X)) with X
1734 if (extractOp.getType() == input.getType())
1735 return input;
1736
1737 // Get required types and ranks in the chain
1738 // input -> broadcast -> extract
1739 // (scalars are treated as rank-0).
1740 auto inputType = llvm::dyn_cast<VectorType>(input.getType());
1741 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1742 unsigned inputRank = inputType ? inputType.getRank() : 0;
1743 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1744 unsigned extractRank = extractType ? extractType.getRank() : 0;
1745
1746 // Cannot do without the broadcast if overall the rank increases.
1747 if (extractRank > inputRank)
1748 return Value();
1749
1750 // The above condition guarantees that input is a vector.
1751 assert(inputType && "input must be a vector type because of previous checks");
1752 ArrayRef<int64_t> inputShape = inputType.getShape();
1753
1754 // In the case where there is a broadcast dimension in the suffix, it is not
1755 // possible to replace extract(broadcast(X)) with extract(X). Example:
1756 //
1757 // broadcast extract
1758 // (1) --------> (3,4) ------> (4)
1759 if (extractType &&
1760 extractType.getShape() != inputShape.take_back(extractRank))
1761 return Value();
1762
1763 // Replace extract(broadcast(X)) with extract(X).
1764 // First, determine the new extraction position.
1765 unsigned deltaOverall = inputRank - extractRank;
1766 unsigned deltaBroadcast = broadcastRank - inputRank;
1767 SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
1768 SmallVector<OpFoldResult> newPositions(deltaOverall);
1769 IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
1770 for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1771 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1772 }
1773 auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
1774 extractOp->setOperands(
1775 llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
1776 extractOp.setStaticPosition(staticPos);
1777 return extractOp.getResult();
1778}
1779
1780/// Fold extractOp coming from ShuffleOp.
1781///
1782/// Example:
1783///
1784/// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1785/// : vector<8xf32>, vector<8xf32>
1786/// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1787/// ->
1788/// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1789///
1790static Value foldExtractFromShuffle(ExtractOp extractOp) {
1791 // Dynamic positions are not folded as the resulting code would be more
1792 // complex than the input code.
1793 if (extractOp.hasDynamicPosition())
1794 return Value();
1795
1796 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1797 if (!shuffleOp)
1798 return Value();
1799
1800 // TODO: 0-D or multi-dimensional vectors not supported yet.
1801 if (shuffleOp.getResultVectorType().getRank() != 1)
1802 return Value();
1803
1804 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1805 auto shuffleMask = shuffleOp.getMask();
1806 int64_t extractIdx = extractOp.getStaticPosition()[0];
1807 int64_t shuffleIdx = shuffleMask[extractIdx];
1808
1809 // Find the shuffled vector to extract from based on the shuffle index.
1810 if (shuffleIdx < inputVecSize) {
1811 extractOp.setOperand(0, shuffleOp.getV1());
1812 extractOp.setStaticPosition({shuffleIdx});
1813 } else {
1814 extractOp.setOperand(0, shuffleOp.getV2());
1815 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1816 }
1817
1818 return extractOp.getResult();
1819}
1820
1821// Fold extractOp with source coming from ShapeCast op.
1822static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1823 // TODO: Canonicalization for dynamic position not implemented yet.
1824 if (extractOp.hasDynamicPosition())
1825 return Value();
1826
1827 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1828 if (!shapeCastOp)
1829 return Value();
1830
1831 // Get the nth dimension size starting from lowest dimension.
1832 auto getDimReverse = [](VectorType type, int64_t n) {
1833 return type.getShape().take_back(n + 1).front();
1834 };
1835 int64_t destinationRank =
1836 llvm::isa<VectorType>(extractOp.getType())
1837 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1838 : 0;
1839 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1840 return Value();
1841 if (destinationRank > 0) {
1842 auto destinationType =
1843 llvm::cast<VectorType>(extractOp.getResult().getType());
1844 for (int64_t i = 0; i < destinationRank; i++) {
1845 // The lowest dimension of the destination must match the lowest
1846 // dimension of the shapecast op source.
1847 // TODO: This case could be support in a canonicalization pattern.
1848 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1849 getDimReverse(destinationType, i))
1850 return Value();
1851 }
1852 }
1853 // Extract the strides associated with the extract op vector source. Then use
1854 // this to calculate a linearized position for the extract.
1855 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1856 std::reverse(extractedPos.begin(), extractedPos.end());
1858 int64_t stride = 1;
1859 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1860 strides.push_back(stride);
1861 stride *=
1862 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1863 }
1864
1865 int64_t position = linearize(extractedPos, strides);
1866 // Then extract the strides associated to the shapeCast op vector source and
1867 // delinearize the position using those strides.
1868 SmallVector<int64_t, 4> newStrides;
1869 int64_t numDimension =
1870 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1871 stride = 1;
1872 for (int64_t i = 0; i < numDimension; i++) {
1873 newStrides.push_back(stride);
1874 stride *=
1875 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1876 }
1877 std::reverse(newStrides.begin(), newStrides.end());
1878 SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1879 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1880 OpBuilder b(extractOp.getContext());
1881 extractOp.setStaticPosition(newPosition);
1882 extractOp.setOperand(0, shapeCastOp.getSource());
1883 return extractOp.getResult();
1884}
1885
1886/// Fold an ExtractOp from ExtractStridedSliceOp.
1887static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1888 // TODO: Canonicalization for dynamic position not implemented yet.
1889 if (extractOp.hasDynamicPosition())
1890 return Value();
1891
1892 auto extractStridedSliceOp =
1893 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1894 if (!extractStridedSliceOp)
1895 return Value();
1896
1897 // 0-D vectors not supported.
1898 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1899 if (hasZeroDimVectors(extractStridedSliceOp))
1900 return Value();
1901
1902 // Return if 'extractStridedSliceOp' has non-unit strides.
1903 if (extractStridedSliceOp.hasNonUnitStrides())
1904 return Value();
1905
1906 // Trim offsets for dimensions fully extracted.
1907 auto sliceOffsets =
1908 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1909 while (!sliceOffsets.empty()) {
1910 size_t lastOffset = sliceOffsets.size() - 1;
1911 if (sliceOffsets.back() != 0 ||
1912 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1913 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1914 break;
1915 sliceOffsets.pop_back();
1916 }
1917 unsigned destinationRank = 0;
1918 if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1919 destinationRank = vecType.getRank();
1920 // The dimensions of the result need to be untouched by the
1921 // extractStridedSlice op.
1922 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1923 sliceOffsets.size())
1924 return Value();
1925
1926 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1927 assert(extractedPos.size() >= sliceOffsets.size());
1928 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1929 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1930 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1931
1932 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1933 OpBuilder b(extractOp.getContext());
1934 extractOp.setStaticPosition(extractedPos);
1935 return extractOp.getResult();
1936}
1937
1938/// Fold extract_op fed from a chain of insertStridedSlice ops.
1939static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1940 // TODO: Canonicalization for dynamic position not implemented yet.
1941 if (extractOp.hasDynamicPosition())
1942 return Value();
1943
1944 int64_t destinationRank =
1945 llvm::isa<VectorType>(extractOp.getType())
1946 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1947 : 0;
1948 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1949 if (!insertOp)
1950 return Value();
1951
1952 // 0-D vectors not supported.
1953 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1954 if (hasZeroDimVectors(insertOp))
1955 return Value();
1956
1957 while (insertOp) {
1958 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1959 insertOp.getSourceVectorType().getRank();
1960 if (destinationRank > insertOp.getSourceVectorType().getRank())
1961 return Value();
1962 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1963 ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1964
1965 if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1966 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1967 }))
1968 return Value();
1969 bool disjoint = false;
1970 SmallVector<int64_t, 4> offsetDiffs;
1971 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1972 int64_t start = insertOffsets[dim];
1973 int64_t size =
1974 (dim < insertRankDiff)
1975 ? 1
1976 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1977 int64_t end = start + size;
1978 int64_t offset = extractOffsets[dim];
1979 // Check if the start of the extract offset is in the interval inserted.
1980 if (start <= offset && offset < end) {
1981 if (dim >= insertRankDiff)
1982 offsetDiffs.push_back(offset - start);
1983 continue;
1984 }
1985 disjoint = true;
1986 break;
1987 }
1988 // The extract element chunk overlap with the vector inserted.
1989 if (!disjoint) {
1990 // If any of the inner dimensions are only partially inserted we have a
1991 // partial overlap.
1992 int64_t srcRankDiff =
1993 insertOp.getSourceVectorType().getRank() - destinationRank;
1994 for (int64_t i = 0; i < destinationRank; i++) {
1995 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1996 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1997 insertRankDiff))
1998 return Value();
1999 }
2000 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
2001 // OpBuilder is only used as a helper to build an I64ArrayAttr.
2002 OpBuilder b(extractOp.getContext());
2003 extractOp.setStaticPosition(offsetDiffs);
2004 return extractOp.getResult();
2005 }
2006 // If the chunk extracted is disjoint from the chunk inserted, keep
2007 // looking in the insert chain.
2008 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2009 }
2010 return Value();
2011}
2012
2013/// Try to fold the extraction of a scalar from a vector defined by
2014/// vector.from_elements. E.g.:
2015///
2016/// %0 = vector.from_elements %a, %b : vector<2xf32>
2017/// %1 = vector.extract %0[0] : f32 from vector<2xf32>
2018/// ==> fold to %a
2019static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
2020 // Dynamic extractions cannot be folded.
2021 if (extractOp.hasDynamicPosition())
2022 return {};
2023
2024 // Look for extract(from_elements).
2025 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2026 if (!fromElementsOp)
2027 return {};
2028
2029 // Scalable vectors are not supported.
2030 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2031 if (vecType.isScalable())
2032 return {};
2033
2034 // Only extractions of scalars are supported.
2035 int64_t rank = vecType.getRank();
2036 ArrayRef<int64_t> indices = extractOp.getStaticPosition();
2037 if (extractOp.getType() != vecType.getElementType())
2038 return {};
2039 assert(static_cast<int64_t>(indices.size()) == rank &&
2040 "unexpected number of indices");
2041
2042 // Compute flattened/linearized index and fold to operand.
2043 int flatIndex = 0;
2044 int stride = 1;
2045 for (int i = rank - 1; i >= 0; --i) {
2046 flatIndex += indices[i] * stride;
2047 stride *= vecType.getDimSize(i);
2048 }
2049 return fromElementsOp.getElements()[flatIndex];
2050}
2051
2052/// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
2053/// then fold it.
2054template <typename OpType, typename AdaptorType>
2055static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
2056 SmallVectorImpl<Value> &operands) {
2057 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2058 OperandRange dynamicPosition = op.getDynamicPosition();
2059 ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2061 if constexpr (std::is_same_v<OpType, ExtractOp>)
2062 vectorShape = op.getSourceVectorType().getShape();
2063 else
2064 vectorShape = op.getDestVectorType().getShape();
2065
2066 // If the dynamic operands is empty, it is returned directly.
2067 if (!dynamicPosition.size())
2068 return {};
2069
2070 // `index` is used to iterate over the `dynamicPosition`.
2071 unsigned index = 0;
2072
2073 // `opChange` is a flag. If it is true, it means to update `op` in place.
2074 bool opChange = false;
2075 for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2076 if (ShapedType::isStatic(staticPosition[i]))
2077 continue;
2078 Attribute positionAttr = dynamicPositionAttr[index];
2079 Value position = dynamicPosition[index++];
2080 if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2081 int64_t value = attr.getInt();
2082 // Do not fold if the value is out of bounds (-1 signifies a poison
2083 // value rather than OOB index).
2084 if (value >= -1 && value < vectorShape[i]) {
2085 staticPosition[i] = attr.getInt();
2086 opChange = true;
2087 continue;
2088 }
2089 }
2090 operands.push_back(position);
2091 }
2092
2093 if (opChange) {
2094 op.setStaticPosition(staticPosition);
2095 op.getOperation()->setOperands(operands);
2096 // Return the original result to indicate an in-place folding happened.
2097 return op.getResult();
2098 }
2099 return {};
2100}
2101
2102/// Fold an insert or extract operation into an poison value when a poison index
2103/// is found at any dimension of the static position.
2105 ArrayRef<int64_t> staticPos,
2106 int64_t poisonVal) {
2107 if (!is_contained(staticPos, poisonVal))
2108 return {};
2109
2110 return ub::PoisonAttr::get(context);
2111}
2112
2113/// Fold a vector extract from is a poison source.
2115 if (matchPattern(srcAttr, ub::m_Poison()))
2116 return srcAttr;
2117
2118 return {};
2119}
2120
2121/// Fold a vector extract extracting from a DenseElementsAttr.
2123 Attribute srcAttr) {
2124 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2125 if (!denseAttr) {
2126 return {};
2127 }
2128
2129 if (denseAttr.isSplat()) {
2130 Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2131 if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2132 newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2133 return newAttr;
2134 }
2135
2136 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2137 if (vecTy.isScalable())
2138 return {};
2139
2140 if (extractOp.hasDynamicPosition()) {
2141 return {};
2142 }
2143
2144 // Materializing subsets of a large constant array can generally lead to
2145 // explosion in IR size because of different combination of subsets that
2146 // can exist. However, vector.extract is a restricted form of subset
2147 // extract where you can only extract non-overlapping (or the same) subset for
2148 // a given rank of the subset. Because of this property, the IR size can only
2149 // increase at most by `rank * size(array)` from a single constant array being
2150 // extracted by multiple extracts.
2151
2152 // Calculate the linearized position of the continuous chunk of elements to
2153 // extract.
2154 SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2155 copy(extractOp.getStaticPosition(), completePositions.begin());
2156 int64_t startPos =
2157 linearize(completePositions, computeStrides(vecTy.getShape()));
2158 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2159
2160 TypedAttr newAttr;
2161 if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2162 SmallVector<Attribute> elementValues(
2163 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2164 newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2165 } else {
2166 newAttr = *denseValuesBegin;
2167 }
2168
2169 return newAttr;
2170}
2171
2172OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2173 // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2174 // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2175 // mismatch).
2176 if (getNumIndices() == 0 && getSource().getType() == getResult().getType())
2177 return getSource();
2178 if (auto res = foldPoisonSrcExtractOp(adaptor.getSource()))
2179 return res;
2180 // Fold `arith.constant` indices into the `vector.extract` operation.
2181 // Do not stop here as this fold may enable subsequent folds that require
2182 // constant indices.
2183 SmallVector<Value> operands = {getSource()};
2184 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
2185
2186 if (auto res = foldPoisonIndexInsertExtractOp(
2187 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2188 return res;
2189 if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getSource()))
2190 return res;
2191 if (succeeded(foldExtractOpFromExtractChain(*this)))
2192 return getResult();
2193 if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2194 return res;
2195 if (auto res = foldExtractFromBroadcast(*this))
2196 return res;
2197 if (auto res = foldExtractFromShuffle(*this))
2198 return res;
2199 if (auto res = foldExtractFromShapeCast(*this))
2200 return res;
2201 if (auto val = foldExtractFromExtractStrided(*this))
2202 return val;
2203 if (auto val = foldExtractStridedOpFromInsertChain(*this))
2204 return val;
2205 if (auto val = foldScalarExtractFromFromElements(*this))
2206 return val;
2207
2208 return inplaceFolded;
2209}
2210
2211namespace {
2212
2213// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2214class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2215public:
2216 using Base::Base;
2217
2218 LogicalResult matchAndRewrite(ExtractOp extractOp,
2219 PatternRewriter &rewriter) const override {
2220
2221 Operation *defOp = extractOp.getSource().getDefiningOp();
2222 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2223 if (!defOp || !isBroadcastLike(defOp) || !outType)
2224 return failure();
2225
2226 Value source = defOp->getOperand(0);
2227 if (isBroadcastableTo(source.getType(), outType) !=
2228 BroadcastableToResult::Success)
2229 return failure();
2230
2231 rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
2232 return success();
2233 }
2234};
2235
2236// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2237class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2238public:
2239 using Base::Base;
2240
2241 LogicalResult matchAndRewrite(ExtractOp extractOp,
2242 PatternRewriter &rewriter) const override {
2243 auto createMaskOp =
2244 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2245 if (!createMaskOp)
2246 return failure();
2247
2248 VectorType extractedMaskType =
2249 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2250
2251 if (!extractedMaskType)
2252 return failure();
2253
2254 auto maskOperands = createMaskOp.getOperands();
2255 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2256 VectorType maskType = createMaskOp.getVectorType();
2257
2258 bool containsUnknownDims = false;
2259 bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2260
2261 for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2262 dimIdx++) {
2263 int64_t pos = extractOpPos[dimIdx];
2264 Value operand = maskOperands[dimIdx];
2265 auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2266 if (!constantOp) {
2267 // Bounds of this dim unknown.
2268 containsUnknownDims = true;
2269 continue;
2270 }
2271
2272 int64_t createMaskBound =
2273 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2274
2275 if (pos != ShapedType::kDynamic) {
2276 // If any position is outside the range from the `create_mask`, then the
2277 // extracted mask will be all-false.
2278 allFalse |= pos >= createMaskBound;
2279 } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2280 // This dim is not all-true and since this is a dynamic index we don't
2281 // know if the extraction is within the true or false region.
2282 // Note: Zero dims have already handled via getMaskFormat().
2283 containsUnknownDims = true;
2284 }
2285 }
2286
2287 if (allFalse) {
2288 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2289 extractOp, DenseElementsAttr::get(extractedMaskType, false));
2290 } else if (!containsUnknownDims) {
2291 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2292 extractOp, extractedMaskType,
2293 maskOperands.drop_front(extractOpPos.size()));
2294 } else {
2295 return failure();
2296 }
2297 return success();
2298 }
2299};
2300
2301// Pattern to rewrite a ExtractOp(ConstantMask) -> ConstantMask.
2302class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
2303public:
2304 using Base::Base;
2305
2306 LogicalResult matchAndRewrite(ExtractOp extractOp,
2307 PatternRewriter &rewriter) const override {
2308 auto constantMaskOp =
2309 extractOp.getSource().getDefiningOp<vector::ConstantMaskOp>();
2310 if (!constantMaskOp)
2311 return failure();
2312
2313 Type resultType = extractOp.getResult().getType();
2314 auto extractedMaskType = dyn_cast<VectorType>(resultType);
2315
2316 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2317 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
2318
2319 VectorType maskType = constantMaskOp.getVectorType();
2320
2321 // Check if any extracted position is outside the mask bounds.
2322 for (size_t dimIdx = 0; dimIdx < extractOpPos.size(); dimIdx++) {
2323 int64_t pos = extractOpPos[dimIdx];
2324 if (pos == ShapedType::kDynamic) {
2325 // If the dim is all-true, a dynamic index is fine — any position
2326 // is within the masked region.
2327 if (maskDimSizes[dimIdx] == maskType.getDimSize(dimIdx))
2328 continue;
2329 // Otherwise we don't know if the position is inside or outside of
2330 // the masked area, so bail out.
2331 return failure();
2332 }
2333
2334 // If the position is statically outside of the masked area, the result
2335 // will be all-false.
2336 if (pos >= maskDimSizes[dimIdx]) {
2337 if (extractedMaskType) {
2338 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2339 extractOp, DenseElementsAttr::get(extractedMaskType, false));
2340 } else {
2341 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2342 extractOp, rewriter.getIntegerAttr(resultType, false));
2343 }
2344 return success();
2345 }
2346 }
2347
2348 // All positions are within the mask bounds.
2349 if (extractedMaskType) {
2350 // Vector result: the result is a constant_mask with the remaining
2351 // dimensions.
2352 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
2353 extractOp, extractedMaskType,
2354 maskDimSizes.drop_front(extractOpPos.size()));
2355 } else {
2356 // Scalar result: all positions are within the masked region, so the
2357 // result is true.
2358 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2359 extractOp, rewriter.getIntegerAttr(resultType, true));
2360 }
2361 return success();
2362 }
2363};
2364
2365// Folds extract(shape_cast(..)) into shape_cast when the total element count
2366// does not change.
2367LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2368 PatternRewriter &rewriter) {
2369 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2370 if (!castOp)
2371 return failure();
2372
2373 VectorType sourceType = castOp.getSourceVectorType();
2374 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2375 if (!targetType)
2376 return failure();
2377
2378 if (sourceType.getNumElements() != targetType.getNumElements())
2379 return failure();
2380
2381 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2382 castOp.getSource());
2383 return success();
2384}
2385
2386/// Try to canonicalize the extraction of a subvector from a vector defined by
2387/// vector.from_elements. E.g.:
2388///
2389/// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2390/// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2391/// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2392LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2393 PatternRewriter &rewriter) {
2394 // Dynamic positions are not supported.
2395 if (extractOp.hasDynamicPosition())
2396 return failure();
2397
2398 // Scalar extracts are handled by the folder.
2399 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2400 if (!resultType)
2401 return failure();
2402
2403 // Look for extracts from a from_elements op.
2404 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2405 if (!fromElementsOp)
2406 return failure();
2407 VectorType inputType = fromElementsOp.getType();
2408
2409 // Scalable vectors are not supported.
2410 if (resultType.isScalable() || inputType.isScalable())
2411 return failure();
2412
2413 // Compute the position of first extracted element and flatten/linearize the
2414 // position.
2415 SmallVector<int64_t> firstElementPos =
2416 llvm::to_vector(extractOp.getStaticPosition());
2417 firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2418 int flatIndex = 0;
2419 int stride = 1;
2420 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2421 flatIndex += firstElementPos[i] * stride;
2422 stride *= inputType.getDimSize(i);
2423 }
2424
2425 // Replace the op with a smaller from_elements op.
2426 rewriter.replaceOpWithNewOp<FromElementsOp>(
2427 extractOp, resultType,
2428 fromElementsOp.getElements().slice(flatIndex,
2429 resultType.getNumElements()));
2430 return success();
2431}
2432
2433/// Replace `vector.extract` with `vector.shape_cast`.
2434///
2435/// BEFORE:
2436/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
2437/// AFTER:
2438/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
2439///
2440/// The canonical form of vector operations that reshape vectors is shape_cast.
2441struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
2442 using Base::Base;
2443 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2444 PatternRewriter &rewriter) const override {
2445 VectorType sourceType = extractOp.getSourceVectorType();
2446 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2447 if (!outType)
2448 return failure();
2449
2450 if (sourceType.getNumElements() != outType.getNumElements())
2451 return rewriter.notifyMatchFailure(
2452 extractOp, "extract to vector with fewer elements");
2453
2454 // Negative values in `position` means that the extacted value is poison.
2455 // There is a vector.extract folder for this.
2456 if (llvm::any_of(extractOp.getMixedPosition(),
2457 [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2458 return rewriter.notifyMatchFailure(extractOp,
2459 "leaving for extract poison folder");
2460
2461 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
2462 extractOp.getSource());
2463
2464 return success();
2465 }
2466};
2467
2468/// Folds vector.extract from vector.insert when the extract position is a
2469/// prefix of the insert position and the remaining (un-indexed) dimensions
2470/// of the extracted sub-vector are all size 1. In that case the extracted
2471/// value is fully determined by the inserted value.
2472///
2473/// Examples:
2474/// %ins = vector.insert %s, %v [3, 0] : f32 into vector<16x1xf32>
2475/// %ext = vector.extract %ins [3] : vector<1xf32> from vector<16x1xf32>
2476/// folds to:
2477/// %ext = vector.broadcast %s : f32 to vector<1xf32>
2478///
2479/// %ins = vector.insert %s, %v [0, 0] : vector<1xf32> into vector<16x1x1xf32>
2480// %ext = vector.extract %ins [0] : vector<1x1xf32> from vector<16x1x1xf32>
2481/// folds to:
2482/// %ext = vector.shape_cast %arg0 : vector<1xf32> to vector<1x1xf32>
2483struct FoldExtractFromInsertUnitDim final
2484 : OpRewritePattern<vector::ExtractOp> {
2485 using Base::Base;
2486
2487 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2488 PatternRewriter &rewriter) const override {
2489 if (extractOp.hasDynamicPosition())
2490 return failure();
2491
2492 auto insertOp = extractOp.getSource().getDefiningOp<vector::InsertOp>();
2493 if (!insertOp || insertOp.hasDynamicPosition())
2494 return failure();
2495
2496 ArrayRef<int64_t> extractPos = extractOp.getStaticPosition();
2497 ArrayRef<int64_t> insertPos = insertOp.getStaticPosition();
2498
2499 // The extract position must be a strict prefix of the insert position.
2500 if (extractPos.size() >= insertPos.size() ||
2501 extractPos != insertPos.take_front(extractPos.size()))
2502 return failure();
2503
2504 // The remaining dimensions (those not indexed by the extract) must all
2505 // be size 1 in the source vector type. This guarantees that the inserted
2506 // value fully determines the extracted sub-vector.
2507 auto srcVecType = extractOp.getSourceVectorType();
2508 for (int64_t i = extractPos.size(), e = srcVecType.getRank(); i < e; ++i)
2509 if (srcVecType.getDimSize(i) != 1)
2510 return failure();
2511
2512 Value inserted = insertOp.getValueToStore();
2513 Type extractedType = extractOp.getResult().getType();
2514 if (isa<VectorType>(inserted.getType())) {
2515 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, extractedType,
2516 inserted);
2517 } else {
2518 // The inserted value fully determines the extracted sub-vector; broadcast
2519 // it to the extracted type.
2520 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2521 extractOp, extractOp.getResult().getType(),
2522 insertOp.getValueToStore());
2523 }
2524 return success();
2525 }
2526};
2527
2528} // namespace
2529
2530void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2531 MLIRContext *context) {
2532 results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
2533 ExtractOpFromConstantMask, ExtractToShapeCast,
2534 FoldExtractFromInsertUnitDim>(context);
2535 results.add(foldExtractFromShapeCastToShapeCast);
2536 results.add(foldExtractFromFromElements);
2537}
2538
2540 SmallVectorImpl<int64_t> &results) {
2541 for (auto attr : arrayAttr)
2542 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2543}
2544
2545//===----------------------------------------------------------------------===//
2546// FmaOp
2547//===----------------------------------------------------------------------===//
2548
2549std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2550 return llvm::to_vector<4>(getVectorType().getShape());
2551}
2552
2553//===----------------------------------------------------------------------===//
2554// ToElementsOp
2555//===----------------------------------------------------------------------===//
2556
2557/// Returns true if all the `operands` are defined by `defOp`.
2558/// Otherwise, returns false.
2559static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
2560 if (operands.empty())
2561 return false;
2562
2563 return llvm::all_of(operands, [&](Value operand) {
2564 Operation *currentDef = operand.getDefiningOp();
2565 return currentDef == defOp;
2566 });
2567}
2568
2569/// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
2570/// (%e0, %e1, ...). For example:
2571///
2572/// %0 = vector.from_elements %a, %b, %c : vector<3xf32>
2573/// %1:3 = vector.to_elements %0 : vector<3xf32>
2574/// user_op %1#0, %1#1, %1#2
2575///
2576/// becomes:
2577///
2578/// user_op %a, %b, %c
2579///
2580static LogicalResult
2581foldToElementsFromElements(ToElementsOp toElementsOp,
2583 auto fromElementsOp =
2584 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2585 if (!fromElementsOp)
2586 return failure();
2587
2588 llvm::append_range(results, fromElementsOp.getElements());
2589 return success();
2590}
2591
2592/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
2593///
2594/// Example:
2595/// %b = vector.broadcast %x : i32 to vector<3xf32>
2596/// %e:3 = vector.to_elements %b : vector<3xf32>
2597/// user_op %e#0, %e#1, %e#2
2598/// becomes:
2599/// user_op %x, %x, %x
2600///
2601/// The vector source case is handled by a canonicalization pattern.
2602static LogicalResult
2603foldToElementsOfBroadcast(ToElementsOp toElementsOp,
2605 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2606 if (!bcastOp)
2607 return failure();
2608 // Vectors are handled in the ToElementsOfBroadcast RewritePattern.
2609 if (isa<VectorType>(bcastOp.getSource().getType()))
2610 return failure();
2611
2612 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2613
2614 Value scalar = bcastOp.getSource();
2615 results.assign(resultVecType.getNumElements(), scalar);
2616 return success();
2617}
2618
2619LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2620 SmallVectorImpl<OpFoldResult> &results) {
2621 if (succeeded(foldToElementsFromElements(*this, results)))
2622 return success();
2623
2624 // Y = ToElements(ShapeCast(X)) -> Y = ToElements(X)
2625 if (auto shapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
2626 setOperand(shapeCast.getSource());
2627 return success();
2628 }
2629
2630 return foldToElementsOfBroadcast(*this, results);
2631}
2632
2633LogicalResult
2634ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2635 ToElementsOp::Adaptor adaptor,
2636 SmallVectorImpl<Type> &inferredReturnTypes) {
2637 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2638 Type elType = vecType.getElementType();
2639 inferredReturnTypes.append(vecType.getNumElements(), elType);
2640 return success();
2641}
2642
2643/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
2644/// vector.
2645/// - Build `vector.to_elements %v` and remap each destination element to the
2646/// corresponding source element using broadcast rules (match or 1 →
2647/// replicate).
2648///
2649/// Example:
2650/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
2651/// %e:6 = vector.to_elements %v : vector<3x2xf32>
2652/// becomes:
2653/// %src_elems:2 = vector.to_elements %src : vector<2xf32>
2654/// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2655/// // %src_elems#1, %src_elems#0, %src_elems#1
2656struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
2657 using Base::Base;
2658
2659 LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
2660 PatternRewriter &rewriter) const override {
2661 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2662 if (!bcastOp)
2663 return failure();
2664
2665 // Only handle broadcasts from a vector source here.
2666 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2667 if (!srcType)
2668 return failure();
2669
2670 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2671
2672 ArrayRef<int64_t> dstShape = dstType.getShape();
2673 ArrayRef<int64_t> srcShape = srcType.getShape();
2674
2675 int64_t dstRank = dstShape.size();
2676 int64_t srcRank = srcShape.size();
2677
2678 // Create elements for the broadcast source vector.
2679 auto srcElems = vector::ToElementsOp::create(
2680 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2681
2682 int64_t dstCount = llvm::product_of(dstShape);
2683
2684 SmallVector<Value> replacements;
2685 replacements.reserve(dstCount);
2686
2687 // For each element of the destination, determine which element of the
2688 // source should be used. We walk all destination positions using a single
2689 // counter, decode it into per-dimension indices, then build the matching
2690 // source position: use the same index where sizes match, and use 0 where
2691 // the source size is 1 (replication). This mapping is needed so we can
2692 // replace each result of to_elements with the corresponding element from
2693 // the broadcast source.
2694 // Inner-dimension stretch example:
2695 // %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
2696 // %e:12 = vector.to_elements %v : vector<2x3x2xf32>
2697 // becomes:
2698 // %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
2699 // // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2700 // // %src_elems#1, %src_elems#0, %src_elems#1,
2701 // // %src_elems#2, %src_elems#3, %src_elems#2,
2702 // // %src_elems#3, %src_elems#2, %src_elems#3
2703
2704 // Row-major strides for the destination shape.
2705 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
2706 // Row-major strides for the source shape.
2707 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
2708 SmallVector<int64_t> dstIdx(dstRank);
2709 SmallVector<int64_t> srcIdx(srcRank);
2710 for (int64_t lin = 0; lin < dstCount; ++lin) {
2711 // Convert linear destination index to per-dimension indices.
2712 dstIdx = delinearize(lin, dstStrides);
2713 for (int64_t k = 0; k < srcRank; ++k)
2714 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2715 // Convert per-dimension source indices back to a linear index.
2716 int64_t srcLin = linearize(srcIdx, srcStrides);
2717 replacements.push_back(srcElems.getResult(srcLin));
2718 }
2719
2720 rewriter.replaceOp(toElementsOp, replacements);
2721 return success();
2722 }
2723};
2724
2725void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2726 MLIRContext *context) {
2727 results.add<ToElementsOfBroadcast>(context);
2728}
2729
2730//===----------------------------------------------------------------------===//
2731// FromElementsOp
2732//===----------------------------------------------------------------------===//
2733
2734/// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
2735///
2736/// Case #1: Input and output vectors are the same.
2737///
2738/// %0:3 = vector.to_elements %a : vector<3xf32>
2739/// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
2740/// user_op %1
2741///
2742/// becomes:
2743///
2744/// user_op %a
2745///
2746static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
2747 OperandRange fromElemsOperands = fromElementsOp.getElements();
2748 if (fromElemsOperands.empty())
2749 return {};
2750
2751 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2752 if (!toElementsOp)
2753 return {};
2754
2755 if (!haveSameDefiningOp(fromElemsOperands, toElementsOp))
2756 return {};
2757
2758 // Case #1: Input and output vectors are the same. Forward the input vector.
2759 Value toElementsInput = toElementsOp.getSource();
2760 if (fromElementsOp.getType() == toElementsInput.getType() &&
2761 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2762 return toElementsInput;
2763 }
2764
2765 // TODO: Support cases with different input and output shapes and different
2766 // number of elements.
2767
2768 return {};
2769}
2770
2771/// Fold vector.from_elements to a constant when all operands are constants.
2772/// Example:
2773/// %c1 = arith.constant 1 : i32
2774/// %c2 = arith.constant 2 : i32
2775/// %v = vector.from_elements %c1, %c2 : vector<2xi32>
2776/// =>
2777/// %v = arith.constant dense<[1, 2]> : vector<2xi32>
2778///
2779static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
2780 ArrayRef<Attribute> elements) {
2781 // Check for null or poison attributes before any processing.
2782 if (llvm::any_of(elements, [](Attribute attr) {
2783 return !attr || matchPattern(attr, ub::m_Poison());
2784 }))
2785 return {};
2786
2787 // DenseElementsAttr only supports int/index/float/complex types.
2788 auto destVecType = fromElementsOp.getDest().getType();
2789 auto destEltType = destVecType.getElementType();
2790 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2791 return {};
2792
2793 // Constant attributes might have a different type than the return type.
2794 // Convert them before creating the dense elements attribute.
2795 auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
2796 return convertNumericAttr(attr, destEltType);
2797 });
2798
2799 return DenseElementsAttr::get(destVecType, convertedElements);
2800}
2801
2802OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2803 if (auto res = foldFromElementsToElements(*this))
2804 return res;
2805 if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
2806 return res;
2807
2808 return {};
2809}
2810
2811/// Rewrite vector.from_elements as vector.broadcast if the elements are the
2812/// same. Example:
2813/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2814/// =>
2815/// %0 = vector.broadcast %a : f32 to vector<3xf32>
2816static LogicalResult
2817rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
2818 PatternRewriter &rewriter) {
2819 if (!llvm::all_equal(fromElementsOp.getElements()))
2820 return failure();
2821 rewriter.replaceOpWithNewOp<BroadcastOp>(
2822 fromElementsOp, fromElementsOp.getType(),
2823 fromElementsOp.getElements().front());
2824 return success();
2825}
2826
2827/// Rewrite from_elements on multiple scalar extracts as a shape_cast
2828/// on a single extract. Example:
2829/// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
2830/// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
2831/// %2 = vector.from_elements %0, %1 : vector<2xi8>
2832///
2833/// becomes
2834/// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
2835/// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
2836///
2837/// The requirements for this to be valid are
2838///
2839/// i) The elements are extracted from the same vector (%source).
2840///
2841/// ii) The elements form a suffix of %source. Specifically, the number
2842/// of elements is the same as the product of the last N dimension sizes
2843/// of %source, for some N.
2844///
2845/// iii) The elements are extracted contiguously in ascending order.
2846
2847class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
2848
2849 using Base::Base;
2850
2851 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2852 PatternRewriter &rewriter) const override {
2853
2854 // Handled by `rewriteFromElementsAsBroadcast`.
2855 if (fromElements.getType().getNumElements() == 1)
2856 return failure();
2857
2858 // The common source that all elements are extracted from, if one exists.
2860 // The position of the combined extract operation, if one is created.
2861 ArrayRef<int64_t> combinedPosition;
2862 // The expected index of extraction of the current element in the loop, if
2863 // elements are extracted contiguously in ascending order.
2864 SmallVector<int64_t> expectedPosition;
2865
2866 for (auto [insertIndex, element] :
2867 llvm::enumerate(fromElements.getElements())) {
2868
2869 // Check that the element is from a vector.extract operation.
2870 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2871 if (!extractOp) {
2872 return rewriter.notifyMatchFailure(fromElements,
2873 "element not from vector.extract");
2874 }
2875
2876 // Check condition (i) by checking that all elements have the same source
2877 // as the first element.
2878 if (insertIndex == 0) {
2879 source = extractOp.getSource();
2880 } else if (extractOp.getSource() != source) {
2881 return rewriter.notifyMatchFailure(fromElements,
2882 "element from different vector");
2883 }
2884
2885 ArrayRef<int64_t> position = extractOp.getStaticPosition();
2886 int64_t rank = position.size();
2887 assert(rank == source.getType().getRank() &&
2888 "scalar extract must have full rank position");
2889
2890 // Check condition (ii) by checking that the position that the first
2891 // element is extracted from has sufficient trailing 0s. For example, in
2892 //
2893 // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
2894 // [...]
2895 // %elms = vector.from_elements %elm0, [...] : vector<12xi8>
2896 //
2897 // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
2898 // elements, which is the number of elements of %n, so this is valid.
2899 if (insertIndex == 0) {
2900 const int64_t numElms = fromElements.getType().getNumElements();
2901 int64_t numSuffixElms = 1;
2902 int64_t index = rank;
2903 while (index > 0 && position[index - 1] == 0 &&
2904 numSuffixElms < numElms) {
2905 numSuffixElms *= source.getType().getDimSize(index - 1);
2906 --index;
2907 }
2908 if (numSuffixElms != numElms) {
2909 return rewriter.notifyMatchFailure(
2910 fromElements, "elements do not form a suffix of source");
2911 }
2912 expectedPosition = llvm::to_vector(position);
2913 combinedPosition = position.drop_back(rank - index);
2914 }
2915
2916 // Check condition (iii).
2917 else if (expectedPosition != position) {
2918 return rewriter.notifyMatchFailure(
2919 fromElements, "elements not in ascending order (static order)");
2920 }
2921 increment(expectedPosition, source.getType().getShape());
2922 }
2923
2924 auto extracted = rewriter.createOrFold<vector::ExtractOp>(
2925 fromElements.getLoc(), source, combinedPosition);
2926
2927 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2928 fromElements, fromElements.getType(), extracted);
2929
2930 return success();
2931 }
2932
2933 /// Increments n-D `indices` by 1 starting from the innermost dimension.
2934 static void increment(MutableArrayRef<int64_t> indices,
2936 for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2937 indices[dim] += 1;
2938 if (indices[dim] < shape[dim])
2939 break;
2940 indices[dim] = 0;
2941 }
2942 }
2943};
2944
2945void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2946 MLIRContext *context) {
2948 results.add<FromElementsToShapeCast>(context);
2949}
2950
2951//===----------------------------------------------------------------------===//
2952// BroadcastOp
2953//===----------------------------------------------------------------------===//
2954
2955void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2956 SetIntRangeFn setResultRanges) {
2957 setResultRanges(getResult(), argRanges.front());
2958}
2959
2960std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2961 return llvm::to_vector<4>(getResultVectorType().getShape());
2962}
2963
2964/// Return the dimensions of the result vector that were formerly ones in the
2965/// source tensor and thus correspond to "dim-1" broadcasting.
2966static llvm::SetVector<int64_t>
2968 ArrayRef<int64_t> dstShape) {
2969 int64_t rankDiff = dstShape.size() - srcShape.size();
2970 int64_t dstDim = rankDiff;
2972 for (auto [s1, s2] :
2973 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2974 if (s1 != s2) {
2975 assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2976 res.insert(dstDim);
2977 }
2978 ++dstDim;
2979 }
2980 return res;
2981}
2982
2983llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2984 // Scalar broadcast is without any unit dim broadcast.
2985 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2986 if (!srcVectorType)
2987 return {};
2988 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2989 getResultVectorType().getShape());
2990}
2991
2992/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2993/// `broadcastedDims` dimensions in the dstShape are broadcasted.
2994/// This requires (and asserts) that the broadcast is free of "dim-1"
2995/// broadcasting.
2996/// Since vector.broadcast only allows expanding leading dimensions, an extra
2997/// vector.transpose may be inserted to make the broadcast possible.
2998/// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2999/// the helper will assert. This means:
3000/// 1. `dstShape` must not be empty.
3001/// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
3002/// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
3003// must match the `value` shape.
3004Value BroadcastOp::createOrFoldBroadcastOp(
3005 OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
3006 const llvm::SetVector<int64_t> &broadcastedDims) {
3007 assert(!dstShape.empty() && "unexpected empty dst shape");
3008
3009 // Well-formedness check.
3010 SmallVector<int64_t> checkShape;
3011 for (int i = 0, e = dstShape.size(); i < e; ++i) {
3012 if (broadcastedDims.contains(i))
3013 continue;
3014 checkShape.push_back(dstShape[i]);
3015 }
3016 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
3017 "ill-formed broadcastedDims contains values not confined to "
3018 "destVectorShape");
3019
3020 Location loc = value.getLoc();
3021 Type elementType = getElementTypeOrSelf(value.getType());
3022 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
3023 VectorType dstVectorType = VectorType::get(dstShape, elementType);
3024
3025 // Step 2. If scalar -> dstShape broadcast, just do it.
3026 if (!srcVectorType) {
3027 assert(checkShape.empty() &&
3028 "ill-formed createOrFoldBroadcastOp arguments");
3029 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
3030 }
3031
3032 assert(srcVectorType.getShape().equals(checkShape) &&
3033 "ill-formed createOrFoldBroadcastOp arguments");
3034
3035 // Step 3. Since vector.broadcast only allows creating leading dims,
3036 // vector -> dstShape broadcast may require a transpose.
3037 // Traverse the dims in order and construct:
3038 // 1. The leading entries of the broadcastShape that is guaranteed to be
3039 // achievable by a simple broadcast.
3040 // 2. The induced permutation for the subsequent vector.transpose that will
3041 // bring us from `broadcastShape` back to he desired `dstShape`.
3042 // If the induced permutation is not the identity, create a vector.transpose.
3043 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
3044 broadcastShape.reserve(dstShape.size());
3045 // Consider the example:
3046 // srcShape = 2x4
3047 // dstShape = 1x2x3x4x5
3048 // broadcastedDims = [0, 2, 4]
3049 //
3050 // We want to build:
3051 // broadcastShape = 1x3x5x2x4
3052 // permutation = [0, 2, 4, 1, 3]
3053 // ---V--- -----V-----
3054 // leading broadcast part src shape part
3055 //
3056 // Note that the trailing dims of broadcastShape are exactly the srcShape
3057 // by construction.
3058 // nextSrcShapeDim is used to keep track of where in the permutation the
3059 // "src shape part" occurs.
3060 int64_t nextSrcShapeDim = broadcastedDims.size();
3061 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
3062 if (broadcastedDims.contains(i)) {
3063 // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
3064 // bring it to the head of the broadcastShape.
3065 // It will need to be permuted back from `broadcastShape.size() - 1` into
3066 // position `i`.
3067 broadcastShape.push_back(dstShape[i]);
3068 permutation[i] = broadcastShape.size() - 1;
3069 } else {
3070 // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
3071 // shape and needs to be permuted into position `i`.
3072 // Don't touch `broadcastShape` here, the whole srcShape will be
3073 // appended after.
3074 permutation[i] = nextSrcShapeDim++;
3075 }
3076 }
3077 // 3.c. Append the srcShape.
3078 llvm::append_range(broadcastShape, srcVectorType.getShape());
3079
3080 // Ensure there are no "dim-1" broadcasts.
3081 assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
3082 .empty() &&
3083 "unexpected \"dim-1\" broadcast");
3084
3085 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
3086 assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
3087 vector::BroadcastableToResult::Success &&
3088 "must be broadcastable");
3089 Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
3090 // Step 4. If we find any dimension that indeed needs to be permuted,
3091 // immediately return a new vector.transpose.
3092 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
3093 if (permutation[i] != i)
3094 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
3095 // Otherwise return res.
3096 return res;
3097}
3098
3100 Type srcType, VectorType dstVectorType,
3101 std::pair<VectorDim, VectorDim> *mismatchingDims) {
3102 // Broadcast scalar to vector of the same element type.
3103 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
3104 srcType == getElementTypeOrSelf(dstVectorType))
3106 // From now on, only vectors broadcast.
3107 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
3108 if (!srcVectorType)
3110
3111 int64_t srcRank = srcVectorType.getRank();
3112 int64_t dstRank = dstVectorType.getRank();
3113 if (srcRank > dstRank)
3115 // Source has an exact match or singleton value for all trailing dimensions
3116 // (all leading dimensions are simply duplicated).
3117 int64_t lead = dstRank - srcRank;
3118 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
3119 // Have mismatching dims (in the sense of vector.broadcast semantics) been
3120 // encountered?
3121 bool foundMismatchingDims = false;
3122
3123 // Check fixed-width dims.
3124 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
3125 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
3126 if (srcDim != 1 && srcDim != dstDim)
3127 foundMismatchingDims = true;
3128
3129 // Check scalable flags.
3130 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
3131 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
3132 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
3133 // 1 -> [N] is fine, everything else should be rejected when mixing
3134 // fixed-width and scalable dims
3135 (srcDimScalableFlag != dstDimScalableFlag &&
3136 (srcDim != 1 || srcDimScalableFlag)))
3137 foundMismatchingDims = true;
3138
3139 if (foundMismatchingDims) {
3140 if (mismatchingDims != nullptr) {
3141 mismatchingDims->first.dim = srcDim;
3142 mismatchingDims->first.isScalable = srcDimScalableFlag;
3143
3144 mismatchingDims->second.dim = dstDim;
3145 mismatchingDims->second.isScalable = dstDimScalableFlag;
3146 }
3148 }
3149 }
3150
3152}
3153
3154LogicalResult BroadcastOp::verify() {
3155 std::pair<VectorDim, VectorDim> mismatchingDims;
3157 getSourceType(), getResultVectorType(), &mismatchingDims);
3159 return success();
3161 return emitOpError("source rank higher than destination rank");
3163 return emitOpError("dimension mismatch (")
3164 << (mismatchingDims.first.isScalable ? "[" : "")
3165 << mismatchingDims.first.dim
3166 << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
3167 << (mismatchingDims.second.isScalable ? "[" : "")
3168 << mismatchingDims.second.dim
3169 << (mismatchingDims.second.isScalable ? "]" : "") << ")";
3170 }
3172 return emitOpError("source type is not a vector");
3173 llvm_unreachable("unexpected vector.broadcast op error");
3174}
3175
3176// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
3177// with broadcast's result type and shape_cast only adds or removes ones in the
3178// leading dimensions.
3179static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
3180 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3181 if (!srcShapeCast)
3182 return failure();
3183
3184 VectorType srcType = srcShapeCast.getSourceVectorType();
3185 VectorType destType = broadcastOp.getResultVectorType();
3186 // Check type compatibility.
3187 if (vector::isBroadcastableTo(srcType, destType) !=
3189 return failure();
3190
3191 ArrayRef<int64_t> srcShape = srcType.getShape();
3192 ArrayRef<int64_t> shapecastShape =
3193 srcShapeCast.getResultVectorType().getShape();
3194 // Trailing dimensions should be the same if shape_cast only alters the
3195 // leading dimensions.
3196 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3197 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3198 shapecastShape.take_back(numTrailingDims)))
3199 return failure();
3200
3201 assert(all_of(srcShape.drop_back(numTrailingDims),
3202 [](int64_t E) { return E == 1; }) &&
3203 all_of(shapecastShape.drop_back(numTrailingDims),
3204 [](int64_t E) { return E == 1; }) &&
3205 "ill-formed shape_cast");
3206
3207 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3208 return success();
3209}
3210
3211OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3212 if (getSourceType() == getResultVectorType())
3213 return getSource();
3214 if (succeeded(foldBroadcastOfShapeCast(*this)))
3215 return getResult();
3216
3217 if (!adaptor.getSource())
3218 return {};
3219 auto vectorType = getResultVectorType();
3220 if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3221 if (vectorType.getElementType() != attr.getType())
3222 return {};
3223 return DenseElementsAttr::get(vectorType, attr);
3224 }
3225 if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3226 if (vectorType.getElementType() != attr.getType())
3227 return {};
3228 return DenseElementsAttr::get(vectorType, attr);
3229 }
3230 if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3231 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
3232 if (matchPattern(adaptor.getSource(), ub::m_Poison()))
3233 return ub::PoisonAttr::get(getContext());
3234 return {};
3235}
3236
3237namespace {
3238
3239// Fold broadcast1(broadcast2(x)) into broadcast1(x).
3240struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
3241 using Base::Base;
3242
3243 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3244 PatternRewriter &rewriter) const override {
3245 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3246 if (!srcBroadcast)
3247 return failure();
3248 rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
3249 broadcastOp.getResultVectorType(),
3250 srcBroadcast.getSource());
3251 return success();
3252 }
3253};
3254
3255/// Replace `vector.broadcast` with `vector.shape_cast`.
3256///
3257/// BEFORE:
3258/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
3259/// AFTER:
3260/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
3261///
3262/// The canonical form of vector operations that reshape vectors is shape_cast.
3263struct BroadcastToShapeCast final
3264 : public OpRewritePattern<vector::BroadcastOp> {
3265 using Base::Base;
3266 LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
3267 PatternRewriter &rewriter) const override {
3268
3269 auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
3270 if (!sourceType) {
3271 return rewriter.notifyMatchFailure(
3272 broadcast, "source is a scalar, shape_cast doesn't support scalar");
3273 }
3274
3275 VectorType outType = broadcast.getType();
3276 if (sourceType.getNumElements() != outType.getNumElements()) {
3277 return rewriter.notifyMatchFailure(
3278 broadcast, "broadcast to a greater number of elements");
3279 }
3280
3281 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
3282 broadcast.getSource());
3283 return success();
3284 }
3285};
3286} // namespace
3287
3288void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3289 MLIRContext *context) {
3290 results.add<BroadcastFolder, BroadcastToShapeCast>(context);
3291}
3292
3293//===----------------------------------------------------------------------===//
3294// ShuffleOp
3295//===----------------------------------------------------------------------===//
3296
3297LogicalResult ShuffleOp::verify() {
3298 VectorType resultType = getResultVectorType();
3299 VectorType v1Type = getV1VectorType();
3300 VectorType v2Type = getV2VectorType();
3301 // Verify ranks.
3302 int64_t resRank = resultType.getRank();
3303 int64_t v1Rank = v1Type.getRank();
3304 int64_t v2Rank = v2Type.getRank();
3305 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3306 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3307 if (!wellFormed0DCase && !wellFormedNDCase)
3308 return emitOpError("rank mismatch");
3309
3310 // Verify all but leading dimension sizes.
3311 for (int64_t r = 1; r < v1Rank; ++r) {
3312 int64_t resDim = resultType.getDimSize(r);
3313 int64_t v1Dim = v1Type.getDimSize(r);
3314 int64_t v2Dim = v2Type.getDimSize(r);
3315 if (resDim != v1Dim || v1Dim != v2Dim)
3316 return emitOpError("dimension mismatch");
3317 }
3318 // Verify mask length.
3319 ArrayRef<int64_t> mask = getMask();
3320 int64_t maskLength = mask.size();
3321 if (maskLength <= 0)
3322 return emitOpError("invalid mask length");
3323 if (maskLength != resultType.getDimSize(0))
3324 return emitOpError("mask length mismatch");
3325 // Verify all indices.
3326 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3327 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3328 for (auto [idx, maskPos] : llvm::enumerate(mask)) {
3329 if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
3330 return emitOpError("mask index #") << (idx + 1) << " out of range";
3331 }
3332 return success();
3333}
3334
3335LogicalResult
3336ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location> loc,
3337 ShuffleOp::Adaptor adaptor,
3338 SmallVectorImpl<Type> &inferredReturnTypes) {
3339 auto v1Type = llvm::dyn_cast<VectorType>(adaptor.getV1().getType());
3340 if (!v1Type) {
3341 return emitOptionalError(loc, "expected vector type");
3342 }
3343 auto v1Rank = v1Type.getRank();
3344 // Construct resulting type: leading dimension matches mask
3345 // length, all trailing dimensions match the operands.
3346 SmallVector<int64_t, 4> shape;
3347 shape.reserve(v1Rank);
3348 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3349 // In the 0-D case there is no trailing shape to append.
3350 if (v1Rank > 0)
3351 llvm::append_range(shape, v1Type.getShape().drop_front());
3352 inferredReturnTypes.push_back(
3353 VectorType::get(shape, v1Type.getElementType()));
3354 return success();
3355}
3356
3357template <typename T>
3358static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
3359 T expected = begin;
3360 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3361 return value == expected++;
3362 });
3363}
3364
3365/// Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
3366/// Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
3368 auto v1Type = op.getV1VectorType();
3369 auto v2Type = op.getV2VectorType();
3370 auto mask = op.getMask();
3371 if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
3372 return op.getV1();
3373 if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
3374 return op.getV2();
3375 return {};
3376}
3377
3378/// If a shuffle operand is poison, replace all mask indices that reference it
3379/// with kPoisonIndex. This is an in-place fold.
3381 bool isV1Poison = matchPattern(op.getV1(), ub::m_Poison());
3382 bool isV2Poison = matchPattern(op.getV2(), ub::m_Poison());
3383 if (!isV1Poison && !isV2Poison)
3384 return {};
3385
3386 int64_t v1Size = op.getV1VectorType().getDimSize(0);
3387 bool changed = false;
3388 SmallVector<int64_t> newMask = llvm::to_vector(op.getMask());
3389 for (int64_t &idx : newMask) {
3390 if (idx == ShuffleOp::kPoisonIndex)
3391 continue;
3392 if ((isV1Poison && idx < v1Size) || (isV2Poison && idx >= v1Size)) {
3393 idx = ShuffleOp::kPoisonIndex;
3394 changed = true;
3395 }
3396 }
3397
3398 if (!changed)
3399 return {};
3400
3401 op.setMask(newMask);
3402 return op.getResult();
3403}
3404
3405/// Fold shuffle poison, poison -> poison.
3407 Attribute v1Attr,
3408 Attribute v2Attr) {
3409 if (matchPattern(v1Attr, ub::m_Poison()) &&
3410 matchPattern(v2Attr, ub::m_Poison()))
3411 return ub::PoisonAttr::get(context);
3412 return {};
3413}
3414
3415/// Fold a shuffle of constant 1-D inputs by evaluating the mask.
3417 Attribute v2Attr) {
3418 auto v1Type = op.getV1VectorType();
3419 if (v1Type.getRank() != 1)
3420 return {};
3421
3422 bool isV1Poison = matchPattern(v1Attr, ub::m_Poison());
3423 bool isV2Poison = matchPattern(v2Attr, ub::m_Poison());
3424
3425 // Poison input attributes need special handling as they are not
3426 // DenseElementsAttr. If an index is poison, we select the first element of
3427 // the first non-poison input.
3428 SmallVector<Attribute> v1Elements, v2Elements;
3429 Attribute poisonElement;
3430 if (!isV2Poison) {
3431 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3432 if (!v2DenseAttr)
3433 return {};
3434 v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3435 poisonElement = v2Elements[0];
3436 }
3437 if (!isV1Poison) {
3438 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3439 if (!v1DenseAttr)
3440 return {};
3441 v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3442 poisonElement = v1Elements[0];
3443 }
3444
3445 ArrayRef<int64_t> mask = op.getMask();
3446 SmallVector<Attribute> results;
3447 int64_t v1Size = v1Type.getDimSize(0);
3448 for (int64_t maskIdx : mask) {
3449 Attribute indexedElm;
3450 // TODO: Return a partial poison vector when supported by the UB dialect.
3451 if (maskIdx == ShuffleOp::kPoisonIndex) {
3452 indexedElm = poisonElement;
3453 } else {
3454 if (maskIdx < v1Size)
3455 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3456 else
3457 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3458 }
3459
3460 results.push_back(indexedElm);
3461 }
3462
3463 return DenseElementsAttr::get(op.getResultVectorType(), results);
3464}
3465
3466OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3467 auto v1Type = getV1VectorType();
3468
3469 assert(!v1Type.isScalable() && !getV2VectorType().isScalable() &&
3470 "Vector shuffle does not support scalable vectors");
3471
3472 // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
3473 // but must be a canonicalization into a vector.broadcast.
3474 if (v1Type.getRank() == 0)
3475 return {};
3476
3477 if (auto res = foldShuffleIdentityMask(*this))
3478 return res;
3479 if (auto res = foldShufflePoisonOperandToMask(*this))
3480 return res;
3481
3482 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3483 if (!v1Attr || !v2Attr)
3484 return {};
3485
3486 if (auto res = foldShufflePoisonInputs(getContext(), v1Attr, v2Attr))
3487 return res;
3488 if (auto res = foldShuffleConstantInputs(*this, v1Attr, v2Attr))
3489 return res;
3490
3491 return {};
3492}
3493
3494namespace {
3495
3496// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
3497// to a broadcast.
3498struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
3499 using Base::Base;
3500
3501 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3502 PatternRewriter &rewriter) const override {
3503 VectorType v1VectorType = shuffleOp.getV1VectorType();
3504 ArrayRef<int64_t> mask = shuffleOp.getMask();
3505 if (v1VectorType.getRank() > 0)
3506 return failure();
3507 if (mask.size() != 1)
3508 return failure();
3509 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3510 if (mask[0] == 0)
3511 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3512 shuffleOp.getV1());
3513 else
3514 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3515 shuffleOp.getV2());
3516 return success();
3517 }
3518};
3519
3520/// Consider the defining operation `defOp` of `value`. If `defOp` is a
3521/// vector.broadcast with a scalar operand, return the scalar value that is
3522/// splatted. Otherwise return null.
3523///
3524/// Example:
3525///
3526/// scalar_source --> vector.broadcast --> value - return scalar_source
3527static Value getScalarSplatSource(Value value) {
3528 // Block argument:
3529 Operation *defOp = value.getDefiningOp();
3530 if (!defOp)
3531 return {};
3532
3533 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3534
3535 // Not broadcast (and not splat):
3536 if (!broadcast)
3537 return {};
3538
3539 // Broadcast of a vector:
3540 if (isa<VectorType>(broadcast.getSourceType()))
3541 return {};
3542
3543 // Broadcast of a scalar:
3544 return broadcast.getSource();
3545}
3546
3547/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
3548class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
3549public:
3550 using Base::Base;
3551
3552 LogicalResult matchAndRewrite(ShuffleOp op,
3553 PatternRewriter &rewriter) const override {
3554 Value splat = getScalarSplatSource(op.getV1());
3555 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3556 return failure();
3557
3558 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3559 return success();
3560 }
3561};
3562
3563/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
3564/// vector.interleave.
3565class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
3566public:
3567 using Base::Base;
3568
3569 LogicalResult matchAndRewrite(ShuffleOp op,
3570 PatternRewriter &rewriter) const override {
3571 VectorType resultType = op.getResultVectorType();
3572 if (resultType.isScalable())
3573 return rewriter.notifyMatchFailure(
3574 op, "ShuffleOp can't represent a scalable interleave");
3575
3576 if (resultType.getRank() != 1)
3577 return rewriter.notifyMatchFailure(
3578 op, "ShuffleOp can't represent an n-D interleave");
3579
3580 VectorType sourceType = op.getV1VectorType();
3581 if (sourceType != op.getV2VectorType() ||
3582 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3583 return rewriter.notifyMatchFailure(
3584 op, "ShuffleOp types don't match an interleave");
3585 }
3586
3587 ArrayRef<int64_t> shuffleMask = op.getMask();
3588 int64_t resultVectorSize = resultType.getNumElements();
3589 for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3590 int64_t maskValueA = shuffleMask[i * 2];
3591 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3592 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3593 return rewriter.notifyMatchFailure(op,
3594 "ShuffleOp mask not interleaving");
3595 }
3596
3597 rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
3598 return success();
3599 }
3600};
3601
3602/// Pattern to replace usused shuffle operands / results with poison.
3603///
3604/// Example Input:
3605/// %r = vector.shuffle %v1, %v2 [2, 3, 3, 3] : vector<2xi32>, vector<2xi32>
3606///
3607/// Example Output:
3608/// %0 = ub.poison : vector<2xi32>
3609/// %r = vector.shuffle %0, %v2 [2, 3, 3, 3] : vector<2xi32>, vector<2xi32>
3610class FoldUnusedShuffleOperand final : public OpRewritePattern<ShuffleOp> {
3611public:
3612 using Base::Base;
3613
3614 LogicalResult matchAndRewrite(ShuffleOp op,
3615 PatternRewriter &rewriter) const override {
3616 // Replace with poison if all mask elements are poison.
3617 if (llvm::all_of(op.getMask(), [](int64_t mask) {
3618 return mask == ShuffleOp::kPoisonIndex;
3619 })) {
3620 rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getType());
3621 return success();
3622 }
3623
3624 // Helper function to replace an operand with poison.
3625 auto replaceOperandWithPoison = [&](OpOperand &operand) {
3626 // Do not replace if the operand is already poison.
3627 if (!matchPattern(operand.get(), ub::m_Poison())) {
3628 Value poison = ub::PoisonOp::create(rewriter, op.getLoc(),
3629 operand.get().getType());
3630 rewriter.modifyOpInPlace(op, [&]() { operand.set(poison); });
3631 return success();
3632 }
3633 return failure();
3634 };
3635
3636 // Replace V1 with poison if it is not used.
3637 int64_t leadingV1Size = op.getV1VectorType().getRank() > 0
3638 ? op.getV1VectorType().getDimSize(0)
3639 : 1;
3640 bool isV1Used = llvm::any_of(op.getMask(), [&](int64_t mask) {
3641 return mask != ShuffleOp::kPoisonIndex && mask < leadingV1Size;
3642 });
3643 if (!isV1Used && succeeded(replaceOperandWithPoison(op.getV1Mutable())))
3644 return success();
3645
3646 // Replace V2 with poison if it is not used.
3647 bool isV2Used = llvm::any_of(op.getMask(), [&](int64_t mask) {
3648 return mask != ShuffleOp::kPoisonIndex && mask >= leadingV1Size;
3649 });
3650 if (!isV2Used && succeeded(replaceOperandWithPoison(op.getV2Mutable())))
3651 return success();
3652
3653 return failure();
3654 }
3655};
3656} // namespace
3657
3658void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3659 MLIRContext *context) {
3660 results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp,
3661 FoldUnusedShuffleOperand>(context);
3662}
3663
3664//===----------------------------------------------------------------------===//
3665// InsertOp
3666//===----------------------------------------------------------------------===//
3667
3668void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3669 SetIntRangeFn setResultRanges) {
3670 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3671}
3672
3673void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3674 Value source, Value dest) {
3675 auto vectorTy = cast<VectorType>(dest.getType());
3676 build(builder, result, source, dest,
3677 SmallVector<int64_t>(vectorTy.getRank(), 0));
3678}
3679
3680void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3681 Value source, Value dest, int64_t position) {
3682 build(builder, result, source, dest, ArrayRef<int64_t>{position});
3683}
3684
3685void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3686 Value source, Value dest, OpFoldResult position) {
3687 build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
3688}
3689
3690void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3691 Value source, Value dest,
3692 ArrayRef<int64_t> position) {
3693 SmallVector<OpFoldResult> posVals;
3694 posVals.reserve(position.size());
3695 llvm::transform(position, std::back_inserter(posVals),
3696 [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
3697 build(builder, result, source, dest, posVals);
3698}
3699
3700void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3701 Value source, Value dest,
3702 ArrayRef<OpFoldResult> position) {
3703 SmallVector<int64_t> staticPos;
3704 SmallVector<Value> dynamicPos;
3705 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
3706 build(builder, result, source, dest, dynamicPos,
3707 builder.getDenseI64ArrayAttr(staticPos));
3708}
3709
3710LogicalResult InsertOp::verify() {
3711 if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3712 if (srcTy.getRank() == 0)
3713 return emitError(
3714 "expected a scalar instead of a 0-d vector as the source operand");
3715
3716 SmallVector<OpFoldResult> position = getMixedPosition();
3717 auto destVectorType = getDestVectorType();
3718 if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
3719 return emitOpError(
3720 "expected position attribute of rank no greater than dest vector rank");
3721 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3722 if (srcVectorType &&
3723 (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3724 static_cast<unsigned>(destVectorType.getRank())))
3725 return emitOpError("expected position attribute rank + source rank to "
3726 "match dest vector rank");
3727 if (!srcVectorType &&
3728 (position.size() != static_cast<unsigned>(destVectorType.getRank())))
3729 return emitOpError(
3730 "expected position attribute rank to match the dest vector rank");
3731 for (auto [idx, pos] : llvm::enumerate(position)) {
3732 if (auto attr = dyn_cast<Attribute>(pos)) {
3733 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3734 if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
3735 destVectorType.getDimSize(idx))) {
3736 return emitOpError("expected position attribute #")
3737 << (idx + 1)
3738 << " to be a non-negative integer smaller than the "
3739 "corresponding "
3740 "dest vector dimension";
3741 }
3742 }
3743 }
3744 return success();
3745}
3746
3747// Calculate the linearized position of the continuous chunk of elements to
3748// insert, based on the shape of the value to insert and the positions to insert
3749// at.
3750static int64_t calculateInsertPosition(VectorType destTy,
3751 ArrayRef<int64_t> positions) {
3752 llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3753 assert(positions.size() <= completePositions.size() &&
3754 "positions size must be less than or equal to destTy rank");
3755 copy(positions, completePositions.begin());
3756 return linearize(completePositions, computeStrides(destTy.getShape()));
3757}
3758
3759namespace {
3760
3761// If insertOp is only inserting unit dimensions it can be transformed to a
3762// broadcast.
3763class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3764public:
3765 using Base::Base;
3766
3767 LogicalResult matchAndRewrite(InsertOp insertOp,
3768 PatternRewriter &rewriter) const override {
3769 auto srcVecType =
3770 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3771 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3772 srcVecType.getNumElements())
3773 return failure();
3774 rewriter.replaceOpWithNewOp<BroadcastOp>(
3775 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3776 return success();
3777 }
3778};
3779
3780/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
3781class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3782public:
3783 using Base::Base;
3784
3785 LogicalResult matchAndRewrite(InsertOp op,
3786 PatternRewriter &rewriter) const override {
3787
3788 Value splat = getScalarSplatSource(op.getValueToStore());
3789 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3790 return failure();
3791
3792 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3793 return success();
3794 }
3795};
3796
3797/// Pattern to optimize a chain of insertions.
3798///
3799/// This pattern identifies chains of vector.insert operations that:
3800/// 1. Only insert values at static positions.
3801/// 2. Completely initialize all elements in the resulting vector.
3802/// 3. All intermediate insert operations have only one use.
3803///
3804/// When these conditions are met, the entire chain can be replaced with a
3805/// single vector.from_elements operation.
3806///
3807/// To keep this pattern simple, and avoid spending too much time on matching
3808/// fragmented insert chains, this pattern only considers the last insert op in
3809/// the chain.
3810///
3811/// Example transformation:
3812/// %poison = ub.poison : vector<2xi32>
3813/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
3814/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
3815/// ->
3816/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
3817class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
3818public:
3819 using Base::Base;
3820 LogicalResult matchAndRewrite(InsertOp op,
3821 PatternRewriter &rewriter) const override {
3822
3823 VectorType destTy = op.getDestVectorType();
3824 if (destTy.isScalable())
3825 return failure();
3826 // Ensure this is the trailing vector.insert op in a chain of inserts.
3827 for (Operation *user : op.getResult().getUsers())
3828 if (auto insertOp = dyn_cast<InsertOp>(user))
3829 if (insertOp.getDest() == op.getResult())
3830 return failure();
3831
3832 InsertOp currentOp = op;
3833 SmallVector<InsertOp> chainInsertOps;
3834 while (currentOp) {
3835 // Check cond 1: Dynamic position is not supported.
3836 if (currentOp.hasDynamicPosition())
3837 return failure();
3838
3839 chainInsertOps.push_back(currentOp);
3840 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3841 // Check cond 3: Intermediate inserts have only one use to avoid an
3842 // explosion of vectors.
3843 if (currentOp && !currentOp->hasOneUse())
3844 return failure();
3845 }
3846
3847 int64_t vectorSize = destTy.getNumElements();
3848 int64_t initializedCount = 0;
3849 SmallVector<bool> initializedDestIdxs(vectorSize, false);
3850 SmallVector<int64_t> pendingInsertPos;
3851 SmallVector<int64_t> pendingInsertSize;
3852 SmallVector<Value> pendingInsertValues;
3853
3854 for (auto insertOp : chainInsertOps) {
3855 // This pattern can do nothing with poison index.
3856 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3857 return failure();
3858
3859 // Calculate the linearized position for inserting elements.
3860 int64_t insertBeginPosition =
3861 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3862
3863 // The valueToStore operand may be a vector or a scalar. Need to handle
3864 // both cases.
3865 int64_t insertSize = 1;
3866 if (auto srcVectorType =
3867 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3868 insertSize = srcVectorType.getNumElements();
3869
3870 assert(insertBeginPosition + insertSize <= vectorSize &&
3871 "insert would overflow the vector");
3872
3873 for (auto index : llvm::seq<int64_t>(insertBeginPosition,
3874 insertBeginPosition + insertSize)) {
3875 if (initializedDestIdxs[index])
3876 continue;
3877 initializedDestIdxs[index] = true;
3878 ++initializedCount;
3879 }
3880
3881 // Defer the creation of ops before we can make sure the pattern can
3882 // succeed.
3883 pendingInsertPos.push_back(insertBeginPosition);
3884 pendingInsertSize.push_back(insertSize);
3885 pendingInsertValues.push_back(insertOp.getValueToStore());
3886
3887 if (initializedCount == vectorSize)
3888 break;
3889 }
3890
3891 // Check cond 2: all positions must be initialized.
3892 if (initializedCount != vectorSize)
3893 return failure();
3894
3895 SmallVector<Value> elements(vectorSize);
3896 for (auto [insertBeginPosition, insertSize, valueToStore] :
3897 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3898 pendingInsertValues))) {
3899 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3900
3901 if (!srcVectorType) {
3902 elements[insertBeginPosition] = valueToStore;
3903 continue;
3904 }
3905
3906 Repeated<Type> elementToInsertTypes(insertSize,
3907 srcVectorType.getElementType());
3908 // Get all elements from the vector in row-major order.
3909 auto elementsToInsert = vector::ToElementsOp::create(
3910 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3911 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3912 elements[insertBeginPosition + linearIdx] =
3913 elementsToInsert.getResult(linearIdx);
3914 }
3915 }
3916
3917 rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
3918 return success();
3919 }
3920};
3921
3922} // namespace
3923
3924static Attribute
3926 Attribute dstAttr,
3927 int64_t maxVectorSizeFoldThreshold) {
3928 if (insertOp.hasDynamicPosition())
3929 return {};
3930
3931 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3932 if (!denseDst)
3933 return {};
3934
3935 if (!srcAttr) {
3936 return {};
3937 }
3938
3939 VectorType destTy = insertOp.getDestVectorType();
3940 if (destTy.isScalable())
3941 return {};
3942
3943 // Make sure we do not create too many large constants.
3944 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3945 !insertOp->hasOneUse())
3946 return {};
3947
3948 // Bail out on poison indices (kPoisonIndex = -1) to avoid computing an
3949 // invalid (negative) linearized position which would cause UB below.
3950 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3951 return {};
3952
3953 // Calculate the linearized position for inserting elements.
3954 int64_t insertBeginPosition =
3955 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3956 SmallVector<Attribute> insertedValues;
3957 Type destEltType = destTy.getElementType();
3958
3959 /// Converts attribute to the expected type if there's
3960 /// a mismatch.
3961 if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3962 for (auto value : denseSource.getValues<Attribute>())
3963 insertedValues.push_back(convertNumericAttr(value, destEltType));
3964 } else {
3965 insertedValues.push_back(convertNumericAttr(srcAttr, destEltType));
3966 }
3967
3968 auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
3969 copy(insertedValues, allValues.begin() + insertBeginPosition);
3970 auto newAttr = DenseElementsAttr::get(destTy, allValues);
3971
3972 return newAttr;
3973}
3974
3975/// Folder to replace the `dest` operand of the insert op with the root dest of
3976/// the insert op use chain.
3977static Value foldInsertUseChain(InsertOp insertOp) {
3978 auto destInsert = insertOp.getDest().getDefiningOp<InsertOp>();
3979 if (!destInsert)
3980 return {};
3981
3982 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3983 return {};
3984
3985 insertOp.setOperand(1, destInsert.getDest());
3986 return insertOp.getResult();
3987}
3988
3989void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3990 MLIRContext *context) {
3991 results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3992 InsertChainFullyInitialized>(context);
3993}
3994
3995OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
3996 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3997 // unless the source vector constant has a single use.
3998 constexpr int64_t vectorSizeFoldThreshold = 256;
3999 // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
4000 // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
4001 // (type mismatch).
4002 if (getNumIndices() == 0 && getValueToStoreType() == getType())
4003 return getValueToStore();
4004 // Fold `arith.constant` indices into the `vector.insert` operation.
4005 // Do not stop here as this fold may enable subsequent folds that require
4006 // constant indices.
4007 SmallVector<Value> operands = {getValueToStore(), getDest()};
4008 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
4009
4010 if (auto res = foldInsertUseChain(*this))
4011 return res;
4012 if (auto res = foldPoisonIndexInsertExtractOp(
4013 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
4014 return res;
4015 if (auto res = foldDenseElementsAttrDestInsertOp(
4016 *this, adaptor.getValueToStore(), adaptor.getDest(),
4017 vectorSizeFoldThreshold)) {
4018 return res;
4019 }
4020
4021 return inplaceFolded;
4022}
4023
4024//===----------------------------------------------------------------------===//
4025// InsertStridedSliceOp
4026//===----------------------------------------------------------------------===//
4027
4028void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
4029 Value source, Value dest,
4030 ArrayRef<int64_t> offsets,
4031 ArrayRef<int64_t> strides) {
4032 result.addOperands({source, dest});
4033 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
4034 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
4035 result.addTypes(dest.getType());
4036 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
4037 offsetsAttr);
4038 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
4039 stridesAttr);
4040}
4041
4042// TODO: Should be moved to Tablegen ConfinedAttr attributes.
4043template <typename OpType>
4044static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
4045 ArrayAttr arrayAttr,
4047 StringRef attrName) {
4048 if (arrayAttr.size() > shape.size())
4049 return op.emitOpError("expected ")
4050 << attrName << " attribute of rank no greater than vector rank";
4051 return success();
4052}
4053
4054// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
4055// interval. If `halfOpen` is true then the admissible interval is [min, max).
4056// Otherwise, the admissible interval is [min, max].
4057template <typename OpType>
4058static LogicalResult
4060 int64_t max, StringRef attrName,
4061 bool halfOpen = true) {
4062 for (auto attr : arrayAttr) {
4063 auto val = llvm::cast<IntegerAttr>(attr).getInt();
4064 auto upper = max;
4065 if (!halfOpen)
4066 upper += 1;
4067 if (val < min || val >= upper)
4068 return op.emitOpError("expected ") << attrName << " to be confined to ["
4069 << min << ", " << upper << ")";
4070 }
4071 return success();
4072}
4073
4074// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
4075// interval. If `halfOpen` is true then the admissible interval is [min, max).
4076// Otherwise, the admissible interval is [min, max].
4077template <typename OpType>
4078static LogicalResult
4080 ArrayRef<int64_t> shape, StringRef attrName,
4081 bool halfOpen = true, int64_t min = 0) {
4082 for (auto [index, attrDimPair] :
4083 llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
4084 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
4085 int64_t max = std::get<1>(attrDimPair);
4086 if (!halfOpen)
4087 max += 1;
4088 if (val < min || val >= max)
4089 return op.emitOpError("expected ")
4090 << attrName << " dimension " << index << " to be confined to ["
4091 << min << ", " << max << ")";
4092 }
4093 return success();
4094}
4095
4096// Returns true if, for all indices i = 0..shape.size()-1, val is in the
4097// [min, max} interval:
4098// val = `arrayAttr1[i]` + `arrayAttr2[i]`,
4099// If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
4100// the admissible interval is [min, max].
4101template <typename OpType>
4103 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
4104 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
4105 bool halfOpen = true, int64_t min = 1) {
4106 assert(arrayAttr1.size() <= shape.size());
4107 assert(arrayAttr2.size() <= shape.size());
4108 for (auto [index, it] :
4109 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
4110 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
4111 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
4112 int64_t max = std::get<2>(it);
4113 if (!halfOpen)
4114 max += 1;
4115 if (val1 + val2 < 0 || val1 + val2 >= max)
4116 return op.emitOpError("expected sum(")
4117 << attrName1 << ", " << attrName2 << ") dimension " << index
4118 << " to be confined to [" << min << ", " << max << ")";
4119 }
4120 return success();
4121}
4122
4124 MLIRContext *context) {
4125 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
4126 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
4127 });
4128 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
4129}
4130
4131LogicalResult InsertStridedSliceOp::verify() {
4132 auto sourceVectorType = getSourceVectorType();
4133 auto destVectorType = getDestVectorType();
4134 auto offsets = getOffsetsAttr();
4135 auto strides = getStridesAttr();
4136 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
4137 return emitOpError(
4138 "expected offsets of same size as destination vector rank");
4139 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
4140 return emitOpError("expected strides of same size as source vector rank");
4141 if (sourceVectorType.getRank() > destVectorType.getRank())
4142 return emitOpError(
4143 "expected source rank to be no greater than destination rank");
4144
4145 auto sourceShape = sourceVectorType.getShape();
4146 auto destShape = destVectorType.getShape();
4147 SmallVector<int64_t, 4> sourceShapeAsDestShape(
4148 destShape.size() - sourceShape.size(), 0);
4149 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
4150 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
4151 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
4152 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
4153 offName)) ||
4154 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
4155 /*max=*/1, stridesName,
4156 /*halfOpen=*/false)) ||
4158 *this, offsets,
4159 makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
4160 offName, "source vector shape",
4161 /*halfOpen=*/false, /*min=*/1)))
4162 return failure();
4163
4164 unsigned rankDiff = destShape.size() - sourceShape.size();
4165 for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
4166 if (sourceVectorType.getScalableDims()[idx] !=
4167 destVectorType.getScalableDims()[idx + rankDiff]) {
4168 return emitOpError("mismatching scalable flags (at source vector idx=")
4169 << idx << ")";
4170 }
4171 if (sourceVectorType.getScalableDims()[idx]) {
4172 auto sourceSize = sourceShape[idx];
4173 auto destSize = destShape[idx + rankDiff];
4174 if (sourceSize != destSize) {
4175 return emitOpError("expected size at idx=")
4176 << idx
4177 << (" to match the corresponding base size from the input "
4178 "vector (")
4179 << sourceSize << (" vs ") << destSize << (")");
4180 }
4181 }
4182 }
4183
4184 return success();
4185}
4186
4187namespace {
4188/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
4189class FoldInsertStridedSliceSplat final
4190 : public OpRewritePattern<InsertStridedSliceOp> {
4191public:
4192 using Base::Base;
4193
4194 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4195 PatternRewriter &rewriter) const override {
4196
4197 auto dst = insertStridedSliceOp.getDest();
4198 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
4199 if (!splat || getScalarSplatSource(dst) != splat)
4200 return failure();
4201
4202 rewriter.replaceOp(insertStridedSliceOp, dst);
4203 return success();
4204 }
4205};
4206
4207/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
4208/// to dst.
4209class FoldInsertStridedSliceOfExtract final
4210 : public OpRewritePattern<InsertStridedSliceOp> {
4211public:
4212 using Base::Base;
4213
4214 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4215 PatternRewriter &rewriter) const override {
4216 auto extractStridedSliceOp =
4217 insertStridedSliceOp.getValueToStore()
4218 .getDefiningOp<vector::ExtractStridedSliceOp>();
4219
4220 if (!extractStridedSliceOp)
4221 return failure();
4222
4223 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
4224 return failure();
4225
4226 // Check if have the same strides and offsets.
4227 if (extractStridedSliceOp.getStrides() !=
4228 insertStridedSliceOp.getStrides() ||
4229 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
4230 return failure();
4231
4232 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
4233 return success();
4234 }
4235};
4236
4237// Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
4238// ConstantOp.
4239class InsertStridedSliceConstantFolder final
4240 : public OpRewritePattern<InsertStridedSliceOp> {
4241public:
4242 using Base::Base;
4243
4244 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
4245 // unless the source vector constant has a single use.
4246 static constexpr int64_t vectorSizeFoldThreshold = 256;
4247
4248 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
4249 PatternRewriter &rewriter) const override {
4250 // Return if 'InsertOp' operand is not defined by a compatible vector
4251 // ConstantOp.
4252 TypedValue<VectorType> destVector = op.getDest();
4253 Attribute vectorDestCst;
4254 if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
4255 return failure();
4256
4257 VectorType destTy = destVector.getType();
4258 if (destTy.isScalable())
4259 return failure();
4260
4261 // Make sure we do not create too many large constants.
4262 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
4263 !destVector.hasOneUse())
4264 return failure();
4265
4266 TypedValue<VectorType> sourceValue = op.getValueToStore();
4267 Attribute sourceCst;
4268 if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
4269 return failure();
4270
4271 // TODO: Support poison.
4272 if (matchPattern(vectorDestCst, ub::m_Poison()) ||
4273 matchPattern(sourceCst, ub::m_Poison()))
4274 return failure();
4275
4276 // TODO: Handle non-unit strides when they become available.
4277 if (op.hasNonUnitStrides())
4278 return failure();
4279
4280 VectorType sliceVecTy = sourceValue.getType();
4281 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4282 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
4283 SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
4284 SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
4285
4286 // Calcualte the destination element indices by enumerating all slice
4287 // positions within the destination and linearizing them. The enumeration
4288 // order is lexicographic which yields a sequence of monotonically
4289 // increasing linearized position indices.
4290 // Because the destination may have higher dimensionality then the slice,
4291 // we keep track of two overlapping sets of positions and offsets.
4292 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
4293 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
4294 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
4295 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
4296 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
4297 MutableArrayRef<int64_t> currSlicePosition(
4298 currDestPosition.begin() + rankDifference, currDestPosition.end());
4299 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
4300 offsets.end());
4301 do {
4302 int64_t linearizedPosition = linearize(currDestPosition, destStrides);
4303 assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
4304 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
4305 "Invalid slice element");
4306 newValues[linearizedPosition] = *sliceValuesIt;
4307 ++sliceValuesIt;
4308 } while (succeeded(
4309 incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
4310
4311 auto newAttr = DenseElementsAttr::get(destTy, newValues);
4312 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
4313 return success();
4314 }
4315};
4316
4317} // namespace
4318
4319void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
4320 RewritePatternSet &results, MLIRContext *context) {
4321 results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4322 InsertStridedSliceConstantFolder>(context);
4323}
4324
4325OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4326 if (getSourceVectorType() == getDestVectorType())
4327 return getValueToStore();
4328 return {};
4329}
4330
4331//===----------------------------------------------------------------------===//
4332// OuterProductOp
4333//===----------------------------------------------------------------------===//
4334
4335/// Build an op without mask, use the type of `acc` as the return type.
4336void OuterProductOp::build(OpBuilder &builder, OperationState &result,
4337 Value lhs, Value rhs, Value acc) {
4338 result.addOperands({lhs, rhs, acc});
4339 result.addTypes(acc.getType());
4340}
4341
4342void OuterProductOp::print(OpAsmPrinter &p) {
4343 p << " " << getLhs() << ", " << getRhs();
4344 if (getAcc()) {
4345 p << ", " << getAcc();
4346 p.printOptionalAttrDict((*this)->getAttrs());
4347 }
4348 p << " : " << getLhs().getType() << ", " << getRhs().getType();
4349}
4350
4351ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
4352 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4353 Type tLHS, tRHS;
4354 if (parser.parseOperandList(operandsInfo) ||
4355 parser.parseOptionalAttrDict(result.attributes) ||
4356 parser.parseColonType(tLHS) || parser.parseComma() ||
4357 parser.parseType(tRHS))
4358 return failure();
4359 if (operandsInfo.size() < 2)
4360 return parser.emitError(parser.getNameLoc(),
4361 "expected at least 2 operands");
4362 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4363 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4364 if (!vLHS)
4365 return parser.emitError(parser.getNameLoc(),
4366 "expected vector type for operand #1");
4367
4368 VectorType resType;
4369 if (vRHS) {
4370 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4371 vRHS.getScalableDims()[0]};
4372 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4373 vLHS.getElementType(), scalableDimsRes);
4374 } else {
4375 // Scalar RHS operand
4376 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4377 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4378 scalableDimsRes);
4379 }
4380
4381 if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
4382 result.attributes.append(
4383 OuterProductOp::getKindAttrName(result.name),
4384 CombiningKindAttr::get(result.getContext(),
4385 OuterProductOp::getDefaultKind()));
4386 }
4387
4388 return failure(
4389 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
4390 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
4391 (operandsInfo.size() > 2 &&
4392 parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
4393 parser.addTypeToList(resType, result.types));
4394}
4395
4396LogicalResult OuterProductOp::verify() {
4397 Type tRHS = getOperandTypeRHS();
4398 VectorType vLHS = getOperandVectorTypeLHS(),
4399 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4400 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4401
4402 if (vLHS.getRank() != 1)
4403 return emitOpError("expected 1-d vector for operand #1");
4404
4405 if (vRHS) {
4406 // Proper OUTER operation.
4407 if (vRHS.getRank() != 1)
4408 return emitOpError("expected 1-d vector for operand #2");
4409 if (vRES.getRank() != 2)
4410 return emitOpError("expected 2-d vector result");
4411 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4412 return emitOpError("expected #1 operand dim to match result dim #1");
4413 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4414 return emitOpError("expected #2 operand dim to match result dim #2");
4415 if (vLHS.isScalable() && !vRHS.isScalable()) {
4416 // This restriction reflects what's currently supported in terms of
4417 // scalable vectors. However, we could relax this if there's a use case.
4418 return emitOpError(
4419 "expected either both or only #2 operand dim to be scalable");
4420 }
4421 } else {
4422 // An AXPY operation.
4423 if (vRES.getRank() != 1)
4424 return emitOpError("expected 1-d vector result");
4425 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4426 return emitOpError("expected #1 operand dim to match result dim #1");
4427 }
4428
4429 if (vACC && vACC != vRES)
4430 return emitOpError("expected operand #3 of same type as result type");
4431
4432 if (!getKindAttr()) {
4433 return emitOpError("expected 'kind' attribute of type CombiningKind (e.g. "
4434 "'vector.kind<add>')");
4435 }
4436
4437 // Verify supported combining kind.
4438 if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
4439 return emitOpError("unsupported outerproduct type");
4440
4441 return success();
4442}
4443
4444// MaskableOpInterface methods.
4445
4446/// Returns the mask type expected by this operation. Mostly used for
4447/// verification purposes. It requires the operation to be vectorized."
4448Type OuterProductOp::getExpectedMaskType() {
4449 auto vecType = this->getResultVectorType();
4450 return VectorType::get(vecType.getShape(),
4451 IntegerType::get(vecType.getContext(), /*width=*/1),
4452 vecType.getScalableDims());
4453}
4454
4455//===----------------------------------------------------------------------===//
4456// ExtractStridedSliceOp
4457//===----------------------------------------------------------------------===//
4458
4459// Inference works as follows:
4460// 1. Add 'sizes' from prefix of dims in 'offsets'.
4461// 2. Add sizes from 'vectorType' for remaining dims.
4462// Scalable flags are inherited from 'vectorType'.
4463static Type inferStridedSliceOpResultType(VectorType vectorType,
4464 ArrayAttr offsets, ArrayAttr sizes,
4465 ArrayAttr strides) {
4466 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4468 shape.reserve(vectorType.getRank());
4469 unsigned idx = 0;
4470 for (unsigned e = offsets.size(); idx < e; ++idx)
4471 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4472 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4473 shape.push_back(vectorType.getShape()[idx]);
4474
4475 return VectorType::get(shape, vectorType.getElementType(),
4476 vectorType.getScalableDims());
4477}
4478
4479void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
4480 Value source, ArrayRef<int64_t> offsets,
4481 ArrayRef<int64_t> sizes,
4482 ArrayRef<int64_t> strides) {
4483 result.addOperands(source);
4484 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
4485 auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
4486 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
4487 result.addTypes(
4488 inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
4489 offsetsAttr, sizesAttr, stridesAttr));
4490 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
4491 offsetsAttr);
4492 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
4493 sizesAttr);
4494 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
4495 stridesAttr);
4496}
4497
4498LogicalResult ExtractStridedSliceOp::verify() {
4499 auto type = getSourceVectorType();
4500 auto offsets = getOffsetsAttr();
4501 auto sizes = getSizesAttr();
4502 auto strides = getStridesAttr();
4503 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4504 return emitOpError(
4505 "expected offsets, sizes and strides attributes of same size");
4506
4507 auto shape = type.getShape();
4508 auto offName = getOffsetsAttrName();
4509 auto sizesName = getSizesAttrName();
4510 auto stridesName = getStridesAttrName();
4511 if (failed(
4512 isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
4513 failed(
4514 isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
4515 failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
4516 stridesName)) ||
4517 failed(
4518 isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
4519 failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
4520 /*halfOpen=*/false,
4521 /*min=*/1)) ||
4522 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
4523 /*max=*/1, stridesName,
4524 /*halfOpen=*/false)) ||
4525 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
4526 shape, offName, sizesName,
4527 /*halfOpen=*/false)))
4528 return failure();
4529
4530 auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
4531 offsets, sizes, strides);
4532 if (getResult().getType() != resultType)
4533 return emitOpError("expected result type to be ") << resultType;
4534
4535 for (unsigned idx = 0; idx < sizes.size(); ++idx) {
4536 if (type.getScalableDims()[idx]) {
4537 auto inputDim = type.getShape()[idx];
4538 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4539 if (inputDim != inputSize)
4540 return emitOpError("expected size at idx=")
4541 << idx
4542 << (" to match the corresponding base size from the input "
4543 "vector (")
4544 << inputSize << (" vs ") << inputDim << (")");
4545 }
4546 }
4547
4548 return success();
4549}
4550
4551// When the source of ExtractStrided comes from a chain of InsertStrided ops try
4552// to use the source of the InsertStrided ops if we can detect that the
4553// extracted vector is a subset of one of the vector inserted.
4554static LogicalResult
4555foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
4556 // Helper to extract integer out of ArrayAttr.
4557 auto getElement = [](ArrayAttr array, int idx) {
4558 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4559 };
4560 ArrayAttr extractOffsets = op.getOffsets();
4561 ArrayAttr extractStrides = op.getStrides();
4562 ArrayAttr extractSizes = op.getSizes();
4563 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4564 while (insertOp) {
4565 if (op.getSourceVectorType().getRank() !=
4566 insertOp.getSourceVectorType().getRank())
4567 return failure();
4568 ArrayAttr insertOffsets = insertOp.getOffsets();
4569 ArrayAttr insertStrides = insertOp.getStrides();
4570 // If the rank of extract is greater than the rank of insert, we are likely
4571 // extracting a partial chunk of the vector inserted.
4572 if (extractOffsets.size() > insertOffsets.size())
4573 return failure();
4574 bool patialoverlap = false;
4575 bool disjoint = false;
4576 SmallVector<int64_t, 4> offsetDiffs;
4577 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4578 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
4579 return failure();
4580 int64_t start = getElement(insertOffsets, dim);
4581 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4582 int64_t offset = getElement(extractOffsets, dim);
4583 int64_t size = getElement(extractSizes, dim);
4584 // Check if the start of the extract offset is in the interval inserted.
4585 if (start <= offset && offset < end) {
4586 // If the extract interval overlaps but is not fully included we may
4587 // have a partial overlap that will prevent any folding.
4588 if (offset + size > end)
4589 patialoverlap = true;
4590 offsetDiffs.push_back(offset - start);
4591 continue;
4592 }
4593 disjoint = true;
4594 break;
4595 }
4596 // The extract element chunk is a subset of the insert element.
4597 if (!disjoint && !patialoverlap) {
4598 op.setOperand(insertOp.getValueToStore());
4599 // OpBuilder is only used as a helper to build an I64ArrayAttr.
4600 OpBuilder b(op.getContext());
4601 op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
4602 return success();
4603 }
4604 // If the chunk extracted is disjoint from the chunk inserted, keep looking
4605 // in the insert chain.
4606 if (disjoint)
4607 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4608 else {
4609 // The extracted vector partially overlap the inserted vector, we cannot
4610 // fold.
4611 return failure();
4612 }
4613 }
4614 return failure();
4615}
4616
4617// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4618static OpFoldResult
4620 Attribute foldInput) {
4621
4622 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4623 if (!dense)
4624 return {};
4625
4626 // TODO: Handle non-unit strides when they become available.
4627 if (op.hasNonUnitStrides())
4628 return {};
4629
4630 VectorType sourceVecTy = op.getSourceVectorType();
4631 ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
4632 SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
4633
4634 VectorType sliceVecTy = op.getType();
4635 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4636 int64_t rank = sliceVecTy.getRank();
4637
4638 // Expand offsets and sizes to match the vector rank.
4639 SmallVector<int64_t, 4> offsets(rank, 0);
4640 copy(getI64SubArray(op.getOffsets()), offsets.begin());
4641
4642 SmallVector<int64_t, 4> sizes(sourceShape);
4643 copy(getI64SubArray(op.getSizes()), sizes.begin());
4644
4645 // Calculate the slice elements by enumerating all slice positions and
4646 // linearizing them. The enumeration order is lexicographic which yields a
4647 // sequence of monotonically increasing linearized position indices.
4648 const auto denseValuesBegin = dense.value_begin<Attribute>();
4649 SmallVector<Attribute> sliceValues;
4650 sliceValues.reserve(sliceVecTy.getNumElements());
4651 SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
4652 do {
4653 int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
4654 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4655 "Invalid index");
4656 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4657 } while (succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
4658
4659 assert(static_cast<int64_t>(sliceValues.size()) ==
4660 sliceVecTy.getNumElements() &&
4661 "Invalid number of slice elements");
4662 return DenseElementsAttr::get(sliceVecTy, sliceValues);
4663}
4664
4665OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4666 if (getSourceVectorType() == getResult().getType())
4667 return getSource();
4668 if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
4669 return getResult();
4670
4671 // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
4672 if (auto splat =
4673 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4674 return DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
4675
4676 // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4677 return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getSource());
4678}
4679
4680void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4681 populateFromInt64AttrArray(getOffsets(), results);
4682}
4683
4684namespace {
4685
4686// Pattern to rewrite nested ExtractStridedSliceOp into a single one.
4687//
4688// Example:
4689//
4690// %0 = vector.extract_strided_slice %arg0
4691// {offsets = [1, 2], sizes = [3, 4], strides = [1, 1]}
4692// : vector<4x8x16xf32> to vector<3x4x16xf32>
4693// %1 = vector.extract_strided_slice %0
4694// {offsets = [0, 1], sizes = [2, 2], strides = [1, 1]}
4695// : vector<3x4x16xf32> to vector<2x2x16xf32>
4696//
4697// to
4698//
4699// %1 = vector.extract_strided_slice %arg0
4700// {offsets = [1, 3], sizes = [2, 2], strides = [1, 1]}
4701// : vector<4x8x16xf32> to vector<2x2x16xf32>
4702class StridedSliceFolder final
4703 : public OpRewritePattern<ExtractStridedSliceOp> {
4704public:
4705 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
4706
4707 LogicalResult matchAndRewrite(ExtractStridedSliceOp secondOp,
4708 PatternRewriter &rewriter) const override {
4709 auto firstOp = secondOp.getSource().getDefiningOp<ExtractStridedSliceOp>();
4710 if (!firstOp)
4711 return failure();
4712
4713 if (secondOp.hasNonUnitStrides() || firstOp.hasNonUnitStrides())
4714 return failure();
4715
4716 SmallVector<int64_t> firstOffsets = getI64SubArray(firstOp.getOffsets());
4717 SmallVector<int64_t> firstSizes = getI64SubArray(firstOp.getSizes());
4718 SmallVector<int64_t> secondOffsets = getI64SubArray(secondOp.getOffsets());
4719 SmallVector<int64_t> secondSizes = getI64SubArray(secondOp.getSizes());
4720
4721 unsigned newRank = std::max(firstOffsets.size(), secondOffsets.size());
4722 SmallVector<int64_t> combinedOffsets(newRank, 0);
4723 SmallVector<int64_t> combinedSizes(newRank);
4724 ArrayRef<int64_t> firstSourceShape =
4725 firstOp.getSourceVectorType().getShape();
4726 for (unsigned i = 0; i < newRank; ++i) {
4727 int64_t off1 = (i < firstOffsets.size()) ? firstOffsets[i] : 0;
4728 int64_t off2 = (i < secondOffsets.size()) ? secondOffsets[i] : 0;
4729 combinedOffsets[i] = off1 + off2;
4730
4731 if (i < secondSizes.size()) {
4732 combinedSizes[i] = secondSizes[i];
4733 } else if (i < firstSizes.size()) {
4734 combinedSizes[i] = firstSizes[i];
4735 } else {
4736 combinedSizes[i] = firstSourceShape[i];
4737 }
4738 }
4739
4740 SmallVector<int64_t> combinedStrides(newRank, 1);
4741 rewriter.replaceOpWithNewOp<ExtractStridedSliceOp>(
4742 secondOp, firstOp.getSource(), combinedOffsets, combinedSizes,
4743 combinedStrides);
4744 return success();
4745 }
4746};
4747
4748// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
4749// CreateMaskOp.
4750//
4751// Example:
4752//
4753// %mask = vector.create_mask %ub : vector<16xi1>
4754// %slice = vector.extract_strided_slice [%offset] [8] [1]
4755//
4756// to
4757//
4758// %new_ub = arith.subi %ub, %offset
4759// %mask = vector.create_mask %new_ub : vector<8xi1>
4760class StridedSliceCreateMaskFolder final
4761 : public OpRewritePattern<ExtractStridedSliceOp> {
4762 using Base::Base;
4763
4764public:
4765 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4766 PatternRewriter &rewriter) const override {
4767 Location loc = extractStridedSliceOp.getLoc();
4768 // Return if 'extractStridedSliceOp' operand is not defined by a
4769 // CreateMaskOp.
4770 auto createMaskOp =
4771 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4772 if (!createMaskOp)
4773 return failure();
4774 // Return if 'extractStridedSliceOp' has non-unit strides.
4775 if (extractStridedSliceOp.hasNonUnitStrides())
4776 return failure();
4777 // Gather constant mask dimension sizes.
4778 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4779 // Gather strided slice offsets and sizes.
4780 SmallVector<int64_t> sliceOffsets;
4781 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4782 sliceOffsets);
4783 SmallVector<int64_t> sliceSizes;
4784 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4785
4786 // Compute slice of vector mask region.
4787 SmallVector<Value> sliceMaskDimSizes;
4788 sliceMaskDimSizes.reserve(maskDimSizes.size());
4789 // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4790 // only iterate on the leading dim sizes. The tail accounts for the
4791 // remaining dim sizes.
4792 for (auto [maskDimSize, sliceOffset, sliceSize] :
4793 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4794 // No need to clamp on min/max values, because create_mask has clamping
4795 // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
4796 // greater than the vector dim size.
4797 IntegerAttr offsetAttr =
4798 rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
4799 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4800 Value sliceMaskDimSize =
4801 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4802 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4803 }
4804 // Add unchanged dimensions.
4805 llvm::append_range(
4806 sliceMaskDimSizes,
4807 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4808 // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
4809 // region.
4810 rewriter.replaceOpWithNewOp<CreateMaskOp>(
4811 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4812 sliceMaskDimSizes);
4813 return success();
4814 }
4815};
4816
4817// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
4818// ConstantMaskOp.
4819class StridedSliceConstantMaskFolder final
4820 : public OpRewritePattern<ExtractStridedSliceOp> {
4821public:
4822 using Base::Base;
4823
4824 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4825 PatternRewriter &rewriter) const override {
4826 // Return if 'extractStridedSliceOp' operand is not defined by a
4827 // ConstantMaskOp.
4828 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4829 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4830 if (!constantMaskOp)
4831 return failure();
4832 // Return if 'extractStridedSliceOp' has non-unit strides.
4833 if (extractStridedSliceOp.hasNonUnitStrides())
4834 return failure();
4835 // Gather constant mask dimension sizes.
4836 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4837 // Gather strided slice offsets and sizes.
4838 SmallVector<int64_t> sliceOffsets;
4839 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4840 sliceOffsets);
4841 SmallVector<int64_t> sliceSizes;
4842 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4843
4844 // Compute slice of vector mask region.
4845 SmallVector<int64_t> sliceMaskDimSizes;
4846 sliceMaskDimSizes.reserve(maskDimSizes.size());
4847 for (auto [maskDimSize, sliceOffset, sliceSize] :
4848 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4849 int64_t sliceMaskDimSize = std::max(
4850 static_cast<int64_t>(0),
4851 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4852 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4853 }
4854 // Add unchanged dimensions.
4855 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4856 for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4857 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4858 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
4859 // region is a conjunction of mask dim intervals).
4860 if (llvm::is_contained(sliceMaskDimSizes, 0))
4861 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4862
4863 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
4864 // region.
4865 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4866 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4867 sliceMaskDimSizes);
4868 return success();
4869 }
4870};
4871
4872// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
4873// BroadcastOp(ExtractStrideSliceOp).
4874class StridedSliceBroadcast final
4875 : public OpRewritePattern<ExtractStridedSliceOp> {
4876public:
4877 using Base::Base;
4878
4879 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4880 PatternRewriter &rewriter) const override {
4881 auto broadcast = op.getSource().getDefiningOp<BroadcastOp>();
4882 if (!broadcast)
4883 return failure();
4884 auto srcVecType =
4885 llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
4886 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4887 auto dstVecType = llvm::cast<VectorType>(op.getType());
4888 unsigned dstRank = dstVecType.getRank();
4889 unsigned rankDiff = dstRank - srcRank;
4890 // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4891 // (n -> m with n > m). If they are originally both broadcasted *and*
4892 // sliced, this can be simplified to just broadcasting.
4893 bool needsSlice = false;
4894 for (unsigned i = 0; i < srcRank; i++) {
4895 if (srcVecType.getDimSize(i) != 1 &&
4896 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4897 needsSlice = true;
4898 break;
4899 }
4900 }
4901 Value source = broadcast.getSource();
4902 if (needsSlice) {
4903 SmallVector<int64_t> offsets =
4904 getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
4905 SmallVector<int64_t> sizes =
4906 getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
4907 for (unsigned i = 0; i < srcRank; i++) {
4908 if (srcVecType.getDimSize(i) == 1) {
4909 // In case this dimension was broadcasted *and* sliced, the offset
4910 // and size need to be updated now that there is no broadcast before
4911 // the slice.
4912 offsets[i] = 0;
4913 sizes[i] = 1;
4914 }
4915 }
4916 source = ExtractStridedSliceOp::create(
4917 rewriter, op->getLoc(), source, offsets, sizes,
4918 getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
4919 }
4920 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
4921 return success();
4922 }
4923};
4924
4925/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
4926class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
4927public:
4928 using Base::Base;
4929
4930 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4931 PatternRewriter &rewriter) const override {
4932
4933 Value splat = getScalarSplatSource(op.getSource());
4934 if (!splat)
4935 return failure();
4936 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
4937 return success();
4938 }
4939};
4940
4941/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
4942/// slice is contiguous, into extract and shape_cast.
4943///
4944/// Example:
4945/// Before:
4946/// %1 = vector.extract_strided_slice %arg0 {
4947/// offsets = [0, 0, 0, 0, 0],
4948/// sizes = [1, 1, 1, 1, 8],
4949/// strides = [1, 1, 1, 1, 1]
4950/// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
4951/// After:
4952/// %0 = vector.extract %arg0[0, 0, 0, 0]
4953/// : vector<8xi8> from vector<8x1x1x2x8xi8>
4954/// %1 = vector.shape_cast %0
4955/// : vector<8xi8> to vector<1x1x1x1x8xi8>
4956///
4957class ContiguousExtractStridedSliceToExtract final
4958 : public OpRewritePattern<ExtractStridedSliceOp> {
4959public:
4960 using Base::Base;
4961
4962 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4963 PatternRewriter &rewriter) const override {
4964 if (op.hasNonUnitStrides())
4965 return failure();
4966 Value source = op.getOperand();
4967 auto sourceType = cast<VectorType>(source.getType());
4968 if (sourceType.isScalable() || sourceType.getRank() == 0)
4969 return failure();
4970
4971 // Compute the number of offsets to pass to ExtractOp::build. That is the
4972 // difference between the source rank and the desired slice rank. We walk
4973 // the dimensions from innermost out, and stop when the next slice dimension
4974 // is not full-size.
4975 SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
4976 int numOffsets;
4977 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4978 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4979 break;
4980 }
4981
4982 // If the created extract op would have no offsets, then this whole
4983 // extract_strided_slice is the identity and should have been handled by
4984 // other canonicalizations.
4985 if (numOffsets == 0)
4986 return failure();
4987
4988 // If not even the inner-most dimension is full-size, this op can't be
4989 // rewritten as an ExtractOp.
4990 if (numOffsets == sourceType.getRank() &&
4991 static_cast<int>(sizes.size()) == sourceType.getRank())
4992 return failure();
4993
4994 // The outer dimensions must have unit size.
4995 for (int i = 0; i < numOffsets; ++i) {
4996 if (sizes[i] != 1)
4997 return failure();
4998 }
4999
5000 // Avoid generating slices that have leading unit dimensions. The shape_cast
5001 // op that we create below would take bad generic fallback patterns
5002 // (ShapeCastOpRewritePattern).
5003 while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
5004 sizes[numOffsets] == 1) {
5005 ++numOffsets;
5006 }
5007
5008 SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
5009 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
5010 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
5011 extractOffsets);
5012 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
5013 return success();
5014 }
5015};
5016
5017} // namespace
5018
5019void ExtractStridedSliceOp::getCanonicalizationPatterns(
5020 RewritePatternSet &results, MLIRContext *context) {
5021 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
5022 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
5023 results.add<StridedSliceFolder, StridedSliceCreateMaskFolder,
5024 StridedSliceConstantMaskFolder, StridedSliceBroadcast,
5025 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
5026 context);
5027}
5028
5029//===----------------------------------------------------------------------===//
5030// TransferReadOp
5031//===----------------------------------------------------------------------===//
5032
5033/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
5034void TransferReadOp::build(OpBuilder &builder, OperationState &result,
5035 VectorType vectorType, Value source,
5036 ValueRange indices, std::optional<Value> padding,
5037 AffineMapAttr permutationMapAttr,
5038 /*optional*/ ArrayAttr inBoundsAttr) {
5039
5040 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5041 if (!padding)
5042 padding = ub::PoisonOp::create(builder, result.location, elemType);
5043 build(builder, result, vectorType, source, indices, permutationMapAttr,
5044 *padding, /*mask=*/Value(), inBoundsAttr);
5045}
5046
5047/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
5048void TransferReadOp::build(OpBuilder &builder, OperationState &result,
5049 VectorType vectorType, Value source,
5050 ValueRange indices, std::optional<Value> padding,
5051 AffineMap permutationMap,
5052 std::optional<ArrayRef<bool>> inBounds) {
5053 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5054 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
5055 ? builder.getBoolArrayAttr(inBounds.value())
5056 : builder.getBoolArrayAttr(
5057 SmallVector<bool>(vectorType.getRank(), false));
5058 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5059 if (!padding)
5060 padding = ub::PoisonOp::create(builder, result.location, elemType);
5061 build(builder, result, vectorType, source, indices, *padding,
5062 permutationMapAttr, inBoundsAttr);
5063}
5064
5065/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
5066void TransferReadOp::build(OpBuilder &builder, OperationState &result,
5067 VectorType vectorType, Value source,
5068 ValueRange indices, std::optional<Value> padding,
5069 std::optional<ArrayRef<bool>> inBounds) {
5070 AffineMap permutationMap = getTransferMinorIdentityMap(
5071 llvm::cast<ShapedType>(source.getType()), vectorType);
5072 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5073 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
5074 ? builder.getBoolArrayAttr(inBounds.value())
5075 : builder.getBoolArrayAttr(
5076 SmallVector<bool>(vectorType.getRank(), false));
5077 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5078 if (!padding)
5079 padding = ub::PoisonOp::create(builder, result.location, elemType);
5080 build(builder, result, vectorType, source, indices, permutationMapAttr,
5081 *padding,
5082 /*mask=*/Value(), inBoundsAttr);
5083}
5084
5085template <typename EmitFun>
5086static LogicalResult verifyPermutationMap(AffineMap permutationMap,
5087 EmitFun emitOpError) {
5088 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
5089 for (auto expr : permutationMap.getResults()) {
5090 auto dim = dyn_cast<AffineDimExpr>(expr);
5091 auto zero = dyn_cast<AffineConstantExpr>(expr);
5092 if (zero) {
5093 if (zero.getValue() != 0) {
5094 return emitOpError(
5095 "requires a projected permutation_map (at most one dim or the zero "
5096 "constant can appear in each result)");
5097 }
5098 continue;
5099 }
5100 if (!dim) {
5101 return emitOpError("requires a projected permutation_map (at most one "
5102 "dim or the zero constant can appear in each result)");
5103 }
5104 if (seen[dim.getPosition()]) {
5105 return emitOpError(
5106 "requires a permutation_map that is a permutation (found one dim "
5107 "used more than once)");
5108 }
5109 seen[dim.getPosition()] = true;
5110 }
5111 return success();
5112}
5113
5114static LogicalResult
5115verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
5116 VectorType vectorType, VectorType maskType,
5117 VectorType inferredMaskType, AffineMap permutationMap,
5118 ArrayAttr inBounds) {
5119 if (op->hasAttr("masked")) {
5120 return op->emitOpError("masked attribute has been removed. "
5121 "Use in_bounds instead.");
5122 }
5123
5124 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
5125 return op->emitOpError(
5126 "requires source to be a memref or ranked tensor type");
5127
5128 auto elementType = shapedType.getElementType();
5129 DataLayout dataLayout = DataLayout::closest(op);
5130 if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
5131 // Memref or tensor has vector element type.
5132 unsigned sourceVecSize =
5133 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
5134 vectorElementType.getShape().back();
5135 unsigned resultVecSize =
5136 dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
5137 vectorType.getShape().back();
5138 if (resultVecSize % sourceVecSize != 0)
5139 return op->emitOpError(
5140 "requires the bitwidth of the minor 1-D vector to be an integral "
5141 "multiple of the bitwidth of the minor 1-D vector of the source");
5142
5143 unsigned sourceVecEltRank = vectorElementType.getRank();
5144 unsigned resultVecRank = vectorType.getRank();
5145 if (sourceVecEltRank > resultVecRank)
5146 return op->emitOpError(
5147 "requires source vector element and vector result ranks to match.");
5148 unsigned rankOffset = resultVecRank - sourceVecEltRank;
5149 // Check that permutation map results match 'rankOffset' of vector type.
5150 if (permutationMap.getNumResults() != rankOffset)
5151 return op->emitOpError("requires a permutation_map with result dims of "
5152 "the same rank as the vector type");
5153
5154 if (maskType)
5155 return op->emitOpError("does not support masks with vector element type");
5156 } else {
5157 // Memref or tensor has scalar element type.
5158 unsigned minorSize =
5159 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
5160 unsigned resultVecSize =
5161 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
5162 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
5163 return op->emitOpError(
5164 "requires the bitwidth of the minor 1-D vector to be an integral "
5165 "multiple of the bitwidth of the source element type");
5166
5167 // Check that permutation map results match rank of vector type.
5168 if (permutationMap.getNumResults() != vectorType.getRank())
5169 return op->emitOpError("requires a permutation_map with result dims of "
5170 "the same rank as the vector type");
5171 }
5172
5173 if (permutationMap.getNumSymbols() != 0)
5174 return op->emitOpError("requires permutation_map without symbols");
5175
5176 if (permutationMap.getNumInputs() != shapedType.getRank())
5177 return op->emitOpError("requires a permutation_map with input dims of the "
5178 "same rank as the source type");
5179
5180 if (maskType && maskType != inferredMaskType)
5181 return op->emitOpError("inferred mask type (")
5182 << inferredMaskType << ") and mask operand type (" << maskType
5183 << ") don't match";
5184
5185 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
5186 return op->emitOpError("expects the in_bounds attr of same rank "
5187 "as permutation_map results: ")
5188 << AffineMapAttr::get(permutationMap)
5189 << " vs inBounds of size: " << inBounds.size();
5190
5191 return success();
5192}
5193
5194static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
5195 SmallVector<StringRef, 3> elidedAttrs;
5196 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
5197 if (op.getPermutationMap().isMinorIdentity())
5198 elidedAttrs.push_back(op.getPermutationMapAttrName());
5199 // Elide in_bounds attribute if all dims are out-of-bounds.
5200 if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
5201 elidedAttrs.push_back(op.getInBoundsAttrName());
5202 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
5203}
5204
5205void TransferReadOp::print(OpAsmPrinter &p) {
5206 p << " " << getBase() << "[" << getIndices() << "], " << getPadding();
5207 if (getMask())
5208 p << ", " << getMask();
5209 printTransferAttrs(p, *this);
5210 p << " : " << getShapedType() << ", " << getVectorType();
5211}
5212
5213VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
5214 AffineMap permMap) {
5215 auto i1Type = IntegerType::get(permMap.getContext(), 1);
5216 AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
5217 assert(invPermMap && "Inversed permutation map couldn't be computed");
5218 SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
5219
5220 // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
5221 // 0-D mask into a single-element 1-D mask.
5222 if (maskShape.empty())
5223 maskShape.push_back(1);
5224
5225 SmallVector<bool> scalableDims =
5226 applyPermutationMap(invPermMap, vecType.getScalableDims());
5227
5228 return VectorType::get(maskShape, i1Type, scalableDims);
5229}
5230
5231ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
5232 auto &builder = parser.getBuilder();
5233 SMLoc typesLoc;
5239 // Parsing with support for paddingValue.
5240 if (parser.parseOperand(sourceInfo) ||
5242 parser.parseComma() || parser.parseOperand(paddingInfo))
5243 return failure();
5244 ParseResult hasMask = parser.parseOptionalComma();
5245 if (hasMask.succeeded()) {
5246 if (parser.parseOperand(maskInfo))
5247 return failure();
5248 }
5249 if (parser.parseOptionalAttrDict(result.attributes) ||
5250 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5251 return failure();
5252 if (types.size() != 2)
5253 return parser.emitError(typesLoc, "requires two types");
5254 auto indexType = builder.getIndexType();
5255 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
5256 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5257 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5258 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
5259 if (!vectorType)
5260 return parser.emitError(typesLoc, "requires vector type");
5261 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
5262 Attribute permMapAttr = result.attributes.get(permMapAttrName);
5263 AffineMap permMap;
5264 if (!permMapAttr) {
5265 if (shapedType.getRank() <
5266 getEffectiveVectorRankForXferOp(shapedType, vectorType))
5267 return parser.emitError(typesLoc,
5268 "expected a custom permutation_map when "
5269 "rank(source) != rank(destination)");
5270 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5271 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5272 } else {
5273 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5274 }
5275 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
5276 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5277 if (!inBoundsAttr) {
5278 result.addAttribute(inBoundsAttrName,
5279 builder.getBoolArrayAttr(
5280 SmallVector<bool>(permMap.getNumResults(), false)));
5281 }
5282 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5283 parser.resolveOperands(indexInfo, indexType, result.operands) ||
5284 parser.resolveOperand(paddingInfo, shapedType.getElementType(),
5285 result.operands))
5286 return failure();
5287 if (hasMask.succeeded()) {
5288 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5289 return parser.emitError(
5290 maskInfo.location, "does not support masks with vector element type");
5291 if (vectorType.getRank() != permMap.getNumResults()) {
5292 return parser.emitError(typesLoc,
5293 "expected the same rank for the vector and the "
5294 "results of the permutation map");
5295 }
5296 // Instead of adding the mask type as an op type, compute it based on the
5297 // vector type and the permutation map (to keep the type signature small).
5298 auto maskType = inferTransferOpMaskType(vectorType, permMap);
5299 if (parser.resolveOperand(maskInfo, maskType, result.operands))
5300 return failure();
5301 }
5302 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
5303 builder.getDenseI32ArrayAttr(
5304 {1, static_cast<int32_t>(indexInfo.size()), 1,
5305 static_cast<int32_t>(hasMask.succeeded())}));
5306 return parser.addTypeToList(vectorType, result.types);
5307}
5308
5309LogicalResult TransferReadOp::verify() {
5310 // Consistency of elemental types in source and vector.
5311 ShapedType shapedType = getShapedType();
5312 VectorType vectorType = getVectorType();
5313 VectorType maskType = getMaskType();
5314 auto paddingType = getPadding().getType();
5315 auto permutationMap = getPermutationMap();
5316 VectorType inferredMaskType =
5317 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5318 : VectorType();
5319 auto sourceElementType = shapedType.getElementType();
5320
5321 if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
5322 return emitOpError("requires ") << shapedType.getRank() << " indices";
5323
5324 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5325 shapedType, vectorType, maskType,
5326 inferredMaskType, permutationMap, getInBounds())))
5327 return failure();
5328
5329 if (auto sourceVectorElementType =
5330 llvm::dyn_cast<VectorType>(sourceElementType)) {
5331 // Source has vector element type.
5332 // Check that 'sourceVectorElementType' and 'paddingType' types match.
5333 if (sourceVectorElementType != paddingType)
5334 return emitOpError(
5335 "requires source element type and padding type to match.");
5336
5337 } else {
5338 // Check that 'paddingType' is valid to store in a vector type.
5339 if (!VectorType::isValidElementType(paddingType))
5340 return emitOpError("requires valid padding vector elemental type");
5341
5342 // Check that padding type and vector element types match.
5343 if (paddingType != sourceElementType)
5344 return emitOpError(
5345 "requires formal padding and source of the same elemental type");
5346 }
5347
5348 return verifyPermutationMap(permutationMap,
5349 [&](Twine t) { return emitOpError(t); });
5350}
5351
5352// MaskableOpInterface methods.
5353
5354/// Returns the mask type expected by this operation. Mostly used for
5355/// verification purposes. It requires the operation to be vectorized."
5356Type TransferReadOp::getExpectedMaskType() {
5357 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5358}
5359
5360//===----------------------------------------------------------------------===//
5361// TransferReadOp: VectorTransferOpInterface methods.
5362//===----------------------------------------------------------------------===//
5363VectorType TransferReadOp::getVectorType() {
5364 return cast<VectorType>(getVector().getType());
5365}
5366
5367template <typename TransferOp>
5368static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
5369 // TODO: support more aggressive createOrFold on:
5370 // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
5371 if (op.getShapedType().isDynamicDim(indicesIdx))
5372 return false;
5373 Value index = op.getIndices()[indicesIdx];
5374 std::optional<int64_t> cstOp = getConstantIntValue(index);
5375 if (!cstOp.has_value())
5376 return false;
5377
5378 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
5379 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
5380
5381 return cstOp.value() + vectorSize <= sourceSize;
5382}
5383
5384template <typename TransferOp>
5385static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
5386 // TODO: support 0-d corner case.
5387 // TODO: Be less conservative.
5388 if (op.getTransferRank() == 0)
5389 return failure();
5390 AffineMap permutationMap = op.getPermutationMap();
5391 bool changed = false;
5392 SmallVector<bool, 4> newInBounds;
5393 newInBounds.reserve(op.getTransferRank());
5394 // Idxs of non-bcast dims - used when analysing bcast dims.
5395 SmallVector<unsigned> nonBcastDims;
5396
5397 // 1. Process non-broadcast dims
5398 for (unsigned i = 0; i < op.getTransferRank(); ++i) {
5399 // 1.1. Already marked as in-bounds, nothing to see here.
5400 if (op.isDimInBounds(i)) {
5401 newInBounds.push_back(true);
5402 continue;
5403 }
5404 // 1.2. Currently out-of-bounds, check whether we can statically determine
5405 // it is inBounds.
5406 bool inBounds = false;
5407 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
5408 if (dimExpr) {
5409 inBounds = isInBounds(op, /*resultIdx=*/i,
5410 /*indicesIdx=*/dimExpr.getPosition());
5411 nonBcastDims.push_back(i);
5412 }
5413
5414 newInBounds.push_back(inBounds);
5415 // We commit the pattern if it is "more inbounds".
5416 changed |= inBounds;
5417 }
5418
5419 // 2. Handle broadcast dims
5420 // If all non-broadcast dims are "in bounds", then all bcast dims should be
5421 // "in bounds" as well.
5422 bool allNonBcastDimsInBounds = llvm::all_of(
5423 nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
5424 if (allNonBcastDimsInBounds) {
5425 for (size_t idx : permutationMap.getBroadcastDims()) {
5426 changed |= !newInBounds[idx];
5427 newInBounds[idx] = true;
5428 }
5429 }
5430
5431 if (!changed)
5432 return failure();
5433 // OpBuilder is only used as a helper to build an I64ArrayAttr.
5434 OpBuilder b(op.getContext());
5435 op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
5436 return success();
5437}
5438
5439template <typename TransferOp>
5440static LogicalResult foldTransferFullMask(TransferOp op) {
5441 auto mask = op.getMask();
5442 if (!mask)
5443 return failure();
5444
5446 return failure();
5447
5448 op.getMaskMutable().clear();
5449 return success();
5450}
5451
5452/// ```
5453/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5454/// : vector<1x4xf32>, tensor<4x4xf32>
5455/// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
5456/// : tensor<4x4xf32>, vector<1x4xf32>
5457/// ```
5458/// -> Folds into
5459/// ```
5460/// %v0
5461/// ```
5462static Value foldRAW(TransferReadOp readOp) {
5463 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5464 return {};
5465 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5466 while (defWrite) {
5467 if (checkSameValueRAW(defWrite, readOp))
5468 return defWrite.getVector();
5470 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5471 cast<VectorTransferOpInterface>(readOp.getOperation())))
5472 break;
5473 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5474 }
5475 return {};
5476}
5477
5478OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5479 if (Value vec = foldRAW(*this))
5480 return vec;
5481 /// transfer_read(memrefcast) -> transfer_read
5482 if (succeeded(foldTransferInBoundsAttribute(*this)))
5483 return getResult();
5484 if (succeeded(foldTransferFullMask(*this)))
5485 return getResult();
5486 if (succeeded(memref::foldMemRefCast(*this)))
5487 return getResult();
5488 if (succeeded(tensor::foldTensorCast(*this)))
5489 return getResult();
5490 return OpFoldResult();
5491}
5492
5493std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5494 return llvm::to_vector<4>(getVectorType().getShape());
5495}
5496
5497void TransferReadOp::getEffects(
5498 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5499 &effects) {
5500 if (llvm::isa<MemRefType>(getShapedType()))
5501 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5502 SideEffects::DefaultResource::get());
5503}
5504
5505Speculation::Speculatability TransferReadOp::getSpeculatability() {
5506 if (hasPureTensorSemantics())
5509}
5510
5511/// Given a projected permutation, inverse an affine map, making the unused dims
5512/// 0 in the result.
5513static AffineMap inverseWithUnusedDims(AffineMap map) {
5514 assert(map.isProjectedPermutation() &&
5515 "expected a projected permutation map");
5516 SmallVector<AffineExpr> results(map.getNumInputs(),
5518 for (auto [idx, result] : llvm::enumerate(map.getResults())) {
5519 // We should only have dim exprs because this is a projected permutation.
5520 int64_t pos = cast<AffineDimExpr>(result).getPosition();
5521 results[pos] = getAffineDimExpr(idx, map.getContext());
5522 }
5523 return AffineMap::get(/*dimCount=*/map.getNumResults(), /*symbolCount=*/0,
5524 results, map.getContext());
5525}
5526
5527namespace {
5528/// Store to load forwarding for transfer operations with permuation maps.
5529/// Even if the permutation maps are different we can still propagate the store
5530/// into the load if the size of the dimensions read and written match. Then we
5531/// can replace the transfer_read + transfer_write by vector.broadcast and
5532/// vector.transpose.
5533/// Example:
5534/// ```
5535/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
5536/// {in_bounds = [true, true],
5537/// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
5538/// vector<4x1xf32>, tensor<4x4x4xf32>
5539/// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
5540/// {in_bounds = [true, true, true, true],
5541/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
5542/// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
5543/// ```
5544/// To:
5545/// ```
5546/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
5547/// %r = vector.transpose %0, [3, 0, 2, 1] :
5548/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
5549/// ```
5550struct TransferReadAfterWriteToBroadcast
5551 : public OpRewritePattern<TransferReadOp> {
5552 using Base::Base;
5553
5554 LogicalResult matchAndRewrite(TransferReadOp readOp,
5555 PatternRewriter &rewriter) const override {
5556 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5557 if (!defWrite)
5558 return failure();
5559 // Bail if we need an alias analysis.
5560 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5561 return failure();
5562 // Bail in the masked case (too complex atm and needed to properly account
5563 // for padding).
5564 if (readOp.getMask() || defWrite.getMask())
5565 return failure();
5566 // If indices are not the same a shift may be required, bail.
5567 if (readOp.getIndices() != defWrite.getIndices())
5568 return failure();
5569 // Bail if we need a bounds analysis.
5570 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5571 return failure();
5572 // TODO: If the written transfer chunk is a superset of the read transfer
5573 // chunk we could do an extract_strided_slice.
5574 if (readOp.getTransferChunkAccessed() !=
5575 defWrite.getTransferChunkAccessed())
5576 return failure();
5577 // WriteMap: tensor -> w_vec
5578 // ReadMap: tensor -> r_vec
5579 //
5580 // inv(WriteMap): w_vec -> tensor
5581 // inv(WriteMap) o ReadMap: w_vec -> r_vec
5582 AffineMap readMap = readOp.getPermutationMap();
5583 AffineMap writeMap = defWrite.getPermutationMap();
5584 AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5585 AffineMap composedMap = readMap.compose(invWriteMap);
5586 // If there are any unused dims in the composedMap, we have to drop some
5587 // unit dims from the written vector before we can do transpose(broadcast).
5588 // TODO: Support this case.
5589 if (getUnusedDimsBitVector(composedMap).any())
5590 return failure();
5591 // readVec = transpose(broadcast(writeVec))
5592 //
5593 // Build a transpose permutation for the above transpose operation.
5594 //
5595 // Treat the composed map as having extra leading dimensions which are
5596 // the broadcasted dimensions, and treat the zeros as these new broadcasted
5597 // dimensions.
5598 SmallVector<unsigned> broadcastedDims = composedMap.getBroadcastDims();
5599 int64_t numBroadcastedDims = broadcastedDims.size();
5600 auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5601 invPerm.resize(composedMap.getNumResults());
5602 for (auto [idx, expr] : llvm::enumerate(composedMap.getResults())) {
5603 if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
5604 int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5605 invPerm[effectiveDim] = idx;
5606 }
5607 }
5608 // Applying the inverse permutation on the readVecTy will give us the
5609 // broadcast result type.
5610 VectorType readVecTy = readOp.getVectorType();
5611 SmallVector<int64_t> permutation = invertPermutationVector(invPerm);
5612 auto broadcastedVecTy =
5613 VectorType::get(applyPermutation(readVecTy.getShape(), invPerm),
5614 readVecTy.getElementType(),
5615 applyPermutation(readVecTy.getScalableDims(), invPerm));
5616 // Build the transpose(broadcast) transformation.
5617 Value vec = defWrite.getVector();
5618 Location loc = readOp.getLoc();
5619 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5620 rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec, permutation);
5621 return success();
5622 }
5623};
5624} // namespace
5625
5626void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5627 MLIRContext *context) {
5628 results.add<TransferReadAfterWriteToBroadcast>(context);
5629}
5630
5631FailureOr<std::optional<SmallVector<Value>>>
5632TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5633 if (!hasPureBufferSemantics())
5634 return failure();
5636 getResult());
5637}
5638
5639//===----------------------------------------------------------------------===//
5640// TransferWriteOp
5641//===----------------------------------------------------------------------===//
5642
5643/// 1. Builder with type inference.
5644void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5645 Value vector, Value dest, ValueRange indices,
5646 AffineMapAttr permutationMapAttr,
5647 /*optional*/ Value mask,
5648 /*optional*/ ArrayAttr inBoundsAttr) {
5649 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
5650 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
5651 mask, inBoundsAttr);
5652}
5653
5654/// 2. Builder with type inference that sets an empty mask (variant with attrs).
5655void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5656 Value vector, Value dest, ValueRange indices,
5657 AffineMapAttr permutationMapAttr,
5658 /*optional*/ ArrayAttr inBoundsAttr) {
5659 build(builder, result, vector, dest, indices, permutationMapAttr,
5660 /*mask=*/Value(), inBoundsAttr);
5661}
5662
5663/// 3. Builder with type inference that sets an empty mask (variant without
5664/// attrs)
5665void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5666 Value vector, Value dest, ValueRange indices,
5667 AffineMap permutationMap,
5668 std::optional<ArrayRef<bool>> inBounds) {
5669 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5670 auto inBoundsAttr =
5671 (inBounds && !inBounds.value().empty())
5672 ? builder.getBoolArrayAttr(inBounds.value())
5673 : builder.getBoolArrayAttr(SmallVector<bool>(
5674 llvm::cast<VectorType>(vector.getType()).getRank(), false));
5675 build(builder, result, vector, dest, indices, permutationMapAttr,
5676 /*mask=*/Value(), inBoundsAttr);
5677}
5678
5679/// 4. Builder with type inference that sets an empty mask and sets permutation
5680/// map to 'getMinorIdentityMap'.
5681void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5682 Value vector, Value dest, ValueRange indices,
5683 std::optional<ArrayRef<bool>> inBounds) {
5684 auto vectorType = llvm::cast<VectorType>(vector.getType());
5685 AffineMap permutationMap = getTransferMinorIdentityMap(
5686 llvm::cast<ShapedType>(dest.getType()), vectorType);
5687 build(builder, result, vector, dest, indices, permutationMap, inBounds);
5688}
5689
5690ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5691 OperationState &result) {
5692 auto &builder = parser.getBuilder();
5693 SMLoc typesLoc;
5694 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5695 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5696 SmallVector<Type, 2> types;
5697 OpAsmParser::UnresolvedOperand maskInfo;
5698 if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
5699 parser.parseOperand(sourceInfo) ||
5700 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
5701 return failure();
5702 ParseResult hasMask = parser.parseOptionalComma();
5703 if (hasMask.succeeded() && parser.parseOperand(maskInfo))
5704 return failure();
5705 if (parser.parseOptionalAttrDict(result.attributes) ||
5706 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5707 return failure();
5708 if (types.size() != 2)
5709 return parser.emitError(typesLoc, "requires two types");
5710 auto indexType = builder.getIndexType();
5711 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5712 if (!vectorType)
5713 return parser.emitError(typesLoc, "requires vector type");
5714 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5715 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5716 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5717 auto permMapAttrName =
5718 TransferWriteOp::getPermutationMapAttrName(result.name);
5719 auto permMapAttr = result.attributes.get(permMapAttrName);
5720 AffineMap permMap;
5721 if (!permMapAttr) {
5722 if (shapedType.getRank() <
5723 getEffectiveVectorRankForXferOp(shapedType, vectorType))
5724 return parser.emitError(typesLoc,
5725 "expected a custom permutation_map when "
5726 "rank(source) != rank(destination)");
5727 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5728 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5729 } else {
5730 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5731 }
5732 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
5733 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5734 if (!inBoundsAttr) {
5735 result.addAttribute(inBoundsAttrName,
5736 builder.getBoolArrayAttr(
5737 SmallVector<bool>(permMap.getNumResults(), false)));
5738 }
5739 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
5740 parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5741 parser.resolveOperands(indexInfo, indexType, result.operands))
5742 return failure();
5743 if (hasMask.succeeded()) {
5744 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5745 return parser.emitError(
5746 maskInfo.location, "does not support masks with vector element type");
5747 if (vectorType.getRank() != permMap.getNumResults()) {
5748 return parser.emitError(typesLoc,
5749 "expected the same rank for the vector and the "
5750 "results of the permutation map");
5751 }
5752 auto maskType = inferTransferOpMaskType(vectorType, permMap);
5753 if (parser.resolveOperand(maskInfo, maskType, result.operands))
5754 return failure();
5755 }
5756 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5757 builder.getDenseI32ArrayAttr(
5758 {1, 1, static_cast<int32_t>(indexInfo.size()),
5759 static_cast<int32_t>(hasMask.succeeded())}));
5760 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5761 parser.addTypeToList(shapedType, result.types));
5762}
5763
5764void TransferWriteOp::print(OpAsmPrinter &p) {
5765 p << " " << getVector() << ", " << getBase() << "[" << getIndices() << "]";
5766 if (getMask())
5767 p << ", " << getMask();
5768 printTransferAttrs(p, *this);
5769 p << " : " << getVectorType() << ", " << getShapedType();
5770}
5771
5772LogicalResult TransferWriteOp::verify() {
5773 // Consistency of elemental types in shape and vector.
5774 ShapedType shapedType = getShapedType();
5775 VectorType vectorType = getVectorType();
5776 VectorType maskType = getMaskType();
5777 auto permutationMap = getPermutationMap();
5778 VectorType inferredMaskType =
5779 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5780 : VectorType();
5781
5782 if (llvm::size(getIndices()) != shapedType.getRank())
5783 return emitOpError("requires ") << shapedType.getRank() << " indices";
5784
5785 // We do not allow broadcast dimensions on TransferWriteOps for the moment,
5786 // as the semantics is unclear. This can be revisited later if necessary.
5787 if (hasBroadcastDim())
5788 return emitOpError("should not have broadcast dimensions");
5789
5790 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5791 shapedType, vectorType, maskType,
5792 inferredMaskType, permutationMap, getInBounds())))
5793 return failure();
5794
5795 return verifyPermutationMap(permutationMap,
5796 [&](Twine t) { return emitOpError(t); });
5797}
5798
5799//===----------------------------------------------------------------------===//
5800// TransferWriteOp: MaskableOpInterface methods.
5801//===----------------------------------------------------------------------===//
5802
5803/// Returns the mask type expected by this operation. Mostly used for
5804/// verification purposes.
5805Type TransferWriteOp::getExpectedMaskType() {
5806 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5807}
5808
5809//===----------------------------------------------------------------------===//
5810// TransferWriteOp: VectorTransferOpInterface methods.
5811//===----------------------------------------------------------------------===//
5812Value TransferWriteOp::getVector() { return getOperand(0); }
5813VectorType TransferWriteOp::getVectorType() {
5814 return cast<VectorType>(getValueToStore().getType());
5815}
5816
5817//===----------------------------------------------------------------------===//
5818// TransferWriteOp: fold methods.
5819//===----------------------------------------------------------------------===//
5820/// Fold:
5821/// ```
5822/// %t1 = ...
5823/// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
5824/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5825/// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
5826/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5827/// ```
5828///
5829/// into:
5830///
5831/// ```
5832/// %t0
5833/// ```
5834///
5835/// The producer of t1 may or may not be DCE'd depending on whether it is a
5836/// block argument or has side effects.
5837static LogicalResult foldReadInitWrite(TransferWriteOp write,
5838 ArrayRef<Attribute>,
5839 SmallVectorImpl<OpFoldResult> &results) {
5840 // TODO: support 0-d corner case.
5841 if (write.getTransferRank() == 0)
5842 return failure();
5843 auto rankedTensorType =
5844 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5845 // If not operating on tensors, bail.
5846 if (!rankedTensorType)
5847 return failure();
5848 // If no read, bail.
5849 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5850 if (!read)
5851 return failure();
5852 // TODO: support 0-d corner case.
5853 if (read.getTransferRank() == 0)
5854 return failure();
5855 // For now, only accept minor identity. Future: composition is minor identity.
5856 if (!read.getPermutationMap().isMinorIdentity() ||
5857 !write.getPermutationMap().isMinorIdentity())
5858 return failure();
5859 // Bail on mismatching ranks.
5860 if (read.getTransferRank() != write.getTransferRank())
5861 return failure();
5862 // Bail on potential out-of-bounds accesses.
5863 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5864 return failure();
5865 // Masked transfers have padding/select semantics and are not identity folds.
5866 if (read.getMask() || write.getMask())
5867 return failure();
5868 // Tensor types must be the same.
5869 if (read.getBase().getType() != rankedTensorType)
5870 return failure();
5871 // Vector types must be the same.
5872 if (read.getVectorType() != write.getVectorType())
5873 return failure();
5874 // Vector and Tensor shapes must match.
5875 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5876 return failure();
5877 // If any index is nonzero.
5878 auto isNotConstantZero = [](Value v) {
5879 auto cstOp = getConstantIntValue(v);
5880 return !cstOp.has_value() || cstOp.value() != 0;
5881 };
5882 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5883 llvm::any_of(write.getIndices(), isNotConstantZero))
5884 return failure();
5885 // Success.
5886 results.push_back(read.getBase());
5887 return success();
5888}
5889
5890static bool checkSameValueWAR(vector::TransferReadOp read,
5891 vector::TransferWriteOp write) {
5892 return read.getBase() == write.getBase() &&
5893 read.getIndices() == write.getIndices() &&
5894 read.getPermutationMap() == write.getPermutationMap() &&
5895 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5896 !write.getMask();
5897}
5898/// Fold transfer_write write after read:
5899/// ```
5900/// %t0 = ...
5901/// %v = vector.transfer_read %t0[%c0...] :
5902/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5903/// %t1 = vector.transfer_write %v, %t0[%c0...] :
5904/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5905/// ```
5906///
5907/// into:
5908///
5909/// ```
5910/// %t0
5911/// ```
5912static LogicalResult foldWAR(TransferWriteOp write,
5913 SmallVectorImpl<OpFoldResult> &results) {
5914 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5915 return failure();
5916 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5917 if (!read)
5918 return failure();
5919
5920 if (!checkSameValueWAR(read, write))
5921 return failure();
5922 results.push_back(read.getBase());
5923 return success();
5924}
5925
5926LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5927 SmallVectorImpl<OpFoldResult> &results) {
5928 if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
5929 return success();
5930 if (succeeded(foldWAR(*this, results)))
5931 return success();
5932 if (succeeded(foldTransferInBoundsAttribute(*this)))
5933 return success();
5934 if (succeeded(foldTransferFullMask(*this)))
5935 return success();
5936 return memref::foldMemRefCast(*this);
5937}
5938
5939//===----------------------------------------------------------------------===//
5940// TransferWriteOp: other methods.
5941//===----------------------------------------------------------------------===//
5942std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5943 return llvm::to_vector<4>(getVectorType().getShape());
5944}
5945
5946void TransferWriteOp::getEffects(
5947 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5948 &effects) {
5949 if (llvm::isa<MemRefType>(getShapedType()))
5950 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5951 SideEffects::DefaultResource::get());
5952}
5953
5954Speculation::Speculatability TransferWriteOp::getSpeculatability() {
5955 if (hasPureTensorSemantics())
5958}
5959
5960namespace {
5961/// Remove dead transfer write from the SSA chain so that it an be eliminated by
5962/// DCE
5963/// ```
5964/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5965/// : vector<1x4xf32>, tensor<4x4xf32>
5966/// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
5967/// : vector<1x4xf32>, tensor<4x4xf32>
5968/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5969/// : vector<1x4xf32>, tensor<4x4xf32>
5970/// ```
5971///
5972/// into:
5973///
5974/// ```
5975/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5976/// : vector<1x4xf32>, tensor<4x4xf32>
5977/// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
5978/// : vector<1x4xf32>, tensor<4x4xf32>
5979/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5980/// : vector<1x4xf32>, tensor<4x4xf32>
5981/// ```
5982///
5983/// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
5984/// any other uses.
5985class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
5986public:
5987 using Base::Base;
5988 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5989 PatternRewriter &rewriter) const override {
5990 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5991 return failure();
5992 vector::TransferWriteOp writeToModify = writeOp;
5993
5994 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5995 while (defWrite) {
5996 if (checkSameValueWAW(writeOp, defWrite)) {
5997 rewriter.modifyOpInPlace(writeToModify, [&]() {
5998 writeToModify.getBaseMutable().assign(defWrite.getBase());
5999 });
6000 return success();
6001 }
6003 cast<VectorTransferOpInterface>(defWrite.getOperation()),
6004 cast<VectorTransferOpInterface>(writeOp.getOperation())))
6005 break;
6006 // If the previous write op doesn't have any other use we an safely look
6007 // at the previous store to see if it can be removed.
6008 if (!defWrite->hasOneUse())
6009 break;
6010 writeToModify = defWrite;
6011 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
6012 }
6013 return failure();
6014 }
6015};
6016
6017/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
6018/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
6019/// overwritten and inserted into another tensor. After this rewrite, the
6020/// operations bufferize in-place since all of them work on the same slice.
6021///
6022/// For example:
6023/// ```mlir
6024/// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
6025/// : vector<8x16xf32>, tensor<8x16xf32>
6026/// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
6027/// : tensor<8x16xf32> to tensor<?x?xf32>
6028/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
6029/// : tensor<?x?xf32> into tensor<27x37xf32>
6030/// ```
6031/// folds to
6032/// ```mlir
6033/// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
6034/// : tensor<27x37xf32> to tensor<?x?xf32>
6035/// %1 = vector.transfer_write %vec, %0[%c0, %c0]
6036/// : vector<8x16xf32>, tensor<?x?xf32>
6037/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
6038/// : tensor<?x?xf32> into tensor<27x37xf32>
6039/// ```
6040struct SwapExtractSliceOfTransferWrite
6041 : public OpRewritePattern<tensor::InsertSliceOp> {
6042public:
6043 using Base::Base;
6044
6045 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
6046 PatternRewriter &rewriter) const override {
6047 if (!insertOp.hasUnitStride())
6048 return failure();
6049 auto extractOp =
6050 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
6051 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
6052 return failure();
6053 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
6054 if (!transferOp || !transferOp->hasOneUse())
6055 return failure();
6056
6057 // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
6058 // rank-reducing.
6059 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
6060 return rewriter.notifyMatchFailure(insertOp,
6061 "use-def chain is rank-reducing");
6062 }
6063
6064 // Fail if tensor::ExtractSliceOp has non-zero offset.
6065 if (!extractOp.hasZeroOffset()) {
6066 return rewriter.notifyMatchFailure(insertOp,
6067 "ExtractSliceOp has non-zero offset");
6068 }
6069
6070 // Fail if tensor::TransferWriteOp has non-zero offset.
6071 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
6072 return getConstantIntValue(value) == static_cast<int64_t>(0);
6073 })) {
6074 return rewriter.notifyMatchFailure(insertOp,
6075 "TranferWriteOp has non-zero offset");
6076 }
6077
6078 // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
6079 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
6080 return rewriter.notifyMatchFailure(
6081 insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
6082 }
6083
6084 for (auto [insertSize, extractSize] :
6085 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
6086 if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
6087 return rewriter.notifyMatchFailure(
6088 insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
6089 }
6090 }
6091
6092 // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
6093 assert(transferOp.getVectorType().hasStaticShape() &&
6094 "expected vector to have a static shape");
6095 ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
6096 SmallVector<int64_t> resultShape = applyPermutationMap(
6097 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
6098 if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
6099 return rewriter.notifyMatchFailure(
6100 insertOp, "TransferWriteOp may not write the full tensor.");
6101 }
6102
6103 // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
6104 // Set all in_bounds to false and let the folder infer them.
6105 SmallVector<bool> newInBounds(vectorShape.size(), false);
6106 auto newExtractOp = tensor::ExtractSliceOp::create(
6107 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
6108 insertOp.getDest(), insertOp.getMixedOffsets(),
6109 insertOp.getMixedSizes(), insertOp.getMixedStrides());
6110 auto newTransferWriteOp = TransferWriteOp::create(
6111 rewriter, transferOp.getLoc(), transferOp.getVector(),
6112 newExtractOp.getResult(), transferOp.getIndices(),
6113 transferOp.getPermutationMapAttr(),
6114 rewriter.getBoolArrayAttr(newInBounds));
6115 rewriter.modifyOpInPlace(insertOp, [&]() {
6116 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
6117 });
6118 return success();
6119 }
6120};
6121
6122} // namespace
6123
6124void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
6125 MLIRContext *context) {
6126 results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
6127}
6128
6129FailureOr<std::optional<SmallVector<Value>>>
6130TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
6131 if (!hasPureBufferSemantics())
6132 return failure();
6134 ValueRange());
6135}
6136
6137//===----------------------------------------------------------------------===//
6138// LoadOp
6139//===----------------------------------------------------------------------===//
6140
6141static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
6142 VectorType vecTy,
6143 MemRefType memRefTy) {
6144 // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
6145 // need any strides limitations.
6146 if (!vecTy.isScalable() &&
6147 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
6148 return success();
6149
6150 if (!memRefTy.isLastDimUnitStride())
6151 return op->emitOpError("most minor memref dim must have unit stride");
6152 return success();
6153}
6154
6155LogicalResult vector::LoadOp::verify() {
6156 VectorType resVecTy = getVectorType();
6157 MemRefType memRefTy = getMemRefType();
6158
6159 if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
6160 return failure();
6161
6162 if (memRefTy.getRank() < resVecTy.getRank())
6163 return emitOpError(
6164 "destination memref has lower rank than the result vector");
6165
6166 // Checks for vector memrefs.
6167 Type memElemTy = memRefTy.getElementType();
6168 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6169 if (memVecTy != resVecTy)
6170 return emitOpError("base memref and result vector types should match");
6171 memElemTy = memVecTy.getElementType();
6172 }
6173
6174 if (resVecTy.getElementType() != memElemTy)
6175 return emitOpError("base and result element types should match");
6176 if (llvm::size(getIndices()) != memRefTy.getRank())
6177 return emitOpError("requires ") << memRefTy.getRank() << " indices";
6178 return success();
6179}
6180
6181OpFoldResult LoadOp::fold(FoldAdaptor) {
6182 if (succeeded(memref::foldMemRefCast(*this)))
6183 return getResult();
6184 return OpFoldResult();
6185}
6186
6187std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
6188 return llvm::to_vector<4>(getVectorType().getShape());
6189}
6190
6191FailureOr<std::optional<SmallVector<Value>>>
6192LoadOp::bubbleDownCasts(OpBuilder &builder) {
6194 getResult());
6195}
6196
6197//===----------------------------------------------------------------------===//
6198// StoreOp
6199//===----------------------------------------------------------------------===//
6200
6201LogicalResult vector::StoreOp::verify() {
6202 VectorType valueVecTy = getVectorType();
6203 MemRefType memRefTy = getMemRefType();
6204
6205 if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
6206 return failure();
6207
6208 if (memRefTy.getRank() < valueVecTy.getRank())
6209 return emitOpError("source memref has lower rank than the vector to store");
6210
6211 // Checks for vector memrefs.
6212 Type memElemTy = memRefTy.getElementType();
6213 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6214 if (memVecTy != valueVecTy)
6215 return emitOpError(
6216 "base memref and valueToStore vector types should match");
6217 memElemTy = memVecTy.getElementType();
6218 }
6219
6220 if (valueVecTy.getElementType() != memElemTy)
6221 return emitOpError("base and valueToStore element type should match");
6222 if (llvm::size(getIndices()) != memRefTy.getRank())
6223 return emitOpError("requires ") << memRefTy.getRank() << " indices";
6224 return success();
6225}
6226
6227LogicalResult StoreOp::fold(FoldAdaptor adaptor,
6228 SmallVectorImpl<OpFoldResult> &results) {
6229 return memref::foldMemRefCast(*this);
6230}
6231
6232std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
6233 return llvm::to_vector<4>(getVectorType().getShape());
6234}
6235
6236FailureOr<std::optional<SmallVector<Value>>>
6237StoreOp::bubbleDownCasts(OpBuilder &builder) {
6239 ValueRange());
6240}
6241
6242//===----------------------------------------------------------------------===//
6243// MaskedLoadOp
6244//===----------------------------------------------------------------------===//
6245
6246LogicalResult MaskedLoadOp::verify() {
6247 VectorType maskVType = getMaskVectorType();
6248 VectorType passVType = getPassThruVectorType();
6249 VectorType resVType = getVectorType();
6250 MemRefType memType = getMemRefType();
6251
6252 if (failed(
6253 verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
6254 return failure();
6255 if (llvm::size(getIndices()) != memType.getRank())
6256 return emitOpError("requires ") << memType.getRank() << " indices";
6257 if (resVType.getShape() != maskVType.getShape())
6258 return emitOpError("expected result shape to match mask shape");
6259 if (resVType != passVType)
6260 return emitOpError("expected pass_thru of same type as result type");
6261 return success();
6262}
6263
6264namespace {
6265class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
6266public:
6267 using Base::Base;
6268 LogicalResult matchAndRewrite(MaskedLoadOp load,
6269 PatternRewriter &rewriter) const override {
6270 switch (getMaskFormat(load.getMask())) {
6272 rewriter.replaceOpWithNewOp<vector::LoadOp>(
6273 load, load.getType(), load.getBase(), load.getIndices());
6274 return success();
6276 rewriter.replaceOp(load, load.getPassThru());
6277 return success();
6279 return failure();
6280 }
6281 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
6282 }
6283};
6284} // namespace
6285
6286void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6287 MLIRContext *context) {
6288 results.add<MaskedLoadFolder>(context);
6289}
6290
6291OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
6292 if (succeeded(memref::foldMemRefCast(*this)))
6293 return getResult();
6294 return OpFoldResult();
6295}
6296
6297FailureOr<std::optional<SmallVector<Value>>>
6298MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
6300 getResult());
6301}
6302
6303//===----------------------------------------------------------------------===//
6304// MaskedStoreOp
6305//===----------------------------------------------------------------------===//
6306
6307LogicalResult MaskedStoreOp::verify() {
6308 VectorType maskVType = getMaskVectorType();
6309 VectorType valueVType = getVectorType();
6310 MemRefType memType = getMemRefType();
6311
6312 if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
6313 "valueToStore")))
6314 return failure();
6315 if (llvm::size(getIndices()) != memType.getRank())
6316 return emitOpError("requires ") << memType.getRank() << " indices";
6317 if (valueVType.getShape() != maskVType.getShape())
6318 return emitOpError("expected valueToStore shape to match mask shape");
6319 return success();
6320}
6321
6322namespace {
6323class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
6324public:
6325 using Base::Base;
6326 LogicalResult matchAndRewrite(MaskedStoreOp store,
6327 PatternRewriter &rewriter) const override {
6328 switch (getMaskFormat(store.getMask())) {
6330 rewriter.replaceOpWithNewOp<vector::StoreOp>(
6331 store, store.getValueToStore(), store.getBase(), store.getIndices());
6332 return success();
6334 rewriter.eraseOp(store);
6335 return success();
6337 return failure();
6338 }
6339 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
6340 }
6341};
6342} // namespace
6343
6344void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6345 MLIRContext *context) {
6346 results.add<MaskedStoreFolder>(context);
6347}
6348
6349LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
6350 SmallVectorImpl<OpFoldResult> &results) {
6351 return memref::foldMemRefCast(*this);
6352}
6353
6354FailureOr<std::optional<SmallVector<Value>>>
6355MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
6357 ValueRange());
6358}
6359
6360//===----------------------------------------------------------------------===//
6361// GatherOp
6362//===----------------------------------------------------------------------===//
6363
6364LogicalResult GatherOp::verify() {
6365 VectorType indVType = getIndexVectorType();
6366 VectorType maskVType = getMaskVectorType();
6367 VectorType resVType = getVectorType();
6368 ShapedType baseType = getBaseType();
6369
6370 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6371 return emitOpError("requires base to be a memref or ranked tensor type");
6372
6373 if (failed(
6374 verifyElementTypesMatch(*this, baseType, resVType, "base", "result")))
6375 return failure();
6376 if (llvm::size(getOffsets()) != baseType.getRank())
6377 return emitOpError("requires ") << baseType.getRank() << " indices";
6378 if (resVType.getShape() != indVType.getShape())
6379 return emitOpError("expected result dim to match indices dim");
6380 if (resVType.getShape() != maskVType.getShape())
6381 return emitOpError("expected result dim to match mask dim");
6382 if (resVType != getPassThruVectorType())
6383 return emitOpError("expected pass_thru of same type as result type");
6384 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6385 return emitOpError(
6386 "alignment is only supported for memref bases, not tensor bases");
6387 }
6388 return success();
6389}
6390
6391// MaskableOpInterface methods.
6392
6393/// Returns the mask type expected by this operation. Mostly used for
6394/// verification purposes. It requires the operation to be vectorized."
6395Type GatherOp::getExpectedMaskType() {
6396 auto vecType = this->getIndexVectorType();
6397 return VectorType::get(vecType.getShape(),
6398 IntegerType::get(vecType.getContext(), /*width=*/1),
6399 vecType.getScalableDims());
6400}
6401
6402std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
6403 return llvm::to_vector<4>(getVectorType().getShape());
6404}
6405
6406/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
6407static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
6408 auto vecType = dyn_cast<VectorType>(indexVec.getType());
6409 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6410 return failure();
6411
6412 if (indexVec.getDefiningOp<StepOp>())
6413 return success();
6414
6415 DenseIntElementsAttr elements;
6416 if (!matchPattern(indexVec, m_Constant(&elements)))
6417 return failure();
6418
6419 return success(
6420 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6421}
6422
6423namespace {
6424class GatherFolder final : public OpRewritePattern<GatherOp> {
6425public:
6426 using Base::Base;
6427 LogicalResult matchAndRewrite(GatherOp gather,
6428 PatternRewriter &rewriter) const override {
6429 switch (getMaskFormat(gather.getMask())) {
6431 return failure(); // no unmasked equivalent
6433 rewriter.replaceOp(gather, gather.getPassThru());
6434 return success();
6436 return failure();
6437 }
6438 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
6439 }
6440};
6441
6442/// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
6443/// maskedload. Only 1D fixed vectors are supported for now.
6444class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
6445public:
6446 using Base::Base;
6447 LogicalResult matchAndRewrite(GatherOp op,
6448 PatternRewriter &rewriter) const override {
6449 if (!isa<MemRefType>(op.getBase().getType()))
6450 return rewriter.notifyMatchFailure(op, "base must be of memref type");
6451
6452 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6453 return failure();
6454
6455 rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
6456 op.getOffsets(), op.getMask(),
6457 op.getPassThru());
6458 return success();
6459 }
6460};
6461} // namespace
6462
6463void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6464 MLIRContext *context) {
6465 results.add<GatherFolder, FoldContiguousGather>(context);
6466}
6467
6468FailureOr<std::optional<SmallVector<Value>>>
6469GatherOp::bubbleDownCasts(OpBuilder &builder) {
6471 getResult());
6472}
6473
6474//===----------------------------------------------------------------------===//
6475// ScatterOp
6476//===----------------------------------------------------------------------===//
6477
6478LogicalResult ScatterOp::verify() {
6479 VectorType indVType = getIndexVectorType();
6480 VectorType maskVType = getMaskVectorType();
6481 VectorType valueVType = getVectorType();
6482 ShapedType baseType = getBaseType();
6483
6484 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6485 return emitOpError("requires base to be a memref or ranked tensor type");
6486
6487 if (failed(verifyElementTypesMatch(*this, baseType, valueVType, "base",
6488 "valueToStore")))
6489 return failure();
6490 if (llvm::size(getOffsets()) != baseType.getRank())
6491 return emitOpError("requires ") << baseType.getRank() << " indices";
6492 if (valueVType.getShape() != indVType.getShape())
6493 return emitOpError("expected valueToStore dim to match indices dim");
6494 if (valueVType.getShape() != maskVType.getShape())
6495 return emitOpError("expected valueToStore dim to match mask dim");
6496 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6497 return emitOpError(
6498 "alignment is only supported for memref bases, not tensor bases");
6499 }
6500 return success();
6501}
6502namespace {
6503class ScatterFolder final : public OpRewritePattern<ScatterOp> {
6504public:
6505 using Base::Base;
6506 LogicalResult matchAndRewrite(ScatterOp scatter,
6507 PatternRewriter &rewriter) const override {
6508 ShapedType baseType = scatter.getBaseType();
6509 bool isMemRef = isa<MemRefType>(baseType);
6510 if (!isMemRef && !isa<RankedTensorType>(baseType))
6511 return failure();
6512
6513 // Memrefs have no result, so an all-false mask can simply erase the op.
6514 // Tensors carry the updated value, so we must replace uses with the
6515 // original base tensor instead of erasing.
6516 switch (getMaskFormat(scatter.getMask())) {
6518 return failure(); // no unmasked equivalent
6520 if (isMemRef)
6521 rewriter.eraseOp(scatter);
6522 else
6523 rewriter.replaceOp(scatter, scatter.getBase());
6524 return success();
6526 return failure();
6527 }
6528 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
6529 }
6530};
6531
6532/// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
6533/// maskedstore. Only 1D fixed vectors are supported for now.
6534class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
6535public:
6536 using Base::Base;
6537 LogicalResult matchAndRewrite(ScatterOp op,
6538 PatternRewriter &rewriter) const override {
6539 // Fold only for memrefs: the replacement uses maskedstore, which does not
6540 // support tensor bases. Tensor cases intentionally bail out.
6541 if (!isa<MemRefType>(op.getBase().getType()))
6542 return failure();
6543
6544 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6545 return failure();
6546
6547 rewriter.replaceOpWithNewOp<MaskedStoreOp>(
6548 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6549 return success();
6550 }
6551};
6552} // namespace
6553
6554void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6555 MLIRContext *context) {
6556 results.add<ScatterFolder, FoldContiguousScatter>(context);
6557}
6558
6559FailureOr<std::optional<SmallVector<Value>>>
6560ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6562 ValueRange());
6563}
6564
6565//===----------------------------------------------------------------------===//
6566// ExpandLoadOp
6567//===----------------------------------------------------------------------===//
6568
6569LogicalResult ExpandLoadOp::verify() {
6570 VectorType maskVType = getMaskVectorType();
6571 VectorType passVType = getPassThruVectorType();
6572 VectorType resVType = getVectorType();
6573 MemRefType memType = getMemRefType();
6574
6575 if (failed(
6576 verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
6577 return failure();
6578 if (llvm::size(getIndices()) != memType.getRank())
6579 return emitOpError("requires ") << memType.getRank() << " indices";
6580 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
6581 return emitOpError("expected result dim to match mask dim");
6582 if (resVType != passVType)
6583 return emitOpError("expected pass_thru of same type as result type");
6584 return success();
6585}
6586
6587namespace {
6588class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
6589public:
6590 using Base::Base;
6591 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6592 PatternRewriter &rewriter) const override {
6593 switch (getMaskFormat(expand.getMask())) {
6595 rewriter.replaceOpWithNewOp<vector::LoadOp>(
6596 expand, expand.getType(), expand.getBase(), expand.getIndices());
6597 return success();
6599 rewriter.replaceOp(expand, expand.getPassThru());
6600 return success();
6602 return failure();
6603 }
6604 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
6605 }
6606};
6607} // namespace
6608
6609void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6610 MLIRContext *context) {
6611 results.add<ExpandLoadFolder>(context);
6612}
6613
6614FailureOr<std::optional<SmallVector<Value>>>
6615ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6617 getResult());
6618}
6619
6620//===----------------------------------------------------------------------===//
6621// CompressStoreOp
6622//===----------------------------------------------------------------------===//
6623
6624LogicalResult CompressStoreOp::verify() {
6625 VectorType maskVType = getMaskVectorType();
6626 VectorType valueVType = getVectorType();
6627 MemRefType memType = getMemRefType();
6628
6629 if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
6630 "valueToStore")))
6631 return failure();
6632 if (llvm::size(getIndices()) != memType.getRank())
6633 return emitOpError("requires ") << memType.getRank() << " indices";
6634 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6635 return emitOpError("expected valueToStore dim to match mask dim");
6636 return success();
6637}
6638
6639namespace {
6640class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
6641public:
6642 using Base::Base;
6643 LogicalResult matchAndRewrite(CompressStoreOp compress,
6644 PatternRewriter &rewriter) const override {
6645 switch (getMaskFormat(compress.getMask())) {
6647 rewriter.replaceOpWithNewOp<vector::StoreOp>(
6648 compress, compress.getValueToStore(), compress.getBase(),
6649 compress.getIndices());
6650 return success();
6652 rewriter.eraseOp(compress);
6653 return success();
6655 return failure();
6656 }
6657 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
6658 }
6659};
6660} // namespace
6661
6662void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6663 MLIRContext *context) {
6664 results.add<CompressStoreFolder>(context);
6665}
6666
6667FailureOr<std::optional<SmallVector<Value>>>
6668CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6670 ValueRange());
6671}
6672
6673//===----------------------------------------------------------------------===//
6674// ShapeCastOp
6675//===----------------------------------------------------------------------===//
6676
6677void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6678 SetIntRangeFn setResultRanges) {
6679 setResultRanges(getResult(), argRanges.front());
6680}
6681
6682std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6683 return llvm::to_vector<4>(getResultVectorType().getShape());
6684}
6685
6686LogicalResult ShapeCastOp::verify() {
6687
6688 VectorType sourceType = getSourceVectorType();
6689 VectorType resultType = getResultVectorType();
6690
6691 // Check that element type is preserved
6692 if (failed(verifyElementTypesMatch(*this, sourceType, resultType, "source",
6693 "result")))
6694 return failure();
6695
6696 // Check that number of elements is preserved
6697 int64_t sourceNElms = sourceType.getNumElements();
6698 int64_t resultNElms = resultType.getNumElements();
6699 if (sourceNElms != resultNElms) {
6700 return emitOpError() << "has different number of elements at source ("
6701 << sourceNElms << ") and result (" << resultNElms
6702 << ")";
6703 }
6704
6705 // Check that (non-)scalability is preserved
6706 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6707 int64_t resultNScalableDims = resultType.getNumScalableDims();
6708 if (sourceNScalableDims != resultNScalableDims)
6709 return emitOpError() << "has different number of scalable dims at source ("
6710 << sourceNScalableDims << ") and result ("
6711 << resultNScalableDims << ")";
6712
6713 return success();
6714}
6715
6716/// Return true if `transpose` does not permute a pair of non-unit dims.
6717/// By `order preserving` we mean that the flattened versions of the input and
6718/// output vectors are (numerically) identical. In other words `transpose` is
6719/// effectively a shape cast.
6720static bool isOrderPreserving(TransposeOp transpose) {
6721 ArrayRef<int64_t> permutation = transpose.getPermutation();
6722 VectorType sourceType = transpose.getSourceVectorType();
6723 ArrayRef<int64_t> inShape = sourceType.getShape();
6724 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6725 auto isNonScalableUnitDim = [&](int64_t dim) {
6726 return inShape[dim] == 1 && !inDimIsScalable[dim];
6727 };
6728 int64_t current = 0;
6729 for (auto p : permutation) {
6730 if (!isNonScalableUnitDim(p)) {
6731 if (p < current) {
6732 return false;
6733 }
6734 current = p;
6735 }
6736 }
6737 return true;
6738}
6739
6740OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6741
6742 VectorType resultType = getType();
6743
6744 // No-op shape cast.
6745 if (getSource().getType() == resultType)
6746 return getSource();
6747
6748 // shape_cast(shape_cast(x)) -> shape_cast(x)
6749 if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6750 setOperand(precedingShapeCast.getSource());
6751 return getResult();
6752 }
6753
6754 // shape_cast(transpose(x)) -> shape_cast(x)
6755 if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6756 if (isOrderPreserving(transpose)) {
6757 setOperand(transpose.getVector());
6758 return getResult();
6759 }
6760 return {};
6761 }
6762
6763 // Y = shape_cast(broadcast(X))
6764 // -> X, if X and Y have same type
6765 if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6766 if (bcastOp.getSourceType() == resultType)
6767 return bcastOp.getSource();
6768 }
6769
6770 // shape_cast(constant) -> constant
6771 if (auto denseAttr =
6772 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6773 return denseAttr.reshape(getType());
6774
6775 // shape_cast(poison) -> poison
6776 if (matchPattern(adaptor.getSource(), ub::m_Poison()))
6777 return ub::PoisonAttr::get(getContext());
6778
6779 return {};
6780}
6781
6782namespace {
6783
6784/// Helper function that computes a new vector type based on the input vector
6785/// type by removing the trailing one dims:
6786///
6787/// vector<4x1x1xi1> --> vector<4x1xi1>
6788///
6789static VectorType trimTrailingOneDims(VectorType oldType) {
6790 ArrayRef<int64_t> oldShape = oldType.getShape();
6791 ArrayRef<int64_t> newShape = oldShape;
6792
6793 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6794 ArrayRef<bool> newScalableDims = oldScalableDims;
6795
6796 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6797 newShape = newShape.drop_back(1);
6798 newScalableDims = newScalableDims.drop_back(1);
6799 }
6800
6801 // Make sure we have at least 1 dimension.
6802 // TODO: Add support for 0-D vectors.
6803 if (newShape.empty()) {
6804 newShape = oldShape.take_back();
6805 newScalableDims = oldScalableDims.take_back();
6806 }
6807
6808 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6809}
6810
6811/// Folds qualifying shape_cast(create_mask) into a new create_mask
6812///
6813/// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
6814/// dimension. If the input vector comes from `vector.create_mask` for which
6815/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
6816/// to fold shape_cast into create_mask.
6817///
6818/// BEFORE:
6819/// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
6820/// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
6821/// AFTER:
6822/// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
6823class ShapeCastCreateMaskFolderTrailingOneDim final
6824 : public OpRewritePattern<ShapeCastOp> {
6825public:
6826 using Base::Base;
6827
6828 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6829 PatternRewriter &rewriter) const override {
6830 Value shapeOpSrc = shapeOp->getOperand(0);
6831 auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
6832 auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
6833 if (!createMaskOp && !constantMaskOp)
6834 return failure();
6835
6836 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6837 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6838
6839 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6840 if (newVecType != shapeOpResTy)
6841 return failure();
6842
6843 auto numDimsToDrop =
6844 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6845
6846 // No unit dims to drop
6847 if (!numDimsToDrop)
6848 return failure();
6849
6850 if (createMaskOp) {
6851 auto maskOperands = createMaskOp.getOperands();
6852 auto numMaskOperands = maskOperands.size();
6853
6854 // Check every mask dim size to see whether it can be dropped
6855 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6856 --i) {
6857 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6858 if (!constant || (constant.value() != 1))
6859 return failure();
6860 }
6861 SmallVector<Value> newMaskOperands =
6862 maskOperands.drop_back(numDimsToDrop);
6863
6864 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
6865 newMaskOperands);
6866 return success();
6867 }
6868
6869 if (constantMaskOp) {
6870 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6871 auto numMaskOperands = maskDimSizes.size();
6872
6873 // Check every mask dim size to see whether it can be dropped
6874 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6875 --i) {
6876 if (maskDimSizes[i] != 1)
6877 return failure();
6878 }
6879
6880 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6881 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
6882 newMaskOperands);
6883 return success();
6884 }
6885
6886 return failure();
6887 }
6888};
6889
6890// vector.broadcast has two distinct semantic modes: duplication across leading
6891// dimensions, and stretching across inner dimensions. This helper returns the
6892// product of the inner-dimension stretching factors.
6893int64_t getBroadcastStretchingFactor(ArrayRef<int64_t> srcShape,
6894 ArrayRef<int64_t> dstShape) {
6895 int stretchingFactor = 1;
6896 int numLeadingDims = dstShape.size() - srcShape.size();
6897 for (int i = 0, e = srcShape.size(); i < e; i++) {
6898 int64_t dstDim = dstShape[numLeadingDims + i];
6899 if (srcShape[i] == 1 && dstDim != 1) {
6900 stretchingFactor *= dstDim;
6901 }
6902 }
6903 return stretchingFactor;
6904}
6905
6906/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
6907class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
6908public:
6909 using Base::Base;
6910
6911 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6912 PatternRewriter &rewriter) const override {
6913 auto broadcastOp =
6914 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6915 if (!broadcastOp)
6916 return failure();
6917
6918 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6919 bool srcIsScalar = !srcVectorType;
6920
6921 // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
6922 // Example
6923 // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
6924 // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
6925 // to
6926 // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
6927 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6928 ArrayRef<int64_t> dstShape = dstVectorType.getShape();
6929 ArrayRef<int64_t> srcShape =
6930 srcIsScalar ? ArrayRef<int64_t>{} : srcVectorType.getShape();
6931 ArrayRef<int64_t> broadcastShape =
6932 broadcastOp.getResultVectorType().getShape();
6933
6934 if (!srcIsScalar) {
6935 if (isBroadcastableTo(srcVectorType, dstVectorType) !=
6936 BroadcastableToResult::Success) {
6937 return failure();
6938 }
6939 // Avoid folding if this would result in switching between the two
6940 // distinct semantic modes of vector.broadcast (duplication vs
6941 // stretching). See https://github.com/llvm/llvm-project/issues/190614.
6942 // This is detected by a change in the stretching factor. However if the
6943 // source has a single element, there is no ambiguity.
6944 if (srcVectorType.getNumElements() != 1) {
6945 if (getBroadcastStretchingFactor(srcShape, dstShape) !=
6946 getBroadcastStretchingFactor(srcShape, broadcastShape)) {
6947 return failure();
6948 }
6949 }
6950 }
6951
6952 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shapeCastOp, dstVectorType,
6953 broadcastOp.getSource());
6954 return success();
6955 }
6956};
6957
6958/// Pattern to rewrite Y = ShapeCast(FromElements(X)) as Y = FromElements(X)
6959///
6960/// BEFORE:
6961/// %1 = vector.from_elements %c1, %c2, %c3 : vector<3xf32>
6962/// %2 = vector.shape_cast %1 : vector<3xf32> to vector<1x3xf32>
6963/// AFTER:
6964/// %2 = vector.from_elements %c1, %c2, %c3 : vector<1x3xf32>
6965///
6966/// Note: this transformation is implemented as an OpRewritePattern, not as a
6967/// fold, because we have to create new op FromElementsOp with updated result
6968/// type. This cannot be done with a fold, because fold cannot create new ops
6969/// and the existing FromElementsOp result type differs from the ShapeCastOp
6970/// result type. Mutating the FromElementsOp (not root op) would violate the
6971/// fold contract and break other users.
6972class FoldShapeCastOfFromElements final : public OpRewritePattern<ShapeCastOp> {
6973public:
6974 using Base::Base;
6975
6976 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6977 PatternRewriter &rewriter) const override {
6978 auto fromElements = shapeCastOp.getSource().getDefiningOp<FromElementsOp>();
6979 if (!fromElements)
6980 return failure();
6981
6982 rewriter.replaceOpWithNewOp<FromElementsOp>(
6983 shapeCastOp, shapeCastOp.getResultVectorType(),
6984 fromElements.getElements());
6985 return success();
6986 }
6987};
6988
6989} // namespace
6990
6991void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6992 MLIRContext *context) {
6993 results.add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder,
6994 FoldShapeCastOfFromElements>(context);
6995}
6996
6997//===----------------------------------------------------------------------===//
6998// VectorBitCastOp
6999//===----------------------------------------------------------------------===//
7000
7001LogicalResult BitCastOp::verify() {
7002 auto sourceVectorType = getSourceVectorType();
7003 auto resultVectorType = getResultVectorType();
7004
7005 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
7006 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
7007 return emitOpError("dimension size mismatch at: ") << i;
7008 }
7009
7010 DataLayout dataLayout = DataLayout::closest(*this);
7011 auto sourceElementBits =
7012 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
7013 auto resultElementBits =
7014 dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
7015
7016 if (sourceVectorType.getRank() == 0) {
7017 if (sourceElementBits != resultElementBits)
7018 return emitOpError("source/result bitwidth of the 0-D vector element "
7019 "types must be equal");
7020 } else if (sourceElementBits * sourceVectorType.getShape().back() !=
7021 resultElementBits * resultVectorType.getShape().back()) {
7022 return emitOpError(
7023 "source/result bitwidth of the minor 1-D vectors must be equal");
7024 }
7025
7026 return success();
7027}
7028
7029OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
7030 // Nop cast.
7031 if (getSource().getType() == getResult().getType())
7032 return getSource();
7033
7034 // Canceling bitcasts.
7035 if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
7036 if (getResult().getType() == otherOp.getSource().getType())
7037 return otherOp.getSource();
7038
7039 setOperand(otherOp.getSource());
7040 return getResult();
7041 }
7042
7043 Attribute sourceConstant = adaptor.getSource();
7044 if (!sourceConstant)
7045 return {};
7046
7047 Type srcElemType = getSourceVectorType().getElementType();
7048 Type dstElemType = getResultVectorType().getElementType();
7049
7050 if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
7051 if (floatPack.isSplat()) {
7052 auto splat = floatPack.getSplatValue<FloatAttr>();
7053
7054 // Casting fp16 into fp32.
7055 if (srcElemType.isF16() && dstElemType.isF32()) {
7056 uint32_t bits = static_cast<uint32_t>(
7057 splat.getValue().bitcastToAPInt().getZExtValue());
7058 // Duplicate the 16-bit pattern.
7059 bits = (bits << 16) | (bits & 0xffff);
7060 APInt intBits(32, bits);
7061 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
7062 return DenseElementsAttr::get(getResultVectorType(), floatBits);
7063 }
7064 }
7065 }
7066
7067 if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
7068 if (intPack.isSplat()) {
7069 auto splat = intPack.getSplatValue<IntegerAttr>();
7070
7071 if (llvm::isa<IntegerType>(dstElemType) && srcElemType.isIntOrFloat()) {
7072 uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
7073 uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
7074
7075 // Casting to a larger integer bit width.
7076 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
7077 APInt intBits = splat.getValue().zext(dstBitWidth);
7078
7079 // Duplicate the lower width element.
7080 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
7081 intBits = (intBits << srcBitWidth) | intBits;
7082 return DenseElementsAttr::get(getResultVectorType(), intBits);
7083 }
7084 }
7085 }
7086 }
7087
7088 return {};
7089}
7090
7091//===----------------------------------------------------------------------===//
7092// TypeCastOp
7093//===----------------------------------------------------------------------===//
7094
7095static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
7096 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
7097 SmallVector<int64_t, 8> res(memRefType.getShape());
7098 if (vectorType)
7099 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
7100 return res;
7101}
7102
7103/// Build the canonical memRefType with a single vector.
7104/// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
7105void TypeCastOp::build(OpBuilder &builder, OperationState &result,
7106 Value source) {
7107 result.addOperands(source);
7108 MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
7109 VectorType vectorType =
7110 VectorType::get(extractShape(memRefType),
7112 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
7113 memRefType.getMemorySpace()));
7114}
7115
7116LogicalResult TypeCastOp::verify() {
7117 MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
7118 if (!canonicalType.getLayout().isIdentity())
7119 return emitOpError("expects operand to be a memref with identity layout");
7120 if (!getResultMemRefType().getLayout().isIdentity())
7121 return emitOpError("expects result to be a memref with identity layout");
7122 if (getResultMemRefType().getMemorySpace() !=
7123 getMemRefType().getMemorySpace())
7124 return emitOpError("expects result in same memory space");
7125
7126 auto sourceType = getMemRefType();
7127 auto resultType = getResultMemRefType();
7128 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
7130 return emitOpError(
7131 "expects result and operand with same underlying scalar type: ")
7132 << resultType;
7133 if (extractShape(sourceType) != extractShape(resultType))
7134 return emitOpError(
7135 "expects concatenated result and operand shapes to be equal: ")
7136 << resultType;
7137 return success();
7138}
7139
7140//===----------------------------------------------------------------------===//
7141// TransposeOp
7142//===----------------------------------------------------------------------===//
7143
7144void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
7145 Value vector, ArrayRef<int64_t> permutation) {
7146 VectorType vt = llvm::cast<VectorType>(vector.getType());
7147 SmallVector<int64_t, 4> transposedShape(vt.getRank());
7148 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
7149 for (unsigned i = 0; i < permutation.size(); ++i) {
7150 transposedShape[i] = vt.getShape()[permutation[i]];
7151 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
7152 }
7153
7154 result.addOperands(vector);
7155 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
7156 transposedScalableDims));
7157 result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
7158 builder.getDenseI64ArrayAttr(permutation));
7159}
7160
7161OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
7162 // Eliminate splat constant transpose ops.
7163 if (auto splat =
7164 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
7165 return splat.reshape(getResultVectorType());
7166
7167 // Eliminate poison transpose ops.
7168 if (matchPattern(adaptor.getVector(), ub::m_Poison()))
7169 return ub::PoisonAttr::get(getContext());
7170
7171 // Eliminate identity transposes, and more generally any transposes that
7172 // preserves the shape without permuting elements.
7173 //
7174 // Examples of what to fold:
7175 // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
7176 // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
7177 // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
7178 //
7179 // Example of what NOT to fold:
7180 // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
7181 //
7182 if (getSourceVectorType() == getResultVectorType() &&
7183 isOrderPreserving(*this))
7184 return getVector();
7185
7186 return {};
7187}
7188
7189LogicalResult vector::TransposeOp::verify() {
7190 VectorType vectorType = getSourceVectorType();
7191 VectorType resultType = getResultVectorType();
7192 int64_t rank = resultType.getRank();
7193 if (vectorType.getRank() != rank)
7194 return emitOpError("vector result rank mismatch: ") << rank;
7195 // Verify transposition array.
7196 ArrayRef<int64_t> perm = getPermutation();
7197 int64_t size = perm.size();
7198 if (rank != size)
7199 return emitOpError("transposition length mismatch: ") << size;
7200 SmallVector<bool, 8> seen(rank, false);
7201 for (const auto &ta : llvm::enumerate(perm)) {
7202 if (ta.value() < 0 || ta.value() >= rank)
7203 return emitOpError("transposition index out of range: ") << ta.value();
7204 if (seen[ta.value()])
7205 return emitOpError("duplicate position index: ") << ta.value();
7206 seen[ta.value()] = true;
7207 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
7208 return emitOpError("dimension size mismatch at: ") << ta.value();
7209 }
7210 return success();
7211}
7212
7213std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
7214 return llvm::to_vector<4>(getResultVectorType().getShape());
7215}
7216
7217void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7218 SetIntRangeFn setResultRanges) {
7219 setResultRanges(getResult(), argRanges.front());
7220}
7221
7222namespace {
7223
7224// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
7225class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
7226public:
7227 using Base::Base;
7228
7229 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7230 PatternRewriter &rewriter) const override {
7231 // Composes two permutations: result[i] = permutation1[permutation2[i]].
7232 auto composePermutations = [](ArrayRef<int64_t> permutation1,
7233 ArrayRef<int64_t> permutation2) {
7234 SmallVector<int64_t, 4> result;
7235 for (auto index : permutation2)
7236 result.push_back(permutation1[index]);
7237 return result;
7238 };
7239
7240 // Return if the input of 'transposeOp' is not defined by another transpose.
7241 vector::TransposeOp parentTransposeOp =
7242 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
7243 if (!parentTransposeOp)
7244 return failure();
7245
7246 SmallVector<int64_t, 4> permutation = composePermutations(
7247 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
7248 // Replace 'transposeOp' with a new transpose operation.
7249 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
7250 transposeOp, transposeOp.getResult().getType(),
7251 parentTransposeOp.getVector(), permutation);
7252 return success();
7253 }
7254};
7255
7256/// Replace transpose(splat-like(v)) with broadcast(v)
7257class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
7258public:
7259 using Base::Base;
7260
7261 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7262 PatternRewriter &rewriter) const override {
7263 Value splat = getScalarSplatSource(transposeOp.getVector());
7264 if (!splat)
7265 return failure();
7266
7267 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
7268 transposeOp, transposeOp.getResultVectorType(), splat);
7269 return success();
7270 }
7271};
7272
7273/// Folds transpose(create_mask) into a new transposed create_mask.
7274class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
7275public:
7276 using Base::Base;
7277
7278 LogicalResult matchAndRewrite(TransposeOp transpOp,
7279 PatternRewriter &rewriter) const override {
7280 Value transposeSrc = transpOp.getVector();
7281 auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
7282 auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
7283 if (!createMaskOp && !constantMaskOp)
7284 return failure();
7285
7286 // Get the transpose permutation and apply it to the vector.create_mask or
7287 // vector.constant_mask operands.
7288 ArrayRef<int64_t> permutation = transpOp.getPermutation();
7289
7290 if (createMaskOp) {
7291 auto maskOperands = createMaskOp.getOperands();
7292 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
7293 applyPermutationToVector(newOperands, permutation);
7294
7295 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
7296 transpOp, transpOp.getResultVectorType(), newOperands);
7297 return success();
7298 }
7299
7300 // ConstantMaskOp case.
7301 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
7302 auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
7303
7304 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
7305 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
7306 return success();
7307 }
7308};
7309
7310/// Folds transpose(shape_cast) into a new shape_cast.
7311class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
7312public:
7313 using Base::Base;
7314
7315 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7316 PatternRewriter &rewriter) const override {
7317 auto shapeCastOp =
7318 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
7319 if (!shapeCastOp)
7320 return failure();
7321 if (!isOrderPreserving(transposeOp))
7322 return failure();
7323
7324 VectorType resultType = transposeOp.getType();
7325
7326 // We don't need to check isValidShapeCast at this point, because it is
7327 // guaranteed that merging the transpose into the the shape_cast is a valid
7328 // shape_cast, because the transpose just inserts/removes ones.
7329
7330 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
7331 shapeCastOp.getSource());
7332 return success();
7333 }
7334};
7335
7336/// Folds transpose(from_elements(...)) into a new from_elements with permuted
7337/// operands matching the transposed shape.
7338///
7339/// Example:
7340///
7341/// %v = vector.from_elements %a00, %a01, %a02, %a10, %a11, %a12 :
7342/// vector<2x3xi32> %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to
7343/// vector<3x2xi32>
7344///
7345/// becomes ->
7346///
7347/// %r = vector.from_elements %a00, %a10, %a01, %a11, %a02, %a12 :
7348/// vector<3x2xi32>
7349///
7350class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
7351public:
7352 using Base::Base;
7353 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7354 PatternRewriter &rewriter) const override {
7355 auto fromElementsOp =
7356 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
7357 if (!fromElementsOp)
7358 return failure();
7359
7360 VectorType srcTy = fromElementsOp.getDest().getType();
7361 VectorType dstTy = transposeOp.getType();
7362
7363 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
7364 int64_t rank = srcTy.getRank();
7365
7366 // Build inverse permutation to map destination indices back to source.
7367 SmallVector<int64_t> inversePerm(rank, 0);
7368 for (int64_t i = 0; i < rank; ++i)
7369 inversePerm[permutation[i]] = i;
7370
7371 ArrayRef<int64_t> srcShape = srcTy.getShape();
7372 ArrayRef<int64_t> dstShape = dstTy.getShape();
7373 SmallVector<int64_t> srcIdx(rank, 0);
7374 SmallVector<int64_t> dstIdx(rank, 0);
7375 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
7376 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
7377
7378 auto elementsOld = fromElementsOp.getElements();
7379 SmallVector<Value> elementsNew;
7380 int64_t dstNumElements = dstTy.getNumElements();
7381 elementsNew.reserve(dstNumElements);
7382
7383 // For each element in destination row-major order, pick the corresponding
7384 // source element.
7385 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
7386 // Pick the destination element index.
7387 dstIdx = delinearize(linearIdx, dstStrides);
7388 // Map the destination element index to the source element index.
7389 for (int64_t j = 0; j < rank; ++j)
7390 srcIdx[j] = dstIdx[inversePerm[j]];
7391 // Linearize the source element index.
7392 int64_t srcLin = linearize(srcIdx, srcStrides);
7393 // Add the source element to the new elements.
7394 elementsNew.push_back(elementsOld[srcLin]);
7395 }
7396
7397 rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
7398 elementsNew);
7399 return success();
7400 }
7401};
7402
7403/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
7404/// 'order preserving', where 'order preserving' means the flattened
7405/// inputs and outputs of the transpose have identical (numerical) values.
7406///
7407/// Example:
7408/// ```
7409/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
7410/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
7411/// to vector<8x1xi32>
7412/// ```
7413/// can be rewritten as the equivalent
7414/// ```
7415/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
7416/// ```
7417/// The algorithm works by partitioning dimensions into groups that can be
7418/// locally permuted while preserving order, and checks that the transpose
7419/// only permutes within these groups.
7420///
7421/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
7422/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
7423/// broadcasting from 1x1x4x1x1x7.
7424/// ^^^ ^ ^^^ ^
7425/// groups: 0 1 2 3
7426/// Order preserving permutations for this example are ones that only permute
7427/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
7428class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
7429public:
7430 using Base::Base;
7431 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
7432 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
7433
7434 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
7435 PatternRewriter &rewriter) const override {
7436
7437 vector::BroadcastOp broadcast =
7438 transpose.getVector().getDefiningOp<vector::BroadcastOp>();
7439 if (!broadcast) {
7440 return rewriter.notifyMatchFailure(transpose,
7441 "not preceded by a broadcast");
7442 }
7443
7444 auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
7445 VectorType outputType = transpose.getResultVectorType();
7446
7447 // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
7448 bool inputIsScalar = !inputType;
7449 if (inputIsScalar) {
7450 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
7451 broadcast.getSource());
7452 return success();
7453 }
7454
7455 ArrayRef<int64_t> permutation = transpose.getPermutation();
7456 ArrayRef<int64_t> inputShape = inputType.getShape();
7457 int64_t inputRank = inputType.getRank();
7458 int64_t outputRank = transpose.getType().getRank();
7459 int64_t deltaRank = outputRank - inputRank;
7460
7461 int low = 0;
7462 for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
7463 bool notOne = inputShape[inputIndex] != 1;
7464 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
7465 bool groupEndFound = notOne || prevNotOne;
7466 if (groupEndFound) {
7467 int high = inputIndex + deltaRank;
7468 // Return failure if not all permutation destinations for indices in
7469 // [low, high) are in [low, high), i.e. the permutation is not local to
7470 // the group.
7471 for (int i = low; i < high; ++i) {
7472 if (permutation[i] < low || permutation[i] >= high) {
7473 return rewriter.notifyMatchFailure(
7474 transpose, "permutation not local to group");
7475 }
7476 }
7477 low = high;
7478 }
7479 }
7480
7481 // We don't need to check the final group [low, outputRank) because if it is
7482 // not locally bound, there must be a preceding group that already failed
7483 // the check (impossible to have just 1 non-locally bound group).
7484
7485 // The preceding logic also ensures that at this point, the output of the
7486 // transpose is definitely broadcastable from the input shape, assert so:
7487 assert(vector::isBroadcastableTo(inputType, outputType) ==
7488 vector::BroadcastableToResult::Success &&
7489 "not broadcastable directly to transpose output");
7490
7491 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
7492 broadcast.getSource());
7493
7494 return success();
7495 }
7496};
7497
7498} // namespace
7499
7500void vector::TransposeOp::getCanonicalizationPatterns(
7501 RewritePatternSet &results, MLIRContext *context) {
7502 results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7503 FoldTransposeSplat, FoldTransposeFromElements,
7504 FoldTransposeBroadcast>(context);
7505}
7506
7507//===----------------------------------------------------------------------===//
7508// ConstantMaskOp
7509//===----------------------------------------------------------------------===//
7510
7511void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
7512 VectorType type, ConstantMaskKind kind) {
7513 assert(kind == ConstantMaskKind::AllTrue ||
7514 kind == ConstantMaskKind::AllFalse);
7515 build(builder, result, type,
7516 kind == ConstantMaskKind::AllTrue
7517 ? type.getShape()
7518 : SmallVector<int64_t>(type.getRank(), 0));
7519}
7520
7521LogicalResult ConstantMaskOp::verify() {
7522 auto resultType = llvm::cast<VectorType>(getResult().getType());
7523 // Check the corner case of 0-D vectors first.
7524 if (resultType.getRank() == 0) {
7525 if (getMaskDimSizes().size() != 1)
7526 return emitError("array attr must have length 1 for 0-D vectors");
7527 auto dim = getMaskDimSizes()[0];
7528 if (dim != 0 && dim != 1)
7529 return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
7530 return success();
7531 }
7532
7533 // Verify that array attr size matches the rank of the vector result.
7534 if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
7535 return emitOpError(
7536 "must specify array attr of size equal vector result rank");
7537 // Verify that each array attr element is in bounds of corresponding vector
7538 // result dimension size.
7539 auto resultShape = resultType.getShape();
7540 auto resultScalableDims = resultType.getScalableDims();
7541 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7542 for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7543 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7544 return emitOpError(
7545 "array attr of size out of bounds of vector result dimension size");
7546 if (resultScalableDims[index] && maskDimSize != 0 &&
7547 maskDimSize != resultShape[index])
7548 return emitOpError(
7549 "only supports 'none set' or 'all set' scalable dimensions");
7550 }
7551 // Verify that if one mask dim size is zero, they all should be zero (because
7552 // the mask region is a conjunction of each mask dimension interval).
7553 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7554 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
7555 if (anyZeros && !allZeros)
7556 return emitOpError("expected all mask dim sizes to be zeros, "
7557 "as a result of conjunction with zero mask dim");
7558 return success();
7559}
7560
7561bool ConstantMaskOp::isAllOnesMask() {
7562 auto resultType = getVectorType();
7563 // Check the corner case of 0-D vectors first.
7564 if (resultType.getRank() == 0) {
7565 assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
7566 return getMaskDimSizes()[0] == 1;
7567 }
7568 for (const auto [resultSize, maskDimSize] :
7569 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7570 if (maskDimSize < resultSize)
7571 return false;
7572 }
7573 return true;
7574}
7575
7576OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7577 ArrayRef<int64_t> bounds = getMaskDimSizes();
7578 ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
7579
7580 auto createBoolSplat = [&](bool x) {
7581 return SplatElementsAttr::get(getVectorType(),
7583 };
7584
7585 // Check the corner case of 0-D vectors first.
7586 if (vectorSizes.empty()) {
7587 assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
7588 return createBoolSplat(bounds[0] == 1);
7589 }
7590 // Fold vector.constant_mask to splat if possible.
7591 if (bounds == vectorSizes)
7592 return createBoolSplat(true);
7593 if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
7594 return createBoolSplat(false);
7595 return OpFoldResult();
7596}
7597
7598//===----------------------------------------------------------------------===//
7599// CreateMaskOp
7600//===----------------------------------------------------------------------===//
7601
7602void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
7603 VectorType type,
7604 ArrayRef<OpFoldResult> mixedOperands) {
7605 SmallVector<Value> operands =
7606 getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
7607 build(builder, result, type, operands);
7608}
7609
7610LogicalResult CreateMaskOp::verify() {
7611 auto vectorType = llvm::cast<VectorType>(getResult().getType());
7612 // Verify that an operand was specified for each result vector each dimension.
7613 if (vectorType.getRank() == 0) {
7614 if (getNumOperands() != 1)
7615 return emitOpError(
7616 "must specify exactly one operand for 0-D create_mask");
7617 } else if (getNumOperands() !=
7618 llvm::cast<VectorType>(getResult().getType()).getRank()) {
7619 return emitOpError(
7620 "must specify an operand for each result vector dimension");
7621 }
7622 return success();
7623}
7624
7625namespace {
7626
7627/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
7628///
7629/// Ex 1:
7630/// %c2 = arith.constant 2 : index
7631/// %c3 = arith.constant 3 : index
7632/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
7633/// Becomes:
7634/// vector.constant_mask [3, 2] : vector<4x3xi1>
7635///
7636/// Ex 2:
7637/// %c_neg_1 = arith.constant -1 : index
7638/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
7639/// becomes:
7640/// vector.constant_mask [0] : vector<[8]xi1>
7641///
7642/// Ex 3:
7643/// %c8 = arith.constant 8 : index
7644/// %c16 = arith.constant 16 : index
7645/// %0 = vector.vscale
7646/// %1 = arith.muli %0, %c16 : index
7647/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
7648/// becomes:
7649/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
7650class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
7651public:
7652 using Base::Base;
7653
7654 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7655 PatternRewriter &rewriter) const override {
7656 VectorType maskType = createMaskOp.getVectorType();
7657 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7658 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7659
7660 // Special case: Rank zero shape.
7661 constexpr std::array<int64_t, 1> rankZeroShape{1};
7662 constexpr std::array<bool, 1> rankZeroScalableDims{false};
7663 if (maskType.getRank() == 0) {
7664 maskTypeDimSizes = rankZeroShape;
7665 maskTypeDimScalableFlags = rankZeroScalableDims;
7666 }
7667
7668 // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
7669 // collect the `constantDims` (for the ConstantMaskOp).
7670 SmallVector<int64_t, 4> constantDims;
7671 for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7672 if (auto intSize = getConstantIntValue(dimSize)) {
7673 // Constant value.
7674 // If the mask dim is non-scalable this can be any value.
7675 // If the mask dim is scalable only zero (all-false) is supported.
7676 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7677 return failure();
7678 constantDims.push_back(*intSize);
7679 } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
7680 // Constant vscale multiple (e.g. 4 x vscale).
7681 // Must be all-true to fold to a ConstantMask.
7682 if (vscaleMultiplier < maskTypeDimSizes[i])
7683 return failure();
7684 constantDims.push_back(*vscaleMultiplier);
7685 } else {
7686 return failure();
7687 }
7688 }
7689
7690 // Clamp values to constant_mask bounds.
7691 for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7692 value = std::clamp<int64_t>(value, 0, maskDimSize);
7693
7694 // If one of dim sizes is zero, set all dims to zero.
7695 if (llvm::is_contained(constantDims, 0))
7696 constantDims.assign(constantDims.size(), 0);
7697
7698 // Replace 'createMaskOp' with ConstantMaskOp.
7699 rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
7700 constantDims);
7701 return success();
7702 }
7703};
7704
7705} // namespace
7706
7707void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7708 MLIRContext *context) {
7709 results.add<CreateMaskFolder>(context);
7710}
7711
7712//===----------------------------------------------------------------------===//
7713// MaskOp
7714//===----------------------------------------------------------------------===//
7715
7716void MaskOp::build(
7717 OpBuilder &builder, OperationState &result, Value mask,
7718 Operation *maskableOp,
7719 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7720 assert(maskRegionBuilder &&
7721 "builder callback for 'maskRegion' must be present");
7722
7723 result.addOperands(mask);
7724 OpBuilder::InsertionGuard guard(builder);
7725 Region *maskRegion = result.addRegion();
7726 builder.createBlock(maskRegion);
7727 maskRegionBuilder(builder, maskableOp);
7728}
7729
7730void MaskOp::build(
7731 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7732 Value mask, Operation *maskableOp,
7733 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7734 build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
7735 maskRegionBuilder);
7736}
7737
7738void MaskOp::build(
7739 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7740 Value mask, Value passthru, Operation *maskableOp,
7741 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7742 build(builder, result, mask, maskableOp, maskRegionBuilder);
7743 if (passthru)
7744 result.addOperands(passthru);
7745 result.addTypes(resultTypes);
7746}
7747
7748ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
7749 // Create the op region.
7750 result.regions.reserve(1);
7751 Region &maskRegion = *result.addRegion();
7752
7753 auto &builder = parser.getBuilder();
7754
7755 // Parse all the operands.
7756 OpAsmParser::UnresolvedOperand mask;
7757 if (parser.parseOperand(mask))
7758 return failure();
7759
7760 // Optional passthru operand.
7761 OpAsmParser::UnresolvedOperand passthru;
7762 ParseResult parsePassthru = parser.parseOptionalComma();
7763 if (parsePassthru.succeeded() && parser.parseOperand(passthru))
7764 return failure();
7765
7766 // Parse op region.
7767 if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
7768 return failure();
7769
7770 MaskOp::ensureTerminator(maskRegion, builder, result.location);
7771
7772 // Parse the optional attribute list.
7773 if (parser.parseOptionalAttrDict(result.attributes))
7774 return failure();
7775
7776 // Parse all the types.
7777 Type maskType;
7778 if (parser.parseColonType(maskType))
7779 return failure();
7780
7781 SmallVector<Type> resultTypes;
7782 if (parser.parseOptionalArrowTypeList(resultTypes))
7783 return failure();
7784 result.types.append(resultTypes);
7785
7786 // Resolve operands.
7787 if (parser.resolveOperand(mask, maskType, result.operands))
7788 return failure();
7789
7790 if (parsePassthru.succeeded()) {
7791 if (resultTypes.empty())
7792 return parser.emitError(
7793 parser.getNameLoc(),
7794 "expects a result if passthru operand is provided");
7795
7796 if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
7797 return failure();
7798 }
7799
7800 return success();
7801}
7802
7803void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7804 p << " " << getMask();
7805 if (getPassthru())
7806 p << ", " << getPassthru();
7807
7808 // Print single masked operation and skip terminator.
7809 p << " { ";
7810 Block *singleBlock = &getMaskRegion().getBlocks().front();
7811 if (singleBlock && !singleBlock->getOperations().empty())
7812 p.printCustomOrGenericOp(&singleBlock->front());
7813 p << " }";
7814
7815 p.printOptionalAttrDict(getOperation()->getAttrs());
7816
7817 p << " : " << getMask().getType();
7818 if (getNumResults() > 0)
7819 p << " -> " << getResultTypes();
7820}
7821
7822void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
7823 // 1. For an empty `vector.mask`, create a default terminator.
7824 if (region.empty() || region.front().empty()) {
7825 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7826 MaskOp>::ensureTerminator(region, builder, loc);
7827 return;
7828 }
7829
7830 // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
7831 Block &block = region.front();
7832 if (isa<vector::YieldOp>(block.back()))
7833 return;
7834
7835 // 3. For a non-empty `vector.mask` without an explicit terminator:
7836
7837 // Create default terminator if the number of masked operations is not
7838 // one. This case will trigger a verification failure.
7839 if (block.getOperations().size() != 1) {
7840 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7841 MaskOp>::ensureTerminator(region, builder, loc);
7842 return;
7843 }
7844
7845 // Create a terminator that yields the results from the masked operation.
7846 OpBuilder opBuilder(builder.getContext());
7847 Operation *maskedOp = &block.front();
7848 opBuilder.setInsertionPointToEnd(&block);
7849 vector::YieldOp::create(opBuilder, loc, maskedOp->getResults());
7850}
7851
7852LogicalResult MaskOp::verify() {
7853 // Structural checks.
7854 Block &block = getMaskRegion().getBlocks().front();
7855 if (block.getOperations().empty())
7856 return emitOpError("expects a terminator within the mask region");
7857
7858 unsigned numMaskRegionOps = block.getOperations().size();
7859 if (numMaskRegionOps > 2)
7860 return emitOpError("expects only one operation to mask");
7861
7862 // Terminator checks.
7863 auto terminator = dyn_cast<vector::YieldOp>(block.back());
7864 if (!terminator)
7865 return emitOpError("expects a terminator within the mask region");
7866
7867 if (terminator->getNumOperands() != getNumResults())
7868 return emitOpError(
7869 "expects number of results to match mask region yielded values");
7870
7871 // Empty vector.mask. Nothing else to check.
7872 if (numMaskRegionOps == 1)
7873 return success();
7874
7875 auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
7876 if (!maskableOp)
7877 return emitOpError("expects a MaskableOpInterface within the mask region");
7878
7879 // Result checks.
7880 if (maskableOp->getNumResults() != getNumResults())
7881 return emitOpError("expects number of results to match maskable operation "
7882 "number of results");
7883
7884 if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
7885 return emitOpError("expects all the results from the MaskableOpInterface "
7886 "to match all the values returned by the terminator");
7887
7888 if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
7889 return emitOpError(
7890 "expects result type to match maskable operation result type");
7891
7892 if (llvm::count_if(maskableOp->getResultTypes(),
7893 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7894 return emitOpError("multiple vector results not supported");
7895
7896 // Mask checks.
7897 Type expectedMaskType = maskableOp.getExpectedMaskType();
7898 if (getMask().getType() != expectedMaskType)
7899 return emitOpError("expects a ")
7900 << expectedMaskType << " mask for the maskable operation";
7901
7902 // Passthru checks.
7903 Value passthru = getPassthru();
7904 if (passthru) {
7905 if (!maskableOp.supportsPassthru())
7906 return emitOpError(
7907 "doesn't expect a passthru argument for this maskable operation");
7908
7909 if (maskableOp->getNumResults() != 1)
7910 return emitOpError("expects result when passthru argument is provided");
7911
7912 if (passthru.getType() != maskableOp->getResultTypes()[0])
7913 return emitOpError("expects passthru type to match result type");
7914 }
7915
7916 return success();
7917}
7918
7919/// Folds empty `vector.mask` with no passthru operand and with or without
7920/// return values. For example:
7921///
7922/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
7923/// vector<8xi1> -> vector<8xf32>
7924/// %1 = user_op %0 : vector<8xf32>
7925///
7926/// becomes:
7927///
7928/// %0 = user_op %a : vector<8xf32>
7929///
7930/// Empty `vector.mask` with passthru operand are handled by the canonicalizer
7931/// as it requires creating new operations.
7932
7933static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7934 SmallVectorImpl<OpFoldResult> &results) {
7935 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7936 return failure();
7937
7938 Block *block = maskOp.getMaskBlock();
7939 auto terminator = cast<vector::YieldOp>(block->front());
7940 if (terminator.getNumOperands() == 0)
7941 return failure();
7942
7943 // `vector.mask` has results, propagate the results.
7944 llvm::append_range(results, terminator.getOperands());
7945 return success();
7946}
7947
7948LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7949 SmallVectorImpl<OpFoldResult> &results) {
7950 if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
7951 return success();
7952
7953 MaskFormat maskFormat = getMaskFormat(getMask());
7954 if (maskFormat != MaskFormat::AllTrue)
7955 return failure();
7956
7957 // Move maskable operation outside of the `vector.mask` region.
7958 // If there is no maskable op (empty body), the fold cannot proceed; the
7959 // canonicalizer handles this case instead.
7960 Operation *maskableOp = getMaskableOp();
7961 if (!maskableOp)
7962 return failure();
7963 maskableOp->dropAllUses();
7964 maskableOp->moveBefore(getOperation());
7965
7966 llvm::append_range(results, maskableOp->getResults());
7967 return success();
7968}
7969
7970/// Canonialize empty `vector.mask` operations that can't be handled in
7971/// `VectorMask::fold` as they require creating new operations.
7972///
7973/// Example 1: Empty `vector.mask` with passthru operand.
7974///
7975/// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
7976/// vector<8xi1> -> vector<8xf32>
7977///
7978/// becomes:
7979///
7980/// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
7981///
7982class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
7983 using Base::Base;
7984
7985 LogicalResult matchAndRewrite(MaskOp maskOp,
7986 PatternRewriter &rewriter) const override {
7987 if (!maskOp.isEmpty())
7988 return failure();
7989
7990 if (!maskOp.hasPassthru())
7991 return failure();
7992
7993 // arith.select with a vector condition requires the value types to be
7994 // vectors of the same shape. Since vector.mask always has a vector mask
7995 // type, bail out when any result type doesn't match the mask shape to
7996 // avoid creating invalid IR.
7997 VectorType maskType = maskOp.getMask().getType();
7998 for (Type resultType : maskOp.getResultTypes()) {
7999 auto vecResultType = dyn_cast<VectorType>(resultType);
8000 if (!vecResultType || vecResultType.getShape() != maskType.getShape())
8001 return failure();
8002 }
8003
8004 Block *block = maskOp.getMaskBlock();
8005 auto terminator = cast<vector::YieldOp>(block->front());
8006 assert(terminator.getNumOperands() == 1 &&
8007 "expected one result when passthru is provided");
8008
8009 rewriter.replaceOpWithNewOp<arith::SelectOp>(
8010 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
8011 terminator.getOperand(0), maskOp.getPassthru());
8012
8013 return success();
8014 }
8015};
8016
8017void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
8018 MLIRContext *context) {
8019 results.add<CanonializeEmptyMaskOp>(context);
8020}
8021
8022// MaskingOpInterface definitions.
8023
8024/// Returns the operation masked by this 'vector.mask'.
8025Operation *MaskOp::getMaskableOp() {
8026 Block *block = getMaskBlock();
8027 if (block->getOperations().size() < 2)
8028 return nullptr;
8029
8030 return &block->front();
8031}
8032
8033/// Returns true if 'vector.mask' has a passthru value.
8034bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
8035
8036//===----------------------------------------------------------------------===//
8037// ScanOp
8038//===----------------------------------------------------------------------===//
8039
8040LogicalResult ScanOp::verify() {
8041 VectorType srcType = getSourceType();
8042 VectorType initialType = getInitialValueType();
8043 // Check reduction dimension < rank.
8044 int64_t srcRank = srcType.getRank();
8045 int64_t reductionDim = getReductionDim();
8046 if (reductionDim >= srcRank)
8047 return emitOpError("reduction dimension ")
8048 << reductionDim << " has to be less than " << srcRank;
8049
8050 // Check that rank(initial_value) = rank(src) - 1.
8051 int64_t initialValueRank = initialType.getRank();
8052 if (initialValueRank != srcRank - 1)
8053 return emitOpError("initial value rank ")
8054 << initialValueRank << " has to be equal to " << srcRank - 1;
8055
8056 // Check shapes of initial value and src.
8057 ArrayRef<int64_t> srcShape = srcType.getShape();
8058 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
8059 SmallVector<int64_t> expectedShape;
8060 for (int i = 0; i < srcRank; i++) {
8061 if (i != reductionDim)
8062 expectedShape.push_back(srcShape[i]);
8063 }
8064 if (!llvm::equal(initialValueShapes, expectedShape)) {
8065 return emitOpError("incompatible input/initial value shapes");
8066 }
8067
8068 // Verify supported reduction kind.
8069 Type eltType = getDestType().getElementType();
8070 if (!isSupportedCombiningKind(getKind(), eltType))
8071 return emitOpError("unsupported reduction type ")
8072 << eltType << " for kind '" << stringifyCombiningKind(getKind())
8073 << "'";
8074
8075 return success();
8076}
8077
8079 RewritePatternSet &patterns, PatternBenefit benefit) {
8080 patterns
8081 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
8082 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
8083 StridedSliceConstantMaskFolder, TransposeFolder>(
8084 patterns.getContext(), benefit);
8085}
8086
8087Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
8088 CombiningKind kind, Value v1, Value acc,
8089 arith::FastMathFlagsAttr fastmath,
8090 Value mask) {
8091 Type t1 = getElementTypeOrSelf(v1.getType());
8092 Type tAcc = getElementTypeOrSelf(acc.getType());
8093 Value result;
8094
8095 switch (kind) {
8096 case CombiningKind::ADD:
8097 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
8098 result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
8099 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
8100 result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
8101 else
8102 llvm_unreachable("invalid value types for ADD reduction");
8103 break;
8104 case CombiningKind::AND:
8105 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8106 result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
8107 break;
8108 case CombiningKind::MAXNUMF:
8109 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8110 "expected float values");
8111 result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
8112 break;
8113 case CombiningKind::MAXIMUMF:
8114 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8115 "expected float values");
8116 result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
8117 break;
8118 case CombiningKind::MINNUMF:
8119 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8120 "expected float values");
8121 result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
8122 break;
8123 case CombiningKind::MINIMUMF:
8124 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8125 "expected float values");
8126 result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
8127 break;
8128 case CombiningKind::MAXSI:
8129 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8130 result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
8131 break;
8132 case CombiningKind::MINSI:
8133 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8134 result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
8135 break;
8136 case CombiningKind::MAXUI:
8137 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8138 result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
8139 break;
8140 case CombiningKind::MINUI:
8141 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8142 result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
8143 break;
8144 case CombiningKind::MUL:
8145 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
8146 result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
8147 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
8148 result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
8149 else
8150 llvm_unreachable("invalid value types for MUL reduction");
8151 break;
8152 case CombiningKind::OR:
8153 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8154 result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
8155 break;
8156 case CombiningKind::XOR:
8157 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8158 result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
8159 break;
8160 };
8161
8162 assert(result && "unknown CombiningKind");
8163 return selectPassthru(b, mask, result, acc);
8164}
8165
8166//===----------------------------------------------------------------------===//
8167// StepOp
8168//===----------------------------------------------------------------------===//
8169
8170void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
8171 SetIntRangeFn setResultRanges) {
8172 auto resultType = cast<VectorType>(getType());
8173 if (resultType.isScalable()) {
8174 return;
8175 }
8176 unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
8177 APInt zero(bitwidth, 0);
8178 APInt high(bitwidth, resultType.getDimSize(0) - 1);
8179 ConstantIntRanges result = {zero, high, zero, high};
8180 setResultRanges(getResult(), result);
8181}
8182
8183namespace {
8184
8185/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
8186/// constant large enough such that the result is the same at all indices.
8187///
8188/// For example, rewrite the 'greater than' comparison below,
8189///
8190/// ```mlir
8191/// %cst = arith.constant dense<7> : vector<3xindex>
8192/// %stp = vector.step : vector<3xindex>
8193/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
8194/// ```
8195///
8196/// as,
8197///
8198/// ```mlir
8199/// %out = arith.constant dense<false> : vector<3xi1>.
8200/// ```
8201///
8202/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result
8203/// is false at ALL indices we fold. If the constant was 1, then
8204/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold,
8205/// conservatively preferring the 'compact' vector.step representation.
8206///
8207/// Note: this folder only works for the case where the constant (`%cst` above)
8208/// is the second operand of the comparison. The arith.cmpi canonicalizer will
8209/// ensure that constants are always second (on the right).
8210struct StepCompareFolder : public OpRewritePattern<StepOp> {
8211 using Base::Base;
8212
8213 LogicalResult matchAndRewrite(StepOp stepOp,
8214 PatternRewriter &rewriter) const override {
8215 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
8216
8217 for (OpOperand &use : stepOp.getResult().getUses()) {
8218 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
8219 if (!cmpiOp)
8220 continue;
8221
8222 // arith.cmpi canonicalizer makes constants final operands.
8223 const unsigned stepOperandNumber = use.getOperandNumber();
8224 if (stepOperandNumber != 0)
8225 continue;
8226
8227 // Check that operand 1 is a constant.
8228 unsigned constOperandNumber = 1;
8229 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
8230 std::optional<int64_t> maybeConstValue =
8231 getConstantIntValue(otherOperand);
8232 if (!maybeConstValue.has_value())
8233 continue;
8234
8235 int64_t constValue = maybeConstValue.value();
8236 arith::CmpIPredicate pred = cmpiOp.getPredicate();
8237
8238 auto maybeSplat = [&]() -> std::optional<bool> {
8239 // Handle ult (unsigned less than) and uge (unsigned greater equal).
8240 if ((pred == arith::CmpIPredicate::ult ||
8241 pred == arith::CmpIPredicate::uge) &&
8242 stepSize <= constValue)
8243 return pred == arith::CmpIPredicate::ult;
8244
8245 // Handle ule and ugt.
8246 if ((pred == arith::CmpIPredicate::ule ||
8247 pred == arith::CmpIPredicate::ugt) &&
8248 stepSize - 1 <= constValue) {
8249 return pred == arith::CmpIPredicate::ule;
8250 }
8251
8252 // Handle eq and ne.
8253 if ((pred == arith::CmpIPredicate::eq ||
8254 pred == arith::CmpIPredicate::ne) &&
8255 stepSize <= constValue)
8256 return pred == arith::CmpIPredicate::ne;
8257
8258 return std::nullopt;
8259 }();
8260
8261 if (!maybeSplat.has_value())
8262 continue;
8263
8264 rewriter.setInsertionPointAfter(cmpiOp);
8265
8266 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
8267 if (!type)
8268 continue;
8269
8270 auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value());
8271 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
8272 type, boolAttr);
8273
8274 rewriter.replaceOp(cmpiOp, splat);
8275 return success();
8276 }
8277
8278 return failure();
8279 }
8280};
8281} // namespace
8282
8283void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
8284 MLIRContext *context) {
8285 results.add<StepCompareFolder>(context);
8286}
8287
8288//===----------------------------------------------------------------------===//
8289// Vector Masking Utilities
8290//===----------------------------------------------------------------------===//
8291
8292/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
8293/// as masked operation.
8294void mlir::vector::createMaskOpRegion(OpBuilder &builder,
8295 Operation *maskableOp) {
8296 assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
8297 Block *insBlock = builder.getInsertionBlock();
8298 // Create a block and move the op to that block.
8299 insBlock->getOperations().splice(
8300 insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
8301 YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults());
8302}
8303
8304/// Creates a vector.mask operation around a maskable operation. Returns the
8305/// vector.mask operation if the mask provided is valid. Otherwise, returns
8306/// the maskable operation itself.
8307Operation *mlir::vector::maskOperation(OpBuilder &builder,
8308 Operation *maskableOp, Value mask,
8309 Value passthru) {
8310 if (!mask)
8311 return maskableOp;
8312 if (passthru)
8313 return MaskOp::create(builder, maskableOp->getLoc(),
8314 maskableOp->getResultTypes(), mask, passthru,
8315 maskableOp, createMaskOpRegion);
8316 return MaskOp::create(builder, maskableOp->getLoc(),
8317 maskableOp->getResultTypes(), mask, maskableOp,
8319}
8320
8321/// Creates a vector select operation that picks values from `newValue` or
8322/// `passthru` for each result vector lane based on `mask`. This utility is used
8323/// to propagate the pass-thru value of vector.mask or for cases where only the
8324/// pass-thru value propagation is needed. VP intrinsics do not support
8325/// pass-thru values and every mask-out lane is set to poison. LLVM backends are
8326/// usually able to match op + select patterns and fold them into a native
8327/// target instructions.
8328Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
8329 Value newValue, Value passthru) {
8330 if (!mask)
8331 return newValue;
8332
8333 return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(),
8334 mask, newValue, passthru);
8335}
8336
8337//===----------------------------------------------------------------------===//
8338// TableGen'd op method definitions
8339//===----------------------------------------------------------------------===//
8340
8341#define GET_ATTRDEF_CLASSES
8342#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
8343
8344#define GET_OP_CLASSES
8345#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition VectorOps.cpp:73
static OpFoldResult foldShuffleIdentityMask(ShuffleOp op)
Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static OpFoldResult foldShufflePoisonInputs(MLIRContext *context, Attribute v1Attr, Attribute v2Attr)
Fold shuffle poison, poison -> poison.
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
Definition VectorOps.cpp:63
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldShuffleConstantInputs(ShuffleOp op, Attribute v1Attr, Attribute v2Attr)
Fold a shuffle of constant 1-D inputs by evaluating the mask.
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static OpFoldResult foldShufflePoisonOperandToMask(ShuffleOp op)
If a shuffle operand is poison, replace all mask indices that reference it with kPoisonIndex.
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
#define mul(a, b)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition Attributes.h:58
bool empty()
Definition Block.h:158
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
iterator begin()
Definition Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
IntegerType getI1Type()
Definition Builders.cpp:57
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:274
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:322
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:444
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:376
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:860
void setOperand(unsigned idx, Value value)
Definition Operation.h:377
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
operand_type_range getOperandTypes()
Definition Operation.h:423
result_type_range getResultTypes()
Definition Operation.h:454
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:441
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
bool isF32() const
Definition Types.cpp:40
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Definition UBMatchers.h:46
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition VectorOps.h:64
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition VectorOps.h:72
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const