MLIR 23.0.0git
VectorOps.cpp
Go to the documentation of this file.
1//===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements convenience types for working with super-vectorization
10// operations, in particular super-vector loads and stores.
11//
12//===----------------------------------------------------------------------===//
13
15
27#include "mlir/IR/AffineExpr.h"
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Builders.h"
33#include "mlir/IR/IRMapping.h"
37#include "mlir/IR/ValueRange.h"
40#include "mlir/Support/LLVM.h"
42#include "llvm/ADT/ArrayRef.h"
43#include "llvm/ADT/Repeated.h"
44#include "llvm/ADT/STLExtras.h"
45#include "llvm/ADT/SmallVector.h"
46#include "llvm/ADT/SmallVectorExtras.h"
47#include "llvm/ADT/StringSet.h"
48#include "llvm/ADT/TypeSwitch.h"
49#include "llvm/Support/Casting.h"
50
51#include <cassert>
52#include <cstdint>
53#include <numeric>
54
55#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
56// Pull in all enum type and utility function definitions.
57#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
58
59using namespace mlir;
60using namespace mlir::vector;
61
62/// Helper enum to classify mask value.
63enum class MaskFormat {
67};
68
69/// Helper method to classify a mask value. Currently, the method
70/// looks "under the hood" of a constant value with dense attributes
71/// and a constant mask operation (since the client may be called at
72/// various stages during progressive lowering).
74 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
75 // Inspect constant dense values. We count up for bits that
76 // are set, count down for bits that are cleared, and bail
77 // when a mix is detected.
78 if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
79 int64_t val = 0;
80 for (bool b : denseElts.getValues<bool>())
81 if (b && val >= 0)
82 val++;
83 else if (!b && val <= 0)
84 val--;
85 else
87 if (val > 0)
89 if (val < 0)
91 }
92 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
93 // Inspect constant mask index. If the index exceeds the
94 // dimension size, all bits are set. If the index is zero
95 // or less, no bits are set.
96 ArrayRef<int64_t> masks = m.getMaskDimSizes();
97 auto shape = m.getType().getShape();
98 bool allTrue = true;
99 bool allFalse = true;
100 for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
101 if (maskIdx < dimSize)
102 allTrue = false;
103 if (maskIdx > 0)
104 allFalse = false;
105 }
106 if (allTrue)
107 return MaskFormat::AllTrue;
108 if (allFalse)
110 } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
111 // Finds all-false create_masks. An all-true create_mask requires all
112 // dims to be constants, so that'll be folded to a constant_mask, then
113 // detected in the constant_mask case.
114 auto maskOperands = m.getOperands();
115 for (Value operand : maskOperands) {
116 if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
117 int64_t dimSize =
118 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
119 if (dimSize <= 0)
121 }
122 }
123 return MaskFormat::Unknown;
124 }
125 return MaskFormat::Unknown;
126}
127
128/// Default callback to build a region with a 'vector.yield' terminator with no
129/// arguments.
131 vector::YieldOp::create(builder, loc);
132}
133
134// Helper for verifying combining kinds in contractions and reductions.
135static bool isSupportedCombiningKind(CombiningKind combiningKind,
136 Type elementType) {
137 switch (combiningKind) {
138 case CombiningKind::ADD:
139 case CombiningKind::MUL:
140 return elementType.isIntOrIndexOrFloat();
141 case CombiningKind::MINUI:
142 case CombiningKind::MINSI:
143 case CombiningKind::MAXUI:
144 case CombiningKind::MAXSI:
145 case CombiningKind::AND:
146 case CombiningKind::OR:
147 case CombiningKind::XOR:
148 return elementType.isIntOrIndex();
149 case CombiningKind::MINNUMF:
150 case CombiningKind::MAXNUMF:
151 case CombiningKind::MINIMUMF:
152 case CombiningKind::MAXIMUMF:
153 return llvm::isa<FloatType>(elementType);
154 }
155 return false;
156}
157
158/// Returns the effective rank of the vector to read/write for Xfer Ops
159///
160/// When the element type of the shaped type is _a scalar_, this will simply
161/// return the rank of the vector ( the result for xfer_read or the value to
162/// store for xfer_write).
163///
164/// When the element type of the base shaped type is _a vector_, returns the
165/// difference between the original vector type and the element type of the
166/// shaped type.
167///
168/// EXAMPLE 1 (element type is _a scalar_):
169/// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
170/// - shapedType.getElementType() = f32 (rank 0)
171/// - vectorType.getRank() = 2
172/// - Result = 2 - 0 = 2
173///
174/// EXAMPLE 2 (element type is _a vector_):
175/// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
176/// - shapedType.getElementType() = vector<20xf32> (rank 1)
177/// - vectorType.getRank() = 1
178/// - Result = 1 - 1 = 0
179///
180/// This is used to determine the number of minor dimensions for identity maps
181/// in vector transfer Ops.
182static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
183 VectorType vectorType) {
184 unsigned elementVectorRank = 0;
185 VectorType elementVectorType =
186 llvm::dyn_cast<VectorType>(shapedType.getElementType());
187 if (elementVectorType)
188 elementVectorRank += elementVectorType.getRank();
189 return vectorType.getRank() - elementVectorRank;
190}
191
193 VectorType vectorType) {
194 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
195 // TODO: replace once we have 0-d vectors.
196 if (shapedType.getRank() == 0 &&
197 vectorType.getShape() == ArrayRef<int64_t>{1})
198 return AffineMap::get(
199 /*numDims=*/0, /*numSymbols=*/0,
200 getAffineConstantExpr(0, shapedType.getContext()));
202 shapedType.getRank(),
203 getEffectiveVectorRankForXferOp(shapedType, vectorType),
204 shapedType.getContext());
205}
206
207/// Check if `write` is of a constant splat and the masked `read` is padded with
208/// the same splat value -- meaning it could be the same value as the initial
209/// constant splat.
210static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
211 vector::TransferReadOp read) {
212 auto readMask = read.getMask();
213 auto writeMask = write.getMask();
214 // Check if the masks are consistent. The splat value could be the same if the
215 // read is masked (and padded with the splat value), and the write is unmasked
216 // or has the same mask. Note this does not allow the case where the write is
217 // masked and the read is unmasked, as then the read could be of more elements
218 // than the write (which may not be the same value).
219 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
220 if (!couldBeSameSplat)
221 return false;
222 // Check for constant splat (as the source of the write).
223 DenseElementsAttr splatAttr;
224 if (!matchPattern(write.getVector(),
225 m_Constant<DenseElementsAttr>(&splatAttr)) ||
226 !splatAttr.isSplat()) {
227 return false;
228 }
229 // The padding of the read and the constant splat value must be the same.
230 Attribute padAttr;
231 if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
232 return false;
233 return padAttr == splatAttr.getSplatValue<Attribute>();
234}
235
236bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
237 vector::TransferReadOp read) {
238 return !defWrite.hasOutOfBoundsDim() &&
239 defWrite.getIndices() == read.getIndices() &&
240 defWrite.getVectorType() == read.getVectorType() &&
241 defWrite.getPermutationMap() == read.getPermutationMap() &&
242 ((!defWrite.getMask() && !read.getMask()) ||
244}
245
246bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
247 vector::TransferWriteOp priorWrite) {
248 return priorWrite.getIndices() == write.getIndices() &&
249 priorWrite.getMask() == write.getMask() &&
250 priorWrite.getVectorType() == write.getVectorType() &&
251 priorWrite.getPermutationMap() == write.getPermutationMap();
252}
253
255 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
256 bool testDynamicValueUsingBounds) {
257 // For simplicity only look at transfer of same type.
258 if (transferA.getVectorType() != transferB.getVectorType())
259 return false;
260 unsigned rankOffset = transferA.getLeadingShapedRank();
261 for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
262 Value indexA = transferA.getIndices()[i];
263 Value indexB = transferB.getIndices()[i];
264 std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
265 std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
266
267 if (i < rankOffset) {
268 // For leading dimensions, if we can prove that index are different we
269 // know we are accessing disjoint slices.
270 if (cstIndexA.has_value() && cstIndexB.has_value()) {
271 if (*cstIndexA != *cstIndexB)
272 return true;
273 continue;
274 }
275 if (testDynamicValueUsingBounds) {
276 // First try to see if we can fully compose and simplify the affine
277 // expression as a fast track.
278 FailureOr<uint64_t> delta =
280 if (succeeded(delta) && *delta != 0)
281 return true;
282
283 FailureOr<bool> testEqual =
285 if (succeeded(testEqual) && !testEqual.value())
286 return true;
287 }
288 } else {
289 // For this dimension, we slice a part of the memref we need to make sure
290 // the intervals accessed don't overlap.
291 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
292 if (cstIndexA.has_value() && cstIndexB.has_value()) {
293 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
294 if (distance >= vectorDim)
295 return true;
296 continue;
297 }
298 if (testDynamicValueUsingBounds) {
299 // First try to see if we can fully compose and simplify the affine
300 // expression as a fast track.
301 FailureOr<int64_t> delta =
303 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
304 return true;
305
306 FailureOr<int64_t> computeDelta =
308 if (succeeded(computeDelta)) {
309 if (std::abs(computeDelta.value()) >= vectorDim)
310 return true;
311 }
312 }
313 }
314 }
315 return false;
316}
317
318bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
319 VectorTransferOpInterface transferB,
320 bool testDynamicValueUsingBounds) {
321 if (transferA.getBase() != transferB.getBase())
322 return false;
323 return isDisjointTransferIndices(transferA, transferB,
324 testDynamicValueUsingBounds);
325}
326
327// Helper to iterate over n-D vector slice elements. Calculate the next
328// `position` in the n-D vector of size `shape`, applying an offset `offsets`.
329// Modifies the `position` in place. Returns a failure when `position` becomes
330// the end position.
331static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
333 ArrayRef<int64_t> offsets) {
334 for (auto [posInDim, dimSize, offsetInDim] :
335 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
336 ++posInDim;
337 if (posInDim < dimSize + offsetInDim)
338 return success();
339
340 // Carry the overflow to the next loop iteration.
341 posInDim = offsetInDim;
342 }
343
344 return failure();
345}
346
347/// Returns the integer numbers in `values`. `values` are expected to be
348/// constant operations.
351 llvm::transform(values, std::back_inserter(ints), [](Value value) {
352 auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
353 assert(constOp && "Unexpected non-constant index");
354 return constOp.value();
355 });
356 return ints;
357}
358
359/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
360/// be constant operations.
363 llvm::transform(
364 foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
365 assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
366 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
367 });
368 return ints;
369}
370
371/// Convert `foldResults` into Values. Integer attributes are converted to
372/// constant op.
374 ArrayRef<OpFoldResult> foldResults) {
375 SmallVector<Value> values;
376 llvm::transform(foldResults, std::back_inserter(values),
377 [&](OpFoldResult foldResult) {
378 if (auto attr = dyn_cast<Attribute>(foldResult))
380 builder, loc, cast<IntegerAttr>(attr).getInt())
381 .getResult();
382
383 return cast<Value>(foldResult);
384 });
385 return values;
386}
387
388std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
389 if (value.getDefiningOp<vector::VectorScaleOp>())
390 return 1;
391 auto mul = value.getDefiningOp<arith::MulIOp>();
392 if (!mul)
393 return {};
394 auto lhs = mul.getLhs();
395 auto rhs = mul.getRhs();
396 if (lhs.getDefiningOp<vector::VectorScaleOp>())
397 return getConstantIntValue(rhs);
398 if (rhs.getDefiningOp<vector::VectorScaleOp>())
399 return getConstantIntValue(lhs);
400 return {};
401}
402
403/// Converts numeric attributes to the expected type. Supports
404/// integer-to-integer and float-to-integer conversions. Returns the original
405/// attribute if no conversion is needed or supported.
406static Attribute convertNumericAttr(Attribute attr, Type expectedType) {
407 // Integer-to-integer conversion
408 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
409 if (auto intType = dyn_cast<IntegerType>(expectedType)) {
410 if (intAttr.getType() != expectedType)
411 return IntegerAttr::get(expectedType, intAttr.getInt());
412 }
413 return attr;
414 }
415
416 // Float-to-integer bitcast (preserves bit representation)
417 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
418 auto intType = dyn_cast<IntegerType>(expectedType);
419 if (!intType)
420 return attr;
421
422 APFloat floatVal = floatAttr.getValue();
423 APInt intVal = floatVal.bitcastToAPInt();
424 return IntegerAttr::get(expectedType, intVal);
425 }
426
427 return attr;
428}
429
430//===----------------------------------------------------------------------===//
431// CombiningKindAttr
432//===----------------------------------------------------------------------===//
433
434namespace mlir {
435namespace vector {
436namespace detail {
438 using KeyTy = uint64_t;
439
441
442 bool operator==(const KeyTy &key) const { return value == key; }
443
445 const KeyTy &key) {
446 return new (allocator.allocate<BitmaskEnumStorage>())
448 }
449
451};
452} // namespace detail
453} // namespace vector
454} // namespace mlir
455
456//===----------------------------------------------------------------------===//
457// VectorDialect
458//===----------------------------------------------------------------------===//
459
460namespace {
461/// This class defines the interface for handling inlining with vector dialect
462/// operations.
463struct VectorInlinerInterface : public DialectInlinerInterface {
464 using DialectInlinerInterface::DialectInlinerInterface;
465
466 /// All vector dialect ops can be inlined.
467 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
468 return true;
469 }
470};
471} // namespace
472
473void VectorDialect::initialize() {
474 addAttributes<
475#define GET_ATTRDEF_LIST
476#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
477 >();
478
479 addOperations<
480#define GET_OP_LIST
481#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
482 >();
483
484 addInterfaces<VectorInlinerInterface>();
485
486 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
487 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
488 YieldOp>();
489 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
490 TransferWriteOp>();
491 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
492 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
493 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
494}
495
496/// Materialize a single constant operation from a given attribute value with
497/// the desired resultant type.
498Operation *VectorDialect::materializeConstant(OpBuilder &builder,
499 Attribute value, Type type,
500 Location loc) {
501 if (matchPattern(value, ub::m_Poison()))
502 return value.getDialect().materializeConstant(builder, value, type, loc);
503
504 return arith::ConstantOp::materialize(builder, value, type, loc);
505}
506
508 return builder.getIntegerType(64);
509}
510
512 ArrayRef<int64_t> values) {
513 return builder.getI64ArrayAttr(values);
514}
515
516//===----------------------------------------------------------------------===//
517// MultiDimReductionOp
518//===----------------------------------------------------------------------===//
519
520void vector::MultiDimReductionOp::build(OpBuilder &builder,
521 OperationState &result, Value source,
522 Value acc, ArrayRef<bool> reductionMask,
523 CombiningKind kind) {
524 SmallVector<int64_t> reductionDims;
525 for (const auto &en : llvm::enumerate(reductionMask))
526 if (en.value())
527 reductionDims.push_back(en.index());
528 build(builder, result, kind, source, acc, reductionDims);
529}
530
531OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
532 // Single parallel dim, this is a noop.
533 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
534 return getSource();
535 return {};
536}
537
538std::optional<SmallVector<int64_t, 4>>
539MultiDimReductionOp::getShapeForUnroll() {
540 return llvm::to_vector<4>(getSourceVectorType().getShape());
541}
542
543LogicalResult MultiDimReductionOp::verify() {
544 SmallVector<int64_t> targetShape;
545 SmallVector<bool> scalableDims;
546 Type inferredReturnType;
547 auto sourceScalableDims = getSourceVectorType().getScalableDims();
548 for (auto [dimIdx, dimSize] :
549 llvm::enumerate(getSourceVectorType().getShape()))
550 if (!llvm::any_of(getReductionDims(),
551 [dimIdx = dimIdx](int64_t reductionDimIdx) {
552 return reductionDimIdx == static_cast<int64_t>(dimIdx);
553 })) {
554 targetShape.push_back(dimSize);
555 scalableDims.push_back(sourceScalableDims[dimIdx]);
556 }
557 // TODO: update to also allow 0-d vectors when available.
558 if (targetShape.empty())
559 inferredReturnType = getSourceVectorType().getElementType();
560 else
561 inferredReturnType = VectorType::get(
562 targetShape, getSourceVectorType().getElementType(), scalableDims);
563 if (getType() != inferredReturnType)
564 return emitOpError() << "destination type " << getType()
565 << " is incompatible with source type "
566 << getSourceVectorType();
567
568 return success();
569}
570
571/// Returns the mask type expected by this operation.
572Type MultiDimReductionOp::getExpectedMaskType() {
573 auto vecType = getSourceVectorType();
574 return VectorType::get(vecType.getShape(),
575 IntegerType::get(vecType.getContext(), /*width=*/1),
576 vecType.getScalableDims());
577}
578
579namespace {
580// Only unit dimensions that are being reduced are folded. If the dimension is
581// unit, but not reduced, it is not folded, thereby keeping the output type the
582// same. If not all dimensions which are reduced are of unit dimension, this
583// transformation does nothing. This is just a generalization of
584// ElideSingleElementReduction for ReduceOp.
585struct ElideUnitDimsInMultiDimReduction
586 : public OpRewritePattern<MultiDimReductionOp> {
587 using Base::Base;
588
589 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
590 PatternRewriter &rewriter) const override {
591 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
592 for (const auto &dim : enumerate(shape)) {
593 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
594 return failure();
595 }
596
597 // Vector mask setup.
598 OpBuilder::InsertionGuard guard(rewriter);
599 Operation *rootOp;
600 Value mask;
601 if (reductionOp.isMasked()) {
602 rewriter.setInsertionPoint(reductionOp.getMaskingOp());
603 rootOp = reductionOp.getMaskingOp();
604 mask = reductionOp.getMaskingOp().getMask();
605 } else {
606 rootOp = reductionOp;
607 }
608
609 Location loc = reductionOp.getLoc();
610 Value acc = reductionOp.getAcc();
611 Value cast;
612 if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
613 if (mask) {
614 VectorType newMaskType =
615 VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
616 dstVecType.getScalableDims());
617 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
618 }
619 cast = vector::ShapeCastOp::create(
620 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
621 } else {
622 // This means we are reducing all the dimensions, and all reduction
623 // dimensions are of size 1. So a simple extraction would do.
624 if (mask)
625 mask = vector::ExtractOp::create(rewriter, loc, mask);
626 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
627 }
628
629 Value result =
630 vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
631 cast, /*fastmath=*/nullptr, mask);
632 rewriter.replaceOp(rootOp, result);
633 return success();
634 }
635};
636} // namespace
637
638void MultiDimReductionOp::getCanonicalizationPatterns(
639 RewritePatternSet &results, MLIRContext *context) {
640 results.add<ElideUnitDimsInMultiDimReduction>(context);
641}
642
643//===----------------------------------------------------------------------===//
644// ReductionOp
645//===----------------------------------------------------------------------===//
646
647void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
648 CombiningKind kind, Value vector,
649 arith::FastMathFlags fastMathFlags) {
650 build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
651}
652
653void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
654 CombiningKind kind, Value vector, Value acc,
655 arith::FastMathFlags fastMathFlags) {
656 build(builder, result,
657 llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
658 acc, fastMathFlags);
659}
660
661LogicalResult ReductionOp::verify() {
662 // Verify for 0-D and 1-D vector.
663 int64_t rank = getSourceVectorType().getRank();
664 if (rank > 1)
665 return emitOpError("unsupported reduction rank: ") << rank;
666
667 // Verify supported reduction kind.
668 Type eltType = getDest().getType();
669 if (!isSupportedCombiningKind(getKind(), eltType))
670 return emitOpError("unsupported reduction type '")
671 << eltType << "' for kind '" << stringifyCombiningKind(getKind())
672 << "'";
673
674 return success();
675}
676
677// MaskableOpInterface methods.
678
679/// Returns the mask type expected by this operation.
680Type ReductionOp::getExpectedMaskType() {
681 auto vecType = getSourceVectorType();
682 return VectorType::get(vecType.getShape(),
683 IntegerType::get(vecType.getContext(), /*width=*/1),
684 vecType.getScalableDims());
685}
686
688 OpBuilder &builder, Location loc,
689 Value vector) {
690 switch (op) {
691 case arith::AtomicRMWKind::addf:
692 case arith::AtomicRMWKind::addi:
693 return vector::ReductionOp::create(builder, vector.getLoc(),
694 CombiningKind::ADD, vector);
695 case arith::AtomicRMWKind::mulf:
696 case arith::AtomicRMWKind::muli:
697 return vector::ReductionOp::create(builder, vector.getLoc(),
698 CombiningKind::MUL, vector);
699 case arith::AtomicRMWKind::minimumf:
700 return vector::ReductionOp::create(builder, vector.getLoc(),
701 CombiningKind::MINIMUMF, vector);
702 case arith::AtomicRMWKind::mins:
703 return vector::ReductionOp::create(builder, vector.getLoc(),
704 CombiningKind::MINSI, vector);
705 case arith::AtomicRMWKind::minu:
706 return vector::ReductionOp::create(builder, vector.getLoc(),
707 CombiningKind::MINUI, vector);
708 case arith::AtomicRMWKind::maximumf:
709 return vector::ReductionOp::create(builder, vector.getLoc(),
710 CombiningKind::MAXIMUMF, vector);
711 case arith::AtomicRMWKind::maxs:
712 return vector::ReductionOp::create(builder, vector.getLoc(),
713 CombiningKind::MAXSI, vector);
714 case arith::AtomicRMWKind::maxu:
715 return vector::ReductionOp::create(builder, vector.getLoc(),
716 CombiningKind::MAXUI, vector);
717 case arith::AtomicRMWKind::andi:
718 return vector::ReductionOp::create(builder, vector.getLoc(),
719 CombiningKind::AND, vector);
720 case arith::AtomicRMWKind::ori:
721 return vector::ReductionOp::create(builder, vector.getLoc(),
722 CombiningKind::OR, vector);
723 case arith::AtomicRMWKind::minnumf:
724 return vector::ReductionOp::create(builder, vector.getLoc(),
725 CombiningKind::MINNUMF, vector);
726 case arith::AtomicRMWKind::maxnumf:
727 return vector::ReductionOp::create(builder, vector.getLoc(),
728 CombiningKind::MAXNUMF, vector);
729 case arith::AtomicRMWKind::xori:
730 return vector::ReductionOp::create(builder, vector.getLoc(),
731 CombiningKind::XOR, vector);
732 default:
733 (void)emitOptionalError(loc, "Reduction operation type not supported");
734 break;
735 }
736 return nullptr;
737}
738
739std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
740 return llvm::to_vector<4>(getSourceVectorType().getShape());
741}
742
743namespace {
744struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
745 using Base::Base;
746
747 LogicalResult matchAndRewrite(ReductionOp reductionOp,
748 PatternRewriter &rewriter) const override {
749 // Vector mask setup.
750 OpBuilder::InsertionGuard guard(rewriter);
751 auto maskableOp =
752 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
753 Operation *rootOp;
754 Value mask;
755 if (maskableOp.isMasked()) {
756 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
757 rootOp = maskableOp.getMaskingOp();
758 mask = maskableOp.getMaskingOp().getMask();
759 } else {
760 rootOp = reductionOp;
761 }
762
763 auto vectorType = reductionOp.getSourceVectorType();
764 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
765 return failure();
766
767 Location loc = reductionOp.getLoc();
768 if (mask)
769 mask = ExtractOp::create(rewriter, loc, mask);
770 Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
771
772 if (Value acc = reductionOp.getAcc())
773 result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
774 result, acc,
775 reductionOp.getFastmathAttr(), mask);
776
777 rewriter.replaceOp(rootOp, result);
778 return success();
779 }
780};
781} // namespace
782
783void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
784 MLIRContext *context) {
785 results.add<ElideSingleElementReduction>(context);
786}
787
788//===----------------------------------------------------------------------===//
789// ContractionOp
790//===----------------------------------------------------------------------===//
791
792void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
794 ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
795 ArrayRef<IteratorType> iteratorTypes) {
796 result.addOperands({lhs, rhs, acc});
797 result.addTypes(acc.getType());
798 result.addAttribute(
799 getIndexingMapsAttrName(result.name),
800 builder.getAffineMapArrayAttr(
801 AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
802 result.addAttribute(
803 getIteratorTypesAttrName(result.name),
804 builder.getArrayAttr(llvm::map_to_vector(
805 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
806 return IteratorTypeAttr::get(builder.getContext(), t);
807 })));
808}
809
810void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
812 ArrayAttr indexingMaps,
813 ArrayAttr iteratorTypes) {
814 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
815 ContractionOp::getDefaultKind());
816}
817
818void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
820 ArrayAttr indexingMaps,
821 ArrayAttr iteratorTypes, CombiningKind kind) {
822 result.addOperands({lhs, rhs, acc});
823 result.addTypes(acc.getType());
824 result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
825 result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
826 result.addAttribute(getKindAttrName(result.name),
827 CombiningKindAttr::get(builder.getContext(), kind));
828}
829
830ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
836 Type resultType;
837 auto loc = parser.getCurrentLocation();
838 DictionaryAttr dictAttr;
839 // TODO: Unify linalg op attribute parsing.
840 if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
841 parser.parseComma() || parser.parseOperand(rhsInfo) ||
842 parser.parseComma() || parser.parseOperand(accInfo) ||
843 parser.parseTrailingOperandList(masksInfo) ||
844 parser.parseOptionalAttrDict(result.attributes) ||
845 parser.parseColonTypeList(types) ||
846 parser.parseKeywordType("into", resultType) ||
847 parser.resolveOperand(lhsInfo, types[0], result.operands) ||
848 parser.resolveOperand(rhsInfo, types[1], result.operands) ||
849 parser.resolveOperand(accInfo, resultType, result.operands) ||
850 parser.addTypeToList(resultType, result.types))
851 return failure();
852 result.attributes.append(dictAttr.getValue().begin(),
853 dictAttr.getValue().end());
854
855 // Convert array of string into an array of IteratyType enums. This is needed,
856 // because tests still use the old format when 'iterator_types' attribute is
857 // represented as an array of strings.
858 // TODO: Remove this conversion once tests are fixed.
859 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
860 result.attributes.get(getIteratorTypesAttrName(result.name)));
861 if (!iteratorTypes) {
862 return parser.emitError(loc)
863 << "expected " << getIteratorTypesAttrName(result.name)
864 << " array attribute";
865 }
866
867 SmallVector<Attribute> iteratorTypeAttrs;
868
869 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
870 auto maybeIteratorType = symbolizeIteratorType(s);
871 if (!maybeIteratorType.has_value())
872 return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
873
874 iteratorTypeAttrs.push_back(
875 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
876 }
877 result.attributes.set(getIteratorTypesAttrName(result.name),
878 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
879
880 if (!result.attributes.get(getKindAttrName(result.name))) {
881 result.addAttribute(
882 getKindAttrName(result.name),
883 CombiningKindAttr::get(result.getContext(),
884 ContractionOp::getDefaultKind()));
885 }
886 if (masksInfo.empty())
887 return success();
888 if (masksInfo.size() != 2)
889 return parser.emitError(parser.getNameLoc(),
890 "expected zero or exactly 2 vector mask operands");
891 auto lhsType = llvm::cast<VectorType>(types[0]);
892 auto rhsType = llvm::cast<VectorType>(types[1]);
893 auto maskElementType = parser.getBuilder().getI1Type();
894 std::array<VectorType, 2> maskTypes = {
895 VectorType::Builder(lhsType).setElementType(maskElementType),
896 VectorType::Builder(rhsType).setElementType(maskElementType)};
897 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
898 return failure();
899 return success();
900}
901
902void ContractionOp::print(OpAsmPrinter &p) {
903 // TODO: Unify printing code with linalg ops.
904 auto attrNames = getTraitAttrNames();
905 llvm::StringSet<> traitAttrsSet;
906 traitAttrsSet.insert_range(attrNames);
908 for (auto attr : (*this)->getAttrs()) {
909 if (attr.getName() == getIteratorTypesAttrName()) {
910 auto iteratorTypes =
911 llvm::cast<ArrayAttr>(attr.getValue())
912 .getAsValueRange<IteratorTypeAttr, IteratorType>();
913 // Convert IteratorType enums into the string representation. This is
914 // needed, because tests still use the old format when 'iterator_types'
915 // attribute is represented as an array of strings.
916 // TODO: Remove this conversion once tests are fixed.
917 SmallVector<Attribute> iteratorTypeNames =
918 llvm::map_to_vector(iteratorTypes, [&](IteratorType t) -> Attribute {
919 return StringAttr::get(getContext(), stringifyIteratorType(t));
920 });
921
922 attrs.emplace_back(getIteratorTypesAttrName(),
923 ArrayAttr::get(getContext(), iteratorTypeNames));
924 } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
925 attrs.push_back(attr);
926 }
927
928 auto dictAttr = DictionaryAttr::get(getContext(), attrs);
929 p << " " << dictAttr << " " << getLhs() << ", ";
930 p << getRhs() << ", " << getAcc();
931
932 p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
933 p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
934 << getResultType();
935}
936
937static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
938 const std::vector<std::pair<int64_t, int64_t>> &map) {
939 for (auto &dimPair : map) {
940 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
941 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
942 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
943 return false;
944 }
945 return true;
946}
947
948static LogicalResult verifyOutputShape(
949 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
950 Type resType,
951 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
952 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
953 DenseSet<int64_t> lhsContractingDimSet;
954 DenseSet<int64_t> rhsContractingDimSet;
955 for (auto &dimPair : contractingDimMap) {
956 lhsContractingDimSet.insert(dimPair.first);
957 rhsContractingDimSet.insert(dimPair.second);
958 }
959 DenseSet<int64_t> rhsBatchDimSet(llvm::from_range,
960 llvm::make_second_range(batchDimMap));
961
962 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
963 SmallVector<int64_t, 4> expectedResultDims;
964 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
965 if (lhsContractingDimSet.count(i) > 0)
966 continue;
967 expectedResultDims.push_back(lhsType.getDimSize(i));
968 }
969
970 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
971 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
972 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
973 continue;
974 expectedResultDims.push_back(rhsType.getDimSize(i));
975 }
976
977 // Verify 'expectedResultDims'.
978 if (expectedResultDims.empty()) {
979 // No batch or free dimension implies a scalar result.
980 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
981 return op.emitOpError("invalid accumulator/result vector shape");
982 } else {
983 // At least one batch or free dimension implies a vector result.
984 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
985 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
986 if (!resVectorType || !accVectorType)
987 return op.emitOpError("invalid accumulator/result vector shape");
988
989 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
990 // types fully define the result vector type. This assumes the affine maps
991 // are well-formed, which must have been verified already.
992 MLIRContext *ctx = op.getContext();
993 AffineMap lhsMap = op.getIndexingMapsArray()[0];
994 AffineMap rhsMap = op.getIndexingMapsArray()[1];
995 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
996 return op.emitOpError(
997 "expected all dimensions to be either a LHS or a RHS dimension");
999 for (auto pair :
1000 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
1001 VectorType v = pair.first;
1002 auto map = pair.second;
1003 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1004 unsigned pos = map.getDimPosition(idx);
1005 if (!extents[pos])
1006 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
1007 }
1008 }
1009 if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
1010 return op.emitOpError("expected all dimensions to get an extent as "
1011 "either a LHS or a RHS dimension");
1012
1013 AffineMap resMap = op.getIndexingMapsArray()[2];
1014 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
1015 /*symbolCount=*/0, extents, ctx);
1016 // Compose the resMap with the extentsMap, which is a constant map.
1017 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
1018 assert(llvm::all_of(expectedMap.getResults(),
1019 llvm::IsaPred<AffineConstantExpr>) &&
1020 "expected constant extent along all dimensions.");
1021 // Extract the expected shape and build the type.
1022 auto expectedShape =
1023 llvm::map_to_vector<4>(expectedMap.getResults(), [](AffineExpr e) {
1024 return cast<AffineConstantExpr>(e).getValue();
1025 });
1026 auto expected =
1027 VectorType::get(expectedShape, resVectorType.getElementType(),
1028 resVectorType.getScalableDims());
1029 if (resVectorType != expected || accVectorType != expected)
1030 return op.emitOpError(
1031 "invalid accumulator/result vector shape, expected: ")
1032 << expected;
1033 }
1034 return success();
1035}
1036
1037LogicalResult ContractionOp::verify() {
1038 VectorType lhsType = getLhsType();
1039 VectorType rhsType = getRhsType();
1040 Type accType = getAccType();
1041 Type resType = getResultType();
1042
1043 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1044 if (!lhsType.getElementType().isSignlessInteger())
1045 return emitOpError("only supports signless integer types");
1046 }
1047
1048 // Verify that an indexing map was specified for each vector operand.
1049 if (getIndexingMapsArray().size() != 3)
1050 return emitOpError("expected an indexing map for each vector operand");
1051
1052 // Verify that each index map has 'numIterators' inputs, no symbols, and
1053 // that the number of map outputs equals the rank of its associated
1054 // vector operand.
1055 unsigned numIterators = getIteratorTypes().getValue().size();
1056 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1057 auto index = it.index();
1058 auto map = it.value();
1059 if (map.getNumSymbols() != 0)
1060 return emitOpError("expected indexing map ")
1061 << index << " to have no symbols";
1062 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
1063 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1064 // Verify that the map has the right number of inputs, outputs, and indices.
1065 // This also correctly accounts for (..) -> () for rank-0 results.
1066 if (map.getNumDims() != numIterators)
1067 return emitOpError("expected indexing map ")
1068 << index << " to have " << numIterators << " number of inputs";
1069 if (map.getNumResults() != rank)
1070 return emitOpError("expected indexing map ")
1071 << index << " to have " << rank << " number of outputs";
1072 if (!map.isProjectedPermutation())
1073 return emitOpError("expected indexing map ")
1074 << index << " to be a projected permutation of its inputs";
1075 }
1076
1077 auto contractingDimMap = getContractingDimMap();
1078 auto batchDimMap = getBatchDimMap();
1079
1080 // Verify at least one contracting dimension pair was specified.
1081 if (contractingDimMap.empty())
1082 return emitOpError("expected at least one contracting dimension pair");
1083
1084 // Verify contracting dimension map was properly constructed.
1085 if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1086 return emitOpError("invalid contracting dimension map");
1087
1088 // Verify batch dimension map was properly constructed.
1089 if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1090 return emitOpError("invalid batch dimension map");
1091
1092 // Verify 'accType' and 'resType' shape.
1093 if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1094 contractingDimMap, batchDimMap)))
1095 return failure();
1096
1097 if (!getKindAttr()) {
1098 return emitOpError("expected 'kind' attribute of type CombiningKind (e.g. "
1099 "'vector.kind<add>')");
1100 }
1101
1102 // Verify supported combining kind.
1103 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1104 auto elementType = vectorType ? vectorType.getElementType() : resType;
1105 if (!isSupportedCombiningKind(getKind(), elementType))
1106 return emitOpError("unsupported contraction type");
1107
1108 // Delayed calling of IndexingMapOpInterface::verifyImpl.
1109 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1110}
1111
1112// MaskableOpInterface methods.
1113
1114/// Returns the mask type expected by this operation. Mostly used for
1115/// verification purposes. It requires the operation to be vectorized."
1116Type ContractionOp::getExpectedMaskType() {
1117 auto indexingMaps = this->getIndexingMapsArray();
1118 AffineMap lhsIdxMap = indexingMaps[0];
1119 AffineMap rhsIdxMap = indexingMaps[1];
1120 VectorType lhsType = this->getLhsType();
1121 VectorType rhsType = this->getRhsType();
1122
1123 unsigned numVecDims = lhsIdxMap.getNumDims();
1124 SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1125 SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1126
1127 // Using the information in the indexing maps, extract the size of each
1128 // dimension in the vector.contract operation from the two input operands.
1129 for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1130 maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1131 maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1132 lhsType.getScalableDims()[dimIdx];
1133 }
1134 for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1135 maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1136 maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1137 rhsType.getScalableDims()[dimIdx];
1138 }
1139
1140 assert(ShapedType::isStaticShape(maskShape) &&
1141 "Mask shape couldn't be computed");
1142
1143 return VectorType::get(maskShape,
1144 IntegerType::get(lhsType.getContext(), /*width=*/1),
1145 maskShapeScalableDims);
1146}
1147
1148SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1149 return SmallVector<StringRef>{getIndexingMapsAttrName(),
1150 getIteratorTypesAttrName(), getKindAttrName()};
1151}
1152
1154 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1155 if (targetExpr == map.getResult(i))
1156 return i;
1157 return -1;
1158}
1159
1160static std::vector<std::pair<int64_t, int64_t>>
1161getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1162 IteratorType targetIteratorType, MLIRContext *context) {
1163 std::vector<std::pair<int64_t, int64_t>> dimMap;
1164 for (const auto &it : llvm::enumerate(iteratorTypes)) {
1165 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1166 if (iteratorType != targetIteratorType)
1167 continue;
1168 // Search lhs/rhs map results for 'targetExpr'.
1169 auto targetExpr = getAffineDimExpr(it.index(), context);
1170 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1171 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1172 if (lhsDim >= 0 && rhsDim >= 0)
1173 dimMap.emplace_back(lhsDim, rhsDim);
1174 }
1175 return dimMap;
1176}
1177
1178void ContractionOp::getIterationBounds(
1179 SmallVectorImpl<int64_t> &iterationBounds) {
1180 auto lhsShape = getLhsType().getShape();
1181 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1182 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1183 for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1184 // Search lhs/rhs map results for 'targetExpr'.
1185 auto targetExpr = getAffineDimExpr(it.index(), getContext());
1186 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1187 if (iteratorType == IteratorType::reduction) {
1188 // Get reduction dim size from lhs shape (same size in rhsShape).
1189 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1190 assert(lhsDimIndex >= 0);
1191 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1192 continue;
1193 }
1194 // Get parallel dimension size from result shape.
1195 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1196 assert(resDimIndex >= 0);
1197 assert(resVectorType != nullptr);
1198 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1199 }
1200}
1201
1202void ContractionOp::getIterationIndexMap(
1203 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1204 unsigned numMaps = getIndexingMapsArray().size();
1205 iterationIndexMap.resize(numMaps);
1206 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1207 auto index = it.index();
1208 auto map = it.value();
1209 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1210 auto dim = cast<AffineDimExpr>(map.getResult(i));
1211 iterationIndexMap[index][dim.getPosition()] = i;
1212 }
1213 }
1214}
1215
1216std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1217 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1218 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1219 getContext());
1220}
1221
1222std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1223 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1224 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1225 getContext());
1226}
1227
1228std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1230 getIterationBounds(shape);
1231 return shape;
1232}
1233
1234/// Return a fused vector::ContractionOp which represents a patterns such as:
1235///
1236/// ```mlir
1237/// %c0 = vector.constant 0: ...
1238/// %c = vector.contract %a, %b, %c0: ...
1239/// %e = add %c, %d: ...
1240/// ```
1241///
1242/// by:
1243///
1244/// ```mlir
1245/// %e = vector.contract %a, %b, %d: ...
1246/// ```
1247///
1248/// Return null if the canonicalization does not apply.
1249// TODO: This should be a folding of Add into Contract in core but while they
1250// live in different dialects, it is not possible without unnatural
1251// dependencies.
1252template <typename AddOpType>
1253struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1254 using OpRewritePattern<AddOpType>::OpRewritePattern;
1255
1256 LogicalResult matchAndRewrite(AddOpType addOp,
1257 PatternRewriter &rewriter) const override {
1258 auto canonicalize = [&](Value maybeContraction,
1259 Value otherOperand) -> vector::ContractionOp {
1260 vector::ContractionOp contractionOp =
1261 dyn_cast_or_null<vector::ContractionOp>(
1262 maybeContraction.getDefiningOp());
1263 if (!contractionOp)
1264 return vector::ContractionOp();
1265 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1266 contractionOp.getAcc().getDefiningOp())) {
1267 if (maybeZero.getValue() ==
1268 rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1269 IRMapping bvm;
1270 bvm.map(contractionOp.getAcc(), otherOperand);
1271 auto newContraction =
1272 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1273 rewriter.replaceOp(addOp, newContraction.getResult());
1274 return newContraction;
1275 }
1276 }
1277 return vector::ContractionOp();
1278 };
1279
1280 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1281 vector::ContractionOp contract = canonicalize(a, b);
1282 contract = contract ? contract : canonicalize(b, a);
1283 return contract ? success() : failure();
1284 }
1285};
1286
1287void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1288 MLIRContext *context) {
1291}
1292
1293// Returns `true` if `index` is either within [0, maxIndex) or equal to
1294// `poisonValue`.
1296 int64_t maxIndex) {
1297 return index == poisonValue || (index >= 0 && index < maxIndex);
1298}
1299
1300//===----------------------------------------------------------------------===//
1301// ExtractOp
1302//===----------------------------------------------------------------------===//
1303
1304void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1305 SetIntRangeFn setResultRanges) {
1306 setResultRanges(getResult(), argRanges.front());
1307}
1308
1309void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1310 Value source) {
1311 auto vectorTy = cast<VectorType>(source.getType());
1312 build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
1313}
1314
1315void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1316 Value source, int64_t position) {
1317 build(builder, result, source, ArrayRef<int64_t>{position});
1318}
1319
1320void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1321 Value source, OpFoldResult position) {
1322 build(builder, result, source, ArrayRef<OpFoldResult>{position});
1323}
1324
1325void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1326 Value source, ArrayRef<int64_t> position) {
1327 build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1328 builder.getDenseI64ArrayAttr(position));
1329}
1330
1331void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1332 Value source, ArrayRef<OpFoldResult> position) {
1333 SmallVector<int64_t> staticPos;
1334 SmallVector<Value> dynamicPos;
1335 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1336 build(builder, result, source, dynamicPos,
1337 builder.getDenseI64ArrayAttr(staticPos));
1338}
1339
1340LogicalResult
1341ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1342 ExtractOp::Adaptor adaptor,
1343 SmallVectorImpl<Type> &inferredReturnTypes) {
1344 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1345 if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1346 vectorType.getRank()) {
1347 inferredReturnTypes.push_back(vectorType.getElementType());
1348 } else {
1349 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1350 vectorType.getRank());
1351 inferredReturnTypes.push_back(VectorType::get(
1352 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1353 vectorType.getScalableDims().drop_front(n)));
1354 }
1355 return success();
1356}
1357
1358LogicalResult vector::ExtractOp::verify() {
1359 if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
1360 if (resTy.getRank() == 0)
1361 return emitError(
1362 "expected a scalar instead of a 0-d vector as the result type");
1363
1364 // Note: This check must come before getMixedPosition() to prevent a crash.
1365 auto dynamicMarkersCount =
1366 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1367 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1368 return emitOpError(
1369 "mismatch between dynamic and static positions (kDynamic marker but no "
1370 "corresponding dynamic position) -- this can only happen due to an "
1371 "incorrect fold/rewrite");
1372 auto position = getMixedPosition();
1373 if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1374 return emitOpError(
1375 "expected position attribute of rank no greater than vector rank");
1376 for (auto [idx, pos] : llvm::enumerate(position)) {
1377 if (auto attr = dyn_cast<Attribute>(pos)) {
1378 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1380 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1381 return emitOpError("expected position attribute #")
1382 << (idx + 1)
1383 << " to be a non-negative integer smaller than the "
1384 "corresponding vector dimension or poison (-1)";
1385 }
1386 }
1387 }
1388 return success();
1389}
1390
1391template <typename IntType>
1393 return llvm::map_to_vector<4>(
1394 arrayAttr.getAsRange<IntegerAttr>(),
1395 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); });
1396}
1397
1398/// Fold the result of chains of ExtractOp in place by simply concatenating the
1399/// positions.
1400static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1401 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1402 return failure();
1403
1404 // TODO: Canonicalization for dynamic position not implemented yet.
1405 if (extractOp.hasDynamicPosition())
1406 return failure();
1407
1408 SmallVector<int64_t> globalPosition;
1409 ExtractOp currentOp = extractOp;
1410 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1411 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1412 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1413 currentOp = nextOp;
1414 // TODO: Canonicalization for dynamic position not implemented yet.
1415 if (currentOp.hasDynamicPosition())
1416 return failure();
1417 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1418 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1419 }
1420 extractOp.setOperand(0, currentOp.getSource());
1421 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1422 OpBuilder b(extractOp.getContext());
1423 std::reverse(globalPosition.begin(), globalPosition.end());
1424 extractOp.setStaticPosition(globalPosition);
1425 return success();
1426}
1427
1428namespace {
1429/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1430/// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1431/// Compose TransposeOp permutations as we walk back.
1432/// This helper class keeps an updated extraction position `extractPosition`
1433/// with extra trailing sentinels.
1434/// The sentinels encode the internal transposition status of the result vector.
1435/// As we iterate, extractPosition is permuted and updated.
1436class ExtractFromInsertTransposeChainState {
1437public:
1438 ExtractFromInsertTransposeChainState(ExtractOp e);
1439
1440 /// Iterate over producing insert and transpose ops until we find a fold.
1441 Value fold();
1442
1443private:
1444 /// Return true if the vector at position `a` is contained within the vector
1445 /// at position `b`. Under insert/extract semantics, this is the same as `a`
1446 /// is a prefix of `b`.
1447 template <typename ContainerA, typename ContainerB>
1448 bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1449 return a.size() <= b.size() &&
1450 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1451 }
1452
1453 /// Return true if the vector at position `a` intersects the vector at
1454 /// position `b`. Under insert/extract semantics, this is the same as equality
1455 /// of all entries of `a` that are >=0 with the corresponding entries of b.
1456 /// Comparison is on the common prefix (i.e. zip).
1457 template <typename ContainerA, typename ContainerB>
1458 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1459 for (auto [elemA, elemB] : llvm::zip(a, b)) {
1460 if (elemA < 0 || elemB < 0)
1461 continue;
1462 if (elemA != elemB)
1463 return false;
1464 }
1465 return true;
1466 }
1467
1468 /// Folding is only possible in the absence of an internal permutation in the
1469 /// result vector.
1470 bool canFold() {
1471 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1472 }
1473
1474 // Helper to get the next defining op of interest.
1475 void updateStateForNextIteration(Value v) {
1476 nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1477 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1478 };
1479
1480 // Case 1. If we hit a transpose, just compose the map and iterate.
1481 // Invariant: insert + transpose do not change rank, we can always compose.
1482 LogicalResult handleTransposeOp();
1483
1484 // Case 2: the insert position matches extractPosition exactly, early return.
1485 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1486
1487 /// Case 3: if the insert position is a prefix of extractPosition, extract a
1488 /// portion of the source of the insert.
1489 /// Example:
1490 /// ```
1491 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1492 /// // extractPosition == [1, 2, 3]
1493 /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1494 /// // can fold to vector.extract %source[0, 3]
1495 /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1496 /// ```
1497 /// To traverse through %source, we need to set the leading dims to 0 and
1498 /// drop the extra leading dims.
1499 /// This method updates the internal state.
1500 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1501
1502 /// Try to fold in place to extract(source, extractPosition) and return the
1503 /// folded result. Return null if folding is not possible (e.g. due to an
1504 /// internal transposition in the result).
1505 Value tryToFoldExtractOpInPlace(Value source);
1506
1507 ExtractOp extractOp;
1508 int64_t vectorRank;
1509 int64_t extractedRank;
1510
1511 InsertOp nextInsertOp;
1512 TransposeOp nextTransposeOp;
1513
1514 /// Sentinel values that encode the internal permutation status of the result.
1515 /// They are set to (-1, ... , -k) at the beginning and appended to
1516 /// `extractPosition`.
1517 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1518 /// ensure that there is no internal transposition.
1519 /// Internal transposition cannot be accounted for with a folding pattern.
1520 // TODO: We could relax the internal transposition with an extra transposition
1521 // operation in a future canonicalizer.
1522 SmallVector<int64_t> sentinels;
1523 SmallVector<int64_t> extractPosition;
1524};
1525} // namespace
1526
1527ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1528 ExtractOp e)
1529 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1530 extractedRank(extractOp.getNumIndices()) {
1531 assert(vectorRank >= extractedRank && "Extracted position overflow");
1532 sentinels.reserve(vectorRank - extractedRank);
1533 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1534 sentinels.push_back(-(i + 1));
1535 extractPosition.assign(extractOp.getStaticPosition().begin(),
1536 extractOp.getStaticPosition().end());
1537 llvm::append_range(extractPosition, sentinels);
1538}
1539
1540// Case 1. If we hit a transpose, just compose the map and iterate.
1541// Invariant: insert + transpose do not change rank, we can always compose.
1542LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1543 // TODO: Canonicalization for dynamic position not implemented yet.
1544 if (extractOp.hasDynamicPosition())
1545 return failure();
1546
1547 if (!nextTransposeOp)
1548 return failure();
1550 nextTransposeOp.getPermutation(), extractOp.getContext()));
1552 return success();
1553}
1554
1555// Case 2: the insert position matches extractPosition exactly, early return.
1556LogicalResult
1557ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1558 Value &res) {
1559 // TODO: Canonicalization for dynamic position not implemented yet.
1560 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1561 return failure();
1562
1563 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1564 if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1565 return failure();
1566 // Case 2.a. early-exit fold.
1567 res = nextInsertOp.getValueToStore();
1568 // Case 2.b. if internal transposition is present, canFold will be false.
1569 return success(canFold());
1570}
1571
1572/// Case 3: if inserted position is a prefix of extractPosition,
1573/// extract a portion of the source of the insertion.
1574/// This method updates the internal state.
1575LogicalResult
1576ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1577 // TODO: Canonicalization for dynamic position not implemented yet.
1578 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1579 return failure();
1580
1581 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1582 if (!isContainedWithin(insertedPos, extractPosition))
1583 return failure();
1584 // Set leading dims to zero.
1585 std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1586 // Drop extra leading dims.
1587 extractPosition.erase(extractPosition.begin(),
1588 extractPosition.begin() + insertedPos.size());
1589 extractedRank = extractPosition.size() - sentinels.size();
1590 // Case 3.a. early-exit fold (break and delegate to post-while path).
1591 res = nextInsertOp.getValueToStore();
1592 // Case 3.b. if internal transposition is present, canFold will be false.
1593 return success();
1594}
1595
1596/// Try to fold in place to extract(source, extractPosition) and return the
1597/// folded result. Return null if folding is not possible (e.g. due to an
1598/// internal transposition in the result).
1599Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1600 Value source) {
1601 // TODO: Canonicalization for dynamic position not implemented yet.
1602 if (extractOp.hasDynamicPosition())
1603 return Value();
1604
1605 // If we can't fold (either internal transposition, or nothing to fold), bail.
1606 bool nothingToFold = (source == extractOp.getSource());
1607 if (nothingToFold || !canFold())
1608 return Value();
1609
1610 // Otherwise, fold by updating the op inplace and return its result.
1611 OpBuilder b(extractOp.getContext());
1612 extractOp.setStaticPosition(
1613 ArrayRef(extractPosition).take_front(extractedRank));
1614 extractOp.getSourceMutable().assign(source);
1615 return extractOp.getResult();
1616}
1617
1618/// Iterate over producing insert and transpose ops until we find a fold.
1619Value ExtractFromInsertTransposeChainState::fold() {
1620 // TODO: Canonicalization for dynamic position not implemented yet.
1621 if (extractOp.hasDynamicPosition())
1622 return Value();
1623
1624 Value valueToExtractFrom = extractOp.getSource();
1625 updateStateForNextIteration(valueToExtractFrom);
1626 while (nextInsertOp || nextTransposeOp) {
1627 // Case 1. If we hit a transpose, just compose the map and iterate.
1628 // Invariant: insert + transpose do not change rank, we can always compose.
1629 if (succeeded(handleTransposeOp())) {
1630 valueToExtractFrom = nextTransposeOp.getVector();
1631 updateStateForNextIteration(valueToExtractFrom);
1632 continue;
1633 }
1634
1635 Value result;
1636 // Case 2: the position match exactly.
1637 if (succeeded(handleInsertOpWithMatchingPos(result)))
1638 return result;
1639
1640 // Case 3: if the inserted position is a prefix of extractPosition, we can
1641 // just extract a portion of the source of the insert.
1642 if (succeeded(handleInsertOpWithPrefixPos(result)))
1643 return tryToFoldExtractOpInPlace(result);
1644
1645 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1646 // values. This is a more difficult case and we bail.
1647 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1648 if (isContainedWithin(extractPosition, insertedPos) ||
1649 intersectsWhereNonNegative(extractPosition, insertedPos))
1650 return Value();
1651
1652 // Case 5: No intersection, we forward the extract to insertOp.dest().
1653 valueToExtractFrom = nextInsertOp.getDest();
1654 updateStateForNextIteration(valueToExtractFrom);
1655 }
1656 // If after all this we can fold, go for it.
1657 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1658}
1659
1660/// Returns true if the operation has a 0-D vector type operand or result.
1662 auto hasZeroDimVectorType = [](Type type) -> bool {
1663 auto vecType = dyn_cast<VectorType>(type);
1664 return vecType && vecType.getRank() == 0;
1665 };
1666
1667 return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1668 llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1669}
1670
1671/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
1672/// considered to be 'broadcastlike'.
1673static bool isBroadcastLike(Operation *op) {
1674 if (isa<BroadcastOp>(op))
1675 return true;
1676
1677 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1678 if (!shapeCast)
1679 return false;
1680
1681 // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1682 // Checking that the destination shape has a prefix of 1s is not sufficient,
1683 // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
1684 // is that the source shape is a suffix of the destination shape.
1685 VectorType srcType = shapeCast.getSourceVectorType();
1686 ArrayRef<int64_t> srcShape = srcType.getShape();
1687 uint64_t srcRank = srcType.getRank();
1688 ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
1689 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1690}
1691
1692/// Fold extract(broadcast(X)) to either extract(X) or just X.
1693///
1694/// Example:
1695///
1696/// broadcast extract [1][2]
1697/// (3, 4) --------> (2, 3, 4) ----------------> (4)
1698///
1699/// becomes
1700/// extract [1]
1701/// (3,4) -------------------------------------> (4)
1702///
1703///
1704/// The variable names used in this implementation correspond to the above
1705/// shapes as,
1706///
1707/// - (3, 4) is `input` shape.
1708/// - (2, 3, 4) is `broadcast` shape.
1709/// - (4) is `extract` shape.
1710///
1711/// This folding is possible when the suffix of `input` shape is the same as
1712/// `extract` shape.
1713static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1714
1715 Operation *defOp = extractOp.getSource().getDefiningOp();
1716 if (!defOp || !isBroadcastLike(defOp))
1717 return Value();
1718
1719 Value input = defOp->getOperand(0);
1720
1721 // Replace extract(broadcast(X)) with X
1722 if (extractOp.getType() == input.getType())
1723 return input;
1724
1725 // Get required types and ranks in the chain
1726 // input -> broadcast -> extract
1727 // (scalars are treated as rank-0).
1728 auto inputType = llvm::dyn_cast<VectorType>(input.getType());
1729 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1730 unsigned inputRank = inputType ? inputType.getRank() : 0;
1731 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1732 unsigned extractRank = extractType ? extractType.getRank() : 0;
1733
1734 // Cannot do without the broadcast if overall the rank increases.
1735 if (extractRank > inputRank)
1736 return Value();
1737
1738 // The above condition guarantees that input is a vector.
1739 assert(inputType && "input must be a vector type because of previous checks");
1740 ArrayRef<int64_t> inputShape = inputType.getShape();
1741
1742 // In the case where there is a broadcast dimension in the suffix, it is not
1743 // possible to replace extract(broadcast(X)) with extract(X). Example:
1744 //
1745 // broadcast extract
1746 // (1) --------> (3,4) ------> (4)
1747 if (extractType &&
1748 extractType.getShape() != inputShape.take_back(extractRank))
1749 return Value();
1750
1751 // Replace extract(broadcast(X)) with extract(X).
1752 // First, determine the new extraction position.
1753 unsigned deltaOverall = inputRank - extractRank;
1754 unsigned deltaBroadcast = broadcastRank - inputRank;
1755 SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
1756 SmallVector<OpFoldResult> newPositions(deltaOverall);
1757 IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
1758 for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1759 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1760 }
1761 auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
1762 extractOp->setOperands(
1763 llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
1764 extractOp.setStaticPosition(staticPos);
1765 return extractOp.getResult();
1766}
1767
1768/// Fold extractOp coming from ShuffleOp.
1769///
1770/// Example:
1771///
1772/// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1773/// : vector<8xf32>, vector<8xf32>
1774/// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1775/// ->
1776/// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1777///
1778static Value foldExtractFromShuffle(ExtractOp extractOp) {
1779 // Dynamic positions are not folded as the resulting code would be more
1780 // complex than the input code.
1781 if (extractOp.hasDynamicPosition())
1782 return Value();
1783
1784 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1785 if (!shuffleOp)
1786 return Value();
1787
1788 // TODO: 0-D or multi-dimensional vectors not supported yet.
1789 if (shuffleOp.getResultVectorType().getRank() != 1)
1790 return Value();
1791
1792 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1793 auto shuffleMask = shuffleOp.getMask();
1794 int64_t extractIdx = extractOp.getStaticPosition()[0];
1795 int64_t shuffleIdx = shuffleMask[extractIdx];
1796
1797 // Find the shuffled vector to extract from based on the shuffle index.
1798 if (shuffleIdx < inputVecSize) {
1799 extractOp.setOperand(0, shuffleOp.getV1());
1800 extractOp.setStaticPosition({shuffleIdx});
1801 } else {
1802 extractOp.setOperand(0, shuffleOp.getV2());
1803 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1804 }
1805
1806 return extractOp.getResult();
1807}
1808
1809// Fold extractOp with source coming from ShapeCast op.
1810static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1811 // TODO: Canonicalization for dynamic position not implemented yet.
1812 if (extractOp.hasDynamicPosition())
1813 return Value();
1814
1815 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1816 if (!shapeCastOp)
1817 return Value();
1818
1819 // Get the nth dimension size starting from lowest dimension.
1820 auto getDimReverse = [](VectorType type, int64_t n) {
1821 return type.getShape().take_back(n + 1).front();
1822 };
1823 int64_t destinationRank =
1824 llvm::isa<VectorType>(extractOp.getType())
1825 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1826 : 0;
1827 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1828 return Value();
1829 if (destinationRank > 0) {
1830 auto destinationType =
1831 llvm::cast<VectorType>(extractOp.getResult().getType());
1832 for (int64_t i = 0; i < destinationRank; i++) {
1833 // The lowest dimension of the destination must match the lowest
1834 // dimension of the shapecast op source.
1835 // TODO: This case could be support in a canonicalization pattern.
1836 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1837 getDimReverse(destinationType, i))
1838 return Value();
1839 }
1840 }
1841 // Extract the strides associated with the extract op vector source. Then use
1842 // this to calculate a linearized position for the extract.
1843 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1844 std::reverse(extractedPos.begin(), extractedPos.end());
1846 int64_t stride = 1;
1847 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1848 strides.push_back(stride);
1849 stride *=
1850 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1851 }
1852
1853 int64_t position = linearize(extractedPos, strides);
1854 // Then extract the strides associated to the shapeCast op vector source and
1855 // delinearize the position using those strides.
1856 SmallVector<int64_t, 4> newStrides;
1857 int64_t numDimension =
1858 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1859 stride = 1;
1860 for (int64_t i = 0; i < numDimension; i++) {
1861 newStrides.push_back(stride);
1862 stride *=
1863 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1864 }
1865 std::reverse(newStrides.begin(), newStrides.end());
1866 SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1867 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1868 OpBuilder b(extractOp.getContext());
1869 extractOp.setStaticPosition(newPosition);
1870 extractOp.setOperand(0, shapeCastOp.getSource());
1871 return extractOp.getResult();
1872}
1873
1874/// Fold an ExtractOp from ExtractStridedSliceOp.
1875static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1876 // TODO: Canonicalization for dynamic position not implemented yet.
1877 if (extractOp.hasDynamicPosition())
1878 return Value();
1879
1880 auto extractStridedSliceOp =
1881 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1882 if (!extractStridedSliceOp)
1883 return Value();
1884
1885 // 0-D vectors not supported.
1886 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1887 if (hasZeroDimVectors(extractStridedSliceOp))
1888 return Value();
1889
1890 // Return if 'extractStridedSliceOp' has non-unit strides.
1891 if (extractStridedSliceOp.hasNonUnitStrides())
1892 return Value();
1893
1894 // Trim offsets for dimensions fully extracted.
1895 auto sliceOffsets =
1896 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1897 while (!sliceOffsets.empty()) {
1898 size_t lastOffset = sliceOffsets.size() - 1;
1899 if (sliceOffsets.back() != 0 ||
1900 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1901 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1902 break;
1903 sliceOffsets.pop_back();
1904 }
1905 unsigned destinationRank = 0;
1906 if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1907 destinationRank = vecType.getRank();
1908 // The dimensions of the result need to be untouched by the
1909 // extractStridedSlice op.
1910 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1911 sliceOffsets.size())
1912 return Value();
1913
1914 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1915 assert(extractedPos.size() >= sliceOffsets.size());
1916 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1917 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1918 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1919
1920 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1921 OpBuilder b(extractOp.getContext());
1922 extractOp.setStaticPosition(extractedPos);
1923 return extractOp.getResult();
1924}
1925
1926/// Fold extract_op fed from a chain of insertStridedSlice ops.
1927static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1928 // TODO: Canonicalization for dynamic position not implemented yet.
1929 if (extractOp.hasDynamicPosition())
1930 return Value();
1931
1932 int64_t destinationRank =
1933 llvm::isa<VectorType>(extractOp.getType())
1934 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1935 : 0;
1936 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1937 if (!insertOp)
1938 return Value();
1939
1940 // 0-D vectors not supported.
1941 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1942 if (hasZeroDimVectors(insertOp))
1943 return Value();
1944
1945 while (insertOp) {
1946 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1947 insertOp.getSourceVectorType().getRank();
1948 if (destinationRank > insertOp.getSourceVectorType().getRank())
1949 return Value();
1950 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1951 ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1952
1953 if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1954 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1955 }))
1956 return Value();
1957 bool disjoint = false;
1958 SmallVector<int64_t, 4> offsetDiffs;
1959 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1960 int64_t start = insertOffsets[dim];
1961 int64_t size =
1962 (dim < insertRankDiff)
1963 ? 1
1964 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1965 int64_t end = start + size;
1966 int64_t offset = extractOffsets[dim];
1967 // Check if the start of the extract offset is in the interval inserted.
1968 if (start <= offset && offset < end) {
1969 if (dim >= insertRankDiff)
1970 offsetDiffs.push_back(offset - start);
1971 continue;
1972 }
1973 disjoint = true;
1974 break;
1975 }
1976 // The extract element chunk overlap with the vector inserted.
1977 if (!disjoint) {
1978 // If any of the inner dimensions are only partially inserted we have a
1979 // partial overlap.
1980 int64_t srcRankDiff =
1981 insertOp.getSourceVectorType().getRank() - destinationRank;
1982 for (int64_t i = 0; i < destinationRank; i++) {
1983 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1984 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1985 insertRankDiff))
1986 return Value();
1987 }
1988 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
1989 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1990 OpBuilder b(extractOp.getContext());
1991 extractOp.setStaticPosition(offsetDiffs);
1992 return extractOp.getResult();
1993 }
1994 // If the chunk extracted is disjoint from the chunk inserted, keep
1995 // looking in the insert chain.
1996 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1997 }
1998 return Value();
1999}
2000
2001/// Try to fold the extraction of a scalar from a vector defined by
2002/// vector.from_elements. E.g.:
2003///
2004/// %0 = vector.from_elements %a, %b : vector<2xf32>
2005/// %1 = vector.extract %0[0] : f32 from vector<2xf32>
2006/// ==> fold to %a
2007static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
2008 // Dynamic extractions cannot be folded.
2009 if (extractOp.hasDynamicPosition())
2010 return {};
2011
2012 // Look for extract(from_elements).
2013 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2014 if (!fromElementsOp)
2015 return {};
2016
2017 // Scalable vectors are not supported.
2018 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2019 if (vecType.isScalable())
2020 return {};
2021
2022 // Only extractions of scalars are supported.
2023 int64_t rank = vecType.getRank();
2024 ArrayRef<int64_t> indices = extractOp.getStaticPosition();
2025 if (extractOp.getType() != vecType.getElementType())
2026 return {};
2027 assert(static_cast<int64_t>(indices.size()) == rank &&
2028 "unexpected number of indices");
2029
2030 // Compute flattened/linearized index and fold to operand.
2031 int flatIndex = 0;
2032 int stride = 1;
2033 for (int i = rank - 1; i >= 0; --i) {
2034 flatIndex += indices[i] * stride;
2035 stride *= vecType.getDimSize(i);
2036 }
2037 return fromElementsOp.getElements()[flatIndex];
2038}
2039
2040/// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
2041/// then fold it.
2042template <typename OpType, typename AdaptorType>
2043static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
2044 SmallVectorImpl<Value> &operands) {
2045 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2046 OperandRange dynamicPosition = op.getDynamicPosition();
2047 ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2049 if constexpr (std::is_same_v<OpType, ExtractOp>)
2050 vectorShape = op.getSourceVectorType().getShape();
2051 else
2052 vectorShape = op.getDestVectorType().getShape();
2053
2054 // If the dynamic operands is empty, it is returned directly.
2055 if (!dynamicPosition.size())
2056 return {};
2057
2058 // `index` is used to iterate over the `dynamicPosition`.
2059 unsigned index = 0;
2060
2061 // `opChange` is a flag. If it is true, it means to update `op` in place.
2062 bool opChange = false;
2063 for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2064 if (ShapedType::isStatic(staticPosition[i]))
2065 continue;
2066 Attribute positionAttr = dynamicPositionAttr[index];
2067 Value position = dynamicPosition[index++];
2068 if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2069 int64_t value = attr.getInt();
2070 // Do not fold if the value is out of bounds (-1 signifies a poison
2071 // value rather than OOB index).
2072 if (value >= -1 && value < vectorShape[i]) {
2073 staticPosition[i] = attr.getInt();
2074 opChange = true;
2075 continue;
2076 }
2077 }
2078 operands.push_back(position);
2079 }
2080
2081 if (opChange) {
2082 op.setStaticPosition(staticPosition);
2083 op.getOperation()->setOperands(operands);
2084 // Return the original result to indicate an in-place folding happened.
2085 return op.getResult();
2086 }
2087 return {};
2088}
2089
2090/// Fold an insert or extract operation into an poison value when a poison index
2091/// is found at any dimension of the static position.
2093 ArrayRef<int64_t> staticPos,
2094 int64_t poisonVal) {
2095 if (!is_contained(staticPos, poisonVal))
2096 return {};
2097
2098 return ub::PoisonAttr::get(context);
2099}
2100
2101/// Fold a vector extract from is a poison source.
2103 if (matchPattern(srcAttr, ub::m_Poison()))
2104 return srcAttr;
2105
2106 return {};
2107}
2108
2109/// Fold a vector extract extracting from a DenseElementsAttr.
2111 Attribute srcAttr) {
2112 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2113 if (!denseAttr) {
2114 return {};
2115 }
2116
2117 if (denseAttr.isSplat()) {
2118 Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2119 if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2120 newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2121 return newAttr;
2122 }
2123
2124 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2125 if (vecTy.isScalable())
2126 return {};
2127
2128 if (extractOp.hasDynamicPosition()) {
2129 return {};
2130 }
2131
2132 // Materializing subsets of a large constant array can generally lead to
2133 // explosion in IR size because of different combination of subsets that
2134 // can exist. However, vector.extract is a restricted form of subset
2135 // extract where you can only extract non-overlapping (or the same) subset for
2136 // a given rank of the subset. Because of this property, the IR size can only
2137 // increase at most by `rank * size(array)` from a single constant array being
2138 // extracted by multiple extracts.
2139
2140 // Calculate the linearized position of the continuous chunk of elements to
2141 // extract.
2142 SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2143 copy(extractOp.getStaticPosition(), completePositions.begin());
2144 int64_t startPos =
2145 linearize(completePositions, computeStrides(vecTy.getShape()));
2146 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2147
2148 TypedAttr newAttr;
2149 if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2150 SmallVector<Attribute> elementValues(
2151 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2152 newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2153 } else {
2154 newAttr = *denseValuesBegin;
2155 }
2156
2157 return newAttr;
2158}
2159
2160OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2161 // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2162 // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2163 // mismatch).
2164 if (getNumIndices() == 0 && getSource().getType() == getResult().getType())
2165 return getSource();
2166 if (auto res = foldPoisonSrcExtractOp(adaptor.getSource()))
2167 return res;
2168 // Fold `arith.constant` indices into the `vector.extract` operation.
2169 // Do not stop here as this fold may enable subsequent folds that require
2170 // constant indices.
2171 SmallVector<Value> operands = {getSource()};
2172 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
2173
2174 if (auto res = foldPoisonIndexInsertExtractOp(
2175 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2176 return res;
2177 if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getSource()))
2178 return res;
2179 if (succeeded(foldExtractOpFromExtractChain(*this)))
2180 return getResult();
2181 if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2182 return res;
2183 if (auto res = foldExtractFromBroadcast(*this))
2184 return res;
2185 if (auto res = foldExtractFromShuffle(*this))
2186 return res;
2187 if (auto res = foldExtractFromShapeCast(*this))
2188 return res;
2189 if (auto val = foldExtractFromExtractStrided(*this))
2190 return val;
2191 if (auto val = foldExtractStridedOpFromInsertChain(*this))
2192 return val;
2193 if (auto val = foldScalarExtractFromFromElements(*this))
2194 return val;
2195
2196 return inplaceFolded;
2197}
2198
2199namespace {
2200
2201// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2202class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2203public:
2204 using Base::Base;
2205
2206 LogicalResult matchAndRewrite(ExtractOp extractOp,
2207 PatternRewriter &rewriter) const override {
2208
2209 Operation *defOp = extractOp.getSource().getDefiningOp();
2210 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2211 if (!defOp || !isBroadcastLike(defOp) || !outType)
2212 return failure();
2213
2214 Value source = defOp->getOperand(0);
2215 if (isBroadcastableTo(source.getType(), outType) !=
2216 BroadcastableToResult::Success)
2217 return failure();
2218
2219 rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
2220 return success();
2221 }
2222};
2223
2224// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2225class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2226public:
2227 using Base::Base;
2228
2229 LogicalResult matchAndRewrite(ExtractOp extractOp,
2230 PatternRewriter &rewriter) const override {
2231 auto createMaskOp =
2232 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2233 if (!createMaskOp)
2234 return failure();
2235
2236 VectorType extractedMaskType =
2237 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2238
2239 if (!extractedMaskType)
2240 return failure();
2241
2242 auto maskOperands = createMaskOp.getOperands();
2243 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2244 VectorType maskType = createMaskOp.getVectorType();
2245
2246 bool containsUnknownDims = false;
2247 bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2248
2249 for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2250 dimIdx++) {
2251 int64_t pos = extractOpPos[dimIdx];
2252 Value operand = maskOperands[dimIdx];
2253 auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2254 if (!constantOp) {
2255 // Bounds of this dim unknown.
2256 containsUnknownDims = true;
2257 continue;
2258 }
2259
2260 int64_t createMaskBound =
2261 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2262
2263 if (pos != ShapedType::kDynamic) {
2264 // If any position is outside the range from the `create_mask`, then the
2265 // extracted mask will be all-false.
2266 allFalse |= pos >= createMaskBound;
2267 } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2268 // This dim is not all-true and since this is a dynamic index we don't
2269 // know if the extraction is within the true or false region.
2270 // Note: Zero dims have already handled via getMaskFormat().
2271 containsUnknownDims = true;
2272 }
2273 }
2274
2275 if (allFalse) {
2276 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2277 extractOp, DenseElementsAttr::get(extractedMaskType, false));
2278 } else if (!containsUnknownDims) {
2279 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2280 extractOp, extractedMaskType,
2281 maskOperands.drop_front(extractOpPos.size()));
2282 } else {
2283 return failure();
2284 }
2285 return success();
2286 }
2287};
2288
2289// Pattern to rewrite a ExtractOp(ConstantMask) -> ConstantMask.
2290class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
2291public:
2292 using Base::Base;
2293
2294 LogicalResult matchAndRewrite(ExtractOp extractOp,
2295 PatternRewriter &rewriter) const override {
2296 auto constantMaskOp =
2297 extractOp.getSource().getDefiningOp<vector::ConstantMaskOp>();
2298 if (!constantMaskOp)
2299 return failure();
2300
2301 Type resultType = extractOp.getResult().getType();
2302 auto extractedMaskType = dyn_cast<VectorType>(resultType);
2303
2304 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2305 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
2306
2307 VectorType maskType = constantMaskOp.getVectorType();
2308
2309 // Check if any extracted position is outside the mask bounds.
2310 for (size_t dimIdx = 0; dimIdx < extractOpPos.size(); dimIdx++) {
2311 int64_t pos = extractOpPos[dimIdx];
2312 if (pos == ShapedType::kDynamic) {
2313 // If the dim is all-true, a dynamic index is fine — any position
2314 // is within the masked region.
2315 if (maskDimSizes[dimIdx] == maskType.getDimSize(dimIdx))
2316 continue;
2317 // Otherwise we don't know if the position is inside or outside of
2318 // the masked area, so bail out.
2319 return failure();
2320 }
2321
2322 // If the position is statically outside of the masked area, the result
2323 // will be all-false.
2324 if (pos >= maskDimSizes[dimIdx]) {
2325 if (extractedMaskType) {
2326 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2327 extractOp, DenseElementsAttr::get(extractedMaskType, false));
2328 } else {
2329 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2330 extractOp, rewriter.getIntegerAttr(resultType, false));
2331 }
2332 return success();
2333 }
2334 }
2335
2336 // All positions are within the mask bounds.
2337 if (extractedMaskType) {
2338 // Vector result: the result is a constant_mask with the remaining
2339 // dimensions.
2340 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
2341 extractOp, extractedMaskType,
2342 maskDimSizes.drop_front(extractOpPos.size()));
2343 } else {
2344 // Scalar result: all positions are within the masked region, so the
2345 // result is true.
2346 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2347 extractOp, rewriter.getIntegerAttr(resultType, true));
2348 }
2349 return success();
2350 }
2351};
2352
2353// Folds extract(shape_cast(..)) into shape_cast when the total element count
2354// does not change.
2355LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2356 PatternRewriter &rewriter) {
2357 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2358 if (!castOp)
2359 return failure();
2360
2361 VectorType sourceType = castOp.getSourceVectorType();
2362 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2363 if (!targetType)
2364 return failure();
2365
2366 if (sourceType.getNumElements() != targetType.getNumElements())
2367 return failure();
2368
2369 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2370 castOp.getSource());
2371 return success();
2372}
2373
2374/// Try to canonicalize the extraction of a subvector from a vector defined by
2375/// vector.from_elements. E.g.:
2376///
2377/// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2378/// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2379/// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2380LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2381 PatternRewriter &rewriter) {
2382 // Dynamic positions are not supported.
2383 if (extractOp.hasDynamicPosition())
2384 return failure();
2385
2386 // Scalar extracts are handled by the folder.
2387 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2388 if (!resultType)
2389 return failure();
2390
2391 // Look for extracts from a from_elements op.
2392 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2393 if (!fromElementsOp)
2394 return failure();
2395 VectorType inputType = fromElementsOp.getType();
2396
2397 // Scalable vectors are not supported.
2398 if (resultType.isScalable() || inputType.isScalable())
2399 return failure();
2400
2401 // Compute the position of first extracted element and flatten/linearize the
2402 // position.
2403 SmallVector<int64_t> firstElementPos =
2404 llvm::to_vector(extractOp.getStaticPosition());
2405 firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2406 int flatIndex = 0;
2407 int stride = 1;
2408 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2409 flatIndex += firstElementPos[i] * stride;
2410 stride *= inputType.getDimSize(i);
2411 }
2412
2413 // Replace the op with a smaller from_elements op.
2414 rewriter.replaceOpWithNewOp<FromElementsOp>(
2415 extractOp, resultType,
2416 fromElementsOp.getElements().slice(flatIndex,
2417 resultType.getNumElements()));
2418 return success();
2419}
2420
2421/// Replace `vector.extract` with `vector.shape_cast`.
2422///
2423/// BEFORE:
2424/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
2425/// AFTER:
2426/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
2427///
2428/// The canonical form of vector operations that reshape vectors is shape_cast.
2429struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
2430 using Base::Base;
2431 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2432 PatternRewriter &rewriter) const override {
2433 VectorType sourceType = extractOp.getSourceVectorType();
2434 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2435 if (!outType)
2436 return failure();
2437
2438 if (sourceType.getNumElements() != outType.getNumElements())
2439 return rewriter.notifyMatchFailure(
2440 extractOp, "extract to vector with fewer elements");
2441
2442 // Negative values in `position` means that the extacted value is poison.
2443 // There is a vector.extract folder for this.
2444 if (llvm::any_of(extractOp.getMixedPosition(),
2445 [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2446 return rewriter.notifyMatchFailure(extractOp,
2447 "leaving for extract poison folder");
2448
2449 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
2450 extractOp.getSource());
2451
2452 return success();
2453 }
2454};
2455
2456} // namespace
2457
2458void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2459 MLIRContext *context) {
2460 results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
2461 ExtractOpFromConstantMask, ExtractToShapeCast>(context);
2462 results.add(foldExtractFromShapeCastToShapeCast);
2463 results.add(foldExtractFromFromElements);
2464}
2465
2467 SmallVectorImpl<int64_t> &results) {
2468 for (auto attr : arrayAttr)
2469 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2470}
2471
2472//===----------------------------------------------------------------------===//
2473// FmaOp
2474//===----------------------------------------------------------------------===//
2475
2476std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2477 return llvm::to_vector<4>(getVectorType().getShape());
2478}
2479
2480//===----------------------------------------------------------------------===//
2481// ToElementsOp
2482//===----------------------------------------------------------------------===//
2483
2484/// Returns true if all the `operands` are defined by `defOp`.
2485/// Otherwise, returns false.
2486static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
2487 if (operands.empty())
2488 return false;
2489
2490 return llvm::all_of(operands, [&](Value operand) {
2491 Operation *currentDef = operand.getDefiningOp();
2492 return currentDef == defOp;
2493 });
2494}
2495
2496/// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
2497/// (%e0, %e1, ...). For example:
2498///
2499/// %0 = vector.from_elements %a, %b, %c : vector<3xf32>
2500/// %1:3 = vector.to_elements %0 : vector<3xf32>
2501/// user_op %1#0, %1#1, %1#2
2502///
2503/// becomes:
2504///
2505/// user_op %a, %b, %c
2506///
2507static LogicalResult
2508foldToElementsFromElements(ToElementsOp toElementsOp,
2510 auto fromElementsOp =
2511 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2512 if (!fromElementsOp)
2513 return failure();
2514
2515 llvm::append_range(results, fromElementsOp.getElements());
2516 return success();
2517}
2518
2519/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
2520///
2521/// Example:
2522/// %b = vector.broadcast %x : i32 to vector<3xf32>
2523/// %e:3 = vector.to_elements %b : vector<3xf32>
2524/// user_op %e#0, %e#1, %e#2
2525/// becomes:
2526/// user_op %x, %x, %x
2527///
2528/// The vector source case is handled by a canonicalization pattern.
2529static LogicalResult
2530foldToElementsOfBroadcast(ToElementsOp toElementsOp,
2532 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2533 if (!bcastOp)
2534 return failure();
2535 // Vectors are handled in the ToElementsOfBroadcast RewritePattern.
2536 if (isa<VectorType>(bcastOp.getSource().getType()))
2537 return failure();
2538
2539 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2540
2541 Value scalar = bcastOp.getSource();
2542 results.assign(resultVecType.getNumElements(), scalar);
2543 return success();
2544}
2545
2546LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2547 SmallVectorImpl<OpFoldResult> &results) {
2548 if (succeeded(foldToElementsFromElements(*this, results)))
2549 return success();
2550
2551 // Y = ToElements(ShapeCast(X)) -> Y = ToElements(X)
2552 if (auto shapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
2553 setOperand(shapeCast.getSource());
2554 return success();
2555 }
2556
2557 return foldToElementsOfBroadcast(*this, results);
2558}
2559
2560LogicalResult
2561ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2562 ToElementsOp::Adaptor adaptor,
2563 SmallVectorImpl<Type> &inferredReturnTypes) {
2564 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2565 Type elType = vecType.getElementType();
2566 inferredReturnTypes.append(vecType.getNumElements(), elType);
2567 return success();
2568}
2569
2570/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
2571/// vector.
2572/// - Build `vector.to_elements %v` and remap each destination element to the
2573/// corresponding source element using broadcast rules (match or 1 →
2574/// replicate).
2575///
2576/// Example:
2577/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
2578/// %e:6 = vector.to_elements %v : vector<3x2xf32>
2579/// becomes:
2580/// %src_elems:2 = vector.to_elements %src : vector<2xf32>
2581/// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2582/// // %src_elems#1, %src_elems#0, %src_elems#1
2583struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
2584 using Base::Base;
2585
2586 LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
2587 PatternRewriter &rewriter) const override {
2588 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2589 if (!bcastOp)
2590 return failure();
2591
2592 // Only handle broadcasts from a vector source here.
2593 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2594 if (!srcType)
2595 return failure();
2596
2597 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2598
2599 ArrayRef<int64_t> dstShape = dstType.getShape();
2600 ArrayRef<int64_t> srcShape = srcType.getShape();
2601
2602 int64_t dstRank = dstShape.size();
2603 int64_t srcRank = srcShape.size();
2604
2605 // Create elements for the broadcast source vector.
2606 auto srcElems = vector::ToElementsOp::create(
2607 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2608
2609 int64_t dstCount = llvm::product_of(dstShape);
2610
2611 SmallVector<Value> replacements;
2612 replacements.reserve(dstCount);
2613
2614 // For each element of the destination, determine which element of the
2615 // source should be used. We walk all destination positions using a single
2616 // counter, decode it into per-dimension indices, then build the matching
2617 // source position: use the same index where sizes match, and use 0 where
2618 // the source size is 1 (replication). This mapping is needed so we can
2619 // replace each result of to_elements with the corresponding element from
2620 // the broadcast source.
2621 // Inner-dimension stretch example:
2622 // %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
2623 // %e:12 = vector.to_elements %v : vector<2x3x2xf32>
2624 // becomes:
2625 // %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
2626 // // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2627 // // %src_elems#1, %src_elems#0, %src_elems#1,
2628 // // %src_elems#2, %src_elems#3, %src_elems#2,
2629 // // %src_elems#3, %src_elems#2, %src_elems#3
2630
2631 // Row-major strides for the destination shape.
2632 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
2633 // Row-major strides for the source shape.
2634 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
2635 SmallVector<int64_t> dstIdx(dstRank);
2636 SmallVector<int64_t> srcIdx(srcRank);
2637 for (int64_t lin = 0; lin < dstCount; ++lin) {
2638 // Convert linear destination index to per-dimension indices.
2639 dstIdx = delinearize(lin, dstStrides);
2640 for (int64_t k = 0; k < srcRank; ++k)
2641 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2642 // Convert per-dimension source indices back to a linear index.
2643 int64_t srcLin = linearize(srcIdx, srcStrides);
2644 replacements.push_back(srcElems.getResult(srcLin));
2645 }
2646
2647 rewriter.replaceOp(toElementsOp, replacements);
2648 return success();
2649 }
2650};
2651
2652void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2653 MLIRContext *context) {
2654 results.add<ToElementsOfBroadcast>(context);
2655}
2656
2657//===----------------------------------------------------------------------===//
2658// FromElementsOp
2659//===----------------------------------------------------------------------===//
2660
2661/// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
2662///
2663/// Case #1: Input and output vectors are the same.
2664///
2665/// %0:3 = vector.to_elements %a : vector<3xf32>
2666/// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
2667/// user_op %1
2668///
2669/// becomes:
2670///
2671/// user_op %a
2672///
2673static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
2674 OperandRange fromElemsOperands = fromElementsOp.getElements();
2675 if (fromElemsOperands.empty())
2676 return {};
2677
2678 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2679 if (!toElementsOp)
2680 return {};
2681
2682 if (!haveSameDefiningOp(fromElemsOperands, toElementsOp))
2683 return {};
2684
2685 // Case #1: Input and output vectors are the same. Forward the input vector.
2686 Value toElementsInput = toElementsOp.getSource();
2687 if (fromElementsOp.getType() == toElementsInput.getType() &&
2688 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2689 return toElementsInput;
2690 }
2691
2692 // TODO: Support cases with different input and output shapes and different
2693 // number of elements.
2694
2695 return {};
2696}
2697
2698/// Fold vector.from_elements to a constant when all operands are constants.
2699/// Example:
2700/// %c1 = arith.constant 1 : i32
2701/// %c2 = arith.constant 2 : i32
2702/// %v = vector.from_elements %c1, %c2 : vector<2xi32>
2703/// =>
2704/// %v = arith.constant dense<[1, 2]> : vector<2xi32>
2705///
2706static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
2707 ArrayRef<Attribute> elements) {
2708 // Check for null or poison attributes before any processing.
2709 if (llvm::any_of(elements, [](Attribute attr) {
2710 return !attr || matchPattern(attr, ub::m_Poison());
2711 }))
2712 return {};
2713
2714 // DenseElementsAttr only supports int/index/float/complex types.
2715 auto destVecType = fromElementsOp.getDest().getType();
2716 auto destEltType = destVecType.getElementType();
2717 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2718 return {};
2719
2720 // Constant attributes might have a different type than the return type.
2721 // Convert them before creating the dense elements attribute.
2722 auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
2723 return convertNumericAttr(attr, destEltType);
2724 });
2725
2726 return DenseElementsAttr::get(destVecType, convertedElements);
2727}
2728
2729OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2730 if (auto res = foldFromElementsToElements(*this))
2731 return res;
2732 if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
2733 return res;
2734
2735 return {};
2736}
2737
2738/// Rewrite vector.from_elements as vector.broadcast if the elements are the
2739/// same. Example:
2740/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2741/// =>
2742/// %0 = vector.broadcast %a : f32 to vector<3xf32>
2743static LogicalResult
2744rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
2745 PatternRewriter &rewriter) {
2746 if (!llvm::all_equal(fromElementsOp.getElements()))
2747 return failure();
2748 rewriter.replaceOpWithNewOp<BroadcastOp>(
2749 fromElementsOp, fromElementsOp.getType(),
2750 fromElementsOp.getElements().front());
2751 return success();
2752}
2753
2754/// Rewrite from_elements on multiple scalar extracts as a shape_cast
2755/// on a single extract. Example:
2756/// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
2757/// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
2758/// %2 = vector.from_elements %0, %1 : vector<2xi8>
2759///
2760/// becomes
2761/// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
2762/// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
2763///
2764/// The requirements for this to be valid are
2765///
2766/// i) The elements are extracted from the same vector (%source).
2767///
2768/// ii) The elements form a suffix of %source. Specifically, the number
2769/// of elements is the same as the product of the last N dimension sizes
2770/// of %source, for some N.
2771///
2772/// iii) The elements are extracted contiguously in ascending order.
2773
2774class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
2775
2776 using Base::Base;
2777
2778 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2779 PatternRewriter &rewriter) const override {
2780
2781 // Handled by `rewriteFromElementsAsBroadcast`.
2782 if (fromElements.getType().getNumElements() == 1)
2783 return failure();
2784
2785 // The common source that all elements are extracted from, if one exists.
2787 // The position of the combined extract operation, if one is created.
2788 ArrayRef<int64_t> combinedPosition;
2789 // The expected index of extraction of the current element in the loop, if
2790 // elements are extracted contiguously in ascending order.
2791 SmallVector<int64_t> expectedPosition;
2792
2793 for (auto [insertIndex, element] :
2794 llvm::enumerate(fromElements.getElements())) {
2795
2796 // Check that the element is from a vector.extract operation.
2797 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2798 if (!extractOp) {
2799 return rewriter.notifyMatchFailure(fromElements,
2800 "element not from vector.extract");
2801 }
2802
2803 // Check condition (i) by checking that all elements have the same source
2804 // as the first element.
2805 if (insertIndex == 0) {
2806 source = extractOp.getSource();
2807 } else if (extractOp.getSource() != source) {
2808 return rewriter.notifyMatchFailure(fromElements,
2809 "element from different vector");
2810 }
2811
2812 ArrayRef<int64_t> position = extractOp.getStaticPosition();
2813 int64_t rank = position.size();
2814 assert(rank == source.getType().getRank() &&
2815 "scalar extract must have full rank position");
2816
2817 // Check condition (ii) by checking that the position that the first
2818 // element is extracted from has sufficient trailing 0s. For example, in
2819 //
2820 // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
2821 // [...]
2822 // %elms = vector.from_elements %elm0, [...] : vector<12xi8>
2823 //
2824 // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
2825 // elements, which is the number of elements of %n, so this is valid.
2826 if (insertIndex == 0) {
2827 const int64_t numElms = fromElements.getType().getNumElements();
2828 int64_t numSuffixElms = 1;
2829 int64_t index = rank;
2830 while (index > 0 && position[index - 1] == 0 &&
2831 numSuffixElms < numElms) {
2832 numSuffixElms *= source.getType().getDimSize(index - 1);
2833 --index;
2834 }
2835 if (numSuffixElms != numElms) {
2836 return rewriter.notifyMatchFailure(
2837 fromElements, "elements do not form a suffix of source");
2838 }
2839 expectedPosition = llvm::to_vector(position);
2840 combinedPosition = position.drop_back(rank - index);
2841 }
2842
2843 // Check condition (iii).
2844 else if (expectedPosition != position) {
2845 return rewriter.notifyMatchFailure(
2846 fromElements, "elements not in ascending order (static order)");
2847 }
2848 increment(expectedPosition, source.getType().getShape());
2849 }
2850
2851 auto extracted = rewriter.createOrFold<vector::ExtractOp>(
2852 fromElements.getLoc(), source, combinedPosition);
2853
2854 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2855 fromElements, fromElements.getType(), extracted);
2856
2857 return success();
2858 }
2859
2860 /// Increments n-D `indices` by 1 starting from the innermost dimension.
2861 static void increment(MutableArrayRef<int64_t> indices,
2863 for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2864 indices[dim] += 1;
2865 if (indices[dim] < shape[dim])
2866 break;
2867 indices[dim] = 0;
2868 }
2869 }
2870};
2871
2872void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2873 MLIRContext *context) {
2875 results.add<FromElementsToShapeCast>(context);
2876}
2877
2878//===----------------------------------------------------------------------===//
2879// BroadcastOp
2880//===----------------------------------------------------------------------===//
2881
2882void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2883 SetIntRangeFn setResultRanges) {
2884 setResultRanges(getResult(), argRanges.front());
2885}
2886
2887std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2888 return llvm::to_vector<4>(getResultVectorType().getShape());
2889}
2890
2891/// Return the dimensions of the result vector that were formerly ones in the
2892/// source tensor and thus correspond to "dim-1" broadcasting.
2893static llvm::SetVector<int64_t>
2895 ArrayRef<int64_t> dstShape) {
2896 int64_t rankDiff = dstShape.size() - srcShape.size();
2897 int64_t dstDim = rankDiff;
2899 for (auto [s1, s2] :
2900 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2901 if (s1 != s2) {
2902 assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2903 res.insert(dstDim);
2904 }
2905 ++dstDim;
2906 }
2907 return res;
2908}
2909
2910llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2911 // Scalar broadcast is without any unit dim broadcast.
2912 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2913 if (!srcVectorType)
2914 return {};
2915 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2916 getResultVectorType().getShape());
2917}
2918
2919/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2920/// `broadcastedDims` dimensions in the dstShape are broadcasted.
2921/// This requires (and asserts) that the broadcast is free of "dim-1"
2922/// broadcasting.
2923/// Since vector.broadcast only allows expanding leading dimensions, an extra
2924/// vector.transpose may be inserted to make the broadcast possible.
2925/// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2926/// the helper will assert. This means:
2927/// 1. `dstShape` must not be empty.
2928/// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2929/// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2930// must match the `value` shape.
2931Value BroadcastOp::createOrFoldBroadcastOp(
2932 OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2933 const llvm::SetVector<int64_t> &broadcastedDims) {
2934 assert(!dstShape.empty() && "unexpected empty dst shape");
2935
2936 // Well-formedness check.
2937 SmallVector<int64_t> checkShape;
2938 for (int i = 0, e = dstShape.size(); i < e; ++i) {
2939 if (broadcastedDims.contains(i))
2940 continue;
2941 checkShape.push_back(dstShape[i]);
2942 }
2943 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2944 "ill-formed broadcastedDims contains values not confined to "
2945 "destVectorShape");
2946
2947 Location loc = value.getLoc();
2948 Type elementType = getElementTypeOrSelf(value.getType());
2949 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2950 VectorType dstVectorType = VectorType::get(dstShape, elementType);
2951
2952 // Step 2. If scalar -> dstShape broadcast, just do it.
2953 if (!srcVectorType) {
2954 assert(checkShape.empty() &&
2955 "ill-formed createOrFoldBroadcastOp arguments");
2956 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2957 }
2958
2959 assert(srcVectorType.getShape().equals(checkShape) &&
2960 "ill-formed createOrFoldBroadcastOp arguments");
2961
2962 // Step 3. Since vector.broadcast only allows creating leading dims,
2963 // vector -> dstShape broadcast may require a transpose.
2964 // Traverse the dims in order and construct:
2965 // 1. The leading entries of the broadcastShape that is guaranteed to be
2966 // achievable by a simple broadcast.
2967 // 2. The induced permutation for the subsequent vector.transpose that will
2968 // bring us from `broadcastShape` back to he desired `dstShape`.
2969 // If the induced permutation is not the identity, create a vector.transpose.
2970 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2971 broadcastShape.reserve(dstShape.size());
2972 // Consider the example:
2973 // srcShape = 2x4
2974 // dstShape = 1x2x3x4x5
2975 // broadcastedDims = [0, 2, 4]
2976 //
2977 // We want to build:
2978 // broadcastShape = 1x3x5x2x4
2979 // permutation = [0, 2, 4, 1, 3]
2980 // ---V--- -----V-----
2981 // leading broadcast part src shape part
2982 //
2983 // Note that the trailing dims of broadcastShape are exactly the srcShape
2984 // by construction.
2985 // nextSrcShapeDim is used to keep track of where in the permutation the
2986 // "src shape part" occurs.
2987 int64_t nextSrcShapeDim = broadcastedDims.size();
2988 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2989 if (broadcastedDims.contains(i)) {
2990 // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2991 // bring it to the head of the broadcastShape.
2992 // It will need to be permuted back from `broadcastShape.size() - 1` into
2993 // position `i`.
2994 broadcastShape.push_back(dstShape[i]);
2995 permutation[i] = broadcastShape.size() - 1;
2996 } else {
2997 // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2998 // shape and needs to be permuted into position `i`.
2999 // Don't touch `broadcastShape` here, the whole srcShape will be
3000 // appended after.
3001 permutation[i] = nextSrcShapeDim++;
3002 }
3003 }
3004 // 3.c. Append the srcShape.
3005 llvm::append_range(broadcastShape, srcVectorType.getShape());
3006
3007 // Ensure there are no "dim-1" broadcasts.
3008 assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
3009 .empty() &&
3010 "unexpected \"dim-1\" broadcast");
3011
3012 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
3013 assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
3014 vector::BroadcastableToResult::Success &&
3015 "must be broadcastable");
3016 Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
3017 // Step 4. If we find any dimension that indeed needs to be permuted,
3018 // immediately return a new vector.transpose.
3019 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
3020 if (permutation[i] != i)
3021 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
3022 // Otherwise return res.
3023 return res;
3024}
3025
3027 Type srcType, VectorType dstVectorType,
3028 std::pair<VectorDim, VectorDim> *mismatchingDims) {
3029 // Broadcast scalar to vector of the same element type.
3030 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
3031 srcType == getElementTypeOrSelf(dstVectorType))
3033 // From now on, only vectors broadcast.
3034 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
3035 if (!srcVectorType)
3037
3038 int64_t srcRank = srcVectorType.getRank();
3039 int64_t dstRank = dstVectorType.getRank();
3040 if (srcRank > dstRank)
3042 // Source has an exact match or singleton value for all trailing dimensions
3043 // (all leading dimensions are simply duplicated).
3044 int64_t lead = dstRank - srcRank;
3045 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
3046 // Have mismatching dims (in the sense of vector.broadcast semantics) been
3047 // encountered?
3048 bool foundMismatchingDims = false;
3049
3050 // Check fixed-width dims.
3051 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
3052 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
3053 if (srcDim != 1 && srcDim != dstDim)
3054 foundMismatchingDims = true;
3055
3056 // Check scalable flags.
3057 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
3058 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
3059 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
3060 // 1 -> [N] is fine, everything else should be rejected when mixing
3061 // fixed-width and scalable dims
3062 (srcDimScalableFlag != dstDimScalableFlag &&
3063 (srcDim != 1 || srcDimScalableFlag)))
3064 foundMismatchingDims = true;
3065
3066 if (foundMismatchingDims) {
3067 if (mismatchingDims != nullptr) {
3068 mismatchingDims->first.dim = srcDim;
3069 mismatchingDims->first.isScalable = srcDimScalableFlag;
3070
3071 mismatchingDims->second.dim = dstDim;
3072 mismatchingDims->second.isScalable = dstDimScalableFlag;
3073 }
3075 }
3076 }
3077
3079}
3080
3081LogicalResult BroadcastOp::verify() {
3082 std::pair<VectorDim, VectorDim> mismatchingDims;
3084 getSourceType(), getResultVectorType(), &mismatchingDims);
3086 return success();
3088 return emitOpError("source rank higher than destination rank");
3090 return emitOpError("dimension mismatch (")
3091 << (mismatchingDims.first.isScalable ? "[" : "")
3092 << mismatchingDims.first.dim
3093 << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
3094 << (mismatchingDims.second.isScalable ? "[" : "")
3095 << mismatchingDims.second.dim
3096 << (mismatchingDims.second.isScalable ? "]" : "") << ")";
3097 }
3099 return emitOpError("source type is not a vector");
3100 llvm_unreachable("unexpected vector.broadcast op error");
3101}
3102
3103// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
3104// with broadcast's result type and shape_cast only adds or removes ones in the
3105// leading dimensions.
3106static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
3107 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3108 if (!srcShapeCast)
3109 return failure();
3110
3111 VectorType srcType = srcShapeCast.getSourceVectorType();
3112 VectorType destType = broadcastOp.getResultVectorType();
3113 // Check type compatibility.
3114 if (vector::isBroadcastableTo(srcType, destType) !=
3116 return failure();
3117
3118 ArrayRef<int64_t> srcShape = srcType.getShape();
3119 ArrayRef<int64_t> shapecastShape =
3120 srcShapeCast.getResultVectorType().getShape();
3121 // Trailing dimensions should be the same if shape_cast only alters the
3122 // leading dimensions.
3123 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3124 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3125 shapecastShape.take_back(numTrailingDims)))
3126 return failure();
3127
3128 assert(all_of(srcShape.drop_back(numTrailingDims),
3129 [](int64_t E) { return E == 1; }) &&
3130 all_of(shapecastShape.drop_back(numTrailingDims),
3131 [](int64_t E) { return E == 1; }) &&
3132 "ill-formed shape_cast");
3133
3134 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3135 return success();
3136}
3137
3138OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3139 if (getSourceType() == getResultVectorType())
3140 return getSource();
3141 if (succeeded(foldBroadcastOfShapeCast(*this)))
3142 return getResult();
3143
3144 if (!adaptor.getSource())
3145 return {};
3146 auto vectorType = getResultVectorType();
3147 if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3148 if (vectorType.getElementType() != attr.getType())
3149 return {};
3150 return DenseElementsAttr::get(vectorType, attr);
3151 }
3152 if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3153 if (vectorType.getElementType() != attr.getType())
3154 return {};
3155 return DenseElementsAttr::get(vectorType, attr);
3156 }
3157 if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3158 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
3159 if (matchPattern(adaptor.getSource(), ub::m_Poison()))
3160 return ub::PoisonAttr::get(getContext());
3161 return {};
3162}
3163
3164namespace {
3165
3166// Fold broadcast1(broadcast2(x)) into broadcast1(x).
3167struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
3168 using Base::Base;
3169
3170 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3171 PatternRewriter &rewriter) const override {
3172 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3173 if (!srcBroadcast)
3174 return failure();
3175 rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
3176 broadcastOp.getResultVectorType(),
3177 srcBroadcast.getSource());
3178 return success();
3179 }
3180};
3181
3182/// Replace `vector.broadcast` with `vector.shape_cast`.
3183///
3184/// BEFORE:
3185/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
3186/// AFTER:
3187/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
3188///
3189/// The canonical form of vector operations that reshape vectors is shape_cast.
3190struct BroadcastToShapeCast final
3191 : public OpRewritePattern<vector::BroadcastOp> {
3192 using Base::Base;
3193 LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
3194 PatternRewriter &rewriter) const override {
3195
3196 auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
3197 if (!sourceType) {
3198 return rewriter.notifyMatchFailure(
3199 broadcast, "source is a scalar, shape_cast doesn't support scalar");
3200 }
3201
3202 VectorType outType = broadcast.getType();
3203 if (sourceType.getNumElements() != outType.getNumElements()) {
3204 return rewriter.notifyMatchFailure(
3205 broadcast, "broadcast to a greater number of elements");
3206 }
3207
3208 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
3209 broadcast.getSource());
3210 return success();
3211 }
3212};
3213} // namespace
3214
3215void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3216 MLIRContext *context) {
3217 results.add<BroadcastFolder, BroadcastToShapeCast>(context);
3218}
3219
3220//===----------------------------------------------------------------------===//
3221// ShuffleOp
3222//===----------------------------------------------------------------------===//
3223
3224LogicalResult ShuffleOp::verify() {
3225 VectorType resultType = getResultVectorType();
3226 VectorType v1Type = getV1VectorType();
3227 VectorType v2Type = getV2VectorType();
3228 // Verify ranks.
3229 int64_t resRank = resultType.getRank();
3230 int64_t v1Rank = v1Type.getRank();
3231 int64_t v2Rank = v2Type.getRank();
3232 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3233 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3234 if (!wellFormed0DCase && !wellFormedNDCase)
3235 return emitOpError("rank mismatch");
3236
3237 // Verify all but leading dimension sizes.
3238 for (int64_t r = 1; r < v1Rank; ++r) {
3239 int64_t resDim = resultType.getDimSize(r);
3240 int64_t v1Dim = v1Type.getDimSize(r);
3241 int64_t v2Dim = v2Type.getDimSize(r);
3242 if (resDim != v1Dim || v1Dim != v2Dim)
3243 return emitOpError("dimension mismatch");
3244 }
3245 // Verify mask length.
3246 ArrayRef<int64_t> mask = getMask();
3247 int64_t maskLength = mask.size();
3248 if (maskLength <= 0)
3249 return emitOpError("invalid mask length");
3250 if (maskLength != resultType.getDimSize(0))
3251 return emitOpError("mask length mismatch");
3252 // Verify all indices.
3253 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3254 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3255 for (auto [idx, maskPos] : llvm::enumerate(mask)) {
3256 if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
3257 return emitOpError("mask index #") << (idx + 1) << " out of range";
3258 }
3259 return success();
3260}
3261
3262LogicalResult
3263ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location> loc,
3264 ShuffleOp::Adaptor adaptor,
3265 SmallVectorImpl<Type> &inferredReturnTypes) {
3266 auto v1Type = llvm::dyn_cast<VectorType>(adaptor.getV1().getType());
3267 if (!v1Type) {
3268 return emitOptionalError(loc, "expected vector type");
3269 }
3270 auto v1Rank = v1Type.getRank();
3271 // Construct resulting type: leading dimension matches mask
3272 // length, all trailing dimensions match the operands.
3273 SmallVector<int64_t, 4> shape;
3274 shape.reserve(v1Rank);
3275 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3276 // In the 0-D case there is no trailing shape to append.
3277 if (v1Rank > 0)
3278 llvm::append_range(shape, v1Type.getShape().drop_front());
3279 inferredReturnTypes.push_back(
3280 VectorType::get(shape, v1Type.getElementType()));
3281 return success();
3282}
3283
3284template <typename T>
3285static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
3286 T expected = begin;
3287 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3288 return value == expected++;
3289 });
3290}
3291
3292OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3293 auto v1Type = getV1VectorType();
3294 auto v2Type = getV2VectorType();
3295
3296 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
3297 "Vector shuffle does not support scalable vectors");
3298
3299 // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
3300 // but must be a canonicalization into a vector.broadcast.
3301 if (v1Type.getRank() == 0)
3302 return {};
3303
3304 // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
3305 auto mask = getMask();
3306 if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
3307 return getV1();
3308 // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
3309 if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
3310 return getV2();
3311
3312 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3313 if (!v1Attr || !v2Attr)
3314 return {};
3315
3316 // Fold shuffle poison, poison -> poison.
3317 bool isV1Poison = matchPattern(v1Attr, ub::m_Poison());
3318 bool isV2Poison = matchPattern(v2Attr, ub::m_Poison());
3319 if (isV1Poison && isV2Poison)
3320 return ub::PoisonAttr::get(getContext());
3321
3322 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
3323 // manipulation.
3324 if (v1Type.getRank() != 1)
3325 return {};
3326
3327 // Poison input attributes need special handling as they are not
3328 // DenseElementsAttr. If an index is poison, we select the first element of
3329 // the first non-poison input.
3330 SmallVector<Attribute> v1Elements, v2Elements;
3331 Attribute poisonElement;
3332 if (!isV2Poison) {
3333 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3334 if (!v2DenseAttr)
3335 return {};
3336 v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3337 poisonElement = v2Elements[0];
3338 }
3339 if (!isV1Poison) {
3340 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3341 if (!v1DenseAttr)
3342 return {};
3343 v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3344 poisonElement = v1Elements[0];
3345 }
3346
3347 SmallVector<Attribute> results;
3348 int64_t v1Size = v1Type.getDimSize(0);
3349 for (int64_t maskIdx : mask) {
3350 Attribute indexedElm;
3351 // TODO: Return a partial poison vector when supported by the UB dialect.
3352 if (maskIdx == ShuffleOp::kPoisonIndex) {
3353 indexedElm = poisonElement;
3354 } else {
3355 if (maskIdx < v1Size)
3356 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3357 else
3358 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3359 }
3360
3361 results.push_back(indexedElm);
3362 }
3363
3364 return DenseElementsAttr::get(getResultVectorType(), results);
3365}
3366
3367namespace {
3368
3369// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
3370// to a broadcast.
3371struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
3372 using Base::Base;
3373
3374 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3375 PatternRewriter &rewriter) const override {
3376 VectorType v1VectorType = shuffleOp.getV1VectorType();
3377 ArrayRef<int64_t> mask = shuffleOp.getMask();
3378 if (v1VectorType.getRank() > 0)
3379 return failure();
3380 if (mask.size() != 1)
3381 return failure();
3382 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3383 if (mask[0] == 0)
3384 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3385 shuffleOp.getV1());
3386 else
3387 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3388 shuffleOp.getV2());
3389 return success();
3390 }
3391};
3392
3393/// Consider the defining operation `defOp` of `value`. If `defOp` is a
3394/// vector.broadcast with a scalar operand, return the scalar value that is
3395/// splatted. Otherwise return null.
3396///
3397/// Example:
3398///
3399/// scalar_source --> vector.broadcast --> value - return scalar_source
3400static Value getScalarSplatSource(Value value) {
3401 // Block argument:
3402 Operation *defOp = value.getDefiningOp();
3403 if (!defOp)
3404 return {};
3405
3406 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3407
3408 // Not broadcast (and not splat):
3409 if (!broadcast)
3410 return {};
3411
3412 // Broadcast of a vector:
3413 if (isa<VectorType>(broadcast.getSourceType()))
3414 return {};
3415
3416 // Broadcast of a scalar:
3417 return broadcast.getSource();
3418}
3419
3420/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
3421class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
3422public:
3423 using Base::Base;
3424
3425 LogicalResult matchAndRewrite(ShuffleOp op,
3426 PatternRewriter &rewriter) const override {
3427 Value splat = getScalarSplatSource(op.getV1());
3428 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3429 return failure();
3430
3431 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3432 return success();
3433 }
3434};
3435
3436/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
3437/// vector.interleave.
3438class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
3439public:
3440 using Base::Base;
3441
3442 LogicalResult matchAndRewrite(ShuffleOp op,
3443 PatternRewriter &rewriter) const override {
3444 VectorType resultType = op.getResultVectorType();
3445 if (resultType.isScalable())
3446 return rewriter.notifyMatchFailure(
3447 op, "ShuffleOp can't represent a scalable interleave");
3448
3449 if (resultType.getRank() != 1)
3450 return rewriter.notifyMatchFailure(
3451 op, "ShuffleOp can't represent an n-D interleave");
3452
3453 VectorType sourceType = op.getV1VectorType();
3454 if (sourceType != op.getV2VectorType() ||
3455 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3456 return rewriter.notifyMatchFailure(
3457 op, "ShuffleOp types don't match an interleave");
3458 }
3459
3460 ArrayRef<int64_t> shuffleMask = op.getMask();
3461 int64_t resultVectorSize = resultType.getNumElements();
3462 for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3463 int64_t maskValueA = shuffleMask[i * 2];
3464 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3465 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3466 return rewriter.notifyMatchFailure(op,
3467 "ShuffleOp mask not interleaving");
3468 }
3469
3470 rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
3471 return success();
3472 }
3473};
3474
3475} // namespace
3476
3477void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3478 MLIRContext *context) {
3479 results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3480 context);
3481}
3482
3483//===----------------------------------------------------------------------===//
3484// InsertOp
3485//===----------------------------------------------------------------------===//
3486
3487void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3488 SetIntRangeFn setResultRanges) {
3489 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3490}
3491
3492void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3493 Value source, Value dest) {
3494 auto vectorTy = cast<VectorType>(dest.getType());
3495 build(builder, result, source, dest,
3496 SmallVector<int64_t>(vectorTy.getRank(), 0));
3497}
3498
3499void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3500 Value source, Value dest, int64_t position) {
3501 build(builder, result, source, dest, ArrayRef<int64_t>{position});
3502}
3503
3504void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3505 Value source, Value dest, OpFoldResult position) {
3506 build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
3507}
3508
3509void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3510 Value source, Value dest,
3511 ArrayRef<int64_t> position) {
3512 SmallVector<OpFoldResult> posVals;
3513 posVals.reserve(position.size());
3514 llvm::transform(position, std::back_inserter(posVals),
3515 [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
3516 build(builder, result, source, dest, posVals);
3517}
3518
3519void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3520 Value source, Value dest,
3521 ArrayRef<OpFoldResult> position) {
3522 SmallVector<int64_t> staticPos;
3523 SmallVector<Value> dynamicPos;
3524 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
3525 build(builder, result, source, dest, dynamicPos,
3526 builder.getDenseI64ArrayAttr(staticPos));
3527}
3528
3529LogicalResult InsertOp::verify() {
3530 if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3531 if (srcTy.getRank() == 0)
3532 return emitError(
3533 "expected a scalar instead of a 0-d vector as the source operand");
3534
3535 SmallVector<OpFoldResult> position = getMixedPosition();
3536 auto destVectorType = getDestVectorType();
3537 if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
3538 return emitOpError(
3539 "expected position attribute of rank no greater than dest vector rank");
3540 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3541 if (srcVectorType &&
3542 (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3543 static_cast<unsigned>(destVectorType.getRank())))
3544 return emitOpError("expected position attribute rank + source rank to "
3545 "match dest vector rank");
3546 if (!srcVectorType &&
3547 (position.size() != static_cast<unsigned>(destVectorType.getRank())))
3548 return emitOpError(
3549 "expected position attribute rank to match the dest vector rank");
3550 for (auto [idx, pos] : llvm::enumerate(position)) {
3551 if (auto attr = dyn_cast<Attribute>(pos)) {
3552 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3553 if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
3554 destVectorType.getDimSize(idx))) {
3555 return emitOpError("expected position attribute #")
3556 << (idx + 1)
3557 << " to be a non-negative integer smaller than the "
3558 "corresponding "
3559 "dest vector dimension";
3560 }
3561 }
3562 }
3563 return success();
3564}
3565
3566// Calculate the linearized position of the continuous chunk of elements to
3567// insert, based on the shape of the value to insert and the positions to insert
3568// at.
3569static int64_t calculateInsertPosition(VectorType destTy,
3570 ArrayRef<int64_t> positions) {
3571 llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3572 assert(positions.size() <= completePositions.size() &&
3573 "positions size must be less than or equal to destTy rank");
3574 copy(positions, completePositions.begin());
3575 return linearize(completePositions, computeStrides(destTy.getShape()));
3576}
3577
3578namespace {
3579
3580// If insertOp is only inserting unit dimensions it can be transformed to a
3581// broadcast.
3582class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3583public:
3584 using Base::Base;
3585
3586 LogicalResult matchAndRewrite(InsertOp insertOp,
3587 PatternRewriter &rewriter) const override {
3588 auto srcVecType =
3589 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3590 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3591 srcVecType.getNumElements())
3592 return failure();
3593 rewriter.replaceOpWithNewOp<BroadcastOp>(
3594 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3595 return success();
3596 }
3597};
3598
3599/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
3600class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3601public:
3602 using Base::Base;
3603
3604 LogicalResult matchAndRewrite(InsertOp op,
3605 PatternRewriter &rewriter) const override {
3606
3607 Value splat = getScalarSplatSource(op.getValueToStore());
3608 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3609 return failure();
3610
3611 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3612 return success();
3613 }
3614};
3615
3616/// Pattern to optimize a chain of insertions.
3617///
3618/// This pattern identifies chains of vector.insert operations that:
3619/// 1. Only insert values at static positions.
3620/// 2. Completely initialize all elements in the resulting vector.
3621/// 3. All intermediate insert operations have only one use.
3622///
3623/// When these conditions are met, the entire chain can be replaced with a
3624/// single vector.from_elements operation.
3625///
3626/// To keep this pattern simple, and avoid spending too much time on matching
3627/// fragmented insert chains, this pattern only considers the last insert op in
3628/// the chain.
3629///
3630/// Example transformation:
3631/// %poison = ub.poison : vector<2xi32>
3632/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
3633/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
3634/// ->
3635/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
3636class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
3637public:
3638 using Base::Base;
3639 LogicalResult matchAndRewrite(InsertOp op,
3640 PatternRewriter &rewriter) const override {
3641
3642 VectorType destTy = op.getDestVectorType();
3643 if (destTy.isScalable())
3644 return failure();
3645 // Ensure this is the trailing vector.insert op in a chain of inserts.
3646 for (Operation *user : op.getResult().getUsers())
3647 if (auto insertOp = dyn_cast<InsertOp>(user))
3648 if (insertOp.getDest() == op.getResult())
3649 return failure();
3650
3651 InsertOp currentOp = op;
3652 SmallVector<InsertOp> chainInsertOps;
3653 while (currentOp) {
3654 // Check cond 1: Dynamic position is not supported.
3655 if (currentOp.hasDynamicPosition())
3656 return failure();
3657
3658 chainInsertOps.push_back(currentOp);
3659 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3660 // Check cond 3: Intermediate inserts have only one use to avoid an
3661 // explosion of vectors.
3662 if (currentOp && !currentOp->hasOneUse())
3663 return failure();
3664 }
3665
3666 int64_t vectorSize = destTy.getNumElements();
3667 int64_t initializedCount = 0;
3668 SmallVector<bool> initializedDestIdxs(vectorSize, false);
3669 SmallVector<int64_t> pendingInsertPos;
3670 SmallVector<int64_t> pendingInsertSize;
3671 SmallVector<Value> pendingInsertValues;
3672
3673 for (auto insertOp : chainInsertOps) {
3674 // This pattern can do nothing with poison index.
3675 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3676 return failure();
3677
3678 // Calculate the linearized position for inserting elements.
3679 int64_t insertBeginPosition =
3680 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3681
3682 // The valueToStore operand may be a vector or a scalar. Need to handle
3683 // both cases.
3684 int64_t insertSize = 1;
3685 if (auto srcVectorType =
3686 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3687 insertSize = srcVectorType.getNumElements();
3688
3689 assert(insertBeginPosition + insertSize <= vectorSize &&
3690 "insert would overflow the vector");
3691
3692 for (auto index : llvm::seq<int64_t>(insertBeginPosition,
3693 insertBeginPosition + insertSize)) {
3694 if (initializedDestIdxs[index])
3695 continue;
3696 initializedDestIdxs[index] = true;
3697 ++initializedCount;
3698 }
3699
3700 // Defer the creation of ops before we can make sure the pattern can
3701 // succeed.
3702 pendingInsertPos.push_back(insertBeginPosition);
3703 pendingInsertSize.push_back(insertSize);
3704 pendingInsertValues.push_back(insertOp.getValueToStore());
3705
3706 if (initializedCount == vectorSize)
3707 break;
3708 }
3709
3710 // Check cond 2: all positions must be initialized.
3711 if (initializedCount != vectorSize)
3712 return failure();
3713
3714 SmallVector<Value> elements(vectorSize);
3715 for (auto [insertBeginPosition, insertSize, valueToStore] :
3716 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3717 pendingInsertValues))) {
3718 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3719
3720 if (!srcVectorType) {
3721 elements[insertBeginPosition] = valueToStore;
3722 continue;
3723 }
3724
3725 Repeated<Type> elementToInsertTypes(insertSize,
3726 srcVectorType.getElementType());
3727 // Get all elements from the vector in row-major order.
3728 auto elementsToInsert = vector::ToElementsOp::create(
3729 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3730 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3731 elements[insertBeginPosition + linearIdx] =
3732 elementsToInsert.getResult(linearIdx);
3733 }
3734 }
3735
3736 rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
3737 return success();
3738 }
3739};
3740
3741} // namespace
3742
3743static Attribute
3745 Attribute dstAttr,
3746 int64_t maxVectorSizeFoldThreshold) {
3747 if (insertOp.hasDynamicPosition())
3748 return {};
3749
3750 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3751 if (!denseDst)
3752 return {};
3753
3754 if (!srcAttr) {
3755 return {};
3756 }
3757
3758 VectorType destTy = insertOp.getDestVectorType();
3759 if (destTy.isScalable())
3760 return {};
3761
3762 // Make sure we do not create too many large constants.
3763 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3764 !insertOp->hasOneUse())
3765 return {};
3766
3767 // Bail out on poison indices (kPoisonIndex = -1) to avoid computing an
3768 // invalid (negative) linearized position which would cause UB below.
3769 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3770 return {};
3771
3772 // Calculate the linearized position for inserting elements.
3773 int64_t insertBeginPosition =
3774 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3775 SmallVector<Attribute> insertedValues;
3776 Type destEltType = destTy.getElementType();
3777
3778 /// Converts attribute to the expected type if there's
3779 /// a mismatch.
3780 if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3781 for (auto value : denseSource.getValues<Attribute>())
3782 insertedValues.push_back(convertNumericAttr(value, destEltType));
3783 } else {
3784 insertedValues.push_back(convertNumericAttr(srcAttr, destEltType));
3785 }
3786
3787 auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
3788 copy(insertedValues, allValues.begin() + insertBeginPosition);
3789 auto newAttr = DenseElementsAttr::get(destTy, allValues);
3790
3791 return newAttr;
3792}
3793
3794/// Folder to replace the `dest` operand of the insert op with the root dest of
3795/// the insert op use chain.
3796static Value foldInsertUseChain(InsertOp insertOp) {
3797 auto destInsert = insertOp.getDest().getDefiningOp<InsertOp>();
3798 if (!destInsert)
3799 return {};
3800
3801 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3802 return {};
3803
3804 insertOp.setOperand(1, destInsert.getDest());
3805 return insertOp.getResult();
3806}
3807
3808void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3809 MLIRContext *context) {
3810 results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3811 InsertChainFullyInitialized>(context);
3812}
3813
3814OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
3815 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3816 // unless the source vector constant has a single use.
3817 constexpr int64_t vectorSizeFoldThreshold = 256;
3818 // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3819 // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3820 // (type mismatch).
3821 if (getNumIndices() == 0 && getValueToStoreType() == getType())
3822 return getValueToStore();
3823 // Fold `arith.constant` indices into the `vector.insert` operation.
3824 // Do not stop here as this fold may enable subsequent folds that require
3825 // constant indices.
3826 SmallVector<Value> operands = {getValueToStore(), getDest()};
3827 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
3828
3829 if (auto res = foldInsertUseChain(*this))
3830 return res;
3831 if (auto res = foldPoisonIndexInsertExtractOp(
3832 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3833 return res;
3834 if (auto res = foldDenseElementsAttrDestInsertOp(
3835 *this, adaptor.getValueToStore(), adaptor.getDest(),
3836 vectorSizeFoldThreshold)) {
3837 return res;
3838 }
3839
3840 return inplaceFolded;
3841}
3842
3843//===----------------------------------------------------------------------===//
3844// InsertStridedSliceOp
3845//===----------------------------------------------------------------------===//
3846
3847void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3848 Value source, Value dest,
3849 ArrayRef<int64_t> offsets,
3850 ArrayRef<int64_t> strides) {
3851 result.addOperands({source, dest});
3852 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3853 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3854 result.addTypes(dest.getType());
3855 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
3856 offsetsAttr);
3857 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
3858 stridesAttr);
3859}
3860
3861// TODO: Should be moved to Tablegen ConfinedAttr attributes.
3862template <typename OpType>
3863static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3864 ArrayAttr arrayAttr,
3866 StringRef attrName) {
3867 if (arrayAttr.size() > shape.size())
3868 return op.emitOpError("expected ")
3869 << attrName << " attribute of rank no greater than vector rank";
3870 return success();
3871}
3872
3873// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3874// interval. If `halfOpen` is true then the admissible interval is [min, max).
3875// Otherwise, the admissible interval is [min, max].
3876template <typename OpType>
3877static LogicalResult
3879 int64_t max, StringRef attrName,
3880 bool halfOpen = true) {
3881 for (auto attr : arrayAttr) {
3882 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3883 auto upper = max;
3884 if (!halfOpen)
3885 upper += 1;
3886 if (val < min || val >= upper)
3887 return op.emitOpError("expected ") << attrName << " to be confined to ["
3888 << min << ", " << upper << ")";
3889 }
3890 return success();
3891}
3892
3893// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3894// interval. If `halfOpen` is true then the admissible interval is [min, max).
3895// Otherwise, the admissible interval is [min, max].
3896template <typename OpType>
3897static LogicalResult
3899 ArrayRef<int64_t> shape, StringRef attrName,
3900 bool halfOpen = true, int64_t min = 0) {
3901 for (auto [index, attrDimPair] :
3902 llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3903 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3904 int64_t max = std::get<1>(attrDimPair);
3905 if (!halfOpen)
3906 max += 1;
3907 if (val < min || val >= max)
3908 return op.emitOpError("expected ")
3909 << attrName << " dimension " << index << " to be confined to ["
3910 << min << ", " << max << ")";
3911 }
3912 return success();
3913}
3914
3915// Returns true if, for all indices i = 0..shape.size()-1, val is in the
3916// [min, max} interval:
3917// val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3918// If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3919// the admissible interval is [min, max].
3920template <typename OpType>
3922 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3923 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3924 bool halfOpen = true, int64_t min = 1) {
3925 assert(arrayAttr1.size() <= shape.size());
3926 assert(arrayAttr2.size() <= shape.size());
3927 for (auto [index, it] :
3928 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3929 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3930 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3931 int64_t max = std::get<2>(it);
3932 if (!halfOpen)
3933 max += 1;
3934 if (val1 + val2 < 0 || val1 + val2 >= max)
3935 return op.emitOpError("expected sum(")
3936 << attrName1 << ", " << attrName2 << ") dimension " << index
3937 << " to be confined to [" << min << ", " << max << ")";
3938 }
3939 return success();
3940}
3941
3943 MLIRContext *context) {
3944 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3945 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3946 });
3947 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3948}
3949
3950LogicalResult InsertStridedSliceOp::verify() {
3951 auto sourceVectorType = getSourceVectorType();
3952 auto destVectorType = getDestVectorType();
3953 auto offsets = getOffsetsAttr();
3954 auto strides = getStridesAttr();
3955 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3956 return emitOpError(
3957 "expected offsets of same size as destination vector rank");
3958 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3959 return emitOpError("expected strides of same size as source vector rank");
3960 if (sourceVectorType.getRank() > destVectorType.getRank())
3961 return emitOpError(
3962 "expected source rank to be no greater than destination rank");
3963
3964 auto sourceShape = sourceVectorType.getShape();
3965 auto destShape = destVectorType.getShape();
3966 SmallVector<int64_t, 4> sourceShapeAsDestShape(
3967 destShape.size() - sourceShape.size(), 0);
3968 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3969 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3970 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3971 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3972 offName)) ||
3973 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3974 /*max=*/1, stridesName,
3975 /*halfOpen=*/false)) ||
3977 *this, offsets,
3978 makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3979 offName, "source vector shape",
3980 /*halfOpen=*/false, /*min=*/1)))
3981 return failure();
3982
3983 unsigned rankDiff = destShape.size() - sourceShape.size();
3984 for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3985 if (sourceVectorType.getScalableDims()[idx] !=
3986 destVectorType.getScalableDims()[idx + rankDiff]) {
3987 return emitOpError("mismatching scalable flags (at source vector idx=")
3988 << idx << ")";
3989 }
3990 if (sourceVectorType.getScalableDims()[idx]) {
3991 auto sourceSize = sourceShape[idx];
3992 auto destSize = destShape[idx + rankDiff];
3993 if (sourceSize != destSize) {
3994 return emitOpError("expected size at idx=")
3995 << idx
3996 << (" to match the corresponding base size from the input "
3997 "vector (")
3998 << sourceSize << (" vs ") << destSize << (")");
3999 }
4000 }
4001 }
4002
4003 return success();
4004}
4005
4006namespace {
4007/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
4008class FoldInsertStridedSliceSplat final
4009 : public OpRewritePattern<InsertStridedSliceOp> {
4010public:
4011 using Base::Base;
4012
4013 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4014 PatternRewriter &rewriter) const override {
4015
4016 auto dst = insertStridedSliceOp.getDest();
4017 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
4018 if (!splat || getScalarSplatSource(dst) != splat)
4019 return failure();
4020
4021 rewriter.replaceOp(insertStridedSliceOp, dst);
4022 return success();
4023 }
4024};
4025
4026/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
4027/// to dst.
4028class FoldInsertStridedSliceOfExtract final
4029 : public OpRewritePattern<InsertStridedSliceOp> {
4030public:
4031 using Base::Base;
4032
4033 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4034 PatternRewriter &rewriter) const override {
4035 auto extractStridedSliceOp =
4036 insertStridedSliceOp.getValueToStore()
4037 .getDefiningOp<vector::ExtractStridedSliceOp>();
4038
4039 if (!extractStridedSliceOp)
4040 return failure();
4041
4042 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
4043 return failure();
4044
4045 // Check if have the same strides and offsets.
4046 if (extractStridedSliceOp.getStrides() !=
4047 insertStridedSliceOp.getStrides() ||
4048 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
4049 return failure();
4050
4051 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
4052 return success();
4053 }
4054};
4055
4056// Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
4057// ConstantOp.
4058class InsertStridedSliceConstantFolder final
4059 : public OpRewritePattern<InsertStridedSliceOp> {
4060public:
4061 using Base::Base;
4062
4063 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
4064 // unless the source vector constant has a single use.
4065 static constexpr int64_t vectorSizeFoldThreshold = 256;
4066
4067 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
4068 PatternRewriter &rewriter) const override {
4069 // Return if 'InsertOp' operand is not defined by a compatible vector
4070 // ConstantOp.
4071 TypedValue<VectorType> destVector = op.getDest();
4072 Attribute vectorDestCst;
4073 if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
4074 return failure();
4075
4076 VectorType destTy = destVector.getType();
4077 if (destTy.isScalable())
4078 return failure();
4079
4080 // Make sure we do not create too many large constants.
4081 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
4082 !destVector.hasOneUse())
4083 return failure();
4084
4085 TypedValue<VectorType> sourceValue = op.getValueToStore();
4086 Attribute sourceCst;
4087 if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
4088 return failure();
4089
4090 // TODO: Support poison.
4091 if (matchPattern(vectorDestCst, ub::m_Poison()) ||
4092 matchPattern(sourceCst, ub::m_Poison()))
4093 return failure();
4094
4095 // TODO: Handle non-unit strides when they become available.
4096 if (op.hasNonUnitStrides())
4097 return failure();
4098
4099 VectorType sliceVecTy = sourceValue.getType();
4100 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4101 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
4102 SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
4103 SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
4104
4105 // Calcualte the destination element indices by enumerating all slice
4106 // positions within the destination and linearizing them. The enumeration
4107 // order is lexicographic which yields a sequence of monotonically
4108 // increasing linearized position indices.
4109 // Because the destination may have higher dimensionality then the slice,
4110 // we keep track of two overlapping sets of positions and offsets.
4111 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
4112 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
4113 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
4114 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
4115 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
4116 MutableArrayRef<int64_t> currSlicePosition(
4117 currDestPosition.begin() + rankDifference, currDestPosition.end());
4118 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
4119 offsets.end());
4120 do {
4121 int64_t linearizedPosition = linearize(currDestPosition, destStrides);
4122 assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
4123 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
4124 "Invalid slice element");
4125 newValues[linearizedPosition] = *sliceValuesIt;
4126 ++sliceValuesIt;
4127 } while (succeeded(
4128 incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
4129
4130 auto newAttr = DenseElementsAttr::get(destTy, newValues);
4131 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
4132 return success();
4133 }
4134};
4135
4136} // namespace
4137
4138void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
4139 RewritePatternSet &results, MLIRContext *context) {
4140 results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4141 InsertStridedSliceConstantFolder>(context);
4142}
4143
4144OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4145 if (getSourceVectorType() == getDestVectorType())
4146 return getValueToStore();
4147 return {};
4148}
4149
4150//===----------------------------------------------------------------------===//
4151// OuterProductOp
4152//===----------------------------------------------------------------------===//
4153
4154/// Build an op without mask, use the type of `acc` as the return type.
4155void OuterProductOp::build(OpBuilder &builder, OperationState &result,
4156 Value lhs, Value rhs, Value acc) {
4157 result.addOperands({lhs, rhs, acc});
4158 result.addTypes(acc.getType());
4159}
4160
4161void OuterProductOp::print(OpAsmPrinter &p) {
4162 p << " " << getLhs() << ", " << getRhs();
4163 if (getAcc()) {
4164 p << ", " << getAcc();
4165 p.printOptionalAttrDict((*this)->getAttrs());
4166 }
4167 p << " : " << getLhs().getType() << ", " << getRhs().getType();
4168}
4169
4170ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
4171 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4172 Type tLHS, tRHS;
4173 if (parser.parseOperandList(operandsInfo) ||
4174 parser.parseOptionalAttrDict(result.attributes) ||
4175 parser.parseColonType(tLHS) || parser.parseComma() ||
4176 parser.parseType(tRHS))
4177 return failure();
4178 if (operandsInfo.size() < 2)
4179 return parser.emitError(parser.getNameLoc(),
4180 "expected at least 2 operands");
4181 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4182 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4183 if (!vLHS)
4184 return parser.emitError(parser.getNameLoc(),
4185 "expected vector type for operand #1");
4186
4187 VectorType resType;
4188 if (vRHS) {
4189 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4190 vRHS.getScalableDims()[0]};
4191 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4192 vLHS.getElementType(), scalableDimsRes);
4193 } else {
4194 // Scalar RHS operand
4195 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4196 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4197 scalableDimsRes);
4198 }
4199
4200 if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
4201 result.attributes.append(
4202 OuterProductOp::getKindAttrName(result.name),
4203 CombiningKindAttr::get(result.getContext(),
4204 OuterProductOp::getDefaultKind()));
4205 }
4206
4207 return failure(
4208 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
4209 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
4210 (operandsInfo.size() > 2 &&
4211 parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
4212 parser.addTypeToList(resType, result.types));
4213}
4214
4215LogicalResult OuterProductOp::verify() {
4216 Type tRHS = getOperandTypeRHS();
4217 VectorType vLHS = getOperandVectorTypeLHS(),
4218 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4219 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4220
4221 if (vLHS.getRank() != 1)
4222 return emitOpError("expected 1-d vector for operand #1");
4223
4224 if (vRHS) {
4225 // Proper OUTER operation.
4226 if (vRHS.getRank() != 1)
4227 return emitOpError("expected 1-d vector for operand #2");
4228 if (vRES.getRank() != 2)
4229 return emitOpError("expected 2-d vector result");
4230 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4231 return emitOpError("expected #1 operand dim to match result dim #1");
4232 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4233 return emitOpError("expected #2 operand dim to match result dim #2");
4234 if (vLHS.isScalable() && !vRHS.isScalable()) {
4235 // This restriction reflects what's currently supported in terms of
4236 // scalable vectors. However, we could relax this if there's a use case.
4237 return emitOpError(
4238 "expected either both or only #2 operand dim to be scalable");
4239 }
4240 } else {
4241 // An AXPY operation.
4242 if (vRES.getRank() != 1)
4243 return emitOpError("expected 1-d vector result");
4244 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4245 return emitOpError("expected #1 operand dim to match result dim #1");
4246 }
4247
4248 if (vACC && vACC != vRES)
4249 return emitOpError("expected operand #3 of same type as result type");
4250
4251 if (!getKindAttr()) {
4252 return emitOpError("expected 'kind' attribute of type CombiningKind (e.g. "
4253 "'vector.kind<add>')");
4254 }
4255
4256 // Verify supported combining kind.
4257 if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
4258 return emitOpError("unsupported outerproduct type");
4259
4260 return success();
4261}
4262
4263// MaskableOpInterface methods.
4264
4265/// Returns the mask type expected by this operation. Mostly used for
4266/// verification purposes. It requires the operation to be vectorized."
4267Type OuterProductOp::getExpectedMaskType() {
4268 auto vecType = this->getResultVectorType();
4269 return VectorType::get(vecType.getShape(),
4270 IntegerType::get(vecType.getContext(), /*width=*/1),
4271 vecType.getScalableDims());
4272}
4273
4274//===----------------------------------------------------------------------===//
4275// ExtractStridedSliceOp
4276//===----------------------------------------------------------------------===//
4277
4278// Inference works as follows:
4279// 1. Add 'sizes' from prefix of dims in 'offsets'.
4280// 2. Add sizes from 'vectorType' for remaining dims.
4281// Scalable flags are inherited from 'vectorType'.
4282static Type inferStridedSliceOpResultType(VectorType vectorType,
4283 ArrayAttr offsets, ArrayAttr sizes,
4284 ArrayAttr strides) {
4285 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4287 shape.reserve(vectorType.getRank());
4288 unsigned idx = 0;
4289 for (unsigned e = offsets.size(); idx < e; ++idx)
4290 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4291 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4292 shape.push_back(vectorType.getShape()[idx]);
4293
4294 return VectorType::get(shape, vectorType.getElementType(),
4295 vectorType.getScalableDims());
4296}
4297
4298void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
4299 Value source, ArrayRef<int64_t> offsets,
4300 ArrayRef<int64_t> sizes,
4301 ArrayRef<int64_t> strides) {
4302 result.addOperands(source);
4303 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
4304 auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
4305 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
4306 result.addTypes(
4307 inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
4308 offsetsAttr, sizesAttr, stridesAttr));
4309 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
4310 offsetsAttr);
4311 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
4312 sizesAttr);
4313 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
4314 stridesAttr);
4315}
4316
4317LogicalResult ExtractStridedSliceOp::verify() {
4318 auto type = getSourceVectorType();
4319 auto offsets = getOffsetsAttr();
4320 auto sizes = getSizesAttr();
4321 auto strides = getStridesAttr();
4322 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4323 return emitOpError(
4324 "expected offsets, sizes and strides attributes of same size");
4325
4326 auto shape = type.getShape();
4327 auto offName = getOffsetsAttrName();
4328 auto sizesName = getSizesAttrName();
4329 auto stridesName = getStridesAttrName();
4330 if (failed(
4331 isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
4332 failed(
4333 isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
4334 failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
4335 stridesName)) ||
4336 failed(
4337 isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
4338 failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
4339 /*halfOpen=*/false,
4340 /*min=*/1)) ||
4341 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
4342 /*max=*/1, stridesName,
4343 /*halfOpen=*/false)) ||
4344 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
4345 shape, offName, sizesName,
4346 /*halfOpen=*/false)))
4347 return failure();
4348
4349 auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
4350 offsets, sizes, strides);
4351 if (getResult().getType() != resultType)
4352 return emitOpError("expected result type to be ") << resultType;
4353
4354 for (unsigned idx = 0; idx < sizes.size(); ++idx) {
4355 if (type.getScalableDims()[idx]) {
4356 auto inputDim = type.getShape()[idx];
4357 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4358 if (inputDim != inputSize)
4359 return emitOpError("expected size at idx=")
4360 << idx
4361 << (" to match the corresponding base size from the input "
4362 "vector (")
4363 << inputSize << (" vs ") << inputDim << (")");
4364 }
4365 }
4366
4367 return success();
4368}
4369
4370// When the source of ExtractStrided comes from a chain of InsertStrided ops try
4371// to use the source of the InsertStrided ops if we can detect that the
4372// extracted vector is a subset of one of the vector inserted.
4373static LogicalResult
4374foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
4375 // Helper to extract integer out of ArrayAttr.
4376 auto getElement = [](ArrayAttr array, int idx) {
4377 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4378 };
4379 ArrayAttr extractOffsets = op.getOffsets();
4380 ArrayAttr extractStrides = op.getStrides();
4381 ArrayAttr extractSizes = op.getSizes();
4382 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4383 while (insertOp) {
4384 if (op.getSourceVectorType().getRank() !=
4385 insertOp.getSourceVectorType().getRank())
4386 return failure();
4387 ArrayAttr insertOffsets = insertOp.getOffsets();
4388 ArrayAttr insertStrides = insertOp.getStrides();
4389 // If the rank of extract is greater than the rank of insert, we are likely
4390 // extracting a partial chunk of the vector inserted.
4391 if (extractOffsets.size() > insertOffsets.size())
4392 return failure();
4393 bool patialoverlap = false;
4394 bool disjoint = false;
4395 SmallVector<int64_t, 4> offsetDiffs;
4396 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4397 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
4398 return failure();
4399 int64_t start = getElement(insertOffsets, dim);
4400 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4401 int64_t offset = getElement(extractOffsets, dim);
4402 int64_t size = getElement(extractSizes, dim);
4403 // Check if the start of the extract offset is in the interval inserted.
4404 if (start <= offset && offset < end) {
4405 // If the extract interval overlaps but is not fully included we may
4406 // have a partial overlap that will prevent any folding.
4407 if (offset + size > end)
4408 patialoverlap = true;
4409 offsetDiffs.push_back(offset - start);
4410 continue;
4411 }
4412 disjoint = true;
4413 break;
4414 }
4415 // The extract element chunk is a subset of the insert element.
4416 if (!disjoint && !patialoverlap) {
4417 op.setOperand(insertOp.getValueToStore());
4418 // OpBuilder is only used as a helper to build an I64ArrayAttr.
4419 OpBuilder b(op.getContext());
4420 op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
4421 return success();
4422 }
4423 // If the chunk extracted is disjoint from the chunk inserted, keep looking
4424 // in the insert chain.
4425 if (disjoint)
4426 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4427 else {
4428 // The extracted vector partially overlap the inserted vector, we cannot
4429 // fold.
4430 return failure();
4431 }
4432 }
4433 return failure();
4434}
4435
4436// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4437static OpFoldResult
4439 Attribute foldInput) {
4440
4441 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4442 if (!dense)
4443 return {};
4444
4445 // TODO: Handle non-unit strides when they become available.
4446 if (op.hasNonUnitStrides())
4447 return {};
4448
4449 VectorType sourceVecTy = op.getSourceVectorType();
4450 ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
4451 SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
4452
4453 VectorType sliceVecTy = op.getType();
4454 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4455 int64_t rank = sliceVecTy.getRank();
4456
4457 // Expand offsets and sizes to match the vector rank.
4458 SmallVector<int64_t, 4> offsets(rank, 0);
4459 copy(getI64SubArray(op.getOffsets()), offsets.begin());
4460
4461 SmallVector<int64_t, 4> sizes(sourceShape);
4462 copy(getI64SubArray(op.getSizes()), sizes.begin());
4463
4464 // Calculate the slice elements by enumerating all slice positions and
4465 // linearizing them. The enumeration order is lexicographic which yields a
4466 // sequence of monotonically increasing linearized position indices.
4467 const auto denseValuesBegin = dense.value_begin<Attribute>();
4468 SmallVector<Attribute> sliceValues;
4469 sliceValues.reserve(sliceVecTy.getNumElements());
4470 SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
4471 do {
4472 int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
4473 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4474 "Invalid index");
4475 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4476 } while (succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
4477
4478 assert(static_cast<int64_t>(sliceValues.size()) ==
4479 sliceVecTy.getNumElements() &&
4480 "Invalid number of slice elements");
4481 return DenseElementsAttr::get(sliceVecTy, sliceValues);
4482}
4483
4484OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4485 if (getSourceVectorType() == getResult().getType())
4486 return getSource();
4487 if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
4488 return getResult();
4489
4490 // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
4491 if (auto splat =
4492 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4493 return DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
4494
4495 // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4496 return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getSource());
4497}
4498
4499void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4500 populateFromInt64AttrArray(getOffsets(), results);
4501}
4502
4503namespace {
4504
4505// Pattern to rewrite nested ExtractStridedSliceOp into a single one.
4506//
4507// Example:
4508//
4509// %0 = vector.extract_strided_slice %arg0
4510// {offsets = [1, 2], sizes = [3, 4], strides = [1, 1]}
4511// : vector<4x8x16xf32> to vector<3x4x16xf32>
4512// %1 = vector.extract_strided_slice %0
4513// {offsets = [0, 1], sizes = [2, 2], strides = [1, 1]}
4514// : vector<3x4x16xf32> to vector<2x2x16xf32>
4515//
4516// to
4517//
4518// %1 = vector.extract_strided_slice %arg0
4519// {offsets = [1, 3], sizes = [2, 2], strides = [1, 1]}
4520// : vector<4x8x16xf32> to vector<2x2x16xf32>
4521class StridedSliceFolder final
4522 : public OpRewritePattern<ExtractStridedSliceOp> {
4523public:
4524 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
4525
4526 LogicalResult matchAndRewrite(ExtractStridedSliceOp secondOp,
4527 PatternRewriter &rewriter) const override {
4528 auto firstOp = secondOp.getSource().getDefiningOp<ExtractStridedSliceOp>();
4529 if (!firstOp)
4530 return failure();
4531
4532 if (secondOp.hasNonUnitStrides() || firstOp.hasNonUnitStrides())
4533 return failure();
4534
4535 SmallVector<int64_t> firstOffsets = getI64SubArray(firstOp.getOffsets());
4536 SmallVector<int64_t> firstSizes = getI64SubArray(firstOp.getSizes());
4537 SmallVector<int64_t> secondOffsets = getI64SubArray(secondOp.getOffsets());
4538 SmallVector<int64_t> secondSizes = getI64SubArray(secondOp.getSizes());
4539
4540 unsigned newRank = std::max(firstOffsets.size(), secondOffsets.size());
4541 SmallVector<int64_t> combinedOffsets(newRank, 0);
4542 SmallVector<int64_t> combinedSizes(newRank);
4543 ArrayRef<int64_t> firstSourceShape =
4544 firstOp.getSourceVectorType().getShape();
4545 for (unsigned i = 0; i < newRank; ++i) {
4546 int64_t off1 = (i < firstOffsets.size()) ? firstOffsets[i] : 0;
4547 int64_t off2 = (i < secondOffsets.size()) ? secondOffsets[i] : 0;
4548 combinedOffsets[i] = off1 + off2;
4549
4550 if (i < secondSizes.size()) {
4551 combinedSizes[i] = secondSizes[i];
4552 } else if (i < firstSizes.size()) {
4553 combinedSizes[i] = firstSizes[i];
4554 } else {
4555 combinedSizes[i] = firstSourceShape[i];
4556 }
4557 }
4558
4559 SmallVector<int64_t> combinedStrides(newRank, 1);
4560 rewriter.replaceOpWithNewOp<ExtractStridedSliceOp>(
4561 secondOp, firstOp.getSource(), combinedOffsets, combinedSizes,
4562 combinedStrides);
4563 return success();
4564 }
4565};
4566
4567// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
4568// CreateMaskOp.
4569//
4570// Example:
4571//
4572// %mask = vector.create_mask %ub : vector<16xi1>
4573// %slice = vector.extract_strided_slice [%offset] [8] [1]
4574//
4575// to
4576//
4577// %new_ub = arith.subi %ub, %offset
4578// %mask = vector.create_mask %new_ub : vector<8xi1>
4579class StridedSliceCreateMaskFolder final
4580 : public OpRewritePattern<ExtractStridedSliceOp> {
4581 using Base::Base;
4582
4583public:
4584 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4585 PatternRewriter &rewriter) const override {
4586 Location loc = extractStridedSliceOp.getLoc();
4587 // Return if 'extractStridedSliceOp' operand is not defined by a
4588 // CreateMaskOp.
4589 auto createMaskOp =
4590 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4591 if (!createMaskOp)
4592 return failure();
4593 // Return if 'extractStridedSliceOp' has non-unit strides.
4594 if (extractStridedSliceOp.hasNonUnitStrides())
4595 return failure();
4596 // Gather constant mask dimension sizes.
4597 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4598 // Gather strided slice offsets and sizes.
4599 SmallVector<int64_t> sliceOffsets;
4600 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4601 sliceOffsets);
4602 SmallVector<int64_t> sliceSizes;
4603 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4604
4605 // Compute slice of vector mask region.
4606 SmallVector<Value> sliceMaskDimSizes;
4607 sliceMaskDimSizes.reserve(maskDimSizes.size());
4608 // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4609 // only iterate on the leading dim sizes. The tail accounts for the
4610 // remaining dim sizes.
4611 for (auto [maskDimSize, sliceOffset, sliceSize] :
4612 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4613 // No need to clamp on min/max values, because create_mask has clamping
4614 // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
4615 // greater than the vector dim size.
4616 IntegerAttr offsetAttr =
4617 rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
4618 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4619 Value sliceMaskDimSize =
4620 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4621 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4622 }
4623 // Add unchanged dimensions.
4624 llvm::append_range(
4625 sliceMaskDimSizes,
4626 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4627 // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
4628 // region.
4629 rewriter.replaceOpWithNewOp<CreateMaskOp>(
4630 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4631 sliceMaskDimSizes);
4632 return success();
4633 }
4634};
4635
4636// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
4637// ConstantMaskOp.
4638class StridedSliceConstantMaskFolder final
4639 : public OpRewritePattern<ExtractStridedSliceOp> {
4640public:
4641 using Base::Base;
4642
4643 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4644 PatternRewriter &rewriter) const override {
4645 // Return if 'extractStridedSliceOp' operand is not defined by a
4646 // ConstantMaskOp.
4647 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4648 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4649 if (!constantMaskOp)
4650 return failure();
4651 // Return if 'extractStridedSliceOp' has non-unit strides.
4652 if (extractStridedSliceOp.hasNonUnitStrides())
4653 return failure();
4654 // Gather constant mask dimension sizes.
4655 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4656 // Gather strided slice offsets and sizes.
4657 SmallVector<int64_t> sliceOffsets;
4658 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4659 sliceOffsets);
4660 SmallVector<int64_t> sliceSizes;
4661 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4662
4663 // Compute slice of vector mask region.
4664 SmallVector<int64_t> sliceMaskDimSizes;
4665 sliceMaskDimSizes.reserve(maskDimSizes.size());
4666 for (auto [maskDimSize, sliceOffset, sliceSize] :
4667 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4668 int64_t sliceMaskDimSize = std::max(
4669 static_cast<int64_t>(0),
4670 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4671 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4672 }
4673 // Add unchanged dimensions.
4674 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4675 for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4676 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4677 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
4678 // region is a conjunction of mask dim intervals).
4679 if (llvm::is_contained(sliceMaskDimSizes, 0))
4680 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4681
4682 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
4683 // region.
4684 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4685 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4686 sliceMaskDimSizes);
4687 return success();
4688 }
4689};
4690
4691// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
4692// BroadcastOp(ExtractStrideSliceOp).
4693class StridedSliceBroadcast final
4694 : public OpRewritePattern<ExtractStridedSliceOp> {
4695public:
4696 using Base::Base;
4697
4698 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4699 PatternRewriter &rewriter) const override {
4700 auto broadcast = op.getSource().getDefiningOp<BroadcastOp>();
4701 if (!broadcast)
4702 return failure();
4703 auto srcVecType =
4704 llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
4705 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4706 auto dstVecType = llvm::cast<VectorType>(op.getType());
4707 unsigned dstRank = dstVecType.getRank();
4708 unsigned rankDiff = dstRank - srcRank;
4709 // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4710 // (n -> m with n > m). If they are originally both broadcasted *and*
4711 // sliced, this can be simplified to just broadcasting.
4712 bool needsSlice = false;
4713 for (unsigned i = 0; i < srcRank; i++) {
4714 if (srcVecType.getDimSize(i) != 1 &&
4715 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4716 needsSlice = true;
4717 break;
4718 }
4719 }
4720 Value source = broadcast.getSource();
4721 if (needsSlice) {
4722 SmallVector<int64_t> offsets =
4723 getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
4724 SmallVector<int64_t> sizes =
4725 getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
4726 for (unsigned i = 0; i < srcRank; i++) {
4727 if (srcVecType.getDimSize(i) == 1) {
4728 // In case this dimension was broadcasted *and* sliced, the offset
4729 // and size need to be updated now that there is no broadcast before
4730 // the slice.
4731 offsets[i] = 0;
4732 sizes[i] = 1;
4733 }
4734 }
4735 source = ExtractStridedSliceOp::create(
4736 rewriter, op->getLoc(), source, offsets, sizes,
4737 getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
4738 }
4739 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
4740 return success();
4741 }
4742};
4743
4744/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
4745class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
4746public:
4747 using Base::Base;
4748
4749 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4750 PatternRewriter &rewriter) const override {
4751
4752 Value splat = getScalarSplatSource(op.getSource());
4753 if (!splat)
4754 return failure();
4755 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
4756 return success();
4757 }
4758};
4759
4760/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
4761/// slice is contiguous, into extract and shape_cast.
4762///
4763/// Example:
4764/// Before:
4765/// %1 = vector.extract_strided_slice %arg0 {
4766/// offsets = [0, 0, 0, 0, 0],
4767/// sizes = [1, 1, 1, 1, 8],
4768/// strides = [1, 1, 1, 1, 1]
4769/// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
4770/// After:
4771/// %0 = vector.extract %arg0[0, 0, 0, 0]
4772/// : vector<8xi8> from vector<8x1x1x2x8xi8>
4773/// %1 = vector.shape_cast %0
4774/// : vector<8xi8> to vector<1x1x1x1x8xi8>
4775///
4776class ContiguousExtractStridedSliceToExtract final
4777 : public OpRewritePattern<ExtractStridedSliceOp> {
4778public:
4779 using Base::Base;
4780
4781 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4782 PatternRewriter &rewriter) const override {
4783 if (op.hasNonUnitStrides())
4784 return failure();
4785 Value source = op.getOperand();
4786 auto sourceType = cast<VectorType>(source.getType());
4787 if (sourceType.isScalable() || sourceType.getRank() == 0)
4788 return failure();
4789
4790 // Compute the number of offsets to pass to ExtractOp::build. That is the
4791 // difference between the source rank and the desired slice rank. We walk
4792 // the dimensions from innermost out, and stop when the next slice dimension
4793 // is not full-size.
4794 SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
4795 int numOffsets;
4796 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4797 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4798 break;
4799 }
4800
4801 // If the created extract op would have no offsets, then this whole
4802 // extract_strided_slice is the identity and should have been handled by
4803 // other canonicalizations.
4804 if (numOffsets == 0)
4805 return failure();
4806
4807 // If not even the inner-most dimension is full-size, this op can't be
4808 // rewritten as an ExtractOp.
4809 if (numOffsets == sourceType.getRank() &&
4810 static_cast<int>(sizes.size()) == sourceType.getRank())
4811 return failure();
4812
4813 // The outer dimensions must have unit size.
4814 for (int i = 0; i < numOffsets; ++i) {
4815 if (sizes[i] != 1)
4816 return failure();
4817 }
4818
4819 // Avoid generating slices that have leading unit dimensions. The shape_cast
4820 // op that we create below would take bad generic fallback patterns
4821 // (ShapeCastOpRewritePattern).
4822 while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
4823 sizes[numOffsets] == 1) {
4824 ++numOffsets;
4825 }
4826
4827 SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
4828 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4829 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4830 extractOffsets);
4831 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
4832 return success();
4833 }
4834};
4835
4836} // namespace
4837
4838void ExtractStridedSliceOp::getCanonicalizationPatterns(
4839 RewritePatternSet &results, MLIRContext *context) {
4840 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
4841 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4842 results.add<StridedSliceFolder, StridedSliceCreateMaskFolder,
4843 StridedSliceConstantMaskFolder, StridedSliceBroadcast,
4844 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4845 context);
4846}
4847
4848//===----------------------------------------------------------------------===//
4849// TransferReadOp
4850//===----------------------------------------------------------------------===//
4851
4852/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
4853void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4854 VectorType vectorType, Value source,
4855 ValueRange indices, std::optional<Value> padding,
4856 AffineMapAttr permutationMapAttr,
4857 /*optional*/ ArrayAttr inBoundsAttr) {
4858
4859 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4860 if (!padding)
4861 padding = ub::PoisonOp::create(builder, result.location, elemType);
4862 build(builder, result, vectorType, source, indices, permutationMapAttr,
4863 *padding, /*mask=*/Value(), inBoundsAttr);
4864}
4865
4866/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
4867void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4868 VectorType vectorType, Value source,
4869 ValueRange indices, std::optional<Value> padding,
4870 AffineMap permutationMap,
4871 std::optional<ArrayRef<bool>> inBounds) {
4872 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4873 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4874 ? builder.getBoolArrayAttr(inBounds.value())
4875 : builder.getBoolArrayAttr(
4876 SmallVector<bool>(vectorType.getRank(), false));
4877 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4878 if (!padding)
4879 padding = ub::PoisonOp::create(builder, result.location, elemType);
4880 build(builder, result, vectorType, source, indices, *padding,
4881 permutationMapAttr, inBoundsAttr);
4882}
4883
4884/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
4885void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4886 VectorType vectorType, Value source,
4887 ValueRange indices, std::optional<Value> padding,
4888 std::optional<ArrayRef<bool>> inBounds) {
4889 AffineMap permutationMap = getTransferMinorIdentityMap(
4890 llvm::cast<ShapedType>(source.getType()), vectorType);
4891 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4892 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4893 ? builder.getBoolArrayAttr(inBounds.value())
4894 : builder.getBoolArrayAttr(
4895 SmallVector<bool>(vectorType.getRank(), false));
4896 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4897 if (!padding)
4898 padding = ub::PoisonOp::create(builder, result.location, elemType);
4899 build(builder, result, vectorType, source, indices, permutationMapAttr,
4900 *padding,
4901 /*mask=*/Value(), inBoundsAttr);
4902}
4903
4904template <typename EmitFun>
4905static LogicalResult verifyPermutationMap(AffineMap permutationMap,
4906 EmitFun emitOpError) {
4907 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
4908 for (auto expr : permutationMap.getResults()) {
4909 auto dim = dyn_cast<AffineDimExpr>(expr);
4910 auto zero = dyn_cast<AffineConstantExpr>(expr);
4911 if (zero) {
4912 if (zero.getValue() != 0) {
4913 return emitOpError(
4914 "requires a projected permutation_map (at most one dim or the zero "
4915 "constant can appear in each result)");
4916 }
4917 continue;
4918 }
4919 if (!dim) {
4920 return emitOpError("requires a projected permutation_map (at most one "
4921 "dim or the zero constant can appear in each result)");
4922 }
4923 if (seen[dim.getPosition()]) {
4924 return emitOpError(
4925 "requires a permutation_map that is a permutation (found one dim "
4926 "used more than once)");
4927 }
4928 seen[dim.getPosition()] = true;
4929 }
4930 return success();
4931}
4932
4933static LogicalResult
4934verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4935 VectorType vectorType, VectorType maskType,
4936 VectorType inferredMaskType, AffineMap permutationMap,
4937 ArrayAttr inBounds) {
4938 if (op->hasAttr("masked")) {
4939 return op->emitOpError("masked attribute has been removed. "
4940 "Use in_bounds instead.");
4941 }
4942
4943 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4944 return op->emitOpError(
4945 "requires source to be a memref or ranked tensor type");
4946
4947 auto elementType = shapedType.getElementType();
4948 DataLayout dataLayout = DataLayout::closest(op);
4949 if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4950 // Memref or tensor has vector element type.
4951 unsigned sourceVecSize =
4952 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
4953 vectorElementType.getShape().back();
4954 unsigned resultVecSize =
4955 dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
4956 vectorType.getShape().back();
4957 if (resultVecSize % sourceVecSize != 0)
4958 return op->emitOpError(
4959 "requires the bitwidth of the minor 1-D vector to be an integral "
4960 "multiple of the bitwidth of the minor 1-D vector of the source");
4961
4962 unsigned sourceVecEltRank = vectorElementType.getRank();
4963 unsigned resultVecRank = vectorType.getRank();
4964 if (sourceVecEltRank > resultVecRank)
4965 return op->emitOpError(
4966 "requires source vector element and vector result ranks to match.");
4967 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4968 // Check that permutation map results match 'rankOffset' of vector type.
4969 if (permutationMap.getNumResults() != rankOffset)
4970 return op->emitOpError("requires a permutation_map with result dims of "
4971 "the same rank as the vector type");
4972
4973 if (maskType)
4974 return op->emitOpError("does not support masks with vector element type");
4975 } else {
4976 // Memref or tensor has scalar element type.
4977 unsigned minorSize =
4978 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4979 unsigned resultVecSize =
4980 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
4981 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
4982 return op->emitOpError(
4983 "requires the bitwidth of the minor 1-D vector to be an integral "
4984 "multiple of the bitwidth of the source element type");
4985
4986 // Check that permutation map results match rank of vector type.
4987 if (permutationMap.getNumResults() != vectorType.getRank())
4988 return op->emitOpError("requires a permutation_map with result dims of "
4989 "the same rank as the vector type");
4990 }
4991
4992 if (permutationMap.getNumSymbols() != 0)
4993 return op->emitOpError("requires permutation_map without symbols");
4994
4995 if (permutationMap.getNumInputs() != shapedType.getRank())
4996 return op->emitOpError("requires a permutation_map with input dims of the "
4997 "same rank as the source type");
4998
4999 if (maskType && maskType != inferredMaskType)
5000 return op->emitOpError("inferred mask type (")
5001 << inferredMaskType << ") and mask operand type (" << maskType
5002 << ") don't match";
5003
5004 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
5005 return op->emitOpError("expects the in_bounds attr of same rank "
5006 "as permutation_map results: ")
5007 << AffineMapAttr::get(permutationMap)
5008 << " vs inBounds of size: " << inBounds.size();
5009
5010 return success();
5011}
5012
5013static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
5014 SmallVector<StringRef, 3> elidedAttrs;
5015 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
5016 if (op.getPermutationMap().isMinorIdentity())
5017 elidedAttrs.push_back(op.getPermutationMapAttrName());
5018 // Elide in_bounds attribute if all dims are out-of-bounds.
5019 if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
5020 elidedAttrs.push_back(op.getInBoundsAttrName());
5021 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
5022}
5023
5024void TransferReadOp::print(OpAsmPrinter &p) {
5025 p << " " << getBase() << "[" << getIndices() << "], " << getPadding();
5026 if (getMask())
5027 p << ", " << getMask();
5028 printTransferAttrs(p, *this);
5029 p << " : " << getShapedType() << ", " << getVectorType();
5030}
5031
5032VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
5033 AffineMap permMap) {
5034 auto i1Type = IntegerType::get(permMap.getContext(), 1);
5035 AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
5036 assert(invPermMap && "Inversed permutation map couldn't be computed");
5037 SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
5038
5039 // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
5040 // 0-D mask into a single-element 1-D mask.
5041 if (maskShape.empty())
5042 maskShape.push_back(1);
5043
5044 SmallVector<bool> scalableDims =
5045 applyPermutationMap(invPermMap, vecType.getScalableDims());
5046
5047 return VectorType::get(maskShape, i1Type, scalableDims);
5048}
5049
5050ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
5051 auto &builder = parser.getBuilder();
5052 SMLoc typesLoc;
5058 // Parsing with support for paddingValue.
5059 if (parser.parseOperand(sourceInfo) ||
5061 parser.parseComma() || parser.parseOperand(paddingInfo))
5062 return failure();
5063 ParseResult hasMask = parser.parseOptionalComma();
5064 if (hasMask.succeeded()) {
5065 if (parser.parseOperand(maskInfo))
5066 return failure();
5067 }
5068 if (parser.parseOptionalAttrDict(result.attributes) ||
5069 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5070 return failure();
5071 if (types.size() != 2)
5072 return parser.emitError(typesLoc, "requires two types");
5073 auto indexType = builder.getIndexType();
5074 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
5075 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5076 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5077 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
5078 if (!vectorType)
5079 return parser.emitError(typesLoc, "requires vector type");
5080 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
5081 Attribute permMapAttr = result.attributes.get(permMapAttrName);
5082 AffineMap permMap;
5083 if (!permMapAttr) {
5084 if (shapedType.getRank() <
5085 getEffectiveVectorRankForXferOp(shapedType, vectorType))
5086 return parser.emitError(typesLoc,
5087 "expected a custom permutation_map when "
5088 "rank(source) != rank(destination)");
5089 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5090 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5091 } else {
5092 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5093 }
5094 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
5095 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5096 if (!inBoundsAttr) {
5097 result.addAttribute(inBoundsAttrName,
5098 builder.getBoolArrayAttr(
5099 SmallVector<bool>(permMap.getNumResults(), false)));
5100 }
5101 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5102 parser.resolveOperands(indexInfo, indexType, result.operands) ||
5103 parser.resolveOperand(paddingInfo, shapedType.getElementType(),
5104 result.operands))
5105 return failure();
5106 if (hasMask.succeeded()) {
5107 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5108 return parser.emitError(
5109 maskInfo.location, "does not support masks with vector element type");
5110 if (vectorType.getRank() != permMap.getNumResults()) {
5111 return parser.emitError(typesLoc,
5112 "expected the same rank for the vector and the "
5113 "results of the permutation map");
5114 }
5115 // Instead of adding the mask type as an op type, compute it based on the
5116 // vector type and the permutation map (to keep the type signature small).
5117 auto maskType = inferTransferOpMaskType(vectorType, permMap);
5118 if (parser.resolveOperand(maskInfo, maskType, result.operands))
5119 return failure();
5120 }
5121 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
5122 builder.getDenseI32ArrayAttr(
5123 {1, static_cast<int32_t>(indexInfo.size()), 1,
5124 static_cast<int32_t>(hasMask.succeeded())}));
5125 return parser.addTypeToList(vectorType, result.types);
5126}
5127
5128LogicalResult TransferReadOp::verify() {
5129 // Consistency of elemental types in source and vector.
5130 ShapedType shapedType = getShapedType();
5131 VectorType vectorType = getVectorType();
5132 VectorType maskType = getMaskType();
5133 auto paddingType = getPadding().getType();
5134 auto permutationMap = getPermutationMap();
5135 VectorType inferredMaskType =
5136 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5137 : VectorType();
5138 auto sourceElementType = shapedType.getElementType();
5139
5140 if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
5141 return emitOpError("requires ") << shapedType.getRank() << " indices";
5142
5143 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5144 shapedType, vectorType, maskType,
5145 inferredMaskType, permutationMap, getInBounds())))
5146 return failure();
5147
5148 if (auto sourceVectorElementType =
5149 llvm::dyn_cast<VectorType>(sourceElementType)) {
5150 // Source has vector element type.
5151 // Check that 'sourceVectorElementType' and 'paddingType' types match.
5152 if (sourceVectorElementType != paddingType)
5153 return emitOpError(
5154 "requires source element type and padding type to match.");
5155
5156 } else {
5157 // Check that 'paddingType' is valid to store in a vector type.
5158 if (!VectorType::isValidElementType(paddingType))
5159 return emitOpError("requires valid padding vector elemental type");
5160
5161 // Check that padding type and vector element types match.
5162 if (paddingType != sourceElementType)
5163 return emitOpError(
5164 "requires formal padding and source of the same elemental type");
5165 }
5166
5167 return verifyPermutationMap(permutationMap,
5168 [&](Twine t) { return emitOpError(t); });
5169}
5170
5171// MaskableOpInterface methods.
5172
5173/// Returns the mask type expected by this operation. Mostly used for
5174/// verification purposes. It requires the operation to be vectorized."
5175Type TransferReadOp::getExpectedMaskType() {
5176 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5177}
5178
5179//===----------------------------------------------------------------------===//
5180// TransferReadOp: VectorTransferOpInterface methods.
5181//===----------------------------------------------------------------------===//
5182VectorType TransferReadOp::getVectorType() {
5183 return cast<VectorType>(getVector().getType());
5184}
5185
5186template <typename TransferOp>
5187static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
5188 // TODO: support more aggressive createOrFold on:
5189 // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
5190 if (op.getShapedType().isDynamicDim(indicesIdx))
5191 return false;
5192 Value index = op.getIndices()[indicesIdx];
5193 std::optional<int64_t> cstOp = getConstantIntValue(index);
5194 if (!cstOp.has_value())
5195 return false;
5196
5197 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
5198 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
5199
5200 return cstOp.value() + vectorSize <= sourceSize;
5201}
5202
5203template <typename TransferOp>
5204static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
5205 // TODO: support 0-d corner case.
5206 // TODO: Be less conservative.
5207 if (op.getTransferRank() == 0)
5208 return failure();
5209 AffineMap permutationMap = op.getPermutationMap();
5210 bool changed = false;
5211 SmallVector<bool, 4> newInBounds;
5212 newInBounds.reserve(op.getTransferRank());
5213 // Idxs of non-bcast dims - used when analysing bcast dims.
5214 SmallVector<unsigned> nonBcastDims;
5215
5216 // 1. Process non-broadcast dims
5217 for (unsigned i = 0; i < op.getTransferRank(); ++i) {
5218 // 1.1. Already marked as in-bounds, nothing to see here.
5219 if (op.isDimInBounds(i)) {
5220 newInBounds.push_back(true);
5221 continue;
5222 }
5223 // 1.2. Currently out-of-bounds, check whether we can statically determine
5224 // it is inBounds.
5225 bool inBounds = false;
5226 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
5227 if (dimExpr) {
5228 inBounds = isInBounds(op, /*resultIdx=*/i,
5229 /*indicesIdx=*/dimExpr.getPosition());
5230 nonBcastDims.push_back(i);
5231 }
5232
5233 newInBounds.push_back(inBounds);
5234 // We commit the pattern if it is "more inbounds".
5235 changed |= inBounds;
5236 }
5237
5238 // 2. Handle broadcast dims
5239 // If all non-broadcast dims are "in bounds", then all bcast dims should be
5240 // "in bounds" as well.
5241 bool allNonBcastDimsInBounds = llvm::all_of(
5242 nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
5243 if (allNonBcastDimsInBounds) {
5244 for (size_t idx : permutationMap.getBroadcastDims()) {
5245 changed |= !newInBounds[idx];
5246 newInBounds[idx] = true;
5247 }
5248 }
5249
5250 if (!changed)
5251 return failure();
5252 // OpBuilder is only used as a helper to build an I64ArrayAttr.
5253 OpBuilder b(op.getContext());
5254 op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
5255 return success();
5256}
5257
5258template <typename TransferOp>
5259static LogicalResult foldTransferFullMask(TransferOp op) {
5260 auto mask = op.getMask();
5261 if (!mask)
5262 return failure();
5263
5265 return failure();
5266
5267 op.getMaskMutable().clear();
5268 return success();
5269}
5270
5271/// ```
5272/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5273/// : vector<1x4xf32>, tensor<4x4xf32>
5274/// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
5275/// : tensor<4x4xf32>, vector<1x4xf32>
5276/// ```
5277/// -> Folds into
5278/// ```
5279/// %v0
5280/// ```
5281static Value foldRAW(TransferReadOp readOp) {
5282 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5283 return {};
5284 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5285 while (defWrite) {
5286 if (checkSameValueRAW(defWrite, readOp))
5287 return defWrite.getVector();
5289 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5290 cast<VectorTransferOpInterface>(readOp.getOperation())))
5291 break;
5292 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5293 }
5294 return {};
5295}
5296
5297OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5298 if (Value vec = foldRAW(*this))
5299 return vec;
5300 /// transfer_read(memrefcast) -> transfer_read
5301 if (succeeded(foldTransferInBoundsAttribute(*this)))
5302 return getResult();
5303 if (succeeded(foldTransferFullMask(*this)))
5304 return getResult();
5305 if (succeeded(memref::foldMemRefCast(*this)))
5306 return getResult();
5307 if (succeeded(tensor::foldTensorCast(*this)))
5308 return getResult();
5309 return OpFoldResult();
5310}
5311
5312std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5313 return llvm::to_vector<4>(getVectorType().getShape());
5314}
5315
5316void TransferReadOp::getEffects(
5317 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5318 &effects) {
5319 if (llvm::isa<MemRefType>(getShapedType()))
5320 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5321 SideEffects::DefaultResource::get());
5322}
5323
5324Speculation::Speculatability TransferReadOp::getSpeculatability() {
5325 if (hasPureTensorSemantics())
5328}
5329
5330/// Given a projected permutation, inverse an affine map, making the unused dims
5331/// 0 in the result.
5332static AffineMap inverseWithUnusedDims(AffineMap map) {
5333 assert(map.isProjectedPermutation() &&
5334 "expected a projected permutation map");
5335 SmallVector<AffineExpr> results(map.getNumInputs(),
5337 for (auto [idx, result] : llvm::enumerate(map.getResults())) {
5338 // We should only have dim exprs because this is a projected permutation.
5339 int64_t pos = cast<AffineDimExpr>(result).getPosition();
5340 results[pos] = getAffineDimExpr(idx, map.getContext());
5341 }
5342 return AffineMap::get(/*dimCount=*/map.getNumResults(), /*symbolCount=*/0,
5343 results, map.getContext());
5344}
5345
5346namespace {
5347/// Store to load forwarding for transfer operations with permuation maps.
5348/// Even if the permutation maps are different we can still propagate the store
5349/// into the load if the size of the dimensions read and written match. Then we
5350/// can replace the transfer_read + transfer_write by vector.broadcast and
5351/// vector.transpose.
5352/// Example:
5353/// ```
5354/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
5355/// {in_bounds = [true, true],
5356/// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
5357/// vector<4x1xf32>, tensor<4x4x4xf32>
5358/// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
5359/// {in_bounds = [true, true, true, true],
5360/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
5361/// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
5362/// ```
5363/// To:
5364/// ```
5365/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
5366/// %r = vector.transpose %0, [3, 0, 2, 1] :
5367/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
5368/// ```
5369struct TransferReadAfterWriteToBroadcast
5370 : public OpRewritePattern<TransferReadOp> {
5371 using Base::Base;
5372
5373 LogicalResult matchAndRewrite(TransferReadOp readOp,
5374 PatternRewriter &rewriter) const override {
5375 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5376 if (!defWrite)
5377 return failure();
5378 // Bail if we need an alias analysis.
5379 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5380 return failure();
5381 // Bail in the masked case (too complex atm and needed to properly account
5382 // for padding).
5383 if (readOp.getMask() || defWrite.getMask())
5384 return failure();
5385 // If indices are not the same a shift may be required, bail.
5386 if (readOp.getIndices() != defWrite.getIndices())
5387 return failure();
5388 // Bail if we need a bounds analysis.
5389 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5390 return failure();
5391 // TODO: If the written transfer chunk is a superset of the read transfer
5392 // chunk we could do an extract_strided_slice.
5393 if (readOp.getTransferChunkAccessed() !=
5394 defWrite.getTransferChunkAccessed())
5395 return failure();
5396 // WriteMap: tensor -> w_vec
5397 // ReadMap: tensor -> r_vec
5398 //
5399 // inv(WriteMap): w_vec -> tensor
5400 // inv(WriteMap) o ReadMap: w_vec -> r_vec
5401 AffineMap readMap = readOp.getPermutationMap();
5402 AffineMap writeMap = defWrite.getPermutationMap();
5403 AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5404 AffineMap composedMap = readMap.compose(invWriteMap);
5405 // If there are any unused dims in the composedMap, we have to drop some
5406 // unit dims from the written vector before we can do transpose(broadcast).
5407 // TODO: Support this case.
5408 if (getUnusedDimsBitVector(composedMap).any())
5409 return failure();
5410 // readVec = transpose(broadcast(writeVec))
5411 //
5412 // Build a transpose permutation for the above transpose operation.
5413 //
5414 // Treat the composed map as having extra leading dimensions which are
5415 // the broadcasted dimensions, and treat the zeros as these new broadcasted
5416 // dimensions.
5417 SmallVector<unsigned> broadcastedDims = composedMap.getBroadcastDims();
5418 int64_t numBroadcastedDims = broadcastedDims.size();
5419 auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5420 invPerm.resize(composedMap.getNumResults());
5421 for (auto [idx, expr] : llvm::enumerate(composedMap.getResults())) {
5422 if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
5423 int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5424 invPerm[effectiveDim] = idx;
5425 }
5426 }
5427 // Applying the inverse permutation on the readVecTy will give us the
5428 // broadcast result type.
5429 VectorType readVecTy = readOp.getVectorType();
5430 SmallVector<int64_t> permutation = invertPermutationVector(invPerm);
5431 auto broadcastedVecTy =
5432 VectorType::get(applyPermutation(readVecTy.getShape(), invPerm),
5433 readVecTy.getElementType(),
5434 applyPermutation(readVecTy.getScalableDims(), invPerm));
5435 // Build the transpose(broadcast) transformation.
5436 Value vec = defWrite.getVector();
5437 Location loc = readOp.getLoc();
5438 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5439 rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec, permutation);
5440 return success();
5441 }
5442};
5443} // namespace
5444
5445void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5446 MLIRContext *context) {
5447 results.add<TransferReadAfterWriteToBroadcast>(context);
5448}
5449
5450FailureOr<std::optional<SmallVector<Value>>>
5451TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5452 if (!hasPureBufferSemantics())
5453 return failure();
5455 getResult());
5456}
5457
5458//===----------------------------------------------------------------------===//
5459// TransferWriteOp
5460//===----------------------------------------------------------------------===//
5461
5462/// 1. Builder with type inference.
5463void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5464 Value vector, Value dest, ValueRange indices,
5465 AffineMapAttr permutationMapAttr,
5466 /*optional*/ Value mask,
5467 /*optional*/ ArrayAttr inBoundsAttr) {
5468 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
5469 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
5470 mask, inBoundsAttr);
5471}
5472
5473/// 2. Builder with type inference that sets an empty mask (variant with attrs).
5474void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5475 Value vector, Value dest, ValueRange indices,
5476 AffineMapAttr permutationMapAttr,
5477 /*optional*/ ArrayAttr inBoundsAttr) {
5478 build(builder, result, vector, dest, indices, permutationMapAttr,
5479 /*mask=*/Value(), inBoundsAttr);
5480}
5481
5482/// 3. Builder with type inference that sets an empty mask (variant without
5483/// attrs)
5484void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5485 Value vector, Value dest, ValueRange indices,
5486 AffineMap permutationMap,
5487 std::optional<ArrayRef<bool>> inBounds) {
5488 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5489 auto inBoundsAttr =
5490 (inBounds && !inBounds.value().empty())
5491 ? builder.getBoolArrayAttr(inBounds.value())
5492 : builder.getBoolArrayAttr(SmallVector<bool>(
5493 llvm::cast<VectorType>(vector.getType()).getRank(), false));
5494 build(builder, result, vector, dest, indices, permutationMapAttr,
5495 /*mask=*/Value(), inBoundsAttr);
5496}
5497
5498/// 4. Builder with type inference that sets an empty mask and sets permutation
5499/// map to 'getMinorIdentityMap'.
5500void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5501 Value vector, Value dest, ValueRange indices,
5502 std::optional<ArrayRef<bool>> inBounds) {
5503 auto vectorType = llvm::cast<VectorType>(vector.getType());
5504 AffineMap permutationMap = getTransferMinorIdentityMap(
5505 llvm::cast<ShapedType>(dest.getType()), vectorType);
5506 build(builder, result, vector, dest, indices, permutationMap, inBounds);
5507}
5508
5509ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5510 OperationState &result) {
5511 auto &builder = parser.getBuilder();
5512 SMLoc typesLoc;
5513 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5514 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5515 SmallVector<Type, 2> types;
5516 OpAsmParser::UnresolvedOperand maskInfo;
5517 if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
5518 parser.parseOperand(sourceInfo) ||
5519 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
5520 return failure();
5521 ParseResult hasMask = parser.parseOptionalComma();
5522 if (hasMask.succeeded() && parser.parseOperand(maskInfo))
5523 return failure();
5524 if (parser.parseOptionalAttrDict(result.attributes) ||
5525 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5526 return failure();
5527 if (types.size() != 2)
5528 return parser.emitError(typesLoc, "requires two types");
5529 auto indexType = builder.getIndexType();
5530 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5531 if (!vectorType)
5532 return parser.emitError(typesLoc, "requires vector type");
5533 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5534 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5535 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5536 auto permMapAttrName =
5537 TransferWriteOp::getPermutationMapAttrName(result.name);
5538 auto permMapAttr = result.attributes.get(permMapAttrName);
5539 AffineMap permMap;
5540 if (!permMapAttr) {
5541 if (shapedType.getRank() <
5542 getEffectiveVectorRankForXferOp(shapedType, vectorType))
5543 return parser.emitError(typesLoc,
5544 "expected a custom permutation_map when "
5545 "rank(source) != rank(destination)");
5546 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5547 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5548 } else {
5549 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5550 }
5551 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
5552 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5553 if (!inBoundsAttr) {
5554 result.addAttribute(inBoundsAttrName,
5555 builder.getBoolArrayAttr(
5556 SmallVector<bool>(permMap.getNumResults(), false)));
5557 }
5558 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
5559 parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5560 parser.resolveOperands(indexInfo, indexType, result.operands))
5561 return failure();
5562 if (hasMask.succeeded()) {
5563 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5564 return parser.emitError(
5565 maskInfo.location, "does not support masks with vector element type");
5566 if (vectorType.getRank() != permMap.getNumResults()) {
5567 return parser.emitError(typesLoc,
5568 "expected the same rank for the vector and the "
5569 "results of the permutation map");
5570 }
5571 auto maskType = inferTransferOpMaskType(vectorType, permMap);
5572 if (parser.resolveOperand(maskInfo, maskType, result.operands))
5573 return failure();
5574 }
5575 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5576 builder.getDenseI32ArrayAttr(
5577 {1, 1, static_cast<int32_t>(indexInfo.size()),
5578 static_cast<int32_t>(hasMask.succeeded())}));
5579 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5580 parser.addTypeToList(shapedType, result.types));
5581}
5582
5583void TransferWriteOp::print(OpAsmPrinter &p) {
5584 p << " " << getVector() << ", " << getBase() << "[" << getIndices() << "]";
5585 if (getMask())
5586 p << ", " << getMask();
5587 printTransferAttrs(p, *this);
5588 p << " : " << getVectorType() << ", " << getShapedType();
5589}
5590
5591LogicalResult TransferWriteOp::verify() {
5592 // Consistency of elemental types in shape and vector.
5593 ShapedType shapedType = getShapedType();
5594 VectorType vectorType = getVectorType();
5595 VectorType maskType = getMaskType();
5596 auto permutationMap = getPermutationMap();
5597 VectorType inferredMaskType =
5598 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5599 : VectorType();
5600
5601 if (llvm::size(getIndices()) != shapedType.getRank())
5602 return emitOpError("requires ") << shapedType.getRank() << " indices";
5603
5604 // We do not allow broadcast dimensions on TransferWriteOps for the moment,
5605 // as the semantics is unclear. This can be revisited later if necessary.
5606 if (hasBroadcastDim())
5607 return emitOpError("should not have broadcast dimensions");
5608
5609 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5610 shapedType, vectorType, maskType,
5611 inferredMaskType, permutationMap, getInBounds())))
5612 return failure();
5613
5614 return verifyPermutationMap(permutationMap,
5615 [&](Twine t) { return emitOpError(t); });
5616}
5617
5618//===----------------------------------------------------------------------===//
5619// TransferWriteOp: MaskableOpInterface methods.
5620//===----------------------------------------------------------------------===//
5621
5622/// Returns the mask type expected by this operation. Mostly used for
5623/// verification purposes.
5624Type TransferWriteOp::getExpectedMaskType() {
5625 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5626}
5627
5628//===----------------------------------------------------------------------===//
5629// TransferWriteOp: VectorTransferOpInterface methods.
5630//===----------------------------------------------------------------------===//
5631Value TransferWriteOp::getVector() { return getOperand(0); }
5632VectorType TransferWriteOp::getVectorType() {
5633 return cast<VectorType>(getValueToStore().getType());
5634}
5635
5636//===----------------------------------------------------------------------===//
5637// TransferWriteOp: fold methods.
5638//===----------------------------------------------------------------------===//
5639/// Fold:
5640/// ```
5641/// %t1 = ...
5642/// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
5643/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5644/// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
5645/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5646/// ```
5647///
5648/// into:
5649///
5650/// ```
5651/// %t0
5652/// ```
5653///
5654/// The producer of t1 may or may not be DCE'd depending on whether it is a
5655/// block argument or has side effects.
5656static LogicalResult foldReadInitWrite(TransferWriteOp write,
5657 ArrayRef<Attribute>,
5658 SmallVectorImpl<OpFoldResult> &results) {
5659 // TODO: support 0-d corner case.
5660 if (write.getTransferRank() == 0)
5661 return failure();
5662 auto rankedTensorType =
5663 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5664 // If not operating on tensors, bail.
5665 if (!rankedTensorType)
5666 return failure();
5667 // If no read, bail.
5668 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5669 if (!read)
5670 return failure();
5671 // TODO: support 0-d corner case.
5672 if (read.getTransferRank() == 0)
5673 return failure();
5674 // For now, only accept minor identity. Future: composition is minor identity.
5675 if (!read.getPermutationMap().isMinorIdentity() ||
5676 !write.getPermutationMap().isMinorIdentity())
5677 return failure();
5678 // Bail on mismatching ranks.
5679 if (read.getTransferRank() != write.getTransferRank())
5680 return failure();
5681 // Bail on potential out-of-bounds accesses.
5682 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5683 return failure();
5684 // Tensor types must be the same.
5685 if (read.getBase().getType() != rankedTensorType)
5686 return failure();
5687 // Vector types must be the same.
5688 if (read.getVectorType() != write.getVectorType())
5689 return failure();
5690 // Vector and Tensor shapes must match.
5691 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5692 return failure();
5693 // If any index is nonzero.
5694 auto isNotConstantZero = [](Value v) {
5695 auto cstOp = getConstantIntValue(v);
5696 return !cstOp.has_value() || cstOp.value() != 0;
5697 };
5698 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5699 llvm::any_of(write.getIndices(), isNotConstantZero))
5700 return failure();
5701 // Success.
5702 results.push_back(read.getBase());
5703 return success();
5704}
5705
5706static bool checkSameValueWAR(vector::TransferReadOp read,
5707 vector::TransferWriteOp write) {
5708 return read.getBase() == write.getBase() &&
5709 read.getIndices() == write.getIndices() &&
5710 read.getPermutationMap() == write.getPermutationMap() &&
5711 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5712 !write.getMask();
5713}
5714/// Fold transfer_write write after read:
5715/// ```
5716/// %t0 = ...
5717/// %v = vector.transfer_read %t0[%c0...] :
5718/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5719/// %t1 = vector.transfer_write %v, %t0[%c0...] :
5720/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5721/// ```
5722///
5723/// into:
5724///
5725/// ```
5726/// %t0
5727/// ```
5728static LogicalResult foldWAR(TransferWriteOp write,
5729 SmallVectorImpl<OpFoldResult> &results) {
5730 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5731 return failure();
5732 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5733 if (!read)
5734 return failure();
5735
5736 if (!checkSameValueWAR(read, write))
5737 return failure();
5738 results.push_back(read.getBase());
5739 return success();
5740}
5741
5742LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5743 SmallVectorImpl<OpFoldResult> &results) {
5744 if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
5745 return success();
5746 if (succeeded(foldWAR(*this, results)))
5747 return success();
5748 if (succeeded(foldTransferInBoundsAttribute(*this)))
5749 return success();
5750 if (succeeded(foldTransferFullMask(*this)))
5751 return success();
5752 return memref::foldMemRefCast(*this);
5753}
5754
5755//===----------------------------------------------------------------------===//
5756// TransferWriteOp: other methods.
5757//===----------------------------------------------------------------------===//
5758std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5759 return llvm::to_vector<4>(getVectorType().getShape());
5760}
5761
5762void TransferWriteOp::getEffects(
5763 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5764 &effects) {
5765 if (llvm::isa<MemRefType>(getShapedType()))
5766 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5767 SideEffects::DefaultResource::get());
5768}
5769
5770Speculation::Speculatability TransferWriteOp::getSpeculatability() {
5771 if (hasPureTensorSemantics())
5774}
5775
5776namespace {
5777/// Remove dead transfer write from the SSA chain so that it an be eliminated by
5778/// DCE
5779/// ```
5780/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5781/// : vector<1x4xf32>, tensor<4x4xf32>
5782/// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
5783/// : vector<1x4xf32>, tensor<4x4xf32>
5784/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5785/// : vector<1x4xf32>, tensor<4x4xf32>
5786/// ```
5787///
5788/// into:
5789///
5790/// ```
5791/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5792/// : vector<1x4xf32>, tensor<4x4xf32>
5793/// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
5794/// : vector<1x4xf32>, tensor<4x4xf32>
5795/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5796/// : vector<1x4xf32>, tensor<4x4xf32>
5797/// ```
5798///
5799/// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
5800/// any other uses.
5801class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
5802public:
5803 using Base::Base;
5804 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5805 PatternRewriter &rewriter) const override {
5806 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5807 return failure();
5808 vector::TransferWriteOp writeToModify = writeOp;
5809
5810 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5811 while (defWrite) {
5812 if (checkSameValueWAW(writeOp, defWrite)) {
5813 rewriter.modifyOpInPlace(writeToModify, [&]() {
5814 writeToModify.getBaseMutable().assign(defWrite.getBase());
5815 });
5816 return success();
5817 }
5819 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5820 cast<VectorTransferOpInterface>(writeOp.getOperation())))
5821 break;
5822 // If the previous write op doesn't have any other use we an safely look
5823 // at the previous store to see if it can be removed.
5824 if (!defWrite->hasOneUse())
5825 break;
5826 writeToModify = defWrite;
5827 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5828 }
5829 return failure();
5830 }
5831};
5832
5833/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
5834/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
5835/// overwritten and inserted into another tensor. After this rewrite, the
5836/// operations bufferize in-place since all of them work on the same slice.
5837///
5838/// For example:
5839/// ```mlir
5840/// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
5841/// : vector<8x16xf32>, tensor<8x16xf32>
5842/// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
5843/// : tensor<8x16xf32> to tensor<?x?xf32>
5844/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5845/// : tensor<?x?xf32> into tensor<27x37xf32>
5846/// ```
5847/// folds to
5848/// ```mlir
5849/// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5850/// : tensor<27x37xf32> to tensor<?x?xf32>
5851/// %1 = vector.transfer_write %vec, %0[%c0, %c0]
5852/// : vector<8x16xf32>, tensor<?x?xf32>
5853/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5854/// : tensor<?x?xf32> into tensor<27x37xf32>
5855/// ```
5856struct SwapExtractSliceOfTransferWrite
5857 : public OpRewritePattern<tensor::InsertSliceOp> {
5858public:
5859 using Base::Base;
5860
5861 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5862 PatternRewriter &rewriter) const override {
5863 if (!insertOp.hasUnitStride())
5864 return failure();
5865 auto extractOp =
5866 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5867 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5868 return failure();
5869 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5870 if (!transferOp || !transferOp->hasOneUse())
5871 return failure();
5872
5873 // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
5874 // rank-reducing.
5875 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5876 return rewriter.notifyMatchFailure(insertOp,
5877 "use-def chain is rank-reducing");
5878 }
5879
5880 // Fail if tensor::ExtractSliceOp has non-zero offset.
5881 if (!extractOp.hasZeroOffset()) {
5882 return rewriter.notifyMatchFailure(insertOp,
5883 "ExtractSliceOp has non-zero offset");
5884 }
5885
5886 // Fail if tensor::TransferWriteOp has non-zero offset.
5887 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5888 return getConstantIntValue(value) == static_cast<int64_t>(0);
5889 })) {
5890 return rewriter.notifyMatchFailure(insertOp,
5891 "TranferWriteOp has non-zero offset");
5892 }
5893
5894 // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
5895 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5896 return rewriter.notifyMatchFailure(
5897 insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
5898 }
5899
5900 for (auto [insertSize, extractSize] :
5901 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5902 if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
5903 return rewriter.notifyMatchFailure(
5904 insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
5905 }
5906 }
5907
5908 // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
5909 assert(transferOp.getVectorType().hasStaticShape() &&
5910 "expected vector to have a static shape");
5911 ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
5912 SmallVector<int64_t> resultShape = applyPermutationMap(
5913 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5914 if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
5915 return rewriter.notifyMatchFailure(
5916 insertOp, "TransferWriteOp may not write the full tensor.");
5917 }
5918
5919 // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
5920 // Set all in_bounds to false and let the folder infer them.
5921 SmallVector<bool> newInBounds(vectorShape.size(), false);
5922 auto newExtractOp = tensor::ExtractSliceOp::create(
5923 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5924 insertOp.getDest(), insertOp.getMixedOffsets(),
5925 insertOp.getMixedSizes(), insertOp.getMixedStrides());
5926 auto newTransferWriteOp = TransferWriteOp::create(
5927 rewriter, transferOp.getLoc(), transferOp.getVector(),
5928 newExtractOp.getResult(), transferOp.getIndices(),
5929 transferOp.getPermutationMapAttr(),
5930 rewriter.getBoolArrayAttr(newInBounds));
5931 rewriter.modifyOpInPlace(insertOp, [&]() {
5932 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5933 });
5934 return success();
5935 }
5936};
5937
5938} // namespace
5939
5940void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5941 MLIRContext *context) {
5942 results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5943}
5944
5945FailureOr<std::optional<SmallVector<Value>>>
5946TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
5947 if (!hasPureBufferSemantics())
5948 return failure();
5950 ValueRange());
5951}
5952
5953//===----------------------------------------------------------------------===//
5954// LoadOp
5955//===----------------------------------------------------------------------===//
5956
5957static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5958 VectorType vecTy,
5959 MemRefType memRefTy) {
5960 // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
5961 // need any strides limitations.
5962 if (!vecTy.isScalable() &&
5963 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5964 return success();
5965
5966 if (!memRefTy.isLastDimUnitStride())
5967 return op->emitOpError("most minor memref dim must have unit stride");
5968 return success();
5969}
5970
5971LogicalResult vector::LoadOp::verify() {
5972 VectorType resVecTy = getVectorType();
5973 MemRefType memRefTy = getMemRefType();
5974
5975 if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
5976 return failure();
5977
5978 if (memRefTy.getRank() < resVecTy.getRank())
5979 return emitOpError(
5980 "destination memref has lower rank than the result vector");
5981
5982 // Checks for vector memrefs.
5983 Type memElemTy = memRefTy.getElementType();
5984 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5985 if (memVecTy != resVecTy)
5986 return emitOpError("base memref and result vector types should match");
5987 memElemTy = memVecTy.getElementType();
5988 }
5989
5990 if (resVecTy.getElementType() != memElemTy)
5991 return emitOpError("base and result element types should match");
5992 if (llvm::size(getIndices()) != memRefTy.getRank())
5993 return emitOpError("requires ") << memRefTy.getRank() << " indices";
5994 return success();
5995}
5996
5997OpFoldResult LoadOp::fold(FoldAdaptor) {
5998 if (succeeded(memref::foldMemRefCast(*this)))
5999 return getResult();
6000 return OpFoldResult();
6001}
6002
6003std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
6004 return llvm::to_vector<4>(getVectorType().getShape());
6005}
6006
6007FailureOr<std::optional<SmallVector<Value>>>
6008LoadOp::bubbleDownCasts(OpBuilder &builder) {
6010 getResult());
6011}
6012
6013//===----------------------------------------------------------------------===//
6014// StoreOp
6015//===----------------------------------------------------------------------===//
6016
6017LogicalResult vector::StoreOp::verify() {
6018 VectorType valueVecTy = getVectorType();
6019 MemRefType memRefTy = getMemRefType();
6020
6021 if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
6022 return failure();
6023
6024 if (memRefTy.getRank() < valueVecTy.getRank())
6025 return emitOpError("source memref has lower rank than the vector to store");
6026
6027 // Checks for vector memrefs.
6028 Type memElemTy = memRefTy.getElementType();
6029 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6030 if (memVecTy != valueVecTy)
6031 return emitOpError(
6032 "base memref and valueToStore vector types should match");
6033 memElemTy = memVecTy.getElementType();
6034 }
6035
6036 if (valueVecTy.getElementType() != memElemTy)
6037 return emitOpError("base and valueToStore element type should match");
6038 if (llvm::size(getIndices()) != memRefTy.getRank())
6039 return emitOpError("requires ") << memRefTy.getRank() << " indices";
6040 return success();
6041}
6042
6043LogicalResult StoreOp::fold(FoldAdaptor adaptor,
6044 SmallVectorImpl<OpFoldResult> &results) {
6045 return memref::foldMemRefCast(*this);
6046}
6047
6048std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
6049 return llvm::to_vector<4>(getVectorType().getShape());
6050}
6051
6052FailureOr<std::optional<SmallVector<Value>>>
6053StoreOp::bubbleDownCasts(OpBuilder &builder) {
6055 ValueRange());
6056}
6057
6058//===----------------------------------------------------------------------===//
6059// MaskedLoadOp
6060//===----------------------------------------------------------------------===//
6061
6062LogicalResult MaskedLoadOp::verify() {
6063 VectorType maskVType = getMaskVectorType();
6064 VectorType passVType = getPassThruVectorType();
6065 VectorType resVType = getVectorType();
6066 MemRefType memType = getMemRefType();
6067
6068 if (failed(
6069 verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
6070 return failure();
6071 if (llvm::size(getIndices()) != memType.getRank())
6072 return emitOpError("requires ") << memType.getRank() << " indices";
6073 if (resVType.getShape() != maskVType.getShape())
6074 return emitOpError("expected result shape to match mask shape");
6075 if (resVType != passVType)
6076 return emitOpError("expected pass_thru of same type as result type");
6077 return success();
6078}
6079
6080namespace {
6081class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
6082public:
6083 using Base::Base;
6084 LogicalResult matchAndRewrite(MaskedLoadOp load,
6085 PatternRewriter &rewriter) const override {
6086 switch (getMaskFormat(load.getMask())) {
6088 rewriter.replaceOpWithNewOp<vector::LoadOp>(
6089 load, load.getType(), load.getBase(), load.getIndices());
6090 return success();
6092 rewriter.replaceOp(load, load.getPassThru());
6093 return success();
6095 return failure();
6096 }
6097 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
6098 }
6099};
6100} // namespace
6101
6102void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6103 MLIRContext *context) {
6104 results.add<MaskedLoadFolder>(context);
6105}
6106
6107OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
6108 if (succeeded(memref::foldMemRefCast(*this)))
6109 return getResult();
6110 return OpFoldResult();
6111}
6112
6113FailureOr<std::optional<SmallVector<Value>>>
6114MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
6116 getResult());
6117}
6118
6119//===----------------------------------------------------------------------===//
6120// MaskedStoreOp
6121//===----------------------------------------------------------------------===//
6122
6123LogicalResult MaskedStoreOp::verify() {
6124 VectorType maskVType = getMaskVectorType();
6125 VectorType valueVType = getVectorType();
6126 MemRefType memType = getMemRefType();
6127
6128 if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
6129 "valueToStore")))
6130 return failure();
6131 if (llvm::size(getIndices()) != memType.getRank())
6132 return emitOpError("requires ") << memType.getRank() << " indices";
6133 if (valueVType.getShape() != maskVType.getShape())
6134 return emitOpError("expected valueToStore shape to match mask shape");
6135 return success();
6136}
6137
6138namespace {
6139class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
6140public:
6141 using Base::Base;
6142 LogicalResult matchAndRewrite(MaskedStoreOp store,
6143 PatternRewriter &rewriter) const override {
6144 switch (getMaskFormat(store.getMask())) {
6146 rewriter.replaceOpWithNewOp<vector::StoreOp>(
6147 store, store.getValueToStore(), store.getBase(), store.getIndices());
6148 return success();
6150 rewriter.eraseOp(store);
6151 return success();
6153 return failure();
6154 }
6155 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
6156 }
6157};
6158} // namespace
6159
6160void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6161 MLIRContext *context) {
6162 results.add<MaskedStoreFolder>(context);
6163}
6164
6165LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
6166 SmallVectorImpl<OpFoldResult> &results) {
6167 return memref::foldMemRefCast(*this);
6168}
6169
6170FailureOr<std::optional<SmallVector<Value>>>
6171MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
6173 ValueRange());
6174}
6175
6176//===----------------------------------------------------------------------===//
6177// GatherOp
6178//===----------------------------------------------------------------------===//
6179
6180LogicalResult GatherOp::verify() {
6181 VectorType indVType = getIndexVectorType();
6182 VectorType maskVType = getMaskVectorType();
6183 VectorType resVType = getVectorType();
6184 ShapedType baseType = getBaseType();
6185
6186 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6187 return emitOpError("requires base to be a memref or ranked tensor type");
6188
6189 if (failed(
6190 verifyElementTypesMatch(*this, baseType, resVType, "base", "result")))
6191 return failure();
6192 if (llvm::size(getOffsets()) != baseType.getRank())
6193 return emitOpError("requires ") << baseType.getRank() << " indices";
6194 if (resVType.getShape() != indVType.getShape())
6195 return emitOpError("expected result dim to match indices dim");
6196 if (resVType.getShape() != maskVType.getShape())
6197 return emitOpError("expected result dim to match mask dim");
6198 if (resVType != getPassThruVectorType())
6199 return emitOpError("expected pass_thru of same type as result type");
6200 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6201 return emitOpError(
6202 "alignment is only supported for memref bases, not tensor bases");
6203 }
6204 return success();
6205}
6206
6207// MaskableOpInterface methods.
6208
6209/// Returns the mask type expected by this operation. Mostly used for
6210/// verification purposes. It requires the operation to be vectorized."
6211Type GatherOp::getExpectedMaskType() {
6212 auto vecType = this->getIndexVectorType();
6213 return VectorType::get(vecType.getShape(),
6214 IntegerType::get(vecType.getContext(), /*width=*/1),
6215 vecType.getScalableDims());
6216}
6217
6218std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
6219 return llvm::to_vector<4>(getVectorType().getShape());
6220}
6221
6222/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
6223static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
6224 auto vecType = dyn_cast<VectorType>(indexVec.getType());
6225 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6226 return failure();
6227
6228 if (indexVec.getDefiningOp<StepOp>())
6229 return success();
6230
6231 DenseIntElementsAttr elements;
6232 if (!matchPattern(indexVec, m_Constant(&elements)))
6233 return failure();
6234
6235 return success(
6236 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6237}
6238
6239namespace {
6240class GatherFolder final : public OpRewritePattern<GatherOp> {
6241public:
6242 using Base::Base;
6243 LogicalResult matchAndRewrite(GatherOp gather,
6244 PatternRewriter &rewriter) const override {
6245 switch (getMaskFormat(gather.getMask())) {
6247 return failure(); // no unmasked equivalent
6249 rewriter.replaceOp(gather, gather.getPassThru());
6250 return success();
6252 return failure();
6253 }
6254 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
6255 }
6256};
6257
6258/// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
6259/// maskedload. Only 1D fixed vectors are supported for now.
6260class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
6261public:
6262 using Base::Base;
6263 LogicalResult matchAndRewrite(GatherOp op,
6264 PatternRewriter &rewriter) const override {
6265 if (!isa<MemRefType>(op.getBase().getType()))
6266 return rewriter.notifyMatchFailure(op, "base must be of memref type");
6267
6268 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6269 return failure();
6270
6271 rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
6272 op.getOffsets(), op.getMask(),
6273 op.getPassThru());
6274 return success();
6275 }
6276};
6277} // namespace
6278
6279void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6280 MLIRContext *context) {
6281 results.add<GatherFolder, FoldContiguousGather>(context);
6282}
6283
6284FailureOr<std::optional<SmallVector<Value>>>
6285GatherOp::bubbleDownCasts(OpBuilder &builder) {
6287 getResult());
6288}
6289
6290//===----------------------------------------------------------------------===//
6291// ScatterOp
6292//===----------------------------------------------------------------------===//
6293
6294LogicalResult ScatterOp::verify() {
6295 VectorType indVType = getIndexVectorType();
6296 VectorType maskVType = getMaskVectorType();
6297 VectorType valueVType = getVectorType();
6298 ShapedType baseType = getBaseType();
6299
6300 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6301 return emitOpError("requires base to be a memref or ranked tensor type");
6302
6303 if (failed(verifyElementTypesMatch(*this, baseType, valueVType, "base",
6304 "valueToStore")))
6305 return failure();
6306 if (llvm::size(getOffsets()) != baseType.getRank())
6307 return emitOpError("requires ") << baseType.getRank() << " indices";
6308 if (valueVType.getShape() != indVType.getShape())
6309 return emitOpError("expected valueToStore dim to match indices dim");
6310 if (valueVType.getShape() != maskVType.getShape())
6311 return emitOpError("expected valueToStore dim to match mask dim");
6312 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6313 return emitOpError(
6314 "alignment is only supported for memref bases, not tensor bases");
6315 }
6316 return success();
6317}
6318namespace {
6319class ScatterFolder final : public OpRewritePattern<ScatterOp> {
6320public:
6321 using Base::Base;
6322 LogicalResult matchAndRewrite(ScatterOp scatter,
6323 PatternRewriter &rewriter) const override {
6324 ShapedType baseType = scatter.getBaseType();
6325 bool isMemRef = isa<MemRefType>(baseType);
6326 if (!isMemRef && !isa<RankedTensorType>(baseType))
6327 return failure();
6328
6329 // Memrefs have no result, so an all-false mask can simply erase the op.
6330 // Tensors carry the updated value, so we must replace uses with the
6331 // original base tensor instead of erasing.
6332 switch (getMaskFormat(scatter.getMask())) {
6334 return failure(); // no unmasked equivalent
6336 if (isMemRef)
6337 rewriter.eraseOp(scatter);
6338 else
6339 rewriter.replaceOp(scatter, scatter.getBase());
6340 return success();
6342 return failure();
6343 }
6344 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
6345 }
6346};
6347
6348/// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
6349/// maskedstore. Only 1D fixed vectors are supported for now.
6350class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
6351public:
6352 using Base::Base;
6353 LogicalResult matchAndRewrite(ScatterOp op,
6354 PatternRewriter &rewriter) const override {
6355 // Fold only for memrefs: the replacement uses maskedstore, which does not
6356 // support tensor bases. Tensor cases intentionally bail out.
6357 if (!isa<MemRefType>(op.getBase().getType()))
6358 return failure();
6359
6360 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6361 return failure();
6362
6363 rewriter.replaceOpWithNewOp<MaskedStoreOp>(
6364 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6365 return success();
6366 }
6367};
6368} // namespace
6369
6370void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6371 MLIRContext *context) {
6372 results.add<ScatterFolder, FoldContiguousScatter>(context);
6373}
6374
6375FailureOr<std::optional<SmallVector<Value>>>
6376ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6378 ValueRange());
6379}
6380
6381//===----------------------------------------------------------------------===//
6382// ExpandLoadOp
6383//===----------------------------------------------------------------------===//
6384
6385LogicalResult ExpandLoadOp::verify() {
6386 VectorType maskVType = getMaskVectorType();
6387 VectorType passVType = getPassThruVectorType();
6388 VectorType resVType = getVectorType();
6389 MemRefType memType = getMemRefType();
6390
6391 if (failed(
6392 verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
6393 return failure();
6394 if (llvm::size(getIndices()) != memType.getRank())
6395 return emitOpError("requires ") << memType.getRank() << " indices";
6396 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
6397 return emitOpError("expected result dim to match mask dim");
6398 if (resVType != passVType)
6399 return emitOpError("expected pass_thru of same type as result type");
6400 return success();
6401}
6402
6403namespace {
6404class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
6405public:
6406 using Base::Base;
6407 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6408 PatternRewriter &rewriter) const override {
6409 switch (getMaskFormat(expand.getMask())) {
6411 rewriter.replaceOpWithNewOp<vector::LoadOp>(
6412 expand, expand.getType(), expand.getBase(), expand.getIndices());
6413 return success();
6415 rewriter.replaceOp(expand, expand.getPassThru());
6416 return success();
6418 return failure();
6419 }
6420 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
6421 }
6422};
6423} // namespace
6424
6425void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6426 MLIRContext *context) {
6427 results.add<ExpandLoadFolder>(context);
6428}
6429
6430FailureOr<std::optional<SmallVector<Value>>>
6431ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6433 getResult());
6434}
6435
6436//===----------------------------------------------------------------------===//
6437// CompressStoreOp
6438//===----------------------------------------------------------------------===//
6439
6440LogicalResult CompressStoreOp::verify() {
6441 VectorType maskVType = getMaskVectorType();
6442 VectorType valueVType = getVectorType();
6443 MemRefType memType = getMemRefType();
6444
6445 if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
6446 "valueToStore")))
6447 return failure();
6448 if (llvm::size(getIndices()) != memType.getRank())
6449 return emitOpError("requires ") << memType.getRank() << " indices";
6450 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6451 return emitOpError("expected valueToStore dim to match mask dim");
6452 return success();
6453}
6454
6455namespace {
6456class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
6457public:
6458 using Base::Base;
6459 LogicalResult matchAndRewrite(CompressStoreOp compress,
6460 PatternRewriter &rewriter) const override {
6461 switch (getMaskFormat(compress.getMask())) {
6463 rewriter.replaceOpWithNewOp<vector::StoreOp>(
6464 compress, compress.getValueToStore(), compress.getBase(),
6465 compress.getIndices());
6466 return success();
6468 rewriter.eraseOp(compress);
6469 return success();
6471 return failure();
6472 }
6473 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
6474 }
6475};
6476} // namespace
6477
6478void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6479 MLIRContext *context) {
6480 results.add<CompressStoreFolder>(context);
6481}
6482
6483FailureOr<std::optional<SmallVector<Value>>>
6484CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6486 ValueRange());
6487}
6488
6489//===----------------------------------------------------------------------===//
6490// ShapeCastOp
6491//===----------------------------------------------------------------------===//
6492
6493void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6494 SetIntRangeFn setResultRanges) {
6495 setResultRanges(getResult(), argRanges.front());
6496}
6497
6498std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6499 return llvm::to_vector<4>(getResultVectorType().getShape());
6500}
6501
6502LogicalResult ShapeCastOp::verify() {
6503
6504 VectorType sourceType = getSourceVectorType();
6505 VectorType resultType = getResultVectorType();
6506
6507 // Check that element type is preserved
6508 if (failed(verifyElementTypesMatch(*this, sourceType, resultType, "source",
6509 "result")))
6510 return failure();
6511
6512 // Check that number of elements is preserved
6513 int64_t sourceNElms = sourceType.getNumElements();
6514 int64_t resultNElms = resultType.getNumElements();
6515 if (sourceNElms != resultNElms) {
6516 return emitOpError() << "has different number of elements at source ("
6517 << sourceNElms << ") and result (" << resultNElms
6518 << ")";
6519 }
6520
6521 // Check that (non-)scalability is preserved
6522 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6523 int64_t resultNScalableDims = resultType.getNumScalableDims();
6524 if (sourceNScalableDims != resultNScalableDims)
6525 return emitOpError() << "has different number of scalable dims at source ("
6526 << sourceNScalableDims << ") and result ("
6527 << resultNScalableDims << ")";
6528
6529 return success();
6530}
6531
6532/// Return true if `transpose` does not permute a pair of non-unit dims.
6533/// By `order preserving` we mean that the flattened versions of the input and
6534/// output vectors are (numerically) identical. In other words `transpose` is
6535/// effectively a shape cast.
6536static bool isOrderPreserving(TransposeOp transpose) {
6537 ArrayRef<int64_t> permutation = transpose.getPermutation();
6538 VectorType sourceType = transpose.getSourceVectorType();
6539 ArrayRef<int64_t> inShape = sourceType.getShape();
6540 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6541 auto isNonScalableUnitDim = [&](int64_t dim) {
6542 return inShape[dim] == 1 && !inDimIsScalable[dim];
6543 };
6544 int64_t current = 0;
6545 for (auto p : permutation) {
6546 if (!isNonScalableUnitDim(p)) {
6547 if (p < current) {
6548 return false;
6549 }
6550 current = p;
6551 }
6552 }
6553 return true;
6554}
6555
6556OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6557
6558 VectorType resultType = getType();
6559
6560 // No-op shape cast.
6561 if (getSource().getType() == resultType)
6562 return getSource();
6563
6564 // shape_cast(shape_cast(x)) -> shape_cast(x)
6565 if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6566 setOperand(precedingShapeCast.getSource());
6567 return getResult();
6568 }
6569
6570 // shape_cast(transpose(x)) -> shape_cast(x)
6571 if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6572 if (isOrderPreserving(transpose)) {
6573 setOperand(transpose.getVector());
6574 return getResult();
6575 }
6576 return {};
6577 }
6578
6579 // Y = shape_cast(broadcast(X))
6580 // -> X, if X and Y have same type
6581 if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6582 if (bcastOp.getSourceType() == resultType)
6583 return bcastOp.getSource();
6584 }
6585
6586 // shape_cast(constant) -> constant
6587 if (auto denseAttr =
6588 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6589 return denseAttr.reshape(getType());
6590
6591 // shape_cast(poison) -> poison
6592 if (matchPattern(adaptor.getSource(), ub::m_Poison()))
6593 return ub::PoisonAttr::get(getContext());
6594
6595 return {};
6596}
6597
6598namespace {
6599
6600/// Helper function that computes a new vector type based on the input vector
6601/// type by removing the trailing one dims:
6602///
6603/// vector<4x1x1xi1> --> vector<4x1xi1>
6604///
6605static VectorType trimTrailingOneDims(VectorType oldType) {
6606 ArrayRef<int64_t> oldShape = oldType.getShape();
6607 ArrayRef<int64_t> newShape = oldShape;
6608
6609 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6610 ArrayRef<bool> newScalableDims = oldScalableDims;
6611
6612 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6613 newShape = newShape.drop_back(1);
6614 newScalableDims = newScalableDims.drop_back(1);
6615 }
6616
6617 // Make sure we have at least 1 dimension.
6618 // TODO: Add support for 0-D vectors.
6619 if (newShape.empty()) {
6620 newShape = oldShape.take_back();
6621 newScalableDims = oldScalableDims.take_back();
6622 }
6623
6624 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6625}
6626
6627/// Folds qualifying shape_cast(create_mask) into a new create_mask
6628///
6629/// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
6630/// dimension. If the input vector comes from `vector.create_mask` for which
6631/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
6632/// to fold shape_cast into create_mask.
6633///
6634/// BEFORE:
6635/// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
6636/// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
6637/// AFTER:
6638/// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
6639class ShapeCastCreateMaskFolderTrailingOneDim final
6640 : public OpRewritePattern<ShapeCastOp> {
6641public:
6642 using Base::Base;
6643
6644 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6645 PatternRewriter &rewriter) const override {
6646 Value shapeOpSrc = shapeOp->getOperand(0);
6647 auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
6648 auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
6649 if (!createMaskOp && !constantMaskOp)
6650 return failure();
6651
6652 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6653 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6654
6655 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6656 if (newVecType != shapeOpResTy)
6657 return failure();
6658
6659 auto numDimsToDrop =
6660 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6661
6662 // No unit dims to drop
6663 if (!numDimsToDrop)
6664 return failure();
6665
6666 if (createMaskOp) {
6667 auto maskOperands = createMaskOp.getOperands();
6668 auto numMaskOperands = maskOperands.size();
6669
6670 // Check every mask dim size to see whether it can be dropped
6671 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6672 --i) {
6673 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6674 if (!constant || (constant.value() != 1))
6675 return failure();
6676 }
6677 SmallVector<Value> newMaskOperands =
6678 maskOperands.drop_back(numDimsToDrop);
6679
6680 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
6681 newMaskOperands);
6682 return success();
6683 }
6684
6685 if (constantMaskOp) {
6686 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6687 auto numMaskOperands = maskDimSizes.size();
6688
6689 // Check every mask dim size to see whether it can be dropped
6690 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6691 --i) {
6692 if (maskDimSizes[i] != 1)
6693 return failure();
6694 }
6695
6696 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6697 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
6698 newMaskOperands);
6699 return success();
6700 }
6701
6702 return failure();
6703 }
6704};
6705
6706// vector.broadcast has two distinct semantic modes: duplication across leading
6707// dimensions, and stretching across inner dimensions. This helper returns the
6708// product of the inner-dimension stretching factors.
6709int64_t getBroadcastStretchingFactor(ArrayRef<int64_t> srcShape,
6710 ArrayRef<int64_t> dstShape) {
6711 int stretchingFactor = 1;
6712 int numLeadingDims = dstShape.size() - srcShape.size();
6713 for (int i = 0, e = srcShape.size(); i < e; i++) {
6714 int64_t dstDim = dstShape[numLeadingDims + i];
6715 if (srcShape[i] == 1 && dstDim != 1) {
6716 stretchingFactor *= dstDim;
6717 }
6718 }
6719 return stretchingFactor;
6720}
6721
6722/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
6723class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
6724public:
6725 using Base::Base;
6726
6727 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6728 PatternRewriter &rewriter) const override {
6729 auto broadcastOp =
6730 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6731 if (!broadcastOp)
6732 return failure();
6733
6734 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6735 bool srcIsScalar = !srcVectorType;
6736
6737 // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
6738 // Example
6739 // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
6740 // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
6741 // to
6742 // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
6743 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6744 ArrayRef<int64_t> dstShape = dstVectorType.getShape();
6745 ArrayRef<int64_t> srcShape =
6746 srcIsScalar ? ArrayRef<int64_t>{} : srcVectorType.getShape();
6747 ArrayRef<int64_t> broadcastShape =
6748 broadcastOp.getResultVectorType().getShape();
6749
6750 if (!srcIsScalar) {
6751 if (isBroadcastableTo(srcVectorType, dstVectorType) !=
6752 BroadcastableToResult::Success) {
6753 return failure();
6754 }
6755 // Avoid folding if this would result in switching between the two
6756 // distinct semantic modes of vector.broadcast (duplication vs
6757 // stretching). See https://github.com/llvm/llvm-project/issues/190614.
6758 // This is detected by a change in the stretching factor. However if the
6759 // source has a single element, there is no ambiguity.
6760 if (srcVectorType.getNumElements() != 1) {
6761 if (getBroadcastStretchingFactor(srcShape, dstShape) !=
6762 getBroadcastStretchingFactor(srcShape, broadcastShape)) {
6763 return failure();
6764 }
6765 }
6766 }
6767
6768 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shapeCastOp, dstVectorType,
6769 broadcastOp.getSource());
6770 return success();
6771 }
6772};
6773
6774/// Pattern to rewrite Y = ShapeCast(FromElements(X)) as Y = FromElements(X)
6775///
6776/// BEFORE:
6777/// %1 = vector.from_elements %c1, %c2, %c3 : vector<3xf32>
6778/// %2 = vector.shape_cast %1 : vector<3xf32> to vector<1x3xf32>
6779/// AFTER:
6780/// %2 = vector.from_elements %c1, %c2, %c3 : vector<1x3xf32>
6781///
6782/// Note: this transformation is implemented as an OpRewritePattern, not as a
6783/// fold, because we have to create new op FromElementsOp with updated result
6784/// type. This cannot be done with a fold, because fold cannot create new ops
6785/// and the existing FromElementsOp result type differs from the ShapeCastOp
6786/// result type. Mutating the FromElementsOp (not root op) would violate the
6787/// fold contract and break other users.
6788class FoldShapeCastOfFromElements final : public OpRewritePattern<ShapeCastOp> {
6789public:
6790 using Base::Base;
6791
6792 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6793 PatternRewriter &rewriter) const override {
6794 auto fromElements = shapeCastOp.getSource().getDefiningOp<FromElementsOp>();
6795 if (!fromElements)
6796 return failure();
6797
6798 rewriter.replaceOpWithNewOp<FromElementsOp>(
6799 shapeCastOp, shapeCastOp.getResultVectorType(),
6800 fromElements.getElements());
6801 return success();
6802 }
6803};
6804
6805} // namespace
6806
6807void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6808 MLIRContext *context) {
6809 results.add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder,
6810 FoldShapeCastOfFromElements>(context);
6811}
6812
6813//===----------------------------------------------------------------------===//
6814// VectorBitCastOp
6815//===----------------------------------------------------------------------===//
6816
6817LogicalResult BitCastOp::verify() {
6818 auto sourceVectorType = getSourceVectorType();
6819 auto resultVectorType = getResultVectorType();
6820
6821 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6822 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6823 return emitOpError("dimension size mismatch at: ") << i;
6824 }
6825
6826 DataLayout dataLayout = DataLayout::closest(*this);
6827 auto sourceElementBits =
6828 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
6829 auto resultElementBits =
6830 dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
6831
6832 if (sourceVectorType.getRank() == 0) {
6833 if (sourceElementBits != resultElementBits)
6834 return emitOpError("source/result bitwidth of the 0-D vector element "
6835 "types must be equal");
6836 } else if (sourceElementBits * sourceVectorType.getShape().back() !=
6837 resultElementBits * resultVectorType.getShape().back()) {
6838 return emitOpError(
6839 "source/result bitwidth of the minor 1-D vectors must be equal");
6840 }
6841
6842 return success();
6843}
6844
6845OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
6846 // Nop cast.
6847 if (getSource().getType() == getResult().getType())
6848 return getSource();
6849
6850 // Canceling bitcasts.
6851 if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6852 if (getResult().getType() == otherOp.getSource().getType())
6853 return otherOp.getSource();
6854
6855 setOperand(otherOp.getSource());
6856 return getResult();
6857 }
6858
6859 Attribute sourceConstant = adaptor.getSource();
6860 if (!sourceConstant)
6861 return {};
6862
6863 Type srcElemType = getSourceVectorType().getElementType();
6864 Type dstElemType = getResultVectorType().getElementType();
6865
6866 if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6867 if (floatPack.isSplat()) {
6868 auto splat = floatPack.getSplatValue<FloatAttr>();
6869
6870 // Casting fp16 into fp32.
6871 if (srcElemType.isF16() && dstElemType.isF32()) {
6872 uint32_t bits = static_cast<uint32_t>(
6873 splat.getValue().bitcastToAPInt().getZExtValue());
6874 // Duplicate the 16-bit pattern.
6875 bits = (bits << 16) | (bits & 0xffff);
6876 APInt intBits(32, bits);
6877 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6878 return DenseElementsAttr::get(getResultVectorType(), floatBits);
6879 }
6880 }
6881 }
6882
6883 if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6884 if (intPack.isSplat()) {
6885 auto splat = intPack.getSplatValue<IntegerAttr>();
6886
6887 if (llvm::isa<IntegerType>(dstElemType) && srcElemType.isIntOrFloat()) {
6888 uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
6889 uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
6890
6891 // Casting to a larger integer bit width.
6892 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6893 APInt intBits = splat.getValue().zext(dstBitWidth);
6894
6895 // Duplicate the lower width element.
6896 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6897 intBits = (intBits << srcBitWidth) | intBits;
6898 return DenseElementsAttr::get(getResultVectorType(), intBits);
6899 }
6900 }
6901 }
6902 }
6903
6904 return {};
6905}
6906
6907//===----------------------------------------------------------------------===//
6908// TypeCastOp
6909//===----------------------------------------------------------------------===//
6910
6911static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
6912 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6913 SmallVector<int64_t, 8> res(memRefType.getShape());
6914 if (vectorType)
6915 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6916 return res;
6917}
6918
6919/// Build the canonical memRefType with a single vector.
6920/// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
6921void TypeCastOp::build(OpBuilder &builder, OperationState &result,
6922 Value source) {
6923 result.addOperands(source);
6924 MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
6925 VectorType vectorType =
6926 VectorType::get(extractShape(memRefType),
6928 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
6929 memRefType.getMemorySpace()));
6930}
6931
6932LogicalResult TypeCastOp::verify() {
6933 MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
6934 if (!canonicalType.getLayout().isIdentity())
6935 return emitOpError("expects operand to be a memref with identity layout");
6936 if (!getResultMemRefType().getLayout().isIdentity())
6937 return emitOpError("expects result to be a memref with identity layout");
6938 if (getResultMemRefType().getMemorySpace() !=
6939 getMemRefType().getMemorySpace())
6940 return emitOpError("expects result in same memory space");
6941
6942 auto sourceType = getMemRefType();
6943 auto resultType = getResultMemRefType();
6944 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
6946 return emitOpError(
6947 "expects result and operand with same underlying scalar type: ")
6948 << resultType;
6949 if (extractShape(sourceType) != extractShape(resultType))
6950 return emitOpError(
6951 "expects concatenated result and operand shapes to be equal: ")
6952 << resultType;
6953 return success();
6954}
6955
6956//===----------------------------------------------------------------------===//
6957// TransposeOp
6958//===----------------------------------------------------------------------===//
6959
6960void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
6961 Value vector, ArrayRef<int64_t> permutation) {
6962 VectorType vt = llvm::cast<VectorType>(vector.getType());
6963 SmallVector<int64_t, 4> transposedShape(vt.getRank());
6964 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
6965 for (unsigned i = 0; i < permutation.size(); ++i) {
6966 transposedShape[i] = vt.getShape()[permutation[i]];
6967 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6968 }
6969
6970 result.addOperands(vector);
6971 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
6972 transposedScalableDims));
6973 result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
6974 builder.getDenseI64ArrayAttr(permutation));
6975}
6976
6977OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6978 // Eliminate splat constant transpose ops.
6979 if (auto splat =
6980 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6981 return splat.reshape(getResultVectorType());
6982
6983 // Eliminate poison transpose ops.
6984 if (matchPattern(adaptor.getVector(), ub::m_Poison()))
6985 return ub::PoisonAttr::get(getContext());
6986
6987 // Eliminate identity transposes, and more generally any transposes that
6988 // preserves the shape without permuting elements.
6989 //
6990 // Examples of what to fold:
6991 // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6992 // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6993 // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6994 //
6995 // Example of what NOT to fold:
6996 // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6997 //
6998 if (getSourceVectorType() == getResultVectorType() &&
6999 isOrderPreserving(*this))
7000 return getVector();
7001
7002 return {};
7003}
7004
7005LogicalResult vector::TransposeOp::verify() {
7006 VectorType vectorType = getSourceVectorType();
7007 VectorType resultType = getResultVectorType();
7008 int64_t rank = resultType.getRank();
7009 if (vectorType.getRank() != rank)
7010 return emitOpError("vector result rank mismatch: ") << rank;
7011 // Verify transposition array.
7012 ArrayRef<int64_t> perm = getPermutation();
7013 int64_t size = perm.size();
7014 if (rank != size)
7015 return emitOpError("transposition length mismatch: ") << size;
7016 SmallVector<bool, 8> seen(rank, false);
7017 for (const auto &ta : llvm::enumerate(perm)) {
7018 if (ta.value() < 0 || ta.value() >= rank)
7019 return emitOpError("transposition index out of range: ") << ta.value();
7020 if (seen[ta.value()])
7021 return emitOpError("duplicate position index: ") << ta.value();
7022 seen[ta.value()] = true;
7023 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
7024 return emitOpError("dimension size mismatch at: ") << ta.value();
7025 }
7026 return success();
7027}
7028
7029std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
7030 return llvm::to_vector<4>(getResultVectorType().getShape());
7031}
7032
7033void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7034 SetIntRangeFn setResultRanges) {
7035 setResultRanges(getResult(), argRanges.front());
7036}
7037
7038namespace {
7039
7040// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
7041class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
7042public:
7043 using Base::Base;
7044
7045 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7046 PatternRewriter &rewriter) const override {
7047 // Composes two permutations: result[i] = permutation1[permutation2[i]].
7048 auto composePermutations = [](ArrayRef<int64_t> permutation1,
7049 ArrayRef<int64_t> permutation2) {
7050 SmallVector<int64_t, 4> result;
7051 for (auto index : permutation2)
7052 result.push_back(permutation1[index]);
7053 return result;
7054 };
7055
7056 // Return if the input of 'transposeOp' is not defined by another transpose.
7057 vector::TransposeOp parentTransposeOp =
7058 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
7059 if (!parentTransposeOp)
7060 return failure();
7061
7062 SmallVector<int64_t, 4> permutation = composePermutations(
7063 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
7064 // Replace 'transposeOp' with a new transpose operation.
7065 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
7066 transposeOp, transposeOp.getResult().getType(),
7067 parentTransposeOp.getVector(), permutation);
7068 return success();
7069 }
7070};
7071
7072/// Replace transpose(splat-like(v)) with broadcast(v)
7073class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
7074public:
7075 using Base::Base;
7076
7077 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7078 PatternRewriter &rewriter) const override {
7079 Value splat = getScalarSplatSource(transposeOp.getVector());
7080 if (!splat)
7081 return failure();
7082
7083 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
7084 transposeOp, transposeOp.getResultVectorType(), splat);
7085 return success();
7086 }
7087};
7088
7089/// Folds transpose(create_mask) into a new transposed create_mask.
7090class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
7091public:
7092 using Base::Base;
7093
7094 LogicalResult matchAndRewrite(TransposeOp transpOp,
7095 PatternRewriter &rewriter) const override {
7096 Value transposeSrc = transpOp.getVector();
7097 auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
7098 auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
7099 if (!createMaskOp && !constantMaskOp)
7100 return failure();
7101
7102 // Get the transpose permutation and apply it to the vector.create_mask or
7103 // vector.constant_mask operands.
7104 ArrayRef<int64_t> permutation = transpOp.getPermutation();
7105
7106 if (createMaskOp) {
7107 auto maskOperands = createMaskOp.getOperands();
7108 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
7109 applyPermutationToVector(newOperands, permutation);
7110
7111 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
7112 transpOp, transpOp.getResultVectorType(), newOperands);
7113 return success();
7114 }
7115
7116 // ConstantMaskOp case.
7117 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
7118 auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
7119
7120 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
7121 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
7122 return success();
7123 }
7124};
7125
7126/// Folds transpose(shape_cast) into a new shape_cast.
7127class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
7128public:
7129 using Base::Base;
7130
7131 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7132 PatternRewriter &rewriter) const override {
7133 auto shapeCastOp =
7134 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
7135 if (!shapeCastOp)
7136 return failure();
7137 if (!isOrderPreserving(transposeOp))
7138 return failure();
7139
7140 VectorType resultType = transposeOp.getType();
7141
7142 // We don't need to check isValidShapeCast at this point, because it is
7143 // guaranteed that merging the transpose into the the shape_cast is a valid
7144 // shape_cast, because the transpose just inserts/removes ones.
7145
7146 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
7147 shapeCastOp.getSource());
7148 return success();
7149 }
7150};
7151
7152/// Folds transpose(from_elements(...)) into a new from_elements with permuted
7153/// operands matching the transposed shape.
7154///
7155/// Example:
7156///
7157/// %v = vector.from_elements %a00, %a01, %a02, %a10, %a11, %a12 :
7158/// vector<2x3xi32> %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to
7159/// vector<3x2xi32>
7160///
7161/// becomes ->
7162///
7163/// %r = vector.from_elements %a00, %a10, %a01, %a11, %a02, %a12 :
7164/// vector<3x2xi32>
7165///
7166class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
7167public:
7168 using Base::Base;
7169 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7170 PatternRewriter &rewriter) const override {
7171 auto fromElementsOp =
7172 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
7173 if (!fromElementsOp)
7174 return failure();
7175
7176 VectorType srcTy = fromElementsOp.getDest().getType();
7177 VectorType dstTy = transposeOp.getType();
7178
7179 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
7180 int64_t rank = srcTy.getRank();
7181
7182 // Build inverse permutation to map destination indices back to source.
7183 SmallVector<int64_t> inversePerm(rank, 0);
7184 for (int64_t i = 0; i < rank; ++i)
7185 inversePerm[permutation[i]] = i;
7186
7187 ArrayRef<int64_t> srcShape = srcTy.getShape();
7188 ArrayRef<int64_t> dstShape = dstTy.getShape();
7189 SmallVector<int64_t> srcIdx(rank, 0);
7190 SmallVector<int64_t> dstIdx(rank, 0);
7191 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
7192 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
7193
7194 auto elementsOld = fromElementsOp.getElements();
7195 SmallVector<Value> elementsNew;
7196 int64_t dstNumElements = dstTy.getNumElements();
7197 elementsNew.reserve(dstNumElements);
7198
7199 // For each element in destination row-major order, pick the corresponding
7200 // source element.
7201 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
7202 // Pick the destination element index.
7203 dstIdx = delinearize(linearIdx, dstStrides);
7204 // Map the destination element index to the source element index.
7205 for (int64_t j = 0; j < rank; ++j)
7206 srcIdx[j] = dstIdx[inversePerm[j]];
7207 // Linearize the source element index.
7208 int64_t srcLin = linearize(srcIdx, srcStrides);
7209 // Add the source element to the new elements.
7210 elementsNew.push_back(elementsOld[srcLin]);
7211 }
7212
7213 rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
7214 elementsNew);
7215 return success();
7216 }
7217};
7218
7219/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
7220/// 'order preserving', where 'order preserving' means the flattened
7221/// inputs and outputs of the transpose have identical (numerical) values.
7222///
7223/// Example:
7224/// ```
7225/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
7226/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
7227/// to vector<8x1xi32>
7228/// ```
7229/// can be rewritten as the equivalent
7230/// ```
7231/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
7232/// ```
7233/// The algorithm works by partitioning dimensions into groups that can be
7234/// locally permuted while preserving order, and checks that the transpose
7235/// only permutes within these groups.
7236///
7237/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
7238/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
7239/// broadcasting from 1x1x4x1x1x7.
7240/// ^^^ ^ ^^^ ^
7241/// groups: 0 1 2 3
7242/// Order preserving permutations for this example are ones that only permute
7243/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
7244class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
7245public:
7246 using Base::Base;
7247 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
7248 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
7249
7250 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
7251 PatternRewriter &rewriter) const override {
7252
7253 vector::BroadcastOp broadcast =
7254 transpose.getVector().getDefiningOp<vector::BroadcastOp>();
7255 if (!broadcast) {
7256 return rewriter.notifyMatchFailure(transpose,
7257 "not preceded by a broadcast");
7258 }
7259
7260 auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
7261 VectorType outputType = transpose.getResultVectorType();
7262
7263 // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
7264 bool inputIsScalar = !inputType;
7265 if (inputIsScalar) {
7266 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
7267 broadcast.getSource());
7268 return success();
7269 }
7270
7271 ArrayRef<int64_t> permutation = transpose.getPermutation();
7272 ArrayRef<int64_t> inputShape = inputType.getShape();
7273 int64_t inputRank = inputType.getRank();
7274 int64_t outputRank = transpose.getType().getRank();
7275 int64_t deltaRank = outputRank - inputRank;
7276
7277 int low = 0;
7278 for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
7279 bool notOne = inputShape[inputIndex] != 1;
7280 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
7281 bool groupEndFound = notOne || prevNotOne;
7282 if (groupEndFound) {
7283 int high = inputIndex + deltaRank;
7284 // Return failure if not all permutation destinations for indices in
7285 // [low, high) are in [low, high), i.e. the permutation is not local to
7286 // the group.
7287 for (int i = low; i < high; ++i) {
7288 if (permutation[i] < low || permutation[i] >= high) {
7289 return rewriter.notifyMatchFailure(
7290 transpose, "permutation not local to group");
7291 }
7292 }
7293 low = high;
7294 }
7295 }
7296
7297 // We don't need to check the final group [low, outputRank) because if it is
7298 // not locally bound, there must be a preceding group that already failed
7299 // the check (impossible to have just 1 non-locally bound group).
7300
7301 // The preceding logic also ensures that at this point, the output of the
7302 // transpose is definitely broadcastable from the input shape, assert so:
7303 assert(vector::isBroadcastableTo(inputType, outputType) ==
7304 vector::BroadcastableToResult::Success &&
7305 "not broadcastable directly to transpose output");
7306
7307 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
7308 broadcast.getSource());
7309
7310 return success();
7311 }
7312};
7313
7314} // namespace
7315
7316void vector::TransposeOp::getCanonicalizationPatterns(
7317 RewritePatternSet &results, MLIRContext *context) {
7318 results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7319 FoldTransposeSplat, FoldTransposeFromElements,
7320 FoldTransposeBroadcast>(context);
7321}
7322
7323//===----------------------------------------------------------------------===//
7324// ConstantMaskOp
7325//===----------------------------------------------------------------------===//
7326
7327void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
7328 VectorType type, ConstantMaskKind kind) {
7329 assert(kind == ConstantMaskKind::AllTrue ||
7330 kind == ConstantMaskKind::AllFalse);
7331 build(builder, result, type,
7332 kind == ConstantMaskKind::AllTrue
7333 ? type.getShape()
7334 : SmallVector<int64_t>(type.getRank(), 0));
7335}
7336
7337LogicalResult ConstantMaskOp::verify() {
7338 auto resultType = llvm::cast<VectorType>(getResult().getType());
7339 // Check the corner case of 0-D vectors first.
7340 if (resultType.getRank() == 0) {
7341 if (getMaskDimSizes().size() != 1)
7342 return emitError("array attr must have length 1 for 0-D vectors");
7343 auto dim = getMaskDimSizes()[0];
7344 if (dim != 0 && dim != 1)
7345 return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
7346 return success();
7347 }
7348
7349 // Verify that array attr size matches the rank of the vector result.
7350 if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
7351 return emitOpError(
7352 "must specify array attr of size equal vector result rank");
7353 // Verify that each array attr element is in bounds of corresponding vector
7354 // result dimension size.
7355 auto resultShape = resultType.getShape();
7356 auto resultScalableDims = resultType.getScalableDims();
7357 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7358 for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7359 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7360 return emitOpError(
7361 "array attr of size out of bounds of vector result dimension size");
7362 if (resultScalableDims[index] && maskDimSize != 0 &&
7363 maskDimSize != resultShape[index])
7364 return emitOpError(
7365 "only supports 'none set' or 'all set' scalable dimensions");
7366 }
7367 // Verify that if one mask dim size is zero, they all should be zero (because
7368 // the mask region is a conjunction of each mask dimension interval).
7369 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7370 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
7371 if (anyZeros && !allZeros)
7372 return emitOpError("expected all mask dim sizes to be zeros, "
7373 "as a result of conjunction with zero mask dim");
7374 return success();
7375}
7376
7377bool ConstantMaskOp::isAllOnesMask() {
7378 auto resultType = getVectorType();
7379 // Check the corner case of 0-D vectors first.
7380 if (resultType.getRank() == 0) {
7381 assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
7382 return getMaskDimSizes()[0] == 1;
7383 }
7384 for (const auto [resultSize, maskDimSize] :
7385 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7386 if (maskDimSize < resultSize)
7387 return false;
7388 }
7389 return true;
7390}
7391
7392OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7393 ArrayRef<int64_t> bounds = getMaskDimSizes();
7394 ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
7395
7396 auto createBoolSplat = [&](bool x) {
7397 return SplatElementsAttr::get(getVectorType(),
7399 };
7400
7401 // Check the corner case of 0-D vectors first.
7402 if (vectorSizes.empty()) {
7403 assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
7404 return createBoolSplat(bounds[0] == 1);
7405 }
7406 // Fold vector.constant_mask to splat if possible.
7407 if (bounds == vectorSizes)
7408 return createBoolSplat(true);
7409 if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
7410 return createBoolSplat(false);
7411 return OpFoldResult();
7412}
7413
7414//===----------------------------------------------------------------------===//
7415// CreateMaskOp
7416//===----------------------------------------------------------------------===//
7417
7418void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
7419 VectorType type,
7420 ArrayRef<OpFoldResult> mixedOperands) {
7421 SmallVector<Value> operands =
7422 getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
7423 build(builder, result, type, operands);
7424}
7425
7426LogicalResult CreateMaskOp::verify() {
7427 auto vectorType = llvm::cast<VectorType>(getResult().getType());
7428 // Verify that an operand was specified for each result vector each dimension.
7429 if (vectorType.getRank() == 0) {
7430 if (getNumOperands() != 1)
7431 return emitOpError(
7432 "must specify exactly one operand for 0-D create_mask");
7433 } else if (getNumOperands() !=
7434 llvm::cast<VectorType>(getResult().getType()).getRank()) {
7435 return emitOpError(
7436 "must specify an operand for each result vector dimension");
7437 }
7438 return success();
7439}
7440
7441namespace {
7442
7443/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
7444///
7445/// Ex 1:
7446/// %c2 = arith.constant 2 : index
7447/// %c3 = arith.constant 3 : index
7448/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
7449/// Becomes:
7450/// vector.constant_mask [3, 2] : vector<4x3xi1>
7451///
7452/// Ex 2:
7453/// %c_neg_1 = arith.constant -1 : index
7454/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
7455/// becomes:
7456/// vector.constant_mask [0] : vector<[8]xi1>
7457///
7458/// Ex 3:
7459/// %c8 = arith.constant 8 : index
7460/// %c16 = arith.constant 16 : index
7461/// %0 = vector.vscale
7462/// %1 = arith.muli %0, %c16 : index
7463/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
7464/// becomes:
7465/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
7466class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
7467public:
7468 using Base::Base;
7469
7470 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7471 PatternRewriter &rewriter) const override {
7472 VectorType maskType = createMaskOp.getVectorType();
7473 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7474 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7475
7476 // Special case: Rank zero shape.
7477 constexpr std::array<int64_t, 1> rankZeroShape{1};
7478 constexpr std::array<bool, 1> rankZeroScalableDims{false};
7479 if (maskType.getRank() == 0) {
7480 maskTypeDimSizes = rankZeroShape;
7481 maskTypeDimScalableFlags = rankZeroScalableDims;
7482 }
7483
7484 // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
7485 // collect the `constantDims` (for the ConstantMaskOp).
7486 SmallVector<int64_t, 4> constantDims;
7487 for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7488 if (auto intSize = getConstantIntValue(dimSize)) {
7489 // Constant value.
7490 // If the mask dim is non-scalable this can be any value.
7491 // If the mask dim is scalable only zero (all-false) is supported.
7492 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7493 return failure();
7494 constantDims.push_back(*intSize);
7495 } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
7496 // Constant vscale multiple (e.g. 4 x vscale).
7497 // Must be all-true to fold to a ConstantMask.
7498 if (vscaleMultiplier < maskTypeDimSizes[i])
7499 return failure();
7500 constantDims.push_back(*vscaleMultiplier);
7501 } else {
7502 return failure();
7503 }
7504 }
7505
7506 // Clamp values to constant_mask bounds.
7507 for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7508 value = std::clamp<int64_t>(value, 0, maskDimSize);
7509
7510 // If one of dim sizes is zero, set all dims to zero.
7511 if (llvm::is_contained(constantDims, 0))
7512 constantDims.assign(constantDims.size(), 0);
7513
7514 // Replace 'createMaskOp' with ConstantMaskOp.
7515 rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
7516 constantDims);
7517 return success();
7518 }
7519};
7520
7521} // namespace
7522
7523void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7524 MLIRContext *context) {
7525 results.add<CreateMaskFolder>(context);
7526}
7527
7528//===----------------------------------------------------------------------===//
7529// MaskOp
7530//===----------------------------------------------------------------------===//
7531
7532void MaskOp::build(
7533 OpBuilder &builder, OperationState &result, Value mask,
7534 Operation *maskableOp,
7535 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7536 assert(maskRegionBuilder &&
7537 "builder callback for 'maskRegion' must be present");
7538
7539 result.addOperands(mask);
7540 OpBuilder::InsertionGuard guard(builder);
7541 Region *maskRegion = result.addRegion();
7542 builder.createBlock(maskRegion);
7543 maskRegionBuilder(builder, maskableOp);
7544}
7545
7546void MaskOp::build(
7547 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7548 Value mask, Operation *maskableOp,
7549 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7550 build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
7551 maskRegionBuilder);
7552}
7553
7554void MaskOp::build(
7555 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7556 Value mask, Value passthru, Operation *maskableOp,
7557 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7558 build(builder, result, mask, maskableOp, maskRegionBuilder);
7559 if (passthru)
7560 result.addOperands(passthru);
7561 result.addTypes(resultTypes);
7562}
7563
7564ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
7565 // Create the op region.
7566 result.regions.reserve(1);
7567 Region &maskRegion = *result.addRegion();
7568
7569 auto &builder = parser.getBuilder();
7570
7571 // Parse all the operands.
7572 OpAsmParser::UnresolvedOperand mask;
7573 if (parser.parseOperand(mask))
7574 return failure();
7575
7576 // Optional passthru operand.
7577 OpAsmParser::UnresolvedOperand passthru;
7578 ParseResult parsePassthru = parser.parseOptionalComma();
7579 if (parsePassthru.succeeded() && parser.parseOperand(passthru))
7580 return failure();
7581
7582 // Parse op region.
7583 if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
7584 return failure();
7585
7586 MaskOp::ensureTerminator(maskRegion, builder, result.location);
7587
7588 // Parse the optional attribute list.
7589 if (parser.parseOptionalAttrDict(result.attributes))
7590 return failure();
7591
7592 // Parse all the types.
7593 Type maskType;
7594 if (parser.parseColonType(maskType))
7595 return failure();
7596
7597 SmallVector<Type> resultTypes;
7598 if (parser.parseOptionalArrowTypeList(resultTypes))
7599 return failure();
7600 result.types.append(resultTypes);
7601
7602 // Resolve operands.
7603 if (parser.resolveOperand(mask, maskType, result.operands))
7604 return failure();
7605
7606 if (parsePassthru.succeeded()) {
7607 if (resultTypes.empty())
7608 return parser.emitError(
7609 parser.getNameLoc(),
7610 "expects a result if passthru operand is provided");
7611
7612 if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
7613 return failure();
7614 }
7615
7616 return success();
7617}
7618
7619void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7620 p << " " << getMask();
7621 if (getPassthru())
7622 p << ", " << getPassthru();
7623
7624 // Print single masked operation and skip terminator.
7625 p << " { ";
7626 Block *singleBlock = &getMaskRegion().getBlocks().front();
7627 if (singleBlock && !singleBlock->getOperations().empty())
7628 p.printCustomOrGenericOp(&singleBlock->front());
7629 p << " }";
7630
7631 p.printOptionalAttrDict(getOperation()->getAttrs());
7632
7633 p << " : " << getMask().getType();
7634 if (getNumResults() > 0)
7635 p << " -> " << getResultTypes();
7636}
7637
7638void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
7639 // 1. For an empty `vector.mask`, create a default terminator.
7640 if (region.empty() || region.front().empty()) {
7641 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7642 MaskOp>::ensureTerminator(region, builder, loc);
7643 return;
7644 }
7645
7646 // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
7647 Block &block = region.front();
7648 if (isa<vector::YieldOp>(block.back()))
7649 return;
7650
7651 // 3. For a non-empty `vector.mask` without an explicit terminator:
7652
7653 // Create default terminator if the number of masked operations is not
7654 // one. This case will trigger a verification failure.
7655 if (block.getOperations().size() != 1) {
7656 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7657 MaskOp>::ensureTerminator(region, builder, loc);
7658 return;
7659 }
7660
7661 // Create a terminator that yields the results from the masked operation.
7662 OpBuilder opBuilder(builder.getContext());
7663 Operation *maskedOp = &block.front();
7664 opBuilder.setInsertionPointToEnd(&block);
7665 vector::YieldOp::create(opBuilder, loc, maskedOp->getResults());
7666}
7667
7668LogicalResult MaskOp::verify() {
7669 // Structural checks.
7670 Block &block = getMaskRegion().getBlocks().front();
7671 if (block.getOperations().empty())
7672 return emitOpError("expects a terminator within the mask region");
7673
7674 unsigned numMaskRegionOps = block.getOperations().size();
7675 if (numMaskRegionOps > 2)
7676 return emitOpError("expects only one operation to mask");
7677
7678 // Terminator checks.
7679 auto terminator = dyn_cast<vector::YieldOp>(block.back());
7680 if (!terminator)
7681 return emitOpError("expects a terminator within the mask region");
7682
7683 if (terminator->getNumOperands() != getNumResults())
7684 return emitOpError(
7685 "expects number of results to match mask region yielded values");
7686
7687 // Empty vector.mask. Nothing else to check.
7688 if (numMaskRegionOps == 1)
7689 return success();
7690
7691 auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
7692 if (!maskableOp)
7693 return emitOpError("expects a MaskableOpInterface within the mask region");
7694
7695 // Result checks.
7696 if (maskableOp->getNumResults() != getNumResults())
7697 return emitOpError("expects number of results to match maskable operation "
7698 "number of results");
7699
7700 if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
7701 return emitOpError("expects all the results from the MaskableOpInterface "
7702 "to match all the values returned by the terminator");
7703
7704 if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
7705 return emitOpError(
7706 "expects result type to match maskable operation result type");
7707
7708 if (llvm::count_if(maskableOp->getResultTypes(),
7709 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7710 return emitOpError("multiple vector results not supported");
7711
7712 // Mask checks.
7713 Type expectedMaskType = maskableOp.getExpectedMaskType();
7714 if (getMask().getType() != expectedMaskType)
7715 return emitOpError("expects a ")
7716 << expectedMaskType << " mask for the maskable operation";
7717
7718 // Passthru checks.
7719 Value passthru = getPassthru();
7720 if (passthru) {
7721 if (!maskableOp.supportsPassthru())
7722 return emitOpError(
7723 "doesn't expect a passthru argument for this maskable operation");
7724
7725 if (maskableOp->getNumResults() != 1)
7726 return emitOpError("expects result when passthru argument is provided");
7727
7728 if (passthru.getType() != maskableOp->getResultTypes()[0])
7729 return emitOpError("expects passthru type to match result type");
7730 }
7731
7732 return success();
7733}
7734
7735/// Folds empty `vector.mask` with no passthru operand and with or without
7736/// return values. For example:
7737///
7738/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
7739/// vector<8xi1> -> vector<8xf32>
7740/// %1 = user_op %0 : vector<8xf32>
7741///
7742/// becomes:
7743///
7744/// %0 = user_op %a : vector<8xf32>
7745///
7746/// Empty `vector.mask` with passthru operand are handled by the canonicalizer
7747/// as it requires creating new operations.
7748
7749static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7750 SmallVectorImpl<OpFoldResult> &results) {
7751 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7752 return failure();
7753
7754 Block *block = maskOp.getMaskBlock();
7755 auto terminator = cast<vector::YieldOp>(block->front());
7756 if (terminator.getNumOperands() == 0)
7757 return failure();
7758
7759 // `vector.mask` has results, propagate the results.
7760 llvm::append_range(results, terminator.getOperands());
7761 return success();
7762}
7763
7764LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7765 SmallVectorImpl<OpFoldResult> &results) {
7766 if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
7767 return success();
7768
7769 MaskFormat maskFormat = getMaskFormat(getMask());
7770 if (maskFormat != MaskFormat::AllTrue)
7771 return failure();
7772
7773 // Move maskable operation outside of the `vector.mask` region.
7774 // If there is no maskable op (empty body), the fold cannot proceed; the
7775 // canonicalizer handles this case instead.
7776 Operation *maskableOp = getMaskableOp();
7777 if (!maskableOp)
7778 return failure();
7779 maskableOp->dropAllUses();
7780 maskableOp->moveBefore(getOperation());
7781
7782 llvm::append_range(results, maskableOp->getResults());
7783 return success();
7784}
7785
7786/// Canonialize empty `vector.mask` operations that can't be handled in
7787/// `VectorMask::fold` as they require creating new operations.
7788///
7789/// Example 1: Empty `vector.mask` with passthru operand.
7790///
7791/// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
7792/// vector<8xi1> -> vector<8xf32>
7793///
7794/// becomes:
7795///
7796/// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
7797///
7798class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
7799 using Base::Base;
7800
7801 LogicalResult matchAndRewrite(MaskOp maskOp,
7802 PatternRewriter &rewriter) const override {
7803 if (!maskOp.isEmpty())
7804 return failure();
7805
7806 if (!maskOp.hasPassthru())
7807 return failure();
7808
7809 // arith.select with a vector condition requires the value types to be
7810 // vectors of the same shape. Since vector.mask always has a vector mask
7811 // type, bail out when any result type doesn't match the mask shape to
7812 // avoid creating invalid IR.
7813 VectorType maskType = maskOp.getMask().getType();
7814 for (Type resultType : maskOp.getResultTypes()) {
7815 auto vecResultType = dyn_cast<VectorType>(resultType);
7816 if (!vecResultType || vecResultType.getShape() != maskType.getShape())
7817 return failure();
7818 }
7819
7820 Block *block = maskOp.getMaskBlock();
7821 auto terminator = cast<vector::YieldOp>(block->front());
7822 assert(terminator.getNumOperands() == 1 &&
7823 "expected one result when passthru is provided");
7824
7825 rewriter.replaceOpWithNewOp<arith::SelectOp>(
7826 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
7827 terminator.getOperand(0), maskOp.getPassthru());
7828
7829 return success();
7830 }
7831};
7832
7833void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7834 MLIRContext *context) {
7835 results.add<CanonializeEmptyMaskOp>(context);
7836}
7837
7838// MaskingOpInterface definitions.
7839
7840/// Returns the operation masked by this 'vector.mask'.
7841Operation *MaskOp::getMaskableOp() {
7842 Block *block = getMaskBlock();
7843 if (block->getOperations().size() < 2)
7844 return nullptr;
7845
7846 return &block->front();
7847}
7848
7849/// Returns true if 'vector.mask' has a passthru value.
7850bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
7851
7852//===----------------------------------------------------------------------===//
7853// ScanOp
7854//===----------------------------------------------------------------------===//
7855
7856LogicalResult ScanOp::verify() {
7857 VectorType srcType = getSourceType();
7858 VectorType initialType = getInitialValueType();
7859 // Check reduction dimension < rank.
7860 int64_t srcRank = srcType.getRank();
7861 int64_t reductionDim = getReductionDim();
7862 if (reductionDim >= srcRank)
7863 return emitOpError("reduction dimension ")
7864 << reductionDim << " has to be less than " << srcRank;
7865
7866 // Check that rank(initial_value) = rank(src) - 1.
7867 int64_t initialValueRank = initialType.getRank();
7868 if (initialValueRank != srcRank - 1)
7869 return emitOpError("initial value rank ")
7870 << initialValueRank << " has to be equal to " << srcRank - 1;
7871
7872 // Check shapes of initial value and src.
7873 ArrayRef<int64_t> srcShape = srcType.getShape();
7874 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
7875 SmallVector<int64_t> expectedShape;
7876 for (int i = 0; i < srcRank; i++) {
7877 if (i != reductionDim)
7878 expectedShape.push_back(srcShape[i]);
7879 }
7880 if (!llvm::equal(initialValueShapes, expectedShape)) {
7881 return emitOpError("incompatible input/initial value shapes");
7882 }
7883
7884 // Verify supported reduction kind.
7885 Type eltType = getDestType().getElementType();
7886 if (!isSupportedCombiningKind(getKind(), eltType))
7887 return emitOpError("unsupported reduction type ")
7888 << eltType << " for kind '" << stringifyCombiningKind(getKind())
7889 << "'";
7890
7891 return success();
7892}
7893
7895 RewritePatternSet &patterns, PatternBenefit benefit) {
7896 patterns
7897 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7898 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7899 StridedSliceConstantMaskFolder, TransposeFolder>(
7900 patterns.getContext(), benefit);
7901}
7902
7903Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
7904 CombiningKind kind, Value v1, Value acc,
7905 arith::FastMathFlagsAttr fastmath,
7906 Value mask) {
7907 Type t1 = getElementTypeOrSelf(v1.getType());
7908 Type tAcc = getElementTypeOrSelf(acc.getType());
7909 Value result;
7910
7911 switch (kind) {
7912 case CombiningKind::ADD:
7913 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7914 result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
7915 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7916 result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7917 else
7918 llvm_unreachable("invalid value types for ADD reduction");
7919 break;
7920 case CombiningKind::AND:
7921 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7922 result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
7923 break;
7924 case CombiningKind::MAXNUMF:
7925 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7926 "expected float values");
7927 result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7928 break;
7929 case CombiningKind::MAXIMUMF:
7930 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7931 "expected float values");
7932 result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7933 break;
7934 case CombiningKind::MINNUMF:
7935 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7936 "expected float values");
7937 result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7938 break;
7939 case CombiningKind::MINIMUMF:
7940 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7941 "expected float values");
7942 result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7943 break;
7944 case CombiningKind::MAXSI:
7945 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7946 result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
7947 break;
7948 case CombiningKind::MINSI:
7949 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7950 result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
7951 break;
7952 case CombiningKind::MAXUI:
7953 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7954 result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
7955 break;
7956 case CombiningKind::MINUI:
7957 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7958 result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
7959 break;
7960 case CombiningKind::MUL:
7961 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7962 result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
7963 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7964 result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7965 else
7966 llvm_unreachable("invalid value types for MUL reduction");
7967 break;
7968 case CombiningKind::OR:
7969 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7970 result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
7971 break;
7972 case CombiningKind::XOR:
7973 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7974 result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
7975 break;
7976 };
7977
7978 assert(result && "unknown CombiningKind");
7979 return selectPassthru(b, mask, result, acc);
7980}
7981
7982//===----------------------------------------------------------------------===//
7983// StepOp
7984//===----------------------------------------------------------------------===//
7985
7986void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7987 SetIntRangeFn setResultRanges) {
7988 auto resultType = cast<VectorType>(getType());
7989 if (resultType.isScalable()) {
7990 return;
7991 }
7992 unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
7993 APInt zero(bitwidth, 0);
7994 APInt high(bitwidth, resultType.getDimSize(0) - 1);
7995 ConstantIntRanges result = {zero, high, zero, high};
7996 setResultRanges(getResult(), result);
7997}
7998
7999namespace {
8000
8001/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
8002/// constant large enough such that the result is the same at all indices.
8003///
8004/// For example, rewrite the 'greater than' comparison below,
8005///
8006/// ```mlir
8007/// %cst = arith.constant dense<7> : vector<3xindex>
8008/// %stp = vector.step : vector<3xindex>
8009/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
8010/// ```
8011///
8012/// as,
8013///
8014/// ```mlir
8015/// %out = arith.constant dense<false> : vector<3xi1>.
8016/// ```
8017///
8018/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result
8019/// is false at ALL indices we fold. If the constant was 1, then
8020/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold,
8021/// conservatively preferring the 'compact' vector.step representation.
8022///
8023/// Note: this folder only works for the case where the constant (`%cst` above)
8024/// is the second operand of the comparison. The arith.cmpi canonicalizer will
8025/// ensure that constants are always second (on the right).
8026struct StepCompareFolder : public OpRewritePattern<StepOp> {
8027 using Base::Base;
8028
8029 LogicalResult matchAndRewrite(StepOp stepOp,
8030 PatternRewriter &rewriter) const override {
8031 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
8032
8033 for (OpOperand &use : stepOp.getResult().getUses()) {
8034 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
8035 if (!cmpiOp)
8036 continue;
8037
8038 // arith.cmpi canonicalizer makes constants final operands.
8039 const unsigned stepOperandNumber = use.getOperandNumber();
8040 if (stepOperandNumber != 0)
8041 continue;
8042
8043 // Check that operand 1 is a constant.
8044 unsigned constOperandNumber = 1;
8045 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
8046 std::optional<int64_t> maybeConstValue =
8047 getConstantIntValue(otherOperand);
8048 if (!maybeConstValue.has_value())
8049 continue;
8050
8051 int64_t constValue = maybeConstValue.value();
8052 arith::CmpIPredicate pred = cmpiOp.getPredicate();
8053
8054 auto maybeSplat = [&]() -> std::optional<bool> {
8055 // Handle ult (unsigned less than) and uge (unsigned greater equal).
8056 if ((pred == arith::CmpIPredicate::ult ||
8057 pred == arith::CmpIPredicate::uge) &&
8058 stepSize <= constValue)
8059 return pred == arith::CmpIPredicate::ult;
8060
8061 // Handle ule and ugt.
8062 if ((pred == arith::CmpIPredicate::ule ||
8063 pred == arith::CmpIPredicate::ugt) &&
8064 stepSize - 1 <= constValue) {
8065 return pred == arith::CmpIPredicate::ule;
8066 }
8067
8068 // Handle eq and ne.
8069 if ((pred == arith::CmpIPredicate::eq ||
8070 pred == arith::CmpIPredicate::ne) &&
8071 stepSize <= constValue)
8072 return pred == arith::CmpIPredicate::ne;
8073
8074 return std::nullopt;
8075 }();
8076
8077 if (!maybeSplat.has_value())
8078 continue;
8079
8080 rewriter.setInsertionPointAfter(cmpiOp);
8081
8082 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
8083 if (!type)
8084 continue;
8085
8086 auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value());
8087 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
8088 type, boolAttr);
8089
8090 rewriter.replaceOp(cmpiOp, splat);
8091 return success();
8092 }
8093
8094 return failure();
8095 }
8096};
8097} // namespace
8098
8099void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
8100 MLIRContext *context) {
8101 results.add<StepCompareFolder>(context);
8102}
8103
8104//===----------------------------------------------------------------------===//
8105// Vector Masking Utilities
8106//===----------------------------------------------------------------------===//
8107
8108/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
8109/// as masked operation.
8110void mlir::vector::createMaskOpRegion(OpBuilder &builder,
8111 Operation *maskableOp) {
8112 assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
8113 Block *insBlock = builder.getInsertionBlock();
8114 // Create a block and move the op to that block.
8115 insBlock->getOperations().splice(
8116 insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
8117 YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults());
8118}
8119
8120/// Creates a vector.mask operation around a maskable operation. Returns the
8121/// vector.mask operation if the mask provided is valid. Otherwise, returns
8122/// the maskable operation itself.
8123Operation *mlir::vector::maskOperation(OpBuilder &builder,
8124 Operation *maskableOp, Value mask,
8125 Value passthru) {
8126 if (!mask)
8127 return maskableOp;
8128 if (passthru)
8129 return MaskOp::create(builder, maskableOp->getLoc(),
8130 maskableOp->getResultTypes(), mask, passthru,
8131 maskableOp, createMaskOpRegion);
8132 return MaskOp::create(builder, maskableOp->getLoc(),
8133 maskableOp->getResultTypes(), mask, maskableOp,
8135}
8136
8137/// Creates a vector select operation that picks values from `newValue` or
8138/// `passthru` for each result vector lane based on `mask`. This utility is used
8139/// to propagate the pass-thru value of vector.mask or for cases where only the
8140/// pass-thru value propagation is needed. VP intrinsics do not support
8141/// pass-thru values and every mask-out lane is set to poison. LLVM backends are
8142/// usually able to match op + select patterns and fold them into a native
8143/// target instructions.
8144Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
8145 Value newValue, Value passthru) {
8146 if (!mask)
8147 return newValue;
8148
8149 return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(),
8150 mask, newValue, passthru);
8151}
8152
8153//===----------------------------------------------------------------------===//
8154// TableGen'd op method definitions
8155//===----------------------------------------------------------------------===//
8156
8157#define GET_ATTRDEF_CLASSES
8158#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
8159
8160#define GET_OP_CLASSES
8161#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
b getContext())
auto load
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition VectorOps.cpp:73
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
Definition VectorOps.cpp:63
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
#define mul(a, b)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition Attributes.h:58
bool empty()
Definition Block.h:158
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
iterator begin()
Definition Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
IntegerType getI1Type()
Definition Builders.cpp:57
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:274
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:322
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:444
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:376
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:860
void setOperand(unsigned idx, Value value)
Definition Operation.h:377
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
operand_type_range getOperandTypes()
Definition Operation.h:423
result_type_range getResultTypes()
Definition Operation.h:454
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:441
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
bool isF32() const
Definition Types.cpp:40
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Definition UBMatchers.h:46
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition VectorOps.h:64
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition VectorOps.h:72
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const